diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index a3376c2b7..b918dd7e1 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -7,9 +7,13 @@ //! 3. Express pivot as linear combination of remaining variables //! 4. Substitute into all other constraints //! 5. Remove eliminated constraints +//! 6. Remove dead witness columns and prune unreachable witness builders use { - crate::{FieldElement, InternedFieldElement, SparseMatrix, R1CS}, + crate::{ + witness::{DependencyInfo, WitnessBuilder}, + FieldElement, InternedFieldElement, SparseMatrix, R1CS, + }, ark_ff::Field, ark_std::{One, Zero}, std::collections::{HashMap, HashSet}, @@ -30,8 +34,12 @@ struct Substitution { pub struct OptimizationStats { pub constraints_before: usize, pub constraints_after: usize, + pub witnesses_before: usize, + pub witnesses_after: usize, pub eliminated: usize, - pub eliminated_columns: HashSet, + pub builders_removed: usize, + /// Virtual witnesses: computation-only, excluded from WHIR commitment. + pub num_virtual: usize, } impl OptimizationStats { @@ -42,6 +50,14 @@ impl OptimizationStats { (self.constraints_before - self.constraints_after) as f64 / self.constraints_before as f64 * 100.0 } + + pub fn witness_reduction_percent(&self) -> f64 { + if self.witnesses_before == 0 { + return 0.0; + } + (self.witnesses_before as f64 - self.witnesses_after as f64) / self.witnesses_before as f64 + * 100.0 + } } /// Run the Gaussian elimination optimization on an R1CS instance. @@ -50,13 +66,24 @@ impl OptimizationStats { /// picks pivots, substitutes into remaining constraints, and removes the /// eliminated rows. /// -/// Column 0 (constant one) is never chosen as a pivot. -pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { +/// `num_public_inputs` columns (1..=num_public_inputs) and column 0 (constant +/// one) are never chosen as pivots. +pub fn optimize_r1cs( + r1cs: &mut R1CS, + witness_builders: &mut Vec, + witness_map: &mut [Option], +) -> OptimizationStats { let constraints_before = r1cs.num_constraints(); + let witnesses_before = r1cs.num_witnesses(); - // Column 0 is the constant-one column and must not be eliminated. + // Columns that must not be eliminated: + // - Column 0: constant one + // - Columns 1..=num_public_inputs: public inputs let mut forbidden: HashSet = HashSet::new(); forbidden.insert(0); + for i in 1..=r1cs.num_public_inputs { + forbidden.insert(i); + } // Phase 1: Identify all linear constraints let mut linear_rows: Vec = Vec::new(); @@ -86,7 +113,7 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { let mut sub_map_phase2: HashMap = HashMap::new(); for &row in &linear_rows { - // Extract the combined linear expression (const * A/B - C) for this constraint + // Extract the linear expression from C[row]: sum of (coeff * w_i) = 0 let expr = r1cs.extract_linear_expression(row); if expr.is_empty() { continue; @@ -157,10 +184,8 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { .map(|(col, val)| (val, col)) .collect(); - // Approximate decrement: occurrence_counts was built from raw A+B+C - // entries, but we decrement once per column in the combined expression. - // A column appearing in multiple matrices is undercounted here. Only - // affects pivot-selection heuristic quality, not correctness. + // Decrement occurrence counts for all columns in this row (they're being + // removed) for (_, col) in &expr { if occurrence_counts[*col] > 0 { occurrence_counts[*col] -= 1; @@ -183,8 +208,11 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { return OptimizationStats { constraints_before, constraints_after: constraints_before, + witnesses_before, + witnesses_after: witnesses_before, eliminated: 0, - eliminated_columns: HashSet::new(), + builders_removed: 0, + num_virtual: 0, }; } @@ -261,20 +289,30 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { } // Phase 4: Remove eliminated constraint rows + eliminated_rows.sort(); r1cs.remove_constraints(&eliminated_rows); - // Note: We do NOT modify witness builders. The witnesses are still - // computed by their original builders. GE only removes redundant - // constraints and substitutes pivots into remaining constraints. - let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); + info!( + "Phase 3 done: {} constraints remaining after substitution", + constraints_after + ); + + // Phase 5: Remove dead witness columns and prune unreachable builders + info!("Phase 5: starting dead column removal + virtual witness assignment"); + let col_stats = remove_dead_columns(r1cs, witness_builders, witness_map); + r1cs.num_virtual = col_stats.num_virtual; + let stats = OptimizationStats { constraints_before, constraints_after, + witnesses_before, + witnesses_after: col_stats.witnesses_after, eliminated, - eliminated_columns: eliminated_cols, + builders_removed: col_stats.builders_removed, + num_virtual: col_stats.num_virtual, }; info!( @@ -284,6 +322,13 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { stats.constraint_reduction_percent(), eliminated ); + info!( + "Column removal: {} -> {} witnesses ({:.1}% reduction), {} builders pruned", + witnesses_before, + stats.witnesses_after, + stats.witness_reduction_percent(), + stats.builders_removed + ); stats } @@ -301,6 +346,382 @@ fn build_occurrence_counts(r1cs: &R1CS) -> Vec { counts } +/// Counts collected during dead-column removal, returned alongside the updated +/// witness count so callers can surface them in diagnostics. +struct ColumnRemovalStats { + witnesses_after: usize, + builders_removed: usize, + /// Virtual columns: computation-only, excluded from R1CS/WHIR but needed + /// by builders. + num_virtual: usize, +} + +/// Phase 5: Remove dead witness columns from matrices and prune unreachable +/// witness builders. +/// +/// After GE, some columns have zero occurrences across all remaining +/// constraints in A, B, and C. These are "dead" columns. This function: +/// +/// 1. Identifies dead columns (zero occurrences in all three matrices) +/// 2. Builds a dependency graph of witness builders +/// 3. Finds which builders are transitively reachable from "live" columns +/// (columns still referenced by constraints) +/// 4. Prunes unreachable builders (Phase B+C cascading) +/// 5. Remaps matrix column indices to close gaps +/// 6. Remaps remaining builder witness indices +fn remove_dead_columns( + r1cs: &mut R1CS, + witness_builders: &mut Vec, + witness_map: &mut [Option], +) -> ColumnRemovalStats { + let num_cols = r1cs.num_witnesses(); + if num_cols == 0 || witness_builders.is_empty() { + return ColumnRemovalStats { + witnesses_after: num_cols, + builders_removed: 0, + num_virtual: 0, + }; + } + + // Step 1: Find dead columns (zero occurrence across A, B, C) + // Also collect columns referenced by the ACIR witness map — these are + // entry points for witness data and must stay alive. + let occurrence_counts = build_occurrence_counts(r1cs); + let mut acir_referenced: HashSet = HashSet::new(); + for entry in witness_map.iter() { + if let Some(nz) = entry { + acir_referenced.insert(nz.get() as usize); + } + } + + let mut dead_cols: HashSet = HashSet::new(); + for col in 0..num_cols { + // Never remove column 0 (constant one) or public input columns + if col == 0 || col <= r1cs.num_public_inputs { + continue; + } + // Never remove columns referenced by the ACIR witness map + if acir_referenced.contains(&col) { + continue; + } + if occurrence_counts[col] == 0 { + dead_cols.insert(col); + } + } + + if dead_cols.is_empty() { + return ColumnRemovalStats { + witnesses_after: num_cols, + builders_removed: 0, + num_virtual: 0, + }; + } + + // Diagnostic: count how many zero-occurrence cols are blocked by each mechanism + let zero_occ_total = (0..num_cols) + .filter(|&c| c > r1cs.num_public_inputs && occurrence_counts[c] == 0) + .count(); + let blocked_by_acir = (0..num_cols) + .filter(|&c| { + c > r1cs.num_public_inputs && occurrence_counts[c] == 0 && acir_referenced.contains(&c) + }) + .count(); + info!( + "Column removal: {} zero-occurrence cols (excl public), {} blocked by ACIR witness map, \ + {} truly dead", + zero_occ_total, + blocked_by_acir, + dead_cols.len() + ); + + info!( + "Column removal: {} dead columns found out of {} total", + dead_cols.len(), + num_cols + ); + + // Step 2: Build witness builder dependency graph and find reachable builders. + // A builder is "live" if any of its output columns is NOT dead, OR if a + // live builder reads any of its output columns. + // We use the existing DependencyInfo infrastructure. + let live_cols: HashSet = (0..num_cols).filter(|c| !dead_cols.contains(c)).collect(); + + // Map: witness column -> builder index + let mut col_to_builder: HashMap = HashMap::new(); + for (builder_idx, builder) in witness_builders.iter().enumerate() { + for col in DependencyInfo::extract_writes(builder) { + col_to_builder.insert(col, builder_idx); + } + } + + // Build adjacency: which builders does each builder depend on? + // (reverse of the normal dependency graph: we want "builder X reads from + // builder Y") + let mut builder_reads_from: Vec> = vec![HashSet::new(); witness_builders.len()]; + for (builder_idx, builder) in witness_builders.iter().enumerate() { + let reads = DependencyInfo::extract_reads(builder); + for read_col in reads { + if let Some(&producer_idx) = col_to_builder.get(&read_col) { + if producer_idx != builder_idx { + builder_reads_from[builder_idx].insert(producer_idx); + } + } + } + } + + // Step 3: BFS/DFS from live builders to find all transitively reachable + // builders. A builder is live if any of its output columns is live + // (referenced by constraints). + let mut live_builders: HashSet = HashSet::new(); + let mut queue: Vec = Vec::new(); + + for (builder_idx, builder) in witness_builders.iter().enumerate() { + let writes = DependencyInfo::extract_writes(builder); + let is_directly_live = writes.iter().any(|c| live_cols.contains(c)); + if is_directly_live { + if live_builders.insert(builder_idx) { + queue.push(builder_idx); + } + } + } + + // BFS: if a live builder reads from another builder, that builder is also live + while let Some(builder_idx) = queue.pop() { + for &dep_idx in &builder_reads_from[builder_idx] { + if live_builders.insert(dep_idx) { + queue.push(dep_idx); + } + } + } + + info!( + "Column removal: {} total builders, {} live (directly or transitively), {} dead", + witness_builders.len(), + live_builders.len(), + witness_builders.len() - live_builders.len() + ); + + let blocked_by_bfs = dead_cols + .iter() + .filter(|&&col| { + col_to_builder + .get(&col) + .map_or(false, |&b| live_builders.contains(&b)) + }) + .count(); + + // Detailed diagnostic breakdowns (reader types, producer types, + // hypothetical analyses) are disabled for performance — they are + // O(dead_cols × graph_size) and blow up on large circuits. Enable + // selectively when debugging a specific circuit. + // + // See git history for the full diagnostic block. + + info!( + "Column removal: of {} dead cols, {} blocked by live builder BFS, {} removable", + dead_cols.len(), + blocked_by_bfs, + dead_cols.len() - blocked_by_bfs + ); + + // Step 4: Dead columns are removable from R1CS matrices — BUT we must + // protect multi-output builders whose output ranges must stay contiguous. + // If ANY column in a multi-output builder's range is live (non-dead), + // ALL columns in that range must stay real to preserve the contiguous + // output_start + num_witnesses layout. + // Protect contiguous-range multi-output builders. Builders with + // individually-addressed outputs don't need protection: + // - U32Addition/Multi, BytePartition: independent index fields + // - ChunkDecompose, SpreadBitExtract: output_indices Vec + // - DigitalDecomposition: output_indices Vec + let mut protected_cols: HashSet = HashSet::new(); + for builder in witness_builders.iter() { + let writes = DependencyInfo::extract_writes(builder); + if writes.len() <= 1 { + continue; + } + if matches!( + builder, + WitnessBuilder::U32Addition(..) + | WitnessBuilder::U32AdditionMulti(..) + | WitnessBuilder::BytePartition { .. } + | WitnessBuilder::ChunkDecompose { .. } + | WitnessBuilder::SpreadBitExtract { .. } + | WitnessBuilder::DigitalDecomposition(..) + ) { + continue; + } + let has_live = writes.iter().any(|c| !dead_cols.contains(c)); + if has_live && writes.iter().any(|c| dead_cols.contains(c)) { + for &c in &writes { + protected_cols.insert(c); + } + } + } + let removable_cols: HashSet = dead_cols + .iter() + .filter(|c| !protected_cols.contains(c)) + .copied() + .collect(); + let protected_count = dead_cols.len() - removable_cols.len(); + if protected_count > 0 { + info!( + "Column removal: {protected_count} dead cols protected (contiguous-range multi-output \ + builders with mixed live/dead outputs)" + ); + } + + if removable_cols.is_empty() { + return ColumnRemovalStats { + witnesses_after: num_cols, + builders_removed: 0, + num_virtual: 0, + }; + } + + // Partition dead cols: dead producers (fully removable) vs live producers + // (virtual). A column must also be virtual if ANY live builder reads it + // (even if its producer is dead) — this can happen after builder rewriting + // changes dependency chains. + let live_read_cols: HashSet = { + let mut s = HashSet::new(); + for (bi, b) in witness_builders.iter().enumerate() { + if live_builders.contains(&bi) { + for c in DependencyInfo::extract_reads(b) { + if removable_cols.contains(&c) { + s.insert(c); + } + } + } + } + s + }; + let mut fully_dead_cols: HashSet = HashSet::new(); + let mut virtual_cols: HashSet = HashSet::new(); + for &col in &removable_cols { + let producer_is_live = col_to_builder + .get(&col) + .map_or(false, |&b| live_builders.contains(&b)); + if producer_is_live || live_read_cols.contains(&col) { + virtual_cols.insert(col); + } else { + fully_dead_cols.insert(col); + } + } + + info!( + "Column removal: {} dead cols total, {} fully dead (producer dead), {} virtual (producer \ + live, computation-only)", + removable_cols.len(), + fully_dead_cols.len(), + virtual_cols.len() + ); + + // Step 5: Build remap table with two regions: + // Real columns → [0, num_real) (for R1CS matrices + builders) + // Virtual columns → [num_real, num_real+num_virtual) (for builders only) + // Fully dead columns → None (no mapping, builders pruned) + let mut remap: Vec> = vec![None; num_cols]; + let mut next_real = 0usize; + // First pass: assign real column indices + for col in 0..num_cols { + if !removable_cols.contains(&col) { + remap[col] = Some(next_real); + next_real += 1; + } + } + let num_real = next_real; + // Second pass: assign virtual column indices (after real) + let mut next_virtual = num_real; + for col in 0..num_cols { + if virtual_cols.contains(&col) { + remap[col] = Some(next_virtual); + next_virtual += 1; + } + } + let num_virtual = next_virtual - num_real; + + // Step 6: Remap R1CS matrices — only uses [0, num_real) columns. + // Virtual columns had zero entries so remove_columns drops them cleanly. + let matrix_remap: Vec> = (0..num_cols) + .map(|col| { + if removable_cols.contains(&col) { + None // Remove from matrices (both virtual and fully dead) + } else { + remap[col] // Real column → compact index + } + }) + .collect(); + r1cs.a = r1cs.a.remove_columns(&matrix_remap); + r1cs.b = r1cs.b.remove_columns(&matrix_remap); + r1cs.c = r1cs.c.remove_columns(&matrix_remap); + + // Step 6b: Remap ACIR witness map (ACIR index -> R1CS column) + for entry in witness_map.iter_mut() { + if let Some(nz) = entry { + let old_col = nz.get() as usize; + let new_col = remap[old_col].unwrap_or_else(|| { + panic!( + "ACIR witness map references removed column {} (should be live)", + old_col + ) + }); + *nz = std::num::NonZeroU32::new(new_col as u32).unwrap_or_else(|| { + panic!( + "ACIR witness col {} remapped to 0 (constant-one column)", + old_col + ) + }); + } + } + + // Step 7: Prune dead builders and remap surviving ones. + // A builder must be kept if it's live OR if it produces any virtual column + // (needed for computation even though its outputs are zero in A/B/C). + let mut keep_builders = live_builders.clone(); + for &col in &virtual_cols { + if let Some(&producer_idx) = col_to_builder.get(&col) { + keep_builders.insert(producer_idx); + } + } + let builders_before = witness_builders.len(); + // Build the remapper ONCE (not per-builder) to avoid repeated HashMap + // construction from the 1M+ entry remap table. + let remapper = { + use crate::witness::WitnessIndexRemapper; + let old_to_new: HashMap = remap + .iter() + .enumerate() + .filter_map(|(old, new)| new.map(|n| (old, n))) + .collect(); + WitnessIndexRemapper::from_map(old_to_new) + }; + let mut new_builders: Vec = Vec::with_capacity(keep_builders.len()); + for (idx, builder) in witness_builders.drain(..).enumerate() { + if keep_builders.contains(&idx) { + new_builders.push(remapper.remap_builder(&builder)); + } + } + *witness_builders = new_builders; + let builders_removed = builders_before - witness_builders.len(); + + info!( + "Column removal: {} -> {} real + {} virtual witnesses ({} total for solving), {} builders \ + pruned", + num_cols, + num_real, + num_virtual, + num_real + num_virtual, + builders_removed + ); + + ColumnRemovalStats { + witnesses_after: num_real, + builders_removed, + num_virtual, + } +} + /// Apply all relevant substitutions to a single row of a matrix. /// /// Since Phase 2b resolves backward chains (later pivots referenced by @@ -355,67 +776,7 @@ fn apply_substitutions_to_row( #[cfg(test)] mod tests { - use {super::*, ark_std::One}; - - /// Evaluate `matrix · witness` for each row, returning a Vec of - /// FieldElements (one per constraint). - fn matvec(r1cs: &R1CS, matrix: &SparseMatrix, witness: &[FieldElement]) -> Vec { - let hydrated = matrix.hydrate(&r1cs.interner); - (0..matrix.num_rows) - .map(|row| { - hydrated - .iter_row(row) - .map(|(col, coeff)| coeff * witness[col]) - .sum() - }) - .collect() - } - - /// Assert that no remaining constraint references an eliminated pivot - /// column. This is the chain-resolution invariant: after GE, every - /// substituted pivot must have been fully inlined into all remaining - /// constraints. - fn assert_no_dangling_pivots(r1cs: &R1CS, stats: &OptimizationStats) { - for row in 0..r1cs.num_constraints() { - for (col, _) in r1cs.a.iter_row(row) { - assert!( - !stats.eliminated_columns.contains(&col), - "row {row} A references eliminated pivot w{col}" - ); - } - for (col, _) in r1cs.b.iter_row(row) { - assert!( - !stats.eliminated_columns.contains(&col), - "row {row} B references eliminated pivot w{col}" - ); - } - for (col, _) in r1cs.c.iter_row(row) { - assert!( - !stats.eliminated_columns.contains(&col), - "row {row} C references eliminated pivot w{col}" - ); - } - } - } - - /// Assert that `A·w ⊙ B·w == C·w` for every row of the R1CS. - fn assert_r1cs_satisfied(r1cs: &R1CS, witness: &[FieldElement]) { - let a_vals = matvec(r1cs, &r1cs.a, witness); - let b_vals = matvec(r1cs, &r1cs.b, witness); - let c_vals = matvec(r1cs, &r1cs.c, witness); - for (row, ((a, b), c)) in a_vals - .iter() - .zip(b_vals.iter()) - .zip(c_vals.iter()) - .enumerate() - { - assert_eq!( - *a * *b, - *c, - "R1CS not satisfied at row {row}: A·w={a:?}, B·w={b:?}, C·w={c:?}" - ); - } - } + use {super::*, crate::witness::SumTerm, ark_std::One}; #[test] fn test_simple_linear_elimination() { @@ -440,9 +801,20 @@ mod tests { // Constraint 1: non-linear r1cs.add_constraint(&[(one, 1)], &[(one, 2)], &[(one, 4)]); + let mut witness_builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(None, 1), SumTerm(None, 2)]), + WitnessBuilder::Product(4, 1, 2), + ]; + assert_eq!(r1cs.num_constraints(), 2); - let stats = optimize_r1cs(&mut r1cs); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut witness_builders, &mut wmap) + }; // Constraint 0 should be eliminated (it's linear) assert_eq!(stats.constraints_after, 1); @@ -457,18 +829,26 @@ mod tests { // Two chained linear constraints where L1's expression references // L0's pivot, creating a substitution chain: // - // L0: 1*1 = w1 - w3 - // L1: 1*1 = w3 - w4 - // Q: w4 * w2 = w5 (non-linear, kept) + // L0: 1*1 = w1 - w3 → w3 = w1 - 1 (pivot w3) + // L1: 1*1 = w3 - w4 → w4 = w3 - 1 (pivot w4, terms ref w3) + // Q: w4 * w2 = w5 (non-linear, kept) // - // Chain resolution must inline pivots transitively so no eliminated - // column appears in remaining constraints. + // w1, w2 are public inputs (forbidden as pivots), forcing w3 and w4 + // as the only pivot candidates for L0 and L1 respectively. + // + // Without chain resolution in Phase 2, S1's terms are [(-1, w0), (1, w3)]. + // Substituting w4 in Q introduces w3 into Q's A matrix. But w3 is + // S0's eliminated pivot — its defining constraint is removed. Bug! + // + // With chain resolution, S1's terms resolve w3 → (w1 - 1), yielding + // [(-2, w0), (1, w1)]. Q becomes (w1-2)*w2 = w5. No dangling pivots. let mut r1cs = R1CS::new(); let one = FieldElement::one(); let neg = -one; - // 6 columns: w0(const), w1..w5 + // 6 columns: w0(const), w1(public), w2(public), w3, w4, w5 r1cs.add_witnesses(6); + r1cs.num_public_inputs = 2; // L0: 1*1 = w1 - w3 r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 1), (neg, 3)]); @@ -477,83 +857,79 @@ mod tests { // Q: w4 * w2 = w5 r1cs.add_constraint(&[(one, 4)], &[(one, 2)], &[(one, 5)]); - // w3 = w1 - 1, w4 = w3 - 1 = w1 - 2, w5 = w4 * w2 - let witness: Vec = [1u64, 5, 3, 4, 3, 9] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - assert_r1cs_satisfied(&r1cs, &witness); + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(Some(neg), 0), SumTerm(None, 1)]), + WitnessBuilder::Sum(4, vec![SumTerm(Some(neg), 0), SumTerm(None, 3)]), + WitnessBuilder::Product(5, 4, 2), + ]; assert_eq!(r1cs.num_constraints(), 3); - let stats = optimize_r1cs(&mut r1cs); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; + // Both linear constraints eliminated, Q remains assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); - } - - #[test] - fn test_deep_chain_elimination() { - // Chain of depth 4 with Q at the end. Verifies that chain resolution - // works transitively through arbitrarily deep substitution chains. - // - // L0: 1*1 = w1 - w3 - // L1: 1*1 = w3 - w4 - // L2: 1*1 = w4 - w5 - // L3: 1*1 = w5 - w6 - // Q: w6 * w2 = w7 (non-linear, kept) - let mut r1cs = R1CS::new(); - let one = FieldElement::one(); - let neg = -one; - // 8 columns: w0(const), w1..w7 - r1cs.add_witnesses(8); + // After Phase 4b: Sum(4) is rewritten to inline w3's substitution + // (-1·w0 + 1·w1), so it no longer reads w3. Sum(3) (producer of w3) + // has no live consumers left → dead → w3 fully removed. + // w4 is dead in constraints (GE substituted it out) but Sum(4) still + // produces it and Product(5) reads it → w4 becomes virtual. + // Expected: 6 → 4 real witnesses (w3 fully removed, w4 virtual). + assert_eq!( + stats.witnesses_after, + stats.witnesses_before - 2, + "Expected 2 witnesses removed from R1CS (w3 dead, w4 virtual), got {} -> {}", + stats.witnesses_before, + stats.witnesses_after + ); - // L0..L3: chain of differences - for i in 0..4u32 { - let prev_col = if i == 0 { 1 } else { 2 + i as usize }; - let cur_col = 3 + i as usize; - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, prev_col), (neg, cur_col)]); + // Verify the remaining constraint references only valid column indices + let num_cols = r1cs.num_witnesses(); + for (col, _) in r1cs.a.iter_row(0) { + assert!(col < num_cols, "A references out-of-range col {col}"); + } + for (col, _) in r1cs.b.iter_row(0) { + assert!(col < num_cols, "B references out-of-range col {col}"); + } + for (col, _) in r1cs.c.iter_row(0) { + assert!(col < num_cols, "C references out-of-range col {col}"); } - // Q: w6 * w2 = w7 - r1cs.add_constraint(&[(one, 6)], &[(one, 2)], &[(one, 7)]); - - // w1=10, w3=9, w4=8, w5=7, w6=6, w7=w6*w2=18 - let witness: Vec = [1u64, 10, 3, 9, 8, 7, 6, 18] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - assert_r1cs_satisfied(&r1cs, &witness); - - assert_eq!(r1cs.num_constraints(), 5); - let stats = optimize_r1cs(&mut r1cs); - - assert_eq!(stats.eliminated, 4); - assert_eq!(stats.constraints_after, 1); - assert_eq!(r1cs.num_constraints(), 1); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); } #[test] fn test_backward_chain_elimination() { - // Backward chain: one substitution is built with terms referencing a - // column that a later substitution eliminates. Phase 2b resolves this - // backward reference so Phase 3's single pass works. + // Backward chain: S_0 is built FIRST with terms referencing w5, + // then S_1 eliminates w5. Phase 2b resolves this backward + // reference so Phase 3's single pass works. + // + // L0: 1*1 = w1 + w5 - w3 → w3 = w1 + w5 - 1 (pivot w3, count=2) + // L1: 1*1 = w4 - w5 → w5 = w4 - 1 (pivot w5, count=2 after + // decrement) Q1: w3 * w2 = w6 + // (non-linear) Q2: w4 * w4 = w7 (extra + // w4 occurrences) Q3: w5 * w1 = w8 + // (breaks count tie: w5=3 > w3=2) // - // L0: 1*1 = w1 + w5 - w3 (linear) - // L1: 1*1 = w4 - w5 (linear) - // Q1: w3 * w2 = w6 (non-linear) - // Q2: w4 * w4 = w7 (non-linear) - // Q3: w5 * w1 = w8 (non-linear) + // w1, w2 are public (forbidden). + // Counts: w3=2, w5=3, w4=3 → L0 picks w3 (min). + // After L0 decrement: w5=2, w4=3 → L1 picks w5. + // + // After full resolution: w3 = w1 + (w4-1) - 1 = w1 + w4 - 2. + // Q1 becomes: (w1 + w4 - 2) * w2 = w6. let mut r1cs = R1CS::new(); let one = FieldElement::one(); let neg = -one; - // 9 columns: w0(const), w1..w8 + // 9 columns: w0(const), w1(pub), w2(pub), w3, w4, w5, w6, w7, w8 r1cs.add_witnesses(9); + r1cs.num_public_inputs = 2; // L0: 1*1 = w1 + w5 - w3 r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 1), (one, 5), (neg, 3)]); @@ -561,127 +937,178 @@ mod tests { r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 4), (neg, 5)]); // Q1: w3 * w2 = w6 r1cs.add_constraint(&[(one, 3)], &[(one, 2)], &[(one, 6)]); - // Q2: w4 * w4 = w7 + // Q2: w4 * w4 = w7 (extra occurrences for w4) r1cs.add_constraint(&[(one, 4)], &[(one, 4)], &[(one, 7)]); - // Q3: w5 * w1 = w8 + // Q3: w5 * w1 = w8 (extra w5 occurrence to break tie vs w3) r1cs.add_constraint(&[(one, 5)], &[(one, 1)], &[(one, 8)]); - // w5 = w4-1 = 6, w3 = w1+w5-1 = 10 - let witness: Vec = [1u64, 5, 3, 10, 7, 6, 30, 49, 30] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - assert_r1cs_satisfied(&r1cs, &witness); + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![ + SumTerm(Some(neg), 0), + SumTerm(None, 1), + SumTerm(None, 5), + ]), + WitnessBuilder::Acir(4, 2), + WitnessBuilder::Sum(5, vec![SumTerm(Some(neg), 0), SumTerm(None, 4)]), + WitnessBuilder::Product(6, 3, 2), + WitnessBuilder::Product(7, 4, 4), + WitnessBuilder::Product(8, 5, 1), + ]; assert_eq!(r1cs.num_constraints(), 5); - let stats = optimize_r1cs(&mut r1cs); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; + // Both linear constraints eliminated, Q1, Q2, Q3 remain assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 3); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); + + // Pivot columns w3, w5 are dead in constraints (GE substituted + // them out) but their producers are still live (downstream builders + // read them) → they become virtual witnesses. + // Expected: 9 → 7 real witnesses (w3 and w5 become virtual). + assert_eq!( + stats.witnesses_after, + stats.witnesses_before - 2, + "Expected 2 virtual witnesses (w3, w5), got {} -> {}", + stats.witnesses_before, + stats.witnesses_after + ); + + // Verify all column references are in valid range + let num_cols = r1cs.num_witnesses(); + for row in 0..r1cs.num_constraints() { + for (col, _) in r1cs.a.iter_row(row) { + assert!(col < num_cols, "row {row} A out-of-range col {col}"); + } + for (col, _) in r1cs.b.iter_row(row) { + assert!(col < num_cols, "row {row} B out-of-range col {col}"); + } + for (col, _) in r1cs.c.iter_row(row) { + assert!(col < num_cols, "row {row} C out-of-range col {col}"); + } + } + } + + /// Helper: verify A·w ⊙ B·w == C·w for all constraints. + fn assert_r1cs_satisfied(r1cs: &R1CS, witness: &[FieldElement]) { + let interner = &r1cs.interner; + for row in 0..r1cs.num_constraints() { + let dot = |matrix: &crate::SparseMatrix| -> FieldElement { + let mut acc = FieldElement::zero(); + for (col, interned_val) in matrix.iter_row(row) { + assert!( + col < witness.len(), + "Row {row}: column index {col} out of range (witness len {})", + witness.len() + ); + let val = interner.get(interned_val).unwrap(); + acc += val * witness[col]; + } + acc + }; + let a_dot = dot(&r1cs.a); + let b_dot = dot(&r1cs.b); + let c_dot = dot(&r1cs.c); + assert_eq!( + a_dot * b_dot, + c_dot, + "Constraint {row} not satisfied: A·w * B·w != C·w" + ); + } } #[test] fn test_arithmetic_correctness() { - // Exercises simple elimination, forward chain, and backward chain - // then checks A·w ⊙ B·w == C·w on optimized R1CS. - // - // L0: 1*1 = w1 + w5 - w3 (linear) - // L1: 1*1 = w4 - w5 (linear) - // Q1: w3 * w2 = w6 (non-linear) - // Q2: w4 * w4 = w7 (non-linear) - // Q3: w5 * w1 = w8 (non-linear) + // Verify optimized R1CS is semantically equivalent to original. + // w0=1 (constant), w1=3 (public), w2=5 (public), w3=w1+w2=8, + // w4=w1*w2=15, w5=w3+w4=23 let mut r1cs = R1CS::new(); let one = FieldElement::one(); - let neg = -one; - // 9 columns: w0(const), w1..w8 - r1cs.add_witnesses(9); + r1cs.add_witnesses(6); + r1cs.num_public_inputs = 2; - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 1), (one, 5), (neg, 3)]); - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 4), (neg, 5)]); - r1cs.add_constraint(&[(one, 3)], &[(one, 2)], &[(one, 6)]); - r1cs.add_constraint(&[(one, 4)], &[(one, 4)], &[(one, 7)]); - r1cs.add_constraint(&[(one, 5)], &[(one, 1)], &[(one, 8)]); + // L0: 1*w3 = w1 + w2 → w3 = w1 + w2 (linear: B is constant) + r1cs.add_constraint(&[(one, 0)], &[(one, 3)], &[(one, 1), (one, 2)]); + // L1: 1*w5 = w3 + w4 → w5 = w3 + w4 (linear: A is constant) + r1cs.add_constraint(&[(one, 0)], &[(one, 5)], &[(one, 3), (one, 4)]); + // Q: w1 * w2 = w4 (non-linear, kept) + r1cs.add_constraint(&[(one, 1)], &[(one, 2)], &[(one, 4)]); - // w0=1, w1=5, w2=3, w4=7, w5=w4-1=6, w3=w1+w5-1=10, - // w6=w3*w2=30, w7=w4*w4=49, w8=w5*w1=30 - let witness: Vec = [1u64, 5, 3, 10, 7, 6, 30, 49, 30] + // Witness: w0=1, w1=3, w2=5, w3=8, w4=15, w5=23 + let witness_vals: Vec = [1u64, 3, 5, 8, 15, 23] .iter() - .map(|v| FieldElement::from(*v)) + .map(|&v| FieldElement::from(v)) .collect(); - assert_r1cs_satisfied(&r1cs, &witness); - - let stats = optimize_r1cs(&mut r1cs); - assert_eq!(stats.eliminated, 2); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); - } + // Verify original R1CS is satisfied + assert_r1cs_satisfied(&r1cs, &witness_vals); + + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(None, 1), SumTerm(None, 2)]), + WitnessBuilder::Product(4, 1, 2), + WitnessBuilder::Sum(5, vec![SumTerm(None, 3), SumTerm(None, 4)]), + ]; + + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; - #[test] - fn test_branch_coverage() { - // Exercises all extract_linear_expression branches and optimizer - // edge cases in one system: - // - // L_a: A=[2*w0], B=[w2], C=[w3] A-only const, coeff 2 - // L_b: A=[w4,w2], B=[3*w0], C=[w5] B-only const, coeff 3 - // L_sr0: A=[w0], B=[w0], C=[2*w6,-w7] both const - // L_sr1: A=[w0], B=[w0], C=[w7,-w6] self-ref rescaling - // L_deg0: A=[w0], B=[w0], C=[w8,-w9] both const - // L_deg1: A=[w0], B=[w0], C=[w8,-w9] degenerate (1-r_p=0) - // Q: A=[w4], B=[w7], C=[w10] non-linear - // - // L_sr1 triggers self-referencing pivot rescaling: chain-resolving - // w6 reintroduces w7 (the current pivot) with r_p=1/2, requiring - // division by (1 - 1/2). - // - // L_deg1 is a duplicate of L_deg0. After L_deg0 eliminates one - // of {w8,w9}, L_deg1's chain resolution produces r_p=1 for the - // remaining variable, so (1 - r_p) = 0 → skip. - let mut r1cs = R1CS::new(); - let one = FieldElement::one(); - let neg = -one; - let two = FieldElement::from(2u64); - let three = FieldElement::from(3u64); - - r1cs.add_witnesses(11); - - // L_a: 2*1 * w2 = w3 → w3 = 2*w2 (A-only constant branch) - r1cs.add_constraint(&[(two, 0)], &[(one, 2)], &[(one, 3)]); - // L_b: (w4+w2) * 3 = w5 → w5 = 3*w4+3*w2 (B-only constant branch) - r1cs.add_constraint(&[(one, 4), (one, 2)], &[(three, 0)], &[(one, 5)]); - // L_sr0: 1 = 2*w6 - w7 - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(two, 6), (neg, 7)]); - // L_sr1: 1 = w7 - w6 - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 7), (neg, 6)]); - // L_deg0: 1 = w8 - w9 - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 8), (neg, 9)]); - // L_deg1: 1 = w8 - w9 (duplicate → degenerate skip) - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 8), (neg, 9)]); - // Q: w4 * w7 = w10 - r1cs.add_constraint(&[(one, 4)], &[(one, 7)], &[(one, 10)]); - - // Witness derived from constraints: - // w2=4, w3=2*4=8, w4=5, w5=3*(5+4)=27, - // w6=2, w7=3 (from 1=2*2-3, 1=3-2), - // w9=10, w8=11 (from 1=11-10), - // w10=w4*w7=15 - let witness: Vec = [1u64, 5, 4, 8, 5, 27, 2, 3, 11, 10, 15] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); + assert_eq!(stats.eliminated, 2, "Should eliminate 2 linear constraints"); + assert_eq!(r1cs.num_constraints(), 1, "Should have 1 constraint left"); - assert_r1cs_satisfied(&r1cs, &witness); + // Verify column indices are in bounds. + let num_cols = r1cs.num_witnesses(); + for row in 0..r1cs.num_constraints() { + for (col, _) in r1cs.a.iter_row(row) { + assert!(col < num_cols, "A col {col} out of range {num_cols}"); + } + for (col, _) in r1cs.b.iter_row(row) { + assert!(col < num_cols, "B col {col} out of range {num_cols}"); + } + for (col, _) in r1cs.c.iter_row(row) { + assert!(col < num_cols, "C col {col} out of range {num_cols}"); + } + } - let stats = optimize_r1cs(&mut r1cs); + // Solve all builders to produce the optimized witness, then verify + // the optimized R1CS is actually satisfied. + let num_total = r1cs.num_witnesses() + r1cs.num_virtual; + let mut opt_witness = vec![FieldElement::zero(); num_total]; + let acir_values: Vec = witness_vals[1..=2].to_vec(); + for b in &builders { + match b { + WitnessBuilder::Constant(crate::witness::ConstantTerm(idx, val)) => { + opt_witness[*idx] = *val; + } + WitnessBuilder::Acir(idx, acir_idx) => { + opt_witness[*idx] = acir_values[*acir_idx]; + } + WitnessBuilder::Sum(idx, terms) => { + let mut acc = FieldElement::zero(); + for term in terms { + let coeff = term.0.unwrap_or(FieldElement::one()); + acc += coeff * opt_witness[term.1]; + } + opt_witness[*idx] = acc; + } + WitnessBuilder::Product(idx, a, b) => { + opt_witness[*idx] = opt_witness[*a] * opt_witness[*b]; + } + _ => panic!("Unexpected builder type in test"), + } + } - // 5 eliminated: L_a(w3), L_b(w5), L_sr0(w6), L_sr1(w7), - // L_deg0(w8 or w9). L_deg1 skipped (degenerate). Q non-linear. - assert_eq!(stats.eliminated, 5); - assert_eq!(stats.constraints_after, 2); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); + assert_r1cs_satisfied(&r1cs, &opt_witness); } } diff --git a/provekit/common/src/r1cs.rs b/provekit/common/src/r1cs.rs index c971d9f19..fbe940a8f 100644 --- a/provekit/common/src/r1cs.rs +++ b/provekit/common/src/r1cs.rs @@ -15,6 +15,11 @@ pub struct R1CS { pub a: SparseMatrix, pub b: SparseMatrix, pub c: SparseMatrix, + /// Virtual witnesses: computation-only columns excluded from matrices/WHIR + /// commitment but needed by witness builders for intermediate calculations. + /// The prover allocates `num_witnesses() + num_virtual` for solving. + #[serde(default)] + pub num_virtual: usize, } impl Default for R1CS { @@ -32,6 +37,7 @@ impl R1CS { a: SparseMatrix::new(0, 0), b: SparseMatrix::new(0, 0), c: SparseMatrix::new(0, 0), + num_virtual: 0, } } diff --git a/provekit/common/src/sparse_matrix.rs b/provekit/common/src/sparse_matrix.rs index d9ac044b9..22695e162 100644 --- a/provekit/common/src/sparse_matrix.rs +++ b/provekit/common/src/sparse_matrix.rs @@ -570,6 +570,39 @@ impl SparseMatrix { } } + /// Remove columns at the given indices and compact remaining columns. + /// Returns a new SparseMatrix with columns remapped according to the + /// given table. `remap[old_col] = Some(new_col)` keeps the column; + /// `None` removes it (entries are dropped). + pub fn remove_columns(&self, remap: &[Option]) -> SparseMatrix { + let new_num_cols = remap.iter().filter(|r| r.is_some()).count(); + let mut new_row_indices = Vec::with_capacity(self.num_rows); + let mut new_col_indices = Vec::new(); + let mut new_values = Vec::new(); + + for row in 0..self.num_rows { + new_row_indices.push(new_col_indices.len() as u32); + let range = self.row_range(row); + for i in range { + let old_col = self.col_indices[i] as usize; + if let Some(new_col) = remap[old_col] { + new_col_indices.push(new_col as u32); + new_values.push(self.values[i]); + } + // Dead columns should have no entries (zero occurrence), + // so this branch should never drop data. + } + } + + SparseMatrix { + num_rows: self.num_rows, + num_cols: new_num_cols, + new_row_indices, + col_indices: new_col_indices, + values: new_values, + } + } + /// Count how many rows reference each column. Returns a Vec of length /// num_cols. pub fn column_occurrence_count(&self) -> Vec { diff --git a/provekit/common/src/witness/digits.rs b/provekit/common/src/witness/digits.rs index 030aea758..222bae666 100644 --- a/provekit/common/src/witness/digits.rs +++ b/provekit/common/src/witness/digits.rs @@ -19,10 +19,10 @@ pub struct DigitalDecompositionWitnesses { pub num_witnesses_to_decompose: usize, /// Witness indices of the values to be decomposed pub witnesses_to_decompose: Vec, - /// The index of the first witness written to - pub first_witness_idx: usize, - /// The number of witnesses written to - pub num_witnesses: usize, + /// Output witness indices. Length = log_bases.len() * + /// num_witnesses_to_decompose. Layout: output_indices[digit_place * + /// num_witnesses_to_decompose + i] + pub output_indices: Vec, } /// Compute a mixed-base decomposition of a field element into its digits, using diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index b0dc929c1..0c4c9aff2 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -20,7 +20,8 @@ pub use { limbs::{Limbs, MAX_LIMBS}, ram::{SpiceMemoryOperation, SpiceWitnesses}, scheduling::{ - Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + DependencyInfo, Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, + SplitWitnessBuilders, WitnessIndexRemapper, }, witness_builder::{ CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 7e511f834..c7c8dd5aa 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -70,7 +70,7 @@ impl DependencyInfo { } /// Extracts the witness indices that a builder reads as inputs. - fn extract_reads(wb: &WitnessBuilder) -> Vec { + pub fn extract_reads(wb: &WitnessBuilder) -> Vec { match wb { WitnessBuilder::Constant(_) | WitnessBuilder::Acir(..) @@ -323,9 +323,7 @@ impl DependencyInfo { WitnessBuilder::MultiplicitiesForRange(start, range, _) => { (*start..*start + *range).collect() } - WitnessBuilder::DigitalDecomposition(dd) => { - (dd.first_witness_idx..dd.first_witness_idx + dd.num_witnesses).collect() - } + WitnessBuilder::DigitalDecomposition(dd) => dd.output_indices.clone(), WitnessBuilder::SpiceWitnesses(sw) => { (sw.first_witness_idx..sw.first_witness_idx + sw.num_witnesses).collect() } @@ -333,16 +331,8 @@ impl DependencyInfo { let n = 2usize.pow(2 * *atomic_bits); (*start..*start + n).collect() } - WitnessBuilder::ChunkDecompose { - output_start, - chunk_bits, - .. - } => (*output_start..*output_start + chunk_bits.len()).collect(), - WitnessBuilder::SpreadBitExtract { - output_start, - chunk_bits, - .. - } => (*output_start..*output_start + chunk_bits.len()).collect(), + WitnessBuilder::ChunkDecompose { output_indices, .. } => output_indices.clone(), + WitnessBuilder::SpreadBitExtract { output_indices, .. } => output_indices.clone(), WitnessBuilder::MultiplicitiesForSpread(start, num_bits, _) => { let n = 1usize << *num_bits; (*start..*start + n).collect() diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 389a14358..07679ac6d 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -19,43 +19,90 @@ use { /// This ensures w1 can be committed independently before challenge extraction. pub struct WitnessIndexRemapper { /// Maps old witness index to new witness index - pub old_to_new: HashMap, - /// Number of witnesses in w1 (boundary between w1 and w2) - pub w1_size: usize, + pub(crate) old_to_new: HashMap, + /// Number of real w1 witnesses (boundary between w1 and w2 in committed + /// vector) + pub(crate) w1_size: usize, + /// Total real witnesses (w1_real + w2_real) — used to set matrix + /// `num_cols` so matrices exclude virtual witnesses. + pub(crate) num_real: usize, } impl WitnessIndexRemapper { /// Creates a remapping from w1 and w2 builder lists. /// - /// Assigns w1 builder outputs to [0, k) and w2 builder outputs to [k, n). - pub fn new(w1_builders: &[WitnessBuilder], w2_builders: &[WitnessBuilder]) -> Self { + /// `num_real_cols` is the R1CS column count before splitting — indices + /// below this are "real" (constrained), indices >= are "virtual" + /// (computation-only). + /// + /// Output layout: + /// [0, w1_real) → real w1 witnesses (committed) + /// [w1_real, w1_real+w2_real) → real w2 witnesses (committed) + /// [w1_real+w2_real, total) → virtual witnesses (solving only) + /// + /// `w1_size` is set to `w1_real` so the WHIR commitment covers only + /// real witnesses. + pub fn new( + w1_builders: &[WitnessBuilder], + w2_builders: &[WitnessBuilder], + num_real_cols: usize, + ) -> Self { let mut old_to_new = HashMap::new(); - let mut next_w1_idx = 0; - let mut next_w2_idx = 0; + let mut next_real_w1 = 0usize; + let mut virtual_w1: Vec = Vec::new(); - // Map w1 builder outputs to [0, k) + // First pass w1: assign real outputs, collect virtual for builder in w1_builders { - let writes = DependencyInfo::extract_writes(builder); - for old_idx in writes { - old_to_new.insert(old_idx, next_w1_idx); - next_w1_idx += 1; + for old_idx in DependencyInfo::extract_writes(builder) { + if old_idx < num_real_cols { + old_to_new.insert(old_idx, next_real_w1); + next_real_w1 += 1; + } else { + virtual_w1.push(old_idx); + } } } + let w1_real = next_real_w1; - let w1_size = next_w1_idx; + let mut next_real_w2 = w1_real; + let mut virtual_w2: Vec = Vec::new(); - // Map w2 builder outputs to [k, n) + // First pass w2: assign real outputs, collect virtual for builder in w2_builders { - let writes = DependencyInfo::extract_writes(builder); - for old_idx in writes { - old_to_new.insert(old_idx, w1_size + next_w2_idx); - next_w2_idx += 1; + for old_idx in DependencyInfo::extract_writes(builder) { + if old_idx < num_real_cols { + old_to_new.insert(old_idx, next_real_w2); + next_real_w2 += 1; + } else { + virtual_w2.push(old_idx); + } } } + // Second pass: assign virtual outputs after all real ones + let mut next_virtual = next_real_w2; + for old_idx in virtual_w1.into_iter().chain(virtual_w2) { + old_to_new.insert(old_idx, next_virtual); + next_virtual += 1; + } + Self { old_to_new, - w1_size, + w1_size: w1_real, + num_real: next_real_w2, + } + } + + /// Creates a remapper from a pre-built mapping table, for use cases + /// that only need `remap_builder` (e.g., column removal optimization). + /// + /// `w1_size` and `num_real` are set to 0 — do not use this remapper for + /// matrix column updates or w1/w2 splitting. + pub fn from_map(old_to_new: HashMap) -> Self { + Self { + old_to_new, + w1_size: 0, + num_real: 0, } } @@ -162,22 +209,22 @@ impl WitnessIndexRemapper { WitnessCoefficient(*coeff, self.remap(*value)), ) } - WitnessBuilder::DigitalDecomposition(dd) => { - let new_witnesses_to_decompose = dd - .witnesses_to_decompose - .iter() - .map(|&w| self.remap(w)) - .collect(); - WitnessBuilder::DigitalDecomposition( - crate::witness::DigitalDecompositionWitnesses { - log_bases: dd.log_bases.clone(), - num_witnesses_to_decompose: dd.num_witnesses_to_decompose, - witnesses_to_decompose: new_witnesses_to_decompose, - first_witness_idx: self.remap(dd.first_witness_idx), - num_witnesses: dd.num_witnesses, - }, - ) - } + WitnessBuilder::DigitalDecomposition(dd) => WitnessBuilder::DigitalDecomposition( + crate::witness::DigitalDecompositionWitnesses { + log_bases: dd.log_bases.clone(), + num_witnesses_to_decompose: dd.num_witnesses_to_decompose, + witnesses_to_decompose: dd + .witnesses_to_decompose + .iter() + .map(|&w| self.remap(w)) + .collect(), + output_indices: dd + .output_indices + .iter() + .map(|&i| self.remap(i)) + .collect(), + }, + ), WitnessBuilder::SpiceMultisetFactor( idx, sz, @@ -493,30 +540,30 @@ impl WitnessIndexRemapper { num_bits: *num_bits, }, WitnessBuilder::ChunkDecompose { - output_start, + output_indices, packed, chunk_bits, } => WitnessBuilder::ChunkDecompose { - output_start: self.remap(*output_start), - packed: self.remap(*packed), - chunk_bits: chunk_bits.clone(), + output_indices: output_indices.iter().map(|&i| self.remap(i)).collect(), + packed: self.remap(*packed), + chunk_bits: chunk_bits.clone(), }, WitnessBuilder::SpreadWitness(output, input) => { WitnessBuilder::SpreadWitness(self.remap(*output), self.remap(*input)) } WitnessBuilder::SpreadBitExtract { - output_start, + output_indices, chunk_bits, sum_terms, extract_even, } => WitnessBuilder::SpreadBitExtract { - output_start: self.remap(*output_start), - chunk_bits: chunk_bits.clone(), - sum_terms: sum_terms + output_indices: output_indices.iter().map(|&i| self.remap(i)).collect(), + chunk_bits: chunk_bits.clone(), + sum_terms: sum_terms .iter() .map(|SumTerm(coeff, idx)| SumTerm(*coeff, self.remap(*idx))) .collect(), - extract_even: *extract_even, + extract_even: *extract_even, }, WitnessBuilder::MultiplicitiesForSpread(start, num_bits, queries) => { let new_queries = queries @@ -559,6 +606,7 @@ impl WitnessIndexRemapper { pub fn remap_r1cs(&self, r1cs: R1CS) -> R1CS { let mut new_r1cs = R1CS::new(); new_r1cs.num_public_inputs = r1cs.num_public_inputs; + new_r1cs.num_virtual = r1cs.num_virtual; new_r1cs.interner = r1cs.interner; // Remap A, B, C in parallel - they're independent @@ -579,9 +627,12 @@ impl WitnessIndexRemapper { new_r1cs } - /// Helper to remap a single sparse matrix + /// Helper to remap a single sparse matrix. + /// Updates `num_cols` to `num_real` (w1_real + w2_real) so the matrix + /// dimensions exclude virtual witnesses. fn remap_sparse_matrix(&self, mut matrix: SparseMatrix) -> SparseMatrix { matrix.remap_columns(|old_col| self.remap(old_col)); + matrix.num_cols = self.num_real; matrix } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index ee1af6f26..fc3c41564 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -6,8 +6,8 @@ use { limbs::Limbs, ram::SpiceWitnesses, scheduling::{ - LayerScheduler, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, - WitnessIndexRemapper, WitnessSplitter, + DependencyInfo, LayerScheduler, LayeredWitnessBuilders, SplitError, + SplitWitnessBuilders, WitnessIndexRemapper, WitnessSplitter, }, ConstantOrR1CSWitness, }, @@ -277,11 +277,11 @@ pub enum WitnessBuilder { /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... - /// Writes chunk values to output_start..output_start+chunk_bits.len() + /// `output_indices[i]` is the witness index for chunk `i`. ChunkDecompose { - output_start: usize, - packed: usize, - chunk_bits: Vec, + output_indices: Vec, + packed: usize, + chunk_bits: Vec, }, /// Prover hint for FakeGLV scalar decomposition. /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, @@ -401,13 +401,12 @@ pub enum WitnessBuilder { SpreadWitness(usize, usize), /// Extracts even or odd bits from a spread sum, decomposed into /// byte-sized chunks. Even bits = XOR result, Odd bits = MAJ/AND - /// result. The sum is computed inline from the provided terms, - /// avoiding a separate witness allocation. + /// result. `output_indices[i]` is the witness index for chunk `i`. SpreadBitExtract { - output_start: usize, - chunk_bits: Vec, - sum_terms: Vec, - extract_even: bool, + output_indices: Vec, + chunk_bits: Vec, + sum_terms: Vec, + extract_even: bool, }, /// Spread table multiplicities: counts how many times each input /// value appears in the query set. @@ -453,7 +452,7 @@ impl WitnessBuilder { pub fn num_witnesses(&self) -> usize { match self { WitnessBuilder::MultiplicitiesForRange(_, range_size, _) => *range_size, - WitnessBuilder::DigitalDecomposition(dd_struct) => dd_struct.num_witnesses, + WitnessBuilder::DigitalDecomposition(dd_struct) => dd_struct.output_indices.len(), WitnessBuilder::SpiceWitnesses(spice_witnesses_struct) => { spice_witnesses_struct.num_witnesses } @@ -463,8 +462,8 @@ impl WitnessBuilder { WitnessBuilder::U32Addition(..) => 2, WitnessBuilder::U32AdditionMulti(..) => 2, WitnessBuilder::BytePartition { .. } => 2, - WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), - WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), + WitnessBuilder::ChunkDecompose { output_indices, .. } => output_indices.len(), + WitnessBuilder::SpreadBitExtract { output_indices, .. } => output_indices.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, @@ -541,8 +540,12 @@ impl WitnessBuilder { .map(|&idx| witness_builders[idx].clone()) .collect(); - // Step 3: Create witness index remapper - let remapper = WitnessIndexRemapper::new(&w1_builders, &w2_builders); + // Step 3: Create witness index remapper. + // Pass num_real_cols so virtual witnesses are placed at the end, + // after all real w1/w2 witnesses — keeping them out of the + // committed WHIR polynomial. + let num_real_cols = r1cs.num_witnesses(); + let remapper = WitnessIndexRemapper::new(&w1_builders, &w2_builders, num_real_cols); let w1_size = remapper.w1_size; // Step 4: Remap all builders @@ -556,10 +559,45 @@ impl WitnessBuilder { .map(|b| remapper.remap_builder(b)) .collect(); - // Step 5: Remap R1CS and witness map + // Step 5: Remap R1CS and witness map. + // num_virtual is preserved — virtual witnesses are at the end of + // the index space, after all real w1/w2 witnesses. let remapped_r1cs = remapper.remap_r1cs(r1cs); let remapped_witness_map = remapper.remap_acir_witness_map(witness_map); + // Debug validation: ensure every remapped builder's reads have a + // producer in the remapped set. Without this, a bug in + // pruning/remapping could silently break the dependency graph, + // causing wrong witnesses at proving time. + #[cfg(debug_assertions)] + { + let all_builders: Vec<&WitnessBuilder> = remapped_w1_builders + .iter() + .chain(remapped_w2_builders.iter()) + .collect(); + + // Build producer map over all remapped builders. + let mut produced: HashSet = HashSet::new(); + produced.insert(0); // witness-one is always available + for b in &all_builders { + for w in DependencyInfo::extract_writes(b) { + produced.insert(w); + } + } + + // Check that every read is satisfied. + for (i, b) in all_builders.iter().enumerate() { + for r in DependencyInfo::extract_reads(b) { + assert!( + produced.contains(&r), + "Builder integrity violation after remapping: builder {i} ({b:?}) reads \ + witness {r}, but no builder produces it. This indicates a bug in the \ + pruning/remapping logic." + ); + } + } + } + // Step 6: Schedule both groups independently with batch inversions let w1_layers = if remapped_w1_builders.is_empty() { LayeredWitnessBuilders { layers: Vec::new() } diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index edcf685ac..f9a0ec7d6 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -113,6 +113,7 @@ impl Prove for NoirProver { let compressed_r1cs = CompressedR1CS::compress(self.r1cs).context("While compressing R1CS")?; let num_witnesses = compressed_r1cs.num_witnesses(); + let num_virtual = compressed_r1cs.num_virtual(); let num_constraints = compressed_r1cs.num_constraints(); // Set up transcript with sponge selected by hash_config. @@ -122,7 +123,9 @@ impl Prove for NoirProver { .instance(&Empty); let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); - let mut witness: Vec> = vec![None; num_witnesses]; + // Allocate space for real + virtual witnesses. Virtual witnesses are + // computation-only (zero entries in A/B/C) but needed by builders. + let mut witness: Vec> = vec![None; num_witnesses + num_virtual]; // Solve w1 (or all witnesses if no challenges). // Outer span captures memory AFTER w1_layers parameter is freed @@ -187,7 +190,8 @@ impl Prove for NoirProver { let w2 = { let _s = info_span!("allocate_w2").entered(); - witness[self.whir_for_witness.w1_size..] + // Only real w2 witnesses (exclude virtual at the end). + witness[self.whir_for_witness.w1_size..num_witnesses] .iter() .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w2 are missing"))) .collect::>>()? @@ -210,8 +214,13 @@ impl Prove for NoirProver { .context("While decompressing R1CS")?; #[cfg(test)] - r1cs.test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) - .context("While verifying R1CS instance")?; + r1cs.test_witness_satisfaction( + &witness[..num_witnesses] + .iter() + .map(|w| w.unwrap()) + .collect::>(), + ) + .context("While verifying R1CS instance")?; let public_inputs = if num_public_inputs == 0 { PublicInputs::new() @@ -224,8 +233,11 @@ impl Prove for NoirProver { ) }; - let full_witness: Vec = witness - .into_iter() + // Extract only real witnesses (first num_witnesses) for the sumcheck. + // Virtual witnesses at [num_witnesses, num_witnesses+num_virtual) were + // needed for builder computation but have zero entries in A/B/C. + let full_witness: Vec = witness[..num_witnesses] + .iter() .enumerate() .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index 80254d4db..2553f6b64 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -19,6 +19,7 @@ use { pub struct CompressedR1CS { num_constraints: usize, num_witnesses: usize, + num_virtual: usize, blob: Vec, } @@ -47,11 +48,13 @@ impl CompressedR1CS { pub fn compress(r1cs: R1CS) -> Result { let num_constraints = r1cs.num_constraints(); let num_witnesses = r1cs.num_witnesses(); + let num_virtual = r1cs.num_virtual; let blob = postcard::to_allocvec(&r1cs).context("R1CS serialization failed")?; drop(r1cs); Ok(Self { num_constraints, num_witnesses, + num_virtual, blob, }) } @@ -68,6 +71,10 @@ impl CompressedR1CS { self.num_witnesses } + pub const fn num_virtual(&self) -> usize { + self.num_virtual + } + pub fn blob_len(&self) -> usize { self.blob.len() } diff --git a/provekit/prover/src/witness/digits.rs b/provekit/prover/src/witness/digits.rs index 409ba2795..4b5ac10e4 100644 --- a/provekit/prover/src/witness/digits.rs +++ b/provekit/prover/src/witness/digits.rs @@ -19,9 +19,8 @@ impl DigitalDecompositionWitnessesSolver for DigitalDecompositionWitnesses { .iter() .enumerate() .for_each(|(digit_place, digit_value)| { - witness[self.first_witness_idx - + digit_place * self.witnesses_to_decompose.len() - + i] = Some(*digit_value); + let idx = digit_place * self.witnesses_to_decompose.len() + i; + witness[self.output_indices[idx]] = Some(*digit_value); }); }); } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 111dfefa3..7eb9f31aa 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -859,7 +859,7 @@ impl WitnessBuilderSolver for WitnessBuilder { ) } WitnessBuilder::ChunkDecompose { - output_start, + output_indices, packed, chunk_bits, } => { @@ -868,7 +868,7 @@ impl WitnessBuilderSolver for WitnessBuilder { for (i, &bits) in chunk_bits.iter().enumerate() { let mask = (1u64 << bits) - 1; let chunk_val = (packed_val >> offset) & mask; - witness[output_start + i] = Some(FieldElement::from(chunk_val)); + witness[output_indices[i]] = Some(FieldElement::from(chunk_val)); offset += bits; } } @@ -878,12 +878,11 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output_idx] = Some(FieldElement::from(spread)); } WitnessBuilder::SpreadBitExtract { - output_start, + output_indices, chunk_bits, sum_terms, extract_even, } => { - // Compute the spread sum inline from terms (no phantom witness needed) let sum_fe: FieldElement = sum_terms .iter() .map(|SumTerm(coeff, idx)| { @@ -896,19 +895,17 @@ impl WitnessBuilderSolver for WitnessBuilder { }) .fold(FieldElement::zero(), |acc, x| acc + x); let sum_val = sum_fe.into_bigint().0[0]; - // Extract even or odd bits from the spread sum let bit_offset = if *extract_even { 0 } else { 1 }; let total_bits: u32 = chunk_bits.iter().sum(); let mut extracted = 0u64; for i in 0..total_bits { extracted |= ((sum_val >> (2 * i + bit_offset)) & 1) << i; } - // Decompose extracted value into chunks let mut offset = 0u32; for (i, &bits) in chunk_bits.iter().enumerate() { let mask = (1u64 << bits) - 1; let chunk_val = (extracted >> offset) & mask; - witness[output_start + i] = Some(FieldElement::from(chunk_val)); + witness[output_indices[i]] = Some(FieldElement::from(chunk_val)); offset += bits; } } diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 657f7bd78..a3a61dcc8 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -25,13 +25,12 @@ impl DigitalDecompositionWitnessesBuilder for DigitalDecompositionWitnesses { witnesses_to_decompose: Vec, ) -> Self { let num_witnesses_to_decompose = witnesses_to_decompose.len(); - let digital_decomp_length = log_bases.len(); + let num_outputs = log_bases.len() * num_witnesses_to_decompose; Self { log_bases, num_witnesses_to_decompose, witnesses_to_decompose, - first_witness_idx: next_witness_idx, - num_witnesses: digital_decomp_length * num_witnesses_to_decompose, + output_indices: (next_witness_idx..next_witness_idx + num_outputs).collect(), } } @@ -42,7 +41,7 @@ impl DigitalDecompositionWitnessesBuilder for DigitalDecompositionWitnesses { fn get_digit_witness_index(&self, digit_place: usize, value_offset: usize) -> usize { debug_assert!(digit_place < self.log_bases.len()); debug_assert!(value_offset < self.num_witnesses_to_decompose); - self.first_witness_idx + digit_place * self.num_witnesses_to_decompose + value_offset + self.output_indices[digit_place * self.num_witnesses_to_decompose + value_offset] } } diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 09add6456..29b11a229 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -50,7 +50,7 @@ impl NoirCompiler { main.opcodes.len() ); - let (mut r1cs, witness_map, witness_builders) = noir_to_r1cs(main)?; + let (mut r1cs, mut witness_map, mut witness_builders) = noir_to_r1cs(main)?; info!( "R1CS {} constraints, {} witnesses, A {} entries, B {} entries, C {} entries", r1cs.num_constraints(), @@ -61,10 +61,16 @@ impl NoirCompiler { ); // Gaussian elimination optimization pass - let opt_stats = provekit_common::optimize::optimize_r1cs(&mut r1cs); + let opt_stats = provekit_common::optimize::optimize_r1cs( + &mut r1cs, + &mut witness_builders, + &mut witness_map, + ); info!( - "After GE optimization: {} constraints ({} eliminated, {:.1}% constraint reduction)", + "After GE optimization: {} constraints, {} witnesses ({} eliminated, {:.1}% \ + constraint reduction)", opt_stats.constraints_after, + opt_stats.witnesses_after, opt_stats.eliminated, opt_stats.constraint_reduction_percent() ); @@ -80,17 +86,17 @@ impl NoirCompiler { witness_map, acir_public_inputs_indices_set, )?; + let num_real = remapped_r1cs.num_witnesses(); + let num_virtual = remapped_r1cs.num_virtual; info!( - "Witness split: w1 size = {}, w2 size = {}", + "Witness split: w1 = {}, w2 = {} (real, committed) + {} virtual (solving only)", split_witness_builders.w1_size, - remapped_r1cs.num_witnesses() - split_witness_builders.w1_size + num_real - split_witness_builders.w1_size, + num_virtual ); - let witness_generator = NoirWitnessGenerator::new( - &program, - remapped_witness_map, - remapped_r1cs.num_witnesses(), - ); + let witness_generator = + NoirWitnessGenerator::new(&program, remapped_witness_map, num_real + num_virtual); let whir_for_witness = WhirR1CSScheme::new_for_r1cs( &remapped_r1cs, diff --git a/provekit/r1cs-compiler/src/spread.rs b/provekit/r1cs-compiler/src/spread.rs index 8acc93bea..d0867dc52 100644 --- a/provekit/r1cs-compiler/src/spread.rs +++ b/provekit/r1cs-compiler/src/spread.rs @@ -35,7 +35,6 @@ fn subchunks(bits: u32, w: u32) -> Vec { #[derive(Clone, Debug)] pub(crate) struct SpreadChunk { pub total_bits: u32, - pub sub_values: Vec, pub sub_spreads: Vec, pub sub_bits: Vec, } @@ -147,8 +146,9 @@ pub(crate) fn decompose_to_spread_word( // Step 2: Single ChunkDecompose producing all sub-chunks directly // from the packed value. let sub_start = compiler.num_witnesses(); + let chunk_count = flat_bits.len(); compiler.add_witness_builder(WitnessBuilder::ChunkDecompose { - output_start: sub_start, + output_indices: (sub_start..sub_start + chunk_count).collect(), packed, chunk_bits: flat_bits.clone(), }); @@ -175,7 +175,6 @@ pub(crate) fn decompose_to_spread_word( let n_subs = chunk_sub_counts[ci]; let sub_bits_slice = &flat_bits[flat_idx..flat_idx + n_subs]; - let mut sub_values = Vec::with_capacity(n_subs); let mut sub_spreads = Vec::with_capacity(n_subs); for j in 0..n_subs { let val_idx = sub_start + flat_idx + j; @@ -187,13 +186,11 @@ pub(crate) fn decompose_to_spread_word( .or_default() .push(val_idx); } - sub_values.push(val_idx); sub_spreads.push(spread_idx); } chunks.push(SpreadChunk { total_bits: chunk_spec[ci], - sub_values, sub_spreads, sub_bits: sub_bits_slice.to_vec(), }); @@ -262,17 +259,19 @@ pub(crate) fn spread_decompose( // The sum is computed inline by the solver from sum_terms, // avoiding a phantom witness that would inflate the witness vector. let even_start = compiler.num_witnesses(); + let even_count = extract_chunks.len(); compiler.add_witness_builder(WitnessBuilder::SpreadBitExtract { - output_start: even_start, - chunk_bits: extract_chunks.clone(), - sum_terms: sum_terms.clone(), - extract_even: true, + output_indices: (even_start..even_start + even_count).collect(), + chunk_bits: extract_chunks.clone(), + sum_terms: sum_terms.clone(), + extract_even: true, }); // Extract odd bits (AND/MAJ) into chunks let odd_start = compiler.num_witnesses(); + let odd_count = extract_chunks.len(); compiler.add_witness_builder(WitnessBuilder::SpreadBitExtract { - output_start: odd_start, + output_indices: (odd_start..odd_start + odd_count).collect(), chunk_bits: extract_chunks.clone(), sum_terms, extract_even: false, @@ -478,10 +477,8 @@ pub(crate) fn decompose_constant_to_spread_word( // spread table lookup is needed for soundness. // Build SpreadChunks with pinned spread witnesses. - // No ChunkDecompose needed — sub_values are never read by any - // downstream R1CS constraint (only sub_spreads are used in - // spread_decompose). We still populate sub_values with the - // spread witness indices as placeholders to satisfy the struct. + // No ChunkDecompose needed — only sub_spreads are used in + // downstream spread_decompose constraints. let mut chunks = Vec::with_capacity(num_chunks); let mut flat_idx = 0usize; @@ -504,10 +501,6 @@ pub(crate) fn decompose_constant_to_spread_word( chunks.push(SpreadChunk { total_bits: chunk_spec[ci], - // NB: sub_values normally holds chunk-value witness indices, but the constant - // path doesn't create separate chunk-value witnesses. Set to spread indices - // instead; this field is currently unused so the mismatch is harmless. - sub_values: sub_spreads.clone(), sub_spreads, sub_bits: sub_bits_slice.to_vec(), }); diff --git a/tooling/cli/src/cmd/circuit_stats/display.rs b/tooling/cli/src/cmd/circuit_stats/display.rs index b353bdcb1..26c8efdc4 100644 --- a/tooling/cli/src/cmd/circuit_stats/display.rs +++ b/tooling/cli/src/cmd/circuit_stats/display.rs @@ -474,9 +474,13 @@ pub(super) fn print_ge_optimization( (optimized_r1cs.num_constraints() as f64).log2() ); println!( - "│ Witnesses: {:>8} (2^{:.2})", - optimized_r1cs.num_witnesses(), - (optimized_r1cs.num_witnesses() as f64).log2() + "│ Witnesses: {:>8} (2^{:.2}) committed to WHIR", + stats.witnesses_after, + (stats.witnesses_after as f64).log2() + ); + println!( + "│ Virtual: {:>8} solving only, not committed", + stats.num_virtual ); println!("│ A entries: {:>8}", optimized_r1cs.a.num_entries()); println!("│ B entries: {:>8}", optimized_r1cs.b.num_entries()); @@ -484,16 +488,19 @@ pub(super) fn print_ge_optimization( println!("└{}", SUBSECTION); println!("\n{}", SEPARATOR); - println!( - "ELIMINATED: {:>8} linear constraints substituted", - stats.eliminated - ); println!( "CONSTRAINT REDUCTION: {:>7.2}% ({} -> {})", stats.constraint_reduction_percent(), stats.constraints_before, stats.constraints_after ); + println!( + "WITNESS REDUCTION: {:>7.2}% ({} -> {} committed + {} virtual)", + stats.witness_reduction_percent(), + stats.witnesses_before, + stats.witnesses_after, + stats.num_virtual + ); println!("{}", SEPARATOR); println!(); } diff --git a/tooling/cli/src/cmd/circuit_stats/mod.rs b/tooling/cli/src/cmd/circuit_stats/mod.rs index 4c4f38bff..f8b4e2e72 100644 --- a/tooling/cli/src/cmd/circuit_stats/mod.rs +++ b/tooling/cli/src/cmd/circuit_stats/mod.rs @@ -91,14 +91,14 @@ fn analyze_circuit(program: Program, path: &Path) -> Result<()> { display::print_acir_stats(&stats); - let (r1cs, _witness_map, _witness_builders, breakdown) = + let (r1cs, mut witness_map, mut witness_builders, breakdown) = noir_to_r1cs_with_breakdown(&circuit).context("Failed to compile circuit to R1CS")?; display::print_r1cs_breakdown(&stats, &circuit, &r1cs, &breakdown); // Run Gaussian elimination optimization and display results let mut optimized_r1cs = r1cs.clone(); - let opt_stats = optimize_r1cs(&mut optimized_r1cs); + let opt_stats = optimize_r1cs(&mut optimized_r1cs, &mut witness_builders, &mut witness_map); display::print_ge_optimization(&r1cs, &optimized_r1cs, &opt_stats);