diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index ce0cd4d2..42be9f11 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -27,7 +27,7 @@ pub use { hash_config::HashConfig, mavros::{MavrosProver, MavrosSchemeData}, noir_proof_scheme::{NoirProof, NoirProofScheme, NoirSchemeData}, - prefix_covector::{OffsetCovector, PrefixCovector}, + prefix_covector::{OffsetCovector, PrefixCovector, SparseCovector}, prover::{NoirProver, Prover}, r1cs::R1CS, transcript_sponge::TranscriptSponge, diff --git a/provekit/common/src/prefix_covector.rs b/provekit/common/src/prefix_covector.rs index 2b1ed678..1e332893 100644 --- a/provekit/common/src/prefix_covector.rs +++ b/provekit/common/src/prefix_covector.rs @@ -238,6 +238,109 @@ pub fn compute_public_eval( eval } +/// A covector with non-zero weights at arbitrary scattered positions within a +/// `domain_size`-length domain. Used for challenge binding where challenge +/// positions in w2 may not be contiguous. +pub struct SparseCovector { + /// (position, weight) pairs. + entries: Vec<(usize, FieldElement)>, + domain_size: usize, +} + +impl SparseCovector { + /// Create a new `SparseCovector` from position-weight pairs. + /// + /// # Panics + /// + /// Asserts that `domain_size` is a power of two and all positions are + /// within bounds. + #[must_use] + pub fn new(entries: Vec<(usize, FieldElement)>, domain_size: usize) -> Self { + debug_assert!(domain_size.is_power_of_two()); + for &(pos, _) in &entries { + assert!( + pos < domain_size, + "SparseCovector: position {pos} >= domain_size {domain_size}" + ); + } + Self { + entries, + domain_size, + } + } +} + +impl LinearForm for SparseCovector { + fn size(&self) -> usize { + self.domain_size + } + + fn mle_evaluate(&self, point: &[FieldElement]) -> FieldElement { + let n = point.len(); + let mut result = FieldElement::zero(); + for &(idx, w) in &self.entries { + if w.is_zero() { + continue; + } + let mut basis = FieldElement::one(); + for (k, pk) in point.iter().enumerate() { + if (idx >> (n - 1 - k)) & 1 == 1 { + basis *= pk; + } else { + basis *= FieldElement::one() - pk; + } + } + result += w * basis; + } + result + } + + fn accumulate(&self, accumulator: &mut [FieldElement], scalar: FieldElement) { + for &(pos, w) in &self.entries { + accumulator[pos] += scalar * w; + } + } +} + +/// Create a challenge-binding weight [`SparseCovector`] from Fiat-Shamir +/// randomness `x`. +/// +/// Places `[1, x, x², …]` at the given `challenge_offsets` positions within a +/// `2^m`-length domain. +#[must_use] +pub fn make_challenge_weight( + x: FieldElement, + challenge_offsets: &[usize], + m: usize, +) -> SparseCovector { + let domain_size = 1 << m; + let mut entries = Vec::with_capacity(challenge_offsets.len()); + let mut x_pow = FieldElement::one(); + for &offset in challenge_offsets { + entries.push((offset, x_pow)); + x_pow *= x; + } + SparseCovector::new(entries, domain_size) +} + +/// Compute the challenge weight evaluation +/// `⟨[1, x, x², …], poly[offsets[0]], poly[offsets[1]], …⟩` without +/// allocating a [`SparseCovector`]. +#[must_use] +pub fn compute_challenge_eval( + x: FieldElement, + challenge_offsets: &[usize], + polynomial: &[FieldElement], +) -> FieldElement { + let mut eval = FieldElement::zero(); + let mut x_pow = FieldElement::one(); + for &offset in challenge_offsets { + eval += x_pow * polynomial[offset]; + x_pow *= x; + } + eval +} + #[cfg(test)] mod tests { use {super::*, whir::algebra::multilinear_extend}; @@ -433,4 +536,83 @@ mod tests { // offset + weights.len() = 7 + 2 = 9 > 8 let _ = OffsetCovector::new(vec![fe(1), fe(2)], 7, 8); } + + fn sparse_full_vector( + entries: &[(usize, FieldElement)], + domain_size: usize, + ) -> Vec { + let mut v = vec![FieldElement::zero(); domain_size]; + for &(pos, w) in entries { + v[pos] = w; + } + v + } + + #[test] + fn sparse_mle_evaluate_matches_full_vector() { + let domain_size = 16; + let entries = vec![(2, fe(7)), (5, fe(3)), (11, fe(13))]; + let point = vec![fe(2), fe(5), fe(13), fe(17)]; + + let covector = SparseCovector::new(entries.clone(), domain_size); + let full = sparse_full_vector(&entries, domain_size); + + let expected = multilinear_extend(&full, &point); + let actual = covector.mle_evaluate(&point); + assert_eq!(actual, expected); + } + + #[test] + fn sparse_accumulate_writes_correct_positions() { + let domain_size = 16; + let entries = vec![(2, fe(7)), (5, fe(3)), (11, fe(13))]; + let scalar = fe(4); + + let covector = SparseCovector::new(entries.clone(), domain_size); + let mut accumulator = vec![FieldElement::zero(); domain_size]; + covector.accumulate(&mut accumulator, scalar); + + let expected = sparse_full_vector(&entries, domain_size); + for i in 0..domain_size { + assert_eq!(accumulator[i], scalar * expected[i], "mismatch at {i}"); + } + } + + #[test] + fn sparse_mle_and_accumulate_are_consistent() { + let domain_size = 8; + let entries = vec![(1, fe(5)), (3, fe(11)), (6, fe(7))]; + + let covector = SparseCovector::new(entries.clone(), domain_size); + + let mut full_weights = vec![FieldElement::zero(); domain_size]; + covector.accumulate(&mut full_weights, FieldElement::one()); + assert_eq!(full_weights, sparse_full_vector(&entries, domain_size)); + + let point = vec![fe(3), fe(7), fe(11)]; + let mle_from_full = multilinear_extend(&full_weights, &point); + let mle_from_covector = covector.mle_evaluate(&point); + assert_eq!(mle_from_full, mle_from_covector); + } + + #[test] + #[should_panic(expected = "position 8 >= domain_size 8")] + fn sparse_panics_on_out_of_bounds() { + let _ = SparseCovector::new(vec![(8, fe(1))], 8); + } + + #[test] + fn compute_challenge_eval_matches_weight() { + let offsets = vec![1, 5, 11]; + let x = fe(7); + + let mut poly = vec![FieldElement::zero(); 16]; + poly[1] = fe(42); + poly[5] = fe(99); + poly[11] = fe(17); + + let eval = compute_challenge_eval(x, &offsets, &poly); + let expected = fe(42) + fe(7) * fe(99) + fe(49) * fe(17); + assert_eq!(eval, expected); + } } diff --git a/provekit/common/src/whir_r1cs.rs b/provekit/common/src/whir_r1cs.rs index e36713a5..fa1bab70 100644 --- a/provekit/common/src/whir_r1cs.rs +++ b/provekit/common/src/whir_r1cs.rs @@ -25,6 +25,7 @@ pub struct WhirR1CSScheme { pub m_0: usize, pub a_num_terms: usize, pub num_challenges: usize, + pub challenge_offsets: Vec, pub has_public_inputs: bool, pub whir_witness: WhirZkConfig, } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index ee1af6f2..f43e1a51 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -512,7 +512,15 @@ impl WitnessBuilder { r1cs: R1CS, witness_map: Vec>, acir_public_inputs_indices_set: HashSet, - ) -> Result<(SplitWitnessBuilders, R1CS, Vec>, usize), SplitError> { + ) -> Result< + ( + SplitWitnessBuilders, + R1CS, + Vec>, + Vec, + ), + SplitError, + > { if witness_builders.is_empty() { return Ok(( SplitWitnessBuilders { @@ -522,7 +530,7 @@ impl WitnessBuilder { }, r1cs, witness_map, - 0, + Vec::new(), )); } @@ -575,12 +583,15 @@ impl WitnessBuilder { scheduler.build_layers() }; - let num_challenges = w2_layers + let challenge_offsets: Vec = w2_layers .layers .iter() .flat_map(|layer| &layer.witness_builders) - .filter(|b| matches!(b, WitnessBuilder::Challenge(_))) - .count(); + .filter_map(|b| match b { + WitnessBuilder::Challenge(idx) => Some(*idx - w1_size), + _ => None, + }) + .collect(); Ok(( SplitWitnessBuilders { @@ -590,7 +601,7 @@ impl WitnessBuilder { }, remapped_r1cs, remapped_witness_map, - num_challenges, + challenge_offsets, )) } } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index a2d55f84..2dec8d4a 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -4,8 +4,9 @@ use { ark_std::{One, Zero}, provekit_common::{ prefix_covector::{ - build_prefix_covectors, compute_alpha_evals, compute_public_eval, expand_powers, - make_public_weight, OffsetCovector, + build_prefix_covectors, compute_alpha_evals, compute_challenge_eval, + compute_public_eval, expand_powers, make_challenge_weight, make_public_weight, + OffsetCovector, }, utils::{ pad_to_power_of_two, @@ -334,6 +335,16 @@ fn prove_from_alphas( None }; + // Challenge binding: prove that w2 contains the correct Fiat-Shamir + // challenge values at the expected positions. + let challenge_eval = if !scheme.challenge_offsets.is_empty() { + let ce = compute_challenge_eval(x, &scheme.challenge_offsets, &c2.polynomial); + merlin.prover_message(&ce); + Some(ce) + } else { + None + }; + let WhirR1CSCommitment { witness: w1, polynomial: p1, @@ -375,12 +386,19 @@ fn prove_from_alphas( } = c2; { let weights = build_prefix_covectors(scheme.m, alphas_2); - let evaluations: Vec = evals_2; + let mut evaluations: Vec = evals_2; - let boxed_weights: Vec>> = weights + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); + + if let Some(ce) = challenge_eval { + let cw = make_challenge_weight(x, &scheme.challenge_offsets, scheme.m); + evaluations.push(ce); + boxed_weights.push(Box::new(cw)); + } + let _ = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p2.as_slice())], diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 09add645..3de70372 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -73,13 +73,14 @@ impl NoirCompiler { main.public_inputs().indices().iter().cloned().collect(); let has_public_inputs = !acir_public_inputs_indices_set.is_empty(); - let (split_witness_builders, remapped_r1cs, remapped_witness_map, num_challenges) = + let (split_witness_builders, remapped_r1cs, remapped_witness_map, challenge_offsets) = WitnessBuilder::split_and_prepare_layers( &witness_builders, r1cs, witness_map, acir_public_inputs_indices_set, )?; + let num_challenges = challenge_offsets.len(); info!( "Witness split: w1 size = {}, w2 size = {}", split_witness_builders.w1_size, @@ -96,6 +97,7 @@ impl NoirCompiler { &remapped_r1cs, split_witness_builders.w1_size, num_challenges, + challenge_offsets, has_public_inputs, hash_config.engine_id(), ); @@ -158,10 +160,15 @@ impl MavrosCompiler { } } + let challenges_size = mavros_r1cs.witness_layout.challenges_size; + // In Mavros, challenges occupy the first `challenges_size` positions of + // w2 (immediately after the pre-commitment boundary). + let challenge_offsets: Vec = (0..challenges_size).collect(); let whir_for_witness = WhirR1CSScheme::new_from_mavros_r1cs( &mavros_r1cs, mavros_r1cs.witness_layout.pre_commitment_size(), - mavros_r1cs.witness_layout.challenges_size, + challenges_size, + challenge_offsets, num_public_inputs > 0, hash_config.engine_id(), ); diff --git a/provekit/r1cs-compiler/src/whir_r1cs.rs b/provekit/r1cs-compiler/src/whir_r1cs.rs index 0b18b19f..c36ce27e 100644 --- a/provekit/r1cs-compiler/src/whir_r1cs.rs +++ b/provekit/r1cs-compiler/src/whir_r1cs.rs @@ -12,6 +12,7 @@ pub trait WhirR1CSSchemeBuilder { r1cs: &R1CS, w1_size: usize, num_challenges: usize, + challenge_offsets: Vec, has_public_inputs: bool, hash_id: EngineId, ) -> Self; @@ -20,6 +21,7 @@ pub trait WhirR1CSSchemeBuilder { r1cs: &MavrosR1CS, w1_size: usize, num_challenges: usize, + challenge_offsets: Vec, has_public_inputs: bool, hash_id: EngineId, ) -> Self; @@ -30,6 +32,7 @@ pub trait WhirR1CSSchemeBuilder { a_num_entries: usize, w1_size: usize, num_challenges: usize, + challenge_offsets: Vec, has_public_inputs: bool, hash_id: EngineId, ) -> Self; @@ -46,9 +49,16 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { r1cs: &R1CS, w1_size: usize, num_challenges: usize, + challenge_offsets: Vec, has_public_inputs: bool, hash_id: EngineId, ) -> Self { + assert_eq!( + num_challenges, + challenge_offsets.len(), + "num_challenges ({num_challenges}) != challenge_offsets.len() ({})", + challenge_offsets.len() + ); let total_witnesses = r1cs.num_witnesses(); assert!( w1_size <= total_witnesses, @@ -74,6 +84,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { m_0, a_num_terms: next_power_of_two(r1cs.a().iter().count()), num_challenges, + challenge_offsets, whir_witness: Self::new_whir_zk_config_for_size(m_raw, 1, hash_id), has_public_inputs, } @@ -108,6 +119,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { r1cs: &MavrosR1CS, w1_size: usize, num_challenges: usize, + challenge_offsets: Vec, has_public_inputs: bool, hash_id: EngineId, ) -> Self { @@ -121,6 +133,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { a_num_entries, w1_size, num_challenges, + challenge_offsets, has_public_inputs, hash_id, ) @@ -132,9 +145,16 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { a_num_entries: usize, w1_size: usize, num_challenges: usize, + challenge_offsets: Vec, has_public_inputs: bool, hash_id: EngineId, ) -> Self { + debug_assert_eq!( + num_challenges, + challenge_offsets.len(), + "num_challenges ({num_challenges}) != challenge_offsets.len() ({})", + challenge_offsets.len() + ); let m_raw = next_power_of_two(num_witnesses); let m0_raw = next_power_of_two(num_constraints); @@ -153,6 +173,7 @@ impl WhirR1CSSchemeBuilder for WhirR1CSScheme { whir_witness: Self::new_whir_zk_config_for_size(m, 1, hash_id), w1_size, num_challenges, + challenge_offsets, has_public_inputs, } } diff --git a/provekit/verifier/src/whir_r1cs.rs b/provekit/verifier/src/whir_r1cs.rs index 84bd0c27..1ca0e4bc 100644 --- a/provekit/verifier/src/whir_r1cs.rs +++ b/provekit/verifier/src/whir_r1cs.rs @@ -3,7 +3,8 @@ use { ark_std::{One, Zero}, provekit_common::{ prefix_covector::{ - build_prefix_covectors, expand_powers, make_public_weight, OffsetCovector, + build_prefix_covectors, expand_powers, make_challenge_weight, make_public_weight, + OffsetCovector, }, utils::sumcheck::{ calculate_eq, eval_cubic_poly, multiply_transposed_by_eq_alpha, transpose_r1cs_matrices, @@ -59,16 +60,15 @@ impl WhirR1CSVerifier for WhirR1CSScheme { .receive_commitments(&mut arthur, 1) .map_err(|_| anyhow::anyhow!("Failed to parse commitment 1"))?; - let commitment_2 = if self.num_challenges > 0 { - let _logup_challenges: Vec = - arthur.verifier_message_vec(self.num_challenges); - Some( - self.whir_witness - .receive_commitments(&mut arthur, 1) - .map_err(|_| anyhow::anyhow!("Failed to parse commitment 2"))?, - ) + let (commitment_2, logup_challenges) = if self.num_challenges > 0 { + let challenges: Vec = arthur.verifier_message_vec(self.num_challenges); + let commitment = self + .whir_witness + .receive_commitments(&mut arthur, 1) + .map_err(|_| anyhow::anyhow!("Failed to parse commitment 2"))?; + (Some(commitment), Some(challenges)) } else { - None + (None, None) }; let (transposed, sumcheck_result) = rayon::join( @@ -145,7 +145,21 @@ impl WhirR1CSVerifier for WhirR1CSScheme { evals_1.to_vec() }; evaluations_1.push(blinding_eval); - let evaluations_2 = evals_2.to_vec(); + let mut evaluations_2 = evals_2.to_vec(); + + // Challenge binding: verify that w2 contains the correct + // Fiat-Shamir challenge values at the expected positions. + let challenge_covector = if let Some(ref challenges) = logup_challenges { + let challenge_eval: FieldElement = arthur + .prover_message() + .map_err(|_| anyhow::anyhow!("Failed to read challenge_eval"))?; + verify_challenge_binding(challenge_eval, x, challenges)?; + let cw = make_challenge_weight(x, &self.challenge_offsets, self.m); + evaluations_2.push(challenge_eval); + Some(cw) + } else { + None + }; let mut weight_refs_1: Vec<&dyn LinearForm> = weights_1 .iter() @@ -157,10 +171,13 @@ impl WhirR1CSVerifier for WhirR1CSScheme { .verify(&mut arthur, &weight_refs_1, &evaluations_1, &commitment_1) .map_err(|_| anyhow::anyhow!("WHIR verification failed for c1"))?; - let weight_refs_2: Vec<&dyn LinearForm> = weights_2 + let mut weight_refs_2: Vec<&dyn LinearForm> = weights_2 .iter() .map(|w| w as &dyn LinearForm) .collect(); + if let Some(ref cw) = challenge_covector { + weight_refs_2.push(cw as &dyn LinearForm); + } self.whir_witness .verify(&mut arthur, &weight_refs_2, &evaluations_2, &commitment_2) .map_err(|_| anyhow::anyhow!("WHIR verification failed for c2"))?; @@ -276,6 +293,27 @@ pub fn run_sumcheck_verifier( }) } +/// Verify that the prover's claimed challenge evaluation matches the +/// Fiat-Shamir challenges sampled by the verifier. This binds the committed +/// w2 polynomial to the transcript-derived challenge values. +fn verify_challenge_binding( + challenge_eval: FieldElement, + x: FieldElement, + challenges: &[FieldElement], +) -> Result<()> { + let mut expected = FieldElement::zero(); + let mut x_pow = FieldElement::one(); + for &ch in challenges { + expected += x_pow * ch; + x_pow *= x; + } + ensure!( + challenge_eval == expected, + "Challenge binding check failed: expected {expected:?}, got {challenge_eval:?}" + ); + Ok(()) +} + /// Verify that the prover's claimed public evaluation matches the known public /// inputs. The weight covers positions `[0, 1, ..., N]` where position 0 is the /// R1CS constant `1` and positions `1..=N` are the public inputs.