diff --git a/Cargo.lock b/Cargo.lock index 56e8bf727..9b135e680 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -858,6 +858,7 @@ dependencies = [ "gkr_engine", "gkr_hashers", "goldilocks", + "gpu", "halo2curves", "log", "mersenne31", @@ -922,6 +923,17 @@ dependencies = [ "serdes", ] +[[package]] +name = "gpu" +version = "0.1.0" +dependencies = [ + "circuit", + "gkr_engine", + "gkr_hashers", + "thiserror", + "transcript", +] + [[package]] name = "group" version = "0.13.0" diff --git a/Cargo.toml b/Cargo.toml index 829df8fd9..2b2d27fd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "config_macros", # proc macros used to declare a new config, this has to a separate crate due to rust compilation issues "gkr", "gkr_engine", # definitions of GKR engine and associated types + "gpu", # GPU support and circuit serialization "hasher", # definitions of FiatShamirFieldHasher, FiatShamirBytesHash, and associated types "poly_commit", "serdes", # serialization and deserialization of various data structures diff --git a/bin/src/main.rs b/bin/src/main.rs index a41aa76e8..c3175e7b6 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -8,9 +8,9 @@ use circuit::Circuit; use clap::Parser; use gkr::{ BN254ConfigMIMC5KZG, BN254ConfigSha2Hyrax, BN254ConfigSha2Raw, GF2ExtConfigSha2Orion, - GF2ExtConfigSha2Raw, Goldilocksx8ConfigSha2Orion, Goldilocksx8ConfigSha2Raw, - M31x1ConfigSha2RawVanilla, M31x16ConfigSha2OrionSquare, M31x16ConfigSha2OrionVanilla, - M31x16ConfigSha2RawSquare, M31x16ConfigSha2RawVanilla, Prover, + GF2ExtConfigSha2Raw, Goldilocksx1ConfigSha2Raw, Goldilocksx8ConfigSha2Orion, + Goldilocksx8ConfigSha2Raw, M31x1ConfigSha2RawVanilla, M31x16ConfigSha2OrionSquare, + M31x16ConfigSha2OrionVanilla, M31x16ConfigSha2RawSquare, M31x16ConfigSha2RawVanilla, Prover, utils::{ KECCAK_BABYBEAR_CIRCUIT, KECCAK_BABYBEAR_WITNESS, KECCAK_BN254_CIRCUIT, KECCAK_BN254_WITNESS, KECCAK_GF2_CIRCUIT, KECCAK_GF2_WITNESS, KECCAK_GOLDILOCKS_CIRCUIT, @@ -69,7 +69,13 @@ fn main() { "m31ext3" => match pcs_type { PolynomialCommitmentType::Raw => match args.circuit.as_str() { - "keccak" => run_benchmark::(&args, mpi_config.clone()), + "keccak" => { + if std::env::var("EXPANDER_GPU").is_ok_and(|v| v == "1") { + run_benchmark::(&args, mpi_config.clone()) + } else { + run_benchmark::(&args, mpi_config.clone()) + } + } "poseidon" => run_benchmark::(&args, mpi_config.clone()), _ => unreachable!(), }, @@ -112,7 +118,13 @@ fn main() { }, "goldilocks" => match pcs_type { PolynomialCommitmentType::Raw => match args.circuit.as_str() { - "keccak" => run_benchmark::(&args, mpi_config.clone()), + "keccak" => { + if std::env::var("EXPANDER_GPU").is_ok_and(|v| v == "1") { + run_benchmark::(&args, mpi_config.clone()) + } else { + run_benchmark::(&args, mpi_config.clone()) + } + } _ => unreachable!(), }, PolynomialCommitmentType::Orion => match args.circuit.as_str() { @@ -206,6 +218,7 @@ where (FieldType::M31x1, "keccak") => 2, (FieldType::M31x16, "keccak") => 2, (FieldType::M31x16, "poseidon") => 120, + (FieldType::Goldilocksx1, "keccak") => 2, (FieldType::Goldilocksx8, "keccak") => 2, (FieldType::BabyBearx16, "keccak") => 2, (FieldType::BN254, "keccak") => 2, diff --git a/gkr/Cargo.toml b/gkr/Cargo.toml index 6f395570f..6c89da793 100644 --- a/gkr/Cargo.toml +++ b/gkr/Cargo.toml @@ -13,6 +13,7 @@ gf2_128 = { path = "../arith/gf2_128" } gkr_engine = { path = "../gkr_engine" } gkr_hashers = { path = "../hasher" } goldilocks = { path = "../arith/goldilocks" } +gpu = { path = "../gpu" } mersenne31 = { path = "../arith/mersenne31" } poly_commit = { path = "../poly_commit" } polynomials = { path = "../arith/polynomials" } diff --git a/gkr/src/prover/gkr_vanilla.rs b/gkr/src/prover/gkr_vanilla.rs index cacb6d29f..69efd3e5b 100644 --- a/gkr/src/prover/gkr_vanilla.rs +++ b/gkr/src/prover/gkr_vanilla.rs @@ -14,7 +14,11 @@ pub fn gkr_prove( sp: &mut ProverScratchPad, transcript: &mut impl Transcript, mpi_config: &MPIConfig, -) -> (F::ChallengeField, ExpanderDualVarChallenge) { +) -> (F::ChallengeField, ExpanderDualVarChallenge) +where + F::CircuitField: std::fmt::Debug, + F::SimdCircuitField: std::fmt::Debug, +{ let layer_num = circuit.layers.len(); let mut challenge: ExpanderDualVarChallenge = @@ -36,6 +40,21 @@ pub fn gkr_prove( mpi_config, ); + // Serialize circuit to file if EXPANDER_GPU environment variable is set to 1 + if std::env::var("EXPANDER_GPU").is_ok_and(|v| v == "1") { + // Only let rank 0 process handle serialization + if mpi_config.is_root() { + if let Err(e) = + gpu::serdes::serial_circuit_witness_as_plaintext(circuit, transcript, &challenge) + { + println!("Failed to serialize circuit: {e}"); + } + } + } + + let mut final_vx_claim = None; + let mut final_vy_claim = None; + for i in (0..layer_num).rev() { let timer = Timer::new( &format!( @@ -47,7 +66,7 @@ pub fn gkr_prove( mpi_config.is_root(), ); - (_, _) = sumcheck_prove_gkr_layer( + let (vx_claim, vy_claim) = sumcheck_prove_gkr_layer( &circuit.layers[i], &mut challenge, alpha, @@ -57,6 +76,12 @@ pub fn gkr_prove( i == layer_num - 1, ); + // Store the final layer claims for later use + if i == 0 { + final_vx_claim = Some(vx_claim); + final_vy_claim = vy_claim; + } + if challenge.rz_1.is_some() { // TODO: try broadcast beta.unwrap directly let mut tmp = transcript.generate_field_element::(); @@ -68,5 +93,18 @@ pub fn gkr_prove( timer.stop(); } + // Print final claims if EXPANDER_GPU environment variable is set to 1 + if std::env::var("EXPANDER_GPU").is_ok_and(|v| v == "1") { + // Only let rank 0 process handle printing final claims + if mpi_config.is_root() { + if let Some(vx) = final_vx_claim { + gpu::serdes::print_final_claims::(&vx, &final_vy_claim); + println!("GKR final proof claims as shown above."); + } + } + // For GPU mode, we'll let the program continue and exit naturally + // This allows MPI to properly clean up + } + (claimed_v, challenge) } diff --git a/gpu/.gitignore b/gpu/.gitignore new file mode 100644 index 000000000..1269488f7 --- /dev/null +++ b/gpu/.gitignore @@ -0,0 +1 @@ +data diff --git a/gpu/Cargo.toml b/gpu/Cargo.toml new file mode 100644 index 000000000..a022535fc --- /dev/null +++ b/gpu/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "gpu" +version = "0.1.0" +edition = "2021" + +[dependencies] +circuit = { path = "../circuit" } +gkr_engine = { path = "../gkr_engine" } +gkr_hashers = { path = "../hasher" } +transcript = { path = "../transcript" } +thiserror.workspace = true diff --git a/gpu/Makefile b/gpu/Makefile new file mode 100644 index 000000000..934dc88f8 --- /dev/null +++ b/gpu/Makefile @@ -0,0 +1,55 @@ +# Makefile for Expander-GPU project +PROFILE_LEVEL ?= 0 +DO_ALL ?= 1 +FIELD_TYPE ?= m31ext3 + +# Supported field types +FIELDS := m31ext3 goldilocksext2 bn254 + +# Generate circuit file names for all supported field types +CIRCUIT_FILES := $(addprefix data/keccak_,$(addsuffix .gpu.circuit,$(FIELDS))) + +# Base command template +# Usage: $(call run_field,FIELD,MPI_LEN,RUN_FLAGS) +define run_field +./expander-gpu --field $(1) --circuit data/keccak_$(1).gpu.circuit --log_level 0 --mpi_len $(2) $(3) +endef + +# Helper function to run expander with different field types +# Usage: $(call run_expander,MPI_LEN,RUN_FLAGS) +define run_expander + $(if $(filter 1,$(DO_ALL)), \ + $(call run_field,m31ext3,$(1),$(2)); \ + $(call run_field,goldilocksext2,$(1),$(2)); \ + $(call run_field,bn254,$(1),$(2)), \ + $(if $(filter $(FIELD_TYPE),$(FIELDS)), \ + $(call run_field,$(FIELD_TYPE),$(1),$(2)), \ + $(error Invalid FIELD_TYPE '$(FIELD_TYPE)'. Must be one of: $(FIELDS)))) +endef + +.PHONY: clean test prepare-data profile mpi-profile mpi-test + +# When circuit files are missing, guide user to run the prepare script. +data/keccak_%.gpu.circuit: + @echo "Error: Circuit file '$@' is missing." + @echo "Please run './prepare-data.sh' to generate necessary circuit files." + @exit 1 + +# Prepare all circuit data files (only if they don't exist) +data: $(CIRCUIT_FILES) + @echo "All circuit data files are ready." + +test: data + $(call run_expander,128,) + +profile: data + $(call run_expander,8192,--enable-same-input --profile $(PROFILE_LEVEL)) + +mpi-test: data + $(call run_expander,128,--enable-mpi-merge) + +mpi-profile: data + $(call run_expander,16384,--enable-same-input --profile $(PROFILE_LEVEL) --enable-mpi-merge) + +clean: + rm -rf data diff --git a/gpu/expander-gpu b/gpu/expander-gpu new file mode 100755 index 000000000..73c9d3ead Binary files /dev/null and b/gpu/expander-gpu differ diff --git a/gpu/prepare-data.sh b/gpu/prepare-data.sh new file mode 100755 index 000000000..ccc43a9fa --- /dev/null +++ b/gpu/prepare-data.sh @@ -0,0 +1,28 @@ +# Create data folder +mkdir data +cd data + +# Download two repo for generating circuit and witness for GPU +git clone git@github.com:PolyhedraZK/ExpanderCompilerCollection.git +git clone git@github.com:PolyhedraZK/Expander.git + +# Use Expander Compiler to generate Circuit and Witness for Expander +cd ExpanderCompilerCollection +cargo test --release keccak + +# Move data to Expander +mkdir ../Expander/data +cp expander_compiler/*.txt ../Expander/data +cd ../Expander + +# Use Expander's GPU serialization to produce circuit and witness for GPU usage +git checkout gpu-expander +EXPANDER_GPU=1 RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo run --release --bin=gkr -- --circuit keccak --pcs Raw --threads 1 --field m31ext3 +EXPANDER_GPU=1 RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo run --release --bin=gkr -- --circuit keccak --pcs Raw --threads 1 --field goldilocks +EXPANDER_GPU=1 RUSTFLAGS="-C target-cpu=native -C target-feature=+avx512f" cargo run --release --bin=gkr -- --circuit keccak --pcs Raw --threads 1 --field fr +mv data/*.gpu.* .. +cd .. + +# Remove this two repo +rm -rf ExpanderCompilerCollection +cd .. diff --git a/gpu/readme.md b/gpu/readme.md new file mode 100644 index 000000000..1835d9209 --- /dev/null +++ b/gpu/readme.md @@ -0,0 +1,131 @@ +
+ Expander Logo +
+ + +# Expander GPU Acceleration + +Expander is a proof generation backend for Polyhedra Network. It aims to support fast proof generation. + +Expander now includes a high-performance GPU backend powered by CUDA, designed to dramatically accelerate proof generation. This backend is optimized for NVIDIA GPUs and offers significant speedups, especially for complex circuits and large-scale computations. + +### Key Features + +- **Massive Parallelism**: Leverages the full power of modern GPUs to process thousands of proofs in parallel. +- **MPI Merge**: Introduces an innovative "MPI Merge" feature that can compress proofs from thousands of independent computations into a single, compact proof. In our tests, we've achieved a compression ratio of up to `16384:1`. This is particularly useful in scenarios with large batches of similar computations. +- **Broad Field Support**: The GPU backend supports multiple field types, including `BN254`, `Goldilocks`, and `M31`. + +### System Requirements + +- **NVIDIA GPU**: A CUDA-enabled NVIDIA GPU with compute capability 7.0+ is recommended. +- **CUDA Toolkit**: Version 12.5 or newer. +- **Compiler**: `clang` and `clang++`. +- **Build Tools**: `cmake` (version 3.18+) and `ninja`. + +### Build Instructions + +The current release of Expander-GPU is in binary form. Please contact us if you are interested in source code access. + +## GPU Benchmarks + +The GPU backend delivers substantial performance improvements over the CPU implementation. The following benchmarks were run on an NVIDIA GPU, showcasing the throughput for various configurations. + +### Performance Results + +| Field | Throughput (8192 proofs) | Throughput (16384 proofs, MPI Merged) | +|------------------|--------------------------|---------------------------------------| +| `m31ext3` | ~2788 proofs/sec | ~3040 computations/sec | +| `goldilocksext2` | ~2597 proofs/sec | ~2255 computations/sec | +| `bn254` | ~1313 proofs/sec | ~1525 computations/sec | + +**Note on BN254 Performance**: The GPU acceleration is particularly impactful for the `BN254` field. Compared to our highly optimized AVX512 CPU backend, **the GPU implementation provides a 7-10x speedup** compared to AMD 9950X3D, achieving over 1500 merged computations per second. This makes Expander an ideal choice for ZK applications built on Ethereum-friendly curves. + +### Running Benchmarks Manually and Profiling + +You can reproduce these benchmarks using the `Makefile`: + +```sh +# Run standard benchmark with 8192 parallel proofs +make profile + +# Run benchmark with 16384 parallel proofs and MPI merge enabled +make mpi-profile + +# Run standard benchmark with detailed profiling data +make profile PROFILE_LEVEL=2 + +# Run standard benchmark with detailed profiling data +make mpi-profile PROFILE_LEVEL=2 +``` + +You can customize the `FIELD_TYPE` and `PROFILE_LEVEL` variables in the `Makefile` to test different configurations. You should be able to see a detailed profiling report as below. + +``` +====== GKR System Initialization ====== +Parsed RZ0 Challenge: 0x128e207ced0a98b1401e2e521465544111847e131de192a5f527ecbd1611d6b0 + +GPU Memory Allocation Summary: + Circuit: 29.47 MB (30898000 B) + Transcript: 4.03 GB (4330817408 B) + Scratchpad: 14.25 GB (15303180288 B) + Total: 18.31 GB (19664895696 B) + +MPI Merge Status: + MPI Length: 8192 (independent computations) + Number of Proofs: 8192 (final transcripts) + MPI Merge Enabled: NO + +System Configuration: + Circuit Layers: 144 layers + MPI Length: 8192 + Enable MPI Merge: false + Field Type: bn254 + Fiat-Shamir Type: sha2-256 + Max Input Variables: 13 + Max Output Variables: 13 + +Prove Done! Final Claims: + vx_claim = [0x08d2107f3419f056dda4310fd9de72a8eca95840b26a20068d70262ea9495086] + vy_claim = [0x18e99e28f39df8da3cea05e6382991fa69fb57053d500112de6dab091267656c] + +====== GKR Hierarchical Profiling Results (with GPU timing) ====== +Function Name Call Count Total Time (s) Avg Time (ms) % of Total +---------------------------------------- ------------ --------------- --------------- ---------- +Sumcheck 287 2.790811 9.724 49.89 % + - receive_challange 3513 1.292797 0.368 23.11 % + - poly_eval_at 3513 1.140716 0.325 20.39 % + - Fiat-shamir(sumcheck) 3513 0.329160 0.094 5.88 % + - Apply phase 2 coef 1754 0.021809 0.012 0.39 % +Prepare H(x) 144 1.603928 11.138 28.67 % + - eq_eval_at 287 0.682776 2.379 12.21 % + - eq_eval_combine 287 0.388053 1.352 6.94 % + - scatter_to_build_eq_buf 3504 0.241123 0.069 4.31 % + - scatter_to_first_element 574 0.034887 0.061 0.62 % + - build_hgx_mult_and_add 144 0.616053 4.278 11.01 % + - build_hgx_mult 143 0.376732 2.634 6.74 % + - build_hgx_add 144 0.237986 1.653 4.25 % + - acc_from_rx_to_rz0 142 0.379750 2.674 6.79 % + - memset_clear_x_vals 143 0.275147 1.924 4.92 % +Prepare H(y) 143 1.182382 8.268 21.14 % + - build_hgy_mult_only 143 0.463457 3.241 8.29 % + - memset_clear_y_vals 143 0.367486 2.570 6.57 % +Fiat-shamir(gkr) 717 0.016354 0.023 0.29 % +TOTAL - 5.593475 - 100.00% +============================================= + +====== Expander-GPU Performance Metrics ====== +Field element type: bn254 +Fiat-shamir type: sha2-256 +GKR proof size: 379232 bytes +GKR proof time: 5.594314 seconds +Proofs per second: 1464.34 proof/sec +``` + +## Acknowledgments + +The code of Expander-GPU is derived from the [ICICLE project](https://github.com/ingonyama-zk/icicle). +We are grateful to the ICICLE team for their contributions to the community, providing efficient field element operations on GPU that enable high-performance cryptographic computations. diff --git a/gpu/src/lib.rs b/gpu/src/lib.rs new file mode 100644 index 000000000..a65e9cfc3 --- /dev/null +++ b/gpu/src/lib.rs @@ -0,0 +1 @@ +pub mod serdes; diff --git a/gpu/src/serdes.rs b/gpu/src/serdes.rs new file mode 100644 index 000000000..de011d2aa --- /dev/null +++ b/gpu/src/serdes.rs @@ -0,0 +1,494 @@ +//! Circuit serialization functionality for GPU processing + +use circuit::Circuit; +use gkr_engine::{ExpanderDualVarChallenge, FieldEngine, Transcript}; +use std::fs::{File, OpenOptions}; +use std::io::{BufWriter, Write}; + +#[derive(thiserror::Error, Debug)] +pub enum SerializationError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Failed to parse field element: {0}")] + FieldParse(String), +} + +// Helper function to serialize a field element to a string representation +fn serialize_field_element(element: &C::CircuitField) -> String { + // Get element size for auxiliary decision making + let element_size = std::mem::size_of::(); + + // For BN254 case (element size is typically 32 bytes, i.e., 8 u32) + if element_size == 32 { + // Convert to u8 slice to inspect actual content + let element_bytes = unsafe { + std::slice::from_raw_parts( + (element as *const C::CircuitField) as *const u8, + element_size, + ) + }; + + // Convert bytes to hexadecimal representation + let mut hex_str = String::with_capacity(2 + element_size * 2); // "0x" + two hexadecimal characters per byte + hex_str.push_str("0x"); + + // Append in little-endian order (from most significant byte) + for i in (0..element_size).rev() { + let byte = element_bytes[i]; + hex_str.push_str(&format!("{byte:02x}")); + } + + // Check if it's zero value (all bytes are 0) + let is_zero = element_bytes.iter().all(|&b| b == 0); + if is_zero { + return "0x0000000000000000000000000000000000000000000000000000000000000000" + .to_string(); + } + + hex_str + } else { + // For other field types: m31ext3 and goldilocks + let debug_str = format!("{element:?}"); + + // Try to extract value from "FieldType { v: value }" format + if let Some(v_pos) = debug_str.find("v: ") { + let after_v = &debug_str[v_pos + 3..]; + if let Some(end_pos) = after_v.find(" }") { + let value_str = &after_v[..end_pos]; + return value_str.trim().to_string(); + } else if let Some(end_pos) = after_v.find("}") { + let value_str = &after_v[..end_pos]; + return value_str.trim().to_string(); + } + } + + // Panic if parsing fails + panic!("Failed to parse field element value from debug string: {debug_str}") + } +} + +// Helper function to serialize a SIMD field element to a string representation +fn serialize_simd_field_element(element: &C::SimdCircuitField) -> String +where + C::SimdCircuitField: std::fmt::Debug, +{ + // Get element size for auxiliary decision making + let element_size = std::mem::size_of::(); + + // For BN254 case (element size is typically 32 bytes * SIMD width) + if element_size >= 32 && element_size % 32 == 0 { + // Convert to u8 slice to inspect actual content + let element_bytes = unsafe { + std::slice::from_raw_parts( + (element as *const C::SimdCircuitField) as *const u8, + element_size, + ) + }; + + // For SIMD, we might have multiple 32-byte field elements + let num_elements = element_size / 32; + if num_elements == 1 { + // Single element case + let mut hex_str = String::with_capacity(2 + 32 * 2); + hex_str.push_str("0x"); + + // Append in little-endian order (from most significant byte) + for i in (0..32).rev() { + let byte = element_bytes[i]; + hex_str.push_str(&format!("{byte:02x}")); + } + + // Check if it's zero value (all bytes are 0) + let is_zero = element_bytes[..32].iter().all(|&b| b == 0); + if is_zero { + return "0x0000000000000000000000000000000000000000000000000000000000000000" + .to_string(); + } + + hex_str + } else { + // Multiple elements case - format as array + let mut result = String::from("["); + for elem_idx in 0..num_elements { + if elem_idx > 0 { + result.push(','); + } + + let start_byte = elem_idx * 32; + let end_byte = start_byte + 32; + + let mut hex_str = String::with_capacity(2 + 32 * 2); + hex_str.push_str("0x"); + + // Append in little-endian order (from most significant byte) + for i in (start_byte..end_byte).rev() { + let byte = element_bytes[i]; + hex_str.push_str(&format!("{byte:02x}")); + } + + // Check if it's zero value + let is_zero = element_bytes[start_byte..end_byte].iter().all(|&b| b == 0); + if is_zero { + result.push_str( + "0x0000000000000000000000000000000000000000000000000000000000000000", + ); + } else { + result.push_str(&hex_str); + } + } + result.push(']'); + result + } + } else { + // For other field types: m31ext3 and goldilocks + let debug_str = format!("{element:?}"); + + // Handle array format: [Type { v: val1 }, Type { v: val2 }, ...] + if debug_str.contains('[') && debug_str.contains(']') { + if let Some(start) = debug_str.find('[') { + if let Some(end) = debug_str.rfind(']') { + let array_content = &debug_str[start + 1..end]; + + // Split by comma and extract values from each element + let elements: Vec<&str> = array_content.split(',').collect(); + let mut extracted_values = Vec::new(); + + for element_str in elements { + let element_str = element_str.trim(); + // Try to extract value from "FieldType { v: value }" format + if let Some(v_pos) = element_str.find("v: ") { + let after_v = &element_str[v_pos + 3..]; + if let Some(end_pos) = after_v.find(" }") { + let value_str = &after_v[..end_pos]; + extracted_values.push(value_str.trim().to_string()); + } else if let Some(end_pos) = after_v.find("}") { + let value_str = &after_v[..end_pos]; + extracted_values.push(value_str.trim().to_string()); + } else { + // Fallback to original element string if parsing fails + extracted_values.push(element_str.to_string()); + } + } else { + // Fallback to original element string if no "v: " found + extracted_values.push(element_str.to_string()); + } + } + + return format!("[{}]", extracted_values.join(",")); + } + } + } else { + // Single element case - try to extract value from "FieldType { v: value }" format + if let Some(v_pos) = debug_str.find("v: ") { + let after_v = &debug_str[v_pos + 3..]; + if let Some(end_pos) = after_v.find(" }") { + let value_str = &after_v[..end_pos]; + return value_str.trim().to_string(); + } else if let Some(end_pos) = after_v.find("}") { + let value_str = &after_v[..end_pos]; + return value_str.trim().to_string(); + } + } + } + + // Panic if parsing fails + panic!("Failed to parse SIMD field element value from debug string: {debug_str}") + } +} + +/// Serialize circuit to a file compatible with circuit.cuh +pub fn serialize_circuit_to_file( + circuit: &Circuit, + filepath: &str, +) -> Result<(), SerializationError> +where + C::CircuitField: std::fmt::Debug, + C::SimdCircuitField: std::fmt::Debug, +{ + // Create data directory if it doesn't exist + if let Some(parent) = std::path::Path::new(filepath).parent() { + std::fs::create_dir_all(parent)?; + } + + // Use BufWriter to improve write efficiency + let file = File::create(filepath)?; + let mut writer = BufWriter::with_capacity(8 * 1024 * 1024, file); // 8MB buffer + + // Determine field type based on field size + let field_type = match std::mem::size_of::() { + 32 => "bn254", // BN254 field + 4 => "m31ext3", // m31ext3 field + 8 => "goldilocksext2", // goldilocks field + _ => "unknown", // unknown field + }; + + // Write header: TotalLayer [layer_count] [field_type] + writeln!(writer, "TotalLayer {} {}", circuit.layers.len(), field_type)?; + writer.flush()?; // Immediately flush header information + + // Count total items to serialize for progress calculation (gates + values) + let total_gates: usize = circuit + .layers + .iter() + .map(|layer| layer.add.len() + layer.mul.len()) + .sum(); + let total_values: usize = circuit + .layers + .iter() + .map(|layer| layer.input_vals.len() + layer.output_vals.len()) + .sum(); + let total_items = total_gates + total_values; + + let mut items_processed = 0; + let mut last_percent = 0; + + // Process each layer + for (layer_idx, layer) in circuit.layers.iter().enumerate() { + // Write layer header: Layer [num_gate_add] [num_gate_mul] [input_var_num] [output_var_num] + // [input_vals_count] [output_vals_count] + writeln!( + writer, + "Layer[{}] {} {} {} {} {} {}", + layer_idx, + layer.add.len(), + layer.mul.len(), + layer.input_var_num, + layer.output_var_num, + layer.input_vals.len(), + layer.output_vals.len() + )?; + + // Write input values + if !layer.input_vals.is_empty() { + writeln!(writer, "=====Input Values=====")?; + for (idx, input_val) in layer.input_vals.iter().enumerate() { + let val_str = serialize_simd_field_element::(input_val); + writeln!(writer, "InputVal[{idx}] {val_str}")?; + items_processed += 1; + } + } + + // Write output values + if !layer.output_vals.is_empty() { + writeln!(writer, "=====Output Values=====")?; + for (idx, output_val) in layer.output_vals.iter().enumerate() { + let val_str = serialize_simd_field_element::(output_val); + writeln!(writer, "OutputVal[{idx}] {val_str}")?; + items_processed += 1; + } + } + + // Write gates section marker + if !layer.add.is_empty() || !layer.mul.is_empty() { + writeln!(writer, "=====Gates=====")?; + } + + // Every 10 layers or large layers force flush buffer + if layer_idx % 10 == 0 || layer.add.len() + layer.mul.len() > 10000 { + writer.flush()?; + } + + // Write add gates + for add_gate in &layer.add { + // Serialize coef to appropriate string format + let coef_str = serialize_field_element::(&add_gate.coef); + + // Write add gate: Add [input_idx] [output_idx] [coef] + writeln!( + writer, + "Add {} {} {}", + add_gate.i_ids[0], add_gate.o_id, coef_str + )?; + + items_processed += 1; + } + + // Write mul gates + for mul_gate in &layer.mul { + // Serialize coef to appropriate string format + let coef_str = serialize_field_element::(&mul_gate.coef); + + // Write mul gate: Mul [input_left_idx],[input_right_idx] [output_idx] [coef] + writeln!( + writer, + "Mul {},{} {} {}", + mul_gate.i_ids[0], mul_gate.i_ids[1], mul_gate.o_id, coef_str + )?; + + items_processed += 1; + } + + // Flush buffer after each layer to ensure data written to disk + writer.flush()?; + + // Calculate and display progress + let percent = if total_items > 0 { + (items_processed * 100) / total_items + } else { + 100 + }; + if percent > last_percent && percent % 5 == 0 { + println!( + "Serialization progress: {percent}% (processed {items_processed}/{total_items} items: {total_gates} gates + {total_values} values)" + ); + last_percent = percent; + } + } + + // Final flush and close file + writer.flush()?; + drop(writer); // Explicitly close file + + // Output first and last layer gate counts and values for verification + if !circuit.layers.is_empty() { + let first_layer = &circuit.layers[0]; + let last_layer = &circuit.layers[circuit.layers.len() - 1]; + + println!("First layer: {} addition gates, {} multiplication gates, {} input values, {} output values", + first_layer.add.len(), first_layer.mul.len(), first_layer.input_vals.len(), first_layer.output_vals.len()); + println!("Last layer: {} addition gates, {} multiplication gates, {} input values, {} output values", + last_layer.add.len(), last_layer.mul.len(), last_layer.input_vals.len(), last_layer.output_vals.len()); + } + + // Verify file write success + match std::fs::metadata(filepath) { + Ok(metadata) => { + // File size should be related to total items (gates + values) + let expected_min_size = total_items * 15; // Roughly estimate each item at least 15 bytes + if metadata.len() < expected_min_size as u64 { + println!("Warning: file size may be insufficient, please check if fully written"); + } + println!( + "Successfully serialized {total_gates} gates and {total_values} values to file (total {total_items} items)" + ); + } + Err(e) => println!("Unable to verify file: {e}"), + } + + Ok(()) +} + +/// Serialize witness as plaintext to a file +pub fn serial_circuit_witness_as_plaintext( + circuit: &Circuit, + transcript: &mut impl Transcript, + challenge: &ExpanderDualVarChallenge, +) -> Result<(), SerializationError> +where + F::CircuitField: std::fmt::Debug, + F::SimdCircuitField: std::fmt::Debug, +{ + // Determine field type and construct filename + let field_type = match std::mem::size_of::() { + 32 => "bn254", + 4 => "m31ext3", + 8 => "goldilocksext2", + _ => "unknown", + }; + let filepath = format!("data/keccak_{field_type}.gpu.circuit"); + + // Check if file already exists + if std::path::Path::new(&filepath).exists() { + println!("Circuit file {filepath} already exists, skipping serialization"); + return Ok(()); + } + + // Perform serialization + println!("GPU enabled, serializing circuit to {filepath}"); + serialize_circuit_to_file(circuit, &filepath)?; + println!("Successfully serialized circuit to {filepath}"); + + // Get digest, proof bytes, and hash_start_index directly from memory using unsafe code + let (digest_bytes, proof_bytes, hash_start_index) = unsafe { + use gkr_hashers::SHA256hasher; + use transcript::BytesHashTranscript; + + // Cast to BytesHashTranscript - the hasher type doesn't matter since we only access + // digest, proof, and hash_start_index fields which have the same layout regardless of H + let transcript_ptr = transcript as *mut _ as *mut BytesHashTranscript; + let bytes_transcript = &*transcript_ptr; + + ( + &bytes_transcript.digest, + &bytes_transcript.proof.bytes, + bytes_transcript.hash_start_index, + ) + }; + + // Check if digest has reasonable size (we expect at least 16 bytes for security) + if digest_bytes.len() < 16 { + panic!("Transcript digest too small: {} bytes", digest_bytes.len()); + } + + // Append to the circuit file + let mut file = OpenOptions::new() + .append(true) + .create(true) + .open(&filepath)?; + + // Write transcript start marker + writeln!(file, "=====Transcript Start=====")?; + writeln!(file, "TranscriptDigestByte={}", digest_bytes.len())?; + writeln!(file, "TranscriptProofByte={}", proof_bytes.len())?; + writeln!(file, "TranscriptHashStartIndex={hash_start_index}")?; + + // Write digest bytes first + writeln!(file, "=====Digest Bytes=====")?; + for (i, chunk) in digest_bytes.chunks(40).enumerate() { + let start_idx = i * 40; + let end_idx = start_idx + chunk.len() - 1; + + // Format bytes with leading zeros and join with commas + let formatted_bytes: Vec = chunk.iter().map(|&byte| format!("{byte:03}")).collect(); + let line = formatted_bytes.join(","); + + // Write line with range annotation + writeln!(file, "{line} //digest[{start_idx}-{end_idx}]")?; + } + + // Write proof bytes + writeln!(file, "=====Proof Bytes=====")?; + for (i, chunk) in proof_bytes.chunks(40).enumerate() { + let start_idx = i * 40; + let end_idx = start_idx + chunk.len() - 1; + + // Format bytes with leading zeros and join with commas + let formatted_bytes: Vec = chunk.iter().map(|&byte| format!("{byte:03}")).collect(); + let line = formatted_bytes.join(","); + + // Write line with range annotation + writeln!(file, "{line} //proof[{start_idx}-{end_idx}]")?; + } + + // Write transcript end marker + writeln!(file, "=====Transcript End=====")?; + + // Write the challenge + writeln!(file, "Challenge: {challenge:?}")?; + + file.flush()?; + + println!("Transcript digest and proof bytes written to {filepath}"); + println!( + "Total digest bytes: {}, Total proof bytes: {}, Hash start index: {}", + digest_bytes.len(), + proof_bytes.len(), + hash_start_index + ); + + Ok(()) +} + +/// Print final claims to console (not to file) +pub fn print_final_claims( + vx_claim: &F::ChallengeField, + vy_claim: &Option, +) { + println!("=====Final Claims====="); + println!("vx_claim = {vx_claim:?}"); + if let Some(vy) = vy_claim { + println!("vy_claim = {vy:?}"); + } else { + println!("vy_claim = None"); + } +} diff --git a/transcript/src/byte_hash_transcript.rs b/transcript/src/byte_hash_transcript.rs index d690dc724..cf7f1c18f 100644 --- a/transcript/src/byte_hash_transcript.rs +++ b/transcript/src/byte_hash_transcript.rs @@ -15,10 +15,10 @@ pub struct BytesHashTranscript { pub digest: Vec, /// The proof bytes. - proof: Proof, + pub proof: Proof, /// The pointer to the proof bytes indicating where the hash starts. - hash_start_index: usize, + pub hash_start_index: usize, /// locking point proof_locked: bool,