diff --git a/bin/src/dev_setup.rs b/bin/src/dev_setup.rs index ac7379c39..febb06720 100644 --- a/bin/src/dev_setup.rs +++ b/bin/src/dev_setup.rs @@ -45,7 +45,10 @@ fn main() { } } -fn proof_gen() { +fn proof_gen() +where + C::FieldConfig: FieldEngine, +{ let mpi_config = MPIConfig::prover_new(); // load circuit @@ -81,7 +84,7 @@ fn proof_gen() { circuit.evaluate(); let (pcs_params, pcs_proving_key, pcs_verification_key, pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit.log_input_size(), &mpi_config, ); diff --git a/bin/src/executor.rs b/bin/src/executor.rs index 87a6469c8..3e2d3654d 100644 --- a/bin/src/executor.rs +++ b/bin/src/executor.rs @@ -124,13 +124,16 @@ pub fn prove( ) -> ( <::FieldConfig as FieldEngine>::ChallengeField, Proof, -) { +) +where + Cfg::FieldConfig: FieldEngine, +{ let mut prover = Prover::::new(mpi_config.clone()); prover.prepare_mem(circuit); // TODO: Read PCS setup from files let (pcs_params, pcs_proving_key, _, mut pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit.log_input_size(), &mpi_config, ); @@ -148,6 +151,7 @@ pub fn verify( // TODO: Read PCS setup from files let (pcs_params, _, pcs_verification_key, _) = expander_pcs_init_testing_only::< Cfg::FieldConfig, + Cfg::PCSPolyField, Cfg::PCSConfig, >(circuit.log_input_size(), &mpi_config); let verifier = Verifier::::new(mpi_config); @@ -165,7 +169,9 @@ pub fn verify( pub async fn run_command( command: &ExpanderExecArgs, mpi_config: &MPIConfig, -) { +) where + Cfg::FieldConfig: FieldEngine, +{ let subcommands = command.subcommands.clone(); match subcommands { @@ -263,7 +269,7 @@ pub async fn run_command( // TODO: Read PCS setup from files let (pcs_params, pcs_proving_key, pcs_verification_key, pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit.log_input_size(), &prover.mpi_config, ); diff --git a/bin/src/main.rs b/bin/src/main.rs index fe9fed42a..da25cb78b 100644 --- a/bin/src/main.rs +++ b/bin/src/main.rs @@ -127,8 +127,9 @@ fn main() { fn run_benchmark<'a, Cfg: GKREngine>(args: &'a Args, mpi_config: MPIConfig) where - >::ScratchPad: 'a, - >::ScratchPad: 'static, + >::ScratchPad: 'a, + >::ScratchPad: 'static, + Cfg::FieldConfig: FieldEngine, { let partial_proof_cnts = (0..args.threads) .map(|_| Arc::new(Mutex::new(0))) @@ -223,7 +224,7 @@ where println!("Circuit loaded!"); let (pcs_params, pcs_proving_key, _pcs_verification_key, pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit_template.log_input_size(), &mpi_config, ); diff --git a/bin/src/main_mpi.rs b/bin/src/main_mpi.rs index 2145270fe..fc2094059 100644 --- a/bin/src/main_mpi.rs +++ b/bin/src/main_mpi.rs @@ -117,7 +117,10 @@ fn main() { MPIConfig::finalize(); } -fn run_benchmark(args: &Args, mpi_config: MPIConfig) { +fn run_benchmark(args: &Args, mpi_config: MPIConfig) +where + Cfg::FieldConfig: FieldEngine, +{ let pack_size = ::get_field_pack_size(); // load circuit @@ -196,7 +199,7 @@ fn run_benchmark(args: &Args, mpi_config: MPIConfig) { prover.prepare_mem(&circuit); let (pcs_params, pcs_proving_key, _pcs_verification_key, mut pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit.log_input_size(), &mpi_config, ); diff --git a/circuit/src/layered/circuit.rs b/circuit/src/layered/circuit.rs index 2dc594453..1590b02b6 100644 --- a/circuit/src/layered/circuit.rs +++ b/circuit/src/layered/circuit.rs @@ -385,7 +385,7 @@ impl Circuit { // If there will be two claims for the input // Introduce an extra relay layer before the input layer if !self.layers[0].structure_info.skip_sumcheck_phase_two { - match >::PCS_TYPE { + match >::PCS_TYPE { // Raw PCS costs nothing in opening, so no need to add relay layer // But we can probably add it in the future for verifier's convenience PolynomialCommitmentType::Raw => (), diff --git a/gkr/benches/gkr_hashes.rs b/gkr/benches/gkr_hashes.rs index 90f054065..7c776f5d2 100644 --- a/gkr/benches/gkr_hashes.rs +++ b/gkr/benches/gkr_hashes.rs @@ -20,10 +20,12 @@ use transcript::BytesHashTranscript; fn prover_run( mpi_config: &MPIConfig, circuit: &mut Circuit, - pcs_params: &>::Params, - pcs_proving_key: &<>::SRS as StructuredReferenceString>::PKey, - pcs_scratch: &mut >::ScratchPad, -) { + pcs_params: &>::Params, + pcs_proving_key: &<>::SRS as StructuredReferenceString>::PKey, + pcs_scratch: &mut >::ScratchPad, +) where + Cfg::FieldConfig: FieldEngine, +{ let mut prover = Prover::::new(mpi_config.clone()); prover.prepare_mem(circuit); prover.prove(circuit, pcs_params, pcs_proving_key, pcs_scratch); @@ -35,10 +37,10 @@ fn benchmark_setup( ) -> ( MPIConfig, Circuit, - >::Params, - <>::SRS as StructuredReferenceString>::PKey, - >::ScratchPad, -) { + >::Params, + <>::SRS as StructuredReferenceString>::PKey, + >::ScratchPad, +){ let mpi_config = MPIConfig::prover_new(); let mut circuit = Circuit::::single_thread_prover_load_circuit::(circuit_file); @@ -50,7 +52,7 @@ fn benchmark_setup( } let (pcs_params, pcs_proving_key, _pcs_verification_key, pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit.log_input_size(), &mpi_config, ); diff --git a/gkr/src/prover/snark.rs b/gkr/src/prover/snark.rs index 8cea58c04..b6610eebe 100644 --- a/gkr/src/prover/snark.rs +++ b/gkr/src/prover/snark.rs @@ -84,10 +84,13 @@ impl Prover { pub fn prove( &mut self, c: &mut Circuit, - pcs_params: &>::Params, - pcs_proving_key: &<>::SRS as StructuredReferenceString>::PKey, - pcs_scratch: &mut >::ScratchPad, - ) -> (::ChallengeField, Proof) { + pcs_params: &>::Params, + pcs_proving_key: &<>::SRS as StructuredReferenceString>::PKey, + pcs_scratch: &mut >::ScratchPad, + ) -> (::ChallengeField, Proof) + where + Cfg::FieldConfig: FieldEngine, + { let proving_timer = Timer::new("prover", self.mpi_config.is_root()); let mut transcript = Cfg::TranscriptConfig::new(); @@ -195,11 +198,13 @@ impl Prover { &self, inputs: &mut MutRefMultiLinearPoly<::SimdCircuitField>, open_at: &mut ExpanderSingleVarChallenge, - pcs_params: &>::Params, - pcs_proving_key: &<>::SRS as StructuredReferenceString>::PKey, - pcs_scratch: &mut >::ScratchPad, + pcs_params: &>::Params, + pcs_proving_key: &<>::SRS as StructuredReferenceString>::PKey, + pcs_scratch: &mut >::ScratchPad, transcript: &mut impl Transcript, - ) { + ) where + Cfg::FieldConfig: FieldEngine, + { let original_input_vars = inputs.num_vars(); let minimum_vars_for_pcs: usize = pcs_params.num_vars(); diff --git a/gkr/src/tests/gkr_correctness.rs b/gkr/src/tests/gkr_correctness.rs index 48f861f6c..c250911dc 100644 --- a/gkr/src/tests/gkr_correctness.rs +++ b/gkr/src/tests/gkr_correctness.rs @@ -163,7 +163,10 @@ fn test_gkr_correctness() { } #[allow(unreachable_patterns)] -fn test_gkr_correctness_helper(write_proof_to: Option<&str>) { +fn test_gkr_correctness_helper(write_proof_to: Option<&str>) +where + Cfg::FieldConfig: FieldEngine, +{ let mpi_config = MPIConfig::prover_new(); root_println!(mpi_config, "============== start ==============="); @@ -226,7 +229,7 @@ fn test_gkr_correctness_helper(write_proof_to: Option<&str>) { prover.prepare_mem(&circuit); let (pcs_params, pcs_proving_key, pcs_verification_key, mut pcs_scratch) = - expander_pcs_init_testing_only::( + expander_pcs_init_testing_only::( circuit.log_input_size(), &mpi_config, ); diff --git a/gkr/src/verifier/snark.rs b/gkr/src/verifier/snark.rs index 6061fed73..a83ca8208 100644 --- a/gkr/src/verifier/snark.rs +++ b/gkr/src/verifier/snark.rs @@ -48,10 +48,10 @@ impl Verifier { circuit: &mut Circuit, transcript: &mut Cfg::TranscriptConfig, proving_time_mpi_size: usize, - ) -> >::Commitment { + ) -> >::Commitment { let timer = Timer::new("pre_gkr", true); let commitment = - <>::Commitment as ExpSerde>::deserialize_from( + <>::Commitment as ExpSerde>::deserialize_from( &mut proof_reader, ) .unwrap(); @@ -258,9 +258,12 @@ impl Verifier { #[allow(clippy::type_complexity)] pub(crate) fn post_gkr( &self, - pcs_params: &>::Params, - pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, - commitment: &>::Commitment, + pcs_params: &>::Params, + pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, + commitment: &>::Commitment, challenge_x: &mut ExpanderSingleVarChallenge, claim_x: &::ChallengeField, challenge_y: &mut Option>, @@ -300,8 +303,11 @@ impl Verifier { circuit: &mut Circuit, public_input: &[::SimdCircuitField], claimed_v: &::ChallengeField, - pcs_params: &>::Params, - pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, + pcs_params: &>::Params, + pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, proof: &Proof, ) -> bool { let timer = Timer::new("snark verify", true); @@ -343,8 +349,11 @@ impl Verifier { circuit: &mut Circuit, public_input: &[::SimdCircuitField], claimed_v: &::ChallengeField, - pcs_params: &>::Params, - pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, + pcs_params: &>::Params, + pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, proof: &Proof, ) -> bool { let timer = Timer::new("snark verify", true); @@ -385,15 +394,18 @@ impl Verifier { #[allow(clippy::too_many_arguments)] fn get_pcs_opening_from_proof_and_verify( &self, - pcs_params: &>::Params, - pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, - commitment: &>::Commitment, + pcs_params: &>::Params, + pcs_verification_key: &<>::SRS as StructuredReferenceString>::VKey, + commitment: &>::Commitment, open_at: &mut ExpanderSingleVarChallenge, v: &::ChallengeField, transcript: &mut impl Transcript, proof_reader: impl Read, ) -> bool { - let opening = >::Opening::deserialize_from( + let opening = >::Opening::deserialize_from( proof_reader, ) .unwrap(); diff --git a/gkr_engine/src/lib.rs b/gkr_engine/src/lib.rs index 6098dc9b8..9c5bb47a1 100644 --- a/gkr_engine/src/lib.rs +++ b/gkr_engine/src/lib.rs @@ -12,6 +12,7 @@ //! - A Config is a struct that implements the Engine trait and contains the parameters for the GKR //! protocol #![allow(clippy::manual_div_ceil)] +#![feature(associated_type_defaults)] mod errors; mod field_engine; @@ -20,6 +21,7 @@ mod poly_commit; mod scheme; mod transcript; +use arith::Field; pub use errors::*; pub use field_engine::*; pub use mpi_engine::*; @@ -71,7 +73,8 @@ pub trait GKREngine: Send + Sync { type TranscriptConfig: Transcript; /// Configuration for polynomial commitment scheme - type PCSConfig: ExpanderPCS; + type PCSPolyField: Field = <::FieldConfig as FieldEngine>::SimdCircuitField; + type PCSConfig: ExpanderPCS; /// GKR scheme const SCHEME: GKRScheme; diff --git a/gkr_engine/src/poly_commit/definition.rs b/gkr_engine/src/poly_commit/definition.rs index f4a870b7b..6a2c3624f 100644 --- a/gkr_engine/src/poly_commit/definition.rs +++ b/gkr_engine/src/poly_commit/definition.rs @@ -1,3 +1,4 @@ +use arith::Field; use polynomials::MultilinearExtension; use rand::RngCore; use serdes::ExpSerde; @@ -25,7 +26,13 @@ impl PCSParams for usize { } } -pub trait ExpanderPCS { +/// This trait specifies the field used on Expander side. +/// PolyField: the field of the coef of polynomial sent to PCS +/// ChallengeField: the field of the challenge point, should be FieldEngine::ChallengeField +/// EvalField: the field of the evaluation, in the current use case always be FieldEngine::Field +/// since Polyfield has simd Note that it is not necessary that PolyField can be handled by PCS +/// directly. PolyField should be bounded when implemented if necessary. +pub trait ExpanderPCS { const NAME: &'static str; const PCS_TYPE: PolynomialCommitmentType; @@ -65,7 +72,7 @@ pub trait ExpanderPCS { params: &Self::Params, mpi_engine: &impl MPIEngine, proving_key: &::PKey, - poly: &impl MultilinearExtension, + poly: &impl MultilinearExtension, scratch_pad: &mut Self::ScratchPad, ) -> Option; @@ -94,7 +101,7 @@ pub trait ExpanderPCS { params: &Self::Params, mpi_engine: &impl MPIEngine, proving_key: &::PKey, - poly: &impl MultilinearExtension, + poly: &impl MultilinearExtension, x: &ExpanderSingleVarChallenge, transcript: &mut impl Transcript, scratch_pad: &Self::ScratchPad, diff --git a/poly_commit/src/hyrax/expander_api.rs b/poly_commit/src/hyrax/expander_api.rs index 7ebd4e480..d7f3523e1 100644 --- a/poly_commit/src/hyrax/expander_api.rs +++ b/poly_commit/src/hyrax/expander_api.rs @@ -18,11 +18,11 @@ use crate::{ HyraxCommitment, HyraxOpening, HyraxPCS, PedersenParams, }; -impl ExpanderPCS for HyraxPCS +impl ExpanderPCS for HyraxPCS where - G: FieldEngine, + G: FieldEngine, C: CurveAffine + ExpSerde, - C::Scalar: ExtensionField + PrimeField, + C::Scalar: ExtensionField + PrimeField + From, C::ScalarExt: ExtensionField + PrimeField, { const NAME: &'static str = "HyraxPCSForExpanderGKR"; @@ -56,7 +56,7 @@ where _params: &Self::Params, mpi_engine: &impl MPIEngine, proving_key: &::PKey, - poly: &impl polynomials::MultilinearExtension<::SimdCircuitField>, + poly: &impl polynomials::MultilinearExtension, _scratch_pad: &mut Self::ScratchPad, ) -> Option { let local_commit = hyrax_commit(proving_key, poly); @@ -83,7 +83,7 @@ where _params: &Self::Params, mpi_engine: &impl MPIEngine, proving_key: &::PKey, - poly: &impl polynomials::MultilinearExtension<::SimdCircuitField>, + poly: &impl polynomials::MultilinearExtension, x: &ExpanderSingleVarChallenge, _transcript: &mut impl Transcript, _scratch_pad: &Self::ScratchPad, diff --git a/poly_commit/src/kzg/expander_api.rs b/poly_commit/src/kzg/expander_api.rs index 11042c72a..5283d2bc1 100644 --- a/poly_commit/src/kzg/expander_api.rs +++ b/poly_commit/src/kzg/expander_api.rs @@ -13,11 +13,11 @@ use serdes::ExpSerde; use crate::*; -impl ExpanderPCS for HyperKZGPCS +impl ExpanderPCS for HyperKZGPCS where - G: FieldEngine, + G: FieldEngine, E: Engine + MultiMillerLoop, - E::Fr: ExtensionField + PrimeField, + E::Fr: ExtensionField + PrimeField + From, E::G1Affine: ExpSerde + Default + CurveAffine, E::G2Affine: ExpSerde + Default + CurveAffine, { @@ -57,7 +57,7 @@ where _params: &Self::Params, mpi_engine: &impl MPIEngine, proving_key: &::PKey, - poly: &impl polynomials::MultilinearExtension<::SimdCircuitField>, + poly: &impl polynomials::MultilinearExtension, _scratch_pad: &mut Self::ScratchPad, ) -> Option { let local_commitment = @@ -84,7 +84,7 @@ where _params: &Self::Params, mpi_engine: &impl MPIEngine, proving_key: &::PKey, - poly: &impl polynomials::MultilinearExtension<::SimdCircuitField>, + poly: &impl polynomials::MultilinearExtension, x: &ExpanderSingleVarChallenge, transcript: &mut impl Transcript, _scratch_pad: &Self::ScratchPad, diff --git a/poly_commit/src/orion/expander_api.rs b/poly_commit/src/orion/expander_api.rs index 404250ea9..0edca4ede 100644 --- a/poly_commit/src/orion/expander_api.rs +++ b/poly_commit/src/orion/expander_api.rs @@ -13,7 +13,7 @@ use crate::orion::{ ORION_CODE_PARAMETER_INSTANCE, }; -impl ExpanderPCS +impl ExpanderPCS for OrionSIMDFieldPCS where C: FieldEngine, diff --git a/poly_commit/src/raw.rs b/poly_commit/src/raw.rs index ba6968a29..978344eee 100644 --- a/poly_commit/src/raw.rs +++ b/poly_commit/src/raw.rs @@ -121,7 +121,7 @@ pub struct RawExpanderGKR { _phantom: std::marker::PhantomData, } -impl ExpanderPCS for RawExpanderGKR { +impl ExpanderPCS for RawExpanderGKR { const NAME: &'static str = "RawExpanderGKR"; const PCS_TYPE: PolynomialCommitmentType = PolynomialCommitmentType::Raw; diff --git a/poly_commit/src/utils.rs b/poly_commit/src/utils.rs index 158750303..fd9320106 100644 --- a/poly_commit/src/utils.rs +++ b/poly_commit/src/utils.rs @@ -3,7 +3,11 @@ use ark_std::test_rng; use gkr_engine::{ExpanderPCS, FieldEngine, MPIConfig, StructuredReferenceString}; #[allow(clippy::type_complexity)] -pub fn expander_pcs_init_testing_only>( +pub fn expander_pcs_init_testing_only< + FieldConfig: FieldEngine, + PCSPolyField: Field, + PCS: ExpanderPCS, +>( n_input_vars: usize, mpi_config: &MPIConfig, ) -> ( @@ -14,9 +18,13 @@ pub fn expander_pcs_init_testing_only>::gen_params(n_input_vars); - let (pcs_setup, calibrated_num_local_simd_vars) = - >::gen_srs_for_testing(&pcs_params, mpi_config, &mut rng); + let mut pcs_params = >::gen_params(n_input_vars); + let (pcs_setup, calibrated_num_local_simd_vars) = >::gen_srs_for_testing( + &pcs_params, mpi_config, &mut rng + ); if n_input_vars < calibrated_num_local_simd_vars { eprintln!( @@ -26,11 +34,14 @@ pub fn expander_pcs_init_testing_only>::gen_params(calibrated_num_local_simd_vars); + pcs_params = >::gen_params( + calibrated_num_local_simd_vars, + ); } let (pcs_proving_key, pcs_verification_key) = pcs_setup.into_keys(); - let pcs_scratch = >::init_scratch_pad(&pcs_params, mpi_config); + let pcs_scratch = + >::init_scratch_pad(&pcs_params, mpi_config); ( pcs_params, diff --git a/poly_commit/tests/common.rs b/poly_commit/tests/common.rs index 92a073dbb..676f3cc7b 100644 --- a/poly_commit/tests/common.rs +++ b/poly_commit/tests/common.rs @@ -44,7 +44,11 @@ pub fn test_pcs>( +pub fn test_pcs_for_expander_gkr< + C: FieldEngine, + T: Transcript, + P: ExpanderPCS, +>( params: &P::Params, mpi_config: &MPIConfig, transcript: &mut T,