Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
907 changes: 667 additions & 240 deletions provekit/common/src/optimize.rs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions provekit/common/src/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ pub struct R1CS {
pub a: SparseMatrix,
pub b: SparseMatrix,
pub c: SparseMatrix,
/// Virtual witnesses: computation-only columns excluded from matrices/WHIR
/// commitment but needed by witness builders for intermediate calculations.
/// The prover allocates `num_witnesses() + num_virtual` for solving.
#[serde(default)]
pub num_virtual: usize,
}

impl Default for R1CS {
Expand All @@ -32,6 +37,7 @@ impl R1CS {
a: SparseMatrix::new(0, 0),
b: SparseMatrix::new(0, 0),
c: SparseMatrix::new(0, 0),
num_virtual: 0,
}
}

Expand Down
33 changes: 33 additions & 0 deletions provekit/common/src/sparse_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,39 @@ impl SparseMatrix {
}
}

/// Remove columns at the given indices and compact remaining columns.
/// Returns a new SparseMatrix with columns remapped according to the
/// given table. `remap[old_col] = Some(new_col)` keeps the column;
/// `None` removes it (entries are dropped).
pub fn remove_columns(&self, remap: &[Option<usize>]) -> SparseMatrix {
let new_num_cols = remap.iter().filter(|r| r.is_some()).count();
let mut new_row_indices = Vec::with_capacity(self.num_rows);
let mut new_col_indices = Vec::new();
let mut new_values = Vec::new();

for row in 0..self.num_rows {
new_row_indices.push(new_col_indices.len() as u32);
let range = self.row_range(row);
for i in range {
let old_col = self.col_indices[i] as usize;
if let Some(new_col) = remap[old_col] {
new_col_indices.push(new_col as u32);
new_values.push(self.values[i]);
}
// Dead columns should have no entries (zero occurrence),
// so this branch should never drop data.
}
}

SparseMatrix {
num_rows: self.num_rows,
num_cols: new_num_cols,
new_row_indices,
col_indices: new_col_indices,
values: new_values,
}
}

