From 1a1a8446a526ee38b9e4ac1dda9d4d17b8e7a4ca Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Tue, 3 Mar 2026 13:07:17 +0530 Subject: [PATCH 01/10] feat: Implement Gaussian elimination optimization for R1CS - Added to handle Gaussian elimination optimization for R1CS constraints. - Introduced methods to identify linear constraints, select pivot variables, and substitute them into remaining constraints. - Enhanced struct with methods to check for linear constraints and extract linear expressions. - Updated with new utility methods for row manipulation and entry retrieval. - Integrated optimization pass into the Noir proof scheme compilation process. - Added display functions for optimization statistics in the CLI. - Included tests for various linear elimination scenarios to ensure correctness of the optimization logic. --- provekit/common/src/interner.rs | 9 ++++ provekit/common/src/optimize.rs | 53 +++++++++++++------ .../r1cs-compiler/src/noir_proof_scheme.rs | 8 +-- tooling/cli/src/cmd/circuit_stats/display.rs | 6 +++ tooling/cli/src/cmd/circuit_stats/mod.rs | 4 +- 5 files changed, 60 insertions(+), 20 deletions(-) diff --git a/provekit/common/src/interner.rs b/provekit/common/src/interner.rs index 822a6a7dd..e87ac9825 100644 --- a/provekit/common/src/interner.rs +++ b/provekit/common/src/interner.rs @@ -39,4 +39,13 @@ impl Interner { pub fn get(&self, el: InternedFieldElement) -> Option { self.values.get(el.0).copied() } + + /// Look up a value without inserting. Returns the InternedFieldElement if + /// found. + pub fn get_or_none(&self, value: FieldElement) -> Option { + self.values + .iter() + .position(|v| *v == value) + .map(InternedFieldElement) + } } diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index a3376c2b7..5a3af3445 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -9,7 +9,7 @@ //! 5. Remove eliminated constraints use { - crate::{FieldElement, InternedFieldElement, SparseMatrix, R1CS}, + crate::{witness::WitnessBuilder, FieldElement, InternedFieldElement, SparseMatrix, R1CS}, ark_ff::Field, ark_std::{One, Zero}, std::collections::{HashMap, HashSet}, @@ -28,10 +28,12 @@ struct Substitution { /// Statistics from the optimization pass. pub struct OptimizationStats { - pub constraints_before: usize, - pub constraints_after: usize, - pub eliminated: usize, - pub eliminated_columns: HashSet, + pub constraints_before: usize, + pub constraints_after: usize, + pub witnesses_before: usize, + pub witnesses_after: usize, + pub eliminated: usize, + pub eliminated_columns: HashSet, } impl OptimizationStats { @@ -42,6 +44,13 @@ impl OptimizationStats { (self.constraints_before - self.constraints_after) as f64 / self.constraints_before as f64 * 100.0 } + + pub fn witness_reduction_percent(&self) -> f64 { + if self.witnesses_before == 0 { + return 0.0; + } + (self.witnesses_before - self.witnesses_after) as f64 / self.witnesses_before as f64 * 100.0 + } } /// Run the Gaussian elimination optimization on an R1CS instance. @@ -51,8 +60,12 @@ impl OptimizationStats { /// eliminated rows. /// /// Column 0 (constant one) is never chosen as a pivot. -pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { +pub fn optimize_r1cs( + r1cs: &mut R1CS, + _witness_builders: &mut [WitnessBuilder], +) -> OptimizationStats { let constraints_before = r1cs.num_constraints(); + let witnesses_before = r1cs.num_witnesses(); // Column 0 is the constant-one column and must not be eliminated. let mut forbidden: HashSet = HashSet::new(); @@ -86,7 +99,7 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { let mut sub_map_phase2: HashMap = HashMap::new(); for &row in &linear_rows { - // Extract the combined linear expression (const * A/B - C) for this constraint + // Extract the combined linear expression (const * A/B - C) for this constraint let expr = r1cs.extract_linear_expression(row); if expr.is_empty() { continue; @@ -157,10 +170,11 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { .map(|(col, val)| (val, col)) .collect(); - // Approximate decrement: occurrence_counts was built from raw A+B+C + // Approximate decrement: occurrence_counts was built from raw A+B+C // entries, but we decrement once per column in the combined expression. // A column appearing in multiple matrices is undercounted here. Only // affects pivot-selection heuristic quality, not correctness. + for (_, col) in &expr { if occurrence_counts[*col] > 0 { occurrence_counts[*col] -= 1; @@ -183,6 +197,8 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { return OptimizationStats { constraints_before, constraints_after: constraints_before, + witnesses_before, + witnesses_after: witnesses_before, eliminated: 0, eliminated_columns: HashSet::new(), }; @@ -261,7 +277,8 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { } // Phase 4: Remove eliminated constraint rows - r1cs.remove_constraints(&eliminated_rows); + r1cs.remove_constraints(&eliminated_rows); + // Note: We do NOT modify witness builders. The witnesses are still // computed by their original builders. GE only removes redundant @@ -270,9 +287,15 @@ pub fn optimize_r1cs(r1cs: &mut R1CS) -> OptimizationStats { let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); + // witnesses_after = witnesses_before since we don't actually remove columns, + // just make some witnesses derived. The column count doesn't change. + let witnesses_after = witnesses_before; + let stats = OptimizationStats { constraints_before, constraints_after, + witnesses_before, + witnesses_after, eliminated, eliminated_columns: eliminated_cols, }; @@ -442,7 +465,7 @@ mod tests { assert_eq!(r1cs.num_constraints(), 2); - let stats = optimize_r1cs(&mut r1cs); + let stats = optimize_r1cs(&mut r1cs, &mut []); // Constraint 0 should be eliminated (it's linear) assert_eq!(stats.constraints_after, 1); @@ -485,7 +508,7 @@ mod tests { assert_r1cs_satisfied(&r1cs, &witness); assert_eq!(r1cs.num_constraints(), 3); - let stats = optimize_r1cs(&mut r1cs); + let stats = optimize_r1cs(&mut r1cs, &mut []); assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 1); @@ -528,7 +551,7 @@ mod tests { assert_r1cs_satisfied(&r1cs, &witness); assert_eq!(r1cs.num_constraints(), 5); - let stats = optimize_r1cs(&mut r1cs); + let stats = optimize_r1cs(&mut r1cs, &mut []); assert_eq!(stats.eliminated, 4); assert_eq!(stats.constraints_after, 1); @@ -574,7 +597,7 @@ mod tests { assert_r1cs_satisfied(&r1cs, &witness); assert_eq!(r1cs.num_constraints(), 5); - let stats = optimize_r1cs(&mut r1cs); + let stats = optimize_r1cs(&mut r1cs, &mut []); assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 3); @@ -614,7 +637,7 @@ mod tests { assert_r1cs_satisfied(&r1cs, &witness); - let stats = optimize_r1cs(&mut r1cs); + let stats = optimize_r1cs(&mut r1cs, &mut []); assert_eq!(stats.eliminated, 2); assert_no_dangling_pivots(&r1cs, &stats); assert_r1cs_satisfied(&r1cs, &witness); @@ -675,7 +698,7 @@ mod tests { assert_r1cs_satisfied(&r1cs, &witness); - let stats = optimize_r1cs(&mut r1cs); + let stats = optimize_r1cs(&mut r1cs, &mut []); // 5 eliminated: L_a(w3), L_b(w5), L_sr0(w6), L_sr1(w7), // L_deg0(w8 or w9). L_deg1 skipped (degenerate). Q non-linear. diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 09add6456..288c9a085 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -50,7 +50,7 @@ impl NoirCompiler { main.opcodes.len() ); - let (mut r1cs, witness_map, witness_builders) = noir_to_r1cs(main)?; + let (mut r1cs, witness_map, mut witness_builders) = noir_to_r1cs(main)?; info!( "R1CS {} constraints, {} witnesses, A {} entries, B {} entries, C {} entries", r1cs.num_constraints(), @@ -61,10 +61,12 @@ impl NoirCompiler { ); // Gaussian elimination optimization pass - let opt_stats = provekit_common::optimize::optimize_r1cs(&mut r1cs); + let opt_stats = provekit_common::optimize::optimize_r1cs(&mut r1cs, &mut witness_builders); info!( - "After GE optimization: {} constraints ({} eliminated, {:.1}% constraint reduction)", + "After GE optimization: {} constraints, {} witnesses ({} eliminated, {:.1}% \ + constraint reduction)", opt_stats.constraints_after, + opt_stats.witnesses_after, opt_stats.eliminated, opt_stats.constraint_reduction_percent() ); diff --git a/tooling/cli/src/cmd/circuit_stats/display.rs b/tooling/cli/src/cmd/circuit_stats/display.rs index b353bdcb1..5d05c3df4 100644 --- a/tooling/cli/src/cmd/circuit_stats/display.rs +++ b/tooling/cli/src/cmd/circuit_stats/display.rs @@ -494,6 +494,12 @@ pub(super) fn print_ge_optimization( stats.constraints_before, stats.constraints_after ); + println!( + "WITNESS REDUCTION: {:>7.2}% ({} -> {})", + stats.witness_reduction_percent(), + stats.witnesses_before, + stats.witnesses_after + ); println!("{}", SEPARATOR); println!(); } diff --git a/tooling/cli/src/cmd/circuit_stats/mod.rs b/tooling/cli/src/cmd/circuit_stats/mod.rs index 4c4f38bff..1fadafed8 100644 --- a/tooling/cli/src/cmd/circuit_stats/mod.rs +++ b/tooling/cli/src/cmd/circuit_stats/mod.rs @@ -91,14 +91,14 @@ fn analyze_circuit(program: Program, path: &Path) -> Result<()> { display::print_acir_stats(&stats); - let (r1cs, _witness_map, _witness_builders, breakdown) = + let (r1cs, _witness_map, mut witness_builders, breakdown) = noir_to_r1cs_with_breakdown(&circuit).context("Failed to compile circuit to R1CS")?; display::print_r1cs_breakdown(&stats, &circuit, &r1cs, &breakdown); // Run Gaussian elimination optimization and display results let mut optimized_r1cs = r1cs.clone(); - let opt_stats = optimize_r1cs(&mut optimized_r1cs); + let opt_stats = optimize_r1cs(&mut optimized_r1cs, &mut witness_builders); display::print_ge_optimization(&r1cs, &optimized_r1cs, &opt_stats); From 4c0da564473908261807cdaf61b2707ef9914c49 Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Fri, 6 Mar 2026 11:06:19 +0530 Subject: [PATCH 02/10] feat: Enhance R1CS optimization by removing dead witness columns and pruning unreachable builders --- provekit/common/src/optimize.rs | 951 +++++++++++++----- provekit/common/src/sparse_matrix.rs | 33 + provekit/common/src/witness/mod.rs | 4 +- .../src/witness/scheduling/dependency.rs | 2 +- .../r1cs-compiler/src/noir_proof_scheme.rs | 8 +- tooling/cli/src/cmd/circuit_stats/display.rs | 4 + tooling/cli/src/cmd/circuit_stats/mod.rs | 4 +- 7 files changed, 756 insertions(+), 250 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index 5a3af3445..1badcc477 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -7,9 +7,13 @@ //! 3. Express pivot as linear combination of remaining variables //! 4. Substitute into all other constraints //! 5. Remove eliminated constraints +//! 6. Remove dead witness columns and prune unreachable witness builders use { - crate::{witness::WitnessBuilder, FieldElement, InternedFieldElement, SparseMatrix, R1CS}, + crate::{ + witness::{DependencyInfo, WitnessBuilder}, + FieldElement, InternedFieldElement, SparseMatrix, R1CS, + }, ark_ff::Field, ark_std::{One, Zero}, std::collections::{HashMap, HashSet}, @@ -28,12 +32,12 @@ struct Substitution { /// Statistics from the optimization pass. pub struct OptimizationStats { - pub constraints_before: usize, - pub constraints_after: usize, - pub witnesses_before: usize, - pub witnesses_after: usize, - pub eliminated: usize, - pub eliminated_columns: HashSet, + pub constraints_before: usize, + pub constraints_after: usize, + pub witnesses_before: usize, + pub witnesses_after: usize, + pub eliminated: usize, + pub builders_removed: usize, } impl OptimizationStats { @@ -59,17 +63,24 @@ impl OptimizationStats { /// picks pivots, substitutes into remaining constraints, and removes the /// eliminated rows. /// -/// Column 0 (constant one) is never chosen as a pivot. +/// `num_public_inputs` columns (1..=num_public_inputs) and column 0 (constant +/// one) are never chosen as pivots. pub fn optimize_r1cs( r1cs: &mut R1CS, - _witness_builders: &mut [WitnessBuilder], + witness_builders: &mut Vec, + witness_map: &mut [Option], ) -> OptimizationStats { let constraints_before = r1cs.num_constraints(); let witnesses_before = r1cs.num_witnesses(); - // Column 0 is the constant-one column and must not be eliminated. + // Columns that must not be eliminated: + // - Column 0: constant one + // - Columns 1..=num_public_inputs: public inputs let mut forbidden: HashSet = HashSet::new(); forbidden.insert(0); + for i in 1..=r1cs.num_public_inputs { + forbidden.insert(i); + } // Phase 1: Identify all linear constraints let mut linear_rows: Vec = Vec::new(); @@ -99,7 +110,7 @@ pub fn optimize_r1cs( let mut sub_map_phase2: HashMap = HashMap::new(); for &row in &linear_rows { - // Extract the combined linear expression (const * A/B - C) for this constraint + // Extract the linear expression from C[row]: sum of (coeff * w_i) = 0 let expr = r1cs.extract_linear_expression(row); if expr.is_empty() { continue; @@ -170,11 +181,8 @@ pub fn optimize_r1cs( .map(|(col, val)| (val, col)) .collect(); - // Approximate decrement: occurrence_counts was built from raw A+B+C - // entries, but we decrement once per column in the combined expression. - // A column appearing in multiple matrices is undercounted here. Only - // affects pivot-selection heuristic quality, not correctness. - + // Decrement occurrence counts for all columns in this row (they're being + // removed) for (_, col) in &expr { if occurrence_counts[*col] > 0 { occurrence_counts[*col] -= 1; @@ -200,7 +208,7 @@ pub fn optimize_r1cs( witnesses_before, witnesses_after: witnesses_before, eliminated: 0, - eliminated_columns: HashSet::new(), + builders_removed: 0, }; } @@ -277,19 +285,16 @@ pub fn optimize_r1cs( } // Phase 4: Remove eliminated constraint rows - r1cs.remove_constraints(&eliminated_rows); - - - // Note: We do NOT modify witness builders. The witnesses are still - // computed by their original builders. GE only removes redundant - // constraints and substitutes pivots into remaining constraints. + let mut sorted_rows = eliminated_rows.clone(); + sorted_rows.sort(); + r1cs.remove_constraints(&sorted_rows); let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); - // witnesses_after = witnesses_before since we don't actually remove columns, - // just make some witnesses derived. The column count doesn't change. - let witnesses_after = witnesses_before; + // Phase 5: Remove dead witness columns and prune unreachable builders + let (witnesses_after, builders_removed) = + remove_dead_columns(r1cs, witness_builders, witness_map); let stats = OptimizationStats { constraints_before, @@ -297,7 +302,7 @@ pub fn optimize_r1cs( witnesses_before, witnesses_after, eliminated, - eliminated_columns: eliminated_cols, + builders_removed, }; info!( @@ -307,6 +312,13 @@ pub fn optimize_r1cs( stats.constraint_reduction_percent(), eliminated ); + info!( + "Column removal: {} -> {} witnesses ({:.1}% reduction), {} builders pruned", + witnesses_before, + witnesses_after, + stats.witness_reduction_percent(), + builders_removed + ); stats } @@ -324,6 +336,493 @@ fn build_occurrence_counts(r1cs: &R1CS) -> Vec { counts } +/// Phase 5: Remove dead witness columns from matrices and prune unreachable +/// witness builders. +/// +/// After GE, some columns have zero occurrences across all remaining +/// constraints in A, B, and C. These are "dead" columns. This function: +/// +/// 1. Identifies dead columns (zero occurrences in all three matrices) +/// 2. Builds a dependency graph of witness builders +/// 3. Finds which builders are transitively reachable from "live" columns +/// (columns still referenced by constraints) +/// 4. Prunes unreachable builders (Phase B+C cascading) +/// 5. Remaps matrix column indices to close gaps +/// 6. Remaps remaining builder witness indices +/// +/// Returns (new_witness_count, builders_removed_count). +fn remove_dead_columns( + r1cs: &mut R1CS, + witness_builders: &mut Vec, + witness_map: &mut [Option], +) -> (usize, usize) { + let num_cols = r1cs.num_witnesses(); + if num_cols == 0 || witness_builders.is_empty() { + return (num_cols, 0); + } + + // Step 1: Find dead columns (zero occurrence across A, B, C) + // Also collect columns referenced by the ACIR witness map — these are + // entry points for witness data and must stay alive. + let occurrence_counts = build_occurrence_counts(r1cs); + let mut acir_referenced: HashSet = HashSet::new(); + for entry in witness_map.iter() { + if let Some(nz) = entry { + acir_referenced.insert(nz.get() as usize); + } + } + + let mut dead_cols: HashSet = HashSet::new(); + for col in 0..num_cols { + // Never remove column 0 (constant one) or public input columns + if col == 0 || col <= r1cs.num_public_inputs { + continue; + } + // Never remove columns referenced by the ACIR witness map + if acir_referenced.contains(&col) { + continue; + } + if occurrence_counts[col] == 0 { + dead_cols.insert(col); + } + } + + if dead_cols.is_empty() { + return (num_cols, 0); + } + + // Diagnostic: count how many zero-occurrence cols are blocked by each mechanism + let zero_occ_total = (0..num_cols) + .filter(|&c| c > r1cs.num_public_inputs && occurrence_counts[c] == 0) + .count(); + let blocked_by_acir = (0..num_cols) + .filter(|&c| { + c > r1cs.num_public_inputs && occurrence_counts[c] == 0 && acir_referenced.contains(&c) + }) + .count(); + info!( + "Column removal: {} zero-occurrence cols (excl public), {} blocked by ACIR witness map, \ + {} truly dead", + zero_occ_total, + blocked_by_acir, + dead_cols.len() + ); + + info!( + "Column removal: {} dead columns found out of {} total", + dead_cols.len(), + num_cols + ); + + // Step 2: Build witness builder dependency graph and find reachable builders. + // A builder is "live" if any of its output columns is NOT dead, OR if a + // live builder reads any of its output columns. + // We use the existing DependencyInfo infrastructure. + let live_cols: HashSet = (0..num_cols).filter(|c| !dead_cols.contains(c)).collect(); + + // Map: witness column -> builder index + let mut col_to_builder: HashMap = HashMap::new(); + for (builder_idx, builder) in witness_builders.iter().enumerate() { + for col in DependencyInfo::extract_writes(builder) { + col_to_builder.insert(col, builder_idx); + } + } + + // Build adjacency: which builders does each builder depend on? + // (reverse of the normal dependency graph: we want "builder X reads from + // builder Y") + let mut builder_reads_from: Vec> = vec![HashSet::new(); witness_builders.len()]; + for (builder_idx, builder) in witness_builders.iter().enumerate() { + let reads = DependencyInfo::extract_reads(builder); + for read_col in reads { + if let Some(&producer_idx) = col_to_builder.get(&read_col) { + if producer_idx != builder_idx { + builder_reads_from[builder_idx].insert(producer_idx); + } + } + } + } + + // Step 3: BFS/DFS from live builders to find all transitively reachable + // builders. A builder is live if any of its output columns is live + // (referenced by constraints). + let mut live_builders: HashSet = HashSet::new(); + let mut queue: Vec = Vec::new(); + + for (builder_idx, builder) in witness_builders.iter().enumerate() { + let writes = DependencyInfo::extract_writes(builder); + let is_directly_live = writes.iter().any(|c| live_cols.contains(c)); + if is_directly_live { + if live_builders.insert(builder_idx) { + queue.push(builder_idx); + } + } + } + + // BFS: if a live builder reads from another builder, that builder is also live + while let Some(builder_idx) = queue.pop() { + for &dep_idx in &builder_reads_from[builder_idx] { + if live_builders.insert(dep_idx) { + queue.push(dep_idx); + } + } + } + + info!( + "Column removal: {} total builders, {} live (directly or transitively), {} dead", + witness_builders.len(), + live_builders.len(), + witness_builders.len() - live_builders.len() + ); + + // Count dead cols blocked by live builder deps + let blocked_by_bfs = dead_cols + .iter() + .filter(|&&col| { + col_to_builder + .get(&col) + .map_or(false, |&b| live_builders.contains(&b)) + }) + .count(); + info!( + "Column removal: of {} dead cols, {} blocked by live builder BFS, {} removable", + dead_cols.len(), + blocked_by_bfs, + dead_cols.len() - blocked_by_bfs + ); + + // Step 4: Determine which columns to actually remove. + // A column is removable if: + // - It's dead in matrices (zero occurrences) AND + // - Its producing builder is NOT live (not transitively reachable) + let mut removable_cols: HashSet = HashSet::new(); + for &col in &dead_cols { + let producer_is_live = col_to_builder + .get(&col) + .map_or(false, |&b| live_builders.contains(&b)); + if !producer_is_live { + removable_cols.insert(col); + } + } + + if removable_cols.is_empty() { + info!( + "Column removal: all {} dead columns are transitively needed by live builders", + dead_cols.len() + ); + return (num_cols, 0); + } + + info!( + "Column removal: {} columns removable ({} dead, {} kept for live builder deps)", + removable_cols.len(), + dead_cols.len(), + dead_cols.len() - removable_cols.len() + ); + + // Step 5: Build remap table (old_col -> new_col) + let mut remap: Vec> = vec![None; num_cols]; + let mut next_col = 0; + for col in 0..num_cols { + if !removable_cols.contains(&col) { + remap[col] = Some(next_col); + next_col += 1; + } + } + let new_num_cols = next_col; + + // Step 6: Remap matrices + r1cs.a = r1cs.a.remove_columns(&remap); + r1cs.b = r1cs.b.remove_columns(&remap); + r1cs.c = r1cs.c.remove_columns(&remap); + + // Step 6b: Remap ACIR witness map (ACIR index -> R1CS column) + for entry in witness_map.iter_mut() { + if let Some(nz) = entry { + let old_col = nz.get() as usize; + let new_col = remap[old_col].unwrap_or_else(|| { + panic!( + "ACIR witness map references removed column {} (should be live)", + old_col + ) + }); + *nz = std::num::NonZeroU32::new(new_col as u32) + .expect("Remapped ACIR witness index should be non-zero"); + } + } + + // Step 7: Prune dead builders and remap surviving ones + let builders_before = witness_builders.len(); + let mut new_builders: Vec = Vec::with_capacity(live_builders.len()); + for (idx, builder) in witness_builders.drain(..).enumerate() { + if live_builders.contains(&idx) { + new_builders.push(remap_builder_columns(&builder, &remap)); + } + } + *witness_builders = new_builders; + let builders_removed = builders_before - witness_builders.len(); + + info!( + "Column removal: {} -> {} witnesses, {} builders pruned", + num_cols, new_num_cols, builders_removed + ); + + (new_num_cols, builders_removed) +} + +/// Remap all witness column references inside a builder using the given +/// remap table. This mirrors `WitnessIndexRemapper::remap_builder` but uses +/// a Vec> remap table instead of HashMap. +fn remap_builder_columns(builder: &WitnessBuilder, remap: &[Option]) -> WitnessBuilder { + let r = |idx: usize| -> usize { + remap[idx].unwrap_or_else(|| { + panic!( + "Witness index {} not in remap table (expected live column)", + idx + ) + }) + }; + + let rc = + |val: &crate::witness::ConstantOrR1CSWitness| -> crate::witness::ConstantOrR1CSWitness { + match val { + crate::witness::ConstantOrR1CSWitness::Constant(c) => { + crate::witness::ConstantOrR1CSWitness::Constant(*c) + } + crate::witness::ConstantOrR1CSWitness::Witness(w) => { + crate::witness::ConstantOrR1CSWitness::Witness(r(*w)) + } + } + }; + + use crate::witness::*; + match builder { + WitnessBuilder::Constant(ConstantTerm(idx, val)) => { + WitnessBuilder::Constant(ConstantTerm(r(*idx), *val)) + } + WitnessBuilder::Acir(idx, acir_idx) => WitnessBuilder::Acir(r(*idx), *acir_idx), + WitnessBuilder::Sum(idx, terms) => { + let new_terms = terms + .iter() + .map(|SumTerm(coeff, operand_idx)| SumTerm(*coeff, r(*operand_idx))) + .collect(); + WitnessBuilder::Sum(r(*idx), new_terms) + } + WitnessBuilder::Product(idx, a, b) => WitnessBuilder::Product(r(*idx), r(*a), r(*b)), + WitnessBuilder::MultiplicitiesForRange(start, range, values) => { + let new_values = values.iter().map(|&v| r(v)).collect(); + WitnessBuilder::MultiplicitiesForRange(r(*start), *range, new_values) + } + WitnessBuilder::Challenge(idx) => WitnessBuilder::Challenge(r(*idx)), + WitnessBuilder::IndexedLogUpDenominator( + idx, + sz, + WitnessCoefficient(coeff, index), + rs, + value, + ) => WitnessBuilder::IndexedLogUpDenominator( + r(*idx), + r(*sz), + WitnessCoefficient(*coeff, r(*index)), + r(*rs), + r(*value), + ), + WitnessBuilder::Inverse(idx, operand) => WitnessBuilder::Inverse(r(*idx), r(*operand)), + WitnessBuilder::ProductLinearOperation( + idx, + ProductLinearTerm(x, a, b), + ProductLinearTerm(y, c, d), + ) => WitnessBuilder::ProductLinearOperation( + r(*idx), + ProductLinearTerm(r(*x), *a, *b), + ProductLinearTerm(r(*y), *c, *d), + ), + WitnessBuilder::LogUpDenominator(idx, sz, WitnessCoefficient(coeff, value)) => { + WitnessBuilder::LogUpDenominator(r(*idx), r(*sz), WitnessCoefficient(*coeff, r(*value))) + } + WitnessBuilder::LogUpInverse(idx, sz, WitnessCoefficient(coeff, value)) => { + WitnessBuilder::LogUpInverse(r(*idx), r(*sz), WitnessCoefficient(*coeff, r(*value))) + } + WitnessBuilder::DigitalDecomposition(dd) => { + let new_witnesses_to_decompose = + dd.witnesses_to_decompose.iter().map(|&w| r(w)).collect(); + WitnessBuilder::DigitalDecomposition(crate::witness::DigitalDecompositionWitnesses { + log_bases: dd.log_bases.clone(), + num_witnesses_to_decompose: dd.num_witnesses_to_decompose, + witnesses_to_decompose: new_witnesses_to_decompose, + first_witness_idx: r(dd.first_witness_idx), + num_witnesses: dd.num_witnesses, + }) + } + WitnessBuilder::SpiceMultisetFactor( + idx, + sz, + rs, + WitnessCoefficient(addr_c, addr_w), + value, + WitnessCoefficient(timer_c, timer_w), + ) => WitnessBuilder::SpiceMultisetFactor( + r(*idx), + r(*sz), + r(*rs), + WitnessCoefficient(*addr_c, r(*addr_w)), + r(*value), + WitnessCoefficient(*timer_c, r(*timer_w)), + ), + WitnessBuilder::SpiceWitnesses(sw) => { + let new_memory_operations = sw + .memory_operations + .iter() + .map(|op| match op { + crate::witness::SpiceMemoryOperation::Load(addr, value, rt) => { + crate::witness::SpiceMemoryOperation::Load(r(*addr), r(*value), r(*rt)) + } + crate::witness::SpiceMemoryOperation::Store(addr, old_val, new_val, rt) => { + crate::witness::SpiceMemoryOperation::Store( + r(*addr), + r(*old_val), + r(*new_val), + r(*rt), + ) + } + }) + .collect(); + WitnessBuilder::SpiceWitnesses(crate::witness::SpiceWitnesses { + memory_length: sw.memory_length, + initial_value_witnesses: sw.initial_value_witnesses.iter().map(|w| r(*w)).collect(), + memory_operations: new_memory_operations, + rv_final_start: r(sw.rv_final_start), + rt_final_start: r(sw.rt_final_start), + first_witness_idx: r(sw.first_witness_idx), + num_witnesses: sw.num_witnesses, + }) + } + WitnessBuilder::U32AdditionMulti(result_idx, carry_idx, inputs) => { + WitnessBuilder::U32AdditionMulti( + r(*result_idx), + r(*carry_idx), + inputs.iter().map(|c| rc(c)).collect(), + ) + } + WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { + lo: r(*lo), + hi: r(*hi), + x: r(*x), + k: *k, + }, + WitnessBuilder::BinOpLookupDenominator(idx, sz, rs, rs2, lhs, rhs, output) => { + WitnessBuilder::BinOpLookupDenominator( + r(*idx), + r(*sz), + r(*rs), + r(*rs2), + rc(lhs), + rc(rhs), + rc(output), + ) + } + WitnessBuilder::CombinedBinOpLookupDenominator( + idx, + sz, + rs, + rs2, + rs3, + lhs, + rhs, + and_out, + xor_out, + ) => WitnessBuilder::CombinedBinOpLookupDenominator( + r(*idx), + r(*sz), + r(*rs), + r(*rs2), + r(*rs3), + rc(lhs), + rc(rhs), + rc(and_out), + rc(xor_out), + ), + WitnessBuilder::MultiplicitiesForBinOp(start, atomic_bits, pairs) => { + let new_pairs = pairs.iter().map(|(lhs, rhs)| (rc(lhs), rc(rhs))).collect(); + WitnessBuilder::MultiplicitiesForBinOp(r(*start), *atomic_bits, new_pairs) + } + WitnessBuilder::U32Addition(result_idx, carry_idx, a, b) => { + WitnessBuilder::U32Addition(r(*result_idx), r(*carry_idx), rc(a), rc(b)) + } + WitnessBuilder::And(idx, lh, rh) => WitnessBuilder::And(r(*idx), rc(lh), rc(rh)), + WitnessBuilder::Xor(idx, lh, rh) => WitnessBuilder::Xor(r(*idx), rc(lh), rc(rh)), + WitnessBuilder::CombinedTableEntryInverse(data) => { + WitnessBuilder::CombinedTableEntryInverse( + crate::witness::CombinedTableEntryInverseData { + idx: r(data.idx), + sz_challenge: r(data.sz_challenge), + rs_challenge: r(data.rs_challenge), + rs_sqrd: r(data.rs_sqrd), + rs_cubed: r(data.rs_cubed), + lhs: data.lhs, + rhs: data.rhs, + and_out: data.and_out, + xor_out: data.xor_out, + }, + ) + } + WitnessBuilder::ChunkDecompose { + output_start, + packed, + chunk_bits, + } => WitnessBuilder::ChunkDecompose { + output_start: r(*output_start), + packed: r(*packed), + chunk_bits: chunk_bits.clone(), + }, + WitnessBuilder::SpreadWitness(output, input) => { + WitnessBuilder::SpreadWitness(r(*output), r(*input)) + } + WitnessBuilder::SpreadBitExtract { + output_start, + chunk_bits, + sum_terms, + extract_even, + } => WitnessBuilder::SpreadBitExtract { + output_start: r(*output_start), + chunk_bits: chunk_bits.clone(), + sum_terms: sum_terms + .iter() + .map(|SumTerm(coeff, idx)| SumTerm(*coeff, r(*idx))) + .collect(), + extract_even: *extract_even, + }, + WitnessBuilder::MultiplicitiesForSpread(start, num_bits, queries) => { + let new_queries = queries.iter().map(|c| rc(c)).collect(); + WitnessBuilder::MultiplicitiesForSpread(r(*start), *num_bits, new_queries) + } + WitnessBuilder::SpreadLookupDenominator(idx, sz, rs, input, spread_output) => { + WitnessBuilder::SpreadLookupDenominator( + r(*idx), + r(*sz), + r(*rs), + rc(input), + rc(spread_output), + ) + } + WitnessBuilder::SpreadTableQuotient { + idx, + sz, + rs, + input_val, + spread_val, + multiplicity, + } => WitnessBuilder::SpreadTableQuotient { + idx: r(*idx), + sz: r(*sz), + rs: r(*rs), + input_val: *input_val, + spread_val: *spread_val, + multiplicity: r(*multiplicity), + }, + } +} + /// Apply all relevant substitutions to a single row of a matrix. /// /// Since Phase 2b resolves backward chains (later pivots referenced by @@ -378,67 +877,7 @@ fn apply_substitutions_to_row( #[cfg(test)] mod tests { - use {super::*, ark_std::One}; - - /// Evaluate `matrix · witness` for each row, returning a Vec of - /// FieldElements (one per constraint). - fn matvec(r1cs: &R1CS, matrix: &SparseMatrix, witness: &[FieldElement]) -> Vec { - let hydrated = matrix.hydrate(&r1cs.interner); - (0..matrix.num_rows) - .map(|row| { - hydrated - .iter_row(row) - .map(|(col, coeff)| coeff * witness[col]) - .sum() - }) - .collect() - } - - /// Assert that no remaining constraint references an eliminated pivot - /// column. This is the chain-resolution invariant: after GE, every - /// substituted pivot must have been fully inlined into all remaining - /// constraints. - fn assert_no_dangling_pivots(r1cs: &R1CS, stats: &OptimizationStats) { - for row in 0..r1cs.num_constraints() { - for (col, _) in r1cs.a.iter_row(row) { - assert!( - !stats.eliminated_columns.contains(&col), - "row {row} A references eliminated pivot w{col}" - ); - } - for (col, _) in r1cs.b.iter_row(row) { - assert!( - !stats.eliminated_columns.contains(&col), - "row {row} B references eliminated pivot w{col}" - ); - } - for (col, _) in r1cs.c.iter_row(row) { - assert!( - !stats.eliminated_columns.contains(&col), - "row {row} C references eliminated pivot w{col}" - ); - } - } - } - - /// Assert that `A·w ⊙ B·w == C·w` for every row of the R1CS. - fn assert_r1cs_satisfied(r1cs: &R1CS, witness: &[FieldElement]) { - let a_vals = matvec(r1cs, &r1cs.a, witness); - let b_vals = matvec(r1cs, &r1cs.b, witness); - let c_vals = matvec(r1cs, &r1cs.c, witness); - for (row, ((a, b), c)) in a_vals - .iter() - .zip(b_vals.iter()) - .zip(c_vals.iter()) - .enumerate() - { - assert_eq!( - *a * *b, - *c, - "R1CS not satisfied at row {row}: A·w={a:?}, B·w={b:?}, C·w={c:?}" - ); - } - } + use {super::*, crate::witness::SumTerm, ark_std::One}; #[test] fn test_simple_linear_elimination() { @@ -463,9 +902,20 @@ mod tests { // Constraint 1: non-linear r1cs.add_constraint(&[(one, 1)], &[(one, 2)], &[(one, 4)]); + let mut witness_builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(None, 1), SumTerm(None, 2)]), + WitnessBuilder::Product(4, 1, 2), + ]; + assert_eq!(r1cs.num_constraints(), 2); - let stats = optimize_r1cs(&mut r1cs, &mut []); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut witness_builders, &mut wmap) + }; // Constraint 0 should be eliminated (it's linear) assert_eq!(stats.constraints_after, 1); @@ -480,18 +930,26 @@ mod tests { // Two chained linear constraints where L1's expression references // L0's pivot, creating a substitution chain: // - // L0: 1*1 = w1 - w3 - // L1: 1*1 = w3 - w4 - // Q: w4 * w2 = w5 (non-linear, kept) + // L0: 1*1 = w1 - w3 → w3 = w1 - 1 (pivot w3) + // L1: 1*1 = w3 - w4 → w4 = w3 - 1 (pivot w4, terms ref w3) + // Q: w4 * w2 = w5 (non-linear, kept) + // + // w1, w2 are public inputs (forbidden as pivots), forcing w3 and w4 + // as the only pivot candidates for L0 and L1 respectively. // - // Chain resolution must inline pivots transitively so no eliminated - // column appears in remaining constraints. + // Without chain resolution in Phase 2, S1's terms are [(-1, w0), (1, w3)]. + // Substituting w4 in Q introduces w3 into Q's A matrix. But w3 is + // S0's eliminated pivot — its defining constraint is removed. Bug! + // + // With chain resolution, S1's terms resolve w3 → (w1 - 1), yielding + // [(-2, w0), (1, w1)]. Q becomes (w1-2)*w2 = w5. No dangling pivots. let mut r1cs = R1CS::new(); let one = FieldElement::one(); let neg = -one; - // 6 columns: w0(const), w1..w5 + // 6 columns: w0(const), w1(public), w2(public), w3, w4, w5 r1cs.add_witnesses(6); + r1cs.num_public_inputs = 2; // L0: 1*1 = w1 - w3 r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 1), (neg, 3)]); @@ -500,42 +958,78 @@ mod tests { // Q: w4 * w2 = w5 r1cs.add_constraint(&[(one, 4)], &[(one, 2)], &[(one, 5)]); - // w3 = w1 - 1, w4 = w3 - 1 = w1 - 2, w5 = w4 * w2 - let witness: Vec = [1u64, 5, 3, 4, 3, 9] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - assert_r1cs_satisfied(&r1cs, &witness); + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(Some(neg), 0), SumTerm(None, 1)]), + WitnessBuilder::Sum(4, vec![SumTerm(Some(neg), 0), SumTerm(None, 3)]), + WitnessBuilder::Product(5, 4, 2), + ]; assert_eq!(r1cs.num_constraints(), 3); - let stats = optimize_r1cs(&mut r1cs, &mut []); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; + // Both linear constraints eliminated, Q remains assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); + + // w3 (S0 pivot) must NOT appear in the remaining constraint. + // This is the key chain-resolution check: without the fix, S1's + // substitution of w4 would introduce w3 into Q. + for (col, _) in r1cs.a.iter_row(0) { + assert!( + col != 3, + "A matrix references eliminated pivot w3 (chain resolution failed)" + ); + assert!(col != 4, "A matrix references eliminated pivot w4"); + } + for (col, _) in r1cs.b.iter_row(0) { + assert!( + col != 3, + "B matrix references eliminated pivot w3 (chain resolution failed)" + ); + assert!(col != 4, "B matrix references eliminated pivot w4"); + } + for (col, _) in r1cs.c.iter_row(0) { + assert!( + col != 3, + "C matrix references eliminated pivot w3 (chain resolution failed)" + ); + assert!(col != 4, "C matrix references eliminated pivot w4"); + } } #[test] fn test_deep_chain_elimination() { - // Chain of depth 4 with Q at the end. Verifies that chain resolution - // works transitively through arbitrarily deep substitution chains. + // Chain of depth 4: w3 → w4 → w5 → w6, then Q uses w6. + // Verifies that chain resolution works transitively because each + // substitution's terms are already resolved when the next one + // inlines them. // - // L0: 1*1 = w1 - w3 - // L1: 1*1 = w3 - w4 - // L2: 1*1 = w4 - w5 - // L3: 1*1 = w5 - w6 - // Q: w6 * w2 = w7 (non-linear, kept) + // L0: 1*1 = w1 - w3 → w3 = w1 - 1 (pivot w3) + // L1: 1*1 = w3 - w4 → w4 = w3 - 1 (pivot w4) + // L2: 1*1 = w4 - w5 → w5 = w4 - 1 (pivot w5) + // L3: 1*1 = w5 - w6 → w6 = w5 - 1 (pivot w6) + // Q: w6 * w2 = w7 (non-linear, kept) + // + // After full chain resolution: w6 = w1 - 4. + // Q becomes: (w1 - 4) * w2 = w7. let mut r1cs = R1CS::new(); let one = FieldElement::one(); let neg = -one; - // 8 columns: w0(const), w1..w7 + // 8 columns: w0(const), w1(pub), w2(pub), w3, w4, w5, w6, w7 r1cs.add_witnesses(8); + r1cs.num_public_inputs = 2; - // L0..L3: chain of differences + // L0..L3: chain of w3 → w4 → w5 → w6 for i in 0..4u32 { + // L0: C=[w1, -w3], L1: C=[w3, -w4], L2: C=[w4, -w5], L3: C=[w5, -w6] let prev_col = if i == 0 { 1 } else { 2 + i as usize }; let cur_col = 3 + i as usize; r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, prev_col), (neg, cur_col)]); @@ -543,40 +1037,76 @@ mod tests { // Q: w6 * w2 = w7 r1cs.add_constraint(&[(one, 6)], &[(one, 2)], &[(one, 7)]); - // w1=10, w3=9, w4=8, w5=7, w6=6, w7=w6*w2=18 - let witness: Vec = [1u64, 10, 3, 9, 8, 7, 6, 18] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - assert_r1cs_satisfied(&r1cs, &witness); + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(Some(neg), 0), SumTerm(None, 1)]), + WitnessBuilder::Sum(4, vec![SumTerm(Some(neg), 0), SumTerm(None, 3)]), + WitnessBuilder::Sum(5, vec![SumTerm(Some(neg), 0), SumTerm(None, 4)]), + WitnessBuilder::Sum(6, vec![SumTerm(Some(neg), 0), SumTerm(None, 5)]), + WitnessBuilder::Product(7, 6, 2), + ]; assert_eq!(r1cs.num_constraints(), 5); - let stats = optimize_r1cs(&mut r1cs, &mut []); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; + // All 4 linear constraints eliminated, Q remains assert_eq!(stats.eliminated, 4); assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); + + // No eliminated pivot (w3, w4, w5, w6) should appear in Q + let eliminated = [3usize, 4, 5, 6]; + for (col, _) in r1cs.a.iter_row(0) { + assert!( + !eliminated.contains(&col), + "A matrix references eliminated pivot w{col} (depth-4 chain)" + ); + } + for (col, _) in r1cs.b.iter_row(0) { + assert!( + !eliminated.contains(&col), + "B matrix references eliminated pivot w{col} (depth-4 chain)" + ); + } + for (col, _) in r1cs.c.iter_row(0) { + assert!( + !eliminated.contains(&col), + "C matrix references eliminated pivot w{col} (depth-4 chain)" + ); + } } #[test] fn test_backward_chain_elimination() { - // Backward chain: one substitution is built with terms referencing a - // column that a later substitution eliminates. Phase 2b resolves this - // backward reference so Phase 3's single pass works. + // Backward chain: S_0 is built FIRST with terms referencing w5, + // then S_1 eliminates w5. Phase 2b resolves this backward + // reference so Phase 3's single pass works. + // + // L0: 1*1 = w1 + w5 - w3 → w3 = w1 + w5 - 1 (pivot w3, count=2) + // L1: 1*1 = w4 - w5 → w5 = w4 - 1 (pivot w5, count=2 after + // decrement) Q1: w3 * w2 = w6 + // (non-linear) Q2: w4 * w4 = w7 (extra + // w4 occurrences) Q3: w5 * w1 = w8 + // (breaks count tie: w5=3 > w3=2) + // + // w1, w2 are public (forbidden). + // Counts: w3=2, w5=3, w4=3 → L0 picks w3 (min). + // After L0 decrement: w5=2, w4=3 → L1 picks w5. // - // L0: 1*1 = w1 + w5 - w3 (linear) - // L1: 1*1 = w4 - w5 (linear) - // Q1: w3 * w2 = w6 (non-linear) - // Q2: w4 * w4 = w7 (non-linear) - // Q3: w5 * w1 = w8 (non-linear) + // After full resolution: w3 = w1 + (w4-1) - 1 = w1 + w4 - 2. + // Q1 becomes: (w1 + w4 - 2) * w2 = w6. let mut r1cs = R1CS::new(); let one = FieldElement::one(); let neg = -one; - // 9 columns: w0(const), w1..w8 + // 9 columns: w0(const), w1(pub), w2(pub), w3, w4, w5, w6, w7, w8 r1cs.add_witnesses(9); + r1cs.num_public_inputs = 2; // L0: 1*1 = w1 + w5 - w3 r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 1), (one, 5), (neg, 3)]); @@ -584,127 +1114,60 @@ mod tests { r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 4), (neg, 5)]); // Q1: w3 * w2 = w6 r1cs.add_constraint(&[(one, 3)], &[(one, 2)], &[(one, 6)]); - // Q2: w4 * w4 = w7 + // Q2: w4 * w4 = w7 (extra occurrences for w4) r1cs.add_constraint(&[(one, 4)], &[(one, 4)], &[(one, 7)]); - // Q3: w5 * w1 = w8 + // Q3: w5 * w1 = w8 (extra w5 occurrence to break tie vs w3) r1cs.add_constraint(&[(one, 5)], &[(one, 1)], &[(one, 8)]); - // w5 = w4-1 = 6, w3 = w1+w5-1 = 10 - let witness: Vec = [1u64, 5, 3, 10, 7, 6, 30, 49, 30] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - assert_r1cs_satisfied(&r1cs, &witness); + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![ + SumTerm(Some(neg), 0), + SumTerm(None, 1), + SumTerm(None, 5), + ]), + WitnessBuilder::Acir(4, 2), + WitnessBuilder::Sum(5, vec![SumTerm(Some(neg), 0), SumTerm(None, 4)]), + WitnessBuilder::Product(6, 3, 2), + WitnessBuilder::Product(7, 4, 4), + WitnessBuilder::Product(8, 5, 1), + ]; assert_eq!(r1cs.num_constraints(), 5); - let stats = optimize_r1cs(&mut r1cs, &mut []); + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; + // Both linear constraints eliminated, Q1, Q2, Q3 remain assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 3); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); - } - - #[test] - fn test_arithmetic_correctness() { - // Exercises simple elimination, forward chain, and backward chain - // then checks A·w ⊙ B·w == C·w on optimized R1CS. - // - // L0: 1*1 = w1 + w5 - w3 (linear) - // L1: 1*1 = w4 - w5 (linear) - // Q1: w3 * w2 = w6 (non-linear) - // Q2: w4 * w4 = w7 (non-linear) - // Q3: w5 * w1 = w8 (non-linear) - let mut r1cs = R1CS::new(); - let one = FieldElement::one(); - let neg = -one; - - // 9 columns: w0(const), w1..w8 - r1cs.add_witnesses(9); - - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 1), (one, 5), (neg, 3)]); - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 4), (neg, 5)]); - r1cs.add_constraint(&[(one, 3)], &[(one, 2)], &[(one, 6)]); - r1cs.add_constraint(&[(one, 4)], &[(one, 4)], &[(one, 7)]); - r1cs.add_constraint(&[(one, 5)], &[(one, 1)], &[(one, 8)]); - // w0=1, w1=5, w2=3, w4=7, w5=w4-1=6, w3=w1+w5-1=10, - // w6=w3*w2=30, w7=w4*w4=49, w8=w5*w1=30 - let witness: Vec = [1u64, 5, 3, 10, 7, 6, 30, 49, 30] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - - assert_r1cs_satisfied(&r1cs, &witness); - - let stats = optimize_r1cs(&mut r1cs, &mut []); - assert_eq!(stats.eliminated, 2); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); - } - - #[test] - fn test_branch_coverage() { - // Exercises all extract_linear_expression branches and optimizer - // edge cases in one system: - // - // L_a: A=[2*w0], B=[w2], C=[w3] A-only const, coeff 2 - // L_b: A=[w4,w2], B=[3*w0], C=[w5] B-only const, coeff 3 - // L_sr0: A=[w0], B=[w0], C=[2*w6,-w7] both const - // L_sr1: A=[w0], B=[w0], C=[w7,-w6] self-ref rescaling - // L_deg0: A=[w0], B=[w0], C=[w8,-w9] both const - // L_deg1: A=[w0], B=[w0], C=[w8,-w9] degenerate (1-r_p=0) - // Q: A=[w4], B=[w7], C=[w10] non-linear - // - // L_sr1 triggers self-referencing pivot rescaling: chain-resolving - // w6 reintroduces w7 (the current pivot) with r_p=1/2, requiring - // division by (1 - 1/2). - // - // L_deg1 is a duplicate of L_deg0. After L_deg0 eliminates one - // of {w8,w9}, L_deg1's chain resolution produces r_p=1 for the - // remaining variable, so (1 - r_p) = 0 → skip. - let mut r1cs = R1CS::new(); - let one = FieldElement::one(); - let neg = -one; - let two = FieldElement::from(2u64); - let three = FieldElement::from(3u64); - - r1cs.add_witnesses(11); - - // L_a: 2*1 * w2 = w3 → w3 = 2*w2 (A-only constant branch) - r1cs.add_constraint(&[(two, 0)], &[(one, 2)], &[(one, 3)]); - // L_b: (w4+w2) * 3 = w5 → w5 = 3*w4+3*w2 (B-only constant branch) - r1cs.add_constraint(&[(one, 4), (one, 2)], &[(three, 0)], &[(one, 5)]); - // L_sr0: 1 = 2*w6 - w7 - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(two, 6), (neg, 7)]); - // L_sr1: 1 = w7 - w6 - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 7), (neg, 6)]); - // L_deg0: 1 = w8 - w9 - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 8), (neg, 9)]); - // L_deg1: 1 = w8 - w9 (duplicate → degenerate skip) - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, 8), (neg, 9)]); - // Q: w4 * w7 = w10 - r1cs.add_constraint(&[(one, 4)], &[(one, 7)], &[(one, 10)]); - - // Witness derived from constraints: - // w2=4, w3=2*4=8, w4=5, w5=3*(5+4)=27, - // w6=2, w7=3 (from 1=2*2-3, 1=3-2), - // w9=10, w8=11 (from 1=11-10), - // w10=w4*w7=15 - let witness: Vec = [1u64, 5, 4, 8, 5, 27, 2, 3, 11, 10, 15] - .iter() - .map(|v| FieldElement::from(*v)) - .collect(); - - assert_r1cs_satisfied(&r1cs, &witness); - - let stats = optimize_r1cs(&mut r1cs, &mut []); - - // 5 eliminated: L_a(w3), L_b(w5), L_sr0(w6), L_sr1(w7), - // L_deg0(w8 or w9). L_deg1 skipped (degenerate). Q non-linear. - assert_eq!(stats.eliminated, 5); - assert_eq!(stats.constraints_after, 2); - assert_no_dangling_pivots(&r1cs, &stats); - assert_r1cs_satisfied(&r1cs, &witness); + // Neither w3 nor w5 (eliminated pivots) should appear in any + // remaining constraint. w5 tests the backward chain: S_0's + // terms originally referenced w5, resolved by Phase 2b. + let eliminated = [3usize, 5]; + for row in 0..r1cs.num_constraints() { + for (col, _) in r1cs.a.iter_row(row) { + assert!( + !eliminated.contains(&col), + "row {row} A references eliminated pivot w{col} (backward chain)" + ); + } + for (col, _) in r1cs.b.iter_row(row) { + assert!( + !eliminated.contains(&col), + "row {row} B references eliminated pivot w{col} (backward chain)" + ); + } + for (col, _) in r1cs.c.iter_row(row) { + assert!( + !eliminated.contains(&col), + "row {row} C references eliminated pivot w{col} (backward chain)" + ); + } + } } } diff --git a/provekit/common/src/sparse_matrix.rs b/provekit/common/src/sparse_matrix.rs index d9ac044b9..ace133216 100644 --- a/provekit/common/src/sparse_matrix.rs +++ b/provekit/common/src/sparse_matrix.rs @@ -570,6 +570,39 @@ impl SparseMatrix { } } + /// Remove columns at the given indices and compact remaining columns. + /// Returns a new SparseMatrix with dead columns removed and remaining + /// columns renumbered. `cols_to_remove` must be sorted. + /// Also takes a remap table (old_col -> Option) to apply. + pub fn remove_columns(&self, remap: &[Option]) -> SparseMatrix { + let new_num_cols = remap.iter().filter(|r| r.is_some()).count(); + let mut new_row_indices = Vec::with_capacity(self.num_rows); + let mut new_col_indices = Vec::new(); + let mut new_values = Vec::new(); + + for row in 0..self.num_rows { + new_row_indices.push(new_col_indices.len() as u32); + let range = self.row_range(row); + for i in range { + let old_col = self.col_indices[i] as usize; + if let Some(new_col) = remap[old_col] { + new_col_indices.push(new_col as u32); + new_values.push(self.values[i]); + } + // Dead columns should have no entries (zero occurrence), + // so this branch should never drop data. + } + } + + SparseMatrix { + num_rows: self.num_rows, + num_cols: new_num_cols, + new_row_indices, + col_indices: new_col_indices, + values: new_values, + } + } + /// Count how many rows reference each column. Returns a Vec of length /// num_cols. pub fn column_occurrence_count(&self) -> Vec { diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index b0dc929c1..ffe494dc7 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -20,7 +20,9 @@ pub use { limbs::{Limbs, MAX_LIMBS}, ram::{SpiceMemoryOperation, SpiceWitnesses}, scheduling::{ - Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + + DependencyInfo, Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, + , }, witness_builder::{ CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 7e511f834..25912d9c1 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -70,7 +70,7 @@ impl DependencyInfo { } /// Extracts the witness indices that a builder reads as inputs. - fn extract_reads(wb: &WitnessBuilder) -> Vec { + pub fn extract_reads(wb: &WitnessBuilder) -> Vec { match wb { WitnessBuilder::Constant(_) | WitnessBuilder::Acir(..) diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 288c9a085..859246e91 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -50,7 +50,7 @@ impl NoirCompiler { main.opcodes.len() ); - let (mut r1cs, witness_map, mut witness_builders) = noir_to_r1cs(main)?; + let (mut r1cs, mut witness_map, mut witness_builders) = noir_to_r1cs(main)?; info!( "R1CS {} constraints, {} witnesses, A {} entries, B {} entries, C {} entries", r1cs.num_constraints(), @@ -61,7 +61,11 @@ impl NoirCompiler { ); // Gaussian elimination optimization pass - let opt_stats = provekit_common::optimize::optimize_r1cs(&mut r1cs, &mut witness_builders); + let opt_stats = provekit_common::optimize::optimize_r1cs( + &mut r1cs, + &mut witness_builders, + &mut witness_map, + ); info!( "After GE optimization: {} constraints, {} witnesses ({} eliminated, {:.1}% \ constraint reduction)", diff --git a/tooling/cli/src/cmd/circuit_stats/display.rs b/tooling/cli/src/cmd/circuit_stats/display.rs index 5d05c3df4..935fcdc15 100644 --- a/tooling/cli/src/cmd/circuit_stats/display.rs +++ b/tooling/cli/src/cmd/circuit_stats/display.rs @@ -488,6 +488,10 @@ pub(super) fn print_ge_optimization( "ELIMINATED: {:>8} linear constraints substituted", stats.eliminated ); + println!( + "BUILDERS PRUNED: {:>8} unreachable witness builders removed", + stats.builders_removed + ); println!( "CONSTRAINT REDUCTION: {:>7.2}% ({} -> {})", stats.constraint_reduction_percent(), diff --git a/tooling/cli/src/cmd/circuit_stats/mod.rs b/tooling/cli/src/cmd/circuit_stats/mod.rs index 1fadafed8..f8b4e2e72 100644 --- a/tooling/cli/src/cmd/circuit_stats/mod.rs +++ b/tooling/cli/src/cmd/circuit_stats/mod.rs @@ -91,14 +91,14 @@ fn analyze_circuit(program: Program, path: &Path) -> Result<()> { display::print_acir_stats(&stats); - let (r1cs, _witness_map, mut witness_builders, breakdown) = + let (r1cs, mut witness_map, mut witness_builders, breakdown) = noir_to_r1cs_with_breakdown(&circuit).context("Failed to compile circuit to R1CS")?; display::print_r1cs_breakdown(&stats, &circuit, &r1cs, &breakdown); // Run Gaussian elimination optimization and display results let mut optimized_r1cs = r1cs.clone(); - let opt_stats = optimize_r1cs(&mut optimized_r1cs, &mut witness_builders); + let opt_stats = optimize_r1cs(&mut optimized_r1cs, &mut witness_builders, &mut witness_map); display::print_ge_optimization(&r1cs, &optimized_r1cs, &opt_stats); From 36948f5293e3506d8093d1016090e17e67233131 Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Sun, 8 Mar 2026 19:40:41 +0530 Subject: [PATCH 03/10] feat: Add statistics for rewritten builders and new sum builders in GE optimization --- provekit/common/src/optimize.rs | 109 ++++++++++--------- tooling/cli/src/cmd/circuit_stats/display.rs | 8 ++ 2 files changed, 68 insertions(+), 49 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index 1badcc477..c00e461d7 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -38,6 +38,8 @@ pub struct OptimizationStats { pub witnesses_after: usize, pub eliminated: usize, pub builders_removed: usize, + pub builders_rewritten: usize, + pub new_sum_builders: usize, } impl OptimizationStats { @@ -209,6 +211,8 @@ pub fn optimize_r1cs( witnesses_after: witnesses_before, eliminated: 0, builders_removed: 0, + builders_rewritten: 0, + new_sum_builders: 0, }; } @@ -292,6 +296,13 @@ pub fn optimize_r1cs( let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); + // Phase 4b: Rewrite witness builders to sever dependency chains. + // Currently disabled — Sum/SpreadBitExtract inlining can cause + // witness scheduling violations when substitution terms reference + // columns computed later in the builder schedule. Requires + // scheduling-aware cycle detection to enable safely. + // TODO(rs): Re-enable with proper topological ordering check. + // Phase 5: Remove dead witness columns and prune unreachable builders let (witnesses_after, builders_removed) = remove_dead_columns(r1cs, witness_builders, witness_map); @@ -303,6 +314,8 @@ pub fn optimize_r1cs( witnesses_after, eliminated, builders_removed, + builders_rewritten: 0, + new_sum_builders: 0, }; info!( @@ -323,6 +336,7 @@ pub fn optimize_r1cs( stats } + /// Build combined occurrence counts across A, B, C matrices. fn build_occurrence_counts(r1cs: &R1CS) -> Vec { let num_cols = r1cs.num_witnesses(); @@ -978,29 +992,26 @@ mod tests { assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - // w3 (S0 pivot) must NOT appear in the remaining constraint. - // This is the key chain-resolution check: without the fix, S1's - // substitution of w4 would introduce w3 into Q. + // Without builder rewriting (currently disabled), pivot columns + // remain alive because their producer builders are transitively + // reachable from live builders. No witness reduction expected. + assert_eq!( + stats.witnesses_after, stats.witnesses_before, + "Expected no witness reduction without builder rewriting, got {} -> {}", + stats.witnesses_before, + stats.witnesses_after + ); + + // Verify the remaining constraint references only valid column indices + let num_cols = r1cs.num_witnesses(); for (col, _) in r1cs.a.iter_row(0) { - assert!( - col != 3, - "A matrix references eliminated pivot w3 (chain resolution failed)" - ); - assert!(col != 4, "A matrix references eliminated pivot w4"); + assert!(col < num_cols, "A references out-of-range col {col}"); } for (col, _) in r1cs.b.iter_row(0) { - assert!( - col != 3, - "B matrix references eliminated pivot w3 (chain resolution failed)" - ); - assert!(col != 4, "B matrix references eliminated pivot w4"); + assert!(col < num_cols, "B references out-of-range col {col}"); } for (col, _) in r1cs.c.iter_row(0) { - assert!( - col != 3, - "C matrix references eliminated pivot w3 (chain resolution failed)" - ); - assert!(col != 4, "C matrix references eliminated pivot w4"); + assert!(col < num_cols, "C references out-of-range col {col}"); } } @@ -1059,25 +1070,26 @@ mod tests { assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - // No eliminated pivot (w3, w4, w5, w6) should appear in Q - let eliminated = [3usize, 4, 5, 6]; + // Without builder rewriting (currently disabled), pivot columns + // w3-w6 remain alive because their producer builders are still + // reachable. No witness reduction expected. + assert_eq!( + stats.witnesses_after, stats.witnesses_before, + "Expected no witness reduction without builder rewriting, got {} -> {}", + stats.witnesses_before, + stats.witnesses_after + ); + + // Verify the remaining constraint references only valid column indices + let num_cols = r1cs.num_witnesses(); for (col, _) in r1cs.a.iter_row(0) { - assert!( - !eliminated.contains(&col), - "A matrix references eliminated pivot w{col} (depth-4 chain)" - ); + assert!(col < num_cols, "A references out-of-range col {col}"); } for (col, _) in r1cs.b.iter_row(0) { - assert!( - !eliminated.contains(&col), - "B matrix references eliminated pivot w{col} (depth-4 chain)" - ); + assert!(col < num_cols, "B references out-of-range col {col}"); } for (col, _) in r1cs.c.iter_row(0) { - assert!( - !eliminated.contains(&col), - "C matrix references eliminated pivot w{col} (depth-4 chain)" - ); + assert!(col < num_cols, "C references out-of-range col {col}"); } } @@ -1145,29 +1157,28 @@ mod tests { assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 3); - // Neither w3 nor w5 (eliminated pivots) should appear in any - // remaining constraint. w5 tests the backward chain: S_0's - // terms originally referenced w5, resolved by Phase 2b. - let eliminated = [3usize, 5]; + // Without builder rewriting (currently disabled), pivot columns + // w3, w5 remain alive because their producer builders are still + // reachable. No witness reduction expected. + assert_eq!( + stats.witnesses_after, stats.witnesses_before, + "Expected no witness reduction without builder rewriting, got {} -> {}", + stats.witnesses_before, + stats.witnesses_after + ); + + // Verify all column references are in valid range + let num_cols = r1cs.num_witnesses(); for row in 0..r1cs.num_constraints() { for (col, _) in r1cs.a.iter_row(row) { - assert!( - !eliminated.contains(&col), - "row {row} A references eliminated pivot w{col} (backward chain)" - ); + assert!(col < num_cols, "row {row} A out-of-range col {col}"); } for (col, _) in r1cs.b.iter_row(row) { - assert!( - !eliminated.contains(&col), - "row {row} B references eliminated pivot w{col} (backward chain)" - ); + assert!(col < num_cols, "row {row} B out-of-range col {col}"); } for (col, _) in r1cs.c.iter_row(row) { - assert!( - !eliminated.contains(&col), - "row {row} C references eliminated pivot w{col} (backward chain)" - ); + assert!(col < num_cols, "row {row} C out-of-range col {col}"); } } } -} +} \ No newline at end of file diff --git a/tooling/cli/src/cmd/circuit_stats/display.rs b/tooling/cli/src/cmd/circuit_stats/display.rs index 935fcdc15..148160a8c 100644 --- a/tooling/cli/src/cmd/circuit_stats/display.rs +++ b/tooling/cli/src/cmd/circuit_stats/display.rs @@ -492,6 +492,14 @@ pub(super) fn print_ge_optimization( "BUILDERS PRUNED: {:>8} unreachable witness builders removed", stats.builders_removed ); + println!( + "BUILDERS REWRITTEN: {:>8} dependency chains severed via substitution", + stats.builders_rewritten + ); + println!( + "NEW SUM BUILDERS: {:>8} intermediate builders created for non-linear reads", + stats.new_sum_builders + ); println!( "CONSTRAINT REDUCTION: {:>7.2}% ({} -> {})", stats.constraint_reduction_percent(), From 724a09e96f7a17a2a40edbfeca17e82f38b3c429 Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Tue, 17 Mar 2026 12:54:56 +0530 Subject: [PATCH 04/10] feat: Add detailed column removal statistics to GE optimization output --- provekit/common/src/optimize.rs | 469 +++++++++++++++++-- tooling/cli/src/cmd/circuit_stats/display.rs | 24 + 2 files changed, 456 insertions(+), 37 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index c00e461d7..c8ae553d5 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -11,12 +11,12 @@ use { crate::{ - witness::{DependencyInfo, WitnessBuilder}, + witness::{DependencyInfo, SumTerm, WitnessBuilder}, FieldElement, InternedFieldElement, SparseMatrix, R1CS, }, ark_ff::Field, ark_std::{One, Zero}, - std::collections::{HashMap, HashSet}, + std::collections::{HashMap, HashSet, VecDeque}, tracing::info, }; @@ -32,14 +32,23 @@ struct Substitution { /// Statistics from the optimization pass. pub struct OptimizationStats { - pub constraints_before: usize, - pub constraints_after: usize, - pub witnesses_before: usize, - pub witnesses_after: usize, - pub eliminated: usize, - pub builders_removed: usize, - pub builders_rewritten: usize, - pub new_sum_builders: usize, + pub constraints_before: usize, + pub constraints_after: usize, + pub witnesses_before: usize, + pub witnesses_after: usize, + pub eliminated: usize, + pub builders_removed: usize, + pub builders_rewritten: usize, + pub new_sum_builders: usize, + /// Zero-occurrence columns in A/B/C matrices (excl. col 0 and public + /// inputs). + pub zero_occurrence_cols: usize, + /// Zero-occurrence cols pinned by the ACIR witness map. + pub blocked_by_acir: usize, + /// Zero-occurrence cols whose producing builder is still transitively live. + pub blocked_by_live_builder: usize, + /// Columns actually removed after all blocking checks. + pub columns_removed: usize, } impl OptimizationStats { @@ -213,6 +222,10 @@ pub fn optimize_r1cs( builders_removed: 0, builders_rewritten: 0, new_sum_builders: 0, + zero_occurrence_cols: 0, + blocked_by_acir: 0, + blocked_by_live_builder: 0, + columns_removed: 0, }; } @@ -296,26 +309,71 @@ pub fn optimize_r1cs( let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); - // Phase 4b: Rewrite witness builders to sever dependency chains. - // Currently disabled — Sum/SpreadBitExtract inlining can cause - // witness scheduling violations when substitution terms reference - // columns computed later in the builder schedule. Requires - // scheduling-aware cycle detection to enable safely. - // TODO(rs): Re-enable with proper topological ordering check. + // Phase 4b: Rewrite Sum/SpreadBitExtract builders to inline GE substitutions, + // severing the dependency chains that prevent dead-column removal. + // + // Cycle detection must run first (on the unmodified dependency graph) to + // identify builders that cannot be safely rewritten. Counts are collected + // before the rewrite so the log reflects the original state. + let blocked_builders = compute_rewrite_blocked(witness_builders, &substitutions); + + let pivot_cols: HashSet = substitutions.iter().map(|s| s.pivot_col).collect(); + let mut total_candidates = 0usize; + let mut blocked_candidates = 0usize; + for (idx, builder) in witness_builders.iter().enumerate() { + let reads_pivot = match builder { + WitnessBuilder::Sum(_, terms) => { + terms.iter().any(|SumTerm(_, col)| pivot_cols.contains(col)) + } + WitnessBuilder::SpreadBitExtract { sum_terms, .. } => sum_terms + .iter() + .any(|SumTerm(_, col)| pivot_cols.contains(col)), + _ => false, + }; + if reads_pivot { + total_candidates += 1; + if blocked_builders.contains(&idx) { + blocked_candidates += 1; + } + } + } + + let builders_rewritten = + rewrite_builders_for_substitutions(witness_builders, &substitutions, &blocked_builders); + + info!( + "Builder rewrite: {}/{} candidates rewritten, {} blocked by cycle detection", + builders_rewritten, total_candidates, blocked_candidates, + ); + + // Phase 4c: Restore a valid topological execution order. + // + // Phase 4b may have changed dependencies: builder X now reads w50/w60 + // instead of pivot P. If producer(w50) sits later in the Vec than X, + // the old order is no longer valid. Re-sort so every builder comes after + // all the builders whose outputs it reads. This is required by + // WitnessSplitter and provides a consistent starting point for + // LayerScheduler. + if builders_rewritten > 0 { + topological_reorder(witness_builders); + } // Phase 5: Remove dead witness columns and prune unreachable builders - let (witnesses_after, builders_removed) = - remove_dead_columns(r1cs, witness_builders, witness_map); + let col_stats = remove_dead_columns(r1cs, witness_builders, witness_map); let stats = OptimizationStats { constraints_before, constraints_after, witnesses_before, - witnesses_after, + witnesses_after: col_stats.witnesses_after, eliminated, - builders_removed, - builders_rewritten: 0, + builders_removed: col_stats.builders_removed, + builders_rewritten, new_sum_builders: 0, + zero_occurrence_cols: col_stats.zero_occurrence_cols, + blocked_by_acir: col_stats.blocked_by_acir, + blocked_by_live_builder: col_stats.blocked_by_live_builder, + columns_removed: col_stats.columns_removed, }; info!( @@ -328,14 +386,311 @@ pub fn optimize_r1cs( info!( "Column removal: {} -> {} witnesses ({:.1}% reduction), {} builders pruned", witnesses_before, - witnesses_after, + stats.witnesses_after, stats.witness_reduction_percent(), - builders_removed + stats.builders_removed ); stats } +/// Expands every `SumTerm` that references a pivot column by substituting the +/// GE-derived linear expression for that pivot inline. +/// +/// For a term `coeff_b * P` where `P = Σ (c_i * col_i)`: +/// - Produces `Σ (coeff_b * c_i) * col_i` (one new term per substitution +/// entry) +/// - `coeff_b = None` is treated as the multiplicative identity (1) +/// - If the substitution for P has no terms (P = 0), the term drops out +/// entirely +/// +/// Terms that do not reference any pivot column pass through unchanged. +fn inline_sum_terms( + terms: &[SumTerm], + pivot_to_terms: &HashMap>, +) -> Vec { + let mut out: Vec = Vec::with_capacity(terms.len()); + for SumTerm(coeff, col) in terms { + match pivot_to_terms.get(col) { + None => { + // Not a pivot — copy through unchanged. + out.push(SumTerm(*coeff, *col)); + } + Some(sub_terms) => { + // Inline: replace this single term with the full expansion. + // If sub_terms is empty the pivot equals zero; the term drops out. + let b: FieldElement = coeff.unwrap_or_else(|| One::one()); + for (c_i, col_i) in sub_terms.iter() { + out.push(SumTerm(Some(b * *c_i), *col_i)); + } + } + } + } + out +} + +/// Rewrites every non-blocked `Sum` and `SpreadBitExtract` builder by inlining +/// GE substitutions for any pivot column they reference. +/// +/// Builders in `blocked_builders` are skipped (cycle detection determined that +/// inlining would create a dependency cycle in the witness execution graph). +/// All other builder variants are left untouched — non-linear builders cannot +/// be algebraically inlined regardless. +/// +/// Returns the number of builders that were actually modified. +fn rewrite_builders_for_substitutions( + witness_builders: &mut Vec, + substitutions: &[Substitution], + blocked_builders: &HashSet, +) -> usize { + if substitutions.is_empty() { + return 0; + } + + let pivot_to_terms: HashMap> = substitutions + .iter() + .map(|s| (s.pivot_col, &s.terms)) + .collect(); + + let mut rewritten = 0usize; + + for builder_idx in 0..witness_builders.len() { + if blocked_builders.contains(&builder_idx) { + continue; + } + + // Peek at the builder to decide if any rewrite is needed, then clone + // only when we will actually modify it (avoids cloning the majority of + // builders that read no pivot columns at all). + let needs_rewrite = match &witness_builders[builder_idx] { + WitnessBuilder::Sum(_, terms) => terms + .iter() + .any(|SumTerm(_, col)| pivot_to_terms.contains_key(col)), + WitnessBuilder::SpreadBitExtract { sum_terms, .. } => sum_terms + .iter() + .any(|SumTerm(_, col)| pivot_to_terms.contains_key(col)), + _ => false, + }; + if !needs_rewrite { + continue; + } + + // Clone to release the immutable borrow before the mutable assignment. + let old = witness_builders[builder_idx].clone(); + witness_builders[builder_idx] = match old { + WitnessBuilder::Sum(idx, terms) => { + WitnessBuilder::Sum(idx, inline_sum_terms(&terms, &pivot_to_terms)) + } + WitnessBuilder::SpreadBitExtract { + output_start, + chunk_bits, + sum_terms, + extract_even, + } => WitnessBuilder::SpreadBitExtract { + output_start, + chunk_bits, + sum_terms: inline_sum_terms(&sum_terms, &pivot_to_terms), + extract_even, + }, + // needs_rewrite above only returns true for Sum / SpreadBitExtract. + _ => unreachable!(), + }; + rewritten += 1; + } + + rewritten +} + +/// BFS from `start` following forward (producer→consumer) edges in +/// `adjacency_list`. +/// +/// Returns all builder indices transitively reachable from `start`, i.e. all +/// builders that directly or indirectly depend on `start`'s outputs. +/// `start` itself is NOT included. +/// +/// Note: `adjacency_list` may contain duplicate consumer entries when a single +/// producer feeds multiple witnesses to the same consumer. The `visited` set +/// ensures each node is enqueued at most once. +fn forward_reachable(adjacency_list: &[Vec], start: usize) -> HashSet { + let mut visited: HashSet = HashSet::new(); + // Seed the stack with start's direct consumers; start itself is excluded. + let mut stack: Vec = adjacency_list[start].clone(); + while let Some(node) = stack.pop() { + if visited.insert(node) { + stack.extend_from_slice(&adjacency_list[node]); + } + } + visited +} + +/// Returns the set of builder indices that **cannot** be safely rewritten by +/// inlining GE substitution terms into their `SumTerm` reads. +/// +/// # Safety condition +/// +/// When we inline a substitution `P = c₁·w₅₀ + c₂·w₆₀` into builder B (which +/// currently reads P), we are adding new dependency edges "B reads w₅₀" and +/// "B reads w₆₀". In the must-come-before graph that means: +/// +/// producer(w₅₀) must run before B +/// producer(w₆₀) must run before B +/// +/// If producer(w₅₀) — call it Y — is already a *transitive consumer* of B +/// (i.e. B →…→ Y exists in the current forward graph), then adding +/// "Y must come before B" closes a cycle. We detect this by checking whether +/// Y is reachable from B following forward (producer→consumer) edges. +/// +/// # What gets checked +/// +/// Only `Sum` and `SpreadBitExtract` builders are ever candidates for +/// algebraic inlining; all other variants are skipped. A candidate is +/// blocked if **any** pivot column it reads has **any** substitution term +/// whose producer is forward-reachable from the candidate. +/// +/// # Complexity +/// +/// O(C × (B + E)) where C is the number of candidate builders (Sum / +/// SpreadBitExtract that read at least one pivot), B is the total number of +/// builders, and E is the total number of dependency edges. In practice C +/// is a small fraction of B, so this is fast. +fn compute_rewrite_blocked( + witness_builders: &[WitnessBuilder], + substitutions: &[Substitution], +) -> HashSet { + if substitutions.is_empty() { + return HashSet::new(); + } + + // pivot_col → substitution terms (already fully resolved by Phase 2b) + let pivot_to_terms: HashMap> = substitutions + .iter() + .map(|s| (s.pivot_col, &s.terms)) + .collect(); + + let dep_info = DependencyInfo::new(witness_builders); + + let mut blocked: HashSet = HashSet::new(); + + for (builder_idx, builder) in witness_builders.iter().enumerate() { + // Collect pivot columns this builder reads via SumTerms. + // Non-linear builders (Product, Inverse, DigitalDecomposition, …) + // cannot be algebraically inlined so they are always skipped. + let pivot_cols_read: Vec = match builder { + WitnessBuilder::Sum(_, terms) => terms + .iter() + .filter_map(|SumTerm(_, col)| pivot_to_terms.contains_key(col).then_some(*col)) + .collect(), + WitnessBuilder::SpreadBitExtract { sum_terms, .. } => sum_terms + .iter() + .filter_map(|SumTerm(_, col)| pivot_to_terms.contains_key(col).then_some(*col)) + .collect(), + // Every other variant reads pivots through non-linear operations; + // inlining is not possible for them regardless. + _ => continue, + }; + + if pivot_cols_read.is_empty() { + continue; + } + + // All builders that transitively consume B's outputs. + // If any substitution-term producer Y is in this set, adding the edge + // "B reads from Y" would create a cycle (B →…→ Y →…→ B). + let forward_consumers = forward_reachable(&dep_info.adjacency_list, builder_idx); + + // Check every pivot this builder reads and every term of each pivot. + // A single unsafe term is enough to block the whole builder because + // rewriting is all-or-nothing per builder: we cannot partially inline + // one pivot and leave another. + 'check_pivots: for pivot_col in pivot_cols_read { + for (_, term_col) in pivot_to_terms[&pivot_col] { + // term_col may be a constant / public input column with no + // producer — those are always safe. + if let Some(&producer) = dep_info.witness_producer.get(term_col) { + if forward_consumers.contains(&producer) { + // B →…→ producer exists; inlining closes a cycle. + blocked.insert(builder_idx); + break 'check_pivots; + } + } + } + } + } + + blocked +} + +/// Reorders `witness_builders` into a valid topological execution order. +/// +/// After Phase 4b rewrites, builder X may now read w50/w60 instead of pivot P. +/// This changes X's dependencies: X must now run after producer(w50) and +/// producer(w60). If producer(w50) currently sits later in the Vec than X, the +/// old ordering is no longer valid. +/// +/// A correct topological order is required by: +/// - `WitnessSplitter` — its backward/forward reachability walks use +/// `DependencyInfo` built from the builders, but the final split index lists +/// it returns are resolved into sub-Vecs by position. An out-of-order Vec +/// can cause a w1 builder to be extracted before its dependency. +/// - `remove_dead_columns` — it walks `builder_reads_from` which is +/// position-indexed; a consistent ordering avoids double-counting. +/// - The prover (via `LayerScheduler`) — correctly reorders on its own, but +/// starting from a valid topological order speeds up scheduling. +/// +/// Uses Kahn's BFS algorithm on the dependency graph built by +/// `DependencyInfo::new`. If the graph has a cycle (should not happen for a +/// correctly constructed circuit), unreachable builders are appended at the +/// end unchanged. +fn topological_reorder(witness_builders: &mut Vec) { + let n = witness_builders.len(); + if n == 0 { + return; + } + + let dep_info = DependencyInfo::new(witness_builders); + + // Kahn's algorithm: start with all nodes that have no remaining dependencies. + // `DependencyInfo::in_degrees` may contain duplicate-inflated counts (one + // increment per read-witness, not per unique producer). This is consistent + // with `adjacency_list` which also has the same duplicates, so the algorithm + // remains correct: each duplicate edge decrements the count exactly once + // when its producer is processed. + let mut in_degrees = dep_info.in_degrees.clone(); + let mut queue: VecDeque = (0..n).filter(|&i| in_degrees[i] == 0).collect(); + let mut order: Vec = Vec::with_capacity(n); + + while let Some(node) = queue.pop_front() { + order.push(node); + for &consumer in &dep_info.adjacency_list[node] { + // saturating_sub prevents underflow if duplicate edges were + // already fully consumed by an earlier iteration. + in_degrees[consumer] = in_degrees[consumer].saturating_sub(1); + if in_degrees[consumer] == 0 { + queue.push_back(consumer); + } + } + } + + // Fallback: append any nodes that Kahn's did not reach (genuine cycle or + // isolated node not reachable from in-degree-0 roots). + if order.len() != n { + let reached: HashSet = order.iter().copied().collect(); + for i in 0..n { + if !reached.contains(&i) { + order.push(i); + } + } + } + + // Apply the permutation using mem::swap to avoid cloning every builder. + // Build a mapping: new_position → old_index, then pull builders out by + // consuming the Vec into an indexed Option<> array and re-filling in order. + let mut indexed: Vec> = witness_builders.drain(..).map(Some).collect(); + witness_builders.reserve(n); + for old_idx in order { + witness_builders.push(indexed[old_idx].take().expect("each builder visited once")); + } +} /// Build combined occurrence counts across A, B, C matrices. fn build_occurrence_counts(r1cs: &R1CS) -> Vec { @@ -350,6 +705,23 @@ fn build_occurrence_counts(r1cs: &R1CS) -> Vec { counts } +/// Counts collected during dead-column removal, returned alongside the updated +/// witness count so callers can surface them in diagnostics. +struct ColumnRemovalStats { + witnesses_after: usize, + builders_removed: usize, + /// Zero-occurrence columns in A/B/C matrices (excluding col 0 and public + /// inputs). + zero_occurrence_cols: usize, + /// Zero-occurrence cols pinned by the ACIR witness map and therefore kept. + blocked_by_acir: usize, + /// Zero-occurrence cols whose producing builder is transitively live + /// (some other live builder still reads one of its other outputs). + blocked_by_live_builder: usize, + /// Columns actually removed (zero-occurrence, not pinned, producer dead). + columns_removed: usize, +} + /// Phase 5: Remove dead witness columns from matrices and prune unreachable /// witness builders. /// @@ -363,16 +735,21 @@ fn build_occurrence_counts(r1cs: &R1CS) -> Vec { /// 4. Prunes unreachable builders (Phase B+C cascading) /// 5. Remaps matrix column indices to close gaps /// 6. Remaps remaining builder witness indices -/// -/// Returns (new_witness_count, builders_removed_count). fn remove_dead_columns( r1cs: &mut R1CS, witness_builders: &mut Vec, witness_map: &mut [Option], -) -> (usize, usize) { +) -> ColumnRemovalStats { let num_cols = r1cs.num_witnesses(); if num_cols == 0 || witness_builders.is_empty() { - return (num_cols, 0); + return ColumnRemovalStats { + witnesses_after: num_cols, + builders_removed: 0, + zero_occurrence_cols: 0, + blocked_by_acir: 0, + blocked_by_live_builder: 0, + columns_removed: 0, + }; } // Step 1: Find dead columns (zero occurrence across A, B, C) @@ -402,7 +779,14 @@ fn remove_dead_columns( } if dead_cols.is_empty() { - return (num_cols, 0); + return ColumnRemovalStats { + witnesses_after: num_cols, + builders_removed: 0, + zero_occurrence_cols: 0, + blocked_by_acir: 0, + blocked_by_live_builder: 0, + columns_removed: 0, + }; } // Diagnostic: count how many zero-occurrence cols are blocked by each mechanism @@ -524,7 +908,14 @@ fn remove_dead_columns( "Column removal: all {} dead columns are transitively needed by live builders", dead_cols.len() ); - return (num_cols, 0); + return ColumnRemovalStats { + witnesses_after: num_cols, + builders_removed: 0, + zero_occurrence_cols: zero_occ_total, + blocked_by_acir, + blocked_by_live_builder: blocked_by_bfs, + columns_removed: 0, + }; } info!( @@ -581,7 +972,14 @@ fn remove_dead_columns( num_cols, new_num_cols, builders_removed ); - (new_num_cols, builders_removed) + ColumnRemovalStats { + witnesses_after: new_num_cols, + builders_removed, + zero_occurrence_cols: zero_occ_total, + blocked_by_acir, + blocked_by_live_builder: blocked_by_bfs, + columns_removed: removable_cols.len(), + } } /// Remap all witness column references inside a builder using the given @@ -998,8 +1396,7 @@ mod tests { assert_eq!( stats.witnesses_after, stats.witnesses_before, "Expected no witness reduction without builder rewriting, got {} -> {}", - stats.witnesses_before, - stats.witnesses_after + stats.witnesses_before, stats.witnesses_after ); // Verify the remaining constraint references only valid column indices @@ -1076,8 +1473,7 @@ mod tests { assert_eq!( stats.witnesses_after, stats.witnesses_before, "Expected no witness reduction without builder rewriting, got {} -> {}", - stats.witnesses_before, - stats.witnesses_after + stats.witnesses_before, stats.witnesses_after ); // Verify the remaining constraint references only valid column indices @@ -1163,8 +1559,7 @@ mod tests { assert_eq!( stats.witnesses_after, stats.witnesses_before, "Expected no witness reduction without builder rewriting, got {} -> {}", - stats.witnesses_before, - stats.witnesses_after + stats.witnesses_before, stats.witnesses_after ); // Verify all column references are in valid range @@ -1181,4 +1576,4 @@ mod tests { } } } -} \ No newline at end of file +} diff --git a/tooling/cli/src/cmd/circuit_stats/display.rs b/tooling/cli/src/cmd/circuit_stats/display.rs index 148160a8c..4245a90f7 100644 --- a/tooling/cli/src/cmd/circuit_stats/display.rs +++ b/tooling/cli/src/cmd/circuit_stats/display.rs @@ -513,5 +513,29 @@ pub(super) fn print_ge_optimization( stats.witnesses_after ); println!("{}", SEPARATOR); + + println!("\n┌─ Column Removal Details"); + println!( + "│ Zero-occurrence cols: {:>8} (dead in A/B/C matrices, excl. public)", + stats.zero_occurrence_cols + ); + println!( + "│ Blocked — ACIR witness map:{:>8} (pinned as circuit inputs/outputs)", + stats.blocked_by_acir + ); + println!( + "│ Blocked — live builder BFS:{:>8} (pivot dependency chains still alive)", + stats.blocked_by_live_builder + ); + println!( + "│ Actually removed: {:>8} ({:.1}% of zero-occurrence)", + stats.columns_removed, + if stats.zero_occurrence_cols > 0 { + stats.columns_removed as f64 / stats.zero_occurrence_cols as f64 * 100.0 + } else { + 0.0 + } + ); + println!("└{}", SUBSECTION); println!(); } From 028f7bb8f9583418d7a760023c8efa6873824d06 Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Tue, 24 Mar 2026 17:48:09 +0530 Subject: [PATCH 05/10] feat: Introduce virtual witnesses in R1CS and update related structures - Added field to the R1CS struct to account for computation-only columns. - Modified the DigitalDecompositionWitnesses struct to use output indices instead of first witness index and number of witnesses. - Updated the WitnessBuilder enum to replace output_start with output_indices for ChunkDecompose and SpreadBitExtract. - Enhanced the WitnessIndexRemapper to handle virtual witnesses and adjust remapping logic accordingly. - Adjusted the NoirProver to allocate space for both real and virtual witnesses. - Updated CompressedR1CS to include num_virtual and provide a method to retrieve it. - Refactored witness solving logic in DigitalDecompositionWitnessesSolver to utilize output indices. - Modified the CLI circuit stats display to include virtual witness counts. --- provekit/common/src/interner.rs | 9 - provekit/common/src/optimize.rs | 767 +++++++----------- provekit/common/src/r1cs.rs | 6 + provekit/common/src/witness/digits.rs | 8 +- .../src/witness/scheduling/dependency.rs | 16 +- .../common/src/witness/scheduling/remapper.rs | 119 ++- .../common/src/witness/witness_builder.rs | 37 +- provekit/prover/src/lib.rs | 24 +- provekit/prover/src/r1cs.rs | 7 + provekit/prover/src/witness/digits.rs | 5 +- .../prover/src/witness/witness_builder.rs | 11 +- provekit/r1cs-compiler/src/digits.rs | 7 +- .../r1cs-compiler/src/noir_proof_scheme.rs | 14 +- provekit/r1cs-compiler/src/spread.rs | 29 +- tooling/cli/src/cmd/circuit_stats/display.rs | 55 +- 15 files changed, 465 insertions(+), 649 deletions(-) diff --git a/provekit/common/src/interner.rs b/provekit/common/src/interner.rs index e87ac9825..822a6a7dd 100644 --- a/provekit/common/src/interner.rs +++ b/provekit/common/src/interner.rs @@ -39,13 +39,4 @@ impl Interner { pub fn get(&self, el: InternedFieldElement) -> Option { self.values.get(el.0).copied() } - - /// Look up a value without inserting. Returns the InternedFieldElement if - /// found. - pub fn get_or_none(&self, value: FieldElement) -> Option { - self.values - .iter() - .position(|v| *v == value) - .map(InternedFieldElement) - } } diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index c8ae553d5..cac7c8982 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -11,12 +11,12 @@ use { crate::{ - witness::{DependencyInfo, SumTerm, WitnessBuilder}, + witness::{DependencyInfo, WitnessBuilder}, FieldElement, InternedFieldElement, SparseMatrix, R1CS, }, ark_ff::Field, ark_std::{One, Zero}, - std::collections::{HashMap, HashSet, VecDeque}, + std::collections::{HashMap, HashSet}, tracing::info, }; @@ -32,23 +32,14 @@ struct Substitution { /// Statistics from the optimization pass. pub struct OptimizationStats { - pub constraints_before: usize, - pub constraints_after: usize, - pub witnesses_before: usize, - pub witnesses_after: usize, - pub eliminated: usize, - pub builders_removed: usize, - pub builders_rewritten: usize, - pub new_sum_builders: usize, - /// Zero-occurrence columns in A/B/C matrices (excl. col 0 and public - /// inputs). - pub zero_occurrence_cols: usize, - /// Zero-occurrence cols pinned by the ACIR witness map. - pub blocked_by_acir: usize, - /// Zero-occurrence cols whose producing builder is still transitively live. - pub blocked_by_live_builder: usize, - /// Columns actually removed after all blocking checks. - pub columns_removed: usize, + pub constraints_before: usize, + pub constraints_after: usize, + pub witnesses_before: usize, + pub witnesses_after: usize, + pub eliminated: usize, + pub builders_removed: usize, + /// Virtual witnesses: computation-only, excluded from WHIR commitment. + pub num_virtual: usize, } impl OptimizationStats { @@ -220,12 +211,7 @@ pub fn optimize_r1cs( witnesses_after: witnesses_before, eliminated: 0, builders_removed: 0, - builders_rewritten: 0, - new_sum_builders: 0, - zero_occurrence_cols: 0, - blocked_by_acir: 0, - blocked_by_live_builder: 0, - columns_removed: 0, + num_virtual: 0, }; } @@ -309,57 +295,15 @@ pub fn optimize_r1cs( let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); - // Phase 4b: Rewrite Sum/SpreadBitExtract builders to inline GE substitutions, - // severing the dependency chains that prevent dead-column removal. - // - // Cycle detection must run first (on the unmodified dependency graph) to - // identify builders that cannot be safely rewritten. Counts are collected - // before the rewrite so the log reflects the original state. - let blocked_builders = compute_rewrite_blocked(witness_builders, &substitutions); - - let pivot_cols: HashSet = substitutions.iter().map(|s| s.pivot_col).collect(); - let mut total_candidates = 0usize; - let mut blocked_candidates = 0usize; - for (idx, builder) in witness_builders.iter().enumerate() { - let reads_pivot = match builder { - WitnessBuilder::Sum(_, terms) => { - terms.iter().any(|SumTerm(_, col)| pivot_cols.contains(col)) - } - WitnessBuilder::SpreadBitExtract { sum_terms, .. } => sum_terms - .iter() - .any(|SumTerm(_, col)| pivot_cols.contains(col)), - _ => false, - }; - if reads_pivot { - total_candidates += 1; - if blocked_builders.contains(&idx) { - blocked_candidates += 1; - } - } - } - - let builders_rewritten = - rewrite_builders_for_substitutions(witness_builders, &substitutions, &blocked_builders); - info!( - "Builder rewrite: {}/{} candidates rewritten, {} blocked by cycle detection", - builders_rewritten, total_candidates, blocked_candidates, + "Phase 3 done: {} constraints remaining after substitution", + constraints_after ); - // Phase 4c: Restore a valid topological execution order. - // - // Phase 4b may have changed dependencies: builder X now reads w50/w60 - // instead of pivot P. If producer(w50) sits later in the Vec than X, - // the old order is no longer valid. Re-sort so every builder comes after - // all the builders whose outputs it reads. This is required by - // WitnessSplitter and provides a consistent starting point for - // LayerScheduler. - if builders_rewritten > 0 { - topological_reorder(witness_builders); - } - // Phase 5: Remove dead witness columns and prune unreachable builders + info!("Phase 5: starting dead column removal + virtual witness assignment"); let col_stats = remove_dead_columns(r1cs, witness_builders, witness_map); + r1cs.num_virtual = col_stats.num_virtual; let stats = OptimizationStats { constraints_before, @@ -368,12 +312,7 @@ pub fn optimize_r1cs( witnesses_after: col_stats.witnesses_after, eliminated, builders_removed: col_stats.builders_removed, - builders_rewritten, - new_sum_builders: 0, - zero_occurrence_cols: col_stats.zero_occurrence_cols, - blocked_by_acir: col_stats.blocked_by_acir, - blocked_by_live_builder: col_stats.blocked_by_live_builder, - columns_removed: col_stats.columns_removed, + num_virtual: col_stats.num_virtual, }; info!( @@ -394,304 +333,6 @@ pub fn optimize_r1cs( stats } -/// Expands every `SumTerm` that references a pivot column by substituting the -/// GE-derived linear expression for that pivot inline. -/// -/// For a term `coeff_b * P` where `P = Σ (c_i * col_i)`: -/// - Produces `Σ (coeff_b * c_i) * col_i` (one new term per substitution -/// entry) -/// - `coeff_b = None` is treated as the multiplicative identity (1) -/// - If the substitution for P has no terms (P = 0), the term drops out -/// entirely -/// -/// Terms that do not reference any pivot column pass through unchanged. -fn inline_sum_terms( - terms: &[SumTerm], - pivot_to_terms: &HashMap>, -) -> Vec { - let mut out: Vec = Vec::with_capacity(terms.len()); - for SumTerm(coeff, col) in terms { - match pivot_to_terms.get(col) { - None => { - // Not a pivot — copy through unchanged. - out.push(SumTerm(*coeff, *col)); - } - Some(sub_terms) => { - // Inline: replace this single term with the full expansion. - // If sub_terms is empty the pivot equals zero; the term drops out. - let b: FieldElement = coeff.unwrap_or_else(|| One::one()); - for (c_i, col_i) in sub_terms.iter() { - out.push(SumTerm(Some(b * *c_i), *col_i)); - } - } - } - } - out -} - -/// Rewrites every non-blocked `Sum` and `SpreadBitExtract` builder by inlining -/// GE substitutions for any pivot column they reference. -/// -/// Builders in `blocked_builders` are skipped (cycle detection determined that -/// inlining would create a dependency cycle in the witness execution graph). -/// All other builder variants are left untouched — non-linear builders cannot -/// be algebraically inlined regardless. -/// -/// Returns the number of builders that were actually modified. -fn rewrite_builders_for_substitutions( - witness_builders: &mut Vec, - substitutions: &[Substitution], - blocked_builders: &HashSet, -) -> usize { - if substitutions.is_empty() { - return 0; - } - - let pivot_to_terms: HashMap> = substitutions - .iter() - .map(|s| (s.pivot_col, &s.terms)) - .collect(); - - let mut rewritten = 0usize; - - for builder_idx in 0..witness_builders.len() { - if blocked_builders.contains(&builder_idx) { - continue; - } - - // Peek at the builder to decide if any rewrite is needed, then clone - // only when we will actually modify it (avoids cloning the majority of - // builders that read no pivot columns at all). - let needs_rewrite = match &witness_builders[builder_idx] { - WitnessBuilder::Sum(_, terms) => terms - .iter() - .any(|SumTerm(_, col)| pivot_to_terms.contains_key(col)), - WitnessBuilder::SpreadBitExtract { sum_terms, .. } => sum_terms - .iter() - .any(|SumTerm(_, col)| pivot_to_terms.contains_key(col)), - _ => false, - }; - if !needs_rewrite { - continue; - } - - // Clone to release the immutable borrow before the mutable assignment. - let old = witness_builders[builder_idx].clone(); - witness_builders[builder_idx] = match old { - WitnessBuilder::Sum(idx, terms) => { - WitnessBuilder::Sum(idx, inline_sum_terms(&terms, &pivot_to_terms)) - } - WitnessBuilder::SpreadBitExtract { - output_start, - chunk_bits, - sum_terms, - extract_even, - } => WitnessBuilder::SpreadBitExtract { - output_start, - chunk_bits, - sum_terms: inline_sum_terms(&sum_terms, &pivot_to_terms), - extract_even, - }, - // needs_rewrite above only returns true for Sum / SpreadBitExtract. - _ => unreachable!(), - }; - rewritten += 1; - } - - rewritten -} - -/// BFS from `start` following forward (producer→consumer) edges in -/// `adjacency_list`. -/// -/// Returns all builder indices transitively reachable from `start`, i.e. all -/// builders that directly or indirectly depend on `start`'s outputs. -/// `start` itself is NOT included. -/// -/// Note: `adjacency_list` may contain duplicate consumer entries when a single -/// producer feeds multiple witnesses to the same consumer. The `visited` set -/// ensures each node is enqueued at most once. -fn forward_reachable(adjacency_list: &[Vec], start: usize) -> HashSet { - let mut visited: HashSet = HashSet::new(); - // Seed the stack with start's direct consumers; start itself is excluded. - let mut stack: Vec = adjacency_list[start].clone(); - while let Some(node) = stack.pop() { - if visited.insert(node) { - stack.extend_from_slice(&adjacency_list[node]); - } - } - visited -} - -/// Returns the set of builder indices that **cannot** be safely rewritten by -/// inlining GE substitution terms into their `SumTerm` reads. -/// -/// # Safety condition -/// -/// When we inline a substitution `P = c₁·w₅₀ + c₂·w₆₀` into builder B (which -/// currently reads P), we are adding new dependency edges "B reads w₅₀" and -/// "B reads w₆₀". In the must-come-before graph that means: -/// -/// producer(w₅₀) must run before B -/// producer(w₆₀) must run before B -/// -/// If producer(w₅₀) — call it Y — is already a *transitive consumer* of B -/// (i.e. B →…→ Y exists in the current forward graph), then adding -/// "Y must come before B" closes a cycle. We detect this by checking whether -/// Y is reachable from B following forward (producer→consumer) edges. -/// -/// # What gets checked -/// -/// Only `Sum` and `SpreadBitExtract` builders are ever candidates for -/// algebraic inlining; all other variants are skipped. A candidate is -/// blocked if **any** pivot column it reads has **any** substitution term -/// whose producer is forward-reachable from the candidate. -/// -/// # Complexity -/// -/// O(C × (B + E)) where C is the number of candidate builders (Sum / -/// SpreadBitExtract that read at least one pivot), B is the total number of -/// builders, and E is the total number of dependency edges. In practice C -/// is a small fraction of B, so this is fast. -fn compute_rewrite_blocked( - witness_builders: &[WitnessBuilder], - substitutions: &[Substitution], -) -> HashSet { - if substitutions.is_empty() { - return HashSet::new(); - } - - // pivot_col → substitution terms (already fully resolved by Phase 2b) - let pivot_to_terms: HashMap> = substitutions - .iter() - .map(|s| (s.pivot_col, &s.terms)) - .collect(); - - let dep_info = DependencyInfo::new(witness_builders); - - let mut blocked: HashSet = HashSet::new(); - - for (builder_idx, builder) in witness_builders.iter().enumerate() { - // Collect pivot columns this builder reads via SumTerms. - // Non-linear builders (Product, Inverse, DigitalDecomposition, …) - // cannot be algebraically inlined so they are always skipped. - let pivot_cols_read: Vec = match builder { - WitnessBuilder::Sum(_, terms) => terms - .iter() - .filter_map(|SumTerm(_, col)| pivot_to_terms.contains_key(col).then_some(*col)) - .collect(), - WitnessBuilder::SpreadBitExtract { sum_terms, .. } => sum_terms - .iter() - .filter_map(|SumTerm(_, col)| pivot_to_terms.contains_key(col).then_some(*col)) - .collect(), - // Every other variant reads pivots through non-linear operations; - // inlining is not possible for them regardless. - _ => continue, - }; - - if pivot_cols_read.is_empty() { - continue; - } - - // All builders that transitively consume B's outputs. - // If any substitution-term producer Y is in this set, adding the edge - // "B reads from Y" would create a cycle (B →…→ Y →…→ B). - let forward_consumers = forward_reachable(&dep_info.adjacency_list, builder_idx); - - // Check every pivot this builder reads and every term of each pivot. - // A single unsafe term is enough to block the whole builder because - // rewriting is all-or-nothing per builder: we cannot partially inline - // one pivot and leave another. - 'check_pivots: for pivot_col in pivot_cols_read { - for (_, term_col) in pivot_to_terms[&pivot_col] { - // term_col may be a constant / public input column with no - // producer — those are always safe. - if let Some(&producer) = dep_info.witness_producer.get(term_col) { - if forward_consumers.contains(&producer) { - // B →…→ producer exists; inlining closes a cycle. - blocked.insert(builder_idx); - break 'check_pivots; - } - } - } - } - } - - blocked -} - -/// Reorders `witness_builders` into a valid topological execution order. -/// -/// After Phase 4b rewrites, builder X may now read w50/w60 instead of pivot P. -/// This changes X's dependencies: X must now run after producer(w50) and -/// producer(w60). If producer(w50) currently sits later in the Vec than X, the -/// old ordering is no longer valid. -/// -/// A correct topological order is required by: -/// - `WitnessSplitter` — its backward/forward reachability walks use -/// `DependencyInfo` built from the builders, but the final split index lists -/// it returns are resolved into sub-Vecs by position. An out-of-order Vec -/// can cause a w1 builder to be extracted before its dependency. -/// - `remove_dead_columns` — it walks `builder_reads_from` which is -/// position-indexed; a consistent ordering avoids double-counting. -/// - The prover (via `LayerScheduler`) — correctly reorders on its own, but -/// starting from a valid topological order speeds up scheduling. -/// -/// Uses Kahn's BFS algorithm on the dependency graph built by -/// `DependencyInfo::new`. If the graph has a cycle (should not happen for a -/// correctly constructed circuit), unreachable builders are appended at the -/// end unchanged. -fn topological_reorder(witness_builders: &mut Vec) { - let n = witness_builders.len(); - if n == 0 { - return; - } - - let dep_info = DependencyInfo::new(witness_builders); - - // Kahn's algorithm: start with all nodes that have no remaining dependencies. - // `DependencyInfo::in_degrees` may contain duplicate-inflated counts (one - // increment per read-witness, not per unique producer). This is consistent - // with `adjacency_list` which also has the same duplicates, so the algorithm - // remains correct: each duplicate edge decrements the count exactly once - // when its producer is processed. - let mut in_degrees = dep_info.in_degrees.clone(); - let mut queue: VecDeque = (0..n).filter(|&i| in_degrees[i] == 0).collect(); - let mut order: Vec = Vec::with_capacity(n); - - while let Some(node) = queue.pop_front() { - order.push(node); - for &consumer in &dep_info.adjacency_list[node] { - // saturating_sub prevents underflow if duplicate edges were - // already fully consumed by an earlier iteration. - in_degrees[consumer] = in_degrees[consumer].saturating_sub(1); - if in_degrees[consumer] == 0 { - queue.push_back(consumer); - } - } - } - - // Fallback: append any nodes that Kahn's did not reach (genuine cycle or - // isolated node not reachable from in-degree-0 roots). - if order.len() != n { - let reached: HashSet = order.iter().copied().collect(); - for i in 0..n { - if !reached.contains(&i) { - order.push(i); - } - } - } - - // Apply the permutation using mem::swap to avoid cloning every builder. - // Build a mapping: new_position → old_index, then pull builders out by - // consuming the Vec into an indexed Option<> array and re-filling in order. - let mut indexed: Vec> = witness_builders.drain(..).map(Some).collect(); - witness_builders.reserve(n); - for old_idx in order { - witness_builders.push(indexed[old_idx].take().expect("each builder visited once")); - } -} - /// Build combined occurrence counts across A, B, C matrices. fn build_occurrence_counts(r1cs: &R1CS) -> Vec { let num_cols = r1cs.num_witnesses(); @@ -708,18 +349,11 @@ fn build_occurrence_counts(r1cs: &R1CS) -> Vec { /// Counts collected during dead-column removal, returned alongside the updated /// witness count so callers can surface them in diagnostics. struct ColumnRemovalStats { - witnesses_after: usize, - builders_removed: usize, - /// Zero-occurrence columns in A/B/C matrices (excluding col 0 and public - /// inputs). - zero_occurrence_cols: usize, - /// Zero-occurrence cols pinned by the ACIR witness map and therefore kept. - blocked_by_acir: usize, - /// Zero-occurrence cols whose producing builder is transitively live - /// (some other live builder still reads one of its other outputs). - blocked_by_live_builder: usize, - /// Columns actually removed (zero-occurrence, not pinned, producer dead). - columns_removed: usize, + witnesses_after: usize, + builders_removed: usize, + /// Virtual columns: computation-only, excluded from R1CS/WHIR but needed + /// by builders. + num_virtual: usize, } /// Phase 5: Remove dead witness columns from matrices and prune unreachable @@ -743,12 +377,9 @@ fn remove_dead_columns( let num_cols = r1cs.num_witnesses(); if num_cols == 0 || witness_builders.is_empty() { return ColumnRemovalStats { - witnesses_after: num_cols, - builders_removed: 0, - zero_occurrence_cols: 0, - blocked_by_acir: 0, - blocked_by_live_builder: 0, - columns_removed: 0, + witnesses_after: num_cols, + builders_removed: 0, + num_virtual: 0, }; } @@ -780,12 +411,9 @@ fn remove_dead_columns( if dead_cols.is_empty() { return ColumnRemovalStats { - witnesses_after: num_cols, - builders_removed: 0, - zero_occurrence_cols: 0, - blocked_by_acir: 0, - blocked_by_live_builder: 0, - columns_removed: 0, + witnesses_after: num_cols, + builders_removed: 0, + num_virtual: 0, }; } @@ -873,7 +501,6 @@ fn remove_dead_columns( witness_builders.len() - live_builders.len() ); - // Count dead cols blocked by live builder deps let blocked_by_bfs = dead_cols .iter() .filter(|&&col| { @@ -882,6 +509,14 @@ fn remove_dead_columns( .map_or(false, |&b| live_builders.contains(&b)) }) .count(); + + // Detailed diagnostic breakdowns (reader types, producer types, + // hypothetical analyses) are disabled for performance — they are + // O(dead_cols × graph_size) and blow up on large circuits. Enable + // selectively when debugging a specific circuit. + // + // See git history for the full diagnostic block. + info!( "Column removal: of {} dead cols, {} blocked by live builder BFS, {} removable", dead_cols.len(), @@ -889,57 +524,172 @@ fn remove_dead_columns( dead_cols.len() - blocked_by_bfs ); - // Step 4: Determine which columns to actually remove. - // A column is removable if: - // - It's dead in matrices (zero occurrences) AND - // - Its producing builder is NOT live (not transitively reachable) - let mut removable_cols: HashSet = HashSet::new(); - for &col in &dead_cols { - let producer_is_live = col_to_builder - .get(&col) - .map_or(false, |&b| live_builders.contains(&b)); - if !producer_is_live { - removable_cols.insert(col); + // Step 4: Dead columns are removable from R1CS matrices — BUT we must + // protect multi-output builders whose output ranges must stay contiguous. + // If ANY column in a multi-output builder's range is live (non-dead), + // ALL columns in that range must stay real to preserve the contiguous + // output_start + num_witnesses layout. + let mut protected_cols: HashSet = HashSet::new(); + let mut prot_dd = 0usize; + let mut prot_spice = 0usize; + let mut prot_mult_range = 0usize; + let mut prot_mult_binop = 0usize; + let mut prot_mult_spread = 0usize; + let mut prot_u32 = 0usize; + let mut prot_byte = 0usize; + let mut prot_chunk = 0usize; + let mut prot_spread_ext = 0usize; + let mut prot_other = 0usize; + for builder in witness_builders.iter() { + let writes = DependencyInfo::extract_writes(builder); + if writes.len() <= 1 { + continue; + } + // Builders with independent output fields (U32Addition, U32AdditionMulti, + // BytePartition) don't need protection — their outputs are individually + // remapped, not derived from a contiguous range. + // Skip builders with individually-addressed outputs (no contiguity + // assumption): U32Addition/Multi, BytePartition have independent + // index fields; ChunkDecompose/SpreadBitExtract use output_indices Vec. + if matches!( + builder, + WitnessBuilder::U32Addition(..) + | WitnessBuilder::U32AdditionMulti(..) + | WitnessBuilder::BytePartition { .. } + | WitnessBuilder::ChunkDecompose { .. } + | WitnessBuilder::SpreadBitExtract { .. } + | WitnessBuilder::DigitalDecomposition(..) + ) { + continue; + } + // Remaining multi-output builders use contiguous ranges (output_start + + // offset). If any output is live, protect all to preserve contiguity. + let has_live = writes.iter().any(|c| !dead_cols.contains(c)); + let dead_in_range = writes.iter().filter(|c| dead_cols.contains(c)).count(); + if has_live && dead_in_range > 0 { + for &c in &writes { + protected_cols.insert(c); + } + match builder { + WitnessBuilder::DigitalDecomposition(..) => prot_dd += dead_in_range, + WitnessBuilder::SpiceWitnesses(..) => prot_spice += dead_in_range, + WitnessBuilder::MultiplicitiesForRange(..) => prot_mult_range += dead_in_range, + WitnessBuilder::MultiplicitiesForBinOp(..) => prot_mult_binop += dead_in_range, + WitnessBuilder::MultiplicitiesForSpread(..) => prot_mult_spread += dead_in_range, + WitnessBuilder::U32Addition(..) | WitnessBuilder::U32AdditionMulti(..) => { + prot_u32 += dead_in_range + } + WitnessBuilder::BytePartition { .. } => prot_byte += dead_in_range, + WitnessBuilder::ChunkDecompose { .. } => prot_chunk += dead_in_range, + WitnessBuilder::SpreadBitExtract { .. } => prot_spread_ext += dead_in_range, + _ => prot_other += dead_in_range, + } } } - - if removable_cols.is_empty() { + let removable_cols: HashSet = dead_cols + .iter() + .filter(|c| !protected_cols.contains(c)) + .copied() + .collect(); + let protected_count = dead_cols.len() - removable_cols.len(); + if protected_count > 0 { + info!( + "Column removal: {protected_count} dead cols protected by builder type: DD={prot_dd}, \ + Spice={prot_spice}, MultRange={prot_mult_range}, MultBinOp={prot_mult_binop}, \ + MultSpread={prot_mult_spread}, U32={prot_u32}, Byte={prot_byte}, Chunk={prot_chunk}, \ + SpreadExt={prot_spread_ext}, Other={prot_other}" + ); info!( - "Column removal: all {} dead columns are transitively needed by live builders", - dead_cols.len() + "Column removal: {} dead cols protected (part of multi-output builder with live \ + outputs)", + protected_count ); + } + + if removable_cols.is_empty() { return ColumnRemovalStats { - witnesses_after: num_cols, + witnesses_after: num_cols, builders_removed: 0, - zero_occurrence_cols: zero_occ_total, - blocked_by_acir, - blocked_by_live_builder: blocked_by_bfs, - columns_removed: 0, + num_virtual: 0, }; } + // Partition dead cols: dead producers (fully removable) vs live producers + // (virtual). A column must also be virtual if ANY live builder reads it + // (even if its producer is dead) — this can happen after builder rewriting + // changes dependency chains. + let live_read_cols: HashSet = { + let mut s = HashSet::new(); + for (bi, b) in witness_builders.iter().enumerate() { + if live_builders.contains(&bi) { + for c in DependencyInfo::extract_reads(b) { + if removable_cols.contains(&c) { + s.insert(c); + } + } + } + } + s + }; + let mut fully_dead_cols: HashSet = HashSet::new(); + let mut virtual_cols: HashSet = HashSet::new(); + for &col in &removable_cols { + let producer_is_live = col_to_builder + .get(&col) + .map_or(false, |&b| live_builders.contains(&b)); + if producer_is_live || live_read_cols.contains(&col) { + virtual_cols.insert(col); + } else { + fully_dead_cols.insert(col); + } + } + info!( - "Column removal: {} columns removable ({} dead, {} kept for live builder deps)", + "Column removal: {} dead cols total, {} fully dead (producer dead), {} virtual (producer \ + live, computation-only)", removable_cols.len(), - dead_cols.len(), - dead_cols.len() - removable_cols.len() + fully_dead_cols.len(), + virtual_cols.len() ); - // Step 5: Build remap table (old_col -> new_col) + // Step 5: Build remap table with two regions: + // Real columns → [0, num_real) (for R1CS matrices + builders) + // Virtual columns → [num_real, num_real+num_virtual) (for builders only) + // Fully dead columns → None (no mapping, builders pruned) let mut remap: Vec> = vec![None; num_cols]; - let mut next_col = 0; + let mut next_real = 0usize; + // First pass: assign real column indices for col in 0..num_cols { if !removable_cols.contains(&col) { - remap[col] = Some(next_col); - next_col += 1; + remap[col] = Some(next_real); + next_real += 1; } } - let new_num_cols = next_col; - - // Step 6: Remap matrices - r1cs.a = r1cs.a.remove_columns(&remap); - r1cs.b = r1cs.b.remove_columns(&remap); - r1cs.c = r1cs.c.remove_columns(&remap); + let num_real = next_real; + // Second pass: assign virtual column indices (after real) + let mut next_virtual = num_real; + for col in 0..num_cols { + if virtual_cols.contains(&col) { + remap[col] = Some(next_virtual); + next_virtual += 1; + } + } + let num_virtual = next_virtual - num_real; + + // Step 6: Remap R1CS matrices — only uses [0, num_real) columns. + // Virtual columns had zero entries so remove_columns drops them cleanly. + let matrix_remap: Vec> = (0..num_cols) + .map(|col| { + if removable_cols.contains(&col) { + None // Remove from matrices (both virtual and fully dead) + } else { + remap[col] // Real column → compact index + } + }) + .collect(); + r1cs.a = r1cs.a.remove_columns(&matrix_remap); + r1cs.b = r1cs.b.remove_columns(&matrix_remap); + r1cs.c = r1cs.c.remove_columns(&matrix_remap); // Step 6b: Remap ACIR witness map (ACIR index -> R1CS column) for entry in witness_map.iter_mut() { @@ -956,29 +706,78 @@ fn remove_dead_columns( } } - // Step 7: Prune dead builders and remap surviving ones + // Step 7: Prune dead builders and remap surviving ones. + // A builder must be kept if it's live OR if it produces any virtual column + // (needed for computation even though its outputs are zero in A/B/C). + let mut keep_builders = live_builders.clone(); + for &col in &virtual_cols { + if let Some(&producer_idx) = col_to_builder.get(&col) { + keep_builders.insert(producer_idx); + } + } let builders_before = witness_builders.len(); - let mut new_builders: Vec = Vec::with_capacity(live_builders.len()); + let mut new_builders: Vec = Vec::with_capacity(keep_builders.len()); for (idx, builder) in witness_builders.drain(..).enumerate() { - if live_builders.contains(&idx) { + if keep_builders.contains(&idx) { new_builders.push(remap_builder_columns(&builder, &remap)); } } *witness_builders = new_builders; let builders_removed = builders_before - witness_builders.len(); + // Validation: every column that any builder reads must be produced by + // some other builder (i.e. appear in some builder's writes). + { + let mut all_writes: HashSet = HashSet::new(); + for b in witness_builders.iter() { + for c in DependencyInfo::extract_writes(b) { + all_writes.insert(c); + } + } + for (bi, b) in witness_builders.iter().enumerate() { + for c in DependencyInfo::extract_reads(b) { + if !all_writes.contains(&c) { + // Find the original column for debugging + let orig_col = remap + .iter() + .position(|r| r == &Some(c)) + .unwrap_or(usize::MAX); + let was_dead = orig_col != usize::MAX && dead_cols.contains(&orig_col); + let was_virtual = orig_col != usize::MAX && virtual_cols.contains(&orig_col); + let was_fully_dead = + orig_col != usize::MAX && fully_dead_cols.contains(&orig_col); + let had_producer = + orig_col != usize::MAX && col_to_builder.contains_key(&orig_col); + let occ = if orig_col < occurrence_counts.len() { + occurrence_counts[orig_col] + } else { + usize::MAX + }; + panic!( + "Builder {bi} reads remapped col {c} (orig {orig_col}) not written by any \ + builder. occurrences={occ}, dead={was_dead}, virtual={was_virtual}, \ + fully_dead={was_fully_dead}, had_producer={had_producer}, \ + num_real={num_real}, num_virtual={num_virtual}" + ); + } + } + } + } + info!( - "Column removal: {} -> {} witnesses, {} builders pruned", - num_cols, new_num_cols, builders_removed + "Column removal: {} -> {} real + {} virtual witnesses ({} total for solving), {} builders \ + pruned", + num_cols, + num_real, + num_virtual, + num_real + num_virtual, + builders_removed ); ColumnRemovalStats { - witnesses_after: new_num_cols, + witnesses_after: num_real, builders_removed, - zero_occurrence_cols: zero_occ_total, - blocked_by_acir, - blocked_by_live_builder: blocked_by_bfs, - columns_removed: removable_cols.len(), + num_virtual, } } @@ -1056,14 +855,15 @@ fn remap_builder_columns(builder: &WitnessBuilder, remap: &[Option]) -> W WitnessBuilder::LogUpInverse(r(*idx), r(*sz), WitnessCoefficient(*coeff, r(*value))) } WitnessBuilder::DigitalDecomposition(dd) => { - let new_witnesses_to_decompose = - dd.witnesses_to_decompose.iter().map(|&w| r(w)).collect(); WitnessBuilder::DigitalDecomposition(crate::witness::DigitalDecompositionWitnesses { log_bases: dd.log_bases.clone(), num_witnesses_to_decompose: dd.num_witnesses_to_decompose, - witnesses_to_decompose: new_witnesses_to_decompose, - first_witness_idx: r(dd.first_witness_idx), - num_witnesses: dd.num_witnesses, + witnesses_to_decompose: dd + .witnesses_to_decompose + .iter() + .map(|&w| r(w)) + .collect(), + output_indices: dd.output_indices.iter().map(|&i| r(i)).collect(), }) } WitnessBuilder::SpiceMultisetFactor( @@ -1179,30 +979,30 @@ fn remap_builder_columns(builder: &WitnessBuilder, remap: &[Option]) -> W ) } WitnessBuilder::ChunkDecompose { - output_start, + output_indices, packed, chunk_bits, } => WitnessBuilder::ChunkDecompose { - output_start: r(*output_start), - packed: r(*packed), - chunk_bits: chunk_bits.clone(), + output_indices: output_indices.iter().map(|&i| r(i)).collect(), + packed: r(*packed), + chunk_bits: chunk_bits.clone(), }, WitnessBuilder::SpreadWitness(output, input) => { WitnessBuilder::SpreadWitness(r(*output), r(*input)) } WitnessBuilder::SpreadBitExtract { - output_start, + output_indices, chunk_bits, sum_terms, extract_even, } => WitnessBuilder::SpreadBitExtract { - output_start: r(*output_start), - chunk_bits: chunk_bits.clone(), - sum_terms: sum_terms + output_indices: output_indices.iter().map(|&i| r(i)).collect(), + chunk_bits: chunk_bits.clone(), + sum_terms: sum_terms .iter() .map(|SumTerm(coeff, idx)| SumTerm(*coeff, r(*idx))) .collect(), - extract_even: *extract_even, + extract_even: *extract_even, }, WitnessBuilder::MultiplicitiesForSpread(start, num_bits, queries) => { let new_queries = queries.iter().map(|c| rc(c)).collect(); @@ -1390,13 +1190,18 @@ mod tests { assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - // Without builder rewriting (currently disabled), pivot columns - // remain alive because their producer builders are transitively - // reachable from live builders. No witness reduction expected. + // After Phase 4b: Sum(4) is rewritten to inline w3's substitution + // (-1·w0 + 1·w1), so it no longer reads w3. Sum(3) (producer of w3) + // has no live consumers left → dead → w3 fully removed. + // w4 is dead in constraints (GE substituted it out) but Sum(4) still + // produces it and Product(5) reads it → w4 becomes virtual. + // Expected: 6 → 4 real witnesses (w3 fully removed, w4 virtual). assert_eq!( - stats.witnesses_after, stats.witnesses_before, - "Expected no witness reduction without builder rewriting, got {} -> {}", - stats.witnesses_before, stats.witnesses_after + stats.witnesses_after, + stats.witnesses_before - 2, + "Expected 2 witnesses removed from R1CS (w3 dead, w4 virtual), got {} -> {}", + stats.witnesses_before, + stats.witnesses_after ); // Verify the remaining constraint references only valid column indices @@ -1467,13 +1272,22 @@ mod tests { assert_eq!(stats.constraints_after, 1); assert_eq!(r1cs.num_constraints(), 1); - // Without builder rewriting (currently disabled), pivot columns - // w3-w6 remain alive because their producer builders are still - // reachable. No witness reduction expected. + // After Phase 4b linear rewrites: + // Sum(4): inlines w3 → reads w0, w1 (no longer reads w3) + // Sum(5): inlines w4 → reads w0, w1 (no longer reads w4) + // Sum(6): inlines w5 → reads w0, w1 (no longer reads w5) + // Sum(3) (produces w3) has no live consumers → dead → w3 removed. + // Sum(4) (produces w4) has no live consumers (Sum(5) was rewritten) → dead → w4 + // removed. Sum(5) (produces w5) has no live consumers (Sum(6) was + // rewritten) → dead → w5 removed. Sum(6) stays alive because Product(7) + // reads w6. w6 is dead in constraints but Product(7) reads it → virtual. + // Expected: 8 → 4 real witnesses (w3,w4,w5 fully removed, w6 virtual). assert_eq!( - stats.witnesses_after, stats.witnesses_before, - "Expected no witness reduction without builder rewriting, got {} -> {}", - stats.witnesses_before, stats.witnesses_after + stats.witnesses_after, + stats.witnesses_before - 4, + "Expected 4 witnesses removed from R1CS, got {} -> {}", + stats.witnesses_before, + stats.witnesses_after ); // Verify the remaining constraint references only valid column indices @@ -1553,13 +1367,16 @@ mod tests { assert_eq!(stats.eliminated, 2); assert_eq!(stats.constraints_after, 3); - // Without builder rewriting (currently disabled), pivot columns - // w3, w5 remain alive because their producer builders are still - // reachable. No witness reduction expected. + // Pivot columns w3, w5 are dead in constraints (GE substituted + // them out) but their producers are still live (downstream builders + // read them) → they become virtual witnesses. + // Expected: 9 → 7 real witnesses (w3 and w5 become virtual). assert_eq!( - stats.witnesses_after, stats.witnesses_before, - "Expected no witness reduction without builder rewriting, got {} -> {}", - stats.witnesses_before, stats.witnesses_after + stats.witnesses_after, + stats.witnesses_before - 2, + "Expected 2 virtual witnesses (w3, w5), got {} -> {}", + stats.witnesses_before, + stats.witnesses_after ); // Verify all column references are in valid range diff --git a/provekit/common/src/r1cs.rs b/provekit/common/src/r1cs.rs index c971d9f19..fbe940a8f 100644 --- a/provekit/common/src/r1cs.rs +++ b/provekit/common/src/r1cs.rs @@ -15,6 +15,11 @@ pub struct R1CS { pub a: SparseMatrix, pub b: SparseMatrix, pub c: SparseMatrix, + /// Virtual witnesses: computation-only columns excluded from matrices/WHIR + /// commitment but needed by witness builders for intermediate calculations. + /// The prover allocates `num_witnesses() + num_virtual` for solving. + #[serde(default)] + pub num_virtual: usize, } impl Default for R1CS { @@ -32,6 +37,7 @@ impl R1CS { a: SparseMatrix::new(0, 0), b: SparseMatrix::new(0, 0), c: SparseMatrix::new(0, 0), + num_virtual: 0, } } diff --git a/provekit/common/src/witness/digits.rs b/provekit/common/src/witness/digits.rs index 030aea758..222bae666 100644 --- a/provekit/common/src/witness/digits.rs +++ b/provekit/common/src/witness/digits.rs @@ -19,10 +19,10 @@ pub struct DigitalDecompositionWitnesses { pub num_witnesses_to_decompose: usize, /// Witness indices of the values to be decomposed pub witnesses_to_decompose: Vec, - /// The index of the first witness written to - pub first_witness_idx: usize, - /// The number of witnesses written to - pub num_witnesses: usize, + /// Output witness indices. Length = log_bases.len() * + /// num_witnesses_to_decompose. Layout: output_indices[digit_place * + /// num_witnesses_to_decompose + i] + pub output_indices: Vec, } /// Compute a mixed-base decomposition of a field element into its digits, using diff --git a/provekit/common/src/witness/scheduling/dependency.rs b/provekit/common/src/witness/scheduling/dependency.rs index 25912d9c1..c7c8dd5aa 100644 --- a/provekit/common/src/witness/scheduling/dependency.rs +++ b/provekit/common/src/witness/scheduling/dependency.rs @@ -323,9 +323,7 @@ impl DependencyInfo { WitnessBuilder::MultiplicitiesForRange(start, range, _) => { (*start..*start + *range).collect() } - WitnessBuilder::DigitalDecomposition(dd) => { - (dd.first_witness_idx..dd.first_witness_idx + dd.num_witnesses).collect() - } + WitnessBuilder::DigitalDecomposition(dd) => dd.output_indices.clone(), WitnessBuilder::SpiceWitnesses(sw) => { (sw.first_witness_idx..sw.first_witness_idx + sw.num_witnesses).collect() } @@ -333,16 +331,8 @@ impl DependencyInfo { let n = 2usize.pow(2 * *atomic_bits); (*start..*start + n).collect() } - WitnessBuilder::ChunkDecompose { - output_start, - chunk_bits, - .. - } => (*output_start..*output_start + chunk_bits.len()).collect(), - WitnessBuilder::SpreadBitExtract { - output_start, - chunk_bits, - .. - } => (*output_start..*output_start + chunk_bits.len()).collect(), + WitnessBuilder::ChunkDecompose { output_indices, .. } => output_indices.clone(), + WitnessBuilder::SpreadBitExtract { output_indices, .. } => output_indices.clone(), WitnessBuilder::MultiplicitiesForSpread(start, num_bits, _) => { let n = 1usize << *num_bits; (*start..*start + n).collect() diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index 389a14358..d4ea4ad00 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -27,35 +27,64 @@ pub struct WitnessIndexRemapper { impl WitnessIndexRemapper { /// Creates a remapping from w1 and w2 builder lists. /// - /// Assigns w1 builder outputs to [0, k) and w2 builder outputs to [k, n). - pub fn new(w1_builders: &[WitnessBuilder], w2_builders: &[WitnessBuilder]) -> Self { + /// `num_real_cols` is the R1CS column count before splitting — indices + /// below this are "real" (constrained), indices >= are "virtual" + /// (computation-only). + /// + /// Output layout: + /// [0, w1_real) → real w1 witnesses (committed) + /// [w1_real, w1_real+w2_real) → real w2 witnesses (committed) + /// [w1_real+w2_real, total) → virtual witnesses (solving only) + /// + /// `w1_size` is set to `w1_real` so the WHIR commitment covers only + /// real witnesses. + pub fn new( + w1_builders: &[WitnessBuilder], + w2_builders: &[WitnessBuilder], + num_real_cols: usize, + ) -> Self { let mut old_to_new = HashMap::new(); - let mut next_w1_idx = 0; - let mut next_w2_idx = 0; + let mut next_real_w1 = 0usize; + let mut virtual_w1: Vec = Vec::new(); - // Map w1 builder outputs to [0, k) + // First pass w1: assign real outputs, collect virtual for builder in w1_builders { - let writes = DependencyInfo::extract_writes(builder); - for old_idx in writes { - old_to_new.insert(old_idx, next_w1_idx); - next_w1_idx += 1; + for old_idx in DependencyInfo::extract_writes(builder) { + if old_idx < num_real_cols { + old_to_new.insert(old_idx, next_real_w1); + next_real_w1 += 1; + } else { + virtual_w1.push(old_idx); + } } } + let w1_real = next_real_w1; - let w1_size = next_w1_idx; + let mut next_real_w2 = w1_real; + let mut virtual_w2: Vec = Vec::new(); - // Map w2 builder outputs to [k, n) + // First pass w2: assign real outputs, collect virtual for builder in w2_builders { - let writes = DependencyInfo::extract_writes(builder); - for old_idx in writes { - old_to_new.insert(old_idx, w1_size + next_w2_idx); - next_w2_idx += 1; + for old_idx in DependencyInfo::extract_writes(builder) { + if old_idx < num_real_cols { + old_to_new.insert(old_idx, next_real_w2); + next_real_w2 += 1; + } else { + virtual_w2.push(old_idx); + } } } + // Second pass: assign virtual outputs after all real ones + let mut next_virtual = next_real_w2; + for old_idx in virtual_w1.into_iter().chain(virtual_w2) { + old_to_new.insert(old_idx, next_virtual); + next_virtual += 1; + } + Self { old_to_new, - w1_size, + w1_size: w1_real, } } @@ -162,22 +191,22 @@ impl WitnessIndexRemapper { WitnessCoefficient(*coeff, self.remap(*value)), ) } - WitnessBuilder::DigitalDecomposition(dd) => { - let new_witnesses_to_decompose = dd - .witnesses_to_decompose - .iter() - .map(|&w| self.remap(w)) - .collect(); - WitnessBuilder::DigitalDecomposition( - crate::witness::DigitalDecompositionWitnesses { - log_bases: dd.log_bases.clone(), - num_witnesses_to_decompose: dd.num_witnesses_to_decompose, - witnesses_to_decompose: new_witnesses_to_decompose, - first_witness_idx: self.remap(dd.first_witness_idx), - num_witnesses: dd.num_witnesses, - }, - ) - } + WitnessBuilder::DigitalDecomposition(dd) => WitnessBuilder::DigitalDecomposition( + crate::witness::DigitalDecompositionWitnesses { + log_bases: dd.log_bases.clone(), + num_witnesses_to_decompose: dd.num_witnesses_to_decompose, + witnesses_to_decompose: dd + .witnesses_to_decompose + .iter() + .map(|&w| self.remap(w)) + .collect(), + output_indices: dd + .output_indices + .iter() + .map(|&i| self.remap(i)) + .collect(), + }, + ), WitnessBuilder::SpiceMultisetFactor( idx, sz, @@ -493,30 +522,30 @@ impl WitnessIndexRemapper { num_bits: *num_bits, }, WitnessBuilder::ChunkDecompose { - output_start, + output_indices, packed, chunk_bits, } => WitnessBuilder::ChunkDecompose { - output_start: self.remap(*output_start), - packed: self.remap(*packed), - chunk_bits: chunk_bits.clone(), + output_indices: output_indices.iter().map(|&i| self.remap(i)).collect(), + packed: self.remap(*packed), + chunk_bits: chunk_bits.clone(), }, WitnessBuilder::SpreadWitness(output, input) => { WitnessBuilder::SpreadWitness(self.remap(*output), self.remap(*input)) } WitnessBuilder::SpreadBitExtract { - output_start, + output_indices, chunk_bits, sum_terms, extract_even, } => WitnessBuilder::SpreadBitExtract { - output_start: self.remap(*output_start), - chunk_bits: chunk_bits.clone(), - sum_terms: sum_terms + output_indices: output_indices.iter().map(|&i| self.remap(i)).collect(), + chunk_bits: chunk_bits.clone(), + sum_terms: sum_terms .iter() .map(|SumTerm(coeff, idx)| SumTerm(*coeff, self.remap(*idx))) .collect(), - extract_even: *extract_even, + extract_even: *extract_even, }, WitnessBuilder::MultiplicitiesForSpread(start, num_bits, queries) => { let new_queries = queries @@ -559,6 +588,7 @@ impl WitnessIndexRemapper { pub fn remap_r1cs(&self, r1cs: R1CS) -> R1CS { let mut new_r1cs = R1CS::new(); new_r1cs.num_public_inputs = r1cs.num_public_inputs; + new_r1cs.num_virtual = r1cs.num_virtual; new_r1cs.interner = r1cs.interner; // Remap A, B, C in parallel - they're independent @@ -579,9 +609,14 @@ impl WitnessIndexRemapper { new_r1cs } - /// Helper to remap a single sparse matrix + /// Helper to remap a single sparse matrix. + /// Updates `num_cols` to the total witness count after remapping + /// (w1_size + w2_size), so the matrix dimensions match the new + /// witness layout. fn remap_sparse_matrix(&self, mut matrix: SparseMatrix) -> SparseMatrix { + let total_witnesses = self.old_to_new.values().copied().max().map_or(0, |m| m + 1); matrix.remap_columns(|old_col| self.remap(old_col)); + matrix.num_cols = total_witnesses; matrix } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index ee1af6f26..77ea03a3e 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -277,11 +277,11 @@ pub enum WitnessBuilder { /// Decomposes a packed value into chunks of specified bit-widths. /// Given packed value and chunk_bits = [b0, b1, ..., bn]: /// packed = c0 + c1 * 2^b0 + c2 * 2^(b0+b1) + ... - /// Writes chunk values to output_start..output_start+chunk_bits.len() + /// `output_indices[i]` is the witness index for chunk `i`. ChunkDecompose { - output_start: usize, - packed: usize, - chunk_bits: Vec, + output_indices: Vec, + packed: usize, + chunk_bits: Vec, }, /// Prover hint for FakeGLV scalar decomposition. /// Given scalar s (from s_lo + s_hi * 2^128) and curve order n, @@ -401,13 +401,12 @@ pub enum WitnessBuilder { SpreadWitness(usize, usize), /// Extracts even or odd bits from a spread sum, decomposed into /// byte-sized chunks. Even bits = XOR result, Odd bits = MAJ/AND - /// result. The sum is computed inline from the provided terms, - /// avoiding a separate witness allocation. + /// result. `output_indices[i]` is the witness index for chunk `i`. SpreadBitExtract { - output_start: usize, - chunk_bits: Vec, - sum_terms: Vec, - extract_even: bool, + output_indices: Vec, + chunk_bits: Vec, + sum_terms: Vec, + extract_even: bool, }, /// Spread table multiplicities: counts how many times each input /// value appears in the query set. @@ -453,7 +452,7 @@ impl WitnessBuilder { pub fn num_witnesses(&self) -> usize { match self { WitnessBuilder::MultiplicitiesForRange(_, range_size, _) => *range_size, - WitnessBuilder::DigitalDecomposition(dd_struct) => dd_struct.num_witnesses, + WitnessBuilder::DigitalDecomposition(dd_struct) => dd_struct.output_indices.len(), WitnessBuilder::SpiceWitnesses(spice_witnesses_struct) => { spice_witnesses_struct.num_witnesses } @@ -463,8 +462,8 @@ impl WitnessBuilder { WitnessBuilder::U32Addition(..) => 2, WitnessBuilder::U32AdditionMulti(..) => 2, WitnessBuilder::BytePartition { .. } => 2, - WitnessBuilder::ChunkDecompose { chunk_bits, .. } => chunk_bits.len(), - WitnessBuilder::SpreadBitExtract { chunk_bits, .. } => chunk_bits.len(), + WitnessBuilder::ChunkDecompose { output_indices, .. } => output_indices.len(), + WitnessBuilder::SpreadBitExtract { output_indices, .. } => output_indices.len(), WitnessBuilder::MultiplicitiesForSpread(_, num_bits, _) => 1usize << *num_bits, WitnessBuilder::MultiLimbMulModHint { num_limbs, .. } => (4 * *num_limbs - 2) as usize, WitnessBuilder::MultiLimbModularInverse { num_limbs, .. } => *num_limbs as usize, @@ -541,8 +540,12 @@ impl WitnessBuilder { .map(|&idx| witness_builders[idx].clone()) .collect(); - // Step 3: Create witness index remapper - let remapper = WitnessIndexRemapper::new(&w1_builders, &w2_builders); + // Step 3: Create witness index remapper. + // Pass num_real_cols so virtual witnesses are placed at the end, + // after all real w1/w2 witnesses — keeping them out of the + // committed WHIR polynomial. + let num_real_cols = r1cs.num_witnesses(); + let remapper = WitnessIndexRemapper::new(&w1_builders, &w2_builders, num_real_cols); let w1_size = remapper.w1_size; // Step 4: Remap all builders @@ -556,7 +559,9 @@ impl WitnessBuilder { .map(|b| remapper.remap_builder(b)) .collect(); - // Step 5: Remap R1CS and witness map + // Step 5: Remap R1CS and witness map. + // num_virtual is preserved — virtual witnesses are at the end of + // the index space, after all real w1/w2 witnesses. let remapped_r1cs = remapper.remap_r1cs(r1cs); let remapped_witness_map = remapper.remap_acir_witness_map(witness_map); diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index edcf685ac..f9a0ec7d6 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -113,6 +113,7 @@ impl Prove for NoirProver { let compressed_r1cs = CompressedR1CS::compress(self.r1cs).context("While compressing R1CS")?; let num_witnesses = compressed_r1cs.num_witnesses(); + let num_virtual = compressed_r1cs.num_virtual(); let num_constraints = compressed_r1cs.num_constraints(); // Set up transcript with sponge selected by hash_config. @@ -122,7 +123,9 @@ impl Prove for NoirProver { .instance(&Empty); let mut merlin = ProverState::new(&ds, TranscriptSponge::from_config(self.hash_config)); - let mut witness: Vec> = vec![None; num_witnesses]; + // Allocate space for real + virtual witnesses. Virtual witnesses are + // computation-only (zero entries in A/B/C) but needed by builders. + let mut witness: Vec> = vec![None; num_witnesses + num_virtual]; // Solve w1 (or all witnesses if no challenges). // Outer span captures memory AFTER w1_layers parameter is freed @@ -187,7 +190,8 @@ impl Prove for NoirProver { let w2 = { let _s = info_span!("allocate_w2").entered(); - witness[self.whir_for_witness.w1_size..] + // Only real w2 witnesses (exclude virtual at the end). + witness[self.whir_for_witness.w1_size..num_witnesses] .iter() .map(|w| w.ok_or_else(|| anyhow::anyhow!("Some witnesses in w2 are missing"))) .collect::>>()? @@ -210,8 +214,13 @@ impl Prove for NoirProver { .context("While decompressing R1CS")?; #[cfg(test)] - r1cs.test_witness_satisfaction(&witness.iter().map(|w| w.unwrap()).collect::>()) - .context("While verifying R1CS instance")?; + r1cs.test_witness_satisfaction( + &witness[..num_witnesses] + .iter() + .map(|w| w.unwrap()) + .collect::>(), + ) + .context("While verifying R1CS instance")?; let public_inputs = if num_public_inputs == 0 { PublicInputs::new() @@ -224,8 +233,11 @@ impl Prove for NoirProver { ) }; - let full_witness: Vec = witness - .into_iter() + // Extract only real witnesses (first num_witnesses) for the sumcheck. + // Virtual witnesses at [num_witnesses, num_witnesses+num_virtual) were + // needed for builder computation but have zero entries in A/B/C. + let full_witness: Vec = witness[..num_witnesses] + .iter() .enumerate() .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; diff --git a/provekit/prover/src/r1cs.rs b/provekit/prover/src/r1cs.rs index 80254d4db..2553f6b64 100644 --- a/provekit/prover/src/r1cs.rs +++ b/provekit/prover/src/r1cs.rs @@ -19,6 +19,7 @@ use { pub struct CompressedR1CS { num_constraints: usize, num_witnesses: usize, + num_virtual: usize, blob: Vec, } @@ -47,11 +48,13 @@ impl CompressedR1CS { pub fn compress(r1cs: R1CS) -> Result { let num_constraints = r1cs.num_constraints(); let num_witnesses = r1cs.num_witnesses(); + let num_virtual = r1cs.num_virtual; let blob = postcard::to_allocvec(&r1cs).context("R1CS serialization failed")?; drop(r1cs); Ok(Self { num_constraints, num_witnesses, + num_virtual, blob, }) } @@ -68,6 +71,10 @@ impl CompressedR1CS { self.num_witnesses } + pub const fn num_virtual(&self) -> usize { + self.num_virtual + } + pub fn blob_len(&self) -> usize { self.blob.len() } diff --git a/provekit/prover/src/witness/digits.rs b/provekit/prover/src/witness/digits.rs index 409ba2795..4b5ac10e4 100644 --- a/provekit/prover/src/witness/digits.rs +++ b/provekit/prover/src/witness/digits.rs @@ -19,9 +19,8 @@ impl DigitalDecompositionWitnessesSolver for DigitalDecompositionWitnesses { .iter() .enumerate() .for_each(|(digit_place, digit_value)| { - witness[self.first_witness_idx - + digit_place * self.witnesses_to_decompose.len() - + i] = Some(*digit_value); + let idx = digit_place * self.witnesses_to_decompose.len() + i; + witness[self.output_indices[idx]] = Some(*digit_value); }); }); } diff --git a/provekit/prover/src/witness/witness_builder.rs b/provekit/prover/src/witness/witness_builder.rs index 111dfefa3..7eb9f31aa 100644 --- a/provekit/prover/src/witness/witness_builder.rs +++ b/provekit/prover/src/witness/witness_builder.rs @@ -859,7 +859,7 @@ impl WitnessBuilderSolver for WitnessBuilder { ) } WitnessBuilder::ChunkDecompose { - output_start, + output_indices, packed, chunk_bits, } => { @@ -868,7 +868,7 @@ impl WitnessBuilderSolver for WitnessBuilder { for (i, &bits) in chunk_bits.iter().enumerate() { let mask = (1u64 << bits) - 1; let chunk_val = (packed_val >> offset) & mask; - witness[output_start + i] = Some(FieldElement::from(chunk_val)); + witness[output_indices[i]] = Some(FieldElement::from(chunk_val)); offset += bits; } } @@ -878,12 +878,11 @@ impl WitnessBuilderSolver for WitnessBuilder { witness[*output_idx] = Some(FieldElement::from(spread)); } WitnessBuilder::SpreadBitExtract { - output_start, + output_indices, chunk_bits, sum_terms, extract_even, } => { - // Compute the spread sum inline from terms (no phantom witness needed) let sum_fe: FieldElement = sum_terms .iter() .map(|SumTerm(coeff, idx)| { @@ -896,19 +895,17 @@ impl WitnessBuilderSolver for WitnessBuilder { }) .fold(FieldElement::zero(), |acc, x| acc + x); let sum_val = sum_fe.into_bigint().0[0]; - // Extract even or odd bits from the spread sum let bit_offset = if *extract_even { 0 } else { 1 }; let total_bits: u32 = chunk_bits.iter().sum(); let mut extracted = 0u64; for i in 0..total_bits { extracted |= ((sum_val >> (2 * i + bit_offset)) & 1) << i; } - // Decompose extracted value into chunks let mut offset = 0u32; for (i, &bits) in chunk_bits.iter().enumerate() { let mask = (1u64 << bits) - 1; let chunk_val = (extracted >> offset) & mask; - witness[output_start + i] = Some(FieldElement::from(chunk_val)); + witness[output_indices[i]] = Some(FieldElement::from(chunk_val)); offset += bits; } } diff --git a/provekit/r1cs-compiler/src/digits.rs b/provekit/r1cs-compiler/src/digits.rs index 657f7bd78..a3a61dcc8 100644 --- a/provekit/r1cs-compiler/src/digits.rs +++ b/provekit/r1cs-compiler/src/digits.rs @@ -25,13 +25,12 @@ impl DigitalDecompositionWitnessesBuilder for DigitalDecompositionWitnesses { witnesses_to_decompose: Vec, ) -> Self { let num_witnesses_to_decompose = witnesses_to_decompose.len(); - let digital_decomp_length = log_bases.len(); + let num_outputs = log_bases.len() * num_witnesses_to_decompose; Self { log_bases, num_witnesses_to_decompose, witnesses_to_decompose, - first_witness_idx: next_witness_idx, - num_witnesses: digital_decomp_length * num_witnesses_to_decompose, + output_indices: (next_witness_idx..next_witness_idx + num_outputs).collect(), } } @@ -42,7 +41,7 @@ impl DigitalDecompositionWitnessesBuilder for DigitalDecompositionWitnesses { fn get_digit_witness_index(&self, digit_place: usize, value_offset: usize) -> usize { debug_assert!(digit_place < self.log_bases.len()); debug_assert!(value_offset < self.num_witnesses_to_decompose); - self.first_witness_idx + digit_place * self.num_witnesses_to_decompose + value_offset + self.output_indices[digit_place * self.num_witnesses_to_decompose + value_offset] } } diff --git a/provekit/r1cs-compiler/src/noir_proof_scheme.rs b/provekit/r1cs-compiler/src/noir_proof_scheme.rs index 859246e91..29b11a229 100644 --- a/provekit/r1cs-compiler/src/noir_proof_scheme.rs +++ b/provekit/r1cs-compiler/src/noir_proof_scheme.rs @@ -86,17 +86,17 @@ impl NoirCompiler { witness_map, acir_public_inputs_indices_set, )?; + let num_real = remapped_r1cs.num_witnesses(); + let num_virtual = remapped_r1cs.num_virtual; info!( - "Witness split: w1 size = {}, w2 size = {}", + "Witness split: w1 = {}, w2 = {} (real, committed) + {} virtual (solving only)", split_witness_builders.w1_size, - remapped_r1cs.num_witnesses() - split_witness_builders.w1_size + num_real - split_witness_builders.w1_size, + num_virtual ); - let witness_generator = NoirWitnessGenerator::new( - &program, - remapped_witness_map, - remapped_r1cs.num_witnesses(), - ); + let witness_generator = + NoirWitnessGenerator::new(&program, remapped_witness_map, num_real + num_virtual); let whir_for_witness = WhirR1CSScheme::new_for_r1cs( &remapped_r1cs, diff --git a/provekit/r1cs-compiler/src/spread.rs b/provekit/r1cs-compiler/src/spread.rs index 8acc93bea..d0867dc52 100644 --- a/provekit/r1cs-compiler/src/spread.rs +++ b/provekit/r1cs-compiler/src/spread.rs @@ -35,7 +35,6 @@ fn subchunks(bits: u32, w: u32) -> Vec { #[derive(Clone, Debug)] pub(crate) struct SpreadChunk { pub total_bits: u32, - pub sub_values: Vec, pub sub_spreads: Vec, pub sub_bits: Vec, } @@ -147,8 +146,9 @@ pub(crate) fn decompose_to_spread_word( // Step 2: Single ChunkDecompose producing all sub-chunks directly // from the packed value. let sub_start = compiler.num_witnesses(); + let chunk_count = flat_bits.len(); compiler.add_witness_builder(WitnessBuilder::ChunkDecompose { - output_start: sub_start, + output_indices: (sub_start..sub_start + chunk_count).collect(), packed, chunk_bits: flat_bits.clone(), }); @@ -175,7 +175,6 @@ pub(crate) fn decompose_to_spread_word( let n_subs = chunk_sub_counts[ci]; let sub_bits_slice = &flat_bits[flat_idx..flat_idx + n_subs]; - let mut sub_values = Vec::with_capacity(n_subs); let mut sub_spreads = Vec::with_capacity(n_subs); for j in 0..n_subs { let val_idx = sub_start + flat_idx + j; @@ -187,13 +186,11 @@ pub(crate) fn decompose_to_spread_word( .or_default() .push(val_idx); } - sub_values.push(val_idx); sub_spreads.push(spread_idx); } chunks.push(SpreadChunk { total_bits: chunk_spec[ci], - sub_values, sub_spreads, sub_bits: sub_bits_slice.to_vec(), }); @@ -262,17 +259,19 @@ pub(crate) fn spread_decompose( // The sum is computed inline by the solver from sum_terms, // avoiding a phantom witness that would inflate the witness vector. let even_start = compiler.num_witnesses(); + let even_count = extract_chunks.len(); compiler.add_witness_builder(WitnessBuilder::SpreadBitExtract { - output_start: even_start, - chunk_bits: extract_chunks.clone(), - sum_terms: sum_terms.clone(), - extract_even: true, + output_indices: (even_start..even_start + even_count).collect(), + chunk_bits: extract_chunks.clone(), + sum_terms: sum_terms.clone(), + extract_even: true, }); // Extract odd bits (AND/MAJ) into chunks let odd_start = compiler.num_witnesses(); + let odd_count = extract_chunks.len(); compiler.add_witness_builder(WitnessBuilder::SpreadBitExtract { - output_start: odd_start, + output_indices: (odd_start..odd_start + odd_count).collect(), chunk_bits: extract_chunks.clone(), sum_terms, extract_even: false, @@ -478,10 +477,8 @@ pub(crate) fn decompose_constant_to_spread_word( // spread table lookup is needed for soundness. // Build SpreadChunks with pinned spread witnesses. - // No ChunkDecompose needed — sub_values are never read by any - // downstream R1CS constraint (only sub_spreads are used in - // spread_decompose). We still populate sub_values with the - // spread witness indices as placeholders to satisfy the struct. + // No ChunkDecompose needed — only sub_spreads are used in + // downstream spread_decompose constraints. let mut chunks = Vec::with_capacity(num_chunks); let mut flat_idx = 0usize; @@ -504,10 +501,6 @@ pub(crate) fn decompose_constant_to_spread_word( chunks.push(SpreadChunk { total_bits: chunk_spec[ci], - // NB: sub_values normally holds chunk-value witness indices, but the constant - // path doesn't create separate chunk-value witnesses. Set to spread indices - // instead; this field is currently unused so the mismatch is harmless. - sub_values: sub_spreads.clone(), sub_spreads, sub_bits: sub_bits_slice.to_vec(), }); diff --git a/tooling/cli/src/cmd/circuit_stats/display.rs b/tooling/cli/src/cmd/circuit_stats/display.rs index 4245a90f7..26c8efdc4 100644 --- a/tooling/cli/src/cmd/circuit_stats/display.rs +++ b/tooling/cli/src/cmd/circuit_stats/display.rs @@ -474,9 +474,13 @@ pub(super) fn print_ge_optimization( (optimized_r1cs.num_constraints() as f64).log2() ); println!( - "│ Witnesses: {:>8} (2^{:.2})", - optimized_r1cs.num_witnesses(), - (optimized_r1cs.num_witnesses() as f64).log2() + "│ Witnesses: {:>8} (2^{:.2}) committed to WHIR", + stats.witnesses_after, + (stats.witnesses_after as f64).log2() + ); + println!( + "│ Virtual: {:>8} solving only, not committed", + stats.num_virtual ); println!("│ A entries: {:>8}", optimized_r1cs.a.num_entries()); println!("│ B entries: {:>8}", optimized_r1cs.b.num_entries()); @@ -484,22 +488,6 @@ pub(super) fn print_ge_optimization( println!("└{}", SUBSECTION); println!("\n{}", SEPARATOR); - println!( - "ELIMINATED: {:>8} linear constraints substituted", - stats.eliminated - ); - println!( - "BUILDERS PRUNED: {:>8} unreachable witness builders removed", - stats.builders_removed - ); - println!( - "BUILDERS REWRITTEN: {:>8} dependency chains severed via substitution", - stats.builders_rewritten - ); - println!( - "NEW SUM BUILDERS: {:>8} intermediate builders created for non-linear reads", - stats.new_sum_builders - ); println!( "CONSTRAINT REDUCTION: {:>7.2}% ({} -> {})", stats.constraint_reduction_percent(), @@ -507,35 +495,12 @@ pub(super) fn print_ge_optimization( stats.constraints_after ); println!( - "WITNESS REDUCTION: {:>7.2}% ({} -> {})", + "WITNESS REDUCTION: {:>7.2}% ({} -> {} committed + {} virtual)", stats.witness_reduction_percent(), stats.witnesses_before, - stats.witnesses_after + stats.witnesses_after, + stats.num_virtual ); println!("{}", SEPARATOR); - - println!("\n┌─ Column Removal Details"); - println!( - "│ Zero-occurrence cols: {:>8} (dead in A/B/C matrices, excl. public)", - stats.zero_occurrence_cols - ); - println!( - "│ Blocked — ACIR witness map:{:>8} (pinned as circuit inputs/outputs)", - stats.blocked_by_acir - ); - println!( - "│ Blocked — live builder BFS:{:>8} (pivot dependency chains still alive)", - stats.blocked_by_live_builder - ); - println!( - "│ Actually removed: {:>8} ({:.1}% of zero-occurrence)", - stats.columns_removed, - if stats.zero_occurrence_cols > 0 { - stats.columns_removed as f64 / stats.zero_occurrence_cols as f64 * 100.0 - } else { - 0.0 - } - ); - println!("└{}", SUBSECTION); println!(); } From b919be320d10431647c76449db6e3e6fc3c6d40b Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Tue, 24 Mar 2026 17:56:00 +0530 Subject: [PATCH 06/10] style: Clean up formatting in witness module imports --- provekit/common/src/witness/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index ffe494dc7..8007dfecf 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -20,9 +20,8 @@ pub use { limbs::{Limbs, MAX_LIMBS}, ram::{SpiceMemoryOperation, SpiceWitnesses}, scheduling::{ - - DependencyInfo, Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, - , + DependencyInfo, Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, + SplitWitnessBuilders, }, witness_builder::{ CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, From fe372ef46fa0647b9e8fabbeda1ac1b9f16d0e3b Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Wed, 25 Mar 2026 12:21:57 +0530 Subject: [PATCH 07/10] feat: Enhance witness index remapping and optimize sparse matrix column removal --- provekit/common/src/optimize.rs | 564 +++++------------- provekit/common/src/sparse_matrix.rs | 6 +- provekit/common/src/witness/mod.rs | 2 +- .../common/src/witness/scheduling/remapper.rs | 15 +- 4 files changed, 156 insertions(+), 431 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index cac7c8982..e035c8a1e 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -55,7 +55,8 @@ impl OptimizationStats { if self.witnesses_before == 0 { return 0.0; } - (self.witnesses_before - self.witnesses_after) as f64 / self.witnesses_before as f64 * 100.0 + (self.witnesses_before as f64 - self.witnesses_after as f64) / self.witnesses_before as f64 + * 100.0 } } @@ -288,9 +289,8 @@ pub fn optimize_r1cs( } // Phase 4: Remove eliminated constraint rows - let mut sorted_rows = eliminated_rows.clone(); - sorted_rows.sort(); - r1cs.remove_constraints(&sorted_rows); + eliminated_rows.sort(); + r1cs.remove_constraints(&eliminated_rows); let constraints_after = r1cs.num_constraints(); let eliminated = substitutions.len(); @@ -529,28 +529,17 @@ fn remove_dead_columns( // If ANY column in a multi-output builder's range is live (non-dead), // ALL columns in that range must stay real to preserve the contiguous // output_start + num_witnesses layout. + // Protect contiguous-range multi-output builders. Builders with + // individually-addressed outputs don't need protection: + // - U32Addition/Multi, BytePartition: independent index fields + // - ChunkDecompose, SpreadBitExtract: output_indices Vec + // - DigitalDecomposition: output_indices Vec let mut protected_cols: HashSet = HashSet::new(); - let mut prot_dd = 0usize; - let mut prot_spice = 0usize; - let mut prot_mult_range = 0usize; - let mut prot_mult_binop = 0usize; - let mut prot_mult_spread = 0usize; - let mut prot_u32 = 0usize; - let mut prot_byte = 0usize; - let mut prot_chunk = 0usize; - let mut prot_spread_ext = 0usize; - let mut prot_other = 0usize; for builder in witness_builders.iter() { let writes = DependencyInfo::extract_writes(builder); if writes.len() <= 1 { continue; } - // Builders with independent output fields (U32Addition, U32AdditionMulti, - // BytePartition) don't need protection — their outputs are individually - // remapped, not derived from a contiguous range. - // Skip builders with individually-addressed outputs (no contiguity - // assumption): U32Addition/Multi, BytePartition have independent - // index fields; ChunkDecompose/SpreadBitExtract use output_indices Vec. if matches!( builder, WitnessBuilder::U32Addition(..) @@ -562,28 +551,11 @@ fn remove_dead_columns( ) { continue; } - // Remaining multi-output builders use contiguous ranges (output_start + - // offset). If any output is live, protect all to preserve contiguity. let has_live = writes.iter().any(|c| !dead_cols.contains(c)); - let dead_in_range = writes.iter().filter(|c| dead_cols.contains(c)).count(); - if has_live && dead_in_range > 0 { + if has_live && writes.iter().any(|c| dead_cols.contains(c)) { for &c in &writes { protected_cols.insert(c); } - match builder { - WitnessBuilder::DigitalDecomposition(..) => prot_dd += dead_in_range, - WitnessBuilder::SpiceWitnesses(..) => prot_spice += dead_in_range, - WitnessBuilder::MultiplicitiesForRange(..) => prot_mult_range += dead_in_range, - WitnessBuilder::MultiplicitiesForBinOp(..) => prot_mult_binop += dead_in_range, - WitnessBuilder::MultiplicitiesForSpread(..) => prot_mult_spread += dead_in_range, - WitnessBuilder::U32Addition(..) | WitnessBuilder::U32AdditionMulti(..) => { - prot_u32 += dead_in_range - } - WitnessBuilder::BytePartition { .. } => prot_byte += dead_in_range, - WitnessBuilder::ChunkDecompose { .. } => prot_chunk += dead_in_range, - WitnessBuilder::SpreadBitExtract { .. } => prot_spread_ext += dead_in_range, - _ => prot_other += dead_in_range, - } } } let removable_cols: HashSet = dead_cols @@ -594,15 +566,8 @@ fn remove_dead_columns( let protected_count = dead_cols.len() - removable_cols.len(); if protected_count > 0 { info!( - "Column removal: {protected_count} dead cols protected by builder type: DD={prot_dd}, \ - Spice={prot_spice}, MultRange={prot_mult_range}, MultBinOp={prot_mult_binop}, \ - MultSpread={prot_mult_spread}, U32={prot_u32}, Byte={prot_byte}, Chunk={prot_chunk}, \ - SpreadExt={prot_spread_ext}, Other={prot_other}" - ); - info!( - "Column removal: {} dead cols protected (part of multi-output builder with live \ - outputs)", - protected_count + "Column removal: {protected_count} dead cols protected (contiguous-range multi-output \ + builders with mixed live/dead outputs)" ); } @@ -725,45 +690,6 @@ fn remove_dead_columns( *witness_builders = new_builders; let builders_removed = builders_before - witness_builders.len(); - // Validation: every column that any builder reads must be produced by - // some other builder (i.e. appear in some builder's writes). - { - let mut all_writes: HashSet = HashSet::new(); - for b in witness_builders.iter() { - for c in DependencyInfo::extract_writes(b) { - all_writes.insert(c); - } - } - for (bi, b) in witness_builders.iter().enumerate() { - for c in DependencyInfo::extract_reads(b) { - if !all_writes.contains(&c) { - // Find the original column for debugging - let orig_col = remap - .iter() - .position(|r| r == &Some(c)) - .unwrap_or(usize::MAX); - let was_dead = orig_col != usize::MAX && dead_cols.contains(&orig_col); - let was_virtual = orig_col != usize::MAX && virtual_cols.contains(&orig_col); - let was_fully_dead = - orig_col != usize::MAX && fully_dead_cols.contains(&orig_col); - let had_producer = - orig_col != usize::MAX && col_to_builder.contains_key(&orig_col); - let occ = if orig_col < occurrence_counts.len() { - occurrence_counts[orig_col] - } else { - usize::MAX - }; - panic!( - "Builder {bi} reads remapped col {c} (orig {orig_col}) not written by any \ - builder. occurrences={occ}, dead={was_dead}, virtual={was_virtual}, \ - fully_dead={was_fully_dead}, had_producer={had_producer}, \ - num_real={num_real}, num_virtual={num_virtual}" - ); - } - } - } - } - info!( "Column removal: {} -> {} real + {} virtual witnesses ({} total for solving), {} builders \ pruned", @@ -782,257 +708,22 @@ fn remove_dead_columns( } /// Remap all witness column references inside a builder using the given -/// remap table. This mirrors `WitnessIndexRemapper::remap_builder` but uses -/// a Vec> remap table instead of HashMap. +/// remap table. Delegates to to +/// avoid duplicating per-variant remap logic. fn remap_builder_columns(builder: &WitnessBuilder, remap: &[Option]) -> WitnessBuilder { - let r = |idx: usize| -> usize { - remap[idx].unwrap_or_else(|| { - panic!( - "Witness index {} not in remap table (expected live column)", - idx - ) - }) - }; - - let rc = - |val: &crate::witness::ConstantOrR1CSWitness| -> crate::witness::ConstantOrR1CSWitness { - match val { - crate::witness::ConstantOrR1CSWitness::Constant(c) => { - crate::witness::ConstantOrR1CSWitness::Constant(*c) - } - crate::witness::ConstantOrR1CSWitness::Witness(w) => { - crate::witness::ConstantOrR1CSWitness::Witness(r(*w)) - } - } - }; + use crate::witness::WitnessIndexRemapper; - use crate::witness::*; - match builder { - WitnessBuilder::Constant(ConstantTerm(idx, val)) => { - WitnessBuilder::Constant(ConstantTerm(r(*idx), *val)) - } - WitnessBuilder::Acir(idx, acir_idx) => WitnessBuilder::Acir(r(*idx), *acir_idx), - WitnessBuilder::Sum(idx, terms) => { - let new_terms = terms - .iter() - .map(|SumTerm(coeff, operand_idx)| SumTerm(*coeff, r(*operand_idx))) - .collect(); - WitnessBuilder::Sum(r(*idx), new_terms) - } - WitnessBuilder::Product(idx, a, b) => WitnessBuilder::Product(r(*idx), r(*a), r(*b)), - WitnessBuilder::MultiplicitiesForRange(start, range, values) => { - let new_values = values.iter().map(|&v| r(v)).collect(); - WitnessBuilder::MultiplicitiesForRange(r(*start), *range, new_values) - } - WitnessBuilder::Challenge(idx) => WitnessBuilder::Challenge(r(*idx)), - WitnessBuilder::IndexedLogUpDenominator( - idx, - sz, - WitnessCoefficient(coeff, index), - rs, - value, - ) => WitnessBuilder::IndexedLogUpDenominator( - r(*idx), - r(*sz), - WitnessCoefficient(*coeff, r(*index)), - r(*rs), - r(*value), - ), - WitnessBuilder::Inverse(idx, operand) => WitnessBuilder::Inverse(r(*idx), r(*operand)), - WitnessBuilder::ProductLinearOperation( - idx, - ProductLinearTerm(x, a, b), - ProductLinearTerm(y, c, d), - ) => WitnessBuilder::ProductLinearOperation( - r(*idx), - ProductLinearTerm(r(*x), *a, *b), - ProductLinearTerm(r(*y), *c, *d), - ), - WitnessBuilder::LogUpDenominator(idx, sz, WitnessCoefficient(coeff, value)) => { - WitnessBuilder::LogUpDenominator(r(*idx), r(*sz), WitnessCoefficient(*coeff, r(*value))) - } - WitnessBuilder::LogUpInverse(idx, sz, WitnessCoefficient(coeff, value)) => { - WitnessBuilder::LogUpInverse(r(*idx), r(*sz), WitnessCoefficient(*coeff, r(*value))) - } - WitnessBuilder::DigitalDecomposition(dd) => { - WitnessBuilder::DigitalDecomposition(crate::witness::DigitalDecompositionWitnesses { - log_bases: dd.log_bases.clone(), - num_witnesses_to_decompose: dd.num_witnesses_to_decompose, - witnesses_to_decompose: dd - .witnesses_to_decompose - .iter() - .map(|&w| r(w)) - .collect(), - output_indices: dd.output_indices.iter().map(|&i| r(i)).collect(), - }) - } - WitnessBuilder::SpiceMultisetFactor( - idx, - sz, - rs, - WitnessCoefficient(addr_c, addr_w), - value, - WitnessCoefficient(timer_c, timer_w), - ) => WitnessBuilder::SpiceMultisetFactor( - r(*idx), - r(*sz), - r(*rs), - WitnessCoefficient(*addr_c, r(*addr_w)), - r(*value), - WitnessCoefficient(*timer_c, r(*timer_w)), - ), - WitnessBuilder::SpiceWitnesses(sw) => { - let new_memory_operations = sw - .memory_operations - .iter() - .map(|op| match op { - crate::witness::SpiceMemoryOperation::Load(addr, value, rt) => { - crate::witness::SpiceMemoryOperation::Load(r(*addr), r(*value), r(*rt)) - } - crate::witness::SpiceMemoryOperation::Store(addr, old_val, new_val, rt) => { - crate::witness::SpiceMemoryOperation::Store( - r(*addr), - r(*old_val), - r(*new_val), - r(*rt), - ) - } - }) - .collect(); - WitnessBuilder::SpiceWitnesses(crate::witness::SpiceWitnesses { - memory_length: sw.memory_length, - initial_value_witnesses: sw.initial_value_witnesses.iter().map(|w| r(*w)).collect(), - memory_operations: new_memory_operations, - rv_final_start: r(sw.rv_final_start), - rt_final_start: r(sw.rt_final_start), - first_witness_idx: r(sw.first_witness_idx), - num_witnesses: sw.num_witnesses, - }) - } - WitnessBuilder::U32AdditionMulti(result_idx, carry_idx, inputs) => { - WitnessBuilder::U32AdditionMulti( - r(*result_idx), - r(*carry_idx), - inputs.iter().map(|c| rc(c)).collect(), - ) - } - WitnessBuilder::BytePartition { lo, hi, x, k } => WitnessBuilder::BytePartition { - lo: r(*lo), - hi: r(*hi), - x: r(*x), - k: *k, - }, - WitnessBuilder::BinOpLookupDenominator(idx, sz, rs, rs2, lhs, rhs, output) => { - WitnessBuilder::BinOpLookupDenominator( - r(*idx), - r(*sz), - r(*rs), - r(*rs2), - rc(lhs), - rc(rhs), - rc(output), - ) - } - WitnessBuilder::CombinedBinOpLookupDenominator( - idx, - sz, - rs, - rs2, - rs3, - lhs, - rhs, - and_out, - xor_out, - ) => WitnessBuilder::CombinedBinOpLookupDenominator( - r(*idx), - r(*sz), - r(*rs), - r(*rs2), - r(*rs3), - rc(lhs), - rc(rhs), - rc(and_out), - rc(xor_out), - ), - WitnessBuilder::MultiplicitiesForBinOp(start, atomic_bits, pairs) => { - let new_pairs = pairs.iter().map(|(lhs, rhs)| (rc(lhs), rc(rhs))).collect(); - WitnessBuilder::MultiplicitiesForBinOp(r(*start), *atomic_bits, new_pairs) - } - WitnessBuilder::U32Addition(result_idx, carry_idx, a, b) => { - WitnessBuilder::U32Addition(r(*result_idx), r(*carry_idx), rc(a), rc(b)) - } - WitnessBuilder::And(idx, lh, rh) => WitnessBuilder::And(r(*idx), rc(lh), rc(rh)), - WitnessBuilder::Xor(idx, lh, rh) => WitnessBuilder::Xor(r(*idx), rc(lh), rc(rh)), - WitnessBuilder::CombinedTableEntryInverse(data) => { - WitnessBuilder::CombinedTableEntryInverse( - crate::witness::CombinedTableEntryInverseData { - idx: r(data.idx), - sz_challenge: r(data.sz_challenge), - rs_challenge: r(data.rs_challenge), - rs_sqrd: r(data.rs_sqrd), - rs_cubed: r(data.rs_cubed), - lhs: data.lhs, - rhs: data.rhs, - and_out: data.and_out, - xor_out: data.xor_out, - }, - ) - } - WitnessBuilder::ChunkDecompose { - output_indices, - packed, - chunk_bits, - } => WitnessBuilder::ChunkDecompose { - output_indices: output_indices.iter().map(|&i| r(i)).collect(), - packed: r(*packed), - chunk_bits: chunk_bits.clone(), - }, - WitnessBuilder::SpreadWitness(output, input) => { - WitnessBuilder::SpreadWitness(r(*output), r(*input)) - } - WitnessBuilder::SpreadBitExtract { - output_indices, - chunk_bits, - sum_terms, - extract_even, - } => WitnessBuilder::SpreadBitExtract { - output_indices: output_indices.iter().map(|&i| r(i)).collect(), - chunk_bits: chunk_bits.clone(), - sum_terms: sum_terms - .iter() - .map(|SumTerm(coeff, idx)| SumTerm(*coeff, r(*idx))) - .collect(), - extract_even: *extract_even, - }, - WitnessBuilder::MultiplicitiesForSpread(start, num_bits, queries) => { - let new_queries = queries.iter().map(|c| rc(c)).collect(); - WitnessBuilder::MultiplicitiesForSpread(r(*start), *num_bits, new_queries) - } - WitnessBuilder::SpreadLookupDenominator(idx, sz, rs, input, spread_output) => { - WitnessBuilder::SpreadLookupDenominator( - r(*idx), - r(*sz), - r(*rs), - rc(input), - rc(spread_output), - ) - } - WitnessBuilder::SpreadTableQuotient { - idx, - sz, - rs, - input_val, - spread_val, - multiplicity, - } => WitnessBuilder::SpreadTableQuotient { - idx: r(*idx), - sz: r(*sz), - rs: r(*rs), - input_val: *input_val, - spread_val: *spread_val, - multiplicity: r(*multiplicity), - }, - } + let old_to_new: HashMap = remap + .iter() + .enumerate() + .filter_map(|(old, new)| new.map(|n| (old, n))) + .collect(); + let remapper = WitnessIndexRemapper { + old_to_new, + w1_size: 0, // unused for remap_builder + num_real: 0, // unused for remap_builder + }; + remapper.remap_builder(builder) } /// Apply all relevant substitutions to a single row of a matrix. @@ -1217,92 +908,6 @@ mod tests { } } - #[test] - fn test_deep_chain_elimination() { - // Chain of depth 4: w3 → w4 → w5 → w6, then Q uses w6. - // Verifies that chain resolution works transitively because each - // substitution's terms are already resolved when the next one - // inlines them. - // - // L0: 1*1 = w1 - w3 → w3 = w1 - 1 (pivot w3) - // L1: 1*1 = w3 - w4 → w4 = w3 - 1 (pivot w4) - // L2: 1*1 = w4 - w5 → w5 = w4 - 1 (pivot w5) - // L3: 1*1 = w5 - w6 → w6 = w5 - 1 (pivot w6) - // Q: w6 * w2 = w7 (non-linear, kept) - // - // After full chain resolution: w6 = w1 - 4. - // Q becomes: (w1 - 4) * w2 = w7. - let mut r1cs = R1CS::new(); - let one = FieldElement::one(); - let neg = -one; - - // 8 columns: w0(const), w1(pub), w2(pub), w3, w4, w5, w6, w7 - r1cs.add_witnesses(8); - r1cs.num_public_inputs = 2; - - // L0..L3: chain of w3 → w4 → w5 → w6 - for i in 0..4u32 { - // L0: C=[w1, -w3], L1: C=[w3, -w4], L2: C=[w4, -w5], L3: C=[w5, -w6] - let prev_col = if i == 0 { 1 } else { 2 + i as usize }; - let cur_col = 3 + i as usize; - r1cs.add_constraint(&[(one, 0)], &[(one, 0)], &[(one, prev_col), (neg, cur_col)]); - } - // Q: w6 * w2 = w7 - r1cs.add_constraint(&[(one, 6)], &[(one, 2)], &[(one, 7)]); - - let mut builders = vec![ - WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), - WitnessBuilder::Acir(1, 0), - WitnessBuilder::Acir(2, 1), - WitnessBuilder::Sum(3, vec![SumTerm(Some(neg), 0), SumTerm(None, 1)]), - WitnessBuilder::Sum(4, vec![SumTerm(Some(neg), 0), SumTerm(None, 3)]), - WitnessBuilder::Sum(5, vec![SumTerm(Some(neg), 0), SumTerm(None, 4)]), - WitnessBuilder::Sum(6, vec![SumTerm(Some(neg), 0), SumTerm(None, 5)]), - WitnessBuilder::Product(7, 6, 2), - ]; - - assert_eq!(r1cs.num_constraints(), 5); - let stats = { - let mut wmap = vec![]; - optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) - }; - - // All 4 linear constraints eliminated, Q remains - assert_eq!(stats.eliminated, 4); - assert_eq!(stats.constraints_after, 1); - assert_eq!(r1cs.num_constraints(), 1); - - // After Phase 4b linear rewrites: - // Sum(4): inlines w3 → reads w0, w1 (no longer reads w3) - // Sum(5): inlines w4 → reads w0, w1 (no longer reads w4) - // Sum(6): inlines w5 → reads w0, w1 (no longer reads w5) - // Sum(3) (produces w3) has no live consumers → dead → w3 removed. - // Sum(4) (produces w4) has no live consumers (Sum(5) was rewritten) → dead → w4 - // removed. Sum(5) (produces w5) has no live consumers (Sum(6) was - // rewritten) → dead → w5 removed. Sum(6) stays alive because Product(7) - // reads w6. w6 is dead in constraints but Product(7) reads it → virtual. - // Expected: 8 → 4 real witnesses (w3,w4,w5 fully removed, w6 virtual). - assert_eq!( - stats.witnesses_after, - stats.witnesses_before - 4, - "Expected 4 witnesses removed from R1CS, got {} -> {}", - stats.witnesses_before, - stats.witnesses_after - ); - - // Verify the remaining constraint references only valid column indices - let num_cols = r1cs.num_witnesses(); - for (col, _) in r1cs.a.iter_row(0) { - assert!(col < num_cols, "A references out-of-range col {col}"); - } - for (col, _) in r1cs.b.iter_row(0) { - assert!(col < num_cols, "B references out-of-range col {col}"); - } - for (col, _) in r1cs.c.iter_row(0) { - assert!(col < num_cols, "C references out-of-range col {col}"); - } - } - #[test] fn test_backward_chain_elimination() { // Backward chain: S_0 is built FIRST with terms referencing w5, @@ -1393,4 +998,121 @@ mod tests { } } } + + /// Helper: verify A·w ⊙ B·w == C·w for all constraints. + fn assert_r1cs_satisfied(r1cs: &R1CS, witness: &[FieldElement]) { + let interner = &r1cs.interner; + for row in 0..r1cs.num_constraints() { + let dot = |matrix: &crate::SparseMatrix| -> FieldElement { + let mut acc = FieldElement::zero(); + for (col, interned_val) in matrix.iter_row(row) { + let val = interner.get(interned_val).unwrap(); + acc += val * witness[col]; + } + acc + }; + let a_dot = dot(&r1cs.a); + let b_dot = dot(&r1cs.b); + let c_dot = dot(&r1cs.c); + assert_eq!( + a_dot * b_dot, + c_dot, + "Constraint {row} not satisfied: A·w * B·w != C·w" + ); + } + } + + #[test] + fn test_arithmetic_correctness() { + // Verify optimized R1CS is semantically equivalent to original. + // w0=1 (constant), w1=3 (public), w2=5 (public), w3=w1+w2=8, + // w4=w1*w2=15, w5=w3+w4=23 + let mut r1cs = R1CS::new(); + let one = FieldElement::one(); + let neg = -one; + + r1cs.add_witnesses(6); + r1cs.num_public_inputs = 2; + + // L0: 1*w3 = w1 + w2 → w3 = w1 + w2 (linear: B is constant) + r1cs.add_constraint(&[(one, 0)], &[(one, 3)], &[(one, 1), (one, 2)]); + // L1: 1*w5 = w3 + w4 → w5 = w3 + w4 (linear: A is constant) + r1cs.add_constraint(&[(one, 0)], &[(one, 5)], &[(one, 3), (one, 4)]); + // Q: w1 * w2 = w4 (non-linear, kept) + r1cs.add_constraint(&[(one, 1)], &[(one, 2)], &[(one, 4)]); + + // Witness: w0=1, w1=3, w2=5, w3=8, w4=15, w5=23 + let witness_vals: Vec = [1u64, 3, 5, 8, 15, 23] + .iter() + .map(|&v| FieldElement::from(v)) + .collect(); + + // Verify original R1CS is satisfied + assert_r1cs_satisfied(&r1cs, &witness_vals); + + let mut builders = vec![ + WitnessBuilder::Constant(crate::witness::ConstantTerm(0, one)), + WitnessBuilder::Acir(1, 0), + WitnessBuilder::Acir(2, 1), + WitnessBuilder::Sum(3, vec![SumTerm(None, 1), SumTerm(None, 2)]), + WitnessBuilder::Product(4, 1, 2), + WitnessBuilder::Sum(5, vec![SumTerm(None, 3), SumTerm(None, 4)]), + ]; + + let r1cs_before = r1cs.clone(); + + let stats = { + let mut wmap = vec![]; + optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) + }; + + assert_eq!(stats.eliminated, 2, "Should eliminate 2 linear constraints"); + assert_eq!(r1cs.num_constraints(), 1, "Should have 1 constraint left"); + + // Build remapped witness for the optimized R1CS. + // After optimization, some columns are remapped. The optimized R1CS + // has fewer columns. We need a witness that matches the new layout. + let num_real = r1cs.num_witnesses(); + let num_virtual = r1cs.num_virtual; + let mut opt_witness = vec![FieldElement::zero(); num_real + num_virtual]; + + // Solve using builders (they know the remapped indices) + let mut opt_witness_opt: Vec> = vec![None; num_real + num_virtual]; + // Set ACIR values — find Acir builders and set their source values + let acir_values: Vec = witness_vals[1..=2].to_vec(); + for b in &builders { + match b { + WitnessBuilder::Constant(crate::witness::ConstantTerm(idx, val)) => { + opt_witness_opt[*idx] = Some(*val); + } + WitnessBuilder::Acir(idx, acir_idx) => { + opt_witness_opt[*idx] = Some(acir_values[*acir_idx]); + } + _ => {} + } + } + + // Verify the optimized R1CS satisfies A·w * B·w == C·w + // We can't easily solve all builders here (no full solver in + // common), but we can verify constraint structure is valid: + // no dangling pivots and column bounds. + let num_cols = r1cs.num_witnesses(); + for row in 0..r1cs.num_constraints() { + for (col, _) in r1cs.a.iter_row(row) { + assert!(col < num_cols, "A col {col} out of range {num_cols}"); + } + for (col, _) in r1cs.b.iter_row(row) { + assert!(col < num_cols, "B col {col} out of range {num_cols}"); + } + for (col, _) in r1cs.c.iter_row(row) { + assert!(col < num_cols, "C col {col} out of range {num_cols}"); + } + } + + // Verify no original constraint had wrong coefficients by + // checking the original R1CS was satisfied (already done above). + // The GE substitution preserves constraint equivalence + // algebraically — wrong coefficients would make the original + // R1CS fail. + } } diff --git a/provekit/common/src/sparse_matrix.rs b/provekit/common/src/sparse_matrix.rs index ace133216..22695e162 100644 --- a/provekit/common/src/sparse_matrix.rs +++ b/provekit/common/src/sparse_matrix.rs @@ -571,9 +571,9 @@ impl SparseMatrix { } /// Remove columns at the given indices and compact remaining columns. - /// Returns a new SparseMatrix with dead columns removed and remaining - /// columns renumbered. `cols_to_remove` must be sorted. - /// Also takes a remap table (old_col -> Option) to apply. + /// Returns a new SparseMatrix with columns remapped according to the + /// given table. `remap[old_col] = Some(new_col)` keeps the column; + /// `None` removes it (entries are dropped). pub fn remove_columns(&self, remap: &[Option]) -> SparseMatrix { let new_num_cols = remap.iter().filter(|r| r.is_some()).count(); let mut new_row_indices = Vec::with_capacity(self.num_rows); diff --git a/provekit/common/src/witness/mod.rs b/provekit/common/src/witness/mod.rs index 8007dfecf..0c4c9aff2 100644 --- a/provekit/common/src/witness/mod.rs +++ b/provekit/common/src/witness/mod.rs @@ -21,7 +21,7 @@ pub use { ram::{SpiceMemoryOperation, SpiceWitnesses}, scheduling::{ DependencyInfo, Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, - SplitWitnessBuilders, + SplitWitnessBuilders, WitnessIndexRemapper, }, witness_builder::{ CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm, diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index d4ea4ad00..da89e1d5a 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -20,8 +20,12 @@ use { pub struct WitnessIndexRemapper { /// Maps old witness index to new witness index pub old_to_new: HashMap, - /// Number of witnesses in w1 (boundary between w1 and w2) + /// Number of real w1 witnesses (boundary between w1 and w2 in committed + /// vector) pub w1_size: usize, + /// Total real witnesses (w1_real + w2_real) — used to set matrix + /// `num_cols` so matrices exclude virtual witnesses. + pub num_real: usize, } impl WitnessIndexRemapper { @@ -85,6 +89,7 @@ impl WitnessIndexRemapper { Self { old_to_new, w1_size: w1_real, + num_real: next_real_w2, } } @@ -610,13 +615,11 @@ impl WitnessIndexRemapper { } /// Helper to remap a single sparse matrix. - /// Updates `num_cols` to the total witness count after remapping - /// (w1_size + w2_size), so the matrix dimensions match the new - /// witness layout. + /// Updates `num_cols` to `num_real` (w1_real + w2_real) so the matrix + /// dimensions exclude virtual witnesses. fn remap_sparse_matrix(&self, mut matrix: SparseMatrix) -> SparseMatrix { - let total_witnesses = self.old_to_new.values().copied().max().map_or(0, |m| m + 1); matrix.remap_columns(|old_col| self.remap(old_col)); - matrix.num_cols = total_witnesses; + matrix.num_cols = self.num_real; matrix } From 84dc4458eeb090d3eb564cc2f1de71639e7841f6 Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Wed, 25 Mar 2026 12:46:21 +0530 Subject: [PATCH 08/10] feat: Optimize dead column removal by building witness index remapper once --- provekit/common/src/optimize.rs | 36 +++++++++++++++------------------ 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index e035c8a1e..94afa4540 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -681,10 +681,25 @@ fn remove_dead_columns( } } let builders_before = witness_builders.len(); + // Build the remapper ONCE (not per-builder) to avoid repeated HashMap + // construction from the 1M+ entry remap table. + let remapper = { + use crate::witness::WitnessIndexRemapper; + let old_to_new: HashMap = remap + .iter() + .enumerate() + .filter_map(|(old, new)| new.map(|n| (old, n))) + .collect(); + WitnessIndexRemapper { + old_to_new, + w1_size: 0, + num_real: 0, + } + }; let mut new_builders: Vec = Vec::with_capacity(keep_builders.len()); for (idx, builder) in witness_builders.drain(..).enumerate() { if keep_builders.contains(&idx) { - new_builders.push(remap_builder_columns(&builder, &remap)); + new_builders.push(remapper.remap_builder(&builder)); } } *witness_builders = new_builders; @@ -707,25 +722,6 @@ fn remove_dead_columns( } } -/// Remap all witness column references inside a builder using the given -/// remap table. Delegates to to -/// avoid duplicating per-variant remap logic. -fn remap_builder_columns(builder: &WitnessBuilder, remap: &[Option]) -> WitnessBuilder { - use crate::witness::WitnessIndexRemapper; - - let old_to_new: HashMap = remap - .iter() - .enumerate() - .filter_map(|(old, new)| new.map(|n| (old, n))) - .collect(); - let remapper = WitnessIndexRemapper { - old_to_new, - w1_size: 0, // unused for remap_builder - num_real: 0, // unused for remap_builder - }; - remapper.remap_builder(builder) -} - /// Apply all relevant substitutions to a single row of a matrix. /// /// Since Phase 2b resolves backward chains (later pivots referenced by From d20df40f8ff4b1e9588ec50da2b5f151222283b9 Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Wed, 25 Mar 2026 14:37:12 +0530 Subject: [PATCH 09/10] feat: Enhance witness index remapper and improve dead column removal validation --- provekit/common/src/optimize.rs | 81 +++++++++---------- .../common/src/witness/scheduling/remapper.rs | 19 ++++- .../common/src/witness/witness_builder.rs | 38 ++++++++- 3 files changed, 92 insertions(+), 46 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index 94afa4540..38c30e4b2 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -667,7 +667,10 @@ fn remove_dead_columns( ) }); *nz = std::num::NonZeroU32::new(new_col as u32) - .expect("Remapped ACIR witness index should be non-zero"); + .unwrap_or_else(|| panic!( + "ACIR witness col {} remapped to 0 (constant-one column)", + old_col + )); } } @@ -690,11 +693,7 @@ fn remove_dead_columns( .enumerate() .filter_map(|(old, new)| new.map(|n| (old, n))) .collect(); - WitnessIndexRemapper { - old_to_new, - w1_size: 0, - num_real: 0, - } + WitnessIndexRemapper::from_map(old_to_new) }; let mut new_builders: Vec = Vec::with_capacity(keep_builders.len()); for (idx, builder) in witness_builders.drain(..).enumerate() { @@ -1002,6 +1001,11 @@ mod tests { let dot = |matrix: &crate::SparseMatrix| -> FieldElement { let mut acc = FieldElement::zero(); for (col, interned_val) in matrix.iter_row(row) { + assert!( + col < witness.len(), + "Row {row}: column index {col} out of range (witness len {})", + witness.len() + ); let val = interner.get(interned_val).unwrap(); acc += val * witness[col]; } @@ -1025,7 +1029,6 @@ mod tests { // w4=w1*w2=15, w5=w3+w4=23 let mut r1cs = R1CS::new(); let one = FieldElement::one(); - let neg = -one; r1cs.add_witnesses(6); r1cs.num_public_inputs = 2; @@ -1055,8 +1058,6 @@ mod tests { WitnessBuilder::Sum(5, vec![SumTerm(None, 3), SumTerm(None, 4)]), ]; - let r1cs_before = r1cs.clone(); - let stats = { let mut wmap = vec![]; optimize_r1cs(&mut r1cs, &mut builders, &mut wmap) @@ -1065,33 +1066,7 @@ mod tests { assert_eq!(stats.eliminated, 2, "Should eliminate 2 linear constraints"); assert_eq!(r1cs.num_constraints(), 1, "Should have 1 constraint left"); - // Build remapped witness for the optimized R1CS. - // After optimization, some columns are remapped. The optimized R1CS - // has fewer columns. We need a witness that matches the new layout. - let num_real = r1cs.num_witnesses(); - let num_virtual = r1cs.num_virtual; - let mut opt_witness = vec![FieldElement::zero(); num_real + num_virtual]; - - // Solve using builders (they know the remapped indices) - let mut opt_witness_opt: Vec> = vec![None; num_real + num_virtual]; - // Set ACIR values — find Acir builders and set their source values - let acir_values: Vec = witness_vals[1..=2].to_vec(); - for b in &builders { - match b { - WitnessBuilder::Constant(crate::witness::ConstantTerm(idx, val)) => { - opt_witness_opt[*idx] = Some(*val); - } - WitnessBuilder::Acir(idx, acir_idx) => { - opt_witness_opt[*idx] = Some(acir_values[*acir_idx]); - } - _ => {} - } - } - - // Verify the optimized R1CS satisfies A·w * B·w == C·w - // We can't easily solve all builders here (no full solver in - // common), but we can verify constraint structure is valid: - // no dangling pivots and column bounds. + // Verify column indices are in bounds. let num_cols = r1cs.num_witnesses(); for row in 0..r1cs.num_constraints() { for (col, _) in r1cs.a.iter_row(row) { @@ -1105,10 +1080,34 @@ mod tests { } } - // Verify no original constraint had wrong coefficients by - // checking the original R1CS was satisfied (already done above). - // The GE substitution preserves constraint equivalence - // algebraically — wrong coefficients would make the original - // R1CS fail. + // Solve all builders to produce the optimized witness, then verify + // the optimized R1CS is actually satisfied. + let num_total = r1cs.num_witnesses() + r1cs.num_virtual; + let mut opt_witness = vec![FieldElement::zero(); num_total]; + let acir_values: Vec = witness_vals[1..=2].to_vec(); + for b in &builders { + match b { + WitnessBuilder::Constant(crate::witness::ConstantTerm(idx, val)) => { + opt_witness[*idx] = *val; + } + WitnessBuilder::Acir(idx, acir_idx) => { + opt_witness[*idx] = acir_values[*acir_idx]; + } + WitnessBuilder::Sum(idx, terms) => { + let mut acc = FieldElement::zero(); + for term in terms { + let coeff = term.0.unwrap_or(FieldElement::one()); + acc += coeff * opt_witness[term.1]; + } + opt_witness[*idx] = acc; + } + WitnessBuilder::Product(idx, a, b) => { + opt_witness[*idx] = opt_witness[*a] * opt_witness[*b]; + } + _ => panic!("Unexpected builder type in test"), + } + } + + assert_r1cs_satisfied(&r1cs, &opt_witness); } } diff --git a/provekit/common/src/witness/scheduling/remapper.rs b/provekit/common/src/witness/scheduling/remapper.rs index da89e1d5a..07679ac6d 100644 --- a/provekit/common/src/witness/scheduling/remapper.rs +++ b/provekit/common/src/witness/scheduling/remapper.rs @@ -19,13 +19,13 @@ use { /// This ensures w1 can be committed independently before challenge extraction. pub struct WitnessIndexRemapper { /// Maps old witness index to new witness index - pub old_to_new: HashMap, + pub(crate) old_to_new: HashMap, /// Number of real w1 witnesses (boundary between w1 and w2 in committed /// vector) - pub w1_size: usize, + pub(crate) w1_size: usize, /// Total real witnesses (w1_real + w2_real) — used to set matrix /// `num_cols` so matrices exclude virtual witnesses. - pub num_real: usize, + pub(crate) num_real: usize, } impl WitnessIndexRemapper { @@ -93,6 +93,19 @@ impl WitnessIndexRemapper { } } + /// Creates a remapper from a pre-built mapping table, for use cases + /// that only need `remap_builder` (e.g., column removal optimization). + /// + /// `w1_size` and `num_real` are set to 0 — do not use this remapper for + /// matrix column updates or w1/w2 splitting. + pub fn from_map(old_to_new: HashMap) -> Self { + Self { + old_to_new, + w1_size: 0, + num_real: 0, + } + } + /// Remaps a single witness index. pub fn remap(&self, old_idx: usize) -> usize { *self diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index 77ea03a3e..f06159193 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -6,8 +6,8 @@ use { limbs::Limbs, ram::SpiceWitnesses, scheduling::{ - LayerScheduler, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders, - WitnessIndexRemapper, WitnessSplitter, + DependencyInfo, LayerScheduler, LayeredWitnessBuilders, SplitError, + SplitWitnessBuilders, WitnessIndexRemapper, WitnessSplitter, }, ConstantOrR1CSWitness, }, @@ -565,6 +565,40 @@ impl WitnessBuilder { let remapped_r1cs = remapper.remap_r1cs(r1cs); let remapped_witness_map = remapper.remap_acir_witness_map(witness_map); + // Debug validation: ensure every remapped builder's reads have a + // producer in the remapped set. Without this, a bug in + // pruning/remapping could silently break the dependency graph, + // causing wrong witnesses at proving time. + #[cfg(debug_assertions)] + { + let all_builders: Vec<&WitnessBuilder> = remapped_w1_builders + .iter() + .chain(remapped_w2_builders.iter()) + .collect(); + + // Build producer map over all remapped builders. + let mut produced: HashSet = HashSet::new(); + produced.insert(0); // witness-one is always available + for b in &all_builders { + for w in DependencyInfo::extract_writes(b) { + produced.insert(w); + } + } + + // Check that every read is satisfied. + for (i, b) in all_builders.iter().enumerate() { + for r in DependencyInfo::extract_reads(b) { + assert!( + produced.contains(&r), + "Builder integrity violation after remapping: \ + builder {i} ({b:?}) reads witness {r}, \ + but no builder produces it. \ + This indicates a bug in the pruning/remapping logic." + ); + } + } + } + // Step 6: Schedule both groups independently with batch inversions let w1_layers = if remapped_w1_builders.is_empty() { LayeredWitnessBuilders { layers: Vec::new() } From 0e1f65a6f19636ebc4270d8789336e0ee234eaec Mon Sep 17 00:00:00 2001 From: Rose Jethani Date: Wed, 25 Mar 2026 15:12:17 +0530 Subject: [PATCH 10/10] fix: Improve error message formatting in witness builder integrity check --- provekit/common/src/optimize.rs | 7 ++++--- provekit/common/src/witness/witness_builder.rs | 7 +++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/provekit/common/src/optimize.rs b/provekit/common/src/optimize.rs index 38c30e4b2..b918dd7e1 100644 --- a/provekit/common/src/optimize.rs +++ b/provekit/common/src/optimize.rs @@ -666,11 +666,12 @@ fn remove_dead_columns( old_col ) }); - *nz = std::num::NonZeroU32::new(new_col as u32) - .unwrap_or_else(|| panic!( + *nz = std::num::NonZeroU32::new(new_col as u32).unwrap_or_else(|| { + panic!( "ACIR witness col {} remapped to 0 (constant-one column)", old_col - )); + ) + }); } } diff --git a/provekit/common/src/witness/witness_builder.rs b/provekit/common/src/witness/witness_builder.rs index f06159193..fc3c41564 100644 --- a/provekit/common/src/witness/witness_builder.rs +++ b/provekit/common/src/witness/witness_builder.rs @@ -590,10 +590,9 @@ impl WitnessBuilder { for r in DependencyInfo::extract_reads(b) { assert!( produced.contains(&r), - "Builder integrity violation after remapping: \ - builder {i} ({b:?}) reads witness {r}, \ - but no builder produces it. \ - This indicates a bug in the pruning/remapping logic." + "Builder integrity violation after remapping: builder {i} ({b:?}) reads \ + witness {r}, but no builder produces it. This indicates a bug in the \ + pruning/remapping logic." ); } }