diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 492008e..625be26 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,6 +73,8 @@ jobs: preset: ci - os: macos-latest preset: ci-mac + - os: windows-latest + preset: ci-win steps: - uses: actions/checkout@v3 @@ -87,6 +89,7 @@ jobs: run: ctest --preset all-test --output-on-failure - name: Test Custom Params + if: matrix.preset != 'ci-win' run: | rm -rf build/CMakeCache.txt cmake --preset ${{ matrix.preset }}-custom-param diff --git a/CMakeLists.txt b/CMakeLists.txt index 93ffde0..51be247 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.21) project( deb - VERSION 0.2.1 + VERSION 0.3.0 LANGUAGES CXX DESCRIPTION "CryptoLab's official cryptosystem library for FHE.") @@ -40,9 +40,16 @@ set(CMAKE_INSTALL_PREFIX set(DEB_MEMORY_ALIGN_SIZE "0" - CACHE STRING - "Memory alignment size in bytes (default: 256). 0 means no alignment." + CACHE STRING "Memory alignment size in bytes. 0 means no alignment.") +set(DEB_EXT_LIB_FOR_SECURE_ZERO + "NONE" + CACHE + STRING + "External library for secure zeroing of memory. Options: NATIVE, LIBSODIUM, OPENSSL, NONE." ) +message(STATUS "Memory alignment size: ${DEB_MEMORY_ALIGN_SIZE} bytes") +message( + STATUS "External library for secure zeroing: ${DEB_EXT_LIB_FOR_SECURE_ZERO}") set(PRE_BUILD_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/prebuild) set(PRE_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include/generated) @@ -52,7 +59,6 @@ message(STATUS "Generate pre-build directory: ${PRE_BUILD_DIR}") include(cmake/CPM.cmake) include(cmake/warnings.cmake) -include(cmake/bundling.cmake) option(BUILD_SHARED_LIBS "Build a shared library instead of a static one." OFF) option(DEB_BUILD_BENCHMARK "Build the benchmark suite." OFF) @@ -84,22 +90,25 @@ add_subdirectory(external) add_subdirectory(prebuild) set(DEB_SRC - src/ModArith.cpp - src/Context.cpp + src/AleaRandomGenerator.cpp src/CKKSTypes.cpp - src/NTT.cpp + src/Decryptor.cpp + src/Encryptor.cpp src/FFT.cpp - src/SeedGenerator.cpp - src/Serialize.cpp - src/SecretKeyGenerator.cpp src/KeyGenerator.cpp - src/Encryptor.cpp - src/Decryptor.cpp) + src/ModArith.cpp + src/NTT.cpp + src/OmpUtils.cpp + src/Preset.cpp + src/RandomGenerator.cpp + src/SecretKeyGenerator.cpp + src/SeedGenerator.cpp + src/Serialize.cpp) add_library(${PROJECT_NAME}_obj OBJECT ${DEB_SRC}) add_dependencies(${PROJECT_NAME}_obj generate_flatbuffers generated_param_header) -set_my_project_warnings(${PROJECT_NAME}_obj) +set_deb_warnings(${PROJECT_NAME}_obj) set_property(TARGET ${PROJECT_NAME}_obj PROPERTY POSITION_INDEPENDENT_CODE ON) if(NOT MSVC) @@ -110,6 +119,19 @@ target_link_libraries( ${PROJECT_NAME}_obj PRIVATE $ $) +string(TOUPPER "${DEB_EXT_LIB_FOR_SECURE_ZERO}" _deb_secure_zero_backend) +if(_deb_secure_zero_backend STREQUAL "LIBSODIUM") + target_compile_definitions(${PROJECT_NAME}_obj + PUBLIC DEB_SECURE_ZERO_LIBSODIUM) + target_link_libraries(${PROJECT_NAME}_obj PRIVATE sodium) +elseif(_deb_secure_zero_backend STREQUAL "OPENSSL") + target_compile_definitions(${PROJECT_NAME}_obj PUBLIC DEB_SECURE_ZERO_OPENSSL) + target_link_libraries(${PROJECT_NAME}_obj PRIVATE OpenSSL::Crypto) +elseif(_deb_secure_zero_backend STREQUAL "NATIVE") + target_compile_definitions(${PROJECT_NAME}_obj PUBLIC DEB_SECURE_ZERO_NATIVE) +endif() +unset(_deb_secure_zero_backend) + target_include_directories( ${PROJECT_NAME}_obj PRIVATE $ @@ -122,7 +144,9 @@ endif() target_compile_definitions(${PROJECT_NAME}_obj PUBLIC DEB_ALINAS_LEN=${DEB_MEMORY_ALIGN_SIZE}) -add_library(${PROJECT_NAME}) +add_library(${PROJECT_NAME} STATIC $ + $) + add_library(${PROJECT_NAME}::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) target_link_libraries(${PROJECT_NAME} PUBLIC ${PROJECT_NAME}_obj) target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_17) @@ -160,8 +184,18 @@ if(DEB_BUILD_DOXYGEN) set(DOXYGEN_USE_MDFILE_AS_MAINPAGE ${CMAKE_CURRENT_SOURCE_DIR}/README.md) doxygen_add_docs( - deb_doc ${CMAKE_CURRENT_SOURCE_DIR}/include/cdeb/deb.h - ${CMAKE_CURRENT_SOURCE_DIR}/README.md ALL + deb_doc + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/CKKSTypes.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/Constant.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/Context.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/Decryptor.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/Encryptor.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/KeyGenerator.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/SecretKeyGenerator.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/Serialize.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/include/deb/Types.hpp + ${CMAKE_CURRENT_SOURCE_DIR}/README.md + ALL COMMENT "Generate documentation with Doxygen") else() message(STATUS "Doxygen not found, documentation will not be generated.") @@ -181,9 +215,6 @@ if(DEB_BUILD_TEST) add_subdirectory(test) endif() -merge_archive_if_static(${PROJECT_NAME} alea) -merge_archive_if_static(${PROJECT_NAME} flatbuffers) - if(APPLE) set_target_properties(${PROJECT_NAME} PROPERTIES INSTALL_RPATH "@loader_path" BUILD_RPATH "@loader_path") diff --git a/CMakePresets.json b/CMakePresets.json index 5c5e59f..508189c 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -25,7 +25,8 @@ "installDir": "${sourceDir}/build/install", "cacheVariables": { "DEB_BUILD_TEST": "ON", - "DEB_BUILD_BENCHMARK": "ON", + "DEB_BUILD_EXAMPLES": "OFF", + "DEB_BUILD_BENCHMARK": "OFF", "DEB_BUILD_WITH_OMP": "ON" } }, @@ -36,6 +37,14 @@ "DEB_BUILD_WITH_OMP": "OFF" } }, + { + "name": "ci-win", + "inherits": "ci", + "generator": "Ninja", + "cacheVariables": { + "DEB_BUILD_WITH_OMP": "OFF" + } + }, { "name": "ci-custom-param", "inherits": "ci", @@ -66,6 +75,8 @@ "inherits": "release", "cacheVariables": { "DEB_RUNTIME_RESOURCE_CHECK": "OFF", + "DEB_BUILD_EXAMPLES": "OFF", + "DEB_BUILD_TEST": "OFF", "DEB_BUILD_BENCHMARK": "ON", "DEB_BUILD_WITH_OMP": "ON" } @@ -94,6 +105,12 @@ "configuration": "release", "jobs": 4 }, + { + "name": "ci-win", + "configurePreset": "ci-win", + "configuration": "release", + "jobs": 4 + }, { "name": "ci-custom-param", "configurePreset": "ci-custom-param", @@ -124,7 +141,7 @@ "noTestsAction": "error" }, "environment": { - "OMP_NUM_THREADS": "8" + "OMP_NUM_THREADS": "1" } }, { diff --git a/benchmark/Blake3RandomGenerator.hpp b/benchmark/Blake3RandomGenerator.hpp new file mode 100644 index 0000000..644a46c --- /dev/null +++ b/benchmark/Blake3RandomGenerator.hpp @@ -0,0 +1,216 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "utils/Constant.hpp" +#include "utils/RandomGenerator.hpp" +#include "blake3.h" + +#include +#include +#include +#include +#include +#include + +namespace deb { + +class Blake3RandomGenerator : public RandomGenerator { +public: + explicit Blake3RandomGenerator(const RNGSeed& seed) { + reseed(reinterpret_cast(seed.data()), + DEB_RNG_SEED_BYTE_SIZE); + } + + Blake3RandomGenerator(const Blake3RandomGenerator&) = delete; + Blake3RandomGenerator& operator=(const Blake3RandomGenerator&) = delete; + + void getRandomUint64Array(u64* dst, size_t len) override { + for (size_t i = 0; i < len; ++i) { + u8 tmp[8]; + getBytes(tmp, sizeof(tmp)); + dst[i] = load_le_u64(tmp); + } + } + void getRandomUint64ArrayInRange(u64* dst, size_t len, u64 range) override { + if (range == 0) { + std::memset(dst, 0, sizeof(u64) * len); + return; + } + for (size_t i = 0; i < len; ++i) { + dst[i] = uniform_u64_range(range); + } + } + + void sampleGaussianInt64Array(i64* dst, size_t len, double stdev) override { + if (!(stdev > 0.0)) throw std::invalid_argument("stdev must be > 0"); + + // Box-Muller + size_t i = 0; + while (i < len) { + double u1 = uniform01(); + double u2 = uniform01(); + // avoid log(0) + if (u1 <= 0.0) continue; + + double r = std::sqrt(-2.0 * std::log(u1)); + double theta = 2.0 * utils::REAL_PI * u2; + + double z0 = r * std::cos(theta) * stdev; + double z1 = r * std::sin(theta) * stdev; + + dst[i++] = static_cast(std::llround(z0)); + if (i < len) dst[i++] = static_cast(std::llround(z1)); + } + } + void sampleHwtInt8Array(i8* dst, size_t len, int hwt) override { + if (hwt < 0 || static_cast(hwt) > len) + throw std::invalid_argument("hwt must be in [0, len]"); + + std::memset(dst, 0, len); + + std::vector idx(len); + for (size_t i = 0; i < len; ++i) + idx[i] = i; + + // Fisher-Yates shuffle for first hwt indices, and assign +1/-1 randomly + for (int i = 0; i < hwt; ++i) { + u64 r = uniform_u64_range(len - i); + size_t k = i + static_cast(r); + std::swap(idx[i], idx[k]); + dst[idx[i]] = uniform_u64_range(2) ? 1 : -1; + } + } + + void reseed(const u8* seed, size_t seed_len) override { + // hard-coded context string + static constexpr const char* kContext = + "example.random 2026-02-13 blake3-prng v1"; + + blake3_hasher h; + blake3_hasher_init_derive_key(&h, kContext); + if (seed && seed_len) blake3_hasher_update(&h, seed, seed_len); + blake3_hasher_finalize(&h, key_.data(), key_.size()); + + counter_ = 0; + buf_pos_ = kBufSize; // flush + // internal use counter=0 when first refill + } + +private: + // 32 bytes key (keyed hashing mode requirement) + std::array key_{}; + + // Counter-based block generation + u64 counter_ = 0; + + // Byte buffering (performance defense against many small calls) + static constexpr size_t kBufSize = 4096; + std::array buf_{}; + size_t buf_pos_ = kBufSize; // empty + + void refill() { + // hashing (domain || counter) message with keyed hashing, + // and output kBufSize bytes stream with finalize_seek. + // Using seek=0 and counter to make independent stream. + const u8 domain[] = { + 'B','L','A','K','E','3','-','P','R','N','G','-','v','1', 0x00 + }; + + u8 ctr_le[8]; + store_le_u64(ctr_le, counter_); + + blake3_hasher h; + blake3_hasher_init_keyed(&h, key_.data()); + blake3_hasher_update(&h, domain, sizeof(domain)); + blake3_hasher_update(&h, ctr_le, sizeof(ctr_le)); + + // from seek=0, kBufSize bytes + blake3_hasher_finalize_seek(&h, /*seek=*/0, buf_.data(), buf_.size()); + + counter_++; + buf_pos_ = 0; + + // optionally rekey every refill to enhance forward-secrecy-ish property. + // Here we do not do light rekeying every refill, + // let caller do periodic reseed or call rekey() if needed. + } + void getBytes(void* out, size_t n) { + u8* p = static_cast(out); + while (n) { + if (buf_pos_ >= buf_.size()) refill(); + size_t avail = buf_.size() - buf_pos_; + size_t take = (n < avail) ? n : avail; + std::memcpy(p, buf_.data() + buf_pos_, take); + buf_pos_ += take; + p += take; + n -= take; + } + } + + static u64 load_le_u64(const u8* p) { + u64 x = 0; + for (int i = 7; i >= 0; --i) x = (x << 8) | p[i]; + return x; + } + + static void store_le_u64(u8* p, u64 x) { + for (int i = 0; i < 8; ++i) { p[i] = static_cast(x & 0xFF); x >>= 8; } + } + + // [0,1) double (53-bit) generation + double uniform01() { + u8 tmp[8]; + getBytes(tmp, sizeof(tmp)); + u64 x = load_le_u64(tmp); + x >>= 11; // 64-53 + // [0, 2^53) / 2^53 + constexpr double denom = 9007199254740992.0; // 2^53 + return static_cast(x) / denom; + } + + // [0, range) unbiased + u64 uniform_u64_range(u64 range) { + if (range == 0) return 0; // caller bug defense + // rejection sampling (unbiased) + const u64 max = std::numeric_limits::max(); + const u64 limit = max - (max % range); + while (true) { + u8 tmp[8]; + getBytes(tmp, sizeof(tmp)); + u64 x = load_le_u64(tmp); + if (x < limit) return x % range; + } + } + + // Rekey (re-extract key from output stream to enhance forward-secrecy-ish property) + void rekey() { + std::array new_key{}; + getBytes(new_key.data(), new_key.size()); + key_ = new_key; + + // Flush buffer and start fresh with new key + buf_pos_ = kBufSize; + counter_ = 0; + } +}; + +std::shared_ptr createBlake3RandomGenerator(const RNGSeed& seed) { + return std::make_shared(seed); +}; + +} // namespace deb diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index c2f99ff..7b00073 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,4 +20,19 @@ cpmaddpackage(URI "gh:google/benchmark@1.9.4" OPTIONS add_executable(deb-benchmark benchmark.cpp) target_link_libraries(deb-benchmark PRIVATE deb_obj deb_DevHeader benchmark::benchmark_main) -set_my_project_warnings(deb-benchmark) + +enable_language(C) +cpmaddpackage( + NAME + blake3 + GITHUB_REPOSITORY + BLAKE3-team/BLAKE3 + GIT_TAG + 1.8.1 + SOURCE_SUBDIR + c) + +add_executable(deb-benchmark-blake3 benchmark_blake3.cpp) +target_link_libraries( + deb-benchmark-blake3 PRIVATE deb_obj deb_DevHeader benchmark::benchmark_main + BLAKE3::blake3) diff --git a/benchmark/benchmark.cpp b/benchmark/benchmark.cpp index 5b05465..e967039 100644 --- a/benchmark/benchmark.cpp +++ b/benchmark/benchmark.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,6 +29,7 @@ std::random_device rd; std::mt19937 gen{rd()}; std::uniform_real_distribution dist{-1.0, 1.0}; +std::uniform_int_distribution dist_u64; using namespace deb; @@ -52,16 +53,15 @@ static CoeffMessage gen_random_coeff(const Size degree) { template static void bm_seckey_encryption(benchmark::State &state) { const Preset preset = T; - const auto context = getContext(preset); - const auto ns = context->get_num_secret(); + const auto ns = get_num_secret(preset); std::vector msg_v; for (Size i = 0; i < ns; ++i) { - msg_v.push_back(gen_random_message(context->get_num_slots())); + msg_v.push_back(gen_random_message(get_num_slots(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - Encryptor encryptor(preset); - Ciphertext ctxt(context); + EncryptorT encryptor; + Ciphertext ctxt(preset); for (auto _ : state) { encryptor.encrypt(msg_v, sk, ctxt); @@ -73,43 +73,37 @@ static void bm_seckey_encryption(benchmark::State &state) { template static void bm_enckey_encryption(benchmark::State &state) { const Preset preset = T; - const auto context = getContext(preset); - const auto ns = context->get_num_secret(); + const auto ns = get_num_secret(preset); std::vector msg_v; for (Size i = 0; i < ns; ++i) { - msg_v.push_back(gen_random_message(context->get_num_slots())); + msg_v.push_back(gen_random_message(get_num_slots(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); KeyGenerator keygen(preset); SwitchKey enckey = keygen.genEncKey(sk); - Encryptor encryptor(preset); - - Ciphertext ctxt(context); - // std::vector ctxt_v; + EncryptorT encryptor; + Ciphertext ctxt(preset); for (auto _ : state) { - encryptor.encrypt(msg_v, sk, ctxt); + encryptor.encrypt(msg_v, enckey, ctxt); benchmark::DoNotOptimize(ctxt); benchmark::ClobberMemory(); - // ctxt_v.push_back(ctxt); } } template static void bm_decryption(benchmark::State &state) { const Preset preset = T; - const auto context = getContext(preset); - const auto ns = context->get_num_secret(); + const auto ns = get_num_secret(preset); std::vector msg_v; for (Size i = 0; i < ns; ++i) { - msg_v.push_back(gen_random_message(context->get_num_slots())); + msg_v.push_back(gen_random_message(get_num_slots(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - Encryptor encryptor(preset); - Decryptor decryptor(preset); - - Ciphertext ctxt(context); + EncryptorT encryptor; + DecryptorT decryptor; + Ciphertext ctxt(preset); encryptor.encrypt(msg_v, sk, ctxt); for (auto _ : state) { @@ -122,66 +116,57 @@ template static void bm_decryption(benchmark::State &state) { template static void bm_seckey_coeff_encryption(benchmark::State &state) { const Preset preset = T; - const auto context = getContext(preset); - const auto ns = context->get_num_secret(); + const auto ns = get_num_secret(preset); std::vector msg_v; for (Size i = 0; i < ns; ++i) { - msg_v.push_back(gen_random_coeff(context->get_degree())); + msg_v.push_back(gen_random_coeff(get_degree(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - Encryptor encryptor(preset); - - Ciphertext ctxt(context); - // std::vector ctxt_v; + EncryptorT encryptor; + Ciphertext ctxt(preset); for (auto _ : state) { encryptor.encrypt(msg_v, sk, ctxt); benchmark::DoNotOptimize(ctxt); benchmark::ClobberMemory(); - // ctxt_v.push_back(ctxt); } } template static void bm_enckey_coeff_encryption(benchmark::State &state) { const Preset preset = T; - const auto context = getContext(preset); - const auto ns = context->get_num_secret(); + const auto ns = get_num_secret(preset); std::vector msg_v; for (Size i = 0; i < ns; ++i) { - msg_v.push_back(gen_random_coeff(context->get_degree())); + msg_v.push_back(gen_random_coeff(get_degree(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); KeyGenerator keygen(preset); SwitchKey enckey = keygen.genEncKey(sk); - Encryptor encryptor(preset); - - Ciphertext ctxt(context); - // std::vector ctxt_v; + EncryptorT encryptor; + Ciphertext ctxt(preset); for (auto _ : state) { - encryptor.encrypt(msg_v, sk, ctxt); + encryptor.encrypt(msg_v, enckey, ctxt); benchmark::DoNotOptimize(ctxt); benchmark::ClobberMemory(); - // ctxt_v.push_back(ctxt); } } template static void bm_coeff_decryption(benchmark::State &state) { const Preset preset = T; - const auto context = getContext(preset); - const auto ns = context->get_num_secret(); + const auto ns = get_num_secret(preset); std::vector msg_v; for (Size i = 0; i < ns; ++i) { - msg_v.push_back(gen_random_coeff(context->get_degree())); + msg_v.push_back(gen_random_coeff(get_degree(preset))); } SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - Encryptor encryptor(preset); - Decryptor decryptor(preset); + EncryptorT encryptor(preset); + DecryptorT decryptor(preset); - Ciphertext ctxt(context); + Ciphertext ctxt(preset); encryptor.encrypt(msg_v, sk, ctxt); for (auto _ : state) { @@ -191,19 +176,58 @@ static void bm_coeff_decryption(benchmark::State &state) { } } -#define X(PRESET) \ - BENCHMARK_TEMPLATE(bm_seckey_encryption, Preset::PRESET_##PRESET) \ - ->Unit(benchmark::kMicrosecond); \ - BENCHMARK_TEMPLATE(bm_enckey_encryption, Preset::PRESET_##PRESET) \ - ->Unit(benchmark::kMicrosecond); \ - BENCHMARK_TEMPLATE(bm_decryption, Preset::PRESET_##PRESET) \ - ->Unit(benchmark::kMicrosecond); \ +template +static void bm_forward_ntt(benchmark::State &state) { + utils::NTT ntt(degree, prime); + u64 data[degree]; + + for (auto _ : state) { + state.PauseTiming(); + for (Size i = 0; i < degree; ++i) { + data[i] = dist_u64(gen) % prime; + } + state.ResumeTiming(); + ntt.computeForward(data); + benchmark::DoNotOptimize(data); + benchmark::ClobberMemory(); + } +} + +template +static void bm_backward_ntt(benchmark::State &state) { + utils::NTT ntt(degree, prime); + u64 data[degree]; + + for (auto _ : state) { + state.PauseTiming(); + for (Size i = 0; i < degree; ++i) { + data[i] = dist_u64(gen) % prime; + } + ntt.computeForward(data); + state.ResumeTiming(); + ntt.computeBackward(data); + benchmark::DoNotOptimize(data); + benchmark::ClobberMemory(); + } +} + +#define X(PRESET) \ + BENCHMARK_TEMPLATE(bm_seckey_encryption, Preset::PRESET_##PRESET) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(bm_enckey_encryption, Preset::PRESET_##PRESET) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(bm_decryption, Preset::PRESET_##PRESET) \ + ->Unit(benchmark::kMicrosecond); \ BENCHMARK_TEMPLATE(bm_seckey_coeff_encryption, Preset::PRESET_##PRESET) \ - ->Unit(benchmark::kMicrosecond); \ + ->Unit(benchmark::kMicrosecond); \ BENCHMARK_TEMPLATE(bm_enckey_coeff_encryption, Preset::PRESET_##PRESET) \ - ->Unit(benchmark::kMicrosecond); \ - BENCHMARK_TEMPLATE(bm_coeff_decryption, Preset::PRESET_##PRESET) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(bm_coeff_decryption, Preset::PRESET_##PRESET) \ ->Unit(benchmark::kMicrosecond); PRESET_LIST #undef X + +BENCHMARK_TEMPLATE(bm_forward_ntt, 65536, 288230376147386369)->Unit(benchmark::kMicrosecond); +BENCHMARK_TEMPLATE(bm_backward_ntt, 65536, 288230376147386369) + ->Unit(benchmark::kMicrosecond); diff --git a/benchmark/benchmark_blake3.cpp b/benchmark/benchmark_blake3.cpp new file mode 100644 index 0000000..0ab7ed0 --- /dev/null +++ b/benchmark/benchmark_blake3.cpp @@ -0,0 +1,110 @@ +/* +* Copyright 2026 CryptoLab, Inc. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +#include "benchmark/benchmark.h" +#include "Decryptor.hpp" +#include "Encryptor.hpp" +#include "KeyGenerator.hpp" +#include "SecretKeyGenerator.hpp" +#include "Blake3RandomGenerator.hpp" +#include "Types.hpp" + +#include +#include +#include + +std::random_device rd; +std::mt19937 gen{rd()}; +std::uniform_real_distribution dist{-1.0, 1.0}; +std::uniform_int_distribution dist_u64{0, UINT64_MAX}; + +using namespace deb; + +static Message gen_random_message(const Size num_slots) { + Message msg(num_slots); + for (Size i = 0; i < msg.size(); ++i) { + msg.data()[i].real(dist(gen)); + msg.data()[i].imag(dist(gen)); + } + return msg; +} + +static CoeffMessage gen_random_coeff(const Size degree) { + CoeffMessage coeff(degree); + for (Size i = 0; i < degree; ++i) { + coeff[i] = dist(gen); + } + return coeff; +} + +// ------------------------------------------------------------------------- +// Blake3 RNG benchmarks with Blake3RandomGenerator +// ------------------------------------------------------------------------- + +template +static void bm_blake3_seckey_encryption(benchmark::State &state) { + setRandomGeneratorFactory(createBlake3RandomGenerator); + const Preset preset = T; + const auto ns = get_num_secret(preset); + std::vector msg_v; + for (Size i = 0; i < ns; ++i) { + msg_v.push_back(gen_random_message(get_num_slots(preset))); + } + + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + Encryptor encryptor(preset); + Ciphertext ctxt(preset); + + for (auto _ : state) { + encryptor.encrypt(msg_v, sk, ctxt); + benchmark::DoNotOptimize(ctxt); + benchmark::ClobberMemory(); + } + setRandomGeneratorFactory(nullptr); +} + +template +static void bm_blake3_enckey_encryption(benchmark::State &state) { + setRandomGeneratorFactory(createBlake3RandomGenerator); + const Preset preset = T; + const auto ns = get_num_secret(preset); + std::vector msg_v; + for (Size i = 0; i < ns; ++i) { + msg_v.push_back(gen_random_message(get_num_slots(preset))); + } + + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + Encryptor encryptor(preset); + Ciphertext ctxt(preset); + + for (auto _ : state) { + encryptor.encrypt(msg_v, enckey, ctxt); + benchmark::DoNotOptimize(ctxt); + benchmark::ClobberMemory(); + } + setRandomGeneratorFactory(nullptr); +} + +#define X(PRESET) \ + BENCHMARK_TEMPLATE(bm_blake3_seckey_encryption, Preset::PRESET_##PRESET) \ + ->Unit(benchmark::kMicrosecond); \ + BENCHMARK_TEMPLATE(bm_blake3_enckey_encryption, Preset::PRESET_##PRESET) \ + ->Unit(benchmark::kMicrosecond); + +PRESET_LIST +#undef X diff --git a/cmake/bundling.cmake b/cmake/bundling.cmake deleted file mode 100644 index 9bfc884..0000000 --- a/cmake/bundling.cmake +++ /dev/null @@ -1,65 +0,0 @@ -# ~~~ -# Copyright 2025 CryptoLab, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ~~~ - -# Merge the object files of `dependency` target with that of `target`. Both -# should be static or object libraries; otherwise this is no-op. -function(merge_archive_if_static target dependency) - # Check if the target has object files - list(APPEND TYPES_HAVING_OBJECTS "STATIC_LIBRARY" "OBJECT_LIBRARY") - get_target_property(IS_STATIC ${target} TYPE) - if(NOT IS_STATIC IN_LIST TYPES_HAVING_OBJECTS) - return() - endif() - - # Check if the dependency is a target and has object files - if(NOT TARGET ${dependency}) - return() - endif() - get_target_property(IS_STATIC ${dependency} TYPE) - if(NOT IS_STATIC IN_LIST TYPES_HAVING_OBJECTS) - return() - endif() - - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") - add_custom_command( - TARGET ${target} - POST_BUILD - COMMAND rm -rf ${target}_objs && mkdir ${target}_objs - COMMAND rm -rf ${dependency}_objs && mkdir ${dependency}_objs - COMMAND ${CMAKE_COMMAND} -E chdir ${target}_objs ${CMAKE_AR} -x - $ - COMMAND ${CMAKE_COMMAND} -E chdir ${dependency}_objs ${CMAKE_AR} -x - $ - COMMAND ar -qcs $ ${target}_objs/*.o - ${dependency}_objs/*.o - COMMAND rm -rf ${target}_objs ${dependency}_objs - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # DEPENDS ${target} - # ${dependency}) - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") - add_custom_command( - TARGET ${target} - POST_BUILD - COMMAND lib.exe /OUT:$ $ - $ - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) # DEPENDS ${target} - # ${dependency}) - else() - message( - WARNING - "Failed merging ${target} target with ${dependency}: unsupported compiler" - ) - endif() -endfunction() diff --git a/cmake/debConfig.cmake.in b/cmake/debConfig.cmake.in index ed4288d..8e870db 100644 --- a/cmake/debConfig.cmake.in +++ b/cmake/debConfig.cmake.in @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/cmake/warnings.cmake b/cmake/warnings.cmake index 8667cb6..caa1b93 100644 --- a/cmake/warnings.cmake +++ b/cmake/warnings.cmake @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,7 @@ # limitations under the License. # ~~~ -function(set_my_project_warnings target) +function(set_deb_warnings target) target_compile_options( ${target} PRIVATE diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d472f3a..02cb3fd 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,3 +37,7 @@ target_link_libraries(Serialization PRIVATE ${${PROJECT_NAME}_example}) add_executable(KeyGeneration KeyGeneration.cpp) target_link_libraries(KeyGeneration PRIVATE ${${PROJECT_NAME}_example}) + +# Custom RNG example +add_executable(CustomRNG CustomRNG.cpp) +target_link_libraries(CustomRNG PRIVATE ${${PROJECT_NAME}_example} sodium) diff --git a/examples/CustomRNG.cpp b/examples/CustomRNG.cpp new file mode 100644 index 0000000..c2b180f --- /dev/null +++ b/examples/CustomRNG.cpp @@ -0,0 +1,97 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ExampleUtils.hpp" +#include "SodiumRandomGenerator.hpp" + +#include + +using namespace deb; + +int main() { + Preset preset = static_cast(0); + if (preset == PRESET_EMPTY) { + std::cerr << "No preset with single secret found." << std::endl; + return -1; + } + std::cout << "Preset: " << get_preset_name(preset) << std::endl; + + // ----------------------------------------------------------------- + // Test RNG + // ----------------------------------------------------------------- + { + RNGSeed seed = SeedGenerator::Gen(); + auto sodium_rng = std::make_shared(seed); + + const size_t len = 16; + std::vector rand_u64(len); + sodium_rng->getRandomUint64Array(rand_u64.data(), len); + + std::cout << "Random u64 array:" << std::endl; + for (size_t i = 0; i < len; ++i) { + std::cout << " " << rand_u64[i] << std::endl; + } + } + // ----------------------------------------------------------------- + // Encrypt using SodiumRandomGenerator directly + // ----------------------------------------------------------------- + { + std::shared_ptr rng = std::make_shared(SeedGenerator::Gen()); + Encryptor enc(preset, rng); + Decryptor dec(preset); + + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + Message msg = generateRandomMessage(preset); + Message decrypted_msg(preset); + Ciphertext ctxt(preset); + + DebTimer::start("Encrypt/Decrypt (sodium RNG direct)"); + enc.encrypt(msg, sk, ctxt); + dec.decrypt(ctxt, sk, decrypted_msg); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) + << " bits" << std::endl; + } + // ----------------------------------------------------------------- + // Encrypt using the global factory so every RNG is Sodium-based + // ----------------------------------------------------------------- + { + setRandomGeneratorFactory([](const RNGSeed &seed) { + return std::make_shared(seed); + }); + + Encryptor enc(preset); + KeyGenerator keygen(preset); + Decryptor dec(preset); + + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + Message msg = generateRandomMessage(preset); + Message decrypted_msg(preset); + Ciphertext ctxt(preset); + + DebTimer::start("Encrypt/Decrypt (sodium RNG global)"); + enc.encrypt(msg, sk, ctxt); + dec.decrypt(ctxt, sk, decrypted_msg); + DebTimer::end(); + std::cout << "log2 error = " << compareMessage(msg, decrypted_msg) + << " bits" << std::endl; + + setRandomGeneratorFactory(nullptr); + } + + std::cout << "\nDone." << std::endl; + return 0; +} diff --git a/examples/EnDecryption-MultiSecret.cpp b/examples/EnDecryption-MultiSecret.cpp index 51c14be..44479af 100644 --- a/examples/EnDecryption-MultiSecret.cpp +++ b/examples/EnDecryption-MultiSecret.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,7 +24,7 @@ int main() { Preset preset = PRESET_EMPTY; // Retrieve preset with single secret for (auto p : Presets) { - if (getContext(p)->get_num_secret() > 1) { + if (get_num_secret(p) > 1) { preset = p; break; } @@ -33,9 +33,8 @@ int main() { std::cerr << "No preset with multiple secrets found." << std::endl; return -1; } - const Size num_secret = getContext(preset)->get_num_secret(); - - std::cout << "Preset: " << getContext(preset)->get_preset_name() << std::endl; + const Size num_secret = get_num_secret(preset); + std::cout << "Preset: " << get_preset_name(preset) << std::endl; Encryptor enc(preset); // Create encryptor Decryptor dec(preset); // Create decryptor std::vector msg; // Message to be encrypted @@ -63,7 +62,7 @@ int main() { } // Scaled encryption and decryption - u64 base_bit = utils::bitWidth(getContext(preset)->get_primes()[0]); // Example scale + u64 base_bit = utils::bitWidth(get_primes(preset)[0]); // Example scale Real scale = std::pow(2.0, base_bit - 3); { auto opt = EncryptOptions().Scale(scale); @@ -75,7 +74,7 @@ int main() { } // Encrypt with custom level - Size custom_level = getContext(preset)->get_encryption_level() / 2; + Size custom_level = get_encryption_level(preset) / 2; { auto opt = EncryptOptions().Level(custom_level); DebTimer::start("Custom Level EnDecryption"); @@ -138,7 +137,7 @@ int main() { // (Coefficient) Message encryption with encryption key // --------------------------------------------------------------------- // Generate encryption key from secret key - KeyGenerator keygen(sk); + KeyGenerator keygen(preset); SwitchKey ek = keygen.genEncKey(sk); // Basic encryption with encryption key diff --git a/examples/EnDecryption.cpp b/examples/EnDecryption.cpp index 1ffba99..e7d0b3f 100644 --- a/examples/EnDecryption.cpp +++ b/examples/EnDecryption.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,7 +25,7 @@ int main() { Preset preset = PRESET_EMPTY; // Retrieve preset with single secret for (auto p : Presets) { - if (getContext(p)->get_num_secret() == 1) { + if (get_num_secret(p) == 1) { preset = p; break; } @@ -34,7 +34,7 @@ int main() { std::cerr << "No preset with single secret found." << std::endl; return -1; } - std::cout << "Preset: " << getContext(preset)->get_preset_name() << std::endl; + std::cout << "Preset: " << get_preset_name(preset) << std::endl; Encryptor enc(preset); // Create encryptor Decryptor dec(preset); // Create decryptor Message msg(preset); // Message to be encrypted @@ -60,7 +60,7 @@ int main() { } // Scaled encryption and decryption - u64 base_bit = utils::bitWidth(getContext(preset)->get_primes()[0]); // Example scale + u64 base_bit = utils::bitWidth(get_primes(preset)[0]); // Example scale Real scale = std::pow(2.0, base_bit - 3); { auto opt = EncryptOptions().Scale(scale); @@ -72,7 +72,7 @@ int main() { } // Encrypt with custom level - Size custom_level = getContext(preset)->get_encryption_level() / 2; + Size custom_level = get_encryption_level(preset) / 2; { auto opt = EncryptOptions().Level(custom_level); DebTimer::start("Custom Level EnDecryption"); @@ -130,7 +130,7 @@ int main() { // (Coefficient) Message encryption with encryption key // --------------------------------------------------------------------- // Generate encryption key from secret key - KeyGenerator keygen(sk); + KeyGenerator keygen(preset); SwitchKey ek = keygen.genEncKey(sk); // Basic encryption with encryption key diff --git a/examples/ExampleUtils.cpp b/examples/ExampleUtils.cpp index cb05e1d..74512e2 100644 --- a/examples/ExampleUtils.cpp +++ b/examples/ExampleUtils.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/examples/ExampleUtils.hpp b/examples/ExampleUtils.hpp index 3e6594e..df897ab 100644 --- a/examples/ExampleUtils.hpp +++ b/examples/ExampleUtils.hpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,6 @@ #pragma once #include "CKKSTypes.hpp" -#include "Context.hpp" #include "Decryptor.hpp" #include "Encryptor.hpp" #include "KeyGenerator.hpp" diff --git a/examples/KeyGeneration.cpp b/examples/KeyGeneration.cpp index 3e34561..7ed04f0 100644 --- a/examples/KeyGeneration.cpp +++ b/examples/KeyGeneration.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,44 +23,44 @@ int main() { Preset preset; for(auto p : Presets) { // fine 0 level preset - if(getContext(p)->get_num_base() == 1 && getContext(p)->get_num_qp() == 0) { + if(get_num_base(p) == 1 && get_num_qp(p) == 0) { preset = p; break; } } - std::cout << "Preset: " << getContext(preset)->get_preset_name() << std::endl; + std::cout << "Preset: " << get_preset_name(preset) << std::endl; SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - KeyGenerator keygen(sk); + KeyGenerator keygen(preset); // Generate encryption key { SwitchKey enckey(preset, SwitchKeyKind::SWK_ENC); DebTimer::start("Encryption Key Generation"); - keygen.genEncKeyInplace(enckey); // inplace keygen + keygen.genEncKeyInplace(enckey, sk); // inplace keygen DebTimer::end(); - SwitchKey enckey2 = keygen.genEncKey(); // outplace keygen + SwitchKey enckey2 = keygen.genEncKey(sk); // outplace keygen } // Generate multiplication key { SwitchKey mulkey(preset, SwitchKeyKind::SWK_MULT); DebTimer::start("Multiplication Key Generation"); - keygen.genMultKeyInplace(mulkey); // inplace keygen + keygen.genMultKeyInplace(mulkey, sk); // inplace keygen DebTimer::end(); - SwitchKey mulkey2 = keygen.genMultKey(); // outplace keygen + SwitchKey mulkey2 = keygen.genMultKey(sk); // outplace keygen } // Generate conjugation key { SwitchKey conjkey(preset, SwitchKeyKind::SWK_CONJ); DebTimer::start("Conjugation Key Generation"); - keygen.genConjKeyInplace(conjkey); // inplace keygen + keygen.genConjKeyInplace(conjkey, sk); // inplace keygen DebTimer::end(); - SwitchKey conjkey2 = keygen.genConjKey(); // outplace keygen + SwitchKey conjkey2 = keygen.genConjKey(sk); // outplace keygen } // Generate left rotation key @@ -68,31 +68,31 @@ int main() { { SwitchKey lrotkey(preset, SwitchKeyKind::SWK_ROT); DebTimer::start("Left Rotation Key Generation"); - keygen.genLeftRotKeyInplace(rot, lrotkey); // inplace keygen + keygen.genLeftRotKeyInplace(rot, lrotkey, sk); // inplace keygen DebTimer::end(); std::cout << "rotation index: " << rot << std::endl; - SwitchKey lrotkey2 = keygen.genLeftRotKey(rot); // outplace keygen + SwitchKey lrotkey2 = keygen.genLeftRotKey(rot, sk); // outplace keygen } // Generate right rotation key { SwitchKey rrotkey(preset, SwitchKeyKind::SWK_ROT); DebTimer::start("Right Rotation Key Generation"); - keygen.genRightRotKeyInplace(rot, rrotkey); // inplace keygen + keygen.genRightRotKeyInplace(rot, rrotkey, sk); // inplace keygen DebTimer::end(); - SwitchKey rrotkey2 = keygen.genRightRotKey(rot); // outplace keygen + SwitchKey rrotkey2 = keygen.genRightRotKey(rot, sk); // outplace keygen } // Generate automorphism key { SwitchKey autokey(preset, SwitchKeyKind::SWK_AUTO); DebTimer::start("Automorphism Key Generation"); - keygen.genAutoKeyInplace(rot, autokey); // inplace keygen + keygen.genAutoKeyInplace(rot, autokey, sk); // inplace keygen DebTimer::end(); - SwitchKey autokey2 = keygen.genAutoKey(rot); // outplace keygen + SwitchKey autokey2 = keygen.genAutoKey(rot, sk); // outplace keygen } // Generate composition key @@ -100,10 +100,10 @@ int main() { { SwitchKey composekey(preset, SwitchKeyKind::SWK_COMPOSE); DebTimer::start("Composition Key Generation"); - keygen.genComposeKeyInplace(sk_from, composekey); // inplace keygen + keygen.genComposeKeyInplace(sk_from, composekey, sk); // inplace keygen DebTimer::end(); - SwitchKey composekey2 = keygen.genComposeKey(sk_from); // outplace keygen + SwitchKey composekey2 = keygen.genComposeKey(sk_from, sk); // outplace keygen } // Generate decomposition key @@ -111,21 +111,21 @@ int main() { { SwitchKey decompkey(preset, SwitchKeyKind::SWK_DECOMPOSE); DebTimer::start("Decomposition Key Generation"); - keygen.genDecomposeKeyInplace(sk_to, decompkey); // inplace keygen + keygen.genDecomposeKeyInplace(sk_to, decompkey, sk); // inplace keygen DebTimer::end(); - SwitchKey decompkey2 = keygen.genDecomposeKey(sk_to); // outplace keygen + SwitchKey decompkey2 = keygen.genDecomposeKey(sk_to, sk); // outplace keygen } - // Generate decomposition key with switching context + // Generate decomposition key with switching preset { Preset preset_swk = preset; // for simplicity, use the same preset SwitchKey decompkey_swk(preset_swk, SwitchKeyKind::SWK_DECOMPOSE); - DebTimer::start("Decomposition Key Generation with switching context"); - keygen.genDecomposeKeyInplace(preset_swk, sk_to, decompkey_swk); // inplace keygen + DebTimer::start("Decomposition Key Generation with switching preset"); + keygen.genDecomposeKeyInplace(preset_swk, sk_to, decompkey_swk, sk); // inplace keygen DebTimer::end(); - SwitchKey decompkey_swk2 = keygen.genDecomposeKey(preset_swk, sk_to); // outplace keygen + SwitchKey decompkey_swk2 = keygen.genDecomposeKey(preset_swk, sk_to, sk); // outplace keygen } // Generate modpack keys @@ -142,16 +142,16 @@ int main() { // Generate self modpack key with pad_rank { - const Size pad_rank = 1U << (getContext(preset)->get_log_degree() / 2); - const Size num_p = getContext(preset)->get_num_p(); + const Size pad_rank = 1U << (get_log_degree(preset) / 2); + const Size num_p = get_num_p(preset); SwitchKey self_modkey(preset, SwitchKeyKind::SWK_MODPACK_SELF); self_modkey.addAx(num_p, pad_rank, true); - self_modkey.addBx(num_p, pad_rank * getContext(preset)->get_num_secret(), true); + self_modkey.addBx(num_p, pad_rank * get_num_secret(preset), true); DebTimer::start("Self ModPack Key Bundle Generation"); - keygen.genModPackKeyBundleInplace(pad_rank, self_modkey); // inplace keygen + keygen.genModPackKeyBundleInplace(pad_rank, self_modkey, sk); // inplace keygen DebTimer::end(); - self_modkey = keygen.genModPackKeyBundle(pad_rank); // outplace keygen + self_modkey = keygen.genModPackKeyBundle(pad_rank, sk); // outplace keygen } return 0; diff --git a/examples/SeedOnlySecretKey.cpp b/examples/SeedOnlySecretKey.cpp index 7d60c1e..f2b62ee 100644 --- a/examples/SeedOnlySecretKey.cpp +++ b/examples/SeedOnlySecretKey.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ using namespace deb; int main() { Preset preset = static_cast(0); - std::cout << "Preset: " << getContext(preset)->get_preset_name() << std::endl; + std::cout << "Preset: " << get_preset_name(preset) << std::endl; // Generate seed for secret key RNGSeed sk_seed = SeedGenerator::Gen(); diff --git a/examples/Serialization.cpp b/examples/Serialization.cpp index 463f6fc..84de9fc 100644 --- a/examples/Serialization.cpp +++ b/examples/Serialization.cpp @@ -1,5 +1,5 @@ /* -* Copyright 2025 CryptoLab, Inc. +* Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,19 +26,19 @@ int main() { // Define presets to test Preset preset; for (auto p : Presets) { - if (getContext(p)->get_num_secret() == 1) { + if (get_num_secret(p) == 1) { preset = p; break; } } - std::cout << "Preset: " << getContext(preset)->get_preset_name() << std::endl; + std::cout << "Preset: " << get_preset_name(preset) << std::endl; // Generate resources Message msg = generateRandomMessage(preset); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); Encryptor encryptor(preset); Decryptor decryptor(preset); - KeyGenerator keygen(sk); + KeyGenerator keygen(preset); std::string tmp_dir = "./example_data/"; std::filesystem::create_directories(tmp_dir); @@ -46,7 +46,7 @@ int main() { // Serialize message, sk, key, ciphertext { - SwitchKey enckey = keygen.genEncKey(); + SwitchKey enckey = keygen.genEncKey(sk); Ciphertext cipher(preset); encryptor.encrypt(msg, enckey, cipher); diff --git a/examples/SodiumRandomGenerator.hpp b/examples/SodiumRandomGenerator.hpp new file mode 100644 index 0000000..e87a3f8 --- /dev/null +++ b/examples/SodiumRandomGenerator.hpp @@ -0,0 +1,162 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "utils/Constant.hpp" +#include "utils/RandomGenerator.hpp" + +#include + +#include +#include +#include +#include +#include +#include + +namespace deb { + +// --------------------------------------------------------------------------- +// SodiumRandomGenerator — a custom RandomGenerator backed by libsodium. +// +// Internally it derives a 32-byte ChaCha20 key from the 64-byte RNGSeed +// via BLAKE2b (crypto_generichash) and uses crypto_stream_xchacha20 with +// an incrementing 24-byte nonce to produce a deterministic byte stream. +// --------------------------------------------------------------------------- +class SodiumRandomGenerator : public RandomGenerator { +public: + explicit SodiumRandomGenerator(const RNGSeed &seed) { + if (sodium_init() < 0) { + throw std::runtime_error( + "[SodiumRandomGenerator] sodium_init failed"); + } + deriveSeedMaterial(reinterpret_cast(seed.data()), + DEB_RNG_SEED_BYTE_SIZE); + } + + SodiumRandomGenerator(const SodiumRandomGenerator &) = delete; + SodiumRandomGenerator &operator=(const SodiumRandomGenerator &) = delete; + + ~SodiumRandomGenerator() override { sodium_memzero(key_, sizeof(key_)); } + + // -- basic random generation -------------------------------------------- + + void getRandomUint64Array(u64 *dst, size_t len) override { + fillBytes(reinterpret_cast(dst), len * sizeof(u64)); + } + + void getRandomUint64ArrayInRange(u64 *dst, size_t len, + u64 range) override { + for (size_t i = 0; i < len; ++i) { + dst[i] = uniformUint64InRange(range); + } + } + + // -- distribution sampling ---------------------------------------------- + + void sampleGaussianInt64Array(i64 *dst, size_t len, + double stdev) override { + // Box-Muller transform, consuming two uniform doubles per pair. + size_t i = 0; + while (i + 1 < len) { + double u1, u2; + uniformDouble01(u1, u2); + double r = stdev * std::sqrt(-2.0 * std::log(u1)); + dst[i] = + static_cast(std::round(r * std::cos(2.0 * utils::REAL_PI * u2))); + dst[i + 1] = + static_cast(std::round(r * std::sin(2.0 * utils::REAL_PI * u2))); + i += 2; + } + if (i < len) { + double u1, u2; + uniformDouble01(u1, u2); + double r = stdev * std::sqrt(-2.0 * std::log(u1)); + dst[i] = + static_cast(std::round(r * std::cos(2.0 * utils::REAL_PI * u2))); + } + } + + void sampleHwtInt8Array(i8 *dst, size_t len, int hwt) override { + // Fill with zeros, place +1/-1 at `hwt` random positions. + std::memset(dst, 0, len); + + // Build index array and shuffle first `hwt` positions (Fisher-Yates). + std::vector indices(len); + std::iota(indices.begin(), indices.end(), 0); + + for (int i = 0; i < hwt; ++i) { + u64 j_raw; + fillBytes(reinterpret_cast(&j_raw), sizeof(j_raw)); + size_t j = i + static_cast(j_raw % (len - i)); + std::swap(indices[i], indices[j]); + + // Random sign: use one bit. + u8 sign_byte; + fillBytes(&sign_byte, 1); + dst[indices[i]] = (sign_byte & 1) ? 1 : -1; + } + } + + void reseed(const u8 *seed, size_t seed_len) override { + deriveSeedMaterial(seed, seed_len); + } + +private: + void deriveSeedMaterial(const u8 *seed, size_t seed_len) { + // BLAKE2b-512 -> first 32 bytes = key, next 24 bytes = initial nonce. + u8 hash[64]; + crypto_generichash(hash, sizeof(hash), seed, seed_len, nullptr, 0); + std::memcpy(key_, hash, crypto_stream_xchacha20_KEYBYTES); + std::memset(nonce_, 0, sizeof(nonce_)); + std::memcpy(nonce_, hash + 32, + std::min(24, sizeof(hash) - 32)); + sodium_memzero(hash, sizeof(hash)); + } + + void fillBytes(u8 *buf, size_t buflen) { + crypto_stream_xchacha20(buf, buflen, nonce_, key_); + incrementNonce(); + } + + void incrementNonce() { sodium_increment(nonce_, sizeof(nonce_)); } + + u64 uniformUint64InRange(u64 range) { + // Rejection sampling for uniform distribution in [0, range). + if (range <= 1) + return 0; + u64 limit = (UINT64_MAX / range) * range; + u64 val; + do { + fillBytes(reinterpret_cast(&val), sizeof(val)); + } while (val >= limit); + return val % range; + } + + void uniformDouble01(double &a, double &b) { + // Generate two doubles in (0, 1] using 52-bit mantissa. + u64 raw[2]; + fillBytes(reinterpret_cast(raw), sizeof(raw)); + a = static_cast((raw[0] >> 12) + 1) / 4503599627370496.0; + b = static_cast((raw[1] >> 12) + 1) / 4503599627370496.0; + } + + u8 key_[crypto_stream_xchacha20_KEYBYTES]; + u8 nonce_[crypto_stream_xchacha20_NONCEBYTES]; +}; + +} // namespace deb diff --git a/external/CMakeLists.txt b/external/CMakeLists.txt index 870b5d9..3bc3494 100644 --- a/external/CMakeLists.txt +++ b/external/CMakeLists.txt @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ cpmaddpackage( GITHUB_REPOSITORY CryptoLabInc/alea GIT_TAG - v0.1.6 + v0.1.7 OPTIONS "ALEA_BUILD_TEST OFF" "ALEA_BUILD_DOXYGEN ${DEB_BUILD_DOXYGEN}" @@ -44,3 +44,14 @@ cpmaddpackage( set(flatbuffers_SOURCE_DIR ${flatbuffers_SOURCE_DIR} CACHE PATH "" FORCE) + +string(TOUPPER "${DEB_EXT_LIB_FOR_SECURE_ZERO}" _deb_secure_zero_backend) + +if(_deb_secure_zero_backend STREQUAL "LIBSODIUM" OR DEB_BUILD_EXAMPLES) + set(SODIUM_DISABLE_TESTS ON) + cpmaddpackage(NAME sodium GITHUB_REPOSITORY robinlinden/libsodium-cmake + GIT_TAG master) +endif() +if(_deb_secure_zero_backend STREQUAL "OPENSSL") + find_package(OpenSSL REQUIRED COMPONENTS Crypto) +endif() diff --git a/include/deb/CKKSTypes.hpp b/include/deb/CKKSTypes.hpp index ba89364..a763d55 100644 --- a/include/deb/CKKSTypes.hpp +++ b/include/deb/CKKSTypes.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,8 @@ #pragma once -#include "Context.hpp" +#include "Preset.hpp" #include "SeedGenerator.hpp" -#include "Types.hpp" #include #include @@ -34,7 +33,7 @@ namespace deb { */ template using ComplexT = std::complex; /** - * @brief Default complex type using @ref Real precision. + * @brief Default complex type using @ref Real (double) precision. */ using Complex = ComplexT; @@ -55,20 +54,9 @@ template class MessageBase { */ explicit MessageBase(const Preset preset) { if constexpr (EncodeT == EncodingType::SLOT) { - data_.resize(getContext(preset)->get_num_slots()); + data_.resize(get_num_slots(preset)); } else if constexpr (EncodeT == EncodingType::COEFF) { - data_.resize(getContext(preset)->get_degree()); - } - } - /** - * @brief Allocates storage based on a shared context. - * @param context Shared context that exposes preset metadata. - */ - explicit MessageBase(const Context &context) { - if constexpr (EncodeT == EncodingType::SLOT) { - data_.resize(context->get_num_slots()); - } else if constexpr (EncodeT == EncodingType::COEFF) { - data_.resize(context->get_degree()); + data_.resize(get_degree(preset)); } } /** @@ -126,7 +114,9 @@ template using CoeffMessageImpl = MessageBase; using Message = MessageImpl; +using FMessage = MessageImpl; using CoeffMessage = CoeffMessageImpl; +using FCoeffMessage = CoeffMessageImpl; #define DECL_MESSAGE_TEMPLATE(encode_t, data_t, prefix) \ prefix template MessageBase::MessageBase( \ @@ -148,7 +138,9 @@ using CoeffMessage = CoeffMessageImpl; #define MESSAGE_TYPE_TEMPLATE(prefix) \ DECL_MESSAGE_TEMPLATE(EncodingType::SLOT, ComplexT, prefix) \ - DECL_MESSAGE_TEMPLATE(EncodingType::COEFF, Real, prefix) + DECL_MESSAGE_TEMPLATE(EncodingType::SLOT, ComplexT, prefix) \ + DECL_MESSAGE_TEMPLATE(EncodingType::COEFF, Real, prefix) \ + DECL_MESSAGE_TEMPLATE(EncodingType::COEFF, float, prefix) MESSAGE_TYPE_TEMPLATE(extern) @@ -163,20 +155,20 @@ class PolyUnit { * @brief Initializes the unit for a preset at a specific modulus level. * @param preset Preset describes modulus chain metadata. * @param level Target modulus index. + * @param alloc True to allocate storage, false to create an zero-allocated + * object. */ - explicit PolyUnit(const Preset preset, const Size level); - /** - * @brief Initializes the unit using a shared context. - * @param context Shared context that exposes metadata. - * @param level Target modulus index. - */ - explicit PolyUnit(const Context &context, const Size level); + explicit PolyUnit(const Preset preset, const Size level, + const bool alloc = true); + /** * @brief Constructs a unit with explicit modulus and degree configuration. * @param prime Prime modulus value. * @param degree Number of coefficients. + * @param alloc True to allocate storage, false to create an zero-allocated + * object. */ - explicit PolyUnit(u64 prime, Size degree); + explicit PolyUnit(u64 prime, Size degree, const bool alloc = true); /** * @brief Creates a full copy of the unit including coefficient storage. @@ -208,16 +200,16 @@ class PolyUnit { * @brief Mutable coefficient accessor without bounds checks. * @param index Coefficient index. */ - u64 &operator[](Size index) noexcept; + u64 &operator[](Size index) noexcept { return data_ptr_.get()[index]; } /** * @brief Const coefficient accessor without bounds checks. * @param index Coefficient index. */ - u64 operator[](Size index) const noexcept; + u64 operator[](Size index) const noexcept { return data_ptr_.get()[index]; } /** * @brief Returns a mutable pointer to coefficient storage. */ - u64 *data() const noexcept; + u64 *data() const noexcept { return data_ptr_.get(); } /** * @brief Sets an externally managed storage buffer. * @param new_data Pointer to caller-managed coefficients. @@ -228,7 +220,8 @@ class PolyUnit { private: u64 prime_; bool ntt_state_; - std::shared_ptr> data_; + Size degree_; + std::shared_ptr data_ptr_; }; /** @@ -246,19 +239,12 @@ class Polynomial { * false to allocate only the default encryption level. */ explicit Polynomial(const Preset preset, const bool full_level = false); - /** - * @brief Constructs with a shared context handle. - * @param context Shared context that exposes metadata. - * @param full_level True to allocate every modulus level, - * false to allocate only the default encryption level. - */ - explicit Polynomial(Context context, const bool full_level = false); /** * @brief Constructs with a custom number of PolyUnit entries. - * @param context Shared context that exposes metadata. + * @param preset Preset that describes modulus chain metadata. * @param custom_size Number of PolyUnit slots. */ - explicit Polynomial(Context context, const Size custom_size); + explicit Polynomial(const Preset preset, const Size custom_size); /** * @brief Copies slices of another polynomial. * @param other Source polynomial. @@ -301,23 +287,32 @@ class Polynomial { * @brief Mutable PolyUnit accessor. * @param index PolyUnit index. */ - PolyUnit &operator[](size_t index) noexcept; + PolyUnit &operator[](size_t index) noexcept { return polyunits_[index]; } /** * @brief Read-only PolyUnit accessor. * @param index PolyUnit index. */ - const PolyUnit &operator[](size_t index) const noexcept; + const PolyUnit &operator[](size_t index) const noexcept { + return polyunits_[index]; + } /** * @brief Mutable pointer to the first PolyUnit. */ - PolyUnit *data() noexcept; + PolyUnit *data() noexcept { return polyunits_.data(); } /** * @brief Const pointer to the first PolyUnit. */ - const PolyUnit *data() const noexcept; + const PolyUnit *data() const noexcept { return polyunits_.data(); } private: - std::vector data_; + std::vector polyunits_; + + /** + * @brief Consecutive data pointer to allocated and deallocated by this + * Polynomial. If nullptr, the data is managed externally (or by polyunits_ + * themselves). + */ + std::shared_ptr dealloc_ptr_; }; /** @@ -330,10 +325,6 @@ class Ciphertext { * @brief Allocates ciphertext metadata for a preset with default level. */ explicit Ciphertext(const Preset preset); - /** - * @brief Allocates ciphertext metadata using a shared context. - */ - explicit Ciphertext(Context context); /** * @brief Allocates ciphertext with explicit level and number of * polynomials. @@ -343,14 +334,6 @@ class Ciphertext { */ explicit Ciphertext(const Preset preset, const Size level, std::optional num_poly = std::nullopt); - /** - * @brief Context-based overload selecting level and size. - * @param context Shared context that exposes metadata. - * @param level Target modulus index. - * @param num_poly Optional number of component polynomials. - */ - explicit Ciphertext(Context context, const Size level, - std::optional num_poly = std::nullopt); /** * @brief Copies a subset of another ciphertext. * @param other Source ciphertext. @@ -413,20 +396,22 @@ class Ciphertext { * @brief Mutable polynomial accessor. * @param index Polynomial index. */ - Polynomial &operator[](size_t index) noexcept; + Polynomial &operator[](size_t index) noexcept { return polys_[index]; } /** * @brief Const polynomial accessor. * @param index Polynomial index. */ - const Polynomial &operator[](size_t index) const noexcept; + const Polynomial &operator[](size_t index) const noexcept { + return polys_[index]; + } /** * @brief Mutable pointer to polynomial storage. */ - Polynomial *data() noexcept; + Polynomial *data() noexcept { return polys_.data(); } /** * @brief Const pointer to polynomial storage. */ - const Polynomial *data() const noexcept; + const Polynomial *data() const noexcept { return polys_.data(); } private: Preset preset_; @@ -439,7 +424,21 @@ class Ciphertext { */ class SecretKey { public: + /** + * @brief Default constructor is deleted to prevent accidental creation of + * sensitive key material. + */ SecretKey() = delete; + /** + * @brief Copy constructor is deleted to prevent accidental copying of + * sensitive key material. + */ + SecretKey(const SecretKey &other) = delete; + /** + * @brief Copy assignment operator is deleted to prevent accidental copying + * of sensitive key material. + */ + SecretKey &operator=(const SecretKey &other) = delete; /** * @brief Deterministic constructor from preset and PRNG seed. * @param preset Preset that describes modulus chain metadata. @@ -453,6 +452,13 @@ class SecretKey { */ explicit SecretKey(Preset preset, bool embedding = true); + SecretKey(SecretKey &&) noexcept; + SecretKey &operator=(SecretKey &&) noexcept; + + /** + * @brief Destructor that securely zeroes all sensitive key material. + */ + ~SecretKey() noexcept; /** * @brief Preset used to generate this key. */ @@ -474,7 +480,20 @@ class SecretKey { * @brief Removes the stored seed to prevent future regenerations. */ void flushSeed() noexcept; - + /** + * @brief Securely zeroes all sensitive key material in place. + * + * Zeros @ref coeffs_, the stored @ref seed_, and every PolyUnit + * coefficient buffer inside @ref polys_ using the secure-zero backend + * selected by @c DEB_EXT_LIB_FOR_SECURE_ZERO at build time: + * - @c LIBSODIUM → @c sodium_memzero + * - @c OPENSSL → @c OPENSSL_cleanse + * - @c NATIVE → @c explicit_bzero / @c SecureZeroMemory / volatile + * byte loop + * - @c NONE → @c memset (not secure, included for testing purposes + * only) + */ + void zeroize() noexcept; /** * @brief Number of raw coefficients currently allocated. */ @@ -501,7 +520,6 @@ class SecretKey { * @brief Const pointer to coefficient array. */ const i8 *coeffs() const noexcept; - /** * @brief Number of stored polynomial components. */ @@ -515,20 +533,20 @@ class SecretKey { * @brief Mutable polynomial accessor. * @param index Polynomial index. */ - Polynomial &operator[](Size index); + Polynomial &operator[](Size index) { return polys_[index]; } /** * @brief Const polynomial accessor. * @param index Polynomial index. */ - const Polynomial &operator[](Size index) const; + const Polynomial &operator[](Size index) const { return polys_[index]; } /** * @brief Mutable pointer to polynomial array. */ - Polynomial *data() noexcept; + Polynomial *data() noexcept { return polys_.data(); } /** * @brief Const pointer to polynomial array. */ - const Polynomial *data() const noexcept; + const Polynomial *data() const noexcept { return polys_.data(); } private: Preset preset_; @@ -551,15 +569,6 @@ class SwitchKey { */ explicit SwitchKey(Preset preset, const SwitchKeyKind type, const std::optional rot_idx = std::nullopt); - /** - * @brief Constructs a switching key from a context and key kind. - * @param context Shared context that exposes metadata. - * @param type SwitchKeyKind (SWK_MULT, SWK_ROT, etc). - * @param rot_idx Optional rotation index. - */ - explicit SwitchKey(const Context &context, const SwitchKeyKind type, - const std::optional rot_idx = std::nullopt); - /** * @brief Returns the preset metadata for this key. */ @@ -689,7 +698,7 @@ class SwitchKey { * @throws std::out_of_range if indices are invalid. */ inline u64 *getData(const Ciphertext &cipher, const Size polyunit_idx, - const Size poly_idx = 0) { + const Size poly_idx) { if (poly_idx >= cipher.numPoly() || polyunit_idx >= cipher[poly_idx].size()) { throw std::out_of_range("Index out of range in getData"); @@ -697,13 +706,13 @@ inline u64 *getData(const Ciphertext &cipher, const Size polyunit_idx, return cipher[poly_idx][polyunit_idx].data(); } -inline u64 *getData(const Polynomial &poly, const Size polyunit_idx = 0) { +inline u64 *getData(const Polynomial &poly, const Size polyunit_idx) { if (polyunit_idx >= poly.size()) { throw std::out_of_range("Index out of range in getData"); } return poly[polyunit_idx].data(); } -inline u64 getData(const u64 *data, const Size idx = 0) { return data[idx]; } +inline u64 getData(const u64 *data, const Size idx) { return data[idx]; } } // namespace deb diff --git a/include/deb/Context.hpp b/include/deb/Context.hpp deleted file mode 100644 index 90bb341..0000000 --- a/include/deb/Context.hpp +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Copyright 2025 CryptoLab, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "Types.hpp" - -#include "DebParam.hpp" -#include "Macro.hpp" -#include "utils/Span.hpp" - -#include -#include - -// Define preset values and precomputed values from preset values. -#define CONST_LIST \ - CV(Preset, preset) \ - CV(Preset, parent) \ - CV(const char *, preset_name) \ - CV(Size, rank) \ - CV(Size, num_secret) \ - CV(Size, log_degree) \ - CV(Size, degree) \ - CV(Size, num_slots) \ - CV(Size, gadget_rank) \ - CV(Size, num_base) \ - CV(Size, num_qp) \ - CV(Size, num_tp) \ - CV(Size, num_p) \ - CV(Size, encryption_level) \ - CV(Size, hamming_weight) \ - CV(Real, gaussian_error_stdev) \ - CV(const u64 *, primes) \ - CV(const Real *, scale_factors) - -namespace deb { - -/** - * @brief Compile-time context wrapper exposing preset constants via accessors. - */ -template struct ContextT : public PRESET { -#define CV(type, var_name) \ - static constexpr type get_##var_name() { return PRESET::var_name; } - CONST_LIST -#undef CV -}; - -/** - * @brief Variant capable of holding any preset-specific context view. - */ -using VariantCtx = std::variant< -#define X(PRESET) ContextT, - PRESET_LIST -#undef X - ContextT>; - -/** - * @brief Runtime context content that dispatches to preset-specific variants. - */ -struct ContextContent { - VariantCtx v; -#define CV(type, var_name) \ - constexpr type get_##var_name() const { \ - return std::visit( \ - [](auto &&ctx) -> type { return ctx.get_##var_name(); }, v); \ - } - CONST_LIST -#undef CV -}; - -/** - * @brief Shared pointer alias referencing runtime context data. - */ -using Context = std::shared_ptr; - -// Singleton ContextPool to manage Context instances -/** - * @brief Provides singleton access to preset contexts. - */ -class ContextPool { -public: - /** - * @brief Accesses the singleton context pool. - * @return Reference to the singleton instance. - */ - static ContextPool &GetInstance() { - static ContextPool instance; - return instance; - } - - /** - * @brief Retrieves the shared context for a preset. - * @param preset Requested preset. - * @return Shared context pointer. - * @throws std::runtime_error When the preset is unknown. - */ - Context get(Preset preset) { - if (auto it = map_.find(preset); it != map_.end()) { - return it->second; - } - throw std::runtime_error("Preset not found in ContextPool"); - } - -private: - ContextPool() { -#define X(PRESET) \ - map_[PRESET_##PRESET] = std::make_shared( \ - ContextContent{VariantCtx{ContextT{}}}); - PRESET_LIST -#undef X - } - std::unordered_map map_; -}; - -/** - * @brief Retrieves the shared context for a preset. - * @param preset Requested preset. - * @return Shared context pointer. - */ -Context getContext(Preset preset); - -/** - * @brief Checks whether a preset enum value is supported. - * @param preset Preset to validate. - * @return True if the preset exists. - */ -bool isValidPreset(Preset preset); - -/** - * @brief Sets an OpenMP thread limit for the current process. - * @param max_threads Maximum number of threads; implementation-defined. - */ -void setOmpThreadLimit(int max_threads); -/** - * @brief Removes any OpenMP thread limit previously applied. - */ -void unsetOmpThreadLimit(); - -} // namespace deb diff --git a/include/deb/Decryptor.hpp b/include/deb/Decryptor.hpp index a13c313..7838e61 100644 --- a/include/deb/Decryptor.hpp +++ b/include/deb/Decryptor.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,8 @@ #pragma once #include "CKKSTypes.hpp" -#include "Context.hpp" #include "utils/FFT.hpp" -#include "utils/ModArith.hpp" +#include "utils/PresetTraits.hpp" #include @@ -29,14 +28,19 @@ namespace deb { /** * @brief Provides CKKS decryption and decoding utilities. */ -class Decryptor { +template class DecryptorT : public PresetTraits

{ +#define CV(type, var_name) using PresetTraits

::var_name; + CONST_LIST +#undef CV + using PresetTraits

::modarith; + public: /** * @brief Creates a decryptor for the given preset. * @param preset Target preset that defines polynomial sizes and moduli. */ - explicit Decryptor(const Preset preset); - // explicit Encryptor(const deb_shared_context_t &context); + explicit DecryptorT(); + explicit DecryptorT(const Preset preset); template >, int> = 0> @@ -76,41 +80,74 @@ class Decryptor { */ void decrypt(const Ciphertext &ctxt, const SecretKey &sk, std::vector &msg, Real scale = 0) const { - deb_assert(msg.size() == context_->get_num_secret(), + deb_assert(msg.size() == num_secret, "[Decryptor::decrypt] Message size mismatch"); decrypt(ctxt, sk, msg.data(), scale); } private: Polynomial - innerDecrypt(const Ciphertext &ctxt, const SecretKey &sk, + innerDecrypt(const Ciphertext &ctxt, const Polynomial &sx, const std::optional &ax = std::nullopt) const; - void decodeWithSinglePoly(const Polynomial &ptxt, CoeffMessage &coeff, + template + void decodeWithSinglePoly(const Polynomial &ptxt, CMSG &coeff, Real scale) const; - void decodeWithPolyPair(const Polynomial &ptxt, CoeffMessage &coeff, + template + void decodeWithPolyPair(const Polynomial &ptxt, CMSG &coeff, Real scale) const; - void decodeWithoutFFT(const Polynomial &ptxt, CoeffMessage &coeff, + template + void decodeWithoutFFT(const Polynomial &ptxt, CMSG &coeff, Real scale) const; - void decode(const Polynomial &ptxt, Message &msg, Real scale) const; + template + void decode(const Polynomial &ptxt, MSG &msg, Real scale) const; - Context context_; - // TODO: move to Context - std::vector modarith_; - utils::FFTImpl fft_; + utils::FFT fft_; }; -#define DECL_DECRYPT_TEMPLATE_MSG(msg_t, prefix) \ - prefix template void Decryptor::decrypt( \ +using Decryptor = DecryptorT<>; + +#define DECL_DECRYPT_TEMPLATE_MSG(preset, msg_t, prefix) \ + prefix template void DecryptorT::decrypt( \ const Ciphertext &ctxt, const SecretKey &sk, msg_t &msg, Real scale) \ const; \ - prefix template void Decryptor::decrypt( \ + prefix template void DecryptorT::decrypt( \ const Ciphertext &ctxt, const SecretKey &sk, msg_t *msg, Real scale) \ - const; + const; \ + prefix template void DecryptorT::decrypt( \ + const Ciphertext &ctxt, const SecretKey &sk, std::vector &msg, \ + Real scale) const; + +#define DECL_DECRYPT_TEMPLATE_DECODE(preset, prefix) \ + prefix template void \ + DecryptorT::decodeWithSinglePoly( \ + const Polynomial &ptxt, CoeffMessage &coeff, Real scale) const; \ + prefix template void \ + DecryptorT::decodeWithSinglePoly( \ + const Polynomial &ptxt, FCoeffMessage &coeff, Real scale) const; \ + prefix template void DecryptorT::decodeWithPolyPair( \ + const Polynomial &ptxt, CoeffMessage &coeff, Real scale) const; \ + prefix template void \ + DecryptorT::decodeWithPolyPair( \ + const Polynomial &ptxt, FCoeffMessage &coeff, Real scale) const; \ + prefix template void DecryptorT::decodeWithoutFFT( \ + const Polynomial &ptxt, CoeffMessage &coeff, Real scale) const; \ + prefix template void DecryptorT::decodeWithoutFFT( \ + const Polynomial &ptxt, FCoeffMessage &coeff, Real scale) const; \ + prefix template void DecryptorT::decode( \ + const Polynomial &ptxt, Message &msg, Real scale) const; \ + prefix template void DecryptorT::decode( \ + const Polynomial &ptxt, FMessage &msg, Real scale) const; -#define DECRYPT_TYPE_TEMPLATE(prefix) \ - DECL_DECRYPT_TEMPLATE_MSG(Message, prefix) \ - DECL_DECRYPT_TEMPLATE_MSG(CoeffMessage, prefix) +#define DECRYPT_TYPE_TEMPLATE(preset, prefix) \ + prefix template class DecryptorT; \ + DECL_DECRYPT_TEMPLATE_MSG(preset, Message, prefix) \ + DECL_DECRYPT_TEMPLATE_MSG(preset, FMessage, prefix) \ + DECL_DECRYPT_TEMPLATE_MSG(preset, CoeffMessage, prefix) \ + DECL_DECRYPT_TEMPLATE_MSG(preset, FCoeffMessage, prefix) \ + DECL_DECRYPT_TEMPLATE_DECODE(preset, prefix) -DECRYPT_TYPE_TEMPLATE(extern) +#define X(preset) DECRYPT_TYPE_TEMPLATE(PRESET_##preset, extern) +PRESET_LIST_WITH_EMPTY +#undef X } // namespace deb diff --git a/include/deb/Encryptor.hpp b/include/deb/Encryptor.hpp index 0722430..8e709e8 100644 --- a/include/deb/Encryptor.hpp +++ b/include/deb/Encryptor.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,14 +17,11 @@ #pragma once #include "CKKSTypes.hpp" -#include "Constant.hpp" -#include "Context.hpp" -#include "SeedGenerator.hpp" #include "utils/Basic.hpp" +#include "utils/Constant.hpp" #include "utils/FFT.hpp" -#include "utils/ModArith.hpp" - -#include "alea/alea.h" +#include "utils/PresetTraits.hpp" +#include "utils/RandomGenerator.hpp" #include #include @@ -38,8 +35,8 @@ namespace deb { * @brief Configures optional behaviors for encryption routines. */ struct EncryptOptions { - Real scale = 0; /**< Requested plaintext scale (0 = auto). */ - Size level = DEB_MAX_SIZE; /**< Encryption level override. */ + Real scale = 0; /**< Requested plaintext scale (0 = auto). */ + Size level = utils::DEB_MAX_SIZE; /**< Encryption level override. */ bool ntt_out = true; /**< Whether ciphertext output stays in NTT form. */ /** * @brief Sets the desired scale value. @@ -72,20 +69,31 @@ struct EncryptOptions { [[maybe_unused]] static EncryptOptions default_opt; -// TODO: make template for Encryptor -// to support constexpr functions with various presets /** * @brief Provides CKKS encoding and encryption routines. */ -class Encryptor { +template class EncryptorT : public PresetTraits

{ +#define CV(type, var_name) using PresetTraits

::var_name; + CONST_LIST +#undef CV + using PresetTraits

::modarith; + public: /** * @brief Constructs an encryptor bound to a preset and optional RNG seed. * @param preset Target preset. * @param seeds Optional deterministic seed. */ - explicit Encryptor(const Preset preset, - std::optional seeds = std::nullopt); + explicit EncryptorT(std::optional seeds = std::nullopt); + explicit EncryptorT(Preset actual_preset, + std::optional seeds = std::nullopt); + /** + * @brief Constructs an encryptor with a custom random generator. + * @param actual_preset Target preset. + * @param rng Custom random generator instance. + */ + explicit EncryptorT(Preset actual_preset, + std::shared_ptr rng); template >, int> = 0> @@ -128,10 +136,7 @@ class Encryptor { void innerEncrypt([[maybe_unused]] const Polynomial &ptxt, [[maybe_unused]] const KEY &key, [[maybe_unused]] Size num_polyunit, - [[maybe_unused]] Ciphertext &ctxt) const { - throw std::runtime_error( - "Encryptor::innerEncrypt: Not implemented for this key type"); - } + [[maybe_unused]] Ciphertext &ctxt) const; template void embeddingToN(const MSG &msg, const Real &delta, Polynomial &ptxt, @@ -143,45 +148,60 @@ class Encryptor { void sampleZO(const Size num_polyunit) const; - void sampleGaussian(const Size idx, const Size num_polyunit, - const bool do_ntt) const; + void sampleGaussian(const Size num_polyunit, const bool do_ntt) const; - Context context_; - std::shared_ptr as_; + std::shared_ptr rng_; // compute buffers mutable Polynomial ptxt_buffer_; mutable Polynomial vx_buffer_; - mutable std::vector ex_buffers_; + mutable Polynomial ex_buffer_; + mutable std::vector samples_; + mutable std::vector mask_; + mutable std::vector i_samples_; - // TODO: move to Context - std::vector modarith_; - utils::FFTImpl fft_; + utils::FFT fft_; }; +using Encryptor = EncryptorT<>; + // NOLINTBEGIN -#define DECL_ENCRYPT_TEMPLATE_MSG_KEY(msg_t, key_t, prefix) \ - prefix template void Encryptor::encrypt( \ +#define DECL_ENCRYPT_TEMPLATE_MSG_KEY(preset, msg_t, key_t, prefix) \ + prefix template void EncryptorT::encrypt( \ const msg_t &msg, const key_t &key, Ciphertext &ctxt, \ const EncryptOptions &opt) const; \ - prefix template void Encryptor::encrypt( \ + prefix template void EncryptorT::encrypt( \ const std::vector &msg, const key_t &key, Ciphertext &ctxt, \ const EncryptOptions &opt) const; \ - prefix template void Encryptor::encrypt( \ + prefix template void EncryptorT::encrypt( \ const msg_t *msg, const key_t &key, Ciphertext &ctxt, \ const EncryptOptions &opt) const; -#define DECL_ENCRYPT_TEMPLATE_MSG(msg_t, prefix) \ - DECL_ENCRYPT_TEMPLATE_MSG_KEY(msg_t, SecretKey, prefix) \ - DECL_ENCRYPT_TEMPLATE_MSG_KEY(msg_t, SwitchKey, prefix) \ - prefix template void Encryptor::embeddingToN( \ +#define DECL_ENCRYPT_TEMPLATE_MSG(preset, msg_t, prefix) \ + DECL_ENCRYPT_TEMPLATE_MSG_KEY(preset, msg_t, SecretKey, prefix) \ + DECL_ENCRYPT_TEMPLATE_MSG_KEY(preset, msg_t, SwitchKey, prefix) \ + prefix template void EncryptorT::embeddingToN( \ const msg_t &msg, const Real &delta, Polynomial &ptxt, \ const Size size) const; \ - prefix template void Encryptor::encodeWithoutNTT( \ + prefix template void EncryptorT::encodeWithoutNTT( \ const msg_t &msg, Polynomial &ptxt, const Size size, const Real scale) \ const; + +#define DECL_ENCRYPT_TEMPLATE(preset, prefix) \ + prefix template class EncryptorT; \ + DECL_ENCRYPT_TEMPLATE_MSG(preset, Message, prefix) \ + DECL_ENCRYPT_TEMPLATE_MSG(preset, FMessage, prefix) \ + DECL_ENCRYPT_TEMPLATE_MSG(preset, CoeffMessage, prefix) \ + DECL_ENCRYPT_TEMPLATE_MSG(preset, FCoeffMessage, prefix) \ + prefix template void EncryptorT::innerEncrypt( \ + const Polynomial &ptxt, const SecretKey &key, const Size num_polyunit, \ + Ciphertext &ctxt) const; \ + prefix template void EncryptorT::innerEncrypt( \ + const Polynomial &ptxt, const SwitchKey &key, const Size num_polyunit, \ + Ciphertext &ctxt) const; // NOLINTEND -DECL_ENCRYPT_TEMPLATE_MSG(Message, extern) -DECL_ENCRYPT_TEMPLATE_MSG(CoeffMessage, extern) +#define X(preset) DECL_ENCRYPT_TEMPLATE(PRESET_##preset, extern) +PRESET_LIST_WITH_EMPTY +#undef X } // namespace deb diff --git a/include/deb/KeyGenerator.hpp b/include/deb/KeyGenerator.hpp index 7c5fd6b..954dd4a 100644 --- a/include/deb/KeyGenerator.hpp +++ b/include/deb/KeyGenerator.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,9 @@ #pragma once #include "CKKSTypes.hpp" -#include "Context.hpp" #include "utils/FFT.hpp" -#include "utils/ModArith.hpp" - -#include "alea/alea.h" +#include "utils/PresetTraits.hpp" +#include "utils/RandomGenerator.hpp" #include #include @@ -32,27 +30,34 @@ namespace deb { /** * @brief Generates an encryption key and switching keys for CKKS presets. */ -class KeyGenerator { +template +class KeyGeneratorT : public PresetTraits

{ +#define CV(type, var_name) using PresetTraits

::var_name; + CONST_LIST +#undef CV + using PresetTraits

::modarith; + public: /** * @brief Builds a key generator for a preset when no secret key is - * provided. + * provided. An external secret key must be given for key generation calls. * @param preset Target preset whose parameters drive key sizes. - * @param seeds Optional deterministic RNG seed material. - */ - explicit KeyGenerator(const Preset preset, - std::optional seeds = std::nullopt); - /** - * @brief Builds a key generator around an existing secret key. - * @param sk Secret key that serves as the default source of secret data. * @param seeds Optional deterministic RNG seed material used when new * samples are required. */ - explicit KeyGenerator(const SecretKey &sk, - std::optional seeds = std::nullopt); + explicit KeyGeneratorT(std::optional seeds = std::nullopt); + explicit KeyGeneratorT(const Preset preset, + std::optional seeds = std::nullopt); + /** + * @brief Builds a key generator with a custom random generator. + * @param preset Target preset whose parameters drive key sizes. + * @param rng Custom random generator instance. + */ + explicit KeyGeneratorT(const Preset preset, + std::shared_ptr rng); - KeyGenerator(const KeyGenerator &) = delete; - ~KeyGenerator() = default; + KeyGeneratorT(const KeyGeneratorT &) = delete; + ~KeyGeneratorT() = default; /** * @brief Generates a switching key that maps one polynomial basis to @@ -69,118 +74,114 @@ class KeyGenerator { const Size bx_size = 0) const; /** - * @brief Generates a fresh encryption key. - * @param sk Optional override secret key; if empty the internally managed - * key is used. + * @brief Generates an encryption key. + * @param sk Secret key to generate public key. * @return Newly created encryption key. */ - SwitchKey genEncKey(std::optional sk = std::nullopt) const; + SwitchKey genEncKey(const SecretKey &sk) const; /** - * @brief Generates an encryption key in-place. + * @brief Generates an encryption key directly into an existing object. * @param enckey Output storage for encryption key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ - void genEncKeyInplace(SwitchKey &enckey, - std::optional sk = std::nullopt) const; + void genEncKeyInplace(SwitchKey &enckey, const SecretKey &sk) const; /** * @brief Generates a multiplication key used for ciphertext-ciphertext * products. - * @param sk Optional override secret key. - * @return Switching key specialized for multiplication. + * @param sk Secret key to generate public key. + * @return Multiplication key. */ - SwitchKey genMultKey(std::optional sk = std::nullopt) const; + SwitchKey genMultKey(const SecretKey &sk) const; /** - * @brief Generates a multiplication key in-place. + * @brief Generates a multiplication key directly into an existing object. * @param mulkey Output storage for multiplication key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ - void genMultKeyInplace(SwitchKey &mulkey, - std::optional sk = std::nullopt) const; + void genMultKeyInplace(SwitchKey &mulkey, const SecretKey &sk) const; /** * @brief Generates a conjugation key for complex conjugate operations. - * @param sk Optional override secret key. - * @return Switching key for conjugation. + * @param sk Secret key to generate public key. + * @return Conjugation key. */ - SwitchKey genConjKey(std::optional sk = std::nullopt) const; + SwitchKey genConjKey(const SecretKey &sk) const; /** - * @brief Generates a conjugation key in-place. + * @brief Generates a conjugation key directly into an existing object. * @param conjkey Output storage for conjugation key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ - void genConjKeyInplace(SwitchKey &conjkey, - std::optional sk = std::nullopt) const; + void genConjKeyInplace(SwitchKey &conjkey, const SecretKey &sk) const; /** - * @brief Generates a left rotation key for slot rotations. - * @param rot Rotation step expressed in slots. - * @param sk Optional override secret key. - * @return Switching key bound to the requested rotation. + * @brief Generates a left rotation key for specific rotate operation. + * @param rot Rotation index. + * @param sk Secret key to generate public key. + * @return Left rotation key of rotation index @p rot. */ - SwitchKey genLeftRotKey(const Size rot, - std::optional sk = std::nullopt) const; + SwitchKey genLeftRotKey(const Size rot, const SecretKey &sk) const; /** - * @brief Generates a left rotation key and writes it to an existing - * structure. - * @param rot Rotation step expressed in slots. + * @brief Generates a left rotation key directly into an existing object. + * @param rot Rotation index. * @param rotkey Output storage for left rotation key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ void genLeftRotKeyInplace(const Size rot, SwitchKey &rotkey, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** - * @brief Generates a right rotation key for slot rotations. - * @param rot Rotation step expressed in slots. - * @param sk Optional override secret key. - * @return Switching key bound to the requested rotation. + * @brief Generates a right rotation key for specific rotate operation. + * @param rot Rotation index. + * @param sk Secret key to generate public key. + * @return Right rotation key of rotation index @p rot. */ - SwitchKey genRightRotKey(const Size rot, - std::optional sk = std::nullopt) const; + SwitchKey genRightRotKey(const Size rot, const SecretKey &sk) const; /** - * @brief Generates a right rotation key into an existing object. - * @param rot Rotation step expressed in slots. + * @brief Generates a right rotation key directly into an existing object. + * @param rot Rotation index. * @param rotkey Output storage for right rotation key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ - void - genRightRotKeyInplace(const Size rot, SwitchKey &rotkey, - std::optional sk = std::nullopt) const; + void genRightRotKeyInplace(const Size rot, SwitchKey &rotkey, + const SecretKey &sk) const; /** * @brief Generates an automorphism key identified by the exponent sig. * @param sig The power index of the automorphism. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. * @return Switching key that realizes the automorphism. */ - SwitchKey genAutoKey(const Size sig, - std::optional sk = std::nullopt) const; + SwitchKey genAutoKey(const Size sig, const SecretKey &sk) const; /** - * @brief Generates an automorphism key into an existing object. + * @brief Generates an automorphism key directly into an existing object. * @param sig Automorphism identifier. * @param autokey Output storage for automorphism key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ void genAutoKeyInplace(const Size sig, SwitchKey &autokey, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** - * @brief Generates a composition switch key from an input secret key. + * @brief Generates a composition switch key from an input secret key @p + * sk_from. * @param sk_from Source secret key to be composed into the managed key. * @param sk Optional target secret key override. - * @return Switching key that composes @p sk_from into @p sk. + * @return Composition key from @p sk_from. */ SwitchKey genComposeKey(const SecretKey &sk_from, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Coefficient vector that describes the source secret key. + * @param sk Optional target secret key override. + * @return Composition key from the secret key from @p coeffs. */ SwitchKey genComposeKey(const std::vector coeffs, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Pointer to coefficient data. * @param size Number of coefficients provided. + * @param sk Optional target secret key override. + * @return Composition key from the secret key from @p coeffs. */ SwitchKey genComposeKey(const i8 *coeffs, Size size, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief Generates a composition key directly into an existing object. * @param sk_from Source secret key to be composed. @@ -188,7 +189,7 @@ class KeyGenerator { * @param sk Optional target secret key override. */ void genComposeKeyInplace(const SecretKey &sk_from, SwitchKey &composekey, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Coefficient vector describing the source secret key. @@ -196,8 +197,7 @@ class KeyGenerator { * @param sk Optional target secret key override. */ void genComposeKeyInplace(const std::vector coeffs, - SwitchKey &composekey, - std::optional sk = std::nullopt) const; + SwitchKey &composekey, const SecretKey &sk) const; /** * @brief @overload * @param coeffs Pointer to coefficient data. @@ -206,50 +206,51 @@ class KeyGenerator { * @param sk Optional target secret key override. */ void genComposeKeyInplace(const i8 *coeffs, Size size, - SwitchKey &composekey, - std::optional sk = std::nullopt) const; + SwitchKey &composekey, const SecretKey &sk) const; /** * @brief Generates a decomposition key that maps to the provided target - * secret key. + * secret key @p sk_to. * @param sk_to Destination secret key. * @param sk Optional source secret key override. - * @return Switching key used for decomposition. + * @return Decomposition key maps to @p sk_to. */ SwitchKey genDecomposeKey(const SecretKey &sk_to, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Coefficient vector describing the destination secret key. + * @param sk Optional source secret key override. + * @return Decomposition key maps to the secret key from @p coeffs. */ SwitchKey genDecomposeKey(const std::vector coeffs, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Pointer to coefficient data. * @param coeffs_size Number of coefficients supplied. + * @param sk Optional source secret key override. + * @return Decomposition key maps to the secret key from @p coeffs. */ SwitchKey genDecomposeKey(const i8 *coeffs, Size coeffs_size, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** - * @brief Fills an existing switch key with decomposition data targeted at - * @p sk_to. + * @brief Generates a decomposition key directly into an existing object. * @param sk_to Destination secret key. * @param decompkey Output storage for decomposition key. * @param sk Optional source secret key override. */ - void - genDecomposeKeyInplace(const SecretKey &sk_to, SwitchKey &decompkey, - std::optional sk = std::nullopt) const; + void genDecomposeKeyInplace(const SecretKey &sk_to, SwitchKey &decompkey, + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Destination secret key coefficients. * @param decompkey Output storage for decomposition key. * @param sk Optional source secret key override. */ - void - genDecomposeKeyInplace(const std::vector coeffs, SwitchKey &decompkey, - std::optional sk = std::nullopt) const; + void genDecomposeKeyInplace(const std::vector coeffs, + SwitchKey &decompkey, + const SecretKey &sk) const; /** * @brief @overload * @param coeffs Destination secret key coefficients buffer. @@ -257,82 +258,88 @@ class KeyGenerator { * @param decompkey Output storage for decomposition key. * @param sk Optional source secret key override. */ - void - genDecomposeKeyInplace(const i8 *coeffs, Size coeffs_size, - SwitchKey &decompkey, - std::optional sk = std::nullopt) const; + void genDecomposeKeyInplace(const i8 *coeffs, Size coeffs_size, + SwitchKey &decompkey, + const SecretKey &sk) const; /** * @brief Generates a decomposition key using preset-specific parameters. * @param preset_swk Preset that controls switching key layout. * @param sk_to Destination secret key. * @param sk Optional source secret key override. - * @return Switching key configured for @p preset_swk. + * @return Decomposition key configured for @p preset_swk. */ SwitchKey genDecomposeKey(const Preset preset_swk, const SecretKey &sk_to, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload + * @param preset_swk Preset that controls switching key layout. * @param coeffs Destination secret key coefficients. + * @param sk Optional source secret key override. + * @return Decomposition key configured for @p preset_swk. */ SwitchKey genDecomposeKey(const Preset preset_swk, const std::vector coeffs, - std::optional sk = std::nullopt) const; + const SecretKey &sk) const; /** * @brief @overload + * @param preset_swk Preset that controls switching key layout. * @param coeffs Pointer to coefficient data. * @param coeffs_size Number of coefficients supplied. + * @param sk Optional source secret key override. + * @return Decomposition key configured for @p preset_swk. */ SwitchKey genDecomposeKey(const Preset preset_swk, const i8 *coeffs, - Size coeffs_size, - std::optional sk = std::nullopt) const; + Size coeffs_size, const SecretKey &sk) const; /** - * @brief Fills an existing switch key using preset-specific parameters. + * @brief Generate a decomposition key directly into an existing object + * using preset-specific parameters. * @param preset_swk Preset that controls the generated layout. * @param sk_to Destination secret key. * @param decompkey Output storage for decomposition key. * @param sk Optional source secret key override. */ - void - genDecomposeKeyInplace(const Preset preset_swk, const SecretKey &sk_to, - SwitchKey &decompkey, - std::optional sk = std::nullopt) const; + void genDecomposeKeyInplace(const Preset preset_swk, const SecretKey &sk_to, + SwitchKey &decompkey, + const SecretKey &sk) const; /** * @brief @overload + * @param preset_swk Preset that controls the generated layout. * @param coeffs Destination secret key coefficients. * @param decompkey Output storage for decomposition key. * @param sk Optional source secret key override. */ - void - genDecomposeKeyInplace(const Preset preset_swk, - const std::vector coeffs, SwitchKey &decompkey, - std::optional sk = std::nullopt) const; + void genDecomposeKeyInplace(const Preset preset_swk, + const std::vector coeffs, + SwitchKey &decompkey, + const SecretKey &sk) const; /** * @brief @overload + * @param preset_swk Preset that controls the generated layout. * @param coeffs Pointer to destination secret key coefficients. * @param coeffs_size Number of coefficients supplied. * @param decompkey Output storage for decomposition key. * @param sk Optional source secret key override. */ - void - genDecomposeKeyInplace(const Preset preset_swk, const i8 *coeffs, - Size coeffs_size, SwitchKey &decompkey, - std::optional sk = std::nullopt) const; + void genDecomposeKeyInplace(const Preset preset_swk, const i8 *coeffs, + Size coeffs_size, SwitchKey &decompkey, + const SecretKey &sk) const; /** * @brief Generates a bundle of modulus packing keys between two secret * keys. * @param sk_from Source secret key. * @param sk_to Destination secret key. - * @return Vector of switching keys implementing the mod-pack bundle. + * @return Vector of modpack keys from @p sk_from to @p sk_to. */ std::vector genModPackKeyBundle(const SecretKey &sk_from, const SecretKey &sk_to) const; /** - * @brief Populates an existing bundle with modulus packing keys. + * @brief Generate a bundle of modulus packing keys directly into an + * existing object. * @param sk_from Source secret key. * @param sk_to Destination secret key. - * @param key_bundle Output vector to populate. + * @param key_bundle Output storage for modpack key bundle. */ void genModPackKeyBundleInplace(const SecretKey &sk_from, const SecretKey &sk_to, @@ -341,22 +348,20 @@ class KeyGenerator { // For self modpack /** * @brief Generates a modulus packing key for self mod-pack operations. - * @param pad_rank Rank padding parameter. - * @param sk Optional override secret key. - * @return Switching key configured for self mod-pack. + * @param pad_rank Rank parameter, assumed to be padded power of two. + * @param sk Secret key to generate public key. + * @return Modpack keys with @p pad_rank. */ - SwitchKey - genModPackKeyBundle(const Size pad_rank, - std::optional sk = std::nullopt) const; + SwitchKey genModPackKeyBundle(const Size pad_rank, + const SecretKey &sk) const; /** * @brief Generates a self mod-pack key in-place. - * @param pad_rank Rank padding parameter. + * @param pad_rank Rank parameter, assumed to be padded power of two. * @param modkey Output storage for mod-pack key. - * @param sk Optional override secret key. + * @param sk Secret key to generate public key. */ - void genModPackKeyBundleInplace( - const Size pad_rank, SwitchKey &modkey, - std::optional sk = std::nullopt) const; + void genModPackKeyBundleInplace(const Size pad_rank, SwitchKey &modkey, + const SecretKey &sk) const; private: void frobeniusMapInNTT(const Polynomial &op, const i32 pow, @@ -368,16 +373,18 @@ class KeyGenerator { void sampleUniform(Polynomial &poly) const; void computeConst(); - Context context_; - std::optional sk_; - std::shared_ptr as_; + std::shared_ptr rng_; // TODO: move to Context std::vector p_mod_; std::vector hat_q_i_mod_; std::vector hat_q_i_inv_mod_; - std::vector modarith_; utils::FFT fft_; }; +using KeyGenerator = KeyGeneratorT<>; + +#define X(preset) extern template class KeyGeneratorT; +PRESET_LIST_WITH_EMPTY +#undef X } // namespace deb diff --git a/include/deb/Preset.hpp b/include/deb/Preset.hpp new file mode 100644 index 0000000..8415110 --- /dev/null +++ b/include/deb/Preset.hpp @@ -0,0 +1,67 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "DebParam.hpp" +#include "Types.hpp" +#include "utils/Macro.hpp" + +#include +#include +#include + +// Define preset values and precomputed values from preset values. +#define CONST_LIST \ + CV(Preset, preset) \ + CV(Preset, parent) \ + CV(const char *, preset_name) \ + CV(Size, rank) \ + CV(Size, num_secret) \ + CV(Size, log_degree) \ + CV(Size, degree) \ + CV(Size, num_slots) \ + CV(Size, gadget_rank) \ + CV(Size, num_base) \ + CV(Size, num_qp) \ + CV(Size, num_tp) \ + CV(Size, num_p) \ + CV(Size, encryption_level) \ + CV(Size, hamming_weight) \ + CV(Real, gaussian_error_stdev) \ + CV(const u64 *, primes) \ + CV(const Real *, scale_factors) + +namespace deb { + +using PresetVariant = std::variant< +#define X(p) p, + PRESET_LIST +#undef X + EMPTY>; + +inline std::unordered_map preset_map = { +#define X(p) {PRESET_##p, p{}}, + PRESET_LIST_WITH_EMPTY +#undef X +}; + +// Getter functions for constant values from presets. +#define CV(type, var_name) type get_##var_name(Preset preset); +CONST_LIST +#undef CV + +} // namespace deb diff --git a/include/deb/SecretKeyGenerator.hpp b/include/deb/SecretKeyGenerator.hpp index 3cb6a0c..29af5be 100644 --- a/include/deb/SecretKeyGenerator.hpp +++ b/include/deb/SecretKeyGenerator.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,14 +17,9 @@ #pragma once #include "CKKSTypes.hpp" -#include "Constant.hpp" -#include "Context.hpp" -#include "SeedGenerator.hpp" +#include "utils/Constant.hpp" #include "utils/NTT.hpp" - -#include "DebFBType.h" -#include "alea/alea.h" -#include "alea/algorithms.h" +#include "utils/RandomGenerator.hpp" #include #include @@ -34,7 +29,6 @@ namespace deb { -// template /** * @brief Generates secret keys and secret coefficients for CKKS presets. */ @@ -145,7 +139,7 @@ class SecretKeyGenerator { static void GenSecretKeyFromCoeffInplace(SecretKey &sk, const i8 *coeffs); private: - Preset preset_; + const Preset preset_; }; /** diff --git a/include/deb/SeedGenerator.hpp b/include/deb/SeedGenerator.hpp index d045357..e8f9313 100644 --- a/include/deb/SeedGenerator.hpp +++ b/include/deb/SeedGenerator.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,16 @@ #pragma once -#include "Context.hpp" +#include "utils/RandomGenerator.hpp" -#include "alea/alea.h" -#include "alea/algorithms.h" - -#include +#include #include namespace deb { /** - * @brief Number of 64-bit words in a CKKS RNG seed. - */ -constexpr size_t DEB_U64_SEED_SIZE = ALEA_SEED_SIZE_SHAKE256 / sizeof(u64); -/** - * @brief Deterministic seed material shared across RNG utilities. - */ -using RNGSeed = std::array; - -/** - * @brief Converts the library seed format to ALEA's byte-oriented seed. - * @param seed Source seed material. - * @return Pointer suitable for ALEA APIs. - */ -const u8 *to_alea_seed(const RNGSeed &seed); - -/** - * @brief Singleton wrapper over ALEA to provide deterministic RNG streams. + * @brief Singleton wrapper over RandomGenerator to provide deterministic RNG + * streams. */ class SeedGenerator { public: @@ -75,10 +57,10 @@ class SeedGenerator { SeedGenerator(std::optional seeds); /** - * @brief Internal helper that produces a new seed from the ALEA state. + * @brief Internal helper that produces a new seed from the RNG state. */ RNGSeed genSeed(); - std::unique_ptr as_; + std::shared_ptr rng_; }; } // namespace deb diff --git a/include/deb/Serialize.hpp b/include/deb/Serialize.hpp index 366b9d2..593f472 100644 --- a/include/deb/Serialize.hpp +++ b/include/deb/Serialize.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,57 +29,78 @@ namespace deb { template using Vector = flatbuffers::Vector; /** - * @brief Converts a buffer of @ref Complex into a vector usable by FlatBuffers. - * @param data Pointer to complex values. - * @param size Number of elements to convert. + * @brief Converts a double-precision complex into FlatBuffers format. + * @param data Pointer to double-precision complex values. + * @param size Number of elements in @p data. * @return Vector of FlatBuffer-compatible complex values. */ std::vector toComplexVector(const Complex *data, const Size size); /** - * @brief Converts FlatBuffers complex data back into @ref Complex values. - * @param data FlatBuffers vector pointer. + * @brief Converts FlatBuffers double-precision complex data back into @ref + * Complex values. + * @param data FlatBuffers double-precision complex vector pointer. * @return Vector with decoded complex values. */ std::vector toDebComplexVector(const Vector *data); /** - * @brief Converts single-precision complex data into FlatBuffers format. - * @param data Pointer to complex32 values. - * @param size Number of elements. - * @return Vector of FlatBuffer-compatible complex32 values. + * @brief Converts single-precision complex into FlatBuffers format. + * @param data Pointer to single-precision complex values. + * @param size Number of elements in @p data. + * @return Vector of FlatBuffer-compatible complex values. */ std::vector toComplex32Vector(const ComplexT *data, const Size size); /** - * @brief Converts FlatBuffers complex32 data back into @ref Complex values. - * @param data FlatBuffers vector pointer. + * @brief Converts FlatBuffers single-precision complex data back into @ref + * ComplexT values. + * @param data FlatBuffers single-precision complex vector pointer. * @return Vector with decoded complex values. */ -std::vector +std::vector> toDebComplex32Vector(const Vector *data); /** - * @brief Serializes a high-precision slot message into FlatBuffers format. - * @param builder FlatBuffer builder used to allocate the payload. - * @param message Plaintext message container. + * @brief Serializes a double-precision slot message into FlatBuffers format. + * @param builder FlatBuffer builder. + * @param message Double-precision deb message object. * @return Offset into the builder pointing to the serialized object. */ flatbuffers::Offset serializeMessage(flatbuffers::FlatBufferBuilder &builder, const Message &message); /** - * @brief Deserializes a FlatBuffers message into @ref Message. - * @param message FlatBuffers object. - * @return Plaintext message container. + * @brief Deserializes a FlatBuffers double-precision message into @ref Message. + * @param message FlatBuffers double-precision message object. + * @return Deserialized deb format message. */ Message deserializeMessage(const deb_fb::Message *message); /** - * @brief Serializes coefficient-domain plaintexts. + * @brief Serializes a single-precision slot message into FlatBuffers format. * @param builder FlatBuffer builder. - * @param coeff Coefficient-domain message container. + * @param message Single-precision message object. + * @return Offset into the builder pointing to the serialized object. + */ +flatbuffers::Offset +serializeFMessage(flatbuffers::FlatBufferBuilder &builder, + const FMessage &message); + +/** + * @brief Deserializes a FlatBuffers single-precision message into @ref + * FMessage. + * @param message FlatBuffers single-precision message object. + * @return Deserialized deb format message. + */ +FMessage deserializeFMessage(const deb_fb::Message32 *message); + +/** + * @brief Serializes a double-precision coefficient message into FlatBuffers + * format. + * @param builder FlatBuffer builder. + * @param coeff Double-precision coefficient message object. * @return Offset pointing to serialized coefficients. */ flatbuffers::Offset @@ -87,12 +108,32 @@ serializeCoeff(flatbuffers::FlatBufferBuilder &builder, const CoeffMessage &coeff); /** - * @brief Deserializes coefficient-domain plaintexts. - * @param coeff FlatBuffers coefficient object. - * @return Coefficient-domain message container. + * @brief Deserializes a FlatBuffers double-precision coefficient message into + * @ref CoeffMessage. + * @param coeff FlatBuffers double-precision coefficient object. + * @return Deserialized deb format coefficient message. */ CoeffMessage deserializeCoeff(const deb_fb::Coeff *coeff); +/** + * @brief Serializes a single-precision coefficient message into FlatBuffers + * format. + * @param builder FlatBuffer builder. + * @param coeff Single-precision coefficient message object. + * @return Offset pointing to serialized coefficients. + */ +flatbuffers::Offset +serializeFCoeff(flatbuffers::FlatBufferBuilder &builder, + const FCoeffMessage &coeff); + +/** + * @brief Deserializes a FlatBuffers single-precision coefficient message into + * @ref FCoeffMessage. + * @param coeff FlatBuffers single-precision coefficient object. + * @return Deserialized deb format coefficient message. + */ +FCoeffMessage deserializeFCoeff(const deb_fb::Coeff32 *coeff); + /** * @brief Serializes a poly unit into FlatBuffers form. * @param builder FlatBuffer builder. @@ -200,8 +241,12 @@ void appendOffsetToVector(const flatbuffers::Offset &offset, type_vec.push_back(deb_fb::DebUnion_PolyUnit); } else if constexpr (std::is_same_v) { type_vec.push_back(deb_fb::DebUnion_Message); + } else if constexpr (std::is_same_v) { + type_vec.push_back(deb_fb::DebUnion_Message32); } else if constexpr (std::is_same_v) { type_vec.push_back(deb_fb::DebUnion_Coeff); + } else if constexpr (std::is_same_v) { + type_vec.push_back(deb_fb::DebUnion_Coeff32); } else { throw std::runtime_error( "[appendOffsetToVector] Unsupported type for serialization"); @@ -250,8 +295,12 @@ template void serializeToStream(const T &data, std::ostream &os) { builder.Finish(toDeb(builder, serializePolyUnit(builder, data))); } else if constexpr (std::is_same_v) { builder.Finish(toDeb(builder, serializeMessage(builder, data))); + } else if constexpr (std::is_same_v) { + builder.Finish(toDeb(builder, serializeFMessage(builder, data))); } else if constexpr (std::is_same_v) { builder.Finish(toDeb(builder, serializeCoeff(builder, data))); + } else if constexpr (std::is_same_v) { + builder.Finish(toDeb(builder, serializeFCoeff(builder, data))); } else { throw std::runtime_error( "[serializeToStream] Unsupported type for serialization"); @@ -290,7 +339,7 @@ void deserializeFromStream(std::istream &is, T &data, if constexpr (std::is_same_v) { data = deserializeSwk(deb->list()->GetAs(0)); } else if constexpr (std::is_same_v) { - data = deserializeSk(deb->list()->GetAs(0)); + data = std::move(deserializeSk(deb->list()->GetAs(0))); } else if constexpr (std::is_same_v) { data = deserializeCipher(deb->list()->GetAs(0)); } else if constexpr (std::is_same_v) { @@ -304,8 +353,12 @@ void deserializeFromStream(std::istream &is, T &data, data = deserializePolyUnit(deb->list()->GetAs(0)); } else if constexpr (std::is_same_v) { data = deserializeMessage(deb->list()->GetAs(0)); + } else if constexpr (std::is_same_v) { + data = deserializeFMessage(deb->list()->GetAs(0)); } else if constexpr (std::is_same_v) { data = deserializeCoeff(deb->list()->GetAs(0)); + } else if constexpr (std::is_same_v) { + data = deserializeFCoeff(deb->list()->GetAs(0)); } else { throw std::runtime_error( "[deserializeFromStream] Unsupported type for deserialization"); diff --git a/include/deb/Types.hpp b/include/deb/Types.hpp index a807e40..97ae78e 100644 --- a/include/deb/Types.hpp +++ b/include/deb/Types.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/include/deb/utils/AleaRandomGenerator.hpp b/include/deb/utils/AleaRandomGenerator.hpp new file mode 100644 index 0000000..c5f032a --- /dev/null +++ b/include/deb/utils/AleaRandomGenerator.hpp @@ -0,0 +1,45 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "utils/RandomGenerator.hpp" + +#include + +namespace deb { + +class AleaRandomGenerator : public RandomGenerator { +public: + explicit AleaRandomGenerator(const RNGSeed &seed); + ~AleaRandomGenerator() override; + + AleaRandomGenerator(const AleaRandomGenerator &) = delete; + AleaRandomGenerator &operator=(const AleaRandomGenerator &) = delete; + + void getRandomUint64Array(u64 *dst, size_t len) override; + void getRandomUint64ArrayInRange(u64 *dst, size_t len, u64 range) override; + + void sampleGaussianInt64Array(i64 *dst, size_t len, double stdev) override; + void sampleHwtInt8Array(i8 *dst, size_t len, int hwt) override; + + void reseed(const u8 *seed, size_t seed_len) override; + +private: + void *state_; +}; + +} // namespace deb diff --git a/include/deb/utils/Basic.hpp b/include/deb/utils/Basic.hpp index 627e655..2f21596 100644 --- a/include/deb/utils/Basic.hpp +++ b/include/deb/utils/Basic.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,14 +17,242 @@ #pragma once #include "CKKSTypes.hpp" - #include +#ifdef _MSC_VER +#include +#endif + namespace deb::utils { +// --------------------------------------------------------------------------- +// 128-bit integer types +// +// GCC/Clang provide __int128 natively. MSVC does not, so we supply +// lightweight struct wrappers that expose the same set of operations used +// throughout deb (arithmetic, shifts, comparisons, casts). +// --------------------------------------------------------------------------- +#ifdef _MSC_VER + +struct i128; // forward declaration + +struct u128 { + u64 lo; + u64 hi; + + constexpr u128() : lo(0), hi(0) {} + constexpr u128(u64 val) : lo(val), hi(0) {} // NOLINT(implicit) + constexpr u128(u64 hi_, u64 lo_) : lo(lo_), hi(hi_) {} + + constexpr explicit operator u64() const { return lo; } + constexpr explicit operator bool() const { return lo || hi; } + + // u128 -> double (used by Decryptor: static_cast(u128_val)) + explicit operator Real() const { + constexpr Real two64 = 18446744073709551616.0; // 2^64 + return static_cast(hi) * two64 + static_cast(lo); + } + + // -- shifts ----------------------------------------------------------- + constexpr u128 operator>>(u64 n) const { + if (n == 0) + return *this; + if (n >= 128) + return u128(); + if (n >= 64) + return u128(0, hi >> (n - 64)); + return u128(hi >> n, (lo >> n) | (hi << (64 - n))); + } + constexpr u128 operator<<(u64 n) const { + if (n == 0) + return *this; + if (n >= 128) + return u128(); + if (n >= 64) + return u128(lo << (n - 64), 0); + return u128((hi << n) | (lo >> (64 - n)), lo << n); + } + + // -- bitwise ---------------------------------------------------------- + constexpr u128 operator|(const u128 &o) const { + return u128(hi | o.hi, lo | o.lo); + } + constexpr u128 operator&(const u128 &o) const { + return u128(hi & o.hi, lo & o.lo); + } + constexpr u128 operator^(const u128 &o) const { + return u128(hi ^ o.hi, lo ^ o.lo); + } + constexpr u128 operator~() const { return u128(~hi, ~lo); } + + // -- arithmetic ------------------------------------------------------- + constexpr u128 operator+(const u128 &o) const { + u64 r = lo + o.lo; + return u128(hi + o.hi + (r < lo ? 1 : 0), r); + } + constexpr u128 operator-(const u128 &o) const { + u64 r = lo - o.lo; + return u128(hi - o.hi - (lo < o.lo ? 1 : 0), r); + } + + // Multiplication – schoolbook 32-bit (constexpr-compatible). + constexpr u128 operator*(const u128 &o) const { + u64 a0 = lo & 0xFFFFFFFF, a1 = lo >> 32; + u64 b0 = o.lo & 0xFFFFFFFF, b1 = o.lo >> 32; + u64 p00 = a0 * b0; + u64 p01 = a0 * b1; + u64 p10 = a1 * b0; + u64 p11 = a1 * b1; + u64 mid = (p00 >> 32) + (p01 & 0xFFFFFFFF) + (p10 & 0xFFFFFFFF); + u64 l = (p00 & 0xFFFFFFFF) | ((mid & 0xFFFFFFFF) << 32); + u64 h = p11 + (p01 >> 32) + (p10 >> 32) + (mid >> 32); + h += lo * o.hi + hi * o.lo; + return u128(h, l); + } + + // Division and modulo by u64 – binary long division (constexpr-compatible). + constexpr u128 operator/(u64 d) const { + if (hi == 0) + return u128(0, lo / d); + u64 qh = hi / d; + u64 rem = hi % d; + u64 ql = 0; + for (int i = 63; i >= 0; i--) { + bool overflow = (rem >> 63) != 0; + rem = (rem << 1) | ((lo >> i) & 1); + if (overflow || rem >= d) { + rem -= d; + ql |= (u64(1) << i); + } + } + return u128(qh, ql); + } + constexpr u128 operator%(u64 d) const { + if (hi == 0) + return u128(0, lo % d); + u64 rem = hi % d; + for (int i = 63; i >= 0; i--) { + bool overflow = (rem >> 63) != 0; + rem = (rem << 1) | ((lo >> i) & 1); + if (overflow || rem >= d) { + rem -= d; + } + } + return u128(0, rem); + } + + // -- comparisons ------------------------------------------------------ + constexpr bool operator==(const u128 &o) const { + return hi == o.hi && lo == o.lo; + } + constexpr bool operator!=(const u128 &o) const { return !(*this == o); } + constexpr bool operator<(const u128 &o) const { + return hi < o.hi || (hi == o.hi && lo < o.lo); + } + constexpr bool operator>(const u128 &o) const { return o < *this; } + constexpr bool operator<=(const u128 &o) const { return !(o < *this); } + constexpr bool operator>=(const u128 &o) const { return !(*this < o); } +}; + +struct i128 { + u64 lo; + i64 hi; // signed + + constexpr i128() : lo(0), hi(0) {} + constexpr i128(int val) // NOLINT(implicit) + : lo(static_cast(static_cast(val))), + hi(val < 0 ? i64(-1) : i64(0)) {} + constexpr i128(i64 val) // NOLINT(implicit) + : lo(static_cast(val)), hi(val < 0 ? i64(-1) : i64(0)) {} + constexpr i128(u64 val) // NOLINT(implicit) + : lo(val), hi(0) {} + constexpr i128(i64 hi_, u64 lo_) : lo(lo_), hi(hi_) {} + constexpr i128(u128 val) : lo(val.lo), hi(static_cast(val.hi)) {} + + // double -> i128 (used in CKKS encoding: static_cast(double)) + i128(Real val) { // NOLINT(implicit) + bool neg = val < 0; + Real a = neg ? -val : val; + constexpr Real two64 = 18446744073709551616.0; + u64 h = (a >= two64) ? static_cast(a / two64) : 0; + u64 l = static_cast(a - static_cast(h) * two64); + if (neg) { + l = ~l + 1; + h = ~h + (l == 0 ? 1 : 0); + } + lo = l; + hi = static_cast(h); + } + + constexpr explicit operator u64() const { return lo; } + + // Reinterpret as unsigned. + constexpr explicit operator u128() const { + return u128(static_cast(hi), lo); + } + + // Arithmetic right shift (sign-extending). + constexpr i128 operator>>(u64 n) const { + if (n == 0) + return *this; + if (n >= 128) + return i128(hi < 0 ? i64(-1) : i64(0)); + if (n >= 64) { + i64 s = hi >> static_cast(n - 64); // arithmetic + return i128(hi < 0 ? i64(-1) : i64(0), static_cast(s)); + } + return i128(hi >> static_cast(n), + (lo >> n) | (static_cast(hi) << (64 - n))); + } + constexpr i128 operator<<(u64 n) const { + if (n == 0) + return *this; + if (n >= 128) + return i128(0, 0); + if (n >= 64) + return i128(static_cast(lo << (n - 64)), 0); + return i128((hi << n) | (lo >> (64 - n)), lo << n); + } + + // -- arithmetic ------------------------------------------------------- + constexpr i128 operator+(const i128 &o) const { + u64 r = lo + o.lo; + return i128(hi + o.hi + (r < lo ? 1 : 0), r); + } + constexpr i128 operator-(const i128 &o) const { + u64 r = lo - o.lo; + return i128(hi - o.hi - (lo < o.lo ? 1 : 0), r); + } + constexpr i128 operator*(const i128 &o) const { + return static_cast(static_cast(*this) * + static_cast(o)); + } + constexpr i128 operator-() const { + u64 nl = ~lo + 1; + return i128(static_cast(~static_cast(hi) + (nl == 0 ? 1 : 0)), + nl); + } + + // -- comparisons ------------------------------------------------------ + constexpr bool operator==(const i128 &o) const { + return hi == o.hi && lo == o.lo; + } + constexpr bool operator!=(const i128 &o) const { return !(*this == o); } + constexpr bool operator<(const i128 &o) const { + return hi < o.hi || (hi == o.hi && lo < o.lo); + } + constexpr bool operator>(const i128 &o) const { return o < *this; } + constexpr bool operator<=(const i128 &o) const { return !(o < *this); } + constexpr bool operator>=(const i128 &o) const { return !(*this < o); } +}; + +#else // GCC / Clang + using u128 = unsigned __int128; using i128 = __int128; +#endif + /** * @brief Returns the upper 64 bits of a 128-bit integer. * @param value 128-bit value. @@ -153,11 +381,7 @@ inline u64 countLeftZeroes(u64 op) { } inline u64 bitWidth(const u64 op) { -#ifdef __cpp_lib_int_pow2 - return std::bit_width(op); -#else return op ? UINT64_C(64) - countLeftZeroes(op) : UINT64_C(0); -#endif } // Integral log2 with log2floor(0) := 0 @@ -193,7 +417,13 @@ template void bitReverseArray(T *data, u64 n) { /** * @brief Subtracts b from a when a is greater or equal, otherwise returns a. */ -inline u64 subIfGE(u64 a, u64 b) { return (a >= b ? a - b : a); } +inline u64 subIfGE(u64 a, u64 b) { return (a >= b) ? (a - b) : a; } + +inline u64 subIfGEConst(u64 a, u64 b) { + // mask is 0xFFFFFFFF if a >= b, 0x0 if a < b + const u64 mask = ((a - b) >> 63) - 1; + return a - (b & mask); +} /** * @brief Computes a modular inverse using Fermat's little theorem. diff --git a/include/deb/Constant.hpp b/include/deb/utils/Constant.hpp similarity index 94% rename from include/deb/Constant.hpp rename to include/deb/utils/Constant.hpp index c57d8a8..a8c4709 100644 --- a/include/deb/Constant.hpp +++ b/include/deb/utils/Constant.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #include "CKKSTypes.hpp" -namespace deb { +namespace deb::utils { /** * @brief Maximum representable value for @ref Size within this library. @@ -54,4 +54,4 @@ constexpr Complex COMPLEX_ZERO(REAL_ZERO, REAL_ZERO); * @brief Imaginary unit constant (0 + 1i). */ constexpr Complex COMPLEX_IMAG_UNIT(REAL_ZERO, REAL_ONE); -} // namespace deb +} // namespace deb::utils diff --git a/include/deb/utils/FFT.hpp b/include/deb/utils/FFT.hpp index 459d084..729aa95 100644 --- a/include/deb/utils/FFT.hpp +++ b/include/deb/utils/FFT.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,24 +26,24 @@ namespace deb::utils { * @brief FFT implementation specialized for CKKS message containers. * @tparam T Floating-point type. */ -template class FFTImpl { +class FFT { public: /** * @brief Precomputes FFT twiddle factors for the requested degree. * @param degree Polynomial degree. */ - FFTImpl(const u64 degree); + FFT(const u64 degree); /** * @brief Applies the forward FFT to a message in-place. * @param msg Slot-encoded message container. */ - void forwardFFT(MessageImpl &msg) const; + template void forwardFFT(MessageImpl &msg) const; /** * @brief Applies the inverse FFT to a message in-place. * @param msg Slot-encoded message container. */ - void backwardFFT(MessageImpl &msg) const; + template void backwardFFT(MessageImpl &msg) const; /** * @brief Returns powers of the generator used for slot rotations. @@ -54,19 +54,18 @@ template class FFTImpl { private: // u64 degree_; // a.k.a. PolyUnit degree N std::vector powers_of_five_; - std::vector> complex_roots_; - std::vector> roots_; - std::vector> inv_roots_; + std::vector> complex_roots_; + std::vector> roots_; + std::vector> inv_roots_; }; -using FFT = FFTImpl; - #define DECL_FFT_TEMPLATE(T, prefix) \ - prefix template FFTImpl::FFTImpl(const u64 degree); \ - prefix template void FFTImpl::forwardFFT(MessageImpl &msg) const; \ - prefix template void FFTImpl::backwardFFT(MessageImpl &msg) const; + prefix template void FFT::forwardFFT(MessageImpl &msg) const; \ + prefix template void FFT::backwardFFT(MessageImpl &msg) const; -#define FFT_TYPE_TEMPLATE(prefix) DECL_FFT_TEMPLATE(Real, prefix) +#define FFT_TYPE_TEMPLATE(prefix) \ + DECL_FFT_TEMPLATE(Real, prefix) \ + DECL_FFT_TEMPLATE(float, prefix) FFT_TYPE_TEMPLATE(extern) diff --git a/include/deb/Macro.hpp b/include/deb/utils/Macro.hpp similarity index 58% rename from include/deb/Macro.hpp rename to include/deb/utils/Macro.hpp index 0d03c68..0517ce0 100644 --- a/include/deb/Macro.hpp +++ b/include/deb/utils/Macro.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,26 @@ #pragma once +#if !defined(__STDC_WANT_LIB_EXT1__) +#define __STDC_WANT_LIB_EXT1__ 1 +#endif + #include #include +#include // explicit_bzero / memset #include +#if defined(DEB_SECURE_ZERO_LIBSODIUM) +#include +#elif defined(DEB_SECURE_ZERO_OPENSSL) +#include +#elif defined(DEB_SECURE_ZERO_NATIVE) +#if defined(_WIN32) || defined(_WIN64) +#define NOMINMAX +#include +#endif +#endif + /** * @brief Helper macro exposing the GCC version as a single integer. * https://gcc.gnu.org/onlinedocs/cpp/Common-Predefined-Macros.html @@ -110,3 +126,60 @@ do { \ } while (0) #endif + +/* Compile-time detection of secure memory-zeroing primitive. + * Priority: + * 1. explicit_bzero – glibc >= 2.25, OpenBSD, FreeBSD >= 11, NetBSD, macOS + * 2. SecureZeroMemory – Windows + * 3. memset_s – C11 Annex K (implementation defines __STDC_LIB_EXT1__) + * 4. volatile loop – fallback + * + * On glibc without _GNU_SOURCE, string.h does not declare explicit_bzero even + * though the symbol exists in libc; supply a forward declaration in that case. + */ +#if defined(__GLIBC__) && \ + (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 25)) +#define DEB_HAVE_EXPLICIT_BZERO +#ifndef _GNU_SOURCE +extern void explicit_bzero(void *, size_t); +#endif +#elif defined(__OpenBSD__) || (defined(__FreeBSD__) && __FreeBSD__ >= 11) || \ + defined(__NetBSD__) || defined(__APPLE__) +#define DEB_HAVE_EXPLICIT_BZERO +#elif defined(_WIN32) || defined(_WIN64) +#define DEB_HAVE_SECURE_ZERO_MEMORY +#elif defined(__STDC_LIB_EXT1__) +#define DEB_HAVE_MEMSET_S +#endif + +/** + * @brief Securely zeroes memory to prevent sensitive data leakage. + * @param ptr Pointer to the memory region. + * @param len Length of the memory region in bytes. + */ +inline void deb_secure_zero(void *ptr, std::size_t len) noexcept { + if (ptr == nullptr || len == 0) + return; +#if defined(DEB_SECURE_ZERO_LIBSODIUM) + sodium_memzero(ptr, len); +#elif defined(DEB_SECURE_ZERO_OPENSSL) + OPENSSL_cleanse(ptr, len); +#elif defined(DEB_SECURE_ZERO_NATIVE) +#if defined(DEB_HAVE_SECURE_ZERO_MEMORY) + SecureZeroMemory(ptr, len); +#elif defined(DEB_HAVE_EXPLICIT_BZERO) + explicit_bzero(ptr, len); +#elif defined(DEB_HAVE_MEMSET_S) + memset_s(ptr, len, 0, len); +#else + // volatile byte loop — best-effort against compiler optimisation + volatile unsigned char *p = static_cast(ptr); + for (std::size_t i = 0; i < len; ++i) + p[i] = 0; +#endif +#else // DEB_SECURE_ZERO is not seted (NONE) + // Fallback to memset, does not guarantee zeroing against compiler + // optimizations, but better than nothing + std::memset(ptr, 0, len); +#endif +} diff --git a/include/deb/utils/ModArith.hpp b/include/deb/utils/ModArith.hpp index 4a40d08..c975fc0 100644 --- a/include/deb/utils/ModArith.hpp +++ b/include/deb/utils/ModArith.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,8 @@ #pragma once #include "CKKSTypes.hpp" -#include "Macro.hpp" #include "utils/Basic.hpp" +#include "utils/Macro.hpp" #include "utils/NTT.hpp" #include @@ -26,10 +26,23 @@ namespace deb::utils { +template struct DegreeTrait { + static constexpr Size degree = D; + DegreeTrait() = default; + DegreeTrait(Size) {} +}; +template <> struct DegreeTrait<1> { + Size degree; + DegreeTrait() = default; + DegreeTrait(Size deg) : degree(deg) {} +}; + /** * @brief Provides modular arithmetic utilities bound to a specific modulus. */ -class ModArith { +template class ModArith : public DegreeTrait { + using DegreeTrait::degree; + public: explicit ModArith() = default; /** @@ -37,7 +50,8 @@ class ModArith { * @param size Default vector size (poly degree). * @param prime Prime modulus. */ - explicit ModArith(Size size, u64 prime); + explicit ModArith(u64 prime); + explicit ModArith(Size degree, u64 prime); /** * @brief Returns the modulus associated with this instance. @@ -252,7 +266,8 @@ class ModArith { * @param num_polyunit Optional cap on processed units (0 = all). * @param expected_ntt_state Hint used to avoid redundant transforms. */ -void forwardNTT(const std::vector &modarith, Polynomial &poly, +template +void forwardNTT(const std::vector> &modarith, Polynomial &poly, Size num_polyunit = 0, [[maybe_unused]] bool expected_ntt_state = false); @@ -263,7 +278,8 @@ void forwardNTT(const std::vector &modarith, Polynomial &poly, * @param num_polyunit Optional cap on processed units. * @param expected_ntt_state Hint used to avoid redundant transforms. */ -void backwardNTT(const std::vector &modarith, Polynomial &poly, +template +void backwardNTT(const std::vector> &modarith, Polynomial &poly, Size num_polyunit = 0, [[maybe_unused]] bool expected_ntt_state = true); @@ -275,8 +291,13 @@ void backwardNTT(const std::vector &modarith, Polynomial &poly, * @param res Result polynomial. * @param num_polyunit Optional cap on processed units. */ -void addPoly(const std::vector &modarith, const Polynomial &op1, +template +void addPoly(const std::vector> &modarith, const Polynomial &op1, const Polynomial &op2, Polynomial &res, Size num_polyunit = 0); +template +void addPolyConst(const std::vector> &modarith, + const Polynomial &op1, const Polynomial &op2, Polynomial &res, + Size num_polyunit = 0); /** * @brief Subtracts @p op2 from @p op1 coefficient-wise. * @param modarith Per-prime helpers. @@ -285,7 +306,8 @@ void addPoly(const std::vector &modarith, const Polynomial &op1, * @param res Result polynomial. * @param num_polyunit Optional cap on processed units. */ -void subPoly(const std::vector &modarith, const Polynomial &op1, +template +void subPoly(const std::vector> &modarith, const Polynomial &op1, const Polynomial &op2, Polynomial &res, Size num_polyunit = 0); /** * @brief Multiplies two polynomials in the NTT domain. @@ -295,8 +317,13 @@ void subPoly(const std::vector &modarith, const Polynomial &op1, * @param res Result polynomial. * @param num_polyunit Optional cap on processed units. */ -void mulPoly(const std::vector &modarith, const Polynomial &op1, +template +void mulPoly(const std::vector> &modarith, const Polynomial &op1, const Polynomial &op2, Polynomial &res, Size num_polyunit = 0); +template +void mulPolyConst(const std::vector> &modarith, + const Polynomial &op1, const Polynomial &op2, Polynomial &res, + Size num_polyunit = 0); /** * @brief Multiplies a polynomial by a scalar vector within index range. * @param modarith Per-prime helpers. @@ -306,7 +333,37 @@ void mulPoly(const std::vector &modarith, const Polynomial &op1, * @param s_id Start index. * @param e_id End index (exclusive). */ -void constMulPoly(const std::vector &modarith, const Polynomial &op1, - const u64 *op2, Polynomial &res, Size s_id, Size e_id); - +template +void constMulPoly(const std::vector> &modarith, + const Polynomial &op1, const u64 *op2, Polynomial &res, + Size s_id, Size e_id); + +#define DECL_MODARITH_HELPER(degree, prefix) \ + prefix template class ModArith; \ + prefix template void forwardNTT(const std::vector> &, \ + Polynomial &, Size, bool); \ + prefix template void backwardNTT(const std::vector> &, \ + Polynomial &, Size, bool); \ + prefix template void addPoly(const std::vector> &, \ + const Polynomial &, const Polynomial &, \ + Polynomial &, Size); \ + prefix template void addPolyConst(const std::vector> &, \ + const Polynomial &, const Polynomial &, \ + Polynomial &, Size); \ + prefix template void subPoly(const std::vector> &, \ + const Polynomial &, const Polynomial &, \ + Polynomial &, Size); \ + prefix template void mulPoly(const std::vector> &, \ + const Polynomial &, const Polynomial &, \ + Polynomial &, Size); \ + prefix template void mulPolyConst(const std::vector> &, \ + const Polynomial &, const Polynomial &, \ + Polynomial &, Size); \ + prefix template void constMulPoly(const std::vector> &, \ + const Polynomial &, const u64 *, \ + Polynomial &, Size, Size); + +#define D(degree) DECL_MODARITH_HELPER(degree, extern) +DEGREE_SET +#undef D } // namespace deb::utils diff --git a/include/deb/utils/NTT.hpp b/include/deb/utils/NTT.hpp index dbc005d..12a3d8e 100644 --- a/include/deb/utils/NTT.hpp +++ b/include/deb/utils/NTT.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,11 +63,10 @@ class NTT { void computeBackward(u64 *op) const; private: - u64 prime_; - u64 two_prime_; - u64 degree_; + const u64 prime_; + const u64 two_prime_; + const u64 degree_; - // TODO(juny): make support constexpr for NTT // roots of unity (bit reversed) std::vector psi_rev_; std::vector psi_inv_rev_; @@ -84,4 +83,5 @@ class NTT { void computeBackwardNativeSingleStep(u64 *op, const u64 t) const; void computeBackwardNativeLast(u64 *op) const; }; + } // namespace deb::utils diff --git a/include/deb/utils/OmpUtils.hpp b/include/deb/utils/OmpUtils.hpp new file mode 100644 index 0000000..42a0b5a --- /dev/null +++ b/include/deb/utils/OmpUtils.hpp @@ -0,0 +1,28 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace deb::utils { +/** + * @brief Sets an OpenMP thread limit for the current process. + * @param max_threads Maximum number of threads; implementation-defined. + */ +void setOmpThreadLimit(int max_threads); +/** + * @brief Removes any OpenMP thread limit previously applied. + */ +void unsetOmpThreadLimit(); + +} // namespace deb::utils diff --git a/include/deb/utils/PresetTraits.hpp b/include/deb/utils/PresetTraits.hpp new file mode 100644 index 0000000..ba38abf --- /dev/null +++ b/include/deb/utils/PresetTraits.hpp @@ -0,0 +1,46 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "utils/ModArith.hpp" + +namespace deb { + +template struct PresetTraits; +#define X(preset) \ + template <> struct PresetTraits : public preset { \ + PresetTraits() = delete; \ + PresetTraits([[maybe_unused]] Preset p) {} \ + std::vector> modarith; \ + }; +PRESET_LIST +#undef X + +template <> struct PresetTraits { +#define CV(type, var_name) type var_name; + CONST_LIST +#undef CV + PresetTraits() = delete; + PresetTraits(Preset p) { +#define CV(type, var_name) this->var_name = get_##var_name(p); + CONST_LIST +#undef CV + } + std::vector> modarith; +}; + +} // namespace deb diff --git a/include/deb/utils/RandomGenerator.hpp b/include/deb/utils/RandomGenerator.hpp new file mode 100644 index 0000000..453bd55 --- /dev/null +++ b/include/deb/utils/RandomGenerator.hpp @@ -0,0 +1,53 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "Types.hpp" + +#include +#include +#include + +namespace deb { + +constexpr size_t DEB_RNG_SEED_BYTE_SIZE = 64; +constexpr size_t DEB_U64_SEED_SIZE = DEB_RNG_SEED_BYTE_SIZE / sizeof(u64); + +using RNGSeed = std::array; + +class RandomGenerator { +public: + virtual ~RandomGenerator() = default; + + virtual void getRandomUint64Array(u64 *dst, size_t len) = 0; + virtual void getRandomUint64ArrayInRange(u64 *dst, size_t len, + u64 range) = 0; + + virtual void sampleGaussianInt64Array(i64 *dst, size_t len, + double stdev) = 0; + virtual void sampleHwtInt8Array(i8 *dst, size_t len, int hwt) = 0; + + virtual void reseed(const u8 *seed, size_t seed_len) = 0; +}; + +using RandomGeneratorFactory = + std::function(const RNGSeed &seed)>; + +void setRandomGeneratorFactory(RandomGeneratorFactory factory); +std::shared_ptr createRandomGenerator(const RNGSeed &seed); + +} // namespace deb diff --git a/include/deb/utils/Span.hpp b/include/deb/utils/Span.hpp deleted file mode 100644 index 16d533b..0000000 --- a/include/deb/utils/Span.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright 2025 CryptoLab, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace deb { -/** - * @brief Lightweight view over contiguous memory similar to std::span. - * @tparam T Element type. - */ -template class span { -public: - /** - * @brief Creates a span pointing to @p ptr with @p size elements. - */ - constexpr span(const T *ptr, std::size_t size) noexcept - : ptr_(ptr), size_(size) {} - - /** - * @brief Creates a single-element span around @p ptr. - */ - constexpr span(const T *ptr) : ptr_(ptr), size_(1) {} - - /** - * @brief Creates a span that views the contents of a std::vector. - */ - constexpr span(const std::vector &vec) - : ptr_(vec.data()), size_(vec.size()) {} - - template - /** - * @brief Creates a span over a std::array with @p N elements. - */ - constexpr span(const std::array &arr) noexcept - : ptr_(arr.data()), size_(N) {} - - /** - * @brief Returns a const iterator to the first element. - */ - constexpr const T *begin() const noexcept { return ptr_; } - /** - * @brief Returns a const iterator past the last element. - */ - constexpr const T *end() const noexcept { return ptr_ + size_; } - - /** - * @brief Returns a mutable iterator to the first element. - */ - constexpr T *begin() noexcept { return const_cast(ptr_); } - /** - * @brief Returns a mutable iterator past the last element. - */ - constexpr T *end() noexcept { return const_cast(ptr_ + size_); } - - /** - * @brief Number of elements referenced by the span. - */ - constexpr std::size_t size() const noexcept { return size_; } - - /** - * @brief Provides mutable element access with no bounds checks. - */ - constexpr T &operator[](std::size_t index) { - return const_cast(ptr_[index]); - } - - /** - * @brief Provides read-only element access with no bounds checks. - */ - const T &operator[](std::size_t index) const { return ptr_[index]; } - - /** - * @brief Returns the underlying pointer. - */ - constexpr T *data() const noexcept { return const_cast(ptr_); } - - /** - * @brief Returns a subspan starting at @p offset with at most @p count - * elements. - * @param offset Starting index relative to this span. - * @param count Maximum number of elements to include (-1 for remainder). - * @return Span referencing the requested region (may be empty). - */ - constexpr span - subspan(std::size_t offset, - std::size_t count = static_cast(-1)) const { - if (offset >= size_) - return span(ptr_, 0); - std::size_t new_size = - (count == static_cast(-1)) ? (size_ - offset) : count; - return span(ptr_ + offset, std::min(new_size, size_ - offset)); - } - -private: - const T *ptr_; - const std::size_t size_; -}; -} // namespace deb diff --git a/prebuild/CMakeLists.txt b/prebuild/CMakeLists.txt index af4121d..9b9d5dc 100644 --- a/prebuild/CMakeLists.txt +++ b/prebuild/CMakeLists.txt @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -84,11 +84,12 @@ if(CMAKE_CROSSCOMPILING) add_custom_command( OUTPUT ${HOST_GENPARAM_DIR}/CMakeCache.txt COMMAND - ${CMAKE_COMMAND} -S ${PRE_BUILD_DIR} -B ${HOST_GENPARAM_DIR} + ${CMAKE_COMMAND} -S ${PRE_BUILD_SRC_DIR} -B ${HOST_GENPARAM_DIR} -DCMAKE_BUILD_TYPE=Release -DCMAKE_TOOLCHAIN_FILE= -DANDROID=OFF -DCMAKE_C_COMPILER=${HOST_CC} -DCMAKE_CXX_COMPILER=${HOST_CXX} -DCMAKE_CXX_STANDARD=17 -DCMAKE_CXX_STANDARD_REQUIRED=ON -DPREBUILD_HOST_MODE=ON -DPRE_BUILD_DIR=${PRE_BUILD_DIR} + -DPRE_BUILD_SRC_DIR=${PRE_BUILD_SRC_DIR} COMMENT "Configuring host GenParam with GCC" VERBATIM COMMAND_EXPAND_LISTS) diff --git a/prebuild/DebGenParam.cpp b/prebuild/DebGenParam.cpp index 10f181e..eaca953 100644 --- a/prebuild/DebGenParam.cpp +++ b/prebuild/DebGenParam.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/prebuild/DebPreComputeUtils.hpp b/prebuild/DebPreComputeUtils.hpp index 49b68cf..c41a930 100644 --- a/prebuild/DebPreComputeUtils.hpp +++ b/prebuild/DebPreComputeUtils.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -421,7 +421,6 @@ resolve_one(const std::string &name, out.RANK = cr.RANK; out.NUM_SECRET = cr.NUM_SECRET; out.GADGET_RANK = cr.GADGET_RANK; - const uint32_t degree = 1u << out.LOG_DEGREE; out.HWT = cr.HWT; out.PRIMES = @@ -499,17 +498,15 @@ static void write_header(const std::string &out_path, os << "#pragma once\n\n"; os << "#include \n"; os << "#include \n"; - os << "#include \n"; - os << "#include \n"; - os << "#include \n"; os << "\n"; os << "#include \"Types.hpp\"\n"; + os << "namespace deb {\n\n"; os << "#define PRESET_LIST"; for (const auto &final : finals) { os << " "; os << "X(" << final.NAME << ")"; } - os << "\nnamespace deb {\n\n"; + os << "\n#define PRESET_LIST_WITH_EMPTY PRESET_LIST X(EMPTY)\n\n"; os << "enum Preset {\n"; for (const auto &final : finals) { @@ -580,6 +577,25 @@ static void write_header(const std::string &out_path, << "};\n\n"; } + // PresetTraits + // os << "template struct PresetTraits;\n"; + // os << "#define X(preset) template <> struct PresetTraits + // " + // ": public preset {};\n"; + // os << "PRESET_LIST\n"; + // os << "#undef X\n"; + + // Degree set + std::unordered_set degrees; + for (const auto &p : finals_copy) { + degrees.insert(1u << p.LOG_DEGREE); + } + // degrees.erase(1); // remove degree 1 + os << "#define DEGREE_SET \\\n"; + for (const auto &d : degrees) { + os << "D(" << d << ") "; + } + os << "\n"; os << "} // namespace deb\n"; os.close(); } diff --git a/src/AleaRandomGenerator.cpp b/src/AleaRandomGenerator.cpp new file mode 100644 index 0000000..6383e39 --- /dev/null +++ b/src/AleaRandomGenerator.cpp @@ -0,0 +1,59 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/AleaRandomGenerator.hpp" +#include "utils/Macro.hpp" + +#include "alea/alea.h" +#include "alea/algorithms.h" + +namespace deb { + +AleaRandomGenerator::AleaRandomGenerator(const RNGSeed &seed) + : state_(alea_init(reinterpret_cast(seed.data()), + ALEA_ALGORITHM_SHAKE256)) { + deb_assert(state_ != nullptr, "Failed to initialize Alea RNG"); +} + +AleaRandomGenerator::~AleaRandomGenerator() { + if (state_) { + alea_free(state_); + } +} + +void AleaRandomGenerator::getRandomUint64Array(u64 *dst, size_t len) { + alea_get_random_uint64_array(state_, dst, len); +} + +void AleaRandomGenerator::getRandomUint64ArrayInRange(u64 *dst, size_t len, + u64 range) { + alea_get_random_uint64_array_in_range(state_, dst, len, range); +} + +void AleaRandomGenerator::sampleGaussianInt64Array(i64 *dst, size_t len, + double stdev) { + alea_sample_gaussian_int64_array(state_, dst, len, stdev); +} + +void AleaRandomGenerator::sampleHwtInt8Array(i8 *dst, size_t len, int hwt) { + alea_sample_hwt_int8_array(state_, dst, len, hwt); +} + +void AleaRandomGenerator::reseed(const u8 *seed, size_t /*seed_len*/) { + alea_reseed(state_, seed); +} + +} // namespace deb diff --git a/src/CKKSTypes.cpp b/src/CKKSTypes.cpp index a438267..f5137de 100644 --- a/src/CKKSTypes.cpp +++ b/src/CKKSTypes.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,12 +15,11 @@ */ #include "CKKSTypes.hpp" - namespace deb { -// --------------------------------------------------------------------- -// Implementation of Message -// --------------------------------------------------------------------- +//// --------------------------------------------------------------------- +//// Implementation of Message +//// --------------------------------------------------------------------- template MessageBase::MessageBase(const Size size) : data_(size) {} template @@ -58,103 +57,151 @@ MESSAGE_TYPE_TEMPLATE() // --------------------------------------------------------------------- // Implementation of PolyUnit // --------------------------------------------------------------------- -PolyUnit::PolyUnit(const Preset preset, const Size level) - : PolyUnit(getContext(preset), level) {} - -PolyUnit::PolyUnit(const Context &context, const Size level) - : prime_(context->get_primes()[level]), ntt_state_(false) { +PolyUnit::PolyUnit(const Preset preset, const Size level, const bool alloc) + : prime_(get_primes(preset)[level]), ntt_state_(false), + degree_(get_degree(preset)) { + if (!alloc) { + data_ptr_ = nullptr; + degree_ = 0; + return; + } #if DEB_ALINAS_LEN == 0 - data_ = std::shared_ptr>( - new span(new u64[context->get_degree()], context->get_degree()), - [](span *p) { - delete[] p->data(); - delete p; - }); + data_ptr_ = + std::shared_ptr(new u64[degree_], std::default_delete()); #else auto *buf = static_cast(::operator new[]( - sizeof(u64) * context->get_degree(), std::align_val_t(DEB_ALINAS_LEN))); - data_ = std::shared_ptr>( - new span(buf, context->get_degree()), [](span *p) { - ::operator delete[](p->data(), std::align_val_t(DEB_ALINAS_LEN)); - delete p; - }); + sizeof(u64) * degree_, std::align_val_t(DEB_ALINAS_LEN))); + data_ptr_ = std::shared_ptr(buf, [](u64 *p) { + ::operator delete[](p, std::align_val_t(DEB_ALINAS_LEN)); + }); #endif } -PolyUnit::PolyUnit(u64 prime, Size degree) : prime_(prime), ntt_state_(false) { +PolyUnit::PolyUnit(u64 prime, Size degree, const bool alloc) + : prime_(prime), ntt_state_(false), degree_(degree) { + if (!alloc) { + data_ptr_ = nullptr; + degree_ = 0; + return; + } #if DEB_ALINAS_LEN == 0 - data_ = std::shared_ptr>(new span(new u64[degree], degree), - [](span *p) { - delete[] p->data(); - delete p; - }); + data_ptr_ = + std::shared_ptr(new u64[degree_], std::default_delete()); #else auto *buf = static_cast(::operator new[]( - sizeof(u64) * degree, std::align_val_t(DEB_ALINAS_LEN))); - data_ = std::shared_ptr>( - new span(buf, degree), [](span *p) { - ::operator delete[](p->data(), std::align_val_t(DEB_ALINAS_LEN)); - delete p; - }); + sizeof(u64) * degree_, std::align_val_t(DEB_ALINAS_LEN))); + data_ptr_ = std::shared_ptr(buf, [](u64 *p) { + ::operator delete[](p, std::align_val_t(DEB_ALINAS_LEN)); + }); #endif } PolyUnit PolyUnit::deepCopy() const { - PolyUnit copy(prime_, degree()); - for (Size i = 0; i < degree(); ++i) { - copy[i] = (*this)[i]; + const bool alloc = data_ptr_ != nullptr && degree_ != 0; + PolyUnit copy(prime_, degree_, alloc); + if (alloc) { + for (Size i = 0; i < degree_; ++i) { + copy[i] = (*this)[i]; + } } copy.setNTT(ntt_state_); return copy; } + void PolyUnit::setPrime(u64 prime) noexcept { prime_ = prime; } u64 PolyUnit::prime() const noexcept { return prime_; } void PolyUnit::setNTT(bool ntt_state) noexcept { ntt_state_ = ntt_state; } bool PolyUnit::isNTT() const noexcept { return ntt_state_; } -Size PolyUnit::degree() const noexcept { - return static_cast(data_->size()); -} -u64 &PolyUnit::operator[](Size index) noexcept { return (*data_)[index]; } -u64 PolyUnit::operator[](Size index) const noexcept { return (*data_)[index]; } -u64 *PolyUnit::data() const noexcept { return data_->data(); } +Size PolyUnit::degree() const noexcept { return degree_; } void PolyUnit::setData(u64 *new_data, Size size) { - data_ = std::shared_ptr>(new span(new_data, size), - [](span *p) { delete p; }); + data_ptr_ = std::shared_ptr(new_data, [](u64 *p) { + // do nothing, external data + }); + degree_ = size; } // --------------------------------------------------------------------- // Implementation of Polynomial // --------------------------------------------------------------------- -Polynomial::Polynomial(const Preset preset, const bool full_level) - : Polynomial(getContext(preset), full_level) {} -Polynomial::Polynomial(Context context, const bool full_level) { - Size num_poly = - full_level ? context->get_num_p() : context->get_encryption_level() + 1; +Polynomial::Polynomial(const Preset preset, const bool full_level) { + const Size degree = get_degree(preset); + const Size num_poly = + full_level ? get_num_p(preset) : get_encryption_level(preset) + 1; +#if DEB_ALINAS_LEN == 0 + dealloc_ptr_ = std::shared_ptr(new u64[num_poly * degree], + std::default_delete()); +#else + auto *buf = static_cast( + std::aligned_alloc(DEB_ALINAS_LEN, sizeof(u64) * num_poly * degree)); + dealloc_ptr_ = std::shared_ptr(buf, [](u64 *p) { std::free(p); }); +#endif for (Size l = 0; l < num_poly; ++l) { - data_.emplace_back(context, l); + polyunits_.emplace_back(preset, l, false); + polyunits_[l].setData(dealloc_ptr_.get() + l * degree, degree); } } -Polynomial::Polynomial(Context context, const Size custom_size) { +Polynomial::Polynomial(const Preset preset, const Size custom_size) { + const Size degree = get_degree(preset); +#if DEB_ALINAS_LEN == 0 + dealloc_ptr_ = std::shared_ptr(new u64[custom_size * degree], + std::default_delete()); +#else + auto *buf = static_cast( + std::aligned_alloc(DEB_ALINAS_LEN, sizeof(u64) * custom_size * degree)); + dealloc_ptr_ = std::shared_ptr(buf, [](u64 *p) { std::free(p); }); +#endif for (Size l = 0; l < custom_size; ++l) { - data_.emplace_back(context, l); + polyunits_.emplace_back(preset, l, false); + polyunits_[l].setData(dealloc_ptr_.get() + l * degree, degree); } } Polynomial::Polynomial(const Polynomial &other, Size others_idx, Size custom_size) - : data_(&other.data_[others_idx], &other.data_[others_idx] + custom_size) {} + : polyunits_(&other.polyunits_[others_idx], + &other.polyunits_[others_idx] + custom_size), + dealloc_ptr_(nullptr) {} Polynomial Polynomial::deepCopy(std::optional num_polyunit) const { const auto num_polyunit_val = num_polyunit.value_or(this->size()); - Polynomial copy(*this); - copy.data_.clear(); - for (Size i = 0; i < num_polyunit_val; ++i) { - copy.data_.push_back(data_[i].deepCopy()); + deb_assert( + num_polyunit_val <= this->size(), + "[Polynomial::deepCopy] Requested number of polyunits exceeds size."); + Polynomial copy(*this, 0, 0); + copy.polyunits_.clear(); + if (dealloc_ptr_ != nullptr) { +#if DEB_ALINAS_LEN == 0 + copy.dealloc_ptr_ = std::shared_ptr( + new u64[num_polyunit_val * polyunits_[0].degree()], + std::default_delete()); +#else + auto *buf = static_cast(::operator new[]( + sizeof(u64) * num_polyunit_val * polyunits_[0].degree(), + std::align_val_t(DEB_ALINAS_LEN))); + copy.dealloc_ptr_ = std::shared_ptr(buf, [buf](u64 *p) { + ::operator delete[](buf, std::align_val_t(DEB_ALINAS_LEN)); + }); +#endif + for (Size i = 0; i < num_polyunit_val; ++i) { + copy.polyunits_.emplace_back(polyunits_[i].prime(), 0, false); + copy.polyunits_[i].setNTT(polyunits_[i].isNTT()); + copy.polyunits_[i].setData(copy.dealloc_ptr_.get() + + i * polyunits_[i].degree(), + polyunits_[i].degree()); + for (Size j = 0; j < polyunits_[i].degree(); ++j) { + copy.polyunits_[i][j] = polyunits_[i][j]; + } + } + } else { + copy.dealloc_ptr_ = nullptr; + for (Size i = 0; i < num_polyunit_val; ++i) { + copy.polyunits_.push_back(polyunits_[i].deepCopy()); + } } return copy; } void Polynomial::setNTT(bool ntt_state) noexcept { - for (auto &poly : data_) { + for (auto &poly : polyunits_) { poly.setNTT(ntt_state); } } @@ -164,53 +211,41 @@ void Polynomial::setLevel(Preset preset, Size level) { } Size Polynomial::level() const noexcept { - return static_cast(data_.size()) - 1; + return static_cast(polyunits_.size()) - 1; } void Polynomial::setSize(Preset preset, Size size) { - const auto context = getContext(preset); if (size <= this->size()) { - data_.erase(data_.begin() + size, data_.end()); + polyunits_.erase(polyunits_.begin() + size, polyunits_.end()); } else { - const auto max_len = context->get_num_p(); + const auto max_len = get_num_p(preset); for (Size l = this->size(); l < size; ++l) { - data_.emplace_back(context->get_primes()[l % max_len], - context->get_degree()); + polyunits_.emplace_back(get_primes(preset)[l % max_len], + get_degree(preset)); } } } Size Polynomial::size() const noexcept { - return static_cast(data_.size()); -} -PolyUnit &Polynomial::operator[](size_t index) noexcept { return data_[index]; } -const PolyUnit &Polynomial::operator[](size_t index) const noexcept { - return data_[index]; + return static_cast(polyunits_.size()); } -PolyUnit *Polynomial::data() noexcept { return data_.data(); } -const PolyUnit *Polynomial::data() const noexcept { return data_.data(); } // --------------------------------------------------------------------- // Implementation of Ciphertext // --------------------------------------------------------------------- -Ciphertext::Ciphertext(const Preset preset) : Ciphertext(getContext(preset)) {} -Ciphertext::Ciphertext(Context context) - : preset_(context->get_preset()), encoding_(SLOT) { - const Size num_polys = context->get_rank() * context->get_num_secret() + 1; +Ciphertext::Ciphertext(const Preset preset) : preset_(preset), encoding_(SLOT) { + const Size num_polys = get_rank(preset) * get_num_secret(preset) + 1; for (Size i = 0; i < num_polys; ++i) { - polys_.emplace_back(context); + polys_.emplace_back(preset); } } Ciphertext::Ciphertext(const Preset preset, const Size level, std::optional num_poly) - : Ciphertext(getContext(preset), level, num_poly) {} -Ciphertext::Ciphertext(Context context, const Size level, - std::optional num_poly) - : preset_(context->get_preset()), encoding_(UNKNOWN) { + : preset_(preset), encoding_(UNKNOWN) { const auto num_polys = - num_poly.value_or(context->get_rank() * context->get_num_secret() + 1); + num_poly.value_or(get_rank(preset) * get_num_secret(preset) + 1); for (Size i = 0; i < num_polys; ++i) { - polys_.emplace_back(context, level + 1); + polys_.emplace_back(preset, level + 1); } } Ciphertext::Ciphertext(const Ciphertext &other, Size others_idx) @@ -260,15 +295,6 @@ Size Ciphertext::numPoly() const noexcept { return static_cast(polys_.size()); } -Polynomial &Ciphertext::operator[](size_t index) noexcept { - return polys_[index]; -} -const Polynomial &Ciphertext::operator[](size_t index) const noexcept { - return polys_[index]; -} -Polynomial *Ciphertext::data() noexcept { return polys_.data(); } -const Polynomial *Ciphertext::data() const noexcept { return polys_.data(); } - // --------------------------------------------------------------------- // Implementation of SecretKey // --------------------------------------------------------------------- @@ -276,20 +302,33 @@ SecretKey::SecretKey(Preset preset, const RNGSeed seed) : preset_(preset), seed_(seed) {} SecretKey::SecretKey(Preset preset, bool embedding) : preset_(preset) { - Context context = getContext(preset); - coeffs_.resize(context->get_rank() * context->get_num_secret() * - context->get_degree(), - 0); + coeffs_.resize( + get_rank(preset) * get_num_secret(preset) * get_degree(preset), 0); if (embedding) { - const Size num_poly = context->get_rank() * context->get_num_secret(); + const Size num_poly = get_rank(preset) * get_num_secret(preset); for (Size i = 0; i < num_poly; ++i) { polys_.emplace_back(preset, true); } } } - +SecretKey::SecretKey(SecretKey &&other) noexcept + : preset_(other.preset_), seed_(std::move(other.seed_)), + coeffs_(std::move(other.coeffs_)), polys_(std::move(other.polys_)) { + other.zeroize(); +} +SecretKey &SecretKey::operator=(SecretKey &&other) noexcept { + if (this != &other) { + zeroize(); + preset_ = other.preset_; + seed_ = std::move(other.seed_); + coeffs_ = std::move(other.coeffs_); + polys_ = std::move(other.polys_); + other.zeroize(); + } + return *this; +} +SecretKey::~SecretKey() noexcept { zeroize(); } Preset SecretKey::preset() const noexcept { return preset_; } - bool SecretKey::hasSeed() const noexcept { return seed_.has_value(); } RNGSeed SecretKey::getSeed() const noexcept { return seed_.value(); } void SecretKey::setSeed(const RNGSeed &seed) noexcept { seed_.emplace(seed); } @@ -299,11 +338,9 @@ Size SecretKey::coeffsSize() const noexcept { return static_cast(coeffs_.size()); } void SecretKey::allocCoeffs() { - auto context = getContext(preset_); coeffs_.clear(); - coeffs_.resize(context->get_rank() * context->get_num_secret() * - context->get_degree(), - 0); + coeffs_.resize( + get_rank(preset_) * get_num_secret(preset_) * get_degree(preset_), 0); } i8 &SecretKey::coeff(Size index) noexcept { return coeffs_[index]; } i8 SecretKey::coeff(Size index) const noexcept { return coeffs_[index]; } @@ -312,50 +349,43 @@ const i8 *SecretKey::coeffs() const noexcept { return coeffs_.data(); } Size SecretKey::numPoly() const noexcept { return static_cast(polys_.size()); } +void SecretKey::zeroize() noexcept { + if (!coeffs_.empty()) { + deb_secure_zero(coeffs_.data(), coeffs_.size() * sizeof(i8)); + } + if (seed_.has_value()) { + deb_secure_zero(seed_->data(), seed_->size() * sizeof(u64)); + seed_.reset(); + } + for (auto &poly : polys_) { + for (Size i = 0; i < poly.size(); ++i) { + deb_secure_zero(poly[i].data(), poly[i].degree() * sizeof(u64)); + } + } +} + void SecretKey::allocPolys(std::optional num_polyunit) { - const auto context = getContext(preset_); - num_polyunit = num_polyunit.value_or(context->get_num_p()); - const Size num_poly = context->get_rank() * context->get_num_secret(); + num_polyunit = num_polyunit.value_or(get_num_p(preset_)); + const Size num_poly = get_rank(preset_) * get_num_secret(preset_); polys_.clear(); for (Size i = 0; i < num_poly; ++i) { - polys_.emplace_back(context, num_polyunit.value()); + polys_.emplace_back(preset_, num_polyunit.value()); } } -Polynomial &SecretKey::operator[](Size index) { return polys_[index]; } -const Polynomial &SecretKey::operator[](Size index) const { - return polys_[index]; -} -Polynomial *SecretKey::data() noexcept { return polys_.data(); } -const Polynomial *SecretKey::data() const noexcept { return polys_.data(); } -// SwitchKey Implementation +// --------------------------------------------------------------------- +// Implementation of SwitchKey +// --------------------------------------------------------------------- SwitchKey::SwitchKey(Preset preset, const SwitchKeyKind type, const std::optional rot_idx) - : SwitchKey(getContext(preset), type, rot_idx) {} -SwitchKey::SwitchKey(const Context &context, const SwitchKeyKind type, - const std::optional rot_idx) - : preset_(context->get_preset()), type_(type), rot_idx_(rot_idx), - dnum_(context->get_gadget_rank()) { - switch (type_) { - case SWK_ENC: - addAx(context->get_num_p(), 1, true); - addBx(context->get_num_p(), context->get_num_secret(), true); - break; - case SWK_MULT: - case SWK_CONJ: - case SWK_ROT: - case SWK_AUTO: - case SWK_MODPACK: - case SWK_COMPOSE: - case SWK_DECOMPOSE: - addAx(context->get_num_p(), dnum_, true); - addBx(context->get_num_p(), dnum_ * context->get_num_secret(), true); - break; - case SWK_MODPACK_SELF: - case SWK_GENERIC: - default: - break; + : preset_(preset), type_(type), rot_idx_(rot_idx), + dnum_(get_gadget_rank(preset)) { + if (type_ == SWK_MODPACK_SELF || type_ == SWK_GENERIC) { + return; } + const Size size = (type_ == SWK_ENC) ? 1 : dnum_; + addAx(get_num_p(preset), size, true); + addBx(get_num_p(preset), size * get_num_secret(preset), true); } Preset SwitchKey::preset() const noexcept { return preset_; } @@ -379,8 +409,7 @@ void SwitchKey::addAx(const Size num_polyunit, std::optional size, void SwitchKey::addAx(const Polynomial &poly) { ax_.push_back(poly); } void SwitchKey::addBx(const Size num_polyunit, std::optional size, const bool ntt_state) { - const auto num_poly = - size.value_or(dnum_ * getContext(preset_)->get_num_secret()); + const auto num_poly = size.value_or(dnum_ * get_num_secret(preset_)); for (Size i = 0; i < num_poly; ++i) { bx_.emplace_back(preset_, num_polyunit); } diff --git a/src/Decryptor.cpp b/src/Decryptor.cpp index bc4ed1a..725f013 100644 --- a/src/Decryptor.cpp +++ b/src/Decryptor.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ */ #include "Decryptor.hpp" -#include "CKKSTypes.hpp" #include "utils/Basic.hpp" #include "utils/NTT.hpp" +#include "utils/OmpUtils.hpp" #include #ifdef DEB_OPENMP @@ -28,24 +28,37 @@ namespace deb { constexpr Size MAX_DECRYPT_SIZE = 2; -Decryptor::Decryptor(const Preset preset) - : context_(getContext(preset)), fft_(context_->get_degree()) { +template +DecryptorT

::DecryptorT() : PresetTraits

(preset), fft_(degree) { + if constexpr (P == PRESET_EMPTY) { + throw std::runtime_error("[Decryptor] Preset template must be " + "specified when preset is not given"); + } for (Size i = 0; i < MAX_DECRYPT_SIZE; ++i) { - modarith_.emplace_back(context_->get_degree(), - context_->get_primes()[i]); + modarith.emplace_back(primes[i]); } } +template +DecryptorT

::DecryptorT(const Preset preset) + : PresetTraits

(preset), fft_(degree) { + for (Size i = 0; i < MAX_DECRYPT_SIZE; ++i) { + modarith.emplace_back(degree, primes[i]); + } +} + +template template >, int>> -void Decryptor::decrypt(const Ciphertext &ctxt, const SecretKey &sk, MSG &msg, - Real scale) const { +void DecryptorT

::decrypt(const Ciphertext &ctxt, const SecretKey &sk, + MSG &msg, Real scale) const { decrypt(ctxt, sk, &msg, scale); } +template template -void Decryptor::decrypt(const Ciphertext &ctxt, const SecretKey &sk, MSG *msg, - Real scale) const { +void DecryptorT

::decrypt(const Ciphertext &ctxt, const SecretKey &sk, + MSG *msg, Real scale) const { deb_assert(ctxt.numPoly() > 0, "[Decryptor::decrypt] Ciphertext size is zero"); deb_assert(sk.numPoly() > 0, @@ -54,79 +67,85 @@ void Decryptor::decrypt(const Ciphertext &ctxt, const SecretKey &sk, MSG *msg, "[Decryptor::decrypt] Level of secret key must be greater than " "or equal to ciphertext level"); if (scale == 0) - scale = - std::pow(2.0, -context_->get_scale_factors()[ctxt[0].size() - 1]); + scale = std::pow(2.0, -scale_factors[ctxt[0].size() - 1]); else scale = 1.0 / scale; const int max_num_threads = - static_cast(ctxt[0].size() * (context_->get_degree() >> 10)); - setOmpThreadLimit(max_num_threads); + static_cast(ctxt[0].size() * (degree >> 10)); + utils::setOmpThreadLimit(max_num_threads); Ciphertext ctxt_copy = ctxt.deepCopy(std::min(ctxt[0].size(), MAX_DECRYPT_SIZE)); Polynomial &ax = ctxt_copy[ctxt_copy.numPoly() - 1]; if (!ax[0].isNTT()) { - forwardNTT(modarith_, ax); + forwardNTT(modarith, ax); } - for (Size i = 0; i < context_->get_num_secret(); ++i) { + for (Size i = 0; i < num_secret; ++i) { Ciphertext ctxt_tmp(ctxt_copy, i); for (Size j = 0; j < ctxt_tmp.numPoly(); ++j) { if (!ctxt_tmp[j][0].isNTT()) { - forwardNTT(modarith_, ctxt_tmp[j]); + forwardNTT(modarith, ctxt_tmp[j]); } } - if constexpr (std::is_same_v) { - Polynomial ptxt_tmp = innerDecrypt(ctxt_tmp, sk, ax); + if constexpr (std::is_same_v || + std::is_same_v) { + Polynomial ptxt_tmp = innerDecrypt(ctxt_tmp, sk[i], ax); decode(ptxt_tmp, msg[i], scale); - } else if constexpr (std::is_same_v) { - Polynomial ptxt_tmp = innerDecrypt(ctxt_tmp, sk, ax); + } else if constexpr (std::is_same_v || + std::is_same_v) { + Polynomial ptxt_tmp = innerDecrypt(ctxt_tmp, sk[i], ax); decodeWithoutFFT(ptxt_tmp, msg[i], scale); } else { throw std::runtime_error( "[Decryptor::decrypt] Unsupported message type"); } } - unsetOmpThreadLimit(); + utils::unsetOmpThreadLimit(); } -DECRYPT_TYPE_TEMPLATE() - -Polynomial Decryptor::innerDecrypt(const Ciphertext &ctxt, const SecretKey &sk, - const std::optional &ax) const { - Polynomial ptxt(context_, std::min(ctxt[0].size(), MAX_DECRYPT_SIZE)); +template +Polynomial +DecryptorT

::innerDecrypt(const Ciphertext &ctxt, const Polynomial &sx, + const std::optional &ax) const { + Polynomial ptxt(preset, std::min(ctxt[0].size(), MAX_DECRYPT_SIZE)); for (u64 i = 0; i < ptxt.size(); ++i) { ptxt[i].setNTT(ctxt[0][i].isNTT()); } // m = c_0 + (c_1 + ... + (c_{n-1} + c_n * s) * s ... ) * s - u64 idx = ctxt.numPoly() - 1; - const Polynomial &tmp = (ax.has_value()) ? ax.value() : ctxt[idx--]; + u64 last_idx = ctxt.numPoly() - 1; + const Polynomial &tmp = (ax.has_value()) ? ax.value() : ctxt[last_idx--]; + PRAGMA_OMP(omp parallel) { - mulPoly(modarith_, tmp, sk[0], ptxt); - addPoly(modarith_, ptxt, ctxt[idx], ptxt); + u64 idx = last_idx; + mulPolyConst(modarith, tmp, sx, ptxt); + addPoly(modarith, ptxt, ctxt[idx], ptxt); while (idx != 0) { - mulPoly(modarith_, ptxt, sk[0], ptxt); - addPoly(modarith_, ptxt, ctxt[--idx], ptxt); + mulPolyConst(modarith, ptxt, sx, ptxt); + addPoly(modarith, ptxt, ctxt[--idx], ptxt); } } return ptxt; } -void Decryptor::decodeWithSinglePoly(const Polynomial &ptxt, - CoeffMessage &coeff, Real scale) const { + +template +template +void DecryptorT

::decodeWithSinglePoly(const Polynomial &ptxt, CMSG &coeff, + Real scale) const { const u64 ptxt_degree = ptxt[0].degree(); - const auto full_degree = static_cast(context_->get_degree()); + const auto full_degree = static_cast(degree); deb_assert(coeff.size() >= ptxt_degree, "[Decryptor::decodeWithSinglePoly] Coeff size is too small"); - const u64 prime = context_->get_primes()[0]; + const u64 prime = primes[0]; const u64 half_prime = prime >> 1; const auto gap = static_cast(full_degree / ptxt_degree); u64 *interim = ptxt[0].data(); if (ptxt[0].isNTT()) { - modarith_[0].backwardNTT(interim); + modarith[0].backwardNTT(interim); } Real tmp; @@ -137,37 +156,46 @@ void Decryptor::decodeWithSinglePoly(const Polynomial &ptxt, } else { tmp = static_cast(interim[idx]); } - coeff[i] = tmp * scale; + if constexpr (std::is_same_v) { + coeff[i] = tmp * scale; + } else if constexpr (std::is_same_v) { + coeff[i] = static_cast(tmp * scale); + } } } -void Decryptor::decodeWithPolyPair(const Polynomial &ptxt, CoeffMessage &coeff, - Real scale) const { - // const Real scale_factor = context_->get_scale_factors()[ptxt.size - - // 1]; - const auto full_degree = static_cast(context_->get_degree()); +template +template +void DecryptorT

::decodeWithPolyPair(const Polynomial &ptxt, CMSG &coeff, + Real scale) const { + const auto full_degree = static_cast(degree); const auto ptxt_degree = static_cast(ptxt[0].degree()); deb_assert(coeff.size() >= ptxt_degree, "[Decryptor::decodeWithPolyPair] Coeff size is too small"); - const auto prime0 = context_->get_primes()[0]; - const auto prime1 = context_->get_primes()[1]; + const auto prime0 = primes[0]; + const auto prime1 = primes[1]; const utils::u128 prod_prime = utils::mul64To128(prime0, prime1); const utils::u128 half_prod_prime = prod_prime >> 1; - const u64 bezout0 = modarith_[1].inverse(prime0); - const u64 bezout1 = modarith_[0].inverse(prime1); + const u64 bezout0 = modarith[1].inverse(prime0); + const u64 bezout1 = modarith[0].inverse(prime1); u64 *ptxt0 = ptxt[0].data(); u64 *ptxt1 = ptxt[1].data(); if (ptxt[0].isNTT()) { - modarith_[0].backwardNTT(ptxt0); - modarith_[1].backwardNTT(ptxt1); + modarith[0].backwardNTT(ptxt0); + modarith[1].backwardNTT(ptxt1); } - modarith_[0].constMultInPlace(ptxt0, bezout1); - modarith_[1].constMultInPlace(ptxt1, bezout0); + modarith[0].constMultInPlace(ptxt0, bezout1); + modarith[1].constMultInPlace(ptxt1, bezout0); std::vector interim(full_degree); + + Real tmp; + auto gap = static_cast(full_degree / ptxt_degree); + + PRAGMA_OMP(omp parallel for schedule(static)) for (Size i = 0; i < full_degree; i++) { interim[i] = utils::mul64To128(ptxt0[i], prime1) + utils::mul64To128(ptxt1[i], prime0); @@ -175,21 +203,27 @@ void Decryptor::decodeWithPolyPair(const Polynomial &ptxt, CoeffMessage &coeff, (interim[i] >= prod_prime) ? interim[i] - prod_prime : interim[i]; } - Real tmp; - auto gap = static_cast(full_degree / ptxt_degree); - for (Size i = 0, idx = 0; i < ptxt_degree; i++, idx += gap) { if (interim[idx] > half_prod_prime) { tmp = -1.0 * static_cast(prod_prime - interim[idx]); } else { tmp = static_cast(interim[idx]); } - coeff[i] = tmp * scale; + if constexpr (std::is_same_v) { + coeff[i] = tmp * scale; + } else if constexpr (std::is_same_v) { + coeff[i] = static_cast(tmp * scale); + } else { + throw std::runtime_error( + "[Decryptor::decodeWithPolyPair] Unsupported message type"); + } } } -void Decryptor::decodeWithoutFFT(const Polynomial &ptxt, CoeffMessage &coeff, - Real scale) const { +template +template +void DecryptorT

::decodeWithoutFFT(const Polynomial &ptxt, CMSG &coeff, + Real scale) const { if (ptxt.size() != 1) { decodeWithPolyPair(ptxt, coeff, scale); } else { @@ -197,19 +231,40 @@ void Decryptor::decodeWithoutFFT(const Polynomial &ptxt, CoeffMessage &coeff, } } -void Decryptor::decode(const Polynomial &ptxt, Message &msg, Real scale) const { +template +template +void DecryptorT

::decode(const Polynomial &ptxt, MSG &msg, Real scale) const { - deb_assert(msg.size() >= context_->get_num_slots(), + deb_assert(msg.size() >= num_slots, "[Decryptor::decode] Message size is too small"); - CoeffMessage coeff(context_); - decodeWithoutFFT(ptxt, coeff, scale); + if constexpr (std::is_same_v) { + CoeffMessage coeff(preset); + decodeWithoutFFT(ptxt, coeff, scale); + + const auto half_degree = num_slots; + for (Size i = 0; i < msg.size(); ++i) { + msg[i].real(coeff[i]); + msg[i].imag(coeff[i + half_degree]); + } + fft_.forwardFFT(msg); + } else if constexpr (std::is_same_v) { + FCoeffMessage coeff(preset); + decodeWithoutFFT(ptxt, coeff, scale); - const auto half_degree = context_->get_num_slots(); - for (Size i = 0; i < msg.size(); ++i) { - msg[i].real(coeff[i]); - msg[i].imag(coeff[i + half_degree]); + const auto half_degree = num_slots; + for (Size i = 0; i < msg.size(); ++i) { + msg[i].real(coeff[i]); + msg[i].imag(coeff[i + half_degree]); + } + fft_.forwardFFT(msg); + } else { + throw std::runtime_error( + "[Decryptor::decode] Unsupported message type"); } - fft_.forwardFFT(msg); } +#define X(preset) DECRYPT_TYPE_TEMPLATE(PRESET_##preset, ) +PRESET_LIST_WITH_EMPTY +#undef X + } // namespace deb diff --git a/src/Encryptor.cpp b/src/Encryptor.cpp index 0f9f0e7..347230f 100644 --- a/src/Encryptor.cpp +++ b/src/Encryptor.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,13 +15,7 @@ */ #include "Encryptor.hpp" -#include "CKKSTypes.hpp" -#include "SecretKeyGenerator.hpp" -#include "utils/Basic.hpp" - -#include "alea/algorithms.h" - -#include +#include "utils/OmpUtils.hpp" #ifdef DEB_OPENMP #include @@ -29,68 +23,102 @@ namespace deb { -Encryptor::Encryptor(const Preset preset, std::optional seeds) - : context_(getContext(preset)), - ptxt_buffer_(context_, - context_->get_num_p() * context_->get_num_secret()), - vx_buffer_(context_, true), fft_(context_->get_degree()) { +template +EncryptorT

::EncryptorT(std::optional seeds) + : PresetTraits

(preset), ptxt_buffer_(preset, num_p * num_secret), + vx_buffer_(preset, true), ex_buffer_(preset, true), samples_(degree), + mask_(degree), i_samples_(degree), fft_(degree) { + if constexpr (P == PRESET_EMPTY) { + throw std::runtime_error( + "[Encryptor] Preset template must be specified when using this " + "constructor"); + } - for (Size i = 0; i < context_->get_num_p(); ++i) { - modarith_.emplace_back(context_->get_degree(), - context_->get_primes()[i]); + for (Size i = 0; i < num_p; ++i) { + modarith.emplace_back(primes[i]); } - for (Size i = 0; i < context_->get_num_secret() + 1; ++i) { - ex_buffers_.emplace_back(context_, true); + + if (!seeds) { + seeds.emplace(SeedGenerator::Gen()); } + rng_ = createRandomGenerator(seeds.value()); +} + +template +EncryptorT

::EncryptorT(Preset actual_preset, + std::optional seeds) + : PresetTraits

(actual_preset), + ptxt_buffer_(actual_preset, num_p * num_secret), + vx_buffer_(actual_preset, true), ex_buffer_(actual_preset, true), + samples_(degree), mask_(degree), i_samples_(degree), fft_(degree) { + + for (Size i = 0; i < num_p; ++i) { + modarith.emplace_back(degree, primes[i]); + } + if (!seeds) { seeds.emplace(SeedGenerator::Gen()); } - as_ = std::shared_ptr( - alea_init(to_alea_seed(seeds.value()), ALEA_ALGORITHM_SHAKE256), - [](void *p) { alea_free(static_cast(p)); }); + rng_ = createRandomGenerator(seeds.value()); +} + +template +EncryptorT

::EncryptorT(Preset actual_preset, + std::shared_ptr rng) + : PresetTraits

(actual_preset), + ptxt_buffer_(actual_preset, num_p * num_secret), + vx_buffer_(actual_preset, true), ex_buffer_(actual_preset, true), + samples_(degree), mask_(degree), i_samples_(degree), rng_(std::move(rng)), + fft_(degree) { + + for (Size i = 0; i < num_p; ++i) { + modarith.emplace_back(degree, primes[i]); + } } +template template >, int>> -void Encryptor::encrypt(const MSG &msg, const KEY &key, Ciphertext &ctxt, - const EncryptOptions &opt) const { - deb_assert(context_->get_num_secret() == 1, +void EncryptorT

::encrypt(const MSG &msg, const KEY &key, Ciphertext &ctxt, + const EncryptOptions &opt) const { + deb_assert(num_secret == 1, "[Encryptor::encrypt] NumSecret must be 1 for a single message " "encryption"); encrypt(&msg, key, ctxt, opt); } +template template -void Encryptor::encrypt(const std::vector &msg, const KEY &key, - Ciphertext &ctxt, const EncryptOptions &opt) const { - deb_assert(msg.size() == context_->get_num_secret(), +void EncryptorT

::encrypt(const std::vector &msg, const KEY &key, + Ciphertext &ctxt, const EncryptOptions &opt) const { + deb_assert(msg.size() == num_secret, "[Encryptor::encrypt] Message vector size must match NumSecret"); encrypt(msg.data(), key, ctxt, opt); } +template template -void Encryptor::encrypt(const MSG *msg, const KEY &key, Ciphertext &ctxt, - const EncryptOptions &opt) const { - const Size single_num_polyunit = (opt.level == DEB_MAX_SIZE) - ? context_->get_encryption_level() + 1 +void EncryptorT

::encrypt(const MSG *msg, const KEY &key, Ciphertext &ctxt, + const EncryptOptions &opt) const { + const Size single_num_polyunit = (opt.level == utils::DEB_MAX_SIZE) + ? encryption_level + 1 : opt.level + 1; - const Size num_secret = context_->get_num_secret(); const Size num_polyunit = single_num_polyunit * num_secret; - deb_assert( - single_num_polyunit - 1 <= context_->get_num_p(), - "[Encryptor::encrypt] Encryption level cannot exceed number of primes"); - deb_assert((num_secret == 1 || context_->get_rank() == 1), + deb_assert(single_num_polyunit - 1 <= num_p, + "[Encryptor::encrypt] Encryption level cannot exceed number of " + "primes"); + deb_assert((num_secret == 1 || rank == 1), "[Encryptor::encrypt] Rank must be 1 when NumSecret > 1" " or NumSecret must be 1 when Rank > 1"); const int max_num_threads = - static_cast(single_num_polyunit * (context_->get_degree() >> 10)); - setOmpThreadLimit(max_num_threads); + static_cast(single_num_polyunit * (degree >> 10)); + utils::setOmpThreadLimit(max_num_threads); Polynomial ptxt(ptxt_buffer_, 0, num_polyunit); for (Size i = 0; i < num_polyunit; ++i) { - ptxt[i].setPrime(context_->get_primes()[i % single_num_polyunit]); + ptxt[i].setPrime(primes[i % single_num_polyunit]); } if (num_secret > 1) { @@ -104,9 +132,11 @@ void Encryptor::encrypt(const MSG *msg, const KEY &key, Ciphertext &ctxt, } innerEncrypt(ptxt, key, single_num_polyunit, ctxt); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v || + std::is_same_v) { ctxt.setEncoding(SLOT); - } else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v || + std::is_same_v) { ctxt.setEncoding(COEFF); } else { throw std::runtime_error( @@ -115,257 +145,266 @@ void Encryptor::encrypt(const MSG *msg, const KEY &key, Ciphertext &ctxt, if (!opt.ntt_out) { for (u64 i = 0; i < ctxt.numPoly(); ++i) { - backwardNTT(modarith_, ctxt[i]); + backwardNTT(modarith, ctxt[i]); } } - unsetOmpThreadLimit(); + utils::unsetOmpThreadLimit(); } -template <> -void Encryptor::innerEncrypt(const Polynomial &ptxt, - const SecretKey &secretkey, - Size num_polyunit, - Ciphertext &ctxt) const { - const Size rank = context_->get_rank(); - const Size num_secret = context_->get_num_secret(); +template +template +void EncryptorT

::innerEncrypt(const Polynomial &ptxt, const KEY &key, + Size num_polyunit, Ciphertext &ctxt) const { deb_assert(ptxt.size() >= num_polyunit * num_secret, "[Encryptor::innerEncrypt] Level of an input Plaintext " "must be greater than or equal to encryption level"); - deb_assert( - secretkey.numPoly() == num_secret * rank, - "[Encryptor::innerEncrypt] Secret key has no embedded polynomials."); - deb_assert( - rank == 1 || num_secret == 1, - "[Encryptor::innerEncrypt] Rank must be 1 or NumSecret must be 1"); + deb_assert(rank == 1 || num_secret == 1, + "[Encryptor::innerEncrypt] Rank must be 1 or NumSecret must be " + "1"); bool isNTT = ptxt[0].isNTT(); - ctxt.setNumPolyunit(num_polyunit); ctxt.setNTT(true); - for (u64 i = 0; i < num_polyunit; ++i) { - alea_get_random_uint64_array_in_range( - as_.get(), ctxt[num_secret][i].data(), context_->get_degree(), - context_->get_primes()[i]); - } - - if (rank == 1) { - // std::vector ex_vec; - std::vector ptxt_vec; - for (Size i = 0; i < num_secret; ++i) { - sampleGaussian(i, num_polyunit, isNTT); - if (i == 0) - ptxt_vec.push_back(ptxt); - else - ptxt_vec.emplace_back(ptxt, i * num_polyunit, num_polyunit); + if constexpr (std::is_same_v) { + deb_assert(key.numPoly() == num_secret * rank, + "[Encryptor::innerEncrypt] Secret key has no embedded " + "polynomials."); + for (u64 i = 0; i < num_polyunit; ++i) { + rng_->getRandomUint64ArrayInRange(ctxt[num_secret][i].data(), + degree, primes[i]); } - PRAGMA_OMP(omp parallel) { + if (rank == 1) { + std::vector ptxt_vec; for (Size i = 0; i < num_secret; ++i) { - // e = e + m - addPoly(modarith_, ex_buffers_[i], ptxt_vec[i], ex_buffers_[i], - num_polyunit); - // perform delayed NTT - if (!isNTT) { - forwardNTT(modarith_, ex_buffers_[i], num_polyunit); - } - mulPoly(modarith_, ctxt[num_secret], secretkey[i], ctxt[i]); - subPoly(modarith_, ex_buffers_[i], ctxt[i], ctxt[i]); + if (i == 0) + ptxt_vec.push_back(ptxt); + else + ptxt_vec.emplace_back(ptxt, i * num_polyunit, num_polyunit); } - } - } else { - sampleGaussian(0, num_polyunit, isNTT); - - // e = e + m - addPoly(modarith_, ex_buffers_[0], ptxt, ex_buffers_[0], num_polyunit); - // perform delayed NTT - if (!isNTT) { - forwardNTT(modarith_, ex_buffers_[0], num_polyunit); - } - // TODO: not tested yet since no preset of rank > 1 - // b = - \sigma a_i * s_i + e + m - Polynomial bx(ctxt[0], 0, num_polyunit); - Polynomial tmp(context_, num_polyunit); - for (Size idx = 1; idx < ctxt.numPoly(); ++idx) { - mulPoly(modarith_, ctxt[idx], secretkey[idx - 1], tmp); - subPoly(modarith_, bx, tmp, bx); - } - } -} + PRAGMA_OMP(omp parallel) { + for (Size i = 0; i < num_secret; ++i) { + sampleGaussian(num_polyunit, isNTT); + // e = e + m + addPoly(modarith, ex_buffer_, ptxt_vec[i], ex_buffer_, + num_polyunit); + // perform delayed NTT + if (!isNTT) { + forwardNTT(modarith, ex_buffer_, num_polyunit); + } + mulPolyConst(modarith, ctxt[num_secret], key[i], ctxt[i]); + subPoly(modarith, ex_buffer_, ctxt[i], ctxt[i]); + } + } + } else { + Polynomial bx(ctxt[0], 0, num_polyunit); + Polynomial tmp(preset, num_polyunit); -template <> -void Encryptor::innerEncrypt(const Polynomial &ptxt, - const SwitchKey &enckey, - Size num_polyunit, - Ciphertext &ctxt) const { - const auto rank = context_->get_rank(); - const auto num_secret = context_->get_num_secret(); - deb_assert(ptxt.size() >= num_polyunit * num_secret, - "[Encryptor::innerEncrypt] Level of an input Plaintext " - "must be greater than or equal to encryption level"); - deb_assert( - rank == 1 || num_secret == 1, - "[Encryptor::innerEncrypt] Rank must be 1 or NumSecret must be 1"); + PRAGMA_OMP(omp parallel) { + sampleGaussian(num_polyunit, isNTT); - bool isNTT = ptxt[0].isNTT(); - ctxt.setNumPolyunit(num_polyunit); - ctxt.setNTT(true); + // e = e + m + addPoly(modarith, ex_buffer_, ptxt, ex_buffer_, num_polyunit); - sampleZO(num_polyunit); - sampleGaussian(num_secret, num_polyunit, true); - if (rank == 1) { - std::vector ptxt_vec; - for (Size i = 0; i < num_secret; ++i) { - sampleGaussian(i, num_polyunit, isNTT); - if (i == 0) - ptxt_vec.push_back(ptxt); - else - ptxt_vec.emplace_back(ptxt, i * num_polyunit, num_polyunit); + // perform delayed NTT + if (!isNTT) { + forwardNTT(modarith, ex_buffer_, num_polyunit); + } + // TODO: not tested yet since no preset of rank > 1 + // b = - \sigma a_i * s_i + e + m + for (Size idx = 1; idx < ctxt.numPoly(); ++idx) { + mulPolyConst(modarith, ctxt[idx], key[idx - 1], tmp); + subPoly(modarith, bx, tmp, bx); + } + } } - - PRAGMA_OMP(omp parallel) { - mulPoly(modarith_, vx_buffer_, enckey.ax(0), ctxt[num_secret], - num_polyunit); - addPoly(modarith_, ctxt[num_secret], ex_buffers_[num_secret], - ctxt[num_secret]); + } else if constexpr (std::is_same_v) { + if (rank == 1) { + std::vector ptxt_vec; for (Size i = 0; i < num_secret; ++i) { + if (i == 0) + ptxt_vec.push_back(ptxt); + else + ptxt_vec.emplace_back(ptxt, i * num_polyunit, num_polyunit); + } - mulPoly(modarith_, vx_buffer_, enckey.bx(i), ctxt[i], - num_polyunit); - addPoly(modarith_, ex_buffers_[i], ptxt_vec[i], ex_buffers_[i], - num_polyunit); - - if (!isNTT) { - forwardNTT(modarith_, ex_buffers_[i], num_polyunit); + PRAGMA_OMP(omp parallel) { + sampleZO(num_polyunit); + sampleGaussian(num_polyunit, true); + mulPolyConst(modarith, vx_buffer_, key.ax(0), ctxt[num_secret], + num_polyunit); + addPoly(modarith, ctxt[num_secret], ex_buffer_, + ctxt[num_secret]); + for (Size i = 0; i < num_secret; ++i) { + sampleGaussian(num_polyunit, isNTT); + mulPoly(modarith, vx_buffer_, key.bx(i), ctxt[i], + num_polyunit); + addPoly(modarith, ex_buffer_, ptxt_vec[i], ex_buffer_, + num_polyunit); + + if (!isNTT) { + forwardNTT(modarith, ex_buffer_, num_polyunit); + } + + addPoly(modarith, ctxt[i], ex_buffer_, ctxt[i]); } - - addPoly(modarith_, ctxt[i], ex_buffers_[i], ctxt[i]); } + } else { + // not implemented yet } } else { - // not implemented yet + throw std::runtime_error( + "[Encryptor::innerEncrypt] Unsupported key type"); } } +template template -void Encryptor::embeddingToN(const MSG &msg, const Real &delta, - Polynomial &ptxt, const Size size) const { +void EncryptorT

::embeddingToN(const MSG &msg, const Real &delta, + Polynomial &ptxt, const Size size) const { const auto msg_size = msg.size(); - const auto degree = context_->get_degree(); Size gap = degree / msg_size; if constexpr (std::is_same_v) { gap /= 2; } - std::vector interim(degree); - - PRAGMA_OMP(omp parallel for schedule(static)) - for (Size i = 0; i < msg_size; i++) { - if constexpr (std::is_same_v) { - interim[i] = static_cast( - utils::addZeroPointFive(msg[i].real() * delta)); - interim[msg_size + i] = static_cast( - utils::addZeroPointFive(msg[i].imag() * delta)); - } else if constexpr (std::is_same_v) { - interim[i] = static_cast( - utils::addZeroPointFive(msg[i] * delta)); - } - } + std::vector interim(msg_size * + ((std::is_same_v) ? 2 : 1)); + for (Size i = 0; i < size; i++) { ptxt[i].setNTT(false); - if (gap > 1) + if (degree > msg_size * ((std::is_same_v) ? 2 : 1)) std::fill_n(ptxt[i].data(), degree, UINT64_C(0)); } - PRAGMA_OMP(omp parallel for collapse(2) schedule(static)) - for (Size i = 0; i < size; i++) { - for (Size j = 0; j < degree; j += gap) { - auto input = interim[j]; - bool is_positive = input >= 0; - auto abs = is_positive ? input : -input; - u64 res = modarith_[i].reduceBarrett(static_cast(abs)); - ptxt[i][j] = is_positive ? res : ptxt[i].prime() - res; + PRAGMA_OMP(omp parallel) { + PRAGMA_OMP(omp for schedule(static)) + for (Size i = 0; i < msg_size; i++) { + if constexpr (std::is_same_v || + std::is_same_v) { + interim[i] = static_cast( + utils::addZeroPointFive(msg[i].real() * delta)); + interim[msg_size + i] = static_cast( + utils::addZeroPointFive(msg[i].imag() * delta)); + } else if constexpr (std::is_same_v || + std::is_same_v) { + interim[i] = static_cast( + utils::addZeroPointFive(msg[i] * delta)); + } + } + + PRAGMA_OMP(omp for collapse(2) schedule(static)) + for (Size i = 0; i < size; i++) { + for (Size j = 0; j < degree / gap; j++) { + const utils::u128 input = static_cast(interim[j]); + utils::u128 sign_mask; + if constexpr ((utils::i128(-1) >> 1) == utils::i128(-1)) { + sign_mask = static_cast(interim[j] >> 127); + } else { + sign_mask = ~((input >> 127) - static_cast(1)); + } + const u64 res = + modarith[i].reduceBarrett((input ^ sign_mask) - sign_mask); + const u64 sign_mask_64 = static_cast(sign_mask); + ptxt[i][j * gap] = (res & ~sign_mask_64) | + ((ptxt[i].prime() - res) & sign_mask_64); + } } } } +template template -void Encryptor::encodeWithoutNTT(const MSG &msg, Polynomial &ptxt, - const Size size, const Real scale) const { - const Real delta{ - scale == 0 ? std::pow(static_cast(2), - context_->get_scale_factors()[ptxt.size() - 1]) - : scale}; - if constexpr (std::is_same_v) { +void EncryptorT

::encodeWithoutNTT(const MSG &msg, Polynomial &ptxt, + const Size size, const Real scale) const { + const Real delta{scale == 0 ? std::pow(static_cast(2), + scale_factors[ptxt.size() - 1]) + : scale}; + if constexpr (std::is_same_v || + std::is_same_v) { embeddingToN(msg, delta, ptxt, size); } else if constexpr (std::is_same_v) { Message tmp(msg.size(), msg.data()); fft_.backwardFFT(tmp); embeddingToN(tmp, delta, ptxt, size); + } else if constexpr (std::is_same_v) { + Message tmp(msg.size()); + for (Size i = 0; i < msg.size(); ++i) { + tmp[i] = ComplexT(static_cast(msg[i].real()), + static_cast(msg[i].imag())); + } + fft_.backwardFFT(tmp); + embeddingToN(tmp, delta, ptxt, size); } else { throw std::runtime_error( "[Encryptor::encodeWithoutNTT] Unsupported message type"); } } -DECL_ENCRYPT_TEMPLATE_MSG(Message, ) -DECL_ENCRYPT_TEMPLATE_MSG(CoeffMessage, ) - -void Encryptor::sampleZO(Size num_polyunit) const { - const auto degree = context_->get_degree(); - - Polynomial &poly = vx_buffer_; - poly.setNTT(false); +template void EncryptorT

::sampleZO(Size num_polyunit) const { - const auto pad_degree = (degree + 31) / 32 * 32; - std::vector random_vector(pad_degree); + // const auto pad_degree = std::max(degree, Size(32)); + const auto pad_num = std::max(degree, Size(32)) / 32; + // std::vector random_vector(pad_degree); - for (Size i = 0; i < pad_degree; i += 32) { - u64 rnd = alea_get_random_uint64(as_.get()); - for (Size j = 0; j < 32; j++, rnd >>= 2) { - // random_vector[i + j] = (rnd & 2) ? (rnd & 1) : -(rnd & 1); - random_vector[i + j] = ((rnd & 2) - 1) * (rnd & 1); - } + PRAGMA_OMP(omp single) { + vx_buffer_.setNTT(false); + rng_->getRandomUint64Array(samples_.data() + degree - pad_num, pad_num); } - const auto *const primes = context_->get_primes(); + PRAGMA_OMP(omp for schedule(static)) + for (Size i = 0; i < degree; ++i) { + u64 &rnd = samples_[i / 32]; + // mask is 0xFFFFFFFF if bit is 1, 0x0 if bit is 0 + mask_[i] = 0UL - ((rnd & 2) >> 1); + samples_[i] = (rnd & 1); + rnd >>= 2; + } - PRAGMA_OMP(omp parallel for collapse(2) schedule(static)) + PRAGMA_OMP(omp for collapse(2) schedule(static)) for (Size i = 0; i < num_polyunit; ++i) { for (Size j = 0; j < degree; ++j) { - // poly[i][j] = (random_vector[j] == -1) ? (primes[i] - 1) : - // random_vector[j]; - poly[i][j] = - ((1 - random_vector[j]) >> 1) * primes[i] + random_vector[j]; + const u64 mask = mask_[j]; + const u64 bit = samples_[j]; + vx_buffer_[i][j] = (bit & mask) | ((primes[i] - bit) & ~mask); } } - forwardNTT(modarith_, poly, num_polyunit); -} -void Encryptor::sampleGaussian(const Size idx, const Size num_polyunit, - const bool do_ntt) const { - const auto degree = context_->get_degree(); - const auto *const primes = context_->get_primes(); + forwardNTT(modarith, vx_buffer_, num_polyunit); +} - std::vector samples(degree); - alea_sample_gaussian_int64_array(as_.get(), samples.data(), degree, - context_->get_gaussian_error_stdev()); +template +void EncryptorT

::sampleGaussian(const Size num_polyunit, + const bool do_ntt) const { - Polynomial &poly = ex_buffers_[idx]; - poly.setNTT(false); + PRAGMA_OMP(omp single) { + rng_->sampleGaussianInt64Array(i_samples_.data(), degree, + gaussian_error_stdev); + ex_buffer_.setNTT(false); + } - PRAGMA_OMP(omp parallel for schedule(static) collapse(2)) + PRAGMA_OMP(omp for collapse(2) schedule(static)) for (Size i = 0; i < num_polyunit; ++i) { - for (Size j = 0; j < context_->get_degree(); ++j) { - // Convert int64_t sample to u64 - poly[i][j] = (samples[j] >= 0) - ? static_cast(samples[j]) - : primes[i] - static_cast(-samples[j]); + for (Size j = 0; j < degree; ++j) { + const u64 prime = primes[i]; + const u64 sample = static_cast(i_samples_[j]); + + // sign_mask_rev is -1(0xFFFFFFFF) if i_samples_[j] positive, + // 0(0x0) if negative + const u64 sign_mask_rev = (sample >> 63) - 1u; + + ex_buffer_[i][j] = + (sample & sign_mask_rev) | ((prime + sample) & ~sign_mask_rev); } } if (do_ntt) { - forwardNTT(modarith_, poly, num_polyunit); + forwardNTT(modarith, ex_buffer_, num_polyunit); } } +#define X(preset) DECL_ENCRYPT_TEMPLATE(PRESET_##preset, ) +PRESET_LIST_WITH_EMPTY +#undef X + } // namespace deb diff --git a/src/FFT.cpp b/src/FFT.cpp index 87e8027..5778096 100644 --- a/src/FFT.cpp +++ b/src/FFT.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,9 @@ */ #include "utils/FFT.hpp" -#include "CKKSTypes.hpp" -#include "Constant.hpp" -#include "Macro.hpp" #include "utils/Basic.hpp" +#include "utils/Constant.hpp" +#include "utils/Macro.hpp" #include namespace { @@ -29,25 +28,25 @@ template void bitReverseMessage(deb::MessageImpl &m) { // Direction = true : forward FFT // Direction = false : backward FFT -template +template inline void butterfly(deb::ComplexT &u, deb::ComplexT &v, - const deb::ComplexT root) { + const deb::ComplexT root) { if constexpr (Direction) { - deb::ComplexT u0 = u; - deb::ComplexT v0 = v * root; - u = u0 + v0; - v = u0 - v0; + deb::ComplexT u0 = u; + deb::ComplexT v0 = static_cast>(v) * root; + u = static_cast>(u0 + v0); + v = static_cast>(u0 - v0); } else { deb::ComplexT u0 = u; deb::ComplexT v0 = v; u = u0 + v0; - v = (u0 - v0) * root; + v = (u0 - v0) * static_cast>(root); } } -template +template void computeSingleStep(deb::ComplexT *op, deb::Size size, deb::Size gap, - const deb::ComplexT *roots_ptr) { + const deb::ComplexT *roots_ptr) { if (gap > 4 || size < 8) { deb::ComplexT *x_ptr = op; deb::ComplexT *y_ptr = op + gap; @@ -106,15 +105,16 @@ void computeSingleStep(deb::ComplexT *op, deb::Size size, deb::Size gap, namespace deb::utils { -template void FFTImpl::forwardFFT(MessageImpl &msg) const { +template void FFT::forwardFFT(MessageImpl &msg) const { const Size sz{msg.size()}; const auto *roots_ptr = roots_.data(); bitReverseMessage(msg); - for (Size gap = 1; gap <= sz / 2; gap <<= 1) + for (Size gap = 1; gap <= sz / 2; gap <<= 1) { computeSingleStep(msg.data(), sz, gap, roots_ptr + gap); + } } -template void FFTImpl::backwardFFT(MessageImpl &msg) const { +template void FFT::backwardFFT(MessageImpl &msg) const { const Size sz{msg.size()}; const auto *roots_ptr = inv_roots_.data(); for (Size gap = sz / 2; gap != 0; gap >>= 1) @@ -125,8 +125,7 @@ template void FFTImpl::backwardFFT(MessageImpl &msg) const { msg[i].imag() / static_cast(sz)}; } -template -FFTImpl::FFTImpl(const u64 degree) { //: degree_(degree) { +FFT::FFT(const u64 degree) { // pre-compute the power of five const u64 half_degree = degree >> 1; const u64 double_degree = degree << 1; @@ -139,8 +138,8 @@ FFTImpl::FFTImpl(const u64 degree) { //: degree_(degree) { complex_roots_.resize(double_degree + 1); for (u64 i = 0; i < double_degree; ++i) { Real angle = REAL_PI * static_cast(i) / static_cast(degree); - const ComplexT w{0.0, 1.0}; - const auto tmp = std::exp(w * static_cast(angle)); + const ComplexT w{0.0, 1.0}; + const auto tmp = std::exp(w * angle); complex_roots_[i] = {tmp.real(), tmp.imag()}; } complex_roots_[double_degree] = complex_roots_[0]; diff --git a/src/KeyGenerator.cpp b/src/KeyGenerator.cpp index f0ae70e..169fb1a 100644 --- a/src/KeyGenerator.cpp +++ b/src/KeyGenerator.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,53 +16,45 @@ #include "KeyGenerator.hpp" #include "SecretKeyGenerator.hpp" - #include "utils/Basic.hpp" -#include - -#include - namespace { -inline void checkSecretKey(const deb::Context &context, - const std::optional &sk) { - deb_assert(sk.has_value(), "[KeyGenerator] Secret key is not set."); - deb_assert(context->get_preset() == sk->preset(), +inline void checkSecretKey(const deb::Preset preset, const deb::SecretKey &sk) { + deb_assert(preset == sk.preset(), "[KeyGenerator] Preset mismatch between KeyGenerator and " "SecretKey."); - deb_assert(sk->numPoly() == context->get_rank() * context->get_num_secret(), + deb_assert(get_rank(preset) * get_num_secret(preset) == sk.numPoly(), "[KeyGenerator] Secret key has no embedded polynomials."); // Maybe we can remove this check to allow non-NTT secret keys - deb_assert((*sk)[0][0].isNTT(), + deb_assert(sk[0][0].isNTT(), "[KeyGenerator] Secret key polynomials are not in NTT domain."); }; -inline void checkSwk(const deb::Context &context, const deb::SwitchKey &swk, +inline void checkSwk(const deb::Preset &preset, const deb::SwitchKey &swk, const deb::SwitchKeyKind expected_type) { - deb_assert(context->get_preset() == swk.preset(), + deb_assert(preset == swk.preset(), "[KeyGenerator] Preset mismatch between KeyGenerator and " "SwitchingKey."); - deb_assert(swk.type() == expected_type, + deb_assert(expected_type == swk.type(), "[KeyGenerator] The provided switching key has invalid type."); }; -inline void checkModPackKeyBundleCondition(const deb::Context &context, - const deb::Context &context_from, - const deb::Context &context_to) { - - [[maybe_unused]] const deb::Size from_degree = context_from->get_degree(); - [[maybe_unused]] const deb::Size from_rank = context_from->get_rank(); - [[maybe_unused]] const deb::Size to_degree = context_to->get_degree(); - [[maybe_unused]] const deb::Size to_rank = context_to->get_rank(); - [[maybe_unused]] const deb::Size degree = context->get_degree(); +inline void checkModPackKeyBundleCondition(const deb::Preset &preset, + const deb::Preset &preset_from, + const deb::Preset &preset_to) { + [[maybe_unused]] const deb::Size from_degree = get_degree(preset_from); + [[maybe_unused]] const deb::Size from_rank = get_rank(preset_from); + [[maybe_unused]] const deb::Size to_degree = get_degree(preset_to); + [[maybe_unused]] const deb::Size to_rank = get_rank(preset_to); + [[maybe_unused]] const deb::Size degree = get_degree(preset); // check dimension is compatible // check output ctxt dimension could be resulted by a single key switching - deb_assert( - to_degree * to_rank == degree, - "[genModPackKeyBundle] Total dimension of output secret key is not " - "equal to the RLWE encryption dimension"); + deb_assert(to_degree * to_rank == degree, + "[genModPackKeyBundle] Total dimension of output secret key is " + "not " + "equal to the RLWE encryption dimension"); // check input ctxt entries can be combined to the output ctxt entries deb_assert(to_degree % from_degree == 0, "[genModPackKeyBundle] The degree of input secret key does not " @@ -92,39 +84,48 @@ inline void automorphism(const deb::i8 *op, deb::i8 *res, const deb::Size sig, namespace deb { -KeyGenerator::KeyGenerator(const Preset preset, - std::optional seeds) - : context_(getContext(preset)), sk_(std::nullopt), - fft_(context_->get_degree()) { - for (u64 i = 0; i < context_->get_num_p(); ++i) { - modarith_.emplace_back(context_->get_degree(), - context_->get_primes()[i]); +template +KeyGeneratorT

::KeyGeneratorT(std::optional seeds) + : KeyGeneratorT(P, std::move(seeds)) { + if constexpr (P == PRESET_EMPTY) { + throw std::runtime_error( + "[KeyGenerator] Preset must be specified for EMPTY preset."); + } +} + +template +KeyGeneratorT

::KeyGeneratorT(const Preset preset, + std::optional seeds) + : PresetTraits

(preset), fft_(degree) { + for (u64 i = 0; i < num_p; ++i) { + modarith.emplace_back(degree, primes[i]); } if (!seeds) { seeds.emplace(SeedGenerator::Gen()); } - as_ = std::shared_ptr( - alea_init(to_alea_seed(seeds.value()), ALEA_ALGORITHM_SHAKE256), - [](void *p) { alea_free(static_cast(p)); }); + rng_ = createRandomGenerator(seeds.value()); computeConst(); } -KeyGenerator::KeyGenerator(const SecretKey &sk, - std::optional seeds) - : KeyGenerator(sk.preset(), std::move(seeds)) { - sk_ = sk; +template +KeyGeneratorT

::KeyGeneratorT(const Preset preset, + std::shared_ptr rng) + : PresetTraits

(preset), rng_(std::move(rng)), fft_(degree) { + for (u64 i = 0; i < num_p; ++i) { + modarith.emplace_back(degree, primes[i]); + } + computeConst(); } -void KeyGenerator::genSwitchingKey(const Polynomial *from, const Polynomial *to, - Polynomial *ax, Polynomial *bx, - const Size ax_size, - const Size bx_size) const { - const Size num_secret = context_->get_num_secret(); - const Size degree = context_->get_degree(); - const Size length = context_->get_num_base() + context_->get_num_qp(); - const Size max_length = context_->get_num_p(); - const Size dnum = context_->get_gadget_rank(); +template +void KeyGeneratorT

::genSwitchingKey(const Polynomial *from, + const Polynomial *to, Polynomial *ax, + Polynomial *bx, const Size ax_size, + const Size bx_size) const { + const Size length = num_base + num_qp; + const Size max_length = num_p; + const Size dnum = gadget_rank; const Size alpha = (length + dnum - 1) / dnum; Size a_size = ax_size == 0 ? dnum : ax_size; Size b_size = bx_size == 0 ? dnum * num_secret : bx_size; @@ -133,7 +134,7 @@ void KeyGenerator::genSwitchingKey(const Polynomial *from, const Polynomial *to, sampleUniform(ax[i]); } - Polynomial tmp(context_, max_length); + Polynomial tmp(preset, max_length); const Size s_size = b_size / a_size; for (Size idx = 0; idx < a_size; ++idx) { @@ -142,8 +143,8 @@ void KeyGenerator::genSwitchingKey(const Polynomial *from, const Polynomial *to, auto &b = bx[idx + sid * a_size]; auto ex = sampleGaussian(max_length, true); - mulPoly(modarith_, a, to[sid], b); - subPoly(modarith_, ex, b, b); + mulPoly(modarith, a, to[sid], b); + subPoly(modarith, ex, b, b); for (Size tdx = 0; tdx < max_length; ++tdx) { if (tdx < idx * alpha || @@ -153,37 +154,36 @@ void KeyGenerator::genSwitchingKey(const Polynomial *from, const Polynomial *to, } } } - constMulPoly(modarith_, from[sid], p_mod_.data(), tmp, idx * alpha, + constMulPoly(modarith, from[sid], p_mod_.data(), tmp, idx * alpha, std::min((idx + 1) * alpha, length)); - constMulPoly(modarith_, tmp, hat_q_i_mod_.data(), tmp, idx * alpha, + constMulPoly(modarith, tmp, hat_q_i_mod_.data(), tmp, idx * alpha, std::min((idx + 1) * alpha, length)); // TODO: optimize inplace addition - // addPoly(modarith_, b, tmp, b); + // addPoly(modarith, b, tmp, b); // Polynomial tmp_copy(tmp, idx * alpha, // std::min(alpha, length - idx * alpha)); // Polynomial b_copy(b, idx * alpha, // std::min(alpha, length - idx * alpha)); - // addPoly(modarith_, b_copy, tmp_copy, b_copy); - addPoly(modarith_, b, tmp, b); + // addPoly(modarith, b_copy, tmp_copy, b_copy); + addPolyConst(modarith, b, tmp, b); } } } -SwitchKey KeyGenerator::genEncKey(std::optional sk) const { - SwitchKey enckey(context_, SWK_ENC); +template +SwitchKey KeyGeneratorT

::genEncKey(const SecretKey &sk) const { + SwitchKey enckey(preset, SWK_ENC); genEncKeyInplace(enckey, sk); return enckey; } -void KeyGenerator::genEncKeyInplace(SwitchKey &enckey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, enckey, SWK_ENC); +template +void KeyGeneratorT

::genEncKeyInplace(SwitchKey &enckey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, enckey, SWK_ENC); const bool ntt_state = true; // currently only support ntt state keys - const Size num_poly = context_->get_num_p(); - const Size num_secret = context_->get_num_secret(); + const Size num_poly = num_p; deb_assert(enckey.bxSize() == num_secret && enckey.axSize() == 1, "[KeyGenerator::genEncKeyInplace] " "The provided switching key has invalid size."); @@ -192,26 +192,25 @@ void KeyGenerator::genEncKeyInplace(SwitchKey &enckey, auto ex = sampleGaussian(num_poly, ntt_state); for (Size i = 0; i < num_secret; ++i) { - mulPoly(modarith_, enckey.ax(), (*sk)[i], enckey.bx(i)); - subPoly(modarith_, ex, enckey.bx(i), enckey.bx(i)); + mulPoly(modarith, enckey.ax(), sk[i], enckey.bx(i)); + subPoly(modarith, ex, enckey.bx(i), enckey.bx(i)); } } -SwitchKey KeyGenerator::genMultKey(std::optional sk) const { - SwitchKey mulkey(context_, SWK_MULT); +template +SwitchKey KeyGeneratorT

::genMultKey(const SecretKey &sk) const { + SwitchKey mulkey(preset, SWK_MULT); genMultKeyInplace(mulkey, sk); return mulkey; } -void KeyGenerator::genMultKeyInplace(SwitchKey &mulkey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, mulkey, SWK_MULT); +template +void KeyGeneratorT

::genMultKeyInplace(SwitchKey &mulkey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, mulkey, SWK_MULT); const bool ntt_state = true; // currently only support ntt state keys - const Size num_secret = context_->get_num_secret(); - const Size max_length = context_->get_num_p(); + const Size max_length = num_p; deb_assert(mulkey.bxSize() == num_secret * mulkey.dnum() && mulkey.axSize() == mulkey.dnum(), "[KeyGenerator::genMultKeyInplace] " @@ -219,31 +218,35 @@ void KeyGenerator::genMultKeyInplace(SwitchKey &mulkey, std::vector sx2; for (Size i = 0; i < num_secret; ++i) { - sx2.emplace_back(context_, max_length); + sx2.emplace_back(preset, max_length); sx2[i].setNTT(ntt_state); - mulPoly(modarith_, (*sk)[i], (*sk)[i], sx2[i]); + mulPoly(modarith, sk[i], sk[i], sx2[i]); } - genSwitchingKey(sx2.data(), sk->data(), mulkey.getAx().data(), + genSwitchingKey(sx2.data(), sk.data(), mulkey.getAx().data(), mulkey.getBx().data()); + for (Size i = 0; i < sx2.size(); ++i) { + for (Size j = 0; j < sx2[i].size(); ++j) { + deb_secure_zero(sx2[i][j].data(), sx2[i][j].degree() * sizeof(u64)); + } + } } -SwitchKey KeyGenerator::genConjKey(std::optional sk) const { - SwitchKey conjkey(context_, SWK_CONJ); +template +SwitchKey KeyGeneratorT

::genConjKey(const SecretKey &sk) const { + SwitchKey conjkey(preset, SWK_CONJ); genConjKeyInplace(conjkey, sk); return conjkey; } -void KeyGenerator::genConjKeyInplace(SwitchKey &conjkey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, conjkey, SWK_CONJ); - const bool ntt_state = (*sk)[0][0].isNTT(); +template +void KeyGeneratorT

::genConjKeyInplace(SwitchKey &conjkey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, conjkey, SWK_CONJ); + const bool ntt_state = sk[0][0].isNTT(); - const Size num_secret = context_->get_num_secret(); - const Size max_length = context_->get_num_p(); + const Size max_length = num_p; deb_assert(conjkey.bxSize() == num_secret * conjkey.dnum() && conjkey.axSize() == conjkey.dnum(), "[KeyGenerator::genConjKeyInplace] " @@ -251,36 +254,39 @@ void KeyGenerator::genConjKeyInplace(SwitchKey &conjkey, std::vector sx; for (Size i = 0; i < num_secret; ++i) { - sx.emplace_back(context_, max_length); + sx.emplace_back(preset, max_length); sx[i].setNTT(ntt_state); // frobenius map in NTT - frobeniusMapInNTT((*sk)[i], -1, sx[i]); + frobeniusMapInNTT(sk[i], -1, sx[i]); } - genSwitchingKey(sx.data(), sk->data(), conjkey.getAx().data(), + genSwitchingKey(sx.data(), sk.data(), conjkey.getAx().data(), conjkey.getBx().data()); + for (Size i = 0; i < sx.size(); ++i) { + for (Size j = 0; j < sx[i].size(); ++j) { + deb_secure_zero(sx[i][j].data(), sx[i][j].degree() * sizeof(u64)); + } + } } -SwitchKey KeyGenerator::genLeftRotKey(const Size rot, - std::optional sk) const { - SwitchKey rotkey(context_, SWK_ROT); +template +SwitchKey KeyGeneratorT

::genLeftRotKey(const Size rot, + const SecretKey &sk) const { + SwitchKey rotkey(preset, SWK_ROT); genLeftRotKeyInplace(rot, rotkey, sk); return rotkey; } -void KeyGenerator::genLeftRotKeyInplace(const Size rot, SwitchKey &rotkey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, rotkey, SWK_ROT); - deb_assert(rot < context_->get_num_slots(), - "[KeyGenerator::genLeftRotKeyInplace] " - "Rotation value exceeds number of slots."); +template +void KeyGeneratorT

::genLeftRotKeyInplace(const Size rot, SwitchKey &rotkey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, rotkey, SWK_ROT); + deb_assert(rot < num_slots, "[KeyGenerator::genLeftRotKeyInplace] " + "Rotation value exceeds number of slots."); const auto ntt_state = true; // currently only support ntt state keys - const Size num_secret = context_->get_num_secret(); - const Size max_length = context_->get_num_p(); + const Size max_length = num_p; deb_assert(rotkey.bxSize() == num_secret * rotkey.dnum() && rotkey.axSize() == rotkey.dnum(), "[KeyGenerator::genLeftRotKeyInplace] " @@ -290,285 +296,320 @@ void KeyGenerator::genLeftRotKeyInplace(const Size rot, SwitchKey &rotkey, std::vector sx; for (Size i = 0; i < num_secret; ++i) { - sx.emplace_back(context_, max_length); + sx.emplace_back(preset, max_length); sx[i].setNTT(ntt_state); - frobeniusMapInNTT((*sk)[i], static_cast(fft_.getPowerOfFive(rot)), + frobeniusMapInNTT(sk[i], static_cast(fft_.getPowerOfFive(rot)), sx[i]); } - genSwitchingKey(sx.data(), sk->data(), rotkey.getAx().data(), + genSwitchingKey(sx.data(), sk.data(), rotkey.getAx().data(), rotkey.getBx().data()); + for (Size i = 0; i < sx.size(); ++i) { + for (Size j = 0; j < sx[i].size(); ++j) { + deb_secure_zero(sx[i][j].data(), sx[i][j].degree() * sizeof(u64)); + } + } } -SwitchKey KeyGenerator::genRightRotKey(const Size rot, - std::optional sk) const { - const Size left_rot_id = context_->get_num_slots() - rot; - SwitchKey rotkey(context_, SWK_ROT); +template +SwitchKey KeyGeneratorT

::genRightRotKey(const Size rot, + const SecretKey &sk) const { + const Size left_rot_id = num_slots - rot; + SwitchKey rotkey(preset, SWK_ROT); genLeftRotKeyInplace(left_rot_id, rotkey, sk); return rotkey; } -void KeyGenerator::genRightRotKeyInplace(const Size rot, SwitchKey &rotkey, - std::optional sk) const { - genLeftRotKeyInplace(context_->get_num_slots() - rot, rotkey, sk); +template +void KeyGeneratorT

::genRightRotKeyInplace(const Size rot, SwitchKey &rotkey, + const SecretKey &sk) const { + genLeftRotKeyInplace(num_slots - rot, rotkey, sk); } -SwitchKey KeyGenerator::genAutoKey(const Size sig, - std::optional sk) const { - SwitchKey autokey(context_, SWK_AUTO); +template +SwitchKey KeyGeneratorT

::genAutoKey(const Size sig, + const SecretKey &sk) const { + SwitchKey autokey(preset, SWK_AUTO); genAutoKeyInplace(sig, autokey, sk); return autokey; } -void KeyGenerator::genAutoKeyInplace(const Size sig, SwitchKey &autokey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, autokey, SWK_AUTO); - deb_assert(sig < context_->get_degree(), - "[KeyGenerator::genAutoKey] " - "Signature value exceeds polynomial degree."); +template +void KeyGeneratorT

::genAutoKeyInplace(const Size sig, SwitchKey &autokey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, autokey, SWK_AUTO); + deb_assert(sig < degree, "[KeyGenerator::genAutoKey] " + "Signature value exceeds polynomial degree."); - const Size num_secret = context_->get_num_secret(); - const Size degree = context_->get_degree(); deb_assert(autokey.bxSize() == num_secret * autokey.dnum() && autokey.axSize() == autokey.dnum(), "[KeyGenerator::genAutoKey] " "The provided switching key has invalid size."); autokey.setRotIdx(sig); - std::vector coeff_sig(degree); + std::vector coeff_sig(degree * num_secret); - automorphism(sk->coeffs(), coeff_sig.data(), sig, degree); - SecretKey sk_sig = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_->get_preset(), coeff_sig.data()); - genSwitchingKey(sk_sig.data(), sk->data(), autokey.getAx().data(), + for (Size i = 0; i < num_secret; ++i) { + automorphism(sk.coeffs() + i * degree, coeff_sig.data() + i * degree, + sig, degree); + } + SecretKey sk_sig = + SecretKeyGenerator::GenSecretKeyFromCoeff(preset, coeff_sig.data()); + genSwitchingKey(sk_sig.data(), sk.data(), autokey.getAx().data(), autokey.getBx().data()); + deb_secure_zero(coeff_sig.data(), coeff_sig.size() * sizeof(i8)); + // sk_sig.zeroize(); // automatically zeroized when going out of scope } -SwitchKey KeyGenerator::genComposeKey(const SecretKey &sk_from, - std::optional sk) const { +template +SwitchKey KeyGeneratorT

::genComposeKey(const SecretKey &sk_from, + const SecretKey &sk) const { // TODO: check prime compatibility return genComposeKey(sk_from.coeffs(), sk_from.coeffsSize(), sk); } -SwitchKey KeyGenerator::genComposeKey(const std::vector coeffs, - std::optional sk) const { +template +SwitchKey KeyGeneratorT

::genComposeKey(const std::vector coeffs, + const SecretKey &sk) const { return genComposeKey(coeffs.data(), static_cast(coeffs.size()), sk); } -SwitchKey KeyGenerator::genComposeKey(const i8 *coeffs, const Size coeffs_size, - std::optional sk) const { - SwitchKey composekey(context_, SWK_COMPOSE); +template +SwitchKey KeyGeneratorT

::genComposeKey(const i8 *coeffs, + const Size coeffs_size, + const SecretKey &sk) const { + SwitchKey composekey(preset, SWK_COMPOSE); genComposeKeyInplace(coeffs, coeffs_size, composekey, sk); return composekey; } -void KeyGenerator::genComposeKeyInplace(const SecretKey &sk_from, - SwitchKey &composekey, - std::optional sk) const { +template +void KeyGeneratorT

::genComposeKeyInplace(const SecretKey &sk_from, + SwitchKey &composekey, + const SecretKey &sk) const { genComposeKeyInplace(sk_from.coeffs(), sk_from.coeffsSize(), composekey, sk); } -void KeyGenerator::genComposeKeyInplace(const std::vector coeffs, - SwitchKey &composekey, - std::optional sk) const { +template +void KeyGeneratorT

::genComposeKeyInplace(const std::vector coeffs, + SwitchKey &composekey, + const SecretKey &sk) const { genComposeKeyInplace(coeffs.data(), static_cast(coeffs.size()), composekey, sk); } -void KeyGenerator::genComposeKeyInplace(const i8 *coeffs, - const Size coeffs_size, - SwitchKey &composekey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, composekey, SWK_COMPOSE); - - const Size num_secret = context_->get_num_secret(); - const Size deg_ratio = context_->get_degree() / coeffs_size; - deb_assert(coeffs_size * deg_ratio == context_->get_degree(), +template +void KeyGeneratorT

::genComposeKeyInplace(const i8 *coeffs, + const Size coeffs_size, + SwitchKey &composekey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, composekey, SWK_COMPOSE); + + const Size deg_ratio = degree / coeffs_size; + deb_assert(coeffs_size * deg_ratio == degree, "[KeyGenerator::genComposeKey] " "The provided secret key has invalid size."); - deb_assert(composekey.bxSize() == num_secret * composekey.dnum() && + deb_assert(num_secret == 1, "[KeyGenerator::genComposeKey] " + "Composition key generation is only supported " + "for single-secret presets."); + deb_assert(composekey.bxSize() == composekey.dnum() && composekey.axSize() == composekey.dnum(), "[KeyGenerator::genComposeKeyInplace] " "The provided switching key has invalid size."); - std::vector coeffs_embed(context_->get_degree(), 0); + std::vector coeffs_embed(degree, 0); for (Size i = 0; i < coeffs_size; ++i) { coeffs_embed[i * deg_ratio] = coeffs[i]; } - SecretKey sk_from = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_->get_preset(), coeffs_embed.data()); + SecretKey sk_from = + SecretKeyGenerator::GenSecretKeyFromCoeff(preset, coeffs_embed.data()); - genSwitchingKey(sk_from.data(), sk->data(), composekey.getAx().data(), + genSwitchingKey(sk_from.data(), sk.data(), composekey.getAx().data(), composekey.getBx().data()); + // sk_from.zeroize(); // automatically zeroized when going out of scope } -SwitchKey KeyGenerator::genDecomposeKey(const SecretKey &sk_to, - std::optional sk) const { +template +SwitchKey KeyGeneratorT

::genDecomposeKey(const SecretKey &sk_to, + const SecretKey &sk) const { return genDecomposeKey(sk_to.coeffs(), sk_to.coeffsSize(), sk); } -SwitchKey KeyGenerator::genDecomposeKey(const std::vector coeffs, - std::optional sk) const { +template +SwitchKey KeyGeneratorT

::genDecomposeKey(const std::vector coeffs, + const SecretKey &sk) const { return genDecomposeKey(coeffs.data(), static_cast(coeffs.size()), sk); } -SwitchKey KeyGenerator::genDecomposeKey(const i8 *coeffs, - const Size coeffs_size, - std::optional sk) const { - SwitchKey decompkey(context_, SWK_DECOMPOSE); +template +SwitchKey KeyGeneratorT

::genDecomposeKey(const i8 *coeffs, + const Size coeffs_size, + const SecretKey &sk) const { + SwitchKey decompkey(preset, SWK_DECOMPOSE); genDecomposeKeyInplace(coeffs, coeffs_size, decompkey, sk); return decompkey; } -void KeyGenerator::genDecomposeKeyInplace(const SecretKey &sk_to, - SwitchKey &decompkey, - std::optional sk) const { +template +void KeyGeneratorT

::genDecomposeKeyInplace(const SecretKey &sk_to, + SwitchKey &decompkey, + const SecretKey &sk) const { genDecomposeKeyInplace(sk_to.coeffs(), sk_to.coeffsSize(), decompkey, sk); } -void KeyGenerator::genDecomposeKeyInplace(const std::vector coeffs, - SwitchKey &decompkey, - std::optional sk) const { +template +void KeyGeneratorT

::genDecomposeKeyInplace(const std::vector coeffs, + SwitchKey &decompkey, + const SecretKey &sk) const { genDecomposeKeyInplace(coeffs.data(), static_cast(coeffs.size()), decompkey, sk); } -void KeyGenerator::genDecomposeKeyInplace(const i8 *coeffs, - const Size coeffs_size, - SwitchKey &decompkey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, decompkey, SWK_DECOMPOSE); - const Size num_secret = context_->get_num_secret(); - const Size deg_ratio = context_->get_degree() / coeffs_size; - deb_assert(coeffs_size * deg_ratio == context_->get_degree(), +template +void KeyGeneratorT

::genDecomposeKeyInplace(const i8 *coeffs, + const Size coeffs_size, + SwitchKey &decompkey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, decompkey, SWK_DECOMPOSE); + const Size deg_ratio = degree / coeffs_size; + deb_assert(coeffs_size * deg_ratio == degree, "[KeyGenerator::genDecomposeKey] " "The provided secret key has invalid size."); - - deb_assert(decompkey.bxSize() == num_secret * decompkey.dnum() && + deb_assert(num_secret == 1, "[KeyGenerator::genDecomposeKey] " + "Decomposition key generation is only " + "supported for single-secret presets."); + deb_assert(decompkey.bxSize() == decompkey.dnum() && decompkey.axSize() == decompkey.dnum(), "[KeyGenerator::genDecomposeKeyInplace] " "The provided switching key has invalid size."); - std::vector coeffs_embed(context_->get_degree(), 0); + std::vector coeffs_embed(degree, 0); for (Size i = 0; i < coeffs_size; ++i) { coeffs_embed[i * deg_ratio] = coeffs[i]; } - SecretKey sk_to = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_->get_preset(), coeffs_embed.data()); - genSwitchingKey(sk->data(), sk_to.data(), decompkey.getAx().data(), + SecretKey sk_to = + SecretKeyGenerator::GenSecretKeyFromCoeff(preset, coeffs_embed.data()); + genSwitchingKey(sk.data(), sk_to.data(), decompkey.getAx().data(), decompkey.getBx().data()); + // sk_to.zeroize(); // automatically zeroized when going out of scope } -SwitchKey KeyGenerator::genDecomposeKey(const Preset preset_swk, - const SecretKey &sk_to, - std::optional sk) const { +template +SwitchKey KeyGeneratorT

::genDecomposeKey(const Preset preset_swk, + const SecretKey &sk_to, + const SecretKey &sk) const { return genDecomposeKey(preset_swk, sk_to.coeffs(), sk_to.coeffsSize(), sk); } -SwitchKey KeyGenerator::genDecomposeKey(const Preset preset_swk, - const std::vector coeffs, - std::optional sk) const { +template +SwitchKey KeyGeneratorT

::genDecomposeKey(const Preset preset_swk, + const std::vector coeffs, + const SecretKey &sk) const { return genDecomposeKey(preset_swk, coeffs.data(), static_cast(coeffs.size()), sk); } -SwitchKey KeyGenerator::genDecomposeKey(const Preset preset_swk, - const i8 *coeffs, Size coeffs_size, - std::optional sk) const { - Context context_swk = getContext(preset_swk); - SwitchKey decompkey(context_swk, SWK_DECOMPOSE); +template +SwitchKey KeyGeneratorT

::genDecomposeKey(const Preset preset_swk, + const i8 *coeffs, Size coeffs_size, + const SecretKey &sk) const { + SwitchKey decompkey(preset_swk, SWK_DECOMPOSE); genDecomposeKeyInplace(preset_swk, coeffs, coeffs_size, decompkey, sk); return decompkey; } -void KeyGenerator::genDecomposeKeyInplace(const Preset preset_swk, - const SecretKey &sk_to, - SwitchKey &decompkey, - std::optional sk) const { +template +void KeyGeneratorT

::genDecomposeKeyInplace(const Preset preset_swk, + const SecretKey &sk_to, + SwitchKey &decompkey, + const SecretKey &sk) const { genDecomposeKeyInplace(preset_swk, sk_to.coeffs(), sk_to.coeffsSize(), decompkey, sk); } -void KeyGenerator::genDecomposeKeyInplace(const Preset preset_swk, - const std::vector coeffs, - SwitchKey &decompkey, - std::optional sk) const { +template +void KeyGeneratorT

::genDecomposeKeyInplace(const Preset preset_swk, + const std::vector coeffs, + SwitchKey &decompkey, + const SecretKey &sk) const { genDecomposeKeyInplace(preset_swk, coeffs.data(), static_cast(coeffs.size()), decompkey, sk); } -void KeyGenerator::genDecomposeKeyInplace(const Preset preset_swk, - const i8 *coeffs, Size coeffs_size, - SwitchKey &decompkey, - std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - Context context_swk = getContext(preset_swk); - checkSecretKey(context_, sk); - checkSwk(context_swk, decompkey, SWK_DECOMPOSE); - deb_assert( - context_->get_degree() == context_swk->get_degree(), - "[KeyGenerator::genDecomposeKey] " - "Degree mismatch between KeyGenerator and switching key preset."); +template +void KeyGeneratorT

::genDecomposeKeyInplace(const Preset preset_swk, + const i8 *coeffs, + Size coeffs_size, + SwitchKey &decompkey, + const SecretKey &sk) const { + checkSecretKey(preset_swk, sk); + checkSwk(preset_swk, decompkey, SWK_DECOMPOSE); + deb_assert(degree == get_degree(preset_swk), + "[KeyGenerator::genDecomposeKey] " + "Degree mismatch between KeyGenerator and switching key " + "preset."); - const Size num_secret = context_swk->get_num_secret(); - const Size deg_ratio = context_swk->get_degree() / coeffs_size; - deb_assert(coeffs_size * deg_ratio == context_->get_degree(), + const Size num_secret = get_num_secret(preset_swk); + const Size deg_ratio = get_degree(preset_swk) / coeffs_size; + deb_assert(coeffs_size * deg_ratio == degree, "[KeyGenerator::genDecomposeKey] " "The provided secret key has invalid size."); - deb_assert(decompkey.bxSize() == num_secret * decompkey.dnum() && + deb_assert(num_secret == 1, "[KeyGenerator::genDecomposeKey] " + "Decomposition key generation is only " + "supported for single-secret presets."); + deb_assert(decompkey.bxSize() == decompkey.dnum() && decompkey.axSize() == decompkey.dnum(), "[KeyGenerator::genDecomposeKeyInplace] " "The provided switching key has invalid size."); - std::vector coeffs_embed(context_->get_degree(), 0); + std::vector coeffs_embed(degree, 0); for (Size i = 0; i < coeffs_size; ++i) { coeffs_embed[i * deg_ratio] = coeffs[i]; } SecretKey sk_to = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_swk->get_preset(), coeffs_embed.data()); - SecretKey sk_from = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_swk->get_preset(), sk->coeffs()); + preset_swk, coeffs_embed.data()); + SecretKey sk_from = + SecretKeyGenerator::GenSecretKeyFromCoeff(preset_swk, sk.coeffs()); KeyGenerator keygen_swk(preset_swk); keygen_swk.genSwitchingKey(sk_from.data(), sk_to.data(), decompkey.getAx().data(), decompkey.getBx().data()); + // sk_to.zeroize(); // automatically zeroized when going out of scope + // sk_from.zeroize(); // automatically zeroized when going out of scope } +template std::vector -KeyGenerator::genModPackKeyBundle(const SecretKey &sk_from, - const SecretKey &sk_to) const { +KeyGeneratorT

::genModPackKeyBundle(const SecretKey &sk_from, + const SecretKey &sk_to) const { std::vector key_bundle; - const auto num_key = getContext(sk_from.preset())->get_rank() / - getContext(sk_to.preset())->get_rank(); + const auto num_key = get_rank(sk_from.preset()) / get_rank(sk_to.preset()); for (u64 i = 0; i < num_key; ++i) { - key_bundle.emplace_back(context_, SWK_MODPACK); + key_bundle.emplace_back(preset, SWK_MODPACK); } genModPackKeyBundleInplace(sk_from, sk_to, key_bundle); return key_bundle; } -void KeyGenerator::genModPackKeyBundleInplace( +template +void KeyGeneratorT

::genModPackKeyBundleInplace( const SecretKey &sk_from, const SecretKey &sk_to, std::vector &key_bundle) const { deb_assert(sk_from[0][0].isNTT() == sk_to[0][0].isNTT(), "[KeyGenerator::genModPackKeyBundle] " "NTT state mismatch between input secret keys."); - - const auto context_from = getContext(sk_from.preset()); - const auto context_to = getContext(sk_to.preset()); - checkModPackKeyBundleCondition(context_, context_from, context_to); - - [[maybe_unused]] const Size num_secret = context_->get_num_secret(); - const u64 from_deg = context_from->get_degree(); - const u64 from_rank = context_from->get_rank(); - const u64 to_deg = context_to->get_degree(); - const u64 to_rank = context_to->get_rank(); - const u64 rlwe_deg = context_->get_degree(); + deb_assert( + get_num_secret(sk_from.preset()) * get_num_secret(sk_to.preset()) == 1, + "[KeyGenerator::genModPackKeyBundle] " + "ModPackKeyBundle is only supported for single-secret presets."); + + const auto preset_from = sk_from.preset(); + const auto preset_to = sk_to.preset(); + checkModPackKeyBundleCondition(preset, preset_from, preset_to); + + const u64 from_deg = get_degree(preset_from); + const u64 from_rank = get_rank(preset_from); + const u64 to_deg = get_degree(preset_to); + const u64 to_rank = get_rank(preset_to); + const u64 rlwe_deg = degree; const u64 num_keys = from_rank / to_rank; const u64 deg_ratio = rlwe_deg / from_deg; - deb_assert(key_bundle.size() == num_keys, "[KeyGenerator::genModPackKeyBundle] " "The provided switching key bundle has invalid size."); + const i8 *sk_from_coeff = sk_from.coeffs(); const i8 *sk_to_coeff = sk_to.coeffs(); auto *rlwe_coeff = new i8[rlwe_deg]; @@ -578,8 +619,8 @@ void KeyGenerator::genModPackKeyBundleInplace( for (u64 k = 0; k < to_deg; ++k) rlwe_coeff[j + to_rank * k] = sk_to_coeff[k + to_deg * j]; - SecretKey sk_to_rlwe = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_->get_preset(), rlwe_coeff); + SecretKey sk_to_rlwe = + SecretKeyGenerator::GenSecretKeyFromCoeff(preset, rlwe_coeff); for (u64 i = 0; i < num_keys; ++i) { // from_deg * (from_rank / num_keys) -> rlwe_deg ; embed and combine @@ -597,36 +638,39 @@ void KeyGenerator::genModPackKeyBundleInplace( for (u64 k = 0; k < from_deg; ++k) rlwe_coeff[j + deg_ratio * k] = sk_from_coeff[k + from_deg * (j + to_rank * i)]; - SecretKey sk_from_rlwe = SecretKeyGenerator::GenSecretKeyFromCoeff( - context_->get_preset(), rlwe_coeff); + SecretKey sk_from_rlwe = + SecretKeyGenerator::GenSecretKeyFromCoeff(preset, rlwe_coeff); genSwitchingKey(sk_from_rlwe.data(), sk_to_rlwe.data(), key_bundle[i].getAx().data(), key_bundle[i].getBx().data()); + // sk_from_rlwe.zeroize(); // automatically zeroized when going out of + // scope } + deb_secure_zero(rlwe_coeff, rlwe_deg * sizeof(i8)); delete[] rlwe_coeff; } -SwitchKey KeyGenerator::genModPackKeyBundle(const Size pad_rank, - std::optional sk) const { - SwitchKey modkey(context_, SWK_MODPACK_SELF); - const auto max_length = context_->get_num_p(); +template +SwitchKey KeyGeneratorT

::genModPackKeyBundle(const Size pad_rank, + const SecretKey &sk) const { + SwitchKey modkey(preset, SWK_MODPACK_SELF); + const auto max_length = num_p; modkey.addAx(max_length, pad_rank, true); - modkey.addBx(max_length, pad_rank * context_->get_num_secret(), true); + modkey.addBx(max_length, pad_rank * num_secret, true); genModPackKeyBundleInplace(pad_rank, modkey, sk); return modkey; } -void KeyGenerator::genModPackKeyBundleInplace( - const Size pad_rank, SwitchKey &modkey, std::optional sk) const { - if (!sk.has_value()) - sk = sk_; - checkSecretKey(context_, sk); - checkSwk(context_, modkey, SWK_MODPACK_SELF); - const Size items_per_ctxt = context_->get_degree() / pad_rank; - const Size degree = context_->get_degree(); - deb_assert( - utils::isPowerOfTwo(pad_rank), - "[KeyGenerator::genModPackKeyBundle] pad_rank must be a power of two."); - deb_assert(modkey.bxSize() == pad_rank * context_->get_num_secret() && +template +void KeyGeneratorT

::genModPackKeyBundleInplace(const Size pad_rank, + SwitchKey &modkey, + const SecretKey &sk) const { + checkSecretKey(preset, sk); + checkSwk(preset, modkey, SWK_MODPACK_SELF); + const Size items_per_ctxt = degree / pad_rank; + deb_assert(utils::isPowerOfTwo(pad_rank), + "[KeyGenerator::genModPackKeyBundle] pad_rank must be a power " + "of two."); + deb_assert(modkey.bxSize() == pad_rank * num_secret && modkey.axSize() == pad_rank, "[KeyGenerator::genModPackKeyBundle] The provided switching key " "has invalid size."); @@ -636,24 +680,26 @@ void KeyGenerator::genModPackKeyBundleInplace( std::memset(from_coeff, 0, degree); for (Size j = 0; j < items_per_ctxt; ++j) { from_coeff[pad_rank * j] = - sk->coeffs()[j * pad_rank + pad_rank - 1 - i]; + sk.coeffs()[j * pad_rank + pad_rank - 1 - i]; } SecretKey sk_from = - SecretKeyGenerator::GenSecretKeyFromCoeff(sk->preset(), from_coeff); - genSwitchingKey(sk_from.data(), sk->data(), &(modkey.ax(i)), - &(modkey.bx(i)), 1, context_->get_num_secret()); + SecretKeyGenerator::GenSecretKeyFromCoeff(sk.preset(), from_coeff); + genSwitchingKey(sk_from.data(), sk.data(), &(modkey.ax(i)), + &(modkey.bx(i)), 1, num_secret); + deb_secure_zero(from_coeff, degree * sizeof(i8)); delete[] from_coeff; + // sk_from.zeroize(); // automatically zeroized when going out of scope } } -void KeyGenerator::frobeniusMapInNTT(const Polynomial &op, const i32 pow, - Polynomial res) const { +template +void KeyGeneratorT

::frobeniusMapInNTT(const Polynomial &op, const i32 pow, + Polynomial res) const { deb_assert(op[0].isNTT(), "[KeyGenerator::frobeniusMapInNTT] " "Input polynomial must be in NTT state."); deb_assert(pow % 2 != 0, "[KeyGenerator::frobeniusMapInNTT] " "Frobenius map power must be odd."); - Size degree = context_->get_degree(); u64 log_degree = utils::log2floor(static_cast(degree)); if (pow == 1) { @@ -688,52 +734,53 @@ void KeyGenerator::frobeniusMapInNTT(const Polynomial &op, const i32 pow, } } -Polynomial KeyGenerator::sampleGaussian(const Size num_polyunit, - bool do_ntt) const { - const auto degree = context_->get_degree(); +template +Polynomial KeyGeneratorT

::sampleGaussian(const Size num_polyunit, + bool do_ntt) const { std::vector samples(degree); - alea_sample_gaussian_int64_array(as_.get(), samples.data(), degree, - context_->get_gaussian_error_stdev()); - Polynomial poly(context_, num_polyunit); + rng_->sampleGaussianInt64Array(samples.data(), degree, + gaussian_error_stdev); + Polynomial poly(preset, num_polyunit); for (Size i = 0; i < poly.size(); ++i) { - poly[i].setPrime(context_->get_primes()[i]); - for (Size j = 0; j < context_->get_degree(); ++j) { + poly[i].setPrime(primes[i]); + for (Size j = 0; j < degree; ++j) { // Convert int64_t sample to u64 - poly[i][j] = (samples[j] >= 0) ? static_cast(samples[j]) - : context_->get_primes()[i] - - static_cast(-samples[j]); + poly[i][j] = (samples[j] >= 0) + ? static_cast(samples[j]) + : primes[i] - static_cast(-samples[j]); } } if (do_ntt) { - forwardNTT(modarith_, poly); + forwardNTT(modarith, poly); } return poly; } -void KeyGenerator::sampleUniform(Polynomial &poly) const { +template +void KeyGeneratorT

::sampleUniform(Polynomial &poly) const { // TODO: add reseed controller for (u64 i = 0; i < poly.size(); ++i) { - alea_get_random_uint64_array_in_range( - as_.get(), poly[i].data(), context_->get_degree(), poly[i].prime()); + rng_->getRandomUint64ArrayInRange(poly[i].data(), degree, + poly[i].prime()); } } -void KeyGenerator::computeConst() { - const Size length = context_->get_num_base() + context_->get_num_qp(); - const Size dnum = context_->get_gadget_rank(); +template void KeyGeneratorT

::computeConst() { + const Size length = num_base + num_qp; + const Size dnum = gadget_rank; const Size alpha = (length + dnum - 1) / dnum; p_mod_.resize(length); for (Size i = 0; i < length; ++i) { - const u64 prime = context_->get_primes()[i]; + const u64 prime = primes[i]; const u64 two_prime = prime << 1; u64 p = UINT64_C(1); - for (Size j = 0; j < context_->get_num_tp(); ++j) { - const u64 pp = modarith_[i].reduceBarrett<2>( - context_->get_primes()[j + length]); - p = modarith_[i].mul(p, pp); + for (Size j = 0; j < num_tp; ++j) { + const u64 pp = + modarith[i].template reduceBarrett<2>(primes[j + length]); + p = modarith[i].mul(p, pp); } p = utils::subIfGE(p, two_prime); p_mod_[i] = utils::subIfGE(p, prime); @@ -744,15 +791,14 @@ void KeyGenerator::computeConst() { for (Size i = 0; i < length; ++i) { const u64 beta = i / alpha; - const u64 prime = context_->get_primes()[i]; + const u64 prime = primes[i]; const u64 two_prime = prime << 1; u64 hat_q = UINT64_C(1); for (Size j = 0; j < length; ++j) { if (j < beta * alpha || j >= (beta + 1) * alpha) { - u64 pp = - modarith_[i].reduceBarrett<2>(context_->get_primes()[j]); - hat_q = modarith_[i].mul(hat_q, pp); + u64 pp = modarith[i].template reduceBarrett<2>(primes[j]); + hat_q = modarith[i].mul(hat_q, pp); } } @@ -760,7 +806,11 @@ void KeyGenerator::computeConst() { hat_q = utils::subIfGE(hat_q, prime); hat_q_i_mod_[i] = hat_q; - hat_q_i_inv_mod_[i] = modarith_[i].inverse(hat_q); + hat_q_i_inv_mod_[i] = modarith[i].inverse(hat_q); } } + +#define X(preset) template class KeyGeneratorT; +PRESET_LIST_WITH_EMPTY +#undef X } // namespace deb diff --git a/src/ModArith.cpp b/src/ModArith.cpp index 2b58622..e0d023e 100644 --- a/src/ModArith.cpp +++ b/src/ModArith.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ */ #include "utils/ModArith.hpp" -#include "Macro.hpp" -#include "utils/Basic.hpp" #include #include @@ -27,18 +25,37 @@ namespace deb::utils { -ModArith::ModArith(Size size, u64 prime) +template +ModArith::ModArith(u64 prime) : prime_(prime), two_prime_(prime << 1), barrett_expt_(bitWidth(prime) - 1), barrett_ratio_(static_cast( (static_cast(1) << (barrett_expt_ + 63)) / prime)), - default_array_size_(size), + default_array_size_(degree), barrett_ratio_for_u64_(divide128By64Lo(UINT64_C(1), UINT64_C(0), prime)), two_to_64_(powModSimple(2, 64, prime)), two_to_64_shoup_(divide128By64Lo(two_to_64_, UINT64_C(0), prime)), - ntt_(std::make_unique(size, prime)) {} + ntt_(std::make_unique(degree, prime)) { + if constexpr (D == 1) { + throw std::runtime_error("[ModArith] Degree template parameter must be " + "non-zero when degree is not specified"); + } +} + +template +ModArith::ModArith(Size actual_degree, u64 prime) + : DegreeTrait(actual_degree), prime_(prime), two_prime_(prime << 1), + barrett_expt_(bitWidth(prime) - 1), + barrett_ratio_(static_cast( + (static_cast(1) << (barrett_expt_ + 63)) / prime)), + default_array_size_(actual_degree), + barrett_ratio_for_u64_(divide128By64Lo(UINT64_C(1), UINT64_C(0), prime)), + two_to_64_(powModSimple(2, 64, prime)), + two_to_64_shoup_(divide128By64Lo(two_to_64_, UINT64_C(0), prime)), + ntt_(std::make_unique(actual_degree, prime)) {} -void ModArith::constMult(const u64 *op1, const u64 op2_big, u64 *res, - Size array_size) const { +template +void ModArith::constMult(const u64 *op1, const u64 op2_big, u64 *res, + Size array_size) const { const u64 op2 = reduceBarrett(op2_big); u64 approx_quotient = divide128By64Lo(op2, UINT64_C(0), prime_); @@ -49,8 +66,9 @@ void ModArith::constMult(const u64 *op1, const u64 op2_big, u64 *res, } } -void ModArith::mulVector(u64 *res, const u64 *op1, const u64 *op2, - Size array_size) const { +template +void ModArith::mulVector(u64 *res, const u64 *op1, const u64 *op2, + Size array_size) const { const auto barr = this->barrett_ratio_; const int k_1 = static_cast(this->barrett_expt_) - 1; @@ -59,14 +77,14 @@ void ModArith::mulVector(u64 *res, const u64 *op1, const u64 *op2, u64 c1 = u128Lo(prod >> (k_1)); u64 c2 = mul64To128Hi(c1, barr); u64 c3 = u128Lo(prod) - c2 * prime_; - res[i] = subIfGE(c3, prime_); + res[i] = subIfGEConst(c3, prime_); } } namespace { -template -inline void for_each_modarith(const std::vector &modarith, Func func, - Size size, Args... args) { +template +inline void for_each_modarith(const std::vector> &modarith, + Func func, Size size, Args... args) { PRAGMA_OMP(omp for schedule(static)) for (Size i = 0; i < size; ++i) { func(modarith[i], getData(std::forward(args), i)...); @@ -74,33 +92,36 @@ inline void for_each_modarith(const std::vector &modarith, Func func, }; } // namespace -void forwardNTT(const std::vector &modarith, Polynomial &poly, +template +void forwardNTT(const std::vector> &modarith, Polynomial &poly, Size num_polyunit, [[maybe_unused]] bool expected_ntt_state) { deb_assert(poly[0].isNTT() == expected_ntt_state, "[forwardNTT] NTT state mismatch"); num_polyunit = num_polyunit ? num_polyunit : poly.size(); for_each_modarith( - modarith, [](const ModArith &ma, u64 *p) { ma.forwardNTT(p); }, + modarith, [](const ModArith &ma, u64 *p) { ma.forwardNTT(p); }, num_polyunit, poly); for (Size i = 0; i < num_polyunit; ++i) { poly[i].setNTT(true); } } -void backwardNTT(const std::vector &modarith, Polynomial &poly, +template +void backwardNTT(const std::vector> &modarith, Polynomial &poly, Size num_polyunit, [[maybe_unused]] bool expected_ntt_state) { deb_assert(poly[0].isNTT() == expected_ntt_state, "[backwardNTT] NTT state mismatch"); num_polyunit = num_polyunit ? num_polyunit : poly.size(); for_each_modarith( - modarith, [](const ModArith &ma, u64 *p) { ma.backwardNTT(p); }, + modarith, [](const ModArith &ma, u64 *p) { ma.backwardNTT(p); }, num_polyunit, poly); for (Size i = 0; i < num_polyunit; ++i) { poly[i].setNTT(false); } } -void addPoly(const std::vector &modarith, const Polynomial &op1, +template +void addPoly(const std::vector> &modarith, const Polynomial &op1, const Polynomial &op2, Polynomial &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[addPoly] operands NTT state mismatch"); @@ -117,7 +138,28 @@ void addPoly(const std::vector &modarith, const Polynomial &op1, } } -void subPoly(const std::vector &modarith, const Polynomial &op1, +template +void addPolyConst(const std::vector> &modarith, + const Polynomial &op1, const Polynomial &op2, Polynomial &res, + Size num_polyunit) { + deb_assert(op1[0].isNTT() == op2[0].isNTT(), + "[addPoly] operands NTT state mismatch"); + res.setNTT(op1[0].isNTT()); + + const auto degree = res[0].degree(); + num_polyunit = num_polyunit ? num_polyunit : res.size(); + + PRAGMA_OMP(omp for collapse(2) schedule(static)) + for (Size i = 0; i < num_polyunit; ++i) { + for (Size j = 0; j < degree; ++j) { + res[i][j] = + subIfGEConst(op1[i][j] + op2[i][j], modarith[i].getPrime()); + } + } +} + +template +void subPoly(const std::vector> &modarith, const Polynomial &op1, const Polynomial &op2, Polynomial &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[subPoly] operands NTT state mismatch"); @@ -129,14 +171,16 @@ void subPoly(const std::vector &modarith, const Polynomial &op1, PRAGMA_OMP(omp for collapse(2) schedule(static)) for (Size i = 0; i < num_polyunit; ++i) { for (Size j = 0; j < degree; ++j) { - res[i][j] = (op1[i][j] >= op2[i][j]) - ? op1[i][j] - op2[i][j] - : modarith[i].getPrime() - op2[i][j] + op1[i][j]; + const u64 tmp = op1[i][j] - op2[i][j]; + // mask is 0xFFFFFFFF if op1[i][j] < op2[i][j], 0x0 otherwise + const u64 mask = ~((tmp >> 63) - 1); + res[i][j] = tmp + (modarith[i].getPrime() & mask); } } } -void mulPoly(const std::vector &modarith, const Polynomial &op1, +template +void mulPoly(const std::vector> &modarith, const Polynomial &op1, const Polynomial &op2, Polynomial &res, Size num_polyunit) { deb_assert(op1[0].isNTT() == op2[0].isNTT(), "[mulPoly] operands NTT state mismatch"); @@ -157,8 +201,33 @@ void mulPoly(const std::vector &modarith, const Polynomial &op1, } } -void constMulPoly(const std::vector &modarith, const Polynomial &op1, - const u64 *op2, Polynomial &res, Size s_id, Size e_id) { +template +void mulPolyConst(const std::vector> &modarith, + const Polynomial &op1, const Polynomial &op2, Polynomial &res, + Size num_polyunit) { + deb_assert(op1[0].isNTT() == op2[0].isNTT(), + "[mulPoly] operands NTT state mismatch"); + res.setNTT(op1[0].isNTT()); + + const auto degree = res[0].degree(); + num_polyunit = num_polyunit ? num_polyunit : res.size(); + + PRAGMA_OMP(omp for collapse(2) schedule(static)) + for (Size i = 0; i < num_polyunit; ++i) { + for (Size j = 0; j < degree; ++j) { + u128 prod = mul64To128(op1[i][j], op2[i][j]); + u64 c1 = u128Lo(prod >> (modarith[i].get_barrett_expt() - 1)); + u64 c2 = mul64To128Hi(c1, modarith[i].get_barrett_ratio()); + u64 c3 = u128Lo(prod) - c2 * modarith[i].getPrime(); + res[i][j] = subIfGEConst(c3, modarith[i].getPrime()); + } + } +} + +template +void constMulPoly(const std::vector> &modarith, + const Polynomial &op1, const u64 *op2, Polynomial &res, + Size s_id, Size e_id) { res.setNTT(op1[0].isNTT()); PRAGMA_OMP(omp for schedule(static)) @@ -166,4 +235,7 @@ void constMulPoly(const std::vector &modarith, const Polynomial &op1, modarith[i].constMult(op1[i].data(), op2[i], res[i].data()); } } +#define D(degree) DECL_MODARITH_HELPER(degree, ) +DEGREE_SET +#undef D } // namespace deb::utils diff --git a/src/NTT.cpp b/src/NTT.cpp index 0d7340a..b6bef49 100644 --- a/src/NTT.cpp +++ b/src/NTT.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,22 +67,22 @@ u64 findPrimitiveRoot(u64 prime) { } } // namespace utils + namespace { -inline void butterfly(u64 &x, u64 &y, const u64 w, const u64 ws, const u64 p1, - const u64 p2) { - u64 tx = subIfGE(x, p2); - u64 ty = mulModLazy(y, w, ws, p1); - x = tx + ty; - y = tx + p2 - ty; +static inline void butterfly(u64 &x, u64 &y, const u64 &w, const u64 &ws, + const u64 &p1, const u64 &p2) { + const u64 ty = mulModLazy(y, w, ws, p1); + x = subIfGE(x, p2); + y = x + p2 - ty; + x += ty; } -inline void butterflyInv(u64 &x, u64 &y, const u64 w, const u64 ws, - const u64 p1, const u64 p2) { - u64 tx = x + y; - u64 ty = x + p2 - y; - x = subIfGE(tx, p2); - y = mulModLazy(ty, w, ws, p1); +static inline void butterflyInv(u64 &x, u64 &y, const u64 &w, const u64 &ws, + const u64 &p1, const u64 &p2) { + const u64 tx = subIfGE(x + y, p2); + y = mulModLazy(x + p2 - y, w, ws, p1); + x = tx; } } // anonymous namespace @@ -92,7 +92,7 @@ NTT::NTT(u64 degree, u64 prime) psi_rev_(degree_), psi_inv_rev_(degree_), psi_rev_shoup_(degree_), psi_inv_rev_shoup_(degree_) { - const u64 num_roots = degree; + const u64 num_roots = degree_; if (prime % (2 * num_roots) != 1) throw std::runtime_error("Not an NTT-friendly prime given."); @@ -339,6 +339,7 @@ void NTT::computeBackwardNativeSingleStep(u64 *op, const u64 t) const { } break; case 8: + DEB_LOOP_UNROLL_8 for (u64 i = 0; i < (degree >> 4); ++i) { butterflyInv(op[16 * i + 0], op[16 * i + 8], w_ptr[i], ws_ptr[i], prime, two_prime); diff --git a/src/Context.cpp b/src/OmpUtils.cpp similarity index 59% rename from src/Context.cpp rename to src/OmpUtils.cpp index d20227a..a6338de 100644 --- a/src/Context.cpp +++ b/src/OmpUtils.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,40 +14,22 @@ * limitations under the License. */ -#include "Context.hpp" -#include "utils/Basic.hpp" +#include "utils/OmpUtils.hpp" -#include +#include #ifdef DEB_OPENMP #include #endif -namespace deb { -// Mapping from preset enum to preset struct -Context getContext(Preset preset) { - return ContextPool::GetInstance().get(preset); -} - -bool isValidPreset([[maybe_unused]] Preset preset) { -#ifdef DEB_RESOURCE_CHECK - switch (preset) { -#define X(NAME) case PRESET_##NAME: - PRESET_LIST - return true; -#undef X - case PRESET_EMPTY: - default: - return false; - } - return false; -#else - return true; -#endif -} +namespace deb::utils { +static int g_omp_threads = -1; void setOmpThreadLimit([[__maybe_unused__]] int max_threads) { #ifdef DEB_OPENMP int current = omp_get_max_threads(); + if (g_omp_threads == -1) { + g_omp_threads = current; + } if (max_threads < current) { omp_set_num_threads(max_threads); } @@ -56,8 +38,17 @@ void setOmpThreadLimit([[__maybe_unused__]] int max_threads) { void unsetOmpThreadLimit() { #ifdef DEB_OPENMP - omp_set_num_threads(omp_get_max_threads()); + if (g_omp_threads != -1) { + omp_set_num_threads(g_omp_threads); + g_omp_threads = -1; + } else { + const char *env_p = std::getenv("OMP_NUM_THREADS"); + if (env_p != nullptr) { + int env_threads = std::atoi(env_p); + omp_set_num_threads(env_threads); + } + } #endif } -} // namespace deb +} // namespace deb::utils diff --git a/src/Preset.cpp b/src/Preset.cpp new file mode 100644 index 0000000..4a52174 --- /dev/null +++ b/src/Preset.cpp @@ -0,0 +1,33 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "Preset.hpp" + +namespace deb { + +#define CV(type, var_name) \ + type get_##var_name(Preset preset) { \ + return std::visit( \ + [](auto &&p) -> type { \ + using T = std::decay_t; \ + return T::var_name; \ + }, \ + preset_map[preset]); \ + } +CONST_LIST +#undef CV + +} // namespace deb diff --git a/src/RandomGenerator.cpp b/src/RandomGenerator.cpp new file mode 100644 index 0000000..e37f77f --- /dev/null +++ b/src/RandomGenerator.cpp @@ -0,0 +1,53 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/RandomGenerator.hpp" +#include "utils/AleaRandomGenerator.hpp" + +#include + +namespace deb { + +namespace { +std::mutex g_factory_mutex; +RandomGeneratorFactory g_factory; +} // namespace + +// Sets a custom random generator factory. +// The factory should be a function that takes an RNGSeed and +// returns a shared pointer to a RandomGenerator. +// The custom generator must ensure that the unpredictability, +// unbiased, forward/backward security properties hold for the +// generated random values. To reset to the default random +// generator, call setRandomGeneratorFactory with an empty factory or nullptr. +void setRandomGeneratorFactory(RandomGeneratorFactory factory) { + std::lock_guard lock(g_factory_mutex); + g_factory = std::move(factory); +} + +std::shared_ptr createRandomGenerator(const RNGSeed &seed) { + RandomGeneratorFactory factory_copy; + { + std::lock_guard lock(g_factory_mutex); + factory_copy = g_factory; + } + if (factory_copy) { + return factory_copy(seed); + } + return std::make_shared(seed); +} + +} // namespace deb diff --git a/src/SecretKeyGenerator.cpp b/src/SecretKeyGenerator.cpp index 285cbaa..f9f5740 100644 --- a/src/SecretKeyGenerator.cpp +++ b/src/SecretKeyGenerator.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,10 +40,9 @@ void SecretKeyGenerator::genSecretKeyFromCoeffInplace(SecretKey &sk, } i8 *SecretKeyGenerator::GenCoeff(const Preset preset, const RNGSeed seed) { - const auto context = getContext(preset); - const auto dim = context->get_degree(); - const auto num_secret = context->get_num_secret(); - const auto section_size = context->get_rank() * dim; + const auto dim = get_degree(preset); + const auto num_secret = get_num_secret(preset); + const auto section_size = get_rank(preset) * dim; const auto size = section_size * num_secret; i8 *coeffs = new i8[size]; GenCoeffInplace(preset, coeffs, seed); @@ -52,30 +51,25 @@ i8 *SecretKeyGenerator::GenCoeff(const Preset preset, const RNGSeed seed) { RNGSeed SecretKeyGenerator::GenCoeffInplace(const Preset preset, i8 *coeffs, std::optional seed) { - const auto context = getContext(preset); - const auto dim = context->get_degree(); - const auto num_secret = context->get_num_secret(); - const auto section_size = context->get_rank() * dim; + const auto dim = get_degree(preset); + const auto num_secret = get_num_secret(preset); + const auto section_size = get_rank(preset) * dim; if (!seed) { seed.emplace(SeedGenerator::Gen()); } - alea_state *as = alea_init(reinterpret_cast(seed->data()), - ALEA_ALGORITHM_SHAKE256); - // Sample Hamming weight + auto rng = createRandomGenerator(seed.value()); for (Size i = 0; i < num_secret; ++i) { - alea_sample_hwt_int8_array( - as, coeffs + i * section_size, section_size, - static_cast(context->get_hamming_weight())); + rng->sampleHwtInt8Array(coeffs + i * section_size, section_size, + static_cast(get_hamming_weight(preset))); } - alea_free(as); return seed.value(); } SecretKey SecretKeyGenerator::ComputeEmbedding(const Preset preset, const i8 *coeffs, std::optional level) { - level = level.value_or(getContext(preset)->get_num_p() - 1); + level = level.value_or(get_num_p(preset) - 1); SecretKey sk(preset); sk.allocPolys(level.value() + 1); ComputeEmbeddingInplace(sk, coeffs); @@ -84,40 +78,34 @@ SecretKey SecretKeyGenerator::ComputeEmbedding(const Preset preset, void SecretKeyGenerator::ComputeEmbeddingInplace(SecretKey &sk, const i8 *coeffs) { - const auto context = getContext(sk.preset()); - const auto dim = context->get_degree(); - const auto num_secret = context->get_num_secret(); - const auto rank = context->get_rank(); - const auto section_size = rank * dim; + const auto dim = get_degree(sk.preset()); + const auto num_secret = get_num_secret(sk.preset()); + const auto rank = get_rank(sk.preset()); deb_assert(coeffs != nullptr, "[SecretKeyGenerator::ComputeEmbeddingInplace] Coefficients are " "not allocated."); if (sk.coeffs() != coeffs) { sk.allocCoeffs(); - memcpy(sk.coeffs(), coeffs, section_size * num_secret * sizeof(i8)); + memcpy(sk.coeffs(), coeffs, rank * dim * num_secret * sizeof(i8)); } - if (sk.numPoly() != num_secret * rank) { + if (sk.numPoly() != rank * num_secret) { sk.allocPolys(); } - for (Size ns_id = 0; ns_id < num_secret; ++ns_id) { - for (Size i = 0; i < rank; ++i) { - const Size idx = ns_id * rank + i; - for (Size j = 0; j < sk[idx].size(); ++j) { - u64 *ptr = sk[idx][j].data(); - for (Size k = 0; k < dim; ++k) { - ptr[k] = - (sk.coeffs()[i * dim + k] >= 0) - ? static_cast(sk.coeffs()[i * dim + k]) - : context->get_primes()[j] - - static_cast(-sk.coeffs()[i * dim + k]); - } - // TODO: reuse NTT object - utils::NTT ntt(context->get_degree(), context->get_primes()[j]); - ntt.computeForward(sk[idx][j].data()); - sk[idx][j].setNTT(true); + for (Size i = 0; i < rank * num_secret; ++i) { + for (Size j = 0; j < sk[i].size(); ++j) { + u64 *ptr = sk[i][j].data(); + for (Size k = 0; k < dim; ++k) { + ptr[k] = (sk.coeffs()[i * dim + k] >= 0) + ? static_cast(sk.coeffs()[i * dim + k]) + : get_primes(sk.preset())[j] - + static_cast(-sk.coeffs()[i * dim + k]); } + // TODO: reuse NTT object + utils::NTT ntt(dim, get_primes(sk.preset())[j]); + ntt.computeForward(sk[i][j].data()); + sk[i][j].setNTT(true); } } } @@ -150,20 +138,19 @@ void SecretKeyGenerator::GenSecretKeyFromCoeffInplace(SecretKey &sk, } void completeSecretKey(SecretKey &sk, std::optional level) { - const auto context = getContext(sk.preset()); - const auto rank = context->get_rank(); - const auto num_secret = context->get_num_secret(); - const auto degree = context->get_degree(); + const auto rank = get_rank(sk.preset()); + const auto num_secret = get_num_secret(sk.preset()); + const auto degree = get_degree(sk.preset()); if (sk.coeffsSize() != rank * num_secret * degree) { sk.allocCoeffs(); if (!sk.hasSeed()) { throw std::runtime_error( "[completeSecretKey] Secret key has no seed."); } - SecretKeyGenerator::GenCoeffInplace(context->get_preset(), sk.coeffs(), + SecretKeyGenerator::GenCoeffInplace(sk.preset(), sk.coeffs(), sk.getSeed()); } - level = level.value_or(context->get_num_p() - 1); + level = level.value_or(get_num_p(sk.preset()) - 1); if (sk.numPoly() != num_secret * rank || sk[0].size() != level.value() + 1) { sk.allocPolys(level.value() + 1); diff --git a/src/SeedGenerator.cpp b/src/SeedGenerator.cpp index 7f6bd84..efa561f 100644 --- a/src/SeedGenerator.cpp +++ b/src/SeedGenerator.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,22 +22,19 @@ namespace deb { -const u8 *to_alea_seed(const RNGSeed &seed) { - return reinterpret_cast(seed.data()); -} - SeedGenerator &SeedGenerator::GetInstance(std::optional seeds) { static SeedGenerator instance(seeds); return instance; } void SeedGenerator::Reseed(const std::optional &seeds) { - alea_reseed(GetInstance().as_.get(), to_alea_seed(seeds.value())); + const auto &s = seeds.value(); + GetInstance().rng_->reseed(reinterpret_cast(s.data()), + DEB_RNG_SEED_BYTE_SIZE); } RNGSeed SeedGenerator::Gen() { return GetInstance().genSeed(); } -SeedGenerator::SeedGenerator(std::optional seeds) - : as_(nullptr, &alea_free) { +SeedGenerator::SeedGenerator(std::optional seeds) { if (!seeds) { std::random_device rd; RNGSeed nseeds; @@ -49,14 +46,12 @@ SeedGenerator::SeedGenerator(std::optional seeds) } seeds.emplace(nseeds); } - alea_state *p = alea_init(reinterpret_cast(seeds->data()), - ALEA_ALGORITHM_SHAKE256); - as_ = std::unique_ptr(p, &alea_free); + rng_ = createRandomGenerator(seeds.value()); } RNGSeed SeedGenerator::genSeed() { RNGSeed seeds; - alea_get_random_uint64_array(as_.get(), seeds.data(), DEB_U64_SEED_SIZE); + rng_->getRandomUint64Array(seeds.data(), DEB_U64_SEED_SIZE); return seeds; } diff --git a/src/Serialize.cpp b/src/Serialize.cpp index d200518..cc058eb 100644 --- a/src/Serialize.cpp +++ b/src/Serialize.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,15 +15,14 @@ */ #include "Serialize.hpp" -#include "Context.hpp" namespace deb { std::vector toComplexVector(const Complex *data, const Size size) { - std::vector complex_vec; + std::vector complex_vec(size); for (Size i = 0; i < size; ++i) { - complex_vec.emplace_back(data[i].real(), data[i].imag()); + complex_vec[i] = {data[i].real(), data[i].imag()}; } return complex_vec; } @@ -31,30 +30,28 @@ std::vector toComplexVector(const Complex *data, std::vector toDebComplexVector(const Vector *data) { const Size size = data->size(); - std::vector Complex_vec; + std::vector Complex_vec(size); for (Size i = 0; i < size; ++i) { - Complex_vec.emplace_back(data->Get(i)->real(), data->Get(i)->imag()); + Complex_vec[i] = {data->Get(i)->real(), data->Get(i)->imag()}; } return Complex_vec; } std::vector toComplex32Vector(const ComplexT *data, const Size size) { - std::vector complex_vec; + std::vector complex_vec(size); for (Size i = 0; i < size; ++i) { - complex_vec.emplace_back(static_cast(data[i].real()), - static_cast(data[i].imag())); + complex_vec[i] = {data[i].real(), data[i].imag()}; } return complex_vec; } -std::vector +std::vector> toDebComplex32Vector(const Vector *data) { const Size size = data->size(); - std::vector Complex_vec; + std::vector> Complex_vec(size); for (Size i = 0; i < size; ++i) { - Complex_vec.emplace_back(static_cast(data->Get(i)->real()), - static_cast(data->Get(i)->imag())); + Complex_vec[i] = {data->Get(i)->real(), data->Get(i)->imag()}; } return Complex_vec; } @@ -68,10 +65,19 @@ serializeMessage(flatbuffers::FlatBufferBuilder &builder, } Message deserializeMessage(const deb_fb::Message *message) { - Message msg(message->size()); - memcpy(msg.data(), toDebComplexVector(message->data()).data(), - message->size() * sizeof(Complex)); - return msg; + return Message(toDebComplexVector(message->data())); +} + +flatbuffers::Offset +serializeFMessage(flatbuffers::FlatBufferBuilder &builder, + const FMessage &message) { + auto complex_offset = builder.CreateVectorOfStructs( + toComplex32Vector(message.data(), message.size())); + return CreateMessage32(builder, message.size(), complex_offset); +} + +FMessage deserializeFMessage(const deb_fb::Message32 *message) { + return FMessage(toDebComplex32Vector(message->data())); } flatbuffers::Offset @@ -89,6 +95,21 @@ CoeffMessage deserializeCoeff(const deb_fb::Coeff *coeff) { return coeff_t; } +flatbuffers::Offset +serializeFCoeff(flatbuffers::FlatBufferBuilder &builder, + const FCoeffMessage &coeff) { + return deb_fb::CreateCoeff32( + builder, coeff.size(), + builder.CreateVector(coeff.data(), coeff.size())); +} + +FCoeffMessage deserializeFCoeff(const deb_fb::Coeff32 *coeff) { + FCoeffMessage coeff_t(coeff->size()); + std::memcpy(coeff_t.data(), coeff->data()->data(), + coeff_t.size() * sizeof(float)); + return coeff_t; +} + flatbuffers::Offset serializePolyUnit(flatbuffers::FlatBufferBuilder &builder, const PolyUnit &polyunit) { @@ -117,7 +138,7 @@ serializePoly(flatbuffers::FlatBufferBuilder &builder, const Polynomial &poly) { } Polynomial deserializePoly(const Preset preset, const deb_fb::Poly *poly) { - Polynomial poly_t(getContext(preset), poly->size()); + Polynomial poly_t(preset, poly->size()); for (Size i = 0; i < poly_t.size(); ++i) { poly_t[i] = deserializePolyUnit(poly->rnspolys()->Get(i)); } @@ -139,7 +160,7 @@ serializeCipher(flatbuffers::FlatBufferBuilder &builder, Ciphertext deserializeCipher(const deb_fb::Cipher *cipher) { auto preset = static_cast(cipher->preset()); - Ciphertext cipher_t(getContext(preset), cipher->bigpolys()->Get(0)->size(), + Ciphertext cipher_t(preset, cipher->bigpolys()->Get(0)->size(), cipher->size()); cipher_t.setEncoding(static_cast(cipher->encoding())); for (Size i = 0; i < cipher_t.numPoly(); ++i) { @@ -150,7 +171,6 @@ Ciphertext deserializeCipher(const deb_fb::Cipher *cipher) { flatbuffers::Offset serializeSk(flatbuffers::FlatBufferBuilder &builder, const SecretKey &sk) { - auto context = getContext(sk.preset()); auto seed_offset = builder.CreateVector(sk.hasSeed() ? sk.getSeed().data() : nullptr, sk.hasSeed() ? sk.getSeed().size() : 0); @@ -188,7 +208,6 @@ SecretKey deserializeSk(const deb_fb::Sk *sk) { flatbuffers::Offset serializeSwk(flatbuffers::FlatBufferBuilder &builder, const SwitchKey &swk) { - auto context = getContext(swk.preset()); std::vector> ax_offsets, bx_offsets; ax_offsets.reserve(swk.axSize()); bx_offsets.reserve(swk.bxSize()); @@ -207,8 +226,7 @@ serializeSwk(flatbuffers::FlatBufferBuilder &builder, const SwitchKey &swk) { SwitchKey deserializeSwk(const deb_fb::Swk *swk) { const auto preset = static_cast(swk->preset()); - SwitchKey swk_t(getContext(preset), - static_cast(swk->type())); + SwitchKey swk_t(preset, static_cast(swk->type())); swk_t.getAx().clear(); for (Size i = 0; i < swk->ax()->size(); ++i) { Polynomial tmp = deserializePoly(preset, swk->ax()->Get(i)); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bce7d4a..310a68f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,5 @@ # ~~~ -# Copyright 2025 CryptoLab, Inc. +# Copyright 2026 CryptoLab, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -34,22 +34,28 @@ cpmaddpackage( # --------------------------------------------------------------------------- set(deb_test deb_obj deb_DevHeader gtest_main) -# Ntt Test -add_executable(NTT-test NTT-test.cpp) -target_link_libraries(NTT-test PRIVATE ${deb_test}) -add_gtest_target_to_ctest(NTT-test) + +# Operation Test +add_executable(Operation-test Operation-test.cpp) +target_link_libraries(Operation-test PRIVATE ${deb_test}) +add_gtest_target_to_ctest(Operation-test) # EnDecryption Test add_executable(EnDecryption-test EnDecryption-test.cpp) target_link_libraries(EnDecryption-test PRIVATE ${deb_test}) add_gtest_target_to_ctest(EnDecryption-test) +# KeyGen Test +add_executable(KeyGen-test KeyGen-test.cpp) +target_link_libraries(KeyGen-test PRIVATE ${deb_test}) +add_gtest_target_to_ctest(KeyGen-test) + # Serialize Test add_executable(Serialize-test Serialize-test.cpp) target_link_libraries(Serialize-test PRIVATE ${deb_test}) add_gtest_target_to_ctest(Serialize-test) -# KeyGen Test -add_executable(KeyGen-test KeyGen-test.cpp) -target_link_libraries(KeyGen-test PRIVATE ${deb_test}) -add_gtest_target_to_ctest(KeyGen-test) +# Ntt Test +add_executable(NTT-test NTT-test.cpp) +target_link_libraries(NTT-test PRIVATE ${deb_test}) +add_gtest_target_to_ctest(NTT-test) diff --git a/test/EnDecryption-test.cpp b/test/EnDecryption-test.cpp index 5983955..8b2bd07 100644 --- a/test/EnDecryption-test.cpp +++ b/test/EnDecryption-test.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,39 +16,40 @@ #include "DebParam.hpp" #include "TestBase.hpp" +#include "utils/OmpUtils.hpp" using namespace deb; class EnDecrypt : public DebTestBase {}; TEST_P(EnDecrypt, EncryptWithEmptySecretKey) { - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); SecretKey sk(preset, SeedGenerator::Gen()); - Ciphertext ctxt(context); - DEB_EXPECT(encryptor.encrypt(msg, sk, ctxt)); + Ciphertext ctxt(preset); + DEB_TEST_EXPECT(encryptor.encrypt(msg, sk, ctxt)); } TEST_P(EnDecrypt, DecryptWithEmptySecretKey) { - setOmpThreadLimit(1); - MSGS msg = gen_random_message(); + utils::setOmpThreadLimit(1); + MSGS msg = gen_random_message(); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset, SeedGenerator::Gen()); - Ciphertext ctxt(context); + Ciphertext ctxt(preset); encryptor.encrypt(msg, sk, ctxt); sk.allocPolys(0); - DEB_EXPECT(decryptor.decrypt(ctxt, sk, msg)); - unsetOmpThreadLimit(); + DEB_TEST_EXPECT(decryptor.decrypt(ctxt, sk, msg)); + utils::unsetOmpThreadLimit(); } TEST_P(EnDecrypt, EncryptAndDecryptWithSecretKey) { - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - MSGS decrypted_msg = gen_empty_message(); + MSGS decrypted_msg = gen_empty_message(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); MSGS scaled_msg = scale_message(msg, l); encryptor.encrypt(scaled_msg, sk, ctxt, EncryptOptions().Level(l)); decryptor.decrypt(ctxt, sk, decrypted_msg); @@ -57,16 +58,31 @@ TEST_P(EnDecrypt, EncryptAndDecryptWithSecretKey) { } } +TEST_P(EnDecrypt, EncryptAndDecryptFloatWithSecretKey) { + FMSGS msg = gen_random_message(); + + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + FMSGS decrypted_msg = gen_empty_message(); + + for (Size l = 0; l < std::min(2U, get_num_p(preset)); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(msg, sk, ctxt, EncryptOptions().Level(l)); + decryptor.decrypt(ctxt, sk, decrypted_msg); + + compare_msg(msg, decrypted_msg, scale_error(sk_err_f, l)); + } +} + TEST_P(EnDecrypt, EncryptAndDecryptWithEncKey) { - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); KeyGenerator keygen(preset); SwitchKey enckey = keygen.genEncKey(sk); - MSGS decrypted_msg = gen_empty_message(); + MSGS decrypted_msg = gen_empty_message(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); MSGS scaled_msg = scale_message(msg, l); encryptor.encrypt(scaled_msg, enckey, ctxt, EncryptOptions().Level(l)); decryptor.decrypt(ctxt, sk, decrypted_msg); @@ -75,21 +91,38 @@ TEST_P(EnDecrypt, EncryptAndDecryptWithEncKey) { } } +TEST_P(EnDecrypt, EncryptAndDecryptFloatWithEncKey) { + FMSGS msg = gen_random_message(); + + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + FMSGS decrypted_msg = gen_empty_message(); + + for (Size l = 0; l < std::min(2U, get_num_p(preset)); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(msg, enckey, ctxt, EncryptOptions().Level(l)); + decryptor.decrypt(ctxt, sk, decrypted_msg); + + compare_msg(msg, decrypted_msg, scale_error(enc_err_f, l)); + } +} + TEST_P(EnDecrypt, ScaleEncryptAndDecryptWithSecretKey) { - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); const int max_scale_bit = - static_cast(utils::bitWidth(context->get_primes()[0]) - 2); + static_cast(utils::bitWidth(get_primes(preset)[0]) - 2); const double min_scale_bit = static_cast(30 + log_error); const double scale_bit = min_scale_bit + abs(dist(gen)) * (max_scale_bit - min_scale_bit); const double scale = std::pow(2.0, scale_bit); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - MSGS decrypted_msg = gen_empty_message(); + MSGS decrypted_msg = gen_empty_message(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); encryptor.encrypt(msg, sk, ctxt, EncryptOptions().Level(l).Scale(scale)); decryptor.decrypt(ctxt, sk, decrypted_msg, scale); @@ -99,10 +132,10 @@ TEST_P(EnDecrypt, ScaleEncryptAndDecryptWithSecretKey) { TEST_P(EnDecrypt, ScaleEncryptAndDecryptWithEncKey) { - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); const int max_scale_bit = - static_cast(utils::bitWidth(context->get_primes()[0]) - 2); + static_cast(utils::bitWidth(get_primes(preset)[0]) - 2); const double min_scale_bit = static_cast(30 + log_error); const double scale_bit = min_scale_bit + abs(dist(gen)) * (max_scale_bit - min_scale_bit); @@ -110,10 +143,10 @@ TEST_P(EnDecrypt, ScaleEncryptAndDecryptWithEncKey) { SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); SwitchKey enckey = KeyGenerator(preset).genEncKey(sk); - MSGS decrypted_msg = gen_empty_message(); + MSGS decrypted_msg = gen_empty_message(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); encryptor.encrypt(msg, enckey, ctxt, EncryptOptions().Level(l).Scale(scale)); decryptor.decrypt(ctxt, sk, decrypted_msg, scale); @@ -123,13 +156,13 @@ TEST_P(EnDecrypt, ScaleEncryptAndDecryptWithEncKey) { TEST_P(EnDecrypt, EncryptAndDecryptCoeffWithSecretKey) { - COEFFS coeff = gen_random_coeff(); + COEFFS coeff = gen_random_coeff(); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - COEFFS decrypted_coeff = gen_empty_coeff(); + COEFFS decrypted_coeff = gen_empty_coeff(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); COEFFS scaled_coeff = scale_coeff(coeff, l); encryptor.encrypt(scaled_coeff, sk, ctxt, EncryptOptions().Level(l)); decryptor.decrypt(ctxt, sk, decrypted_coeff); @@ -137,17 +170,32 @@ TEST_P(EnDecrypt, EncryptAndDecryptCoeffWithSecretKey) { } } +TEST_P(EnDecrypt, EncryptAndDecryptFloatCoeffWithSecretKey) { + + FCOEFFS coeff = gen_random_coeff(); + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + + FCOEFFS decrypted_coeff = gen_empty_coeff(); + + for (Size l = 0; l < std::min(2U, get_num_p(preset)); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(coeff, sk, ctxt, EncryptOptions().Level(l)); + decryptor.decrypt(ctxt, sk, decrypted_coeff); + compare_coeff(coeff, decrypted_coeff, scale_error(sk_err_f, l)); + } +} + TEST_P(EnDecrypt, EncryptAndDecryptCoeffWithEncKey) { - COEFFS coeff = gen_random_coeff(); + COEFFS coeff = gen_random_coeff(); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); KeyGenerator keygen(preset); SwitchKey enckey = keygen.genEncKey(sk); - COEFFS decrypted_coeff = gen_empty_coeff(); + COEFFS decrypted_coeff = gen_empty_coeff(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); COEFFS scaled_coeff = scale_coeff(coeff, l); encryptor.encrypt(scaled_coeff, enckey, ctxt, EncryptOptions().Level(l)); @@ -156,22 +204,39 @@ TEST_P(EnDecrypt, EncryptAndDecryptCoeffWithEncKey) { } } +TEST_P(EnDecrypt, EncryptAndDecryptFloatCoeffWithEncKey) { + + FCOEFFS coeff = gen_random_coeff(); + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); + KeyGenerator keygen(preset); + SwitchKey enckey = keygen.genEncKey(sk); + + FCOEFFS decrypted_coeff = gen_empty_coeff(); + + for (Size l = 0; l < std::min(2U, get_num_p(preset)); ++l) { + Ciphertext ctxt(preset, l); + encryptor.encrypt(coeff, enckey, ctxt, EncryptOptions().Level(l)); + decryptor.decrypt(ctxt, sk, decrypted_coeff); + compare_coeff(coeff, decrypted_coeff, scale_error(enc_err_f, l)); + } +} + TEST_P(EnDecrypt, ScaleEncryptAndDecryptCoeffWithSecretKey) { - COEFFS coeff = gen_random_coeff(); + COEFFS coeff = gen_random_coeff(); const int max_scale_bit = - static_cast(utils::bitWidth(context->get_primes()[0]) - 2); + static_cast(utils::bitWidth(get_primes(preset)[0]) - 2); const double min_scale_bit = static_cast(30 + log_error); const double scale_bit = min_scale_bit + abs(dist(gen)) * (max_scale_bit - min_scale_bit); const double scale = std::pow(2.0, scale_bit); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - COEFFS decrypted_coeff = gen_empty_coeff(); + COEFFS decrypted_coeff = gen_empty_coeff(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); encryptor.encrypt(coeff, sk, ctxt, EncryptOptions().Level(l).Scale(scale)); decryptor.decrypt(ctxt, sk, decrypted_coeff, scale); @@ -181,10 +246,10 @@ TEST_P(EnDecrypt, ScaleEncryptAndDecryptCoeffWithSecretKey) { TEST_P(EnDecrypt, ScaleEncryptAndDecryptCoeffWithEncKey) { - COEFFS coeff = gen_random_coeff(); + COEFFS coeff = gen_random_coeff(); const int max_scale_bit = - static_cast(utils::bitWidth(context->get_primes()[0]) - 2); + static_cast(utils::bitWidth(get_primes(preset)[0]) - 2); const double min_scale_bit = static_cast(30 + log_error); const double scale_bit = min_scale_bit + abs(dist(gen)) * (max_scale_bit - min_scale_bit); @@ -192,10 +257,10 @@ TEST_P(EnDecrypt, ScaleEncryptAndDecryptCoeffWithEncKey) { SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); SwitchKey enckey = KeyGenerator(preset).genEncKey(sk); - COEFFS decrypted_coeff = gen_empty_coeff(); + COEFFS decrypted_coeff = gen_empty_coeff(); - for (Size l = 0; l < context->get_num_p(); ++l) { - Ciphertext ctxt(context, l); + for (Size l = 0; l < get_num_p(preset); ++l) { + Ciphertext ctxt(preset, l); encryptor.encrypt(coeff, enckey, ctxt, EncryptOptions().Level(l).Scale(scale)); decryptor.decrypt(ctxt, sk, decrypted_coeff, scale); diff --git a/test/KeyGen-test.cpp b/test/KeyGen-test.cpp index 5184343..8ebfec3 100644 --- a/test/KeyGen-test.cpp +++ b/test/KeyGen-test.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,115 +22,94 @@ using namespace deb; class KeyGen : public DebTestBase { public: - SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - KeyGenerator keygen{sk}; - - void compareArray(const u64 *arr1, const u64 *arr2, const Size size) { - for (Size i = 0; i < size; ++i) { - ASSERT_EQ(arr1[i], arr2[i]); - } - } - - void comparePoly(const PolyUnit &poly1, const PolyUnit &poly2) { - ASSERT_EQ(poly1.prime(), poly2.prime()); - ASSERT_EQ(poly1.degree(), poly2.degree()); - ASSERT_EQ(poly1.isNTT(), poly2.isNTT()); - compareArray(poly1.data(), poly2.data(), poly1.degree()); - } + const Size gadget_rank = get_gadget_rank(preset); + const Size num_p = get_num_p(preset); - void compareBigPoly(const Polynomial &bigpoly1, - const Polynomial &bigpoly2) { - ASSERT_EQ(bigpoly1.size(), bigpoly2.size()); - for (Size i = 0; i < bigpoly1.size(); ++i) { - comparePoly(bigpoly1[i], bigpoly2[i]); - } - } + KeyGenerator keygen{preset}; + SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); }; TEST_P(KeyGen, GenEncryptionKey) { - SwitchKey enckey(context, SwitchKeyKind::SWK_ENC); - ASSERT_NO_THROW(enckey = keygen.genEncKey()); - ASSERT_NO_THROW(keygen.genEncKeyInplace(enckey)); + SwitchKey enckey(preset, SwitchKeyKind::SWK_ENC); + ASSERT_NO_THROW(enckey = keygen.genEncKey(sk)); + ASSERT_NO_THROW(keygen.genEncKeyInplace(enckey, sk)); ASSERT_EQ(enckey.axSize(), 1); - ASSERT_EQ(enckey.bxSize(), context->get_num_secret()); - ASSERT_EQ(enckey.ax().size(), context->get_num_p()); - ASSERT_EQ(enckey.bx().size(), context->get_num_p()); + ASSERT_EQ(enckey.bxSize(), num_secret); + ASSERT_EQ(enckey.ax().size(), num_p); + ASSERT_EQ(enckey.bx().size(), num_p); } TEST_P(KeyGen, GenMultiplicationKey) { - SwitchKey mulkey(context, SwitchKeyKind::SWK_MULT); - ASSERT_NO_THROW(mulkey = keygen.genMultKey()); - ASSERT_NO_THROW(keygen.genMultKeyInplace(mulkey)); - - ASSERT_EQ(mulkey.axSize(), context->get_gadget_rank()); - ASSERT_EQ(mulkey.bxSize(), - context->get_gadget_rank() * context->get_num_secret()); - ASSERT_EQ(mulkey.ax().size(), context->get_num_p()); - ASSERT_EQ(mulkey.bx().size(), context->get_num_p()); + SwitchKey mulkey(preset, SwitchKeyKind::SWK_MULT); + ASSERT_NO_THROW(mulkey = keygen.genMultKey(sk)); + ASSERT_NO_THROW(keygen.genMultKeyInplace(mulkey, sk)); + + ASSERT_EQ(mulkey.axSize(), gadget_rank); + ASSERT_EQ(mulkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(mulkey.ax().size(), num_p); + ASSERT_EQ(mulkey.bx().size(), num_p); } TEST_P(KeyGen, GenConjugationKey) { - SwitchKey conjkey(context, SwitchKeyKind::SWK_CONJ); - ASSERT_NO_THROW(conjkey = keygen.genConjKey()); - ASSERT_NO_THROW(keygen.genConjKeyInplace(conjkey)); - - ASSERT_EQ(conjkey.axSize(), context->get_gadget_rank()); - ASSERT_EQ(conjkey.bxSize(), - context->get_gadget_rank() * context->get_num_secret()); - ASSERT_EQ(conjkey.ax().size(), context->get_num_p()); - ASSERT_EQ(conjkey.bx().size(), context->get_num_p()); + SwitchKey conjkey(preset, SwitchKeyKind::SWK_CONJ); + ASSERT_NO_THROW(conjkey = keygen.genConjKey(sk)); + ASSERT_NO_THROW(keygen.genConjKeyInplace(conjkey, sk)); + + ASSERT_EQ(conjkey.axSize(), gadget_rank); + ASSERT_EQ(conjkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(conjkey.ax().size(), num_p); + ASSERT_EQ(conjkey.bx().size(), num_p); } TEST_P(KeyGen, GenRotationKeys) { - const Size num_slots = context->get_num_slots(); const Size rot = dist_u64(gen) % (num_slots - 1) + 1; const RNGSeed seed = SeedGenerator::Gen(); - KeyGenerator keygen_same1(sk, seed); - KeyGenerator keygen_same2(sk, seed); - SwitchKey left_rotkey(context, SwitchKeyKind::SWK_ROT, rot); - ASSERT_NO_THROW(left_rotkey = keygen.genLeftRotKey(rot)); - ASSERT_NO_THROW(keygen_same1.genLeftRotKeyInplace(rot, left_rotkey)); + KeyGenerator keygen_same1(preset, seed); + KeyGenerator keygen_same2(preset, seed); + SwitchKey left_rotkey(preset, SwitchKeyKind::SWK_ROT, rot); + ASSERT_NO_THROW(left_rotkey = keygen.genLeftRotKey(rot, sk)); + ASSERT_NO_THROW(keygen_same1.genLeftRotKeyInplace(rot, left_rotkey, sk)); - SwitchKey right_rotkey(context, SwitchKeyKind::SWK_ROT, num_slots - rot); - ASSERT_NO_THROW(right_rotkey = keygen.genRightRotKey(num_slots - rot)); + SwitchKey right_rotkey(preset, SwitchKeyKind::SWK_ROT, num_slots - rot); + ASSERT_NO_THROW(right_rotkey = keygen.genRightRotKey(num_slots - rot, sk)); ASSERT_NO_THROW( - keygen_same2.genRightRotKeyInplace(num_slots - rot, right_rotkey)); - ASSERT_EQ(left_rotkey.axSize(), context->get_gadget_rank()); - ASSERT_EQ(left_rotkey.bxSize(), - context->get_gadget_rank() * context->get_num_secret()); - ASSERT_EQ(left_rotkey.ax().size(), context->get_num_p()); - ASSERT_EQ(left_rotkey.bx().size(), context->get_num_p()); - - ASSERT_EQ(right_rotkey.axSize(), context->get_gadget_rank()); - ASSERT_EQ(right_rotkey.bxSize(), - context->get_gadget_rank() * context->get_num_secret()); - ASSERT_EQ(right_rotkey.ax().size(), context->get_num_p()); - ASSERT_EQ(right_rotkey.bx().size(), context->get_num_p()); + keygen_same2.genRightRotKeyInplace(num_slots - rot, right_rotkey, sk)); + ASSERT_EQ(left_rotkey.axSize(), gadget_rank); + ASSERT_EQ(left_rotkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(left_rotkey.ax().size(), num_p); + ASSERT_EQ(left_rotkey.bx().size(), num_p); + + ASSERT_EQ(right_rotkey.axSize(), gadget_rank); + ASSERT_EQ(right_rotkey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(right_rotkey.ax().size(), num_p); + ASSERT_EQ(right_rotkey.bx().size(), num_p); for (Size i = 0; i < left_rotkey.axSize(); ++i) { - compareBigPoly(left_rotkey.ax(i), right_rotkey.ax(i)); + comparePoly(left_rotkey.ax(i), right_rotkey.ax(i)); } for (Size i = 0; i < left_rotkey.bxSize(); ++i) { - compareBigPoly(left_rotkey.bx(i), right_rotkey.bx(i)); + comparePoly(left_rotkey.bx(i), right_rotkey.bx(i)); } } TEST_P(KeyGen, GenAutomorphismKey) { const Size sig = dist_u64(gen) % (degree - 1) + 1; - SwitchKey autokey(context, SwitchKeyKind::SWK_AUTO); - ASSERT_NO_THROW(autokey = keygen.genAutoKey(sig)); - ASSERT_NO_THROW(keygen.genAutoKeyInplace(sig, autokey)); - - ASSERT_EQ(autokey.axSize(), context->get_gadget_rank()); - ASSERT_EQ(autokey.bxSize(), - context->get_gadget_rank() * context->get_num_secret()); - ASSERT_EQ(autokey.ax().size(), context->get_num_p()); - ASSERT_EQ(autokey.bx().size(), context->get_num_p()); + SwitchKey autokey(preset, SwitchKeyKind::SWK_AUTO); + ASSERT_NO_THROW(autokey = keygen.genAutoKey(sig, sk)); + ASSERT_NO_THROW(keygen.genAutoKeyInplace(sig, autokey, sk)); + + ASSERT_EQ(autokey.axSize(), gadget_rank); + ASSERT_EQ(autokey.bxSize(), gadget_rank * num_secret); + ASSERT_EQ(autokey.ax().size(), num_p); + ASSERT_EQ(autokey.bx().size(), num_p); } TEST_P(KeyGen, GenModPackKey) { + if (num_secret != 1) { + GTEST_SKIP() << "MODPACK key generation is only for single secret."; + } std::vector modkey; ASSERT_NO_THROW(modkey = keygen.genModPackKeyBundle(sk, sk)); ASSERT_NO_THROW(keygen.genModPackKeyBundleInplace(sk, sk, modkey)); @@ -141,16 +120,15 @@ TEST_P(KeyGen, GenModPackKeySelf) { GTEST_SKIP() << "MODPACK_SELF key generation is only for single secret."; } - const Size pad_rank = 1U - << (dist_u64(gen) % (context->get_log_degree() / 2)); - SwitchKey modevikey(context, SwitchKeyKind::SWK_MODPACK_SELF); - ASSERT_NO_THROW(modevikey = keygen.genModPackKeyBundle(pad_rank)); - ASSERT_NO_THROW(keygen.genModPackKeyBundleInplace(pad_rank, modevikey)); + const Size pad_rank = 1U << (dist_u64(gen) % (get_log_degree(preset) / 2)); + SwitchKey modevikey(preset, SwitchKeyKind::SWK_MODPACK_SELF); + ASSERT_NO_THROW(modevikey = keygen.genModPackKeyBundle(pad_rank, sk)); + ASSERT_NO_THROW(keygen.genModPackKeyBundleInplace(pad_rank, modevikey, sk)); ASSERT_EQ(modevikey.axSize(), pad_rank); - ASSERT_EQ(modevikey.bxSize(), pad_rank * context->get_num_secret()); - ASSERT_EQ(modevikey.ax().size(), context->get_num_p()); - ASSERT_EQ(modevikey.bx().size(), context->get_num_p()); + ASSERT_EQ(modevikey.bxSize(), pad_rank * num_secret); + ASSERT_EQ(modevikey.ax().size(), num_p); + ASSERT_EQ(modevikey.bx().size(), num_p); } #define X(PRESET) Preset::PRESET_##PRESET, diff --git a/test/NTT-test.cpp b/test/NTT-test.cpp index 762bd2e..9438576 100644 --- a/test/NTT-test.cpp +++ b/test/NTT-test.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/test/Operation-test.cpp b/test/Operation-test.cpp new file mode 100644 index 0000000..9450b29 --- /dev/null +++ b/test/Operation-test.cpp @@ -0,0 +1,390 @@ +/* + * Copyright 2026 CryptoLab, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/Basic.hpp" +#include "utils/Macro.hpp" + +#include +#include + +using namespace deb; +using namespace deb::utils; + +// RNG helpers +static std::mt19937_64 rng{std::random_device{}()}; +static u128 randomU128() { + return (static_cast(rng()) << 64) | static_cast(rng()); +} +static i128 randomI128() { return static_cast(randomU128()); } + +// --------------------------------------------- +// U128 tests +// --------------------------------------------- +class U128ArithTest : public ::testing::Test {}; + +// KAT tests +TEST_F(U128ArithTest, KAT_HiLo) { + constexpr u128 val = + (static_cast(UINT64_C(0xDEADBEEFCAFEBABE)) << 64) | + UINT64_C(0x0102030405060708); + EXPECT_EQ(u128Hi(val), UINT64_C(0xDEADBEEFCAFEBABE)); + EXPECT_EQ(u128Lo(val), UINT64_C(0x0102030405060708)); +} + +TEST_F(U128ArithTest, KAT_HiLo_Zero) { + EXPECT_EQ(u128Hi(static_cast(0)), UINT64_C(0)); + EXPECT_EQ(u128Lo(static_cast(0)), UINT64_C(0)); +} + +TEST_F(U128ArithTest, KAT_Mul64To128_MaxTimesMax) { + // UINT64_MAX * UINT64_MAX = (2^64-1)^2 = 2^128 - 2*2^64 + 1 + // hi = UINT64_MAX - 1, lo = 1 + u128 result = mul64To128(UINT64_MAX, UINT64_MAX); + EXPECT_EQ(u128Hi(result), UINT64_MAX - 1); + EXPECT_EQ(u128Lo(result), UINT64_C(1)); +} + +TEST_F(U128ArithTest, KAT_Mul64To128_PowersOfTwo) { + // (1 << 63) * 2 = 2^64 => hi = 1, lo = 0 + u128 result = mul64To128(UINT64_C(1) << 63, UINT64_C(2)); + EXPECT_EQ(u128Hi(result), UINT64_C(1)); + EXPECT_EQ(u128Lo(result), UINT64_C(0)); +} + +TEST_F(U128ArithTest, KAT_Mul64To128_ZeroOperand) { + EXPECT_EQ(mul64To128(UINT64_C(0), UINT64_MAX), static_cast(0)); + EXPECT_EQ(mul64To128(UINT64_MAX, UINT64_C(0)), static_cast(0)); +} + +TEST_F(U128ArithTest, KAT_Mul64To128Hi) { + EXPECT_EQ(mul64To128Hi(UINT64_MAX, UINT64_MAX), UINT64_MAX - 1); + // (2^32) * (2^32) = 2^64 => hi = 1 + EXPECT_EQ(mul64To128Hi(UINT64_C(1) << 32, UINT64_C(1) << 32), UINT64_C(1)); + EXPECT_EQ(mul64To128Hi(UINT64_C(0), UINT64_MAX), UINT64_C(0)); +} + +TEST_F(U128ArithTest, KAT_Divide128By64Lo_Simple) { + // 100 / 10 = 10 + EXPECT_EQ(divide128By64Lo(UINT64_C(0), UINT64_C(100), UINT64_C(10)), + UINT64_C(10)); + // 2^64 / 2 = 2^63 + EXPECT_EQ(divide128By64Lo(UINT64_C(1), UINT64_C(0), UINT64_C(2)), + UINT64_C(1) << 63); +} + +// Edge-value tests +TEST_F(U128ArithTest, Edge_Zero) { + constexpr u128 zero = static_cast(0); + EXPECT_EQ(u128Hi(zero), UINT64_C(0)); + EXPECT_EQ(u128Lo(zero), UINT64_C(0)); + EXPECT_EQ(mul64To128(UINT64_C(0), UINT64_C(0)), zero); + EXPECT_EQ(mul64To128Hi(UINT64_C(0), UINT64_C(0)), UINT64_C(0)); +} + +TEST_F(U128ArithTest, Edge_MaxValue) { + // ~0 : all 128 bits set — hi = UINT64_MAX, lo = UINT64_MAX + constexpr u128 u128_max = ~static_cast(0); + EXPECT_EQ(u128Hi(u128_max), UINT64_MAX); + EXPECT_EQ(u128Lo(u128_max), UINT64_MAX); + // wraps to 0 on +1 + EXPECT_EQ(u128_max + static_cast(1), static_cast(0)); + // wraps to u128_max on -1 + EXPECT_EQ(static_cast(0) - static_cast(1), u128_max); +} + +TEST_F(U128ArithTest, Edge_Boundary64) { + // 2^64 - 1 : hi = 0, lo = UINT64_MAX + constexpr u128 just_below = static_cast(UINT64_MAX); + EXPECT_EQ(u128Hi(just_below), UINT64_C(0)); + EXPECT_EQ(u128Lo(just_below), UINT64_MAX); + // 2^64 : hi = 1, lo = 0 + constexpr u128 exactly_2_64 = just_below + static_cast(1); + EXPECT_EQ(u128Hi(exactly_2_64), UINT64_C(1)); + EXPECT_EQ(u128Lo(exactly_2_64), UINT64_C(0)); +} + +TEST_F(U128ArithTest, Edge_Mul64To128_OneTimesOne) { + u128 result = mul64To128(UINT64_C(1), UINT64_C(1)); + EXPECT_EQ(u128Hi(result), UINT64_C(0)); + EXPECT_EQ(u128Lo(result), UINT64_C(1)); +} + +TEST_F(U128ArithTest, Edge_Mul64To128_OneTimesMaxHi) { + // 1 * UINT64_MAX : no carry into hi + u128 result = mul64To128(UINT64_C(1), UINT64_MAX); + EXPECT_EQ(u128Hi(result), UINT64_C(0)); + EXPECT_EQ(u128Lo(result), UINT64_MAX); + EXPECT_EQ(mul64To128Hi(UINT64_C(1), UINT64_MAX), UINT64_C(0)); +} + +TEST_F(U128ArithTest, Edge_Divide128By64Lo_DivideZero) { + // 0 / nonzero = 0 + EXPECT_EQ(divide128By64Lo(UINT64_C(0), UINT64_C(0), UINT64_MAX), + UINT64_C(0)); + EXPECT_EQ(divide128By64Lo(UINT64_C(0), UINT64_C(0), UINT64_C(1)), + UINT64_C(0)); +} + +TEST_F(U128ArithTest, Edge_Divide128By64Lo_DivideByOne) { + // (0 : lo) / 1 = lo + EXPECT_EQ(divide128By64Lo(UINT64_C(0), UINT64_C(0), UINT64_C(1)), + UINT64_C(0)); + EXPECT_EQ(divide128By64Lo(UINT64_C(0), UINT64_MAX, UINT64_C(1)), + UINT64_MAX); +} + +// Random tests +TEST_F(U128ArithTest, Random_HiLoReconstruct) { + for (int i = 0; i < 1000; ++i) { + u128 val = randomU128(); + u128 reconstructed = (static_cast(u128Hi(val)) << 64) | + static_cast(u128Lo(val)); + EXPECT_EQ(val, reconstructed); + } +} + +TEST_F(U128ArithTest, Random_Mul64To128_HiMatchesHiFunc) { + for (int i = 0; i < 1000; ++i) { + u64 a = rng(), b = rng(); + u128 full = mul64To128(a, b); + EXPECT_EQ(u128Hi(full), mul64To128Hi(a, b)); + } +} + +TEST_F(U128ArithTest, Random_Mul64To128_Commutativity) { + for (int i = 0; i < 1000; ++i) { + u64 a = rng(), b = rng(); + EXPECT_EQ(mul64To128(a, b), mul64To128(b, a)); + } +} + +TEST_F(U128ArithTest, Random_Divide128By64Lo_InverseOfMul) { + // Use a 32-bit divisor so that a*b always fits (hi word < b is guaranteed) + for (int i = 0; i < 1000; ++i) { + u64 b = (rng() & 0xFFFFFFFF) + 1; // 32-bit non-zero divisor + u64 a = rng() >> 1; // 63-bit quotient + u128 product = static_cast(a) * b; + EXPECT_EQ(divide128By64Lo(u128Hi(product), u128Lo(product), b), a); + } +} + +TEST_F(U128ArithTest, Random_Mul64To128_MultiplyByOne) { + // a * 1 = a (hi = 0, lo = a) + for (int i = 0; i < 200; ++i) { + u64 a = rng(); + u128 result = mul64To128(a, UINT64_C(1)); + EXPECT_EQ(u128Hi(result), UINT64_C(0)); + EXPECT_EQ(u128Lo(result), a); + } +} + +// --------------------------------------------- +// I128 tests +// --------------------------------------------- +class I128ArithTest : public ::testing::Test {}; + +// KAT tests +TEST_F(I128ArithTest, KAT_SignedOverflowBeyond64Bit) { + i128 a = static_cast(INT64_MAX); + i128 result = a + 1; + // 2^63 doesn't fit in i64, but is positive in i128 + EXPECT_GT(result, a); + EXPECT_EQ(result, static_cast(1) << 63); +} + +TEST_F(I128ArithTest, KAT_NegativeMultiplication) { + // (-1) * INT64_MIN => positive value, beyond i64 range + i128 neg_one = static_cast(-1); + i128 imin = static_cast(INT64_MIN); + EXPECT_EQ(neg_one * imin, -imin); + EXPECT_GT(neg_one * imin, static_cast(0)); +} + +TEST_F(I128ArithTest, KAT_LargeNegativeExtension) { + // One step below INT64_MIN must remain negative and smaller + i128 a = static_cast(INT64_MIN) - 1; + EXPECT_LT(a, static_cast(INT64_MIN)); +} + +TEST_F(I128ArithTest, KAT_AddSubInverse) { + i128 a = static_cast(INT64_MAX) + 1; // 2^63 + i128 b = static_cast(INT64_MIN); // -2^63 + EXPECT_EQ(a + b, static_cast(0)); +} + +TEST_F(I128ArithTest, KAT_NegationAndDoubling) { + i128 val = static_cast(INT64_C(0x123456789ABCDEF0)); + EXPECT_EQ(val + (-val), static_cast(0)); + EXPECT_EQ(val * 2, val + val); +} + +// Edge-value tests +// I128_MAX = 0x7FFF...FFFF, I128_MIN = 0x8000...0000 +static constexpr i128 I128_MAX = static_cast(~static_cast(0) >> 1); +static constexpr i128 I128_MIN = static_cast(static_cast(1) << 127); + +TEST_F(I128ArithTest, Edge_Zero) { + constexpr i128 zero = static_cast(0); + EXPECT_EQ(zero + zero, zero); + EXPECT_EQ(zero * I128_MAX, zero); + EXPECT_EQ(zero * I128_MIN, zero); + EXPECT_EQ(-zero, zero); +} + +TEST_F(I128ArithTest, Edge_One_NegOne) { + constexpr i128 one = static_cast(1); + constexpr i128 neg_one = static_cast(-1); + EXPECT_GT(one, static_cast(0)); + EXPECT_LT(neg_one, static_cast(0)); + EXPECT_EQ(one + neg_one, static_cast(0)); + EXPECT_EQ(one * neg_one, neg_one); + EXPECT_EQ(neg_one * neg_one, one); +} + +TEST_F(I128ArithTest, Edge_MaxValue) { + // I128_MAX is positive and greater than INT64_MAX + EXPECT_GT(I128_MAX, static_cast(INT64_MAX)); + EXPECT_EQ(I128_MAX - I128_MAX, static_cast(0)); + EXPECT_EQ(I128_MAX * static_cast(1), I128_MAX); + EXPECT_EQ(I128_MAX * static_cast(-1), -I128_MAX); + // One below max is still positive + EXPECT_GT(I128_MAX - static_cast(1), static_cast(0)); +} + +TEST_F(I128ArithTest, Edge_MinValue) { + // I128_MIN is negative and less than INT64_MIN + EXPECT_LT(I128_MIN, static_cast(INT64_MIN)); + EXPECT_EQ(I128_MIN - I128_MIN, static_cast(0)); + EXPECT_LT(I128_MIN + static_cast(1), static_cast(0)); + EXPECT_EQ(I128_MIN * static_cast(1), I128_MIN); +} + +TEST_F(I128ArithTest, Edge_MaxPlusMinIsNegOne) { + // I128_MAX + I128_MIN = (2^127 - 1) + (-2^127) = -1 + EXPECT_EQ(I128_MAX + I128_MIN, static_cast(-1)); +} + +TEST_F(I128ArithTest, Edge_MaxMinusMinIsAllOnes) { + // I128_MAX - I128_MIN wraps: (2^127-1) - (-2^127) = 2^128-1 ≡ -1 mod 2^128 + // cast back to i128 → -1 (wraps) + EXPECT_EQ(I128_MAX - I128_MIN, static_cast(-1)); +} + +// Random tests +TEST_F(I128ArithTest, Random_AddSubInverse) { + for (int i = 0; i < 1000; ++i) { + i128 a = randomI128(), b = randomI128(); + EXPECT_EQ((a + b) - b, a); + } +} + +TEST_F(I128ArithTest, Random_AddCommutativity) { + for (int i = 0; i < 1000; ++i) { + i128 a = randomI128(), b = randomI128(); + EXPECT_EQ(a + b, b + a); + } +} + +TEST_F(I128ArithTest, Random_MulCommutativity) { + for (int i = 0; i < 1000; ++i) { + i128 a = randomI128(), b = randomI128(); + EXPECT_EQ(a * b, b * a); + } +} + +TEST_F(I128ArithTest, Random_NegationAddInverse) { + for (int i = 0; i < 1000; ++i) { + i128 a = randomI128(); + EXPECT_EQ(a + (-a), static_cast(0)); + } +} + +TEST_F(I128ArithTest, Random_MulByOne) { + for (int i = 0; i < 200; ++i) { + i128 a = randomI128(); + EXPECT_EQ(a * static_cast(1), a); + EXPECT_EQ(a * static_cast(-1), -a); + } +} + +// --------------------------------------------- +// Zeroization tests +// --------------------------------------------- +class ZeroizationTest : public ::testing::Test {}; + +TEST_F(ZeroizationTest, U64Array) { + constexpr std::size_t N = 16; + u64 *arr = new u64[N]; + for (Size i = 0; i < N; ++i) + arr[i] = UINT64_MAX; // all bits set +#if defined(DEB_SECURE_ZERO_LIBSODIUM) + sodium_memzero(arr, N * sizeof(u64)); + for (Size i = 0; i < N; ++i) { + EXPECT_EQ(arr[i], static_cast(0)); + arr[i] = UINT64_MAX; + } +#elif defined(DEB_SECURE_ZERO_OPENSSL) + OPENSSL_cleanse(arr, N * sizeof(u64)); + for (Size i = 0; i < N; ++i) { + EXPECT_EQ(arr[i], static_cast(0)); + arr[i] = UINT64_MAX; + } +#elif defined(DEB_SECURE_ZERO_NATIVE) +#if defined(DEB_HAVE_SECURE_ZERO_MEMORY) + SecureZeroMemory(arr, N * sizeof(u64)); +#elif defined(DEB_HAVE_EXPLICIT_BZERO) + explicit_bzero(arr, N * sizeof(u64)); +#elif defined(DEB_HAVE_MEMSET_S) + memset_s(arr, N * sizeof(u64), 0, N * sizeof(u64)); +#else + // volatile byte loop — best-effort against compiler optimisation + volatile u64 *p = static_cast(arr); + for (Size i = 0; i < N; ++i) + p[i] = 0; +#endif + for (Size i = 0; i < N; ++i) + EXPECT_EQ(arr[i], static_cast(0)); +#endif + delete[] arr; +} + +TEST_F(ZeroizationTest, NullptrIsNoOp) { + EXPECT_NO_FATAL_FAILURE(deb_secure_zero(nullptr, 0)); + EXPECT_NO_FATAL_FAILURE(deb_secure_zero(nullptr, 16)); +} + +TEST_F(ZeroizationTest, ZeroLengthIsNoOp) { + u64 val = ~static_cast(0); + deb_secure_zero(&val, 0); + EXPECT_NE(val, static_cast(0)); // must NOT have been cleared +} + +TEST_F(ZeroizationTest, PartialOverwrite) { + // Zeroing only the first sizeof(u64) bytes should affect exactly those + // bytes + u64 val[2] = {UINT64_C(0xAAAAAAAAAAAAAAAA), UINT64_C(0xBBBBBBBBBBBBBBBB)}; + deb_secure_zero(val, sizeof(u64)); + + unsigned char raw[sizeof(u64)]; + std::memcpy(raw, val, sizeof(u64)); + bool any_zero = false; + for (std::size_t i = 0; i < sizeof(u64); ++i) + any_zero |= (raw[i] == 0); +#if defined(DEB_SECURE_ZERO_LIBSODIUM) || defined(DEB_SECURE_ZERO_OPENSSL) || \ + defined(DEB_SECURE_ZERO_NATIVE) + EXPECT_TRUE(any_zero); +#endif +} diff --git a/test/Serialize-test.cpp b/test/Serialize-test.cpp index 39bde28..3f1ae1d 100644 --- a/test/Serialize-test.cpp +++ b/test/Serialize-test.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,39 +24,18 @@ using namespace deb; class Serialize : public DebTestBase { public: - template - void compareArray(const T *arr1, const T *arr2, const Size size) { - for (Size i = 0; i < size; ++i) { - ASSERT_EQ(arr1[i], arr2[i]); - } - } - - void comparePoly(const PolyUnit &poly1, const PolyUnit &poly2) { - ASSERT_EQ(poly1.prime(), poly2.prime()); - ASSERT_EQ(poly1.degree(), poly2.degree()); - ASSERT_EQ(poly1.isNTT(), poly2.isNTT()); - compareArray(poly1.data(), poly2.data(), poly1.degree()); - } - - void compareBigPoly(const Polynomial &bigpoly1, - const Polynomial &bigpoly2) { - ASSERT_EQ(bigpoly1.size(), bigpoly2.size()); - for (Size i = 0; i < bigpoly1.size(); ++i) { - comparePoly(bigpoly1[i], bigpoly2[i]); - } - } - void compareCipher(const Ciphertext &cipher1, const Ciphertext &cipher2) { + ASSERT_EQ(cipher1.preset(), cipher2.preset()); ASSERT_EQ(cipher1.numPoly(), cipher2.numPoly()); ASSERT_EQ(cipher1.encoding(), cipher2.encoding()); for (Size i = 0; i < cipher1.numPoly(); ++i) { - compareBigPoly(cipher1[i], cipher2[i]); + comparePoly(cipher1[i], cipher2[i]); } } }; TEST_P(Serialize, MessageSerializationTest) { - Message msg = gen_random_message()[0]; + Message msg = gen_random_message()[0]; std::ostringstream os; serializeToStream(msg, os); std::istringstream is(os.str()); @@ -70,8 +49,20 @@ TEST_P(Serialize, MessageSerializationTest) { } } +TEST_P(Serialize, FMessageSerializationTest) { + FMessage msg = gen_random_message()[0]; + std::ostringstream os; + serializeToStream(msg, os); + std::istringstream is(os.str()); + FMessage deserialized_msg(0); + deserializeFromStream(is, deserialized_msg); + + EXPECT_EQ(msg.size(), deserialized_msg.size()); + compareArray(msg.data(), deserialized_msg.data(), msg.size()); +} + TEST_P(Serialize, CoeffSerializationTest) { - CoeffMessage coeff = gen_random_coeff()[0]; + CoeffMessage coeff = gen_random_coeff()[0]; std::ostringstream os; serializeToStream(coeff, os); std::istringstream is(os.str()); @@ -82,8 +73,20 @@ TEST_P(Serialize, CoeffSerializationTest) { compareArray(coeff.data(), deserialized_coeff.data(), coeff.size()); } -TEST_P(Serialize, PolySerializationTest) { - const auto prime = context->get_primes()[0]; +TEST_P(Serialize, FCoeffSerializationTest) { + FCoeffMessage coeff = gen_random_coeff()[0]; + std::ostringstream os; + serializeToStream(coeff, os); + std::istringstream is(os.str()); + FCoeffMessage deserialized_coeff(0); + deserializeFromStream(is, deserialized_coeff); + + EXPECT_EQ(coeff.size(), deserialized_coeff.size()); + compareArray(coeff.data(), deserialized_coeff.data(), coeff.size()); +} + +TEST_P(Serialize, PolyUnitSerializationTest) { + const auto prime = get_primes(preset)[0]; PolyUnit poly(prime, degree); for (Size i = 0; i < degree; ++i) { poly[i] = static_cast(dist(gen) * static_cast(prime)); @@ -95,12 +98,12 @@ TEST_P(Serialize, PolySerializationTest) { PolyUnit deserialized_poly(prime, 0); deserializeFromStream(is, deserialized_poly); - comparePoly(poly, deserialized_poly); + comparePolyUnit(poly, deserialized_poly); } -TEST_P(Serialize, BigPolySerializationTest) { - Polynomial bigpoly(context); - const auto *const primes = context->get_primes(); +TEST_P(Serialize, PolySerializationTest) { + Polynomial bigpoly(preset); + const auto *const primes = get_primes(preset); for (Size i = 0; i < bigpoly.size(); ++i) { for (Size j = 0; j < degree; ++j) { bigpoly[i][j] = @@ -111,15 +114,15 @@ TEST_P(Serialize, BigPolySerializationTest) { std::ostringstream os; serializeToStream(bigpoly, os); std::istringstream is(os.str()); - Polynomial deserialized_bigpoly(context, static_cast(0)); + Polynomial deserialized_bigpoly(preset, static_cast(0)); deserializeFromStream(is, deserialized_bigpoly, preset); - compareBigPoly(bigpoly, deserialized_bigpoly); + comparePoly(bigpoly, deserialized_bigpoly); } TEST_P(Serialize, CipherSerializationTest) { - Ciphertext ctxt(context, context->get_encryption_level(), - context->get_num_secret()); + Ciphertext ctxt(preset, get_encryption_level(preset), + get_num_secret(preset)); for (Size i = 0; i < ctxt.numPoly(); ++i) { for (Size j = 0; j < ctxt[i].size(); ++j) { for (Size k = 0; k < degree; ++k) { @@ -133,7 +136,7 @@ TEST_P(Serialize, CipherSerializationTest) { std::ostringstream os; serializeToStream(ctxt, os); std::istringstream is(os.str()); - Ciphertext deserialized_ctxt(context, 0, 1); + Ciphertext deserialized_ctxt(preset, 0, 1); deserializeFromStream(is, deserialized_ctxt); compareCipher(ctxt, deserialized_ctxt); @@ -151,13 +154,13 @@ TEST_P(Serialize, SecretKeySerializationTest) { EXPECT_EQ(sk.numPoly(), deserialized_sk.numPoly()); compareArray(sk.coeffs(), deserialized_sk.coeffs(), sk.coeffsSize()); for (Size i = 0; i < sk.numPoly(); ++i) { - compareBigPoly(sk[i], deserialized_sk[i]); + comparePoly(sk[i], deserialized_sk[i]); } } TEST_P(Serialize, SwkSerializationTest) { const SwitchKeyKind kind = SWK_ROT; - SwitchKey swk(context, kind, dist_u64(gen) % ((degree >> 1) - 1) + 1); + SwitchKey swk(preset, kind, dist_u64(gen) % ((degree >> 1) - 1) + 1); for (Size i = 0; i < swk.axSize(); ++i) { for (Size j = 0; j < swk.ax(i).size(); ++j) { @@ -171,7 +174,7 @@ TEST_P(Serialize, SwkSerializationTest) { std::ostringstream os; serializeToStream(swk, os); std::istringstream is(os.str()); - SwitchKey deserialized_swk(context, kind); + SwitchKey deserialized_swk(preset, kind); deserializeFromStream(is, deserialized_swk); EXPECT_EQ(swk.preset(), deserialized_swk.preset()); @@ -181,21 +184,21 @@ TEST_P(Serialize, SwkSerializationTest) { EXPECT_EQ(swk.axSize(), deserialized_swk.axSize()); EXPECT_EQ(swk.bxSize(), deserialized_swk.bxSize()); for (Size i = 0; i < swk.axSize(); ++i) { - compareBigPoly(swk.ax(i), deserialized_swk.ax(i)); + comparePoly(swk.ax(i), deserialized_swk.ax(i)); } for (Size i = 0; i < swk.bxSize(); ++i) { - compareBigPoly(swk.bx(i), deserialized_swk.bx(i)); + comparePoly(swk.bx(i), deserialized_swk.bx(i)); } } TEST_P(Serialize, EndecryptionSerializationTest) { - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); msg = scale_message(msg, 0); - MSGS decrypted_msg = gen_empty_message(); + MSGS decrypted_msg = gen_empty_message(); SecretKey sk = SecretKeyGenerator::GenSecretKey(preset); - Ciphertext ctxt(context); + Ciphertext ctxt(preset); encryptor.encrypt(msg, sk, ctxt); std::ostringstream os; @@ -203,12 +206,22 @@ TEST_P(Serialize, EndecryptionSerializationTest) { serializeToStream(sk, os); std::istringstream is(os.str()); - Ciphertext deserialized_ctxt(context); + Ciphertext deserialized_ctxt(preset); deserializeFromStream(is, deserialized_ctxt); SecretKey deserialized_sk(preset); deserializeFromStream(is, deserialized_sk); + decryptor.decrypt(ctxt, sk, msg); decryptor.decrypt(deserialized_ctxt, deserialized_sk, decrypted_msg); + compareCipher(ctxt, deserialized_ctxt); + EXPECT_EQ(sk.preset(), deserialized_sk.preset()); + EXPECT_EQ(sk.numPoly(), deserialized_sk.numPoly()); + compareArray(sk.getSeed().data(), deserialized_sk.getSeed().data(), + sk.getSeed().size()); + compareArray(sk.coeffs(), deserialized_sk.coeffs(), sk.coeffsSize()); + for (Size i = 0; i < sk.numPoly(); ++i) { + comparePoly(sk[i], deserialized_sk[i]); + } compare_msg(msg, decrypted_msg, scale_error(sk_err, 0)); } @@ -219,14 +232,14 @@ TEST_P(Serialize, EndecryptionWithEncKeySerializationTest) { std::ostringstream os; serializeToStream(swk, os); std::istringstream is(os.str()); - SwitchKey deserialized_swk(context, SWK_ENC); + SwitchKey deserialized_swk(preset, SWK_ENC); deserializeFromStream(is, deserialized_swk); - MSGS msg = gen_random_message(); + MSGS msg = gen_random_message(); msg = scale_message(msg, 0); - MSGS decrypted_msg = gen_empty_message(); + MSGS decrypted_msg = gen_empty_message(); - Ciphertext ctxt(context); + Ciphertext ctxt(preset); encryptor.encrypt(msg, deserialized_swk, ctxt); decryptor.decrypt(ctxt, sk, decrypted_msg); compare_msg(msg, decrypted_msg, scale_error(enc_err, 0)); @@ -258,7 +271,7 @@ TEST_P(Serialize, MinimalSecretKeySerializationTest) { completeSecretKey(sk); completeSecretKey(deserialized_sk); for (Size i = 0; i < sk.numPoly(); ++i) { - compareBigPoly(sk[i], deserialized_sk[i]); + comparePoly(sk[i], deserialized_sk[i]); } } diff --git a/test/TestBase.hpp b/test/TestBase.hpp index b2bf629..11a97a4 100644 --- a/test/TestBase.hpp +++ b/test/TestBase.hpp @@ -1,5 +1,5 @@ /* - * Copyright 2025 CryptoLab, Inc. + * Copyright 2026 CryptoLab, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,10 @@ #pragma once #include "CKKSTypes.hpp" -#include "Context.hpp" #include "Decryptor.hpp" #include "Encryptor.hpp" #include "KeyGenerator.hpp" +#include "Preset.hpp" #include "SecretKeyGenerator.hpp" #include "SeedGenerator.hpp" @@ -33,23 +33,24 @@ using namespace deb; #if defined(DEB_RESOURCE_CHECK) && defined(NDEBUG) -#define DEB_ASSERT(statement) ASSERT_THROW(statement, std::runtime_error) -#define DEB_EXPECT(statement) EXPECT_THROW(statement, std::runtime_error) +#define DEB_TEST_ASSERT(statement) ASSERT_THROW(statement, std::runtime_error) +#define DEB_TEST_EXPECT(statement) EXPECT_THROW(statement, std::runtime_error) #else -#define DEB_ASSERT(statement) ASSERT_DEATH(statement, ".*") -#define DEB_EXPECT(statement) EXPECT_DEATH(statement, ".*") +#define DEB_TEST_ASSERT(statement) ASSERT_DEATH(statement, ".*") +#define DEB_TEST_EXPECT(statement) EXPECT_DEATH(statement, ".*") #endif using MSGS = std::vector; +using FMSGS = std::vector; using COEFFS = std::vector; +using FCOEFFS = std::vector; class DebTestBase : public ::testing::TestWithParam { public: const Preset preset{GetParam()}; - Context context{getContext(preset)}; - const Size num_slots{context->get_num_slots()}; - const Size degree{context->get_degree()}; - const Size num_secret{context->get_num_secret()}; + const Size num_slots{get_num_slots(preset)}; + const Size degree{get_degree(preset)}; + const Size num_secret{get_num_secret(preset)}; Encryptor encryptor{preset}; Decryptor decryptor{preset}; @@ -62,21 +63,23 @@ class DebTestBase : public ::testing::TestWithParam { // 50 bit prime -> sk_err = 2^-26.3, enc_err = 2^-13.3 // 40 bit prime -> sk_err = 2^-24.6, enc_err = 2^-11.6 const double log_error = - static_cast(utils::bitWidth(context->get_primes()[0])) / 6.0; + static_cast(utils::bitWidth(get_primes(preset)[0])) / 6.0; const double sk_err = std::pow(2.0, -18 - log_error); const double enc_err = std::pow(2.0, -5 - log_error); + const double sk_err_f = std::pow(2.0, -18); + const double enc_err_f = std::pow(2.0, -10); void SetUp() override { // Initialize any necessary resources or state before each test } void TearDown() override { // Clean up any resources or state after each tests } - MSGS scale_message(MSGS &msg, uint32_t level) { - const double scale = context->get_scale_factors()[level]; + template T scale_message(T &msg, uint32_t level) { + const double scale = get_scale_factors(preset)[level]; if (scale == 0.0) { const double scale = - std::pow(2.0, utils::bitWidth(context->get_primes()[0]) - 4); - MSGS scale_msg = gen_empty_message(); + std::pow(2.0, utils::bitWidth(get_primes(preset)[0]) - 4); + T scale_msg = gen_empty_message(); for (Size i = 0; i < num_secret; ++i) { for (Size j = 0; j < num_slots; ++j) { scale_msg[i][j].real(msg[i][j].real() * scale); @@ -88,11 +91,11 @@ class DebTestBase : public ::testing::TestWithParam { return msg; } COEFFS scale_coeff(COEFFS &coeffs, uint32_t level) { - const double scale = context->get_scale_factors()[level]; + const double scale = get_scale_factors(preset)[level]; if (scale == 0.0) { const double scale = - std::pow(2.0, utils::bitWidth(context->get_primes()[0]) - 4); - COEFFS scale_coeffs = gen_empty_coeff(); + std::pow(2.0, utils::bitWidth(get_primes(preset)[0]) - 4); + COEFFS scale_coeffs = gen_empty_coeff(); for (Size i = 0; i < num_secret; ++i) { for (Size j = 0; j < degree; ++j) { scale_coeffs[i][j] = coeffs[i][j] * scale; @@ -103,53 +106,69 @@ class DebTestBase : public ::testing::TestWithParam { return coeffs; } double scale_error(double err, uint32_t level) { - const double scale = context->get_scale_factors()[level]; + const double scale = get_scale_factors(preset)[level]; if (scale == 0.0) { const double scale = - std::pow(2.0, utils::bitWidth(context->get_primes()[0]) - 4); + std::pow(2.0, utils::bitWidth(get_primes(preset)[0]) - 4); return err * scale; } return err; } - MSGS gen_empty_message() { - MSGS msg; + template T gen_empty_message() { + T msg; for (Size i = 0; i < num_secret; ++i) { msg.emplace_back(num_slots); } return msg; } - MSGS gen_random_message() { - MSGS msg; + template T gen_random_message() { + T msg; for (Size i = 0; i < num_secret; ++i) { - Message m(num_slots); - for (Size j = 0; j < num_slots; ++j) { - m[j].real(dist(gen)); - m[j].imag(dist(gen)); + if constexpr (std::is_same_v) { + FMessage m(num_slots); + for (Size j = 0; j < num_slots; ++j) { + m[j].real(static_cast(dist(gen))); + m[j].imag(static_cast(dist(gen))); + } + msg.emplace_back(std::move(m)); + } else if constexpr (std::is_same_v) { + Message m(num_slots); + for (Size j = 0; j < num_slots; ++j) { + m[j].real(dist(gen)); + m[j].imag(dist(gen)); + } + msg.emplace_back(std::move(m)); } - msg.emplace_back(std::move(m)); } return msg; } - COEFFS gen_empty_coeff() { - COEFFS coeffs; + template T gen_empty_coeff() { + T coeffs; for (Size i = 0; i < num_secret; ++i) { coeffs.emplace_back(degree); } return coeffs; } - COEFFS gen_random_coeff() { - COEFFS coeffs; + template T gen_random_coeff() { + T coeffs; for (Size i = 0; i < num_secret; ++i) { - CoeffMessage coeff(degree); - for (Size j = 0; j < coeff.size(); ++j) { - coeff[j] = dist(gen); + if constexpr (std::is_same_v) { + FCoeffMessage coeff(degree); + for (Size j = 0; j < coeff.size(); ++j) { + coeff[j] = static_cast(dist(gen)); + } + coeffs.emplace_back(std::move(coeff)); + } else if constexpr (std::is_same_v) { + CoeffMessage coeff(degree); + for (Size j = 0; j < coeff.size(); ++j) { + coeff[j] = dist(gen); + } + coeffs.emplace_back(std::move(coeff)); } - coeffs.emplace_back(std::move(coeff)); } return coeffs; } - - void compare_msg(MSGS &msg1, MSGS &msg2, double tol) const { + template void compare_msg(T &msg1, T &msg2, double tol) const { for (Size i = 0; i < num_secret; ++i) { for (Size j = 0; j < num_slots; ++j) { ASSERT_NEAR(msg1[i][j].real(), msg2[i][j].real(), tol); @@ -157,11 +176,32 @@ class DebTestBase : public ::testing::TestWithParam { } } } - void compare_coeff(COEFFS &coeff1, COEFFS &coeff2, double tol) const { + template + void compare_coeff(T &coeff1, T &coeff2, double tol) const { for (Size i = 0; i < num_secret; ++i) { for (Size j = 0; j < degree; ++j) { ASSERT_NEAR(coeff1[i][j], coeff2[i][j], tol); } } } + template + void compareArray(const T *arr1, const T *arr2, const Size size) { + for (Size i = 0; i < size; ++i) { + ASSERT_EQ(arr1[i], arr2[i]); + } + } + + void comparePolyUnit(const PolyUnit &poly1, const PolyUnit &poly2) { + ASSERT_EQ(poly1.prime(), poly2.prime()); + ASSERT_EQ(poly1.degree(), poly2.degree()); + ASSERT_EQ(poly1.isNTT(), poly2.isNTT()); + compareArray(poly1.data(), poly2.data(), poly1.degree()); + } + + void comparePoly(const Polynomial &bigpoly1, const Polynomial &bigpoly2) { + ASSERT_EQ(bigpoly1.size(), bigpoly2.size()); + for (Size i = 0; i < bigpoly1.size(); ++i) { + comparePolyUnit(bigpoly1[i], bigpoly2[i]); + } + } };