diff --git a/poly_commit/src/batching.rs b/poly_commit/src/batching.rs index c259aa836..8e39081c2 100644 --- a/poly_commit/src/batching.rs +++ b/poly_commit/src/batching.rs @@ -80,7 +80,7 @@ where for (tilde_g, tilde_eq) in tilde_gs.iter().zip(tilde_eqs.into_iter()) { sumcheck_poly.add_pair(tilde_g.clone(), tilde_eq); } - let proof = SumCheck::::prove(&sumcheck_poly, transcript); + let proof = SumCheck::::prove(sumcheck_poly, transcript); timer.stop(); let a2 = proof.export_point_to_expander(); diff --git a/sumcheck/src/sumcheck_generic.rs b/sumcheck/src/sumcheck_generic.rs index e2f8e43e9..832725670 100644 --- a/sumcheck/src/sumcheck_generic.rs +++ b/sumcheck/src/sumcheck_generic.rs @@ -93,13 +93,14 @@ impl SumCheck { /// Generate proof of the sum of polynomial over {0,1}^`num_vars` /// /// The polynomial is represented in the form of a VirtualPolynomial. + /// Takes ownership of poly_list to avoid cloning internally. pub fn prove( - poly_list: &SumOfProductsPoly, + poly_list: SumOfProductsPoly, transcript: &mut impl Transcript, ) -> IOPProof { let num_vars = poly_list.num_vars(); - let mut prover_state = IOPProverState::prover_init(poly_list); + let mut prover_state = IOPProverState::prover_init_owned(poly_list); let mut challenge = None; let mut prover_msgs = Vec::with_capacity(num_vars); for _ in 0..num_vars { diff --git a/sumcheck/src/sumcheck_generic/prover.rs b/sumcheck/src/sumcheck_generic/prover.rs index e19d62f57..19a547a62 100644 --- a/sumcheck/src/sumcheck_generic/prover.rs +++ b/sumcheck/src/sumcheck_generic/prover.rs @@ -31,6 +31,32 @@ impl IOPProverState { } } + /// Initialize the prover state by taking ownership of the polynomials, + /// avoiding the clone in `prover_init`. + pub fn prover_init_owned(polynomials: SumOfProductsPoly) -> Self { + let num_vars = polynomials.num_vars(); + let init_sum_of_vals: Vec = polynomials + .f_and_g_pairs + .par_iter() + .map(|(f, g)| { + f.coeffs + .iter() + .zip(g.coeffs.iter()) + .map(|(&f, &g)| f * g) + .sum::() + }) + .collect(); + let eq_prefix = vec![F::one(); polynomials.f_and_g_pairs.len()]; + Self { + challenges: Vec::with_capacity(num_vars), + round: 0, + init_num_vars: num_vars, + mle_list: polynomials, + init_sum_of_vals, + eq_prefix, + } + } + /// Receive message from verifier, generate prover message, and proceed to /// next round. /// diff --git a/sumcheck/src/sumcheck_generic/tests.rs b/sumcheck/src/sumcheck_generic/tests.rs index 8a1de31d7..fba3ef8b3 100644 --- a/sumcheck/src/sumcheck_generic/tests.rs +++ b/sumcheck/src/sumcheck_generic/tests.rs @@ -89,7 +89,7 @@ fn test_sumcheck_e2e() { // prover let mut transcript = BytesHashTranscript::::new(); - let proof = SumCheck::::prove(&mle_list, &mut transcript); + let proof = SumCheck::::prove(mle_list.clone(), &mut transcript); // verifier let mut transcript = BytesHashTranscript::::new(); @@ -119,7 +119,7 @@ fn test_sumcheck_generic_padding_helper() { }; let claimed_sum = mle_list.sum(); - let proof = SumCheck::prove(&mle_list, &mut T::new()); + let proof = SumCheck::prove(mle_list.clone(), &mut T::new()); let padded_mle_list = SumOfProductsPoly { f_and_g_pairs: mle_list @@ -135,7 +135,7 @@ fn test_sumcheck_generic_padding_helper() { .collect(), }; - let proof_with_padded_mle_list = SumCheck::prove(&padded_mle_list, &mut T::new()); + let proof_with_padded_mle_list = SumCheck::prove(padded_mle_list, &mut T::new()); assert_eq!(proof, proof_with_padded_mle_list);