/// Count how many rows reference each column. Returns a Vec of length
/// num_cols.
pub fn column_occurrence_count(&self) -> Vec<usize> {
Expand Down
8 changes: 4 additions & 4 deletions provekit/common/src/witness/digits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ pub struct DigitalDecompositionWitnesses {
pub num_witnesses_to_decompose: usize,
/// Witness indices of the values to be decomposed
pub witnesses_to_decompose: Vec<usize>,
/// The index of the first witness written to
pub first_witness_idx: usize,
/// The number of witnesses written to
pub num_witnesses: usize,
/// Output witness indices. Length = log_bases.len() *
/// num_witnesses_to_decompose. Layout: output_indices[digit_place *
/// num_witnesses_to_decompose + i]
pub output_indices: Vec<usize>,
}

/// Compute a mixed-base decomposition of a field element into its digits, using
Expand Down
3 changes: 2 additions & 1 deletion provekit/common/src/witness/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ pub use {
limbs::{Limbs, MAX_LIMBS},
ram::{SpiceMemoryOperation, SpiceWitnesses},
scheduling::{
Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError, SplitWitnessBuilders,
DependencyInfo, Layer, LayerScheduler, LayerType, LayeredWitnessBuilders, SplitError,
SplitWitnessBuilders, WitnessIndexRemapper,
},
witness_builder::{
CombinedTableEntryInverseData, ConstantTerm, NonNativeEcOp, ProductLinearTerm, SumTerm,
Expand Down
18 changes: 4 additions & 14 deletions provekit/common/src/witness/scheduling/dependency.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl DependencyInfo {
}

/// Extracts the witness indices that a builder reads as inputs.
fn extract_reads(wb: &WitnessBuilder) -> Vec<usize> {
pub fn extract_reads(wb: &WitnessBuilder) -> Vec<usize> {
match wb {
WitnessBuilder::Constant(_)
| WitnessBuilder::Acir(..)
Expand Down Expand Up @@ -323,26 +323,16 @@ impl DependencyInfo {
WitnessBuilder::MultiplicitiesForRange(start, range, _) => {
(*start..*start + *range).collect()
}
WitnessBuilder::DigitalDecomposition(dd) => {
(dd.first_witness_idx..dd.first_witness_idx + dd.num_witnesses).collect()
}
WitnessBuilder::DigitalDecomposition(dd) => dd.output_indices.clone(),
WitnessBuilder::SpiceWitnesses(sw) => {
(sw.first_witness_idx..sw.first_witness_idx + sw.num_witnesses).collect()
}
WitnessBuilder::MultiplicitiesForBinOp(start, atomic_bits, ..) => {
let n = 2usize.pow(2 * *atomic_bits);
(*start..*start + n).collect()
}
WitnessBuilder::ChunkDecompose {
output_start,
chunk_bits,
..
} => (*output_start..*output_start + chunk_bits.len()).collect(),
WitnessBuilder::SpreadBitExtract {
output_start,
chunk_bits,
..
} => (*output_start..*output_start + chunk_bits.len()).collect(),
WitnessBuilder::ChunkDecompose { output_indices, .. } => output_indices.clone(),
WitnessBuilder::SpreadBitExtract { output_indices, .. } => output_indices.clone(),
WitnessBuilder::MultiplicitiesForSpread(start, num_bits, _) => {
let n = 1usize << *num_bits;
(*start..*start + n).collect()
Expand Down
141 changes: 96 additions & 45 deletions provekit/common/src/witness/scheduling/remapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,90 @@ use {
/// This ensures w1 can be committed independently before challenge extraction.
pub struct WitnessIndexRemapper {
/// Maps old witness index to new witness index
pub old_to_new: HashMap<usize, usize>,
/// Number of witnesses in w1 (boundary between w1 and w2)
pub w1_size: usize,
pub(crate) old_to_new: HashMap<usize, usize>,
/// Number of real w1 witnesses (boundary between w1 and w2 in committed
/// vector)
pub(crate) w1_size: usize,
/// Total real witnesses (w1_real + w2_real) — used to set matrix
/// `num_cols` so matrices exclude virtual witnesses.
pub(crate) num_real: usize,
}

impl WitnessIndexRemapper {
/// Creates a remapping from w1 and w2 builder lists.
///
/// Assigns w1 builder outputs to [0, k) and w2 builder outputs to [k, n).
pub fn new(w1_builders: &[WitnessBuilder], w2_builders: &[WitnessBuilder]) -> Self {
/// `num_real_cols` is the R1CS column count before splitting — indices
/// below this are "real" (constrained), indices >= are "virtual"
/// (computation-only).
///
/// Output layout:
/// [0, w1_real) → real w1 witnesses (committed)
/// [w1_real, w1_real+w2_real) → real w2 witnesses (committed)
/// [w1_real+w2_real, total) → virtual witnesses (solving only)
///
/// `w1_size` is set to `w1_real` so the WHIR commitment covers only
/// real witnesses.
pub fn new(
w1_builders: &[WitnessBuilder],
w2_builders: &[WitnessBuilder],
num_real_cols: usize,
) -> Self {
let mut old_to_new = HashMap::new();
let mut next_w1_idx = 0;
let mut next_w2_idx = 0;
let mut next_real_w1 = 0usize;
let mut virtual_w1: Vec<usize> = Vec::new();

// Map w1 builder outputs to [0, k)
// First pass w1: assign real outputs, collect virtual
for builder in w1_builders {
let writes = DependencyInfo::extract_writes(builder);
for old_idx in writes {
old_to_new.insert(old_idx, next_w1_idx);
next_w1_idx += 1;
for old_idx in DependencyInfo::extract_writes(builder) {
if old_idx < num_real_cols {
old_to_new.insert(old_idx, next_real_w1);
next_real_w1 += 1;
} else {
virtual_w1.push(old_idx);
}
}
}
let w1_real = next_real_w1;

let w1_size = next_w1_idx;
let mut next_real_w2 = w1_real;
let mut virtual_w2: Vec<usize> = Vec::new();

// Map w2 builder outputs to [k, n)
// First pass w2: assign real outputs, collect virtual
for builder in w2_builders {
let writes = DependencyInfo::extract_writes(builder);
for old_idx in writes {
old_to_new.insert(old_idx, w1_size + next_w2_idx);
next_w2_idx += 1;
for old_idx in DependencyInfo::extract_writes(builder) {
if old_idx < num_real_cols {
old_to_new.insert(old_idx, next_real_w2);
next_real_w2 += 1;
} else {
virtual_w2.push(old_idx);
}
}
}

// Second pass: assign virtual outputs after all real ones
let mut next_virtual = next_real_w2;
for old_idx in virtual_w1.into_iter().chain(virtual_w2) {
old_to_new.insert(old_idx, next_virtual);
next_virtual += 1;
}

Self {
old_to_new,
w1_size,
w1_size: w1_real,
num_real: next_real_w2,
}
}

/// Creates a remapper from a pre-built mapping table, for use cases
/// that only need `remap_builder` (e.g., column removal optimization).
///
/// `w1_size` and `num_real` are set to 0 — do not use this remapper for
/// matrix column updates or w1/w2 splitting.
pub fn from_map(old_to_new: HashMap<usize, usize>) -> Self {
Self {
old_to_new,
w1_size: 0,
num_real: 0,
}
}

Expand Down Expand Up @@ -162,22 +209,22 @@ impl WitnessIndexRemapper {
WitnessCoefficient(*coeff, self.remap(*value)),
)
}
WitnessBuilder::DigitalDecomposition(dd) => {
let new_witnesses_to_decompose = dd
.witnesses_to_decompose
.iter()
.map(|&w| self.remap(w))
.collect();
WitnessBuilder::DigitalDecomposition(
crate::witness::DigitalDecompositionWitnesses {
log_bases: dd.log_bases.clone(),
num_witnesses_to_decompose: dd.num_witnesses_to_decompose,
witnesses_to_decompose: new_witnesses_to_decompose,
first_witness_idx: self.remap(dd.first_witness_idx),
num_witnesses: dd.num_witnesses,
},
)
}
WitnessBuilder::DigitalDecomposition(dd) => WitnessBuilder::DigitalDecomposition(
crate::witness::DigitalDecompositionWitnesses {
log_bases: dd.log_bases.clone(),
num_witnesses_to_decompose: dd.num_witnesses_to_decompose,
witnesses_to_decompose: dd
.witnesses_to_decompose
.iter()
.map(|&w| self.remap(w))
.collect(),
output_indices: dd
.output_indices
.iter()
.map(|&i| self.remap(i))
.collect(),
},
),
WitnessBuilder::SpiceMultisetFactor(
idx,
sz,
Expand Down Expand Up @@ -493,30 +540,30 @@ impl WitnessIndexRemapper {
num_bits: *num_bits,
},
WitnessBuilder::ChunkDecompose {
output_start,
output_indices,
packed,
chunk_bits,
} => WitnessBuilder::ChunkDecompose {
output_start: self.remap(*output_start),
packed: self.remap(*packed),
chunk_bits: chunk_bits.clone(),
output_indices: output_indices.iter().map(|&i| self.remap(i)).collect(),
packed: self.remap(*packed),
chunk_bits: chunk_bits.clone(),
},
WitnessBuilder::SpreadWitness(output, input) => {
WitnessBuilder::SpreadWitness(self.remap(*output), self.remap(*input))
}
WitnessBuilder::SpreadBitExtract {
output_start,
output_indices,
chunk_bits,
sum_terms,
extract_even,
} => WitnessBuilder::SpreadBitExtract {
output_start: self.remap(*output_start),
chunk_bits: chunk_bits.clone(),
sum_terms: sum_terms
output_indices: output_indices.iter().map(|&i| self.remap(i)).collect(),
chunk_bits: chunk_bits.clone(),
sum_terms: sum_terms
.iter()
.map(|SumTerm(coeff, idx)| SumTerm(*coeff, self.remap(*idx)))
.collect(),
extract_even: *extract_even,
extract_even: *extract_even,
},
WitnessBuilder::MultiplicitiesForSpread(start, num_bits, queries) => {
let new_queries = queries
Expand Down Expand Up @@ -559,6 +606,7 @@ impl WitnessIndexRemapper {
pub fn remap_r1cs(&self, r1cs: R1CS) -> R1CS {
let mut new_r1cs = R1CS::new();
new_r1cs.num_public_inputs = r1cs.num_public_inputs;
new_r1cs.num_virtual = r1cs.num_virtual;
new_r1cs.interner = r1cs.interner;

// Remap A, B, C in parallel - they're independent
Expand All @@ -579,9 +627,12 @@ impl WitnessIndexRemapper {
new_r1cs
}

/// Helper to remap a single sparse matrix
/// Helper to remap a single sparse matrix.
/// Updates `num_cols` to `num_real` (w1_real + w2_real) so the matrix
/// dimensions exclude virtual witnesses.
fn remap_sparse_matrix(&self, mut matrix: SparseMatrix) -> SparseMatrix {
matrix.remap_columns(|old_col| self.remap(old_col));
matrix.num_cols = self.num_real;
matrix
}

Expand Down
Loading
Loading