Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gkr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ default = []
grinding = [ ]
recursion = [ "transcript/recursion" ]
profile = [ "utils/profile", "sumcheck/profile" ]
low-memory = [ "poly_commit/low-memory" ]

[[bench]]
name = "gkr-hashes"
Expand Down
1 change: 1 addition & 0 deletions poly_commit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ harness = false
default = [ ]
profile = [ "utils/profile" ]
cuda_msm = [ "msm_cuda" ]
low-memory = [ ]
60 changes: 43 additions & 17 deletions poly_commit/src/batching.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,16 @@ use utils::timer::Timer;
/// - the new point for evaluation
/// - the new polynomial that is merged via sumcheck
/// - the proof of the sumcheck
///
/// When `low_memory` is true, tilde_gs are moved (not cloned) into sumcheck and
/// recomputed afterwards for g_prime construction. This reduces peak memory at
/// the cost of redundant computation.
#[allow(clippy::type_complexity)]
pub fn prover_merge_points<C>(
polys: &[impl MultilinearExtension<C::Scalar>],
points: &[impl AsRef<[C::Scalar]>],
transcript: &mut impl Transcript,
low_memory: bool,
) -> (
Vec<C::Scalar>,
MultiLinearPoly<C::Scalar>,
Expand All @@ -46,20 +51,7 @@ where
// \tilde g_i(b) = eq(t, i) * f_i(b)
let timer = Timer::new("Building tilde g_i(b)", true);

let tilde_gs = polys
.par_iter()
.enumerate()
.map(|(index, f_i)| {
let mut tilde_g_eval = vec![C::Scalar::zero(); 1 << f_i.num_vars()];
for (j, &f_i_eval) in f_i.hypercube_basis_ref().iter().enumerate() {
tilde_g_eval[j] = f_i_eval * eq_t_i[index];
}

MultiLinearPoly {
coeffs: tilde_g_eval,
}
})
.collect::<Vec<_>>();
let tilde_gs = build_tilde_gs(polys, &eq_t_i);
timer.stop();

// built the virtual polynomial for SumCheck
Expand All @@ -77,8 +69,18 @@ where

let timer = Timer::new("Sumcheck merging points", true);
let mut sumcheck_poly = SumOfProductsPoly::new();
for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) {
sumcheck_poly.add_pair(tilde_g.clone(), tilde_eq);
// Use Option to let the compiler track ownership across the two branches.
let mut tilde_gs = Some(tilde_gs);
if low_memory {
// Move tilde_gs into sumcheck_poly to avoid holding two copies simultaneously
for (tilde_g, tilde_eq) in tilde_gs.take().unwrap().into_iter().zip(tilde_eqs.into_iter())
{
sumcheck_poly.add_pair(tilde_g, tilde_eq);
}
} else {
for (tilde_g, tilde_eq) in tilde_gs.as_ref().unwrap().iter().zip(tilde_eqs.into_iter()) {
sumcheck_poly.add_pair(tilde_g.clone(), tilde_eq);
}
}
let proof = SumCheck::<C::Scalar>::prove(sumcheck_poly, transcript);
timer.stop();
Expand All @@ -99,7 +101,11 @@ where
})
.collect::<Vec<_>>();

for (tilde_g, eq_i_a2) in tilde_gs.iter().zip(eq_i_a2_polys.iter()) {
// In low_memory mode, tilde_gs were moved and consumed.
// Recompute them for g_prime construction.
let tilde_gs_for_gprime = tilde_gs.unwrap_or_else(|| build_tilde_gs(polys, &eq_t_i));

for (tilde_g, eq_i_a2) in tilde_gs_for_gprime.iter().zip(eq_i_a2_polys.iter()) {
for (j, &tilde_g_eval) in tilde_g.coeffs.iter().enumerate() {
g_prime_evals[j] += tilde_g_eval * eq_i_a2;
}
Expand All @@ -111,6 +117,26 @@ where
(a2, g_prime, proof)
}

/// Build tilde_gs: \tilde g_i(b) = eq(t, i) * f_i(b)
fn build_tilde_gs<S: ExtensionField + PrimeField>(
polys: &[impl MultilinearExtension<S>],
eq_t_i: &[S],
) -> Vec<MultiLinearPoly<S>> {
polys
.par_iter()
.enumerate()
.map(|(index, f_i)| {
let mut tilde_g_eval = vec![S::zero(); 1 << f_i.num_vars()];
for (j, &f_i_eval) in f_i.hypercube_basis_ref().iter().enumerate() {
tilde_g_eval[j] = f_i_eval * eq_t_i[index];
}
MultiLinearPoly {
coeffs: tilde_g_eval,
}
})
.collect()
}

pub fn verifier_merge_points<C>(
commitments: &[impl AsRef<[C]>],
points: &[impl AsRef<[C::Scalar]>],
Expand Down
2 changes: 1 addition & 1 deletion poly_commit/src/hyrax/hyrax_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ where
eval_timer.stop();

let merger_timer = Timer::new("merging points", true);
let (new_point, g_prime, proof) = prover_merge_points::<C>(polys, points, transcript);
let (new_point, g_prime, proof) = prover_merge_points::<C>(polys, points, transcript, cfg!(feature = "low-memory"));
merger_timer.stop();

// open g'(X) at point (a2)
Expand Down
2 changes: 1 addition & 1 deletion poly_commit/src/kzg/uni_kzg/hyper_kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ where

let merger_timer = Timer::new("merging points", true);
let (new_point, g_prime, proof) =
prover_merge_points::<E::G1Affine>(polys, &points, transcript);
prover_merge_points::<E::G1Affine>(polys, &points, transcript, cfg!(feature = "low-memory"));
merger_timer.stop();

let pcs_timer = Timer::new("kzg_open", true);
Expand Down
Loading