diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 6f395570f..dfef844e4 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -41,6 +41,7 @@ default = [] grinding = [ ] recursion = [ "transcript/recursion" ] profile = [ "utils/profile", "sumcheck/profile" ] +low-memory = [ "poly_commit/low-memory" ] [[bench]] name = "gkr-hashes" diff --git a/poly_commit/Cargo.toml b/poly_commit/Cargo.toml index 413e22bfe..cccea1129 100644 --- a/poly_commit/Cargo.toml +++ b/poly_commit/Cargo.toml @@ -51,3 +51,4 @@ harness = false default = [ ] profile = [ "utils/profile" ] cuda_msm = [ "msm_cuda" ] +low-memory = [ ] diff --git a/poly_commit/src/batching.rs b/poly_commit/src/batching.rs index 8e39081c2..77183ac0c 100644 --- a/poly_commit/src/batching.rs +++ b/poly_commit/src/batching.rs @@ -18,11 +18,16 @@ use utils::timer::Timer; /// - the new point for evaluation /// - the new polynomial that is merged via sumcheck /// - the proof of the sumcheck +/// +/// When `low_memory` is true, tilde_gs are moved (not cloned) into sumcheck and +/// recomputed afterwards for g_prime construction. This reduces peak memory at +/// the cost of redundant computation. #[allow(clippy::type_complexity)] pub fn prover_merge_points( polys: &[impl MultilinearExtension], points: &[impl AsRef<[C::Scalar]>], transcript: &mut impl Transcript, + low_memory: bool, ) -> ( Vec, MultiLinearPoly, @@ -46,20 +51,7 @@ where // \tilde g_i(b) = eq(t, i) * f_i(b) let timer = Timer::new("Building tilde g_i(b)", true); - let tilde_gs = polys - .par_iter() - .enumerate() - .map(|(index, f_i)| { - let mut tilde_g_eval = vec![C::Scalar::zero(); 1 << f_i.num_vars()]; - for (j, &f_i_eval) in f_i.hypercube_basis_ref().iter().enumerate() { - tilde_g_eval[j] = f_i_eval * eq_t_i[index]; - } - - MultiLinearPoly { - coeffs: tilde_g_eval, - } - }) - .collect::>(); + let tilde_gs = build_tilde_gs(polys, &eq_t_i); timer.stop(); // built the virtual polynomial for SumCheck @@ -77,8 +69,18 @@ where let timer = Timer::new("Sumcheck merging points", true); let mut sumcheck_poly = SumOfProductsPoly::new(); - for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) { - sumcheck_poly.add_pair(tilde_g.clone(), tilde_eq); + // Use Option to let the compiler track ownership across the two branches. + let mut tilde_gs = Some(tilde_gs); + if low_memory { + // Move tilde_gs into sumcheck_poly to avoid holding two copies simultaneously + for (tilde_g, tilde_eq) in tilde_gs.take().unwrap().into_iter().zip(tilde_eqs.into_iter()) + { + sumcheck_poly.add_pair(tilde_g, tilde_eq); + } + } else { + for (tilde_g, tilde_eq) in tilde_gs.as_ref().unwrap().iter().zip(tilde_eqs.into_iter()) { + sumcheck_poly.add_pair(tilde_g.clone(), tilde_eq); + } } let proof = SumCheck::::prove(sumcheck_poly, transcript); timer.stop(); @@ -99,7 +101,11 @@ where }) .collect::>(); - for (tilde_g, eq_i_a2) in tilde_gs.iter().zip(eq_i_a2_polys.iter()) { + // In low_memory mode, tilde_gs were moved and consumed. + // Recompute them for g_prime construction. + let tilde_gs_for_gprime = tilde_gs.unwrap_or_else(|| build_tilde_gs(polys, &eq_t_i)); + + for (tilde_g, eq_i_a2) in tilde_gs_for_gprime.iter().zip(eq_i_a2_polys.iter()) { for (j, &tilde_g_eval) in tilde_g.coeffs.iter().enumerate() { g_prime_evals[j] += tilde_g_eval * eq_i_a2; } @@ -111,6 +117,26 @@ where (a2, g_prime, proof) } +/// Build tilde_gs: \tilde g_i(b) = eq(t, i) * f_i(b) +fn build_tilde_gs( + polys: &[impl MultilinearExtension], + eq_t_i: &[S], +) -> Vec> { + polys + .par_iter() + .enumerate() + .map(|(index, f_i)| { + let mut tilde_g_eval = vec![S::zero(); 1 << f_i.num_vars()]; + for (j, &f_i_eval) in f_i.hypercube_basis_ref().iter().enumerate() { + tilde_g_eval[j] = f_i_eval * eq_t_i[index]; + } + MultiLinearPoly { + coeffs: tilde_g_eval, + } + }) + .collect() +} + pub fn verifier_merge_points( commitments: &[impl AsRef<[C]>], points: &[impl AsRef<[C::Scalar]>], diff --git a/poly_commit/src/hyrax/hyrax_impl.rs b/poly_commit/src/hyrax/hyrax_impl.rs index f17fa4cba..9ab7b6105 100644 --- a/poly_commit/src/hyrax/hyrax_impl.rs +++ b/poly_commit/src/hyrax/hyrax_impl.rs @@ -358,7 +358,7 @@ where eval_timer.stop(); let merger_timer = Timer::new("merging points", true); - let (new_point, g_prime, proof) = prover_merge_points::(polys, points, transcript); + let (new_point, g_prime, proof) = prover_merge_points::(polys, points, transcript, cfg!(feature = "low-memory")); merger_timer.stop(); // open g'(X) at point (a2) diff --git a/poly_commit/src/kzg/uni_kzg/hyper_kzg.rs b/poly_commit/src/kzg/uni_kzg/hyper_kzg.rs index a3da52d0f..ffd900095 100644 --- a/poly_commit/src/kzg/uni_kzg/hyper_kzg.rs +++ b/poly_commit/src/kzg/uni_kzg/hyper_kzg.rs @@ -289,7 +289,7 @@ where let merger_timer = Timer::new("merging points", true); let (new_point, g_prime, proof) = - prover_merge_points::(polys, &points, transcript); + prover_merge_points::(polys, &points, transcript, cfg!(feature = "low-memory")); merger_timer.stop(); let pcs_timer = Timer::new("kzg_open", true);