diff --git a/.github/workflows/end-to-end.yml b/.github/workflows/end-to-end.yml index 774edeec4..6467536b9 100644 --- a/.github/workflows/end-to-end.yml +++ b/.github/workflows/end-to-end.yml @@ -57,48 +57,90 @@ jobs: - name: Clean stale benchmark artifacts working-directory: noir-examples/noir-passport/merkle_age_check run: | - rm -f ./benchmark-inputs/*.pkp ./benchmark-inputs/*.pkv ./benchmark-inputs/*.np + rm -f ./benchmark-inputs/*.pkp ./benchmark-inputs/*.pkv ./benchmark-inputs/*.np ./benchmark-inputs/*.sp echo "Cleaned stale benchmark artifacts" - - name: Prepare circuits + - name: Start server (prepare + serve) working-directory: noir-examples/noir-passport/merkle_age_check run: | + CIRCUIT_ARGS="" for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do - echo "Preparing $circuit" - cargo run --release --bin provekit-cli prepare ./target/$circuit.json \ - --pkp ./benchmark-inputs/$circuit-prover.pkp \ - --pkv ./benchmark-inputs/$circuit-verifier.pkv - echo "Prepared $circuit" + CIRCUIT_ARGS="$CIRCUIT_ARGS --circuit $circuit:./target/$circuit.json" done + rm -f /tmp/spark.sock + ../../../target/release/provekit-cli serve \ + --socket /tmp/spark.sock \ + --output-dir ./benchmark-inputs \ + $CIRCUIT_ARGS > /tmp/spark-server.log 2>&1 & + SERVER_PID=$! + echo "SERVER_PID=$SERVER_PID" >> "$GITHUB_ENV" + # Wait for READY + for i in $(seq 1 600); do + if [ -S /tmp/spark.sock ]; then + echo "Server ready after ${i}s" + break + fi + if ! kill -0 $SERVER_PID 2>/dev/null; then + echo "Server process died during preparation" + exit 1 + fi + sleep 1 + done + if [ ! -S /tmp/spark.sock ]; then + echo "Timed out waiting for server" + exit 1 + fi - - name: Generate proofs for all circuits + - name: Prove and request spark proofs for all circuits working-directory: noir-examples/noir-passport/merkle_age_check run: | for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do echo "Proving $circuit" - cargo run --release --bin provekit-cli prove \ - ./benchmark-inputs/$circuit-prover.pkp \ + ../../../target/release/provekit-cli prove \ + ./benchmark-inputs/$circuit.pkp \ ./benchmark-inputs/tbs_720/$circuit.toml \ - -o ./benchmark-inputs/$circuit-proof.np + -o ./benchmark-inputs/$circuit-proof.np \ + --socket /tmp/spark.sock \ + --circuit $circuit \ + --spark-out ./benchmark-inputs/$circuit-spark-proof.sp echo "Proved $circuit" done + - name: Stop SPARK server + if: always() + run: | + kill $SERVER_PID 2>/dev/null || true + echo "=== SPARK server log ===" + cat /tmp/spark-server.log 2>/dev/null || true + - name: Verify proofs for all circuits working-directory: noir-examples/noir-passport/merkle_age_check run: | for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do echo "Verifying $circuit" - cargo run --release --bin provekit-cli verify \ - ./benchmark-inputs/$circuit-verifier.pkv \ + ../../../target/release/provekit-cli verify \ + ./benchmark-inputs/$circuit.pkv \ ./benchmark-inputs/$circuit-proof.np echo "Verified $circuit" done + - name: Verify SPARK proofs for all circuits + working-directory: noir-examples/noir-passport/merkle_age_check + run: | + for circuit in t_add_dsc_720 t_add_id_data_720 t_add_integrity_commit t_attest; do + echo "SPARK verifying $circuit" + ../../../target/release/provekit-cli verify \ + ./benchmark-inputs/$circuit.pkv \ + ./benchmark-inputs/$circuit-proof.np \ + --spark-proof ./benchmark-inputs/$circuit-spark-proof.sp + echo "SPARK verified $circuit" + done + - name: Generate Gnark inputs working-directory: noir-examples/noir-passport/merkle_age_check run: | - cargo run --release --bin provekit-cli generate-gnark-inputs \ - ./benchmark-inputs/t_attest-verifier.pkv \ + ../../../target/release/provekit-cli generate-gnark-inputs \ + ./benchmark-inputs/t_attest.pkv \ ./benchmark-inputs/t_attest-proof.np diff --git a/.gitignore b/.gitignore index ad7ca8c30..c715e8195 100644 --- a/.gitignore +++ b/.gitignore @@ -14,11 +14,13 @@ *.pkp *.pkv *.np +*.sp params_for_recursive_verifier params artifacts/ spartan_vm_debug/ mavros_debug/ +mavros/ # Don't ignore benchmarking artifacts !tooling/provekit-bench/benches/* diff --git a/Cargo.lock b/Cargo.lock index fc5371a21..40e9e7822 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -136,9 +136,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.21" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", "anstyle-parse", @@ -151,15 +151,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" -version = "0.2.7" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" dependencies = [ "utf8parse", ] @@ -207,9 +207,9 @@ checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" [[package]] name = "argh" -version = "0.1.15" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32c2462e89541e6687e684d97310015d64a0627b61106fc472156a38f61cd1e" +checksum = "211818e820cda9ca6f167a64a5c808837366a6dfd807157c64c1304c486cd033" dependencies = [ "argh_derive", "argh_shared", @@ -217,9 +217,9 @@ dependencies = [ [[package]] name = "argh_derive" -version = "0.1.15" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccc2a031b364bd099fed016feb1ccfca2c3549d63c16f330cfc40b27b7692231" +checksum = "c442a9d18cef5dde467405d27d461d080d68972d6d0dfd0408265b6749ec427d" dependencies = [ "argh_shared", "proc-macro2", @@ -229,9 +229,9 @@ dependencies = [ [[package]] name = "argh_shared" -version = "0.1.15" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b9abea17ef74821d1d3490aee9e0749d731445d965b7512308b2aa00c90079e" +checksum = "e5ade012bac4db278517a0132c8c10c6427025868dca16c801087c28d5a411f1" dependencies = [ "serde", ] @@ -866,10 +866,11 @@ dependencies = [ [[package]] name = "borsh" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1da5ab77c1437701eeff7c88d968729e7766172279eab0676857b3d63af7a6f" +checksum = "cfd1e3f8955a5d7de9fab72fc8373fade9fb8a703968cb200ae3dc6cf08e185a" dependencies = [ + "bytes", "cfg_aliases", ] @@ -947,9 +948,9 @@ checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "cc" -version = "1.2.56" +version = "1.2.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", "jobserver", @@ -1018,9 +1019,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", "clap_derive", @@ -1028,9 +1029,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstream", "anstyle", @@ -1041,18 +1042,18 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.5.66" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c757a3b7e39161a4e56f9365141ada2a6c915a8622c408ab6bb4b5d047371031" +checksum = "19c9f1dde76b736e3681f28cec9d5a61299cbaae0fce80a68e43724ad56031eb" dependencies = [ "clap", ] [[package]] name = "clap_derive" -version = "4.5.55" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" dependencies = [ "heck 0.5.0", "proc-macro2", @@ -1062,9 +1063,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "clipboard-win" @@ -1147,9 +1148,9 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "combine" @@ -1358,9 +1359,9 @@ dependencies = [ [[package]] name = "crypto-common" -version = "0.1.7" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", "typenum", @@ -1400,9 +1401,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" dependencies = [ "darling_core", "darling_macro", @@ -1410,11 +1411,10 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" dependencies = [ - "fnv", "ident_case", "proc-macro2", "quote", @@ -1424,9 +1424,9 @@ dependencies = [ [[package]] name = "darling_macro" -version = "0.21.3" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" dependencies = [ "darling_core", "quote", @@ -1487,9 +1487,9 @@ dependencies = [ [[package]] name = "derive-where" -version = "1.6.0" +version = "1.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f" +checksum = "d08b3a0bcc0d079199cd476b2cae8435016ec11d1c0986c6901c5ac223041534" dependencies = [ "proc-macro2", "quote", @@ -1728,9 +1728,9 @@ dependencies = [ [[package]] name = "env_filter" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +checksum = "32e90c2accc4b07a8456ea0debdc2e7587bdd890680d71173a15d4ae604f6eef" dependencies = [ "log", "regex", @@ -1738,9 +1738,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.9" +version = "0.11.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +checksum = "0621c04f2196ac3f488dd583365b9c09be011a4ab8b9f37248ffcc8f6198b56a" dependencies = [ "env_filter", "log", @@ -1857,7 +1857,7 @@ checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" dependencies = [ "cfg-if", "libc", - "libredox 0.1.14", + "libredox 0.1.15", ] [[package]] @@ -2083,9 +2083,9 @@ dependencies = [ [[package]] name = "generic-array" -version = "0.14.7" +version = "0.14.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" dependencies = [ "typenum", "version_check", @@ -2647,9 +2647,9 @@ checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" [[package]] name = "iri-string" -version = "0.7.10" +version = "0.7.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +checksum = "d8e7418f59cc01c88316161279a7f665217ae316b388e58a0d10e29f54f1e5eb" dependencies = [ "memchr", "serde", @@ -2706,9 +2706,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jni" @@ -2719,7 +2719,7 @@ dependencies = [ "cesu8", "cfg-if", "combine", - "jni-sys", + "jni-sys 0.3.1", "log", "thiserror 1.0.69", "walkdir", @@ -2728,9 +2728,31 @@ dependencies = [ [[package]] name = "jni-sys" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" +checksum = "41a652e1f9b6e0275df1f15b32661cf0d4b78d4d87ddec5e0c3c20f097433258" +dependencies = [ + "jni-sys 0.4.1", +] + +[[package]] +name = "jni-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6377a88cb3910bee9b0fa88d4f42e1d2da8e79915598f65fb0c7ee14c878af2" +dependencies = [ + "jni-sys-macros", +] + +[[package]] +name = "jni-sys-macros" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38c0b942f458fe50cdac086d2f946512305e5631e720728f2a61aabcd47a6264" +dependencies = [ + "quote", + "syn 2.0.117", +] [[package]] name = "jobserver" @@ -2927,9 +2949,9 @@ checksum = "82903360c009b816f5ab72a9b68158c27c301ee2c3f20655b55c5e589e7d3bb7" [[package]] name = "libc" -version = "0.2.182" +version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" [[package]] name = "libm" @@ -2950,9 +2972,9 @@ dependencies = [ [[package]] name = "libredox" -version = "0.1.14" +version = "0.1.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1744e39d1d6a9948f4f388969627434e31128196de472883b39f148769bfe30a" +checksum = "7ddbf48fd451246b1f8c2610bd3b4ac0cc6e149d89832867093ab69a17194f08" dependencies = [ "bitflags 2.11.0", "libc", @@ -3724,9 +3746,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf97ec579c3c42f953ef76dbf8d55ac91fb219dde70e49aa4a6b7d74e9919050" +checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" [[package]] name = "num-integer" @@ -3760,9 +3782,9 @@ dependencies = [ [[package]] name = "num_enum" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1207a7e20ad57b847bbddc6776b968420d38292bbfe2089accff5e19e82454c" +checksum = "5d0bca838442ec211fa11de3a8b0e0e8f3a4522575b5c4c06ed722e005036f26" dependencies = [ "num_enum_derive", "rustversion", @@ -3770,9 +3792,9 @@ dependencies = [ [[package]] name = "num_enum_derive" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff32365de1b6743cb203b710788263c44a03de03802daf96092f2da4fe6ba4d7" +checksum = "680998035259dcfcafe653688bf2aa6d3e2dc05e98be6ab46afb089dc84f1df8" dependencies = [ "proc-macro-crate", "proc-macro2", @@ -3806,9 +3828,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "once_cell_polyfill" @@ -3829,9 +3851,9 @@ dependencies = [ [[package]] name = "openssl" -version = "0.10.75" +version = "0.10.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" dependencies = [ "bitflags 2.11.0", "cfg-if", @@ -3861,9 +3883,9 @@ checksum = "7c87def4c32ab89d880effc9e097653c8da5d6ef28e6b539d313baaacfbafcbe" [[package]] name = "openssl-sys" -version = "0.9.111" +version = "0.9.112" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" dependencies = [ "cc", "libc", @@ -3873,9 +3895,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "5.1.0" +version = "5.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4779c6901a562440c3786d08192c6fbda7c1c2060edd10006b05ee35d10f2d" +checksum = "0218004a4aae742209bee9c3cef05672f6b2708be36a50add8eb613b1f2a4008" dependencies = [ "num-traits", "rand 0.8.5", @@ -4358,7 +4380,7 @@ version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e67ba7e9b2b56446f1d419b1d807906278ffa1a658a8a5d8a39dcb1f5a78614f" dependencies = [ - "toml_edit 0.25.4+spec-1.1.0", + "toml_edit 0.25.8+spec-1.1.0", ] [[package]] @@ -4372,9 +4394,9 @@ dependencies = [ [[package]] name = "proptest" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37566cb3fdacef14c0737f9546df7cfeadbfbc9fef10991038bf5015d0c80532" +checksum = "4b45fcc2344c680f5025fe57779faef368840d0bd1f42f216291f0dc4ace4744" dependencies = [ "bit-set", "bit-vec", @@ -4536,20 +4558,25 @@ dependencies = [ "argh", "ark-ff 0.5.0", "base64", + "bincode 1.3.3", "hex", + "mavros-artifacts", "noirc_abi", "postcard", "provekit-common", "provekit-gnark", "provekit-prover", "provekit-r1cs-compiler", + "provekit-spark", "provekit-verifier", "rayon", + "serde", "serde_json", "tikv-jemallocator", "tracing", "tracing-subscriber", "tracing-tracy", + "whir", ] [[package]] @@ -4652,6 +4679,20 @@ dependencies = [ "whir", ] +[[package]] +name = "provekit-spark" +version = "0.1.0" +dependencies = [ + "anyhow", + "ark-ff 0.5.0", + "ark-std 0.5.0", + "provekit-common", + "rayon", + "serde", + "tracing", + "whir", +] + [[package]] name = "provekit-verifier" version = "0.1.0" @@ -4982,7 +5023,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ "getrandom 0.2.17", - "libredox 0.1.14", + "libredox 0.1.15", "thiserror 1.0.69", ] @@ -5359,9 +5400,9 @@ checksum = "f87165f0995f63a9fbeea62b64d10b4d9d8e78ec6d7d51fb2125fda7bb36788f" [[package]] name = "rustls-webpki" -version = "0.103.9" +version = "0.103.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" dependencies = [ "ring", "rustls-pki-types", @@ -5493,9 +5534,9 @@ dependencies = [ [[package]] name = "schannel" -version = "0.1.28" +version = "0.1.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +checksum = "91c1b7e4904c873ef0710c1f407dde2e6287de2bebc1bbbf7d430bb7cbffd939" dependencies = [ "windows-sys 0.61.2", ] @@ -5711,9 +5752,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "381b283ce7bc6b476d903296fb59d0d36633652b633b27f64db4fb46dcbfc3b9" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64", "chrono", @@ -5730,9 +5771,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6d4e30573c8cb306ed6ab1dca8423eec9a463ea0e155f45399455e0368b27e0" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ "darling", "proc-macro2", @@ -5941,12 +5982,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.6.2" +version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" dependencies = [ "libc", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -6145,9 +6186,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "tempfile" -version = "3.26.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82a72c767771b47409d2345987fda8628641887d5466101319899796367354a0" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", "getrandom 0.4.2", @@ -6178,12 +6219,12 @@ dependencies = [ [[package]] name = "terminal_size" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b8cb979cb11c32ce1603f8137b22262a9d131aaa5c37b5678025f22b8becd0" +checksum = "230a1b821ccbd75b185820a1f1ff7b14d21da1e442e22c0863ea5f08771a8874" dependencies = [ "rustix 1.1.4", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -6449,9 +6490,9 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "1.0.0+spec-1.1.0" +version = "1.1.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c2555c699578a4f59f0cc68e5116c8d7cabbd45e1409b989d4be085b53f13e" +checksum = "97251a7c317e03ad83774a8752a7e81fb6067740609f75ea2b585b569a59198f" dependencies = [ "serde_core", ] @@ -6480,28 +6521,28 @@ dependencies = [ "serde_spanned", "toml_datetime 0.6.11", "toml_write", - "winnow 0.7.14", + "winnow 0.7.15", ] [[package]] name = "toml_edit" -version = "0.25.4+spec-1.1.0" +version = "0.25.8+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7193cbd0ce53dc966037f54351dbbcf0d5a642c7f0038c382ef9e677ce8c13f2" +checksum = "16bff38f1d86c47f9ff0647e6838d7bb362522bdf44006c7068c2b1e606f1f3c" dependencies = [ "indexmap 2.13.0", - "toml_datetime 1.0.0+spec-1.1.0", + "toml_datetime 1.1.0+spec-1.1.0", "toml_parser", - "winnow 0.7.14", + "winnow 1.0.0", ] [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.1.0+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "2334f11ee363607eb04df9b8fc8a13ca1715a72ba8662a26ac285c98aabb4011" dependencies = [ - "winnow 0.7.14", + "winnow 1.0.0", ] [[package]] @@ -6636,9 +6677,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "matchers", "nu-ansi-term", @@ -6789,9 +6830,9 @@ checksum = "3b09c83c3c29d37506a3e260c08c03743a6bb66a9cd432c6934ab501a190571f" [[package]] name = "unicode-segmentation" -version = "1.12.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" +checksum = "da36089a805484bcccfffe0739803392c8298778a2d2f09febf76fac5ad9025b" [[package]] name = "unicode-width" @@ -7571,9 +7612,18 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.14" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" +dependencies = [ + "memchr", +] + +[[package]] +name = "winnow" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" +checksum = "a90e88e4667264a994d34e6d1ab2d26d398dcdca8b7f52bec8668957517fc7d8" dependencies = [ "memchr", ] @@ -7732,18 +7782,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.40" +version = "0.8.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a789c6e490b576db9f7e6b6d661bcc9799f7c0ac8352f56ea20193b2681532e5" +checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.40" +version = "0.8.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f65c489a7071a749c849713807783f70672b28094011623e200cb86dcb835953" +checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c73da926a..4ca3ce09c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ members = [ "tooling/provekit-wasm", "tooling/verifier-server", "ntt", + "provekit/spark", "poseidon2", "playground/passport-input-gen", ] @@ -98,6 +99,7 @@ provekit-ffi = { path = "tooling/provekit-ffi" } provekit-gnark = { path = "tooling/provekit-gnark" } provekit-prover = { path = "provekit/prover", default-features = false } provekit-r1cs-compiler = { path = "provekit/r1cs-compiler" } +provekit-spark = { path = "provekit/spark" } provekit-verifier = { path = "provekit/verifier" } provekit-verifier-server = { path = "tooling/verifier-server" } provekit-wasm = { path = "tooling/provekit-wasm" } diff --git a/provekit/common/src/file/binary_format.rs b/provekit/common/src/file/binary_format.rs index 44ff55717..b331a62ed 100644 --- a/provekit/common/src/file/binary_format.rs +++ b/provekit/common/src/file/binary_format.rs @@ -24,4 +24,7 @@ pub const NOIR_PROOF_SCHEME_FORMAT: [u8; 8] = *b"NrProScm"; pub const NOIR_PROOF_SCHEME_VERSION: (u16, u16) = (1, 2); pub const NOIR_PROOF_FORMAT: [u8; 8] = *b"NPSProof"; -pub const NOIR_PROOF_VERSION: (u16, u16) = (1, 1); +pub const NOIR_PROOF_VERSION: (u16, u16) = (1, 2); + +pub const SPARK_PROOF_FORMAT: [u8; 8] = *b"SprkProf"; +pub const SPARK_PROOF_VERSION: (u16, u16) = (1, 0); diff --git a/provekit/common/src/file/io/mod.rs b/provekit/common/src/file/io/mod.rs index 4d04680a3..a6f2a072c 100644 --- a/provekit/common/src/file/io/mod.rs +++ b/provekit/common/src/file/io/mod.rs @@ -3,9 +3,10 @@ mod buf_ext; mod counting_writer; mod json; +pub use self::bin::Compression; use { self::{ - bin::{read_bin, read_hash_config as read_hash_config_bin, write_bin, Compression}, + bin::{read_bin, read_hash_config as read_hash_config_bin, write_bin}, buf_ext::BufExt, counting_writer::CountingWriter, json::{read_json, write_json}, @@ -26,7 +27,7 @@ pub trait FileFormat: Serialize + for<'a> Deserialize<'a> { } /// Helper trait to optionally extract hash config. -pub(crate) trait MaybeHashAware { +pub trait MaybeHashAware { fn maybe_hash_config(&self) -> Option; } diff --git a/provekit/common/src/lib.rs b/provekit/common/src/lib.rs index ce0cd4d2f..961dfcf44 100644 --- a/provekit/common/src/lib.rs +++ b/provekit/common/src/lib.rs @@ -1,7 +1,7 @@ pub mod file; pub use file::binary_format; pub mod hash_config; -mod interner; +pub mod interner; mod mavros; mod noir_proof_scheme; pub mod optimize; @@ -9,6 +9,7 @@ pub mod prefix_covector; mod prover; mod r1cs; pub mod skyscraper; +pub mod spark; pub mod sparse_matrix; mod transcript_sponge; pub mod u256_arith; diff --git a/provekit/common/src/noir_proof_scheme.rs b/provekit/common/src/noir_proof_scheme.rs index a43377eae..4273011ed 100644 --- a/provekit/common/src/noir_proof_scheme.rs +++ b/provekit/common/src/noir_proof_scheme.rs @@ -1,5 +1,6 @@ use { crate::{ + spark::R1CSSparkQuery, whir_r1cs::{WhirR1CSProof, WhirR1CSScheme}, witness::{NoirWitnessGenerator, SplitWitnessBuilders}, HashConfig, MavrosSchemeData, NoirElement, PublicInputs, R1CS, @@ -29,8 +30,9 @@ pub enum NoirProofScheme { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct NoirProof { - pub public_inputs: PublicInputs, - pub whir_r1cs_proof: WhirR1CSProof, + pub public_inputs: PublicInputs, + pub whir_r1cs_proof: WhirR1CSProof, + pub r1cs_spark_query: R1CSSparkQuery, } impl NoirProofScheme { diff --git a/provekit/common/src/spark.rs b/provekit/common/src/spark.rs new file mode 100644 index 000000000..65071668b --- /dev/null +++ b/provekit/common/src/spark.rs @@ -0,0 +1,21 @@ +use { + crate::{utils::serde_ark, FieldElement}, + serde::{Deserialize, Serialize}, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Point { + #[serde(with = "serde_ark")] + pub row: Vec, + #[serde(with = "serde_ark")] + pub col: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct R1CSSparkQuery { + pub point_to_evaluate: Point, + #[serde(with = "serde_ark")] + pub matrix_batching_randomness: FieldElement, + #[serde(with = "serde_ark")] + pub claimed_value: FieldElement, +} diff --git a/provekit/common/src/utils/sumcheck.rs b/provekit/common/src/utils/sumcheck.rs index 207d76be4..0ab5ecebd 100644 --- a/provekit/common/src/utils/sumcheck.rs +++ b/provekit/common/src/utils/sumcheck.rs @@ -152,6 +152,11 @@ fn eval_eq( } } +/// Evaluates a quadratic polynomial on a value +pub fn eval_quadratic_poly(poly: [FieldElement; 3], point: FieldElement) -> FieldElement { + poly[0] + point * (poly[1] + point * poly[2]) +} + /// Evaluates a cubic polynomial on a value pub fn eval_cubic_poly(poly: [FieldElement; 4], point: FieldElement) -> FieldElement { poly[0] + point * (poly[1] + point * (poly[2] + point * poly[3])) diff --git a/provekit/prover/src/lib.rs b/provekit/prover/src/lib.rs index edcf685ac..1384ee96a 100644 --- a/provekit/prover/src/lib.rs +++ b/provekit/prover/src/lib.rs @@ -230,7 +230,7 @@ impl Prove for NoirProver { .map(|(i, w)| w.ok_or_else(|| anyhow::anyhow!("Witness {i} unsolved after solving"))) .collect::>>()?; - let whir_r1cs_proof = self + let (whir_r1cs_proof, r1cs_spark_query) = self .whir_for_witness .prove_noir(merlin, r1cs, commitments, full_witness, &public_inputs) .context("While proving R1CS instance")?; @@ -238,6 +238,7 @@ impl Prove for NoirProver { Ok(NoirProof { public_inputs, whir_r1cs_proof, + r1cs_spark_query, }) } } @@ -314,7 +315,7 @@ impl Prove for MavrosProver { PublicInputs::from_vec(witgen_result.out_wit_pre_comm[1..=num_public_inputs].to_vec()) }; - let whir_r1cs_proof = self + let (whir_r1cs_proof, r1cs_spark_query) = self .whir_for_witness .prove_mavros( merlin, @@ -330,6 +331,7 @@ impl Prove for MavrosProver { Ok(NoirProof { public_inputs, whir_r1cs_proof, + r1cs_spark_query, }) } diff --git a/provekit/prover/src/whir_r1cs.rs b/provekit/prover/src/whir_r1cs.rs index a2d55f849..98027077b 100644 --- a/provekit/prover/src/whir_r1cs.rs +++ b/provekit/prover/src/whir_r1cs.rs @@ -1,18 +1,19 @@ use { anyhow::{ensure, Result}, - ark_ff::UniformRand, + ark_ff::{AdditiveGroup, UniformRand}, ark_std::{One, Zero}, provekit_common::{ prefix_covector::{ build_prefix_covectors, compute_alpha_evals, compute_public_eval, expand_powers, make_public_weight, OffsetCovector, }, + spark::{Point, R1CSSparkQuery}, utils::{ pad_to_power_of_two, sumcheck::{ calculate_evaluations_over_boolean_hypercube_for_eq, calculate_witness_bounds, - eval_cubic_poly, multiply_transposed_by_eq_alpha, sumcheck_fold_map_reduce, - transpose_r1cs_matrices, + eval_cubic_poly, eval_quadratic_poly, multiply_transposed_by_eq_alpha, + sumcheck_fold_map_reduce, transpose_r1cs_matrices, }, HALF, }, @@ -23,8 +24,9 @@ use { tracing::instrument, whir::{ algebra::{dot, linear_form::LinearForm}, - protocols::whir_zk::Witness as WhirZkWitness, + protocols::{whir::FinalClaim, whir_zk::Witness as WhirZkWitness}, transcript::{ProverState, VerifierMessage}, + utils::zip_strict, }, }; #[cfg(not(target_arch = "wasm32"))] @@ -61,7 +63,7 @@ pub trait WhirR1CSProver { commitments: Vec, full_witness: Vec, public_inputs: &PublicInputs, - ) -> Result; + ) -> Result<(WhirR1CSProof, R1CSSparkQuery)>; #[cfg(not(target_arch = "wasm32"))] fn prove_mavros( @@ -73,7 +75,7 @@ pub trait WhirR1CSProver { witness_layout: WitnessLayout, constraints_layout: ConstraintsLayout, ad_binary: &[u64], - ) -> Result; + ) -> Result<(WhirR1CSProof, R1CSSparkQuery)>; } impl WhirR1CSProver for WhirR1CSScheme { @@ -146,7 +148,7 @@ impl WhirR1CSProver for WhirR1CSScheme { commitments: Vec, full_witness: Vec, public_inputs: &PublicInputs, - ) -> Result { + ) -> Result<(WhirR1CSProof, R1CSSparkQuery)> { ensure!(!commitments.is_empty(), "Need at least one commitment"); let (a, b, c) = calculate_witness_bounds(&r1cs, &full_witness); @@ -173,16 +175,19 @@ impl WhirR1CSProver for WhirR1CSScheme { let blinding_offset = blinding.offset; let blinding_weights = expand_powers::<4>(&alpha); - prove_from_alphas( + let (whir_r1cs_proof, final_claim) = prove_from_alphas( self, merlin, + alpha, alphas, blinding_eval, blinding_offset, blinding_weights, commitments, public_inputs, - ) + )?; + + Ok((whir_r1cs_proof, final_claim)) } #[cfg(not(target_arch = "wasm32"))] @@ -196,7 +201,7 @@ impl WhirR1CSProver for WhirR1CSScheme { witness_layout: WitnessLayout, constraints_layout: ConstraintsLayout, ad_binary: &[u64], - ) -> Result { + ) -> Result<(WhirR1CSProof, R1CSSparkQuery)> { ensure!(!commitments.is_empty(), "Need at least one commitment"); let blinding = commitments[0] @@ -229,16 +234,19 @@ impl WhirR1CSProver for WhirR1CSScheme { let blinding_offset = blinding.offset; let blinding_weights = expand_powers::<4>(&alpha); - prove_from_alphas( + let (whir_r1cs_proof, final_claim) = prove_from_alphas( self, merlin, + alpha, alphas, blinding_eval, blinding_offset, blinding_weights, commitments, public_inputs, - ) + )?; + + Ok((whir_r1cs_proof, final_claim)) } } @@ -246,19 +254,20 @@ impl WhirR1CSProver for WhirR1CSScheme { fn prove_from_alphas( scheme: &WhirR1CSScheme, mut merlin: ProverState, + alpha: Vec, alphas: [Vec; 3], blinding_eval: FieldElement, blinding_offset: usize, blinding_weights: Vec, commitments: Vec, public_inputs: &PublicInputs, -) -> Result { +) -> Result<(WhirR1CSProof, R1CSSparkQuery)> { let is_single = commitments.len() == 1; let (x, public_weight) = get_public_weights(public_inputs, &mut merlin, scheme.m); let domain_size = 1usize << scheme.m; - if is_single { + let final_claim = if is_single { // Single commitment path let commitment = commitments .into_iter() @@ -283,19 +292,51 @@ fn prove_from_alphas( let blinding_covector = OffsetCovector::new(blinding_weights, blinding_offset, domain_size); + let alpha_weight_data: Vec<_> = weights + .iter() + .map(|w| (w.vector().to_vec(), w.size())) + .collect(); + let mut boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); boxed_weights.push(Box::new(blinding_covector)); - let _ = scheme.whir_witness.prove( + let public_offset = if public_inputs.is_empty() { 0 } else { 1 }; + + let final_claim = scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(commitment.polynomial.as_slice())], commitment.witness, boxed_weights, Cow::Borrowed(&evaluations), ); + + let rlc = zip_strict( + final_claim.rlc_coefficients[public_offset..(public_offset + 3)].iter(), + alpha_weight_data[public_offset..(public_offset + 3)].iter(), + ) + .map(|(&c, (vec, ds))| { + let w = PrefixCovector::new(vec.clone(), *ds); + c * w.mle_evaluate(&final_claim.evaluation_point) + }) + .sum::(); + + let claimed_batched_spark_value = if !public_inputs.is_empty() { + rlc / final_claim.rlc_coefficients[1] + } else { + rlc + }; + + R1CSSparkQuery { + point_to_evaluate: Point { + row: alpha, + col: final_claim.evaluation_point, + }, + matrix_batching_randomness: final_claim.rlc_coefficients[1], + claimed_value: claimed_batched_spark_value, + } } else { // Dual commitment path let mut commitments = commitments.into_iter(); @@ -307,6 +348,7 @@ fn prove_from_alphas( .expect("dual-commitment path requires second commitment"); let (alphas_1, alphas_2): (Vec<_>, Vec<_>) = alphas + .clone() .into_iter() .map(|mut v| { let v2 = v.split_off(scheme.w1_size); @@ -339,8 +381,8 @@ fn prove_from_alphas( polynomial: p1, .. } = c1; - { - let mut weights = build_prefix_covectors(scheme.m, alphas_1); + let final_claim1 = { + let mut weights = build_prefix_covectors(scheme.m, alphas_1.clone()); let mut evaluations: Vec = Vec::new(); if let Some(pe) = public_1 { weights.insert(0, make_public_weight(x, public_inputs.len(), scheme.m)); @@ -358,14 +400,14 @@ fn prove_from_alphas( .collect(); boxed_weights.push(Box::new(blinding_covector)); - let _ = scheme.whir_witness.prove( + scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p1.as_slice())], w1, boxed_weights, Cow::Borrowed(&evaluations), - ); - } + ) + }; drop(p1); let WhirR1CSCommitment { @@ -373,31 +415,115 @@ fn prove_from_alphas( polynomial: p2, .. } = c2; - { - let weights = build_prefix_covectors(scheme.m, alphas_2); + let final_claim2 = { + let weights = build_prefix_covectors(scheme.m, alphas_2.clone()); let evaluations: Vec = evals_2; let boxed_weights: Vec>> = weights .into_iter() .map(|w| Box::new(w) as Box>) .collect(); - let _ = scheme.whir_witness.prove( + scheme.whir_witness.prove( &mut merlin, vec![Cow::Borrowed(p2.as_slice())], w2, boxed_weights, Cow::Borrowed(&evaluations), - ); + ) + }; + + let beta: FieldElement = merlin.verifier_message(); + + let alphas1_padded: Vec> = alphas_1 + .clone() + .into_iter() + .map(|mut alpha| { + alpha.resize(1 << scheme.m, FieldElement::zero()); + alpha + }) + .collect(); + let alphas2_padded: Vec> = alphas_2 + .clone() + .into_iter() + .map(|mut alpha| { + alpha.resize(1 << scheme.m, FieldElement::zero()); + alpha + }) + .collect(); + + let alphas: Vec> = alphas1_padded + .iter() + .zip(alphas2_padded.iter()) + .map(|(a1, a2)| { + let mut combined = a1.clone(); + combined.extend_from_slice(a2); + combined + }) + .collect(); + + let claimed_eval1: Vec = alphas1_padded + .iter() + .map(|alphas1| { + PrefixCovector::new(alphas1.clone(), 1 << scheme.m) + .mle_evaluate(&final_claim1.evaluation_point) + }) + .collect(); + let claimed_eval2: Vec = alphas2_padded + .iter() + .map(|alphas2| { + PrefixCovector::new(alphas2.clone(), 1 << scheme.m) + .mle_evaluate(&final_claim2.evaluation_point) + }) + .collect(); + let claimed_evals: [FieldElement; 3] = + std::array::from_fn(|i| claimed_eval1[i] + beta * claimed_eval2[i]); + + let mut eval_point1 = final_claim1.evaluation_point.clone(); + eval_point1.insert(0, FieldElement::zero()); + let hypercube1 = + calculate_evaluations_over_boolean_hypercube_for_eq(&eval_point1, 1 << (scheme.m + 1)); + + let mut eval_point2 = final_claim2.evaluation_point.clone(); + eval_point2.insert(0, FieldElement::one()); + let hypercube2 = + calculate_evaluations_over_boolean_hypercube_for_eq(&eval_point2, 1 << (scheme.m + 1)); + + let hypercube: Vec = hypercube1 + .iter() + .zip(hypercube2) + .map(|(h1, h2)| *h1 + beta * h2) + .collect(); + + let alpha_refs: [&[FieldElement]; 3] = [&alphas[0], &alphas[1], &alphas[2]]; + + let (folded_values, folding_randomness) = + run_two_sumcheck(&mut merlin, &hypercube, alpha_refs, claimed_evals)?; + + let matrix_batching: FieldElement = merlin.verifier_message(); + let claimed_batched = folded_values[1] + + folded_values[2] * matrix_batching + + folded_values[3] * matrix_batching * matrix_batching; + + R1CSSparkQuery { + point_to_evaluate: Point { + row: alpha, + col: folding_randomness, + }, + matrix_batching_randomness: matrix_batching, + claimed_value: claimed_batched, } - } + }; let proof = merlin.proof(); - Ok(WhirR1CSProof { - narg_string: proof.narg_string, - hints: proof.hints, - #[cfg(debug_assertions)] - pattern: proof.pattern, - }) + Ok(( + WhirR1CSProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + }, + final_claim, + )) } pub fn compute_blinding_coefficients_for_round( @@ -596,6 +722,100 @@ pub fn run_zk_sumcheck_prover( (alpha, blinding_eval) } +pub fn run_two_sumcheck( + merlin: &mut ProverState, + hypercube: &[FieldElement], + alphas: [&[FieldElement]; 3], + mut claimed_values: [FieldElement; 3], +) -> Result<([FieldElement; 4], Vec)> { + let mut sumcheck_randomness; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut h_mle = hypercube.to_vec(); + let mut a_mle = alphas[0].to_vec(); + let mut b_mle = alphas[1].to_vec(); + let mut c_mle = alphas[2].to_vec(); + loop { + let [a_hhat_i_at_0, a_highest_coeff, b_hhat_i_at_0, b_highest_coeff, c_hhat_i_at_0, c_highest_coeff] = + sumcheck_fold_map_reduce( + [&mut h_mle, &mut a_mle, &mut b_mle, &mut c_mle], + fold, + |[h_mle, a_mle, b_mle, c_mle]| { + [ + h_mle.0 * a_mle.0, + (h_mle.1 - h_mle.0) * (a_mle.1 - a_mle.0), + h_mle.0 * b_mle.0, + (h_mle.1 - h_mle.0) * (b_mle.1 - b_mle.0), + h_mle.0 * c_mle.0, + (h_mle.1 - h_mle.0) * (c_mle.1 - c_mle.0), + ] + }, + ); + + if fold.is_some() { + h_mle.truncate(h_mle.len() / 2); + a_mle.truncate(a_mle.len() / 2); + b_mle.truncate(b_mle.len() / 2); + c_mle.truncate(c_mle.len() / 2); + } + + let mut a_hhat_i_coeffs = [FieldElement::zero(); 3]; + + a_hhat_i_coeffs[0] = a_hhat_i_at_0; + a_hhat_i_coeffs[2] = a_highest_coeff; + a_hhat_i_coeffs[1] = + claimed_values[0] - a_hhat_i_coeffs[0] - a_hhat_i_coeffs[0] - a_hhat_i_coeffs[2]; + + for a_coeff in &a_hhat_i_coeffs { + merlin.prover_message(a_coeff); + } + + let mut b_hhat_i_coeffs = [FieldElement::zero(); 3]; + + b_hhat_i_coeffs[0] = b_hhat_i_at_0; + b_hhat_i_coeffs[2] = b_highest_coeff; + b_hhat_i_coeffs[1] = + claimed_values[1] - b_hhat_i_coeffs[0] - b_hhat_i_coeffs[0] - b_hhat_i_coeffs[2]; + + for b_coeff in &b_hhat_i_coeffs { + merlin.prover_message(b_coeff); + } + + let mut c_hhat_i_coeffs = [FieldElement::zero(); 3]; + + c_hhat_i_coeffs[0] = c_hhat_i_at_0; + c_hhat_i_coeffs[2] = c_highest_coeff; + c_hhat_i_coeffs[1] = + claimed_values[2] - c_hhat_i_coeffs[0] - c_hhat_i_coeffs[0] - c_hhat_i_coeffs[2]; + + for c_coeff in &c_hhat_i_coeffs { + merlin.prover_message(c_coeff); + } + + sumcheck_randomness = merlin.verifier_message(); + fold = Some(sumcheck_randomness); + claimed_values[0] = eval_quadratic_poly(a_hhat_i_coeffs, sumcheck_randomness); + claimed_values[1] = eval_quadratic_poly(b_hhat_i_coeffs, sumcheck_randomness); + claimed_values[2] = eval_quadratic_poly(c_hhat_i_coeffs, sumcheck_randomness); + + sumcheck_randomness_accumulator.push(sumcheck_randomness); + if h_mle.len() <= 2 { + break; + } + } + + let folded_h = h_mle[0] + (h_mle[1] - h_mle[0]) * sumcheck_randomness; + let folded_a = a_mle[0] + (a_mle[1] - a_mle[0]) * sumcheck_randomness; + let folded_b = b_mle[0] + (b_mle[1] - b_mle[0]) * sumcheck_randomness; + let folded_c = c_mle[0] + (c_mle[1] - c_mle[0]) * sumcheck_randomness; + + Ok(( + [folded_h, folded_a, folded_b, folded_c], + sumcheck_randomness_accumulator, + )) +} + fn create_weights_and_evaluations( m: usize, polynomial: &[FieldElement], diff --git a/provekit/spark/Cargo.toml b/provekit/spark/Cargo.toml new file mode 100644 index 000000000..1f6b62f96 --- /dev/null +++ b/provekit/spark/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "provekit-spark" +version = "0.1.0" +edition.workspace = true +rust-version.workspace = true +authors.workspace = true +license.workspace = true +homepage.workspace = true +repository.workspace = true + +[dependencies] +provekit-common.workspace = true +ark-ff.workspace = true +ark-std.workspace = true +anyhow.workspace = true +serde.workspace = true +whir.workspace = true +tracing.workspace = true +rayon.workspace = true + +[lints] +workspace = true diff --git a/provekit/spark/src/gpa.rs b/provekit/spark/src/gpa.rs new file mode 100644 index 000000000..2715c6d66 --- /dev/null +++ b/provekit/spark/src/gpa.rs @@ -0,0 +1,433 @@ +use { + anyhow::{ensure, Context}, + provekit_common::{ + utils::{ + next_power_of_two, + sumcheck::{ + calculate_eq, calculate_evaluations_over_boolean_hypercube_for_eq, eval_cubic_poly, + sumcheck_fold_map_reduce, + }, + HALF, + }, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::transcript::{ProverState, VerifierMessage, VerifierState}, +}; + +#[instrument(skip_all)] +pub fn run_gpa2( + merlin: &mut ProverState, + left: &[FieldElement], + right: &[FieldElement], +) -> anyhow::Result> { + let mut concatenated = left.to_vec(); + concatenated.extend_from_slice(right); + let mut layers = calculate_binary_multiplication_tree(concatenated)?; + + let mut drain = layers.drain(1..); + + let first_layer = drain.next().context("GPA tree has fewer than 2 layers")?; + let (accumulated_randomness, mut sumcheck_claim) = add_line_to_transcript(merlin, first_layer); + let mut accumulated_randomness = accumulated_randomness.to_vec(); + + for layer in drain { + (sumcheck_claim, accumulated_randomness) = + run_gpa_sumcheck(merlin, layer, sumcheck_claim, accumulated_randomness)?; + } + + Ok(accumulated_randomness) +} + +#[instrument(skip_all)] +pub fn run_gpa4( + merlin: &mut ProverState, + leaves: Vec, +) -> anyhow::Result> { + let mut layers = calculate_binary_multiplication_tree(leaves)?; + + let mut drain = layers.drain(2..); + + let coeffs = drain.next().context("GPA tree has fewer than 3 layers")?; + let coeffs = [ + coeffs[0], + coeffs[1] - coeffs[0], + coeffs[2] - coeffs[0], + coeffs[3] - coeffs[2] - coeffs[1] + coeffs[0], + ]; + + for c in &coeffs { + merlin.prover_message(c); + } + + let r0: FieldElement = merlin.verifier_message(); + let r1: FieldElement = merlin.verifier_message(); + let mut accumulated_randomness = vec![r0, r1]; + + let mut sumcheck_claim = coeffs[0] + coeffs[1] * r1 + coeffs[2] * r0 + coeffs[3] * r0 * r1; + + for layer in drain { + (sumcheck_claim, accumulated_randomness) = + run_gpa_sumcheck(merlin, layer, sumcheck_claim, accumulated_randomness)?; + } + + Ok(accumulated_randomness) +} + +fn calculate_binary_multiplication_tree( + array_to_prove: Vec, +) -> anyhow::Result>> { + use rayon::prelude::*; + + ensure!( + array_to_prove.len() == (1 << next_power_of_two(array_to_prove.len())), + "Input length must be power of two" + ); + + let mut layers = vec![]; + let mut current_layer = array_to_prove; + + while current_layer.len() > 1 { + let next_layer: Vec = current_layer + .par_chunks_exact(2) + .map(|pair| pair[0] * pair[1]) + .collect(); + + layers.push(current_layer); + current_layer = next_layer; + } + + layers.push(current_layer); + layers.reverse(); + Ok(layers) +} + +fn add_line_to_transcript( + merlin: &mut ProverState, + arr: Vec, +) -> ([FieldElement; 1], FieldElement) { + let line_poly = [arr[0], arr[1] - arr[0]]; + + for c in line_poly.iter() { + merlin.prover_message(c); + } + + let challenge: FieldElement = merlin.verifier_message(); + + let next_claim = line_poly[0] + line_poly[1] * challenge; + + ([challenge], next_claim) +} + +fn run_gpa_sumcheck( + merlin: &mut ProverState, + layer: Vec, + mut sumcheck_claim: FieldElement, + accumulated_randomness: Vec, +) -> anyhow::Result<(FieldElement, Vec)> { + let (mut even_layer, mut odd_layer) = split_even_odd(layer); + + let mut eq_evaluations = calculate_evaluations_over_boolean_hypercube_for_eq( + &accumulated_randomness, + 1 << accumulated_randomness.len(), + ); + let mut challenge; + let mut round_randomness = Vec::::new(); + let mut fold = None; + + loop { + let [eval_at_0, eval_at_neg1, eval_at_inf_over_x3] = sumcheck_fold_map_reduce( + [&mut eq_evaluations, &mut even_layer, &mut odd_layer], + fold, + |[eq, v0, v1]| { + [ + eq.0 * v0.0 * v1.0, + (eq.0 + eq.0 - eq.1) * (v0.0 + v0.0 - v0.1) * (v1.0 + v1.0 - v1.1), + (eq.1 - eq.0) * (v0.1 - v0.0) * (v1.1 - v1.0), + ] + }, + ); + + if fold.is_some() { + eq_evaluations.truncate(eq_evaluations.len() / 2); + even_layer.truncate(even_layer.len() / 2); + odd_layer.truncate(odd_layer.len() / 2); + } + + let poly_coeffs = reconstruct_cubic_from_evaluations( + sumcheck_claim, + eval_at_0, + eval_at_neg1, + eval_at_inf_over_x3, + ); + + ensure!( + sumcheck_claim + == poly_coeffs[0] + + poly_coeffs[0] + + poly_coeffs[1] + + poly_coeffs[2] + + poly_coeffs[3], + "Sumcheck binding check failed" + ); + + for coeff in &poly_coeffs { + merlin.prover_message(coeff); + } + challenge = merlin.verifier_message(); + + fold = Some(challenge); + sumcheck_claim = eval_cubic_poly(poly_coeffs, challenge); + round_randomness.push(challenge); + + if eq_evaluations.len() <= 2 { + break; + } + } + + let final_v0 = even_layer[0] + (even_layer[1] - even_layer[0]) * challenge; + let final_v1 = odd_layer[0] + (odd_layer[1] - odd_layer[0]) * challenge; + let final_v2 = eq_evaluations[0] + (eq_evaluations[1] - eq_evaluations[0]) * challenge; + + ensure!( + sumcheck_claim == final_v0 * final_v1 * final_v2, + "GPA sumcheck claim mismatch" + ); + + let line_coeffs = [final_v0, final_v1 - final_v0]; + + for c in &line_coeffs { + merlin.prover_message(c); + } + + let line_challenge: FieldElement = merlin.verifier_message(); + let next_claim = line_coeffs[0] + line_coeffs[1] * line_challenge; + round_randomness.push(line_challenge); + + Ok((next_claim, round_randomness)) +} + +fn reconstruct_cubic_from_evaluations( + binding_value: FieldElement, + at_0: FieldElement, + at_neg1: FieldElement, + at_inf_over_x3: FieldElement, +) -> [FieldElement; 4] { + let mut coeffs = [FieldElement::from(0u64); 4]; + + coeffs[0] = at_0; + coeffs[2] = HALF * (binding_value + at_neg1 - at_0 - at_0 - at_0); + coeffs[3] = at_inf_over_x3; + coeffs[1] = binding_value - coeffs[0] - coeffs[0] - coeffs[3] - coeffs[2]; + + coeffs +} + +fn split_even_odd(input: Vec) -> (Vec, Vec) { + input + .chunks_exact(2) + .map(|chunk| (chunk[0], chunk[1])) + .unzip() +} + +pub struct GPASumcheckResult { + pub claimed_values: Vec, + pub last_sumcheck_value: FieldElement, + pub randomness: Vec, +} + +#[instrument(skip_all)] +pub fn gpa_sumcheck_verifier2( + arthur: &mut VerifierState<'_, TranscriptSponge>, + height_of_binary_tree: usize, +) -> anyhow::Result { + let mut prev_randomness; + let mut current_randomness = Vec::::new(); + + let claimed_0: FieldElement = arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let claimed_1: FieldElement = arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let claimed_values = [claimed_0, claimed_1]; + + let line_challenge: FieldElement = arthur.verifier_message(); + + let mut sumcheck_value = eval_line(&claimed_values, &line_challenge); + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + + for layer_idx in 1..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + let cubic_coeffs: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let sumcheck_challenge: FieldElement = arthur.verifier_message(); + + ensure!( + eval_cubic_poly(cubic_coeffs, FieldElement::from(0u64)) + + eval_cubic_poly(cubic_coeffs, FieldElement::from(1u64)) + == sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge); + } + + let line_coeffs: [FieldElement; 2] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let line_challenge: FieldElement = arthur.verifier_message(); + + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::from(0u64)) + * eval_line(&line_coeffs, &FieldElement::from(1u64)); + ensure!( + expected_line_value == sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge); + } + + let claimed_values = [claimed_values[0], claimed_values[0] + claimed_values[1]].to_vec(); + + Ok(GPASumcheckResult { + claimed_values, + last_sumcheck_value: sumcheck_value, + randomness: prev_randomness, + }) +} + +#[instrument(skip_all)] +pub fn gpa_sumcheck_verifier4( + arthur: &mut VerifierState<'_, TranscriptSponge>, + height_of_binary_tree: usize, +) -> anyhow::Result { + let claimed_values: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let r0: FieldElement = arthur.verifier_message(); + let r1: FieldElement = arthur.verifier_message(); + let mut prev_randomness = vec![r0, r1]; + let mut current_randomness = Vec::::new(); + + let mut sumcheck_value = claimed_values[0] + + claimed_values[1] * prev_randomness[1] + + claimed_values[2] * prev_randomness[0] + + claimed_values[3] * prev_randomness[0] * prev_randomness[1]; + + for layer_idx in 2..height_of_binary_tree - 1 { + for _ in 0..layer_idx { + let cubic_coeffs: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let sumcheck_challenge: FieldElement = arthur.verifier_message(); + + ensure!( + eval_cubic_poly(cubic_coeffs, FieldElement::from(0u64)) + + eval_cubic_poly(cubic_coeffs, FieldElement::from(1u64)) + == sumcheck_value, + "Sumcheck verification failed at layer {layer_idx}" + ); + + current_randomness.push(sumcheck_challenge); + sumcheck_value = eval_cubic_poly(cubic_coeffs, sumcheck_challenge); + } + + let line_coeffs: [FieldElement; 2] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let line_challenge: FieldElement = arthur.verifier_message(); + + let expected_line_value = calculate_eq(&prev_randomness, ¤t_randomness) + * eval_line(&line_coeffs, &FieldElement::from(0u64)) + * eval_line(&line_coeffs, &FieldElement::from(1u64)); + ensure!( + expected_line_value == sumcheck_value, + "Line evaluation mismatch" + ); + + current_randomness.push(line_challenge); + prev_randomness = current_randomness; + current_randomness = Vec::new(); + sumcheck_value = eval_line(&line_coeffs, &line_challenge); + } + + let claimed_values = [ + claimed_values[0], + claimed_values[0] + claimed_values[1], + claimed_values[0] + claimed_values[2], + claimed_values[0] + claimed_values[1] + claimed_values[2] + claimed_values[3], + ] + .to_vec(); + + Ok(GPASumcheckResult { + claimed_values, + last_sumcheck_value: sumcheck_value, + randomness: prev_randomness, + }) +} + +pub fn eval_line(poly: &[FieldElement], point: &FieldElement) -> FieldElement { + poly[0] + *point * poly[1] +} + +pub fn calculate_adr(randomness: &[FieldElement]) -> FieldElement { + randomness + .iter() + .rev() + .enumerate() + .fold(FieldElement::from(0u64), |acc, (i, &r)| { + acc + r * FieldElement::from(1u64 << i) + }) +} diff --git a/provekit/spark/src/lib.rs b/provekit/spark/src/lib.rs new file mode 100644 index 000000000..42b6085c0 --- /dev/null +++ b/provekit/spark/src/lib.rs @@ -0,0 +1,17 @@ +pub mod gpa; +pub mod memory; +pub mod prover; +pub mod sumcheck; +pub mod types; +pub mod utils; +pub mod verifier; + +pub use { + prover::{SPARKProver, SPARKScheme as SPARKProverScheme}, + types::{ + MatrixDimensions, SPARKProof, SPARKWHIRConfigs, SparkCommitments, SparkPreparedData, + SparkWitnesses, + }, + utils::calculate_memory, + verifier::{SPARKScheme as SPARKVerifierScheme, SPARKVerifier}, +}; diff --git a/provekit/spark/src/memory.rs b/provekit/spark/src/memory.rs new file mode 100644 index 000000000..825176f53 --- /dev/null +++ b/provekit/spark/src/memory.rs @@ -0,0 +1,156 @@ +use { + crate::{ + gpa::{calculate_adr, gpa_sumcheck_verifier2, run_gpa2}, + types::WhirWitness, + }, + anyhow::{ensure, Result}, + ark_std::One, + provekit_common::{FieldElement, TranscriptSponge, WhirConfig}, + rayon::prelude::*, + std::borrow::Cow, + tracing::instrument, + whir::{ + algebra::{linear_form::MultilinearExtension, multilinear_extend}, + protocols::irs_commit::Commitment, + transcript::{ProverState, VerifierState}, + }, +}; + +pub struct AxisConfig<'a> { + pub eq_memory: &'a [FieldElement], + pub final_timestamp: &'a [FieldElement], + pub whir_config: &'a WhirConfig, +} + +#[instrument(skip_all)] +pub fn prove_axis( + merlin: &mut ProverState, + config: AxisConfig<'_>, + final_ts_witness: &WhirWitness, + gamma: &FieldElement, + tau: &FieldElement, +) -> Result<()> { + let gamma_sq = *gamma * *gamma; + + let (init_vec, final_vec) = rayon::join( + || { + config + .eq_memory + .par_iter() + .enumerate() + .map(|(i, &v)| { + let a = FieldElement::from(i as u64); + a * gamma_sq + v * gamma - tau + }) + .collect::>() + }, + || { + config + .eq_memory + .par_iter() + .zip(config.final_timestamp.par_iter()) + .enumerate() + .map(|(i, (&v, &t))| { + let a = FieldElement::from(i as u64); + a * gamma_sq + v * gamma + t - tau + }) + .collect::>() + }, + ); + + let gpa_randomness = run_gpa2(merlin, &init_vec, &final_vec)?; + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(1); + + let final_ts_eval = multilinear_extend(config.final_timestamp, evaluation_randomness); + merlin.prover_hint_ark(&final_ts_eval); + + produce_whir_proof( + merlin, + evaluation_randomness, + &[config.final_timestamp], + config.whir_config, + final_ts_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +pub fn verify_axis( + arthur: &mut VerifierState<'_, TranscriptSponge>, + num_axis_items: usize, + whir_config: &WhirConfig, + finalts_commitment: Commitment, + init_mem_fn: impl Fn(&[FieldElement]) -> FieldElement, + tau: &FieldElement, + gamma: &FieldElement, + claimed_rs: &FieldElement, + claimed_ws: &FieldElement, +) -> Result<()> { + let gpa_result = gpa_sumcheck_verifier2( + arthur, + provekit_common::utils::next_power_of_two(num_axis_items) + 2, + )?; + + let claimed_init = gpa_result.claimed_values[0]; + let claimed_final = gpa_result.claimed_values[1]; + let (last_randomness, evaluation_randomness) = gpa_result.randomness.split_at(1); + + let gamma_sq = *gamma * *gamma; + + let init_adr = calculate_adr(evaluation_randomness); + let init_mem = init_mem_fn(evaluation_randomness); + let init_opening = init_adr * gamma_sq + init_mem * gamma - tau; + + let final_cntr: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let eval_weight = MultilinearExtension::new(evaluation_randomness.to_vec()); + let finalts_claim = whir_config + .verify(arthur, &[&finalts_commitment], &[final_cntr]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed: {e}"))?; + finalts_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for final timestamps: {e}"))?; + + let final_opening = init_adr * gamma_sq + init_mem * gamma + final_cntr - tau; + + let evaluated_value = init_opening * (FieldElement::one() - last_randomness[0]) + + final_opening * last_randomness[0]; + + ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + ensure!(claimed_init * claimed_ws == claimed_final * claimed_rs); + + Ok(()) +} + +#[instrument(skip_all)] +pub fn produce_whir_proof( + merlin: &mut ProverState, + evaluation_point: &[FieldElement], + vectors: &[&[FieldElement]], + config: &WhirConfig, + witness: &WhirWitness, +) -> Result<()> { + let lf = MultilinearExtension::new(evaluation_point.to_vec()); + + let evaluations: Vec = vectors + .iter() + .map(|v| multilinear_extend(v, evaluation_point)) + .collect(); + + _ = config.prove( + merlin, + vectors.iter().map(|v| Cow::Borrowed(*v)).collect(), + vec![Cow::Owned(witness.clone())], + vec![Box::new(lf) + as Box< + dyn whir::algebra::linear_form::LinearForm, + >], + Cow::Borrowed(&evaluations), + ); + + Ok(()) +} diff --git a/provekit/spark/src/prover.rs b/provekit/spark/src/prover.rs new file mode 100644 index 000000000..37b24b890 --- /dev/null +++ b/provekit/spark/src/prover.rs @@ -0,0 +1,416 @@ +use { + crate::{ + gpa::run_gpa4, + memory::{produce_whir_proof, prove_axis, AxisConfig}, + sumcheck::run_spark_sumcheck, + types::{ + EValuesForMatrix, MatrixDimensions, Memory, SPARKProof, SPARKWHIRConfigs, + SerializableCommitment, SparkMatrix, SparkPreparedData, WhirWitness, + }, + utils::calculate_memory, + }, + anyhow::Result, + ark_ff::Field, + provekit_common::{ + spark::R1CSSparkQuery, utils::next_power_of_two, FieldElement, TranscriptSponge, WhirConfig, + }, + rayon::{join, prelude::*}, + tracing::instrument, + whir::{ + algebra::multilinear_extend, + parameters::ProtocolParameters, + transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierMessage}, + }, +}; + +pub trait SPARKProver { + fn prove(&self, spark_data: &SparkPreparedData, request: &R1CSSparkQuery) + -> Result; +} + +pub struct SPARKScheme { + pub whir_configs: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +pub fn new_whir_config_for_size(log_size: usize, batch_size: usize) -> WhirConfig { + let nv = log_size.max(4); + + let whir_params = ProtocolParameters { + unique_decoding: false, + initial_folding_factor: 3, + security_level: 128, + pow_bits: 10, + folding_factor: 3, + starting_log_inv_rate: 2, + batch_size, + hash_id: whir::hash::SHA2, + }; + + WhirConfig::new(1 << nv, &whir_params) +} + +impl SPARKScheme { + pub fn new_for_r1cs(r1cs: &provekit_common::R1CS) -> Self { + let num_rows = 2 * r1cs.num_constraints(); + let num_cols = 2 * r1cs.num_witnesses(); + let nonzero_terms = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + Self::new(num_rows, num_cols, nonzero_terms) + } + + pub fn new(num_rows: usize, num_cols: usize, nonzero_terms: usize) -> Self { + let padded_num_entries = 1 << next_power_of_two(nonzero_terms); + + let row_config = new_whir_config_for_size(next_power_of_two(num_rows), 1); + let col_config = new_whir_config_for_size(next_power_of_two(num_cols), 1); + let num_terms_1batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 1); + let num_terms_2batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 2); + let num_terms_4batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 4); + + Self { + whir_configs: SPARKWHIRConfigs { + row: row_config, + col: col_config, + num_terms_1batched: num_terms_1batched_config, + num_terms_2batched: num_terms_2batched_config, + num_terms_4batched: num_terms_4batched_config, + }, + matrix_dimensions: MatrixDimensions { + num_rows, + num_cols, + nonzero_terms, + }, + } + } +} + +/// Challenges drawn from the Fiat-Shamir transcript during proving. +struct Challenges { + gamma: FieldElement, + tau: FieldElement, +} + +impl SPARKProver for SPARKScheme { + #[instrument(skip_all)] + fn prove( + &self, + spark_data: &SparkPreparedData, + request: &R1CSSparkQuery, + ) -> Result { + let padded_num_entries = spark_data.matrix.coo.val.len(); + + let ds = DomainSeparator::protocol(&self.whir_configs).instance(&Empty); + let mut merlin = ProverState::new(&ds, TranscriptSponge::default()); + + let memory = calculate_memory( + request.matrix_batching_randomness + / (FieldElement::ONE + request.matrix_batching_randomness), + &request.point_to_evaluate.row, + &request.point_to_evaluate.col, + ); + + let claimed_value = (request.claimed_value + / (FieldElement::ONE + request.matrix_batching_randomness)) + / (FieldElement::ONE + request.matrix_batching_randomness); + + let (e_rx, e_ry) = rayon::join( + || { + spark_data.matrix.coo.row[..padded_num_entries] + .par_iter() + .map(|&r| memory.eq_rx[r]) + .collect() + }, + || { + spark_data.matrix.coo.col[..padded_num_entries] + .par_iter() + .map(|&c| memory.eq_ry[c]) + .collect() + }, + ); + + let e_values = EValuesForMatrix { e_rx, e_ry }; + + prove_spark( + &mut merlin, + spark_data, + &e_values, + claimed_value, + &memory, + &self.whir_configs, + )?; + + let proof = merlin.proof(); + Ok(SPARKProof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + whir_params: self.whir_configs.clone(), + matrix_dimensions: self.matrix_dimensions.clone(), + }) + } +} + +#[instrument(skip_all)] +fn prove_spark( + merlin: &mut ProverState, + data: &SparkPreparedData, + e_values: &EValuesForMatrix, + claimed_value: FieldElement, + memory: &Memory, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + replay_commitment( + merlin, + &data.commitments.vals, + &whir_configs.num_terms_1batched, + ); + replay_commitment( + merlin, + &data.commitments.rs_ws, + &whir_configs.num_terms_4batched, + ); + replay_commitment(merlin, &data.commitments.final_row_ts, &whir_configs.row); + replay_commitment(merlin, &data.commitments.final_col_ts, &whir_configs.col); + + let e_values_witness = commit_e_values(merlin, whir_configs, e_values); + + spark_sumcheck( + merlin, + &data.matrix, + e_values, + claimed_value, + &e_values_witness, + &data.witnesses.vals_witness, + whir_configs, + )?; + + let tau: FieldElement = merlin.verifier_message(); + let gamma: FieldElement = merlin.verifier_message(); + let challenges = Challenges { tau, gamma }; + + run_rs_ws_gpa_and_proofs( + merlin, + &data.matrix, + e_values, + &e_values_witness, + &data.witnesses.rs_ws_witness, + whir_configs, + &challenges, + )?; + + prove_axis( + merlin, + AxisConfig { + eq_memory: &memory.eq_rx, + final_timestamp: &data.matrix.timestamps.final_row, + whir_config: &whir_configs.row, + }, + &data.witnesses.final_row_ts_witness, + &challenges.gamma, + &challenges.tau, + )?; + + prove_axis( + merlin, + AxisConfig { + eq_memory: &memory.eq_ry, + final_timestamp: &data.matrix.timestamps.final_col, + whir_config: &whir_configs.col, + }, + &data.witnesses.final_col_ts_witness, + &challenges.gamma, + &challenges.tau, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn spark_sumcheck( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + claimed_value: FieldElement, + e_values_witness: &WhirWitness, + vals_witness: &WhirWitness, + whir_configs: &SPARKWHIRConfigs, +) -> Result<()> { + let mles: [&[FieldElement]; 3] = [&matrix.coo.val, &e_values.e_rx, &e_values.e_ry]; + let (sumcheck_final_folds, folding_randomness) = + run_spark_sumcheck(merlin, mles, claimed_value)?; + + merlin.prover_hint_ark(&[ + sumcheck_final_folds[0], + sumcheck_final_folds[1], + sumcheck_final_folds[2], + ]); + + produce_whir_proof( + merlin, + &folding_randomness, + &[&e_values.e_rx, &e_values.e_ry], + &whir_configs.num_terms_2batched, + e_values_witness, + )?; + + produce_whir_proof( + merlin, + &folding_randomness, + &[&matrix.coo.val], + &whir_configs.num_terms_1batched, + vals_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn run_rs_ws_gpa_and_proofs( + merlin: &mut ProverState, + matrix: &SparkMatrix, + e_values: &EValuesForMatrix, + e_values_witness: &WhirWitness, + rs_ws_witness: &WhirWitness, + whir_configs: &SPARKWHIRConfigs, + challenges: &Challenges, +) -> Result<()> { + let gamma_sq = challenges.gamma * challenges.gamma; + let one = FieldElement::from(1u64); + + let row_field = &matrix.coo.row_field; + let col_field = &matrix.coo.col_field; + let n = row_field.len(); + let m = col_field.len(); + + let (row_pairs, col_pairs) = tracing::info_span!("build_rs_ws_pairs").in_scope(|| { + join( + || { + (0..n) + .into_par_iter() + .map(|i| { + let a = row_field[i]; + let v = e_values.e_rx[i]; + let t = matrix.timestamps.read_row[i]; + let base = a * gamma_sq + v * challenges.gamma + t - challenges.tau; + (base, base + one) + }) + .collect::>() + }, + || { + (0..m) + .into_par_iter() + .map(|i| { + let a = col_field[i]; + let v = e_values.e_ry[i]; + let t = matrix.timestamps.read_col[i]; + let base = a * gamma_sq + v * challenges.gamma + t - challenges.tau; + (base, base + one) + }) + .collect::>() + }, + ) + }); + let (row_rs_vec, row_ws_vec): (Vec<_>, Vec<_>) = row_pairs.into_iter().unzip(); + let (col_rs_vec, col_ws_vec): (Vec<_>, Vec<_>) = col_pairs.into_iter().unzip(); + + let mut gpa_leaves_flat = Vec::with_capacity(4 * row_rs_vec.len()); + let gpa_leaves = [row_rs_vec, row_ws_vec, col_rs_vec, col_ws_vec]; + gpa_leaves_flat.extend(gpa_leaves.into_iter().flatten()); + let gpa_randomness = run_gpa4(merlin, gpa_leaves_flat)?; + + let (_combination_randomness, evaluation_randomness) = gpa_randomness.split_at(2); + + let ((row_address_eval, row_timestamp_eval), (col_address_eval, col_timestamp_eval)) = + tracing::info_span!("multilinear_extend_rs_ws").in_scope(|| { + join( + || { + join( + || multilinear_extend(row_field, evaluation_randomness), + || multilinear_extend(&matrix.timestamps.read_row, evaluation_randomness), + ) + }, + || { + join( + || multilinear_extend(col_field, evaluation_randomness), + || multilinear_extend(&matrix.timestamps.read_col, evaluation_randomness), + ) + }, + ) + }); + + merlin.prover_hint_ark(&row_address_eval); + merlin.prover_hint_ark(&row_timestamp_eval); + merlin.prover_hint_ark(&col_address_eval); + merlin.prover_hint_ark(&col_timestamp_eval); + + let rs_ws_vecs: [&[FieldElement]; 4] = [ + &matrix.coo.row_field, + &matrix.timestamps.read_row, + &matrix.coo.col_field, + &matrix.timestamps.read_col, + ]; + + produce_whir_proof( + merlin, + evaluation_randomness, + &rs_ws_vecs, + &whir_configs.num_terms_4batched, + rs_ws_witness, + )?; + + let (row_value_eval, col_value_eval) = tracing::info_span!("multilinear_extend_e_values") + .in_scope(|| { + join( + || multilinear_extend(&e_values.e_rx, evaluation_randomness), + || multilinear_extend(&e_values.e_ry, evaluation_randomness), + ) + }); + merlin.prover_hint_ark(&row_value_eval); + merlin.prover_hint_ark(&col_value_eval); + + produce_whir_proof( + merlin, + evaluation_randomness, + &[&e_values.e_rx, &e_values.e_ry], + &whir_configs.num_terms_2batched, + e_values_witness, + )?; + + Ok(()) +} + +#[instrument(skip_all)] +fn commit_e_values( + merlin: &mut ProverState, + whir_configs: &SPARKWHIRConfigs, + e_values: &EValuesForMatrix, +) -> WhirWitness { + whir_configs + .num_terms_2batched + .commit(merlin, &[&e_values.e_rx, &e_values.e_ry]) +} + +fn replay_commitment( + merlin: &mut ProverState, + commitment: &SerializableCommitment, + config: &WhirConfig, +) { + let ic = &config.initial_committer; + + // Absorb the Merkle root + merlin.prover_message(&commitment.merkle_root); + + // Draw OOD challenge points (deterministic from transcript state) + let _oods_points: Vec = merlin.verifier_message_vec(ic.out_domain_samples); + + // Absorb OOD evaluations + for eval in &commitment.out_of_domain_evals { + merlin.prover_message(eval); + } +} diff --git a/provekit/spark/src/sumcheck.rs b/provekit/spark/src/sumcheck.rs new file mode 100644 index 000000000..293ea8309 --- /dev/null +++ b/provekit/spark/src/sumcheck.rs @@ -0,0 +1,127 @@ +use { + anyhow::{ensure, Result}, + ark_std::{One, Zero}, + provekit_common::{ + utils::{ + sumcheck::{eval_cubic_poly, sumcheck_fold_map_reduce}, + HALF, + }, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::transcript::{ProverState, VerifierMessage, VerifierState}, +}; + +#[instrument(skip_all)] +pub fn run_spark_sumcheck( + merlin: &mut ProverState, + mles: [&[FieldElement]; 3], + mut claimed_value: FieldElement, +) -> Result<([FieldElement; 3], Vec)> { + let mut sumcheck_randomness; + let mut sumcheck_randomness_accumulator = Vec::::new(); + let mut fold = None; + + let mut m0 = mles[0].to_vec(); + let mut m1 = mles[1].to_vec(); + let mut m2 = mles[2].to_vec(); + + loop { + let [hhat_i_at_0, hhat_i_at_em1, hhat_i_at_inf_over_x_cube] = + sumcheck_fold_map_reduce([&mut m0, &mut m1, &mut m2], fold, |[m0, m1, m2]| { + [ + m0.0 * m1.0 * m2.0, + (m0.0 + m0.0 - m0.1) * (m1.0 + m1.0 - m1.1) * (m2.0 + m2.0 - m2.1), + (m0.1 - m0.0) * (m1.1 - m1.0) * (m2.1 - m2.0), + ] + }); + + if fold.is_some() { + m0.truncate(m0.len() / 2); + m1.truncate(m1.len() / 2); + m2.truncate(m2.len() / 2); + } + + let mut hhat_i_coeffs = [FieldElement::zero(); 4]; + + hhat_i_coeffs[0] = hhat_i_at_0; + hhat_i_coeffs[2] = + HALF * (claimed_value + hhat_i_at_em1 - hhat_i_at_0 - hhat_i_at_0 - hhat_i_at_0); + hhat_i_coeffs[3] = hhat_i_at_inf_over_x_cube; + hhat_i_coeffs[1] = claimed_value + - hhat_i_coeffs[0] + - hhat_i_coeffs[0] + - hhat_i_coeffs[3] + - hhat_i_coeffs[2]; + + ensure!( + claimed_value + == hhat_i_coeffs[0] + + hhat_i_coeffs[0] + + hhat_i_coeffs[1] + + hhat_i_coeffs[2] + + hhat_i_coeffs[3], + "Sumcheck binding check failed" + ); + + for coeff in &hhat_i_coeffs { + merlin.prover_message(coeff); + } + sumcheck_randomness = merlin.verifier_message(); + fold = Some(sumcheck_randomness); + claimed_value = eval_cubic_poly(hhat_i_coeffs, sumcheck_randomness); + sumcheck_randomness_accumulator.push(sumcheck_randomness); + if m0.len() <= 2 { + break; + } + } + + let folded_v0 = m0[0] + (m0[1] - m0[0]) * sumcheck_randomness; + let folded_v1 = m1[0] + (m1[1] - m1[0]) * sumcheck_randomness; + let folded_v2 = m2[0] + (m2[1] - m2[0]) * sumcheck_randomness; + + Ok(( + [folded_v0, folded_v1, folded_v2], + sumcheck_randomness_accumulator, + )) +} + +#[instrument(skip_all)] +pub fn run_sumcheck_verifier_spark( + arthur: &mut VerifierState<'_, TranscriptSponge>, + variable_count: usize, + initial_sumcheck_val: FieldElement, +) -> Result<(Vec, FieldElement)> { + let mut saved_val_for_sumcheck_equality_assertion = initial_sumcheck_val; + + let mut alpha = vec![FieldElement::zero(); variable_count]; + + for i in 0..variable_count { + let hhat_i: [FieldElement; 4] = [ + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("{e}"))?, + ]; + let alpha_i: FieldElement = arthur.verifier_message(); + alpha[i] = alpha_i; + + let hhat_i_at_zero = eval_cubic_poly(hhat_i, FieldElement::zero()); + let hhat_i_at_one = eval_cubic_poly(hhat_i, FieldElement::one()); + ensure!( + saved_val_for_sumcheck_equality_assertion == hhat_i_at_zero + hhat_i_at_one, + "Sumcheck equality check failed" + ); + saved_val_for_sumcheck_equality_assertion = eval_cubic_poly(hhat_i, alpha_i); + } + + Ok((alpha, saved_val_for_sumcheck_equality_assertion)) +} diff --git a/provekit/spark/src/types.rs b/provekit/spark/src/types.rs new file mode 100644 index 000000000..5e019f19f --- /dev/null +++ b/provekit/spark/src/types.rs @@ -0,0 +1,125 @@ +#[cfg(debug_assertions)] +use whir::transcript::Interaction; +use { + provekit_common::{ + file::{ + binary_format::{SPARK_PROOF_FORMAT, SPARK_PROOF_VERSION}, + Compression, FileFormat, MaybeHashAware, + }, + utils::serde_hex, + FieldElement, HashConfig, WhirConfig, + }, + serde::{Deserialize, Serialize}, + whir::{hash::Hash, protocols::irs_commit}, +}; + +pub type WhirWitness = irs_commit::Witness; + +#[derive(Serialize, Deserialize)] +pub struct SPARKProof { + #[serde(with = "serde_hex")] + pub narg_string: Vec, + #[serde(with = "serde_hex")] + pub hints: Vec, + #[cfg(debug_assertions)] + pub pattern: Vec, + pub whir_params: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +impl FileFormat for SPARKProof { + const FORMAT: [u8; 8] = SPARK_PROOF_FORMAT; + const EXTENSION: &'static str = "sp"; + const VERSION: (u16, u16) = SPARK_PROOF_VERSION; + const COMPRESSION: Compression = Compression::Zstd; +} + +/// Impl for SPARKProof (no hash config). +impl MaybeHashAware for SPARKProof { + fn maybe_hash_config(&self) -> Option { + None + } +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct MatrixDimensions { + pub num_rows: usize, + pub num_cols: usize, + pub nonzero_terms: usize, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct SPARKWHIRConfigs { + pub row: WhirConfig, + pub col: WhirConfig, + pub num_terms_1batched: WhirConfig, + pub num_terms_2batched: WhirConfig, + pub num_terms_4batched: WhirConfig, +} + +#[derive(Debug, Clone)] +pub struct SparkMatrix { + pub coo: COOMatrix, + pub timestamps: TimeStamps, +} + +#[derive(Debug, Clone)] +pub struct COOMatrix { + pub row: Vec, + pub col: Vec, + pub row_field: Vec, + pub col_field: Vec, + pub val: Vec, +} + +#[derive(Debug, Clone)] +pub struct TimeStamps { + pub read_row: Vec, + pub read_col: Vec, + pub final_row: Vec, + pub final_col: Vec, +} + +#[derive(Clone)] +pub struct SparkWitnesses { + pub vals_witness: WhirWitness, + pub rs_ws_witness: WhirWitness, + pub final_row_ts_witness: WhirWitness, + pub final_col_ts_witness: WhirWitness, +} + +#[derive(Clone)] +pub struct SerializableCommitment { + pub merkle_root: Hash, + pub out_of_domain_points: Vec, + pub out_of_domain_evals: Vec, +} + +#[derive(Clone)] +pub struct SparkCommitments { + pub vals: SerializableCommitment, + pub rs_ws: SerializableCommitment, + pub final_row_ts: SerializableCommitment, + pub final_col_ts: SerializableCommitment, +} + +/// All data needed for SPARK proving: the R1CS matrix, witnesses, and +/// commitments. +#[derive(Clone)] +pub struct SparkPreparedData { + pub matrix: SparkMatrix, + pub witnesses: SparkWitnesses, + pub commitments: SparkCommitments, +} + +#[derive(Debug, Clone)] +pub struct Memory { + pub eq_rx: Vec, + pub eq_ry: Vec, +} + +#[derive(Debug, Clone)] +pub struct EValuesForMatrix { + pub e_rx: Vec, + pub e_ry: Vec, +} diff --git a/provekit/spark/src/utils.rs b/provekit/spark/src/utils.rs new file mode 100644 index 000000000..ff122e2fa --- /dev/null +++ b/provekit/spark/src/utils.rs @@ -0,0 +1,28 @@ +pub use crate::types::Memory; +use provekit_common::{ + utils::sumcheck::calculate_evaluations_over_boolean_hypercube_for_eq, FieldElement, +}; + +#[tracing::instrument(skip_all)] +pub fn calculate_memory( + b: FieldElement, + point_row: &[FieldElement], + point_col: &[FieldElement], +) -> Memory { + let row_point: Vec<_> = std::iter::once(b) + .chain(point_row.iter().copied()) + .collect(); + let col_point: Vec<_> = std::iter::once(b) + .chain(point_col.iter().copied()) + .collect(); + Memory { + eq_rx: calculate_evaluations_over_boolean_hypercube_for_eq( + &row_point, + 1 << row_point.len(), + ), + eq_ry: calculate_evaluations_over_boolean_hypercube_for_eq( + &col_point, + 1 << col_point.len(), + ), + } +} diff --git a/provekit/spark/src/verifier.rs b/provekit/spark/src/verifier.rs new file mode 100644 index 000000000..694bf5422 --- /dev/null +++ b/provekit/spark/src/verifier.rs @@ -0,0 +1,244 @@ +use { + crate::{ + gpa::gpa_sumcheck_verifier4, + memory::verify_axis, + sumcheck::run_sumcheck_verifier_spark, + types::{MatrixDimensions, SPARKProof, SPARKWHIRConfigs}, + }, + anyhow::{ensure, Context, Result}, + ark_ff::Field, + provekit_common::{ + spark::R1CSSparkQuery, + utils::{next_power_of_two, sumcheck::calculate_eq}, + FieldElement, TranscriptSponge, + }, + tracing::instrument, + whir::{ + algebra::linear_form::MultilinearExtension, + transcript::{codecs::Empty, DomainSeparator, Proof, VerifierMessage, VerifierState}, + }, +}; + +pub trait SPARKVerifier { + fn verify(&self, proof: SPARKProof, request: &R1CSSparkQuery) -> Result<()>; +} + +pub struct SPARKScheme { + pub whir_configs: SPARKWHIRConfigs, + pub matrix_dimensions: MatrixDimensions, +} + +impl SPARKScheme { + pub fn from_proof(proof: &SPARKProof) -> Self { + Self { + whir_configs: proof.whir_params.clone(), + matrix_dimensions: proof.matrix_dimensions.clone(), + } + } +} + +impl SPARKVerifier for SPARKScheme { + #[instrument(skip_all)] + fn verify(&self, proof: SPARKProof, request: &R1CSSparkQuery) -> Result<()> { + let ds = DomainSeparator::protocol(&self.whir_configs).instance(&Empty); + let whir_proof = Proof { + narg_string: proof.narg_string, + hints: proof.hints, + #[cfg(debug_assertions)] + pattern: proof.pattern, + }; + let mut arthur = VerifierState::new(&ds, &whir_proof, TranscriptSponge::default()); + + let claimed_value = (request.claimed_value + / (FieldElement::ONE + request.matrix_batching_randomness)) + / (FieldElement::ONE + request.matrix_batching_randomness); + + let mut new_request = request.clone(); + let b1 = request.matrix_batching_randomness + / (FieldElement::ONE + request.matrix_batching_randomness); + new_request.point_to_evaluate.row = std::iter::once(b1) + .chain(new_request.point_to_evaluate.row.clone()) + .collect(); + new_request.point_to_evaluate.col = std::iter::once(b1) + .chain(new_request.point_to_evaluate.col.clone()) + .collect(); + + verify_spark_single_matrix( + &self.whir_configs, + self.matrix_dimensions.clone(), + &mut arthur, + &new_request, + &claimed_value, + ) + } +} + +#[instrument(skip_all)] +pub(crate) fn verify_spark_single_matrix( + whir_params: &SPARKWHIRConfigs, + matrix_dimensions: MatrixDimensions, + arthur: &mut VerifierState<'_, TranscriptSponge>, + request: &R1CSSparkQuery, + claimed_value: &FieldElement, +) -> Result<()> { + let val_commitment = whir_params + .num_terms_1batched + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive val commitment: {e}"))?; + let rsws_commitment = whir_params + .num_terms_4batched + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive rsws commitment: {e}"))?; + let a_row_finalts_commitment = whir_params + .row + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive row finalts commitment: {e}"))?; + let a_col_finalts_commitment = whir_params + .col + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive col finalts commitment: {e}"))?; + let e_values_commitment = whir_params + .num_terms_2batched + .receive_commitment(arthur) + .map_err(|e| anyhow::anyhow!("Failed to receive e_values commitment: {e}"))?; + + let (randomness, last_sumcheck_value) = run_sumcheck_verifier_spark( + arthur, + next_power_of_two(matrix_dimensions.nonzero_terms), + *claimed_value, + ) + .context("While verifying SPARK sumcheck")?; + let eval_weight = MultilinearExtension::new(randomness); + + let sumcheck_hints: [FieldElement; 3] = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + ensure!(last_sumcheck_value == sumcheck_hints[0] * sumcheck_hints[1] * sumcheck_hints[2]); + + let e_values_claim = whir_params + .num_terms_2batched + .verify(arthur, &[&e_values_commitment], &[ + sumcheck_hints[1], + sumcheck_hints[2], + ]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for e_values (sumcheck): {e}"))?; + e_values_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for e_values: {e}"))?; + + let val_claim = whir_params + .num_terms_1batched + .verify(arthur, &[&val_commitment], &[sumcheck_hints[0]]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for val: {e}"))?; + val_claim + .verify([&eval_weight as &dyn whir::algebra::linear_form::LinearForm]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for val: {e}"))?; + + let tau: FieldElement = arthur.verifier_message(); + let gamma: FieldElement = arthur.verifier_message(); + + let gpa_result = gpa_sumcheck_verifier4( + arthur, + provekit_common::utils::next_power_of_two(matrix_dimensions.nonzero_terms) + 3, + )?; + + let (combination_randomness, evaluation_randomness) = gpa_result.randomness.split_at(2); + + let claimed_row_rs = gpa_result.claimed_values[0]; + let claimed_row_ws = gpa_result.claimed_values[1]; + let claimed_col_rs = gpa_result.claimed_values[2]; + let claimed_col_ws = gpa_result.claimed_values[3]; + + let row_adr: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let row_timestamp: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let col_adr: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let col_timestamp: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let gpa_eval_weight = MultilinearExtension::new(evaluation_randomness.to_vec()); + let gpa_eval_lf: &dyn whir::algebra::linear_form::LinearForm = &gpa_eval_weight; + + let rsws_claim = whir_params + .num_terms_4batched + .verify(arthur, &[&rsws_commitment], &[ + row_adr, + row_timestamp, + col_adr, + col_timestamp, + ]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for rsws: {e}"))?; + rsws_claim + .verify([gpa_eval_lf]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for rsws: {e}"))?; + + let row_mem: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + let col_mem: FieldElement = arthur + .prover_hint_ark() + .map_err(|e| anyhow::anyhow!("{e}"))?; + + let e_values_gpa_claim = whir_params + .num_terms_2batched + .verify(arthur, &[&e_values_commitment], &[row_mem, col_mem]) + .map_err(|e| anyhow::anyhow!("WHIR verify failed for e_values (GPA): {e}"))?; + e_values_gpa_claim + .verify([gpa_eval_lf]) + .map_err(|e| anyhow::anyhow!("FinalClaim check failed for e_values (GPA): {e}"))?; + + let gamma_sq = gamma * gamma; + + let row_rs_opening = row_adr * gamma_sq + row_mem * gamma + row_timestamp - tau; + let row_ws_opening = + row_adr * gamma_sq + row_mem * gamma + row_timestamp + FieldElement::from(1) - tau; + let col_rs_opening = col_adr * gamma_sq + col_mem * gamma + col_timestamp - tau; + let col_ws_opening = + col_adr * gamma_sq + col_mem * gamma + col_timestamp + FieldElement::from(1) - tau; + + let evaluated_value = row_rs_opening + * (FieldElement::from(1) - combination_randomness[0]) + * (FieldElement::from(1) - combination_randomness[1]) + + row_ws_opening + * (FieldElement::from(1) - combination_randomness[0]) + * combination_randomness[1] + + col_rs_opening + * combination_randomness[0] + * (FieldElement::from(1) - combination_randomness[1]) + + col_ws_opening * combination_randomness[0] * combination_randomness[1]; + + ensure!(evaluated_value == gpa_result.last_sumcheck_value); + + verify_axis( + arthur, + matrix_dimensions.num_rows, + &whir_params.row, + a_row_finalts_commitment, + |eval_rand| calculate_eq(&request.point_to_evaluate.row, eval_rand), + &tau, + &gamma, + &claimed_row_rs, + &claimed_row_ws, + )?; + + verify_axis( + arthur, + matrix_dimensions.num_cols, + &whir_params.col, + a_col_finalts_commitment, + |eval_rand| calculate_eq(&request.point_to_evaluate.col, eval_rand), + &tau, + &gamma, + &claimed_col_rs, + &claimed_col_ws, + )?; + + Ok(()) +} diff --git a/tooling/cli/Cargo.toml b/tooling/cli/Cargo.toml index 59369b84d..ea8c67d48 100644 --- a/tooling/cli/Cargo.toml +++ b/tooling/cli/Cargo.toml @@ -10,10 +10,12 @@ repository.workspace = true [dependencies] # Workspace crates +mavros-artifacts.workspace = true provekit-common.workspace = true provekit-gnark.workspace = true provekit-prover = { workspace = true, features = ["witness-generation", "parallel"] } provekit-r1cs-compiler.workspace = true +provekit-spark.workspace = true provekit-verifier.workspace = true # Noir language @@ -25,11 +27,14 @@ ark-ff.workspace = true # 3rd party anyhow.workspace = true +bincode.workspace = true argh.workspace = true base64.workspace = true hex.workspace = true postcard.workspace = true rayon.workspace = true +serde.workspace = true +whir.workspace = true serde_json.workspace = true tikv-jemallocator = { workspace = true, optional = true } tracing.workspace = true diff --git a/tooling/cli/src/cmd/mod.rs b/tooling/cli/src/cmd/mod.rs index faf7297cb..f8592f9b5 100644 --- a/tooling/cli/src/cmd/mod.rs +++ b/tooling/cli/src/cmd/mod.rs @@ -1,9 +1,11 @@ mod analyze_pkp; mod circuit_stats; mod generate_gnark_inputs; -mod prepare; +pub mod prepare; mod prove; +mod serve; mod show_inputs; +mod spark_protocol; mod verify; use {anyhow::Result, argh::FromArgs}; @@ -41,6 +43,7 @@ enum Commands { AnalyzePkp(analyze_pkp::Args), Prepare(prepare::Args), Prove(prove::Args), + Serve(serve::Args), CircuitStats(circuit_stats::Args), Verify(verify::Args), GenerateGnarkInputs(generate_gnark_inputs::Args), @@ -59,6 +62,7 @@ impl Command for Commands { Self::AnalyzePkp(args) => args.run(), Self::Prepare(args) => args.run(), Self::Prove(args) => args.run(), + Self::Serve(args) => args.run(), Self::CircuitStats(args) => args.run(), Self::Verify(args) => args.run(), Self::GenerateGnarkInputs(args) => args.run(), diff --git a/tooling/cli/src/cmd/prepare.rs b/tooling/cli/src/cmd/prepare.rs index 814371eb3..4554164dc 100644 --- a/tooling/cli/src/cmd/prepare.rs +++ b/tooling/cli/src/cmd/prepare.rs @@ -2,14 +2,30 @@ use { super::Command, anyhow::{Context, Result}, argh::FromArgs, - provekit_common::{file::write, HashConfig, Prover, Verifier}, + mavros_artifacts::R1CS as MavrosR1CS, + provekit_common::{ + file::write, utils::next_power_of_two, FieldElement, HashConfig, Prover, TranscriptSponge, + Verifier, WhirConfig, R1CS, + }, provekit_r1cs_compiler::{MavrosCompiler, NoirCompiler}, - std::{path::PathBuf, str::FromStr}, + provekit_spark::{ + prover::new_whir_config_for_size, + types::{COOMatrix, SerializableCommitment, SparkMatrix, TimeStamps}, + SPARKWHIRConfigs, SparkCommitments, SparkWitnesses, + }, + std::{ + path::{Path, PathBuf}, + str::FromStr, + }, tracing::instrument, + whir::{ + hash::Hash, + transcript::{ProverState, VerifierMessage, VerifierState}, + }, }; #[derive(PartialEq, Eq, Debug)] -enum Compiler { +pub enum Compiler { Noir, Mavros, } @@ -92,3 +108,282 @@ impl Command for Args { Ok(()) } } + +pub fn build_spark_r1cs_noir( + r1cs: &R1CS, + log_row: usize, + log_col: usize, + w1_size: usize, + num_challenges: usize, +) -> Result { + let is_single_commitment = num_challenges == 0; + + let original_num_entries = + r1cs.a().iter().count() + r1cs.b().iter().count() + r1cs.c().iter().count(); + + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let row_cnt = 1 << log_row; + let col_cnt = if is_single_commitment { + 1 << log_col + } else { + 1 << (1 + log_col) + }; + + let col_witness_split_offset = |c: usize| -> usize { + if !is_single_commitment && (c >= w1_size) { + (1 << log_col) - w1_size + } else { + 0 + } + }; + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for (matrix, row_offset, col_offset) in [ + (r1cs.a(), 0, 0), + (r1cs.b(), 0, col_cnt), + (r1cs.c(), row_cnt, col_cnt), + ] { + for ((r, c), v) in matrix.iter() { + row.push(r + row_offset); + col.push(c + col_offset + col_witness_split_offset(c)); + val.push(v); + } + } + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0u64)); + } + + Ok(build_spark_matrix(row, col, val, 2 * row_cnt, 2 * col_cnt)) +} + +pub fn build_spark_r1cs_mavros( + r1cs_path: &Path, + log_row: usize, + log_col: usize, + w1_size: usize, + num_challenges: usize, +) -> Result { + let is_single_commitment = num_challenges == 0; + + let r1cs_bytes = std::fs::read(r1cs_path).context("while reading R1CS file")?; + let r1cs: MavrosR1CS = + bincode::deserialize(&r1cs_bytes).context("while deserializing R1CS from bincode")?; + + let row_cnt = 1 << log_row; + let col_cnt = if is_single_commitment { + 1 << log_col + } else { + 1 << (1 + log_col) + }; + + let col_witness_split_offset = |c: usize| -> usize { + if !is_single_commitment && (c >= w1_size) { + (1 << log_col) - w1_size + } else { + 0 + } + }; + + let original_num_entries: usize = r1cs + .constraints + .iter() + .map(|r1c| r1c.a.len() + r1c.b.len() + r1c.c.len()) + .sum(); + + let padded_num_entries = 1 << next_power_of_two(original_num_entries); + let to_fill = padded_num_entries - original_num_entries; + + let (mut row, mut col, mut val) = ( + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + Vec::with_capacity(padded_num_entries), + ); + + for (i, r1c) in r1cs.constraints.iter().enumerate() { + for &(c, v) in &r1c.a { + row.push(i); + col.push(c + col_witness_split_offset(c)); + val.push(v); + } + for &(c, v) in &r1c.b { + row.push(i); + col.push(c + col_cnt + col_witness_split_offset(c)); + val.push(v); + } + for &(c, v) in &r1c.c { + row.push(i + row_cnt); + col.push(c + col_cnt + col_witness_split_offset(c)); + val.push(v); + } + } + + for _ in 0..to_fill { + row.push(0); + col.push(0); + val.push(FieldElement::from(0u64)); + } + + Ok(build_spark_matrix(row, col, val, 2 * row_cnt, 2 * col_cnt)) +} + +pub fn build_spark_matrix( + row: Vec, + col: Vec, + val: Vec, + num_rows: usize, + num_cols: usize, +) -> SparkMatrix { + let len = row.len(); + let mut read_row_counters = vec![0usize; num_rows]; + let mut read_col_counters = vec![0usize; num_cols]; + let mut read_row = Vec::with_capacity(len); + let mut read_col = Vec::with_capacity(len); + + for i in 0..len { + read_row.push(FieldElement::from(read_row_counters[row[i]] as u64)); + read_row_counters[row[i]] += 1; + read_col.push(FieldElement::from(read_col_counters[col[i]] as u64)); + read_col_counters[col[i]] += 1; + } + + let final_row = read_row_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect(); + let final_col = read_col_counters + .iter() + .map(|&x| FieldElement::from(x as u64)) + .collect(); + + let row_field = row.iter().map(|&r| FieldElement::from(r as u64)).collect(); + let col_field = col.iter().map(|&c| FieldElement::from(c as u64)).collect(); + + SparkMatrix { + coo: COOMatrix { + row, + col, + row_field, + col_field, + val, + }, + timestamps: TimeStamps { + read_row, + read_col, + final_row, + final_col, + }, + } +} + +pub struct SPARKCommitterScheme { + pub whir_configs: SPARKWHIRConfigs, +} + +impl SPARKCommitterScheme { + pub fn new(num_rows: usize, num_cols: usize, nonzero_terms: usize) -> Self { + let padded_num_entries = 1 << next_power_of_two(nonzero_terms); + + let row_config = new_whir_config_for_size(next_power_of_two(num_rows), 1); + let col_config = new_whir_config_for_size(next_power_of_two(num_cols), 1); + let num_terms_1batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 1); + let num_terms_2batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 2); + let num_terms_4batched_config = + new_whir_config_for_size(next_power_of_two(padded_num_entries), 4); + + Self { + whir_configs: SPARKWHIRConfigs { + row: row_config, + col: col_config, + num_terms_1batched: num_terms_1batched_config, + num_terms_2batched: num_terms_2batched_config, + num_terms_4batched: num_terms_4batched_config, + }, + } + } + + pub fn commit( + &self, + merlin: &mut ProverState, + matrix: &SparkMatrix, + ) -> SparkWitnesses { + let vals_witness = self + .whir_configs + .num_terms_1batched + .commit(merlin, &[&matrix.coo.val]); + + let rs_ws_witness = self.whir_configs.num_terms_4batched.commit(merlin, &[ + &matrix.coo.row_field, + &matrix.timestamps.read_row, + &matrix.coo.col_field, + &matrix.timestamps.read_col, + ]); + + let final_row_ts_witness = self + .whir_configs + .row + .commit(merlin, &[&matrix.timestamps.final_row]); + + let final_col_ts_witness = self + .whir_configs + .col + .commit(merlin, &[&matrix.timestamps.final_col]); + + SparkWitnesses { + vals_witness, + rs_ws_witness, + final_row_ts_witness, + final_col_ts_witness, + } + } +} + +pub fn extract_single_commitment( + arthur: &mut VerifierState<'_, TranscriptSponge>, + config: &WhirConfig, +) -> Result { + let ic = &config.initial_committer; + let merkle_root: Hash = arthur + .prover_message() + .map_err(|e| anyhow::anyhow!("Failed to read merkle root: {e}"))?; + let out_of_domain_points: Vec = + arthur.verifier_message_vec(ic.out_domain_samples); + let out_of_domain_evals: Vec = arthur + .prover_messages_vec(ic.out_domain_samples * ic.num_vectors) + .map_err(|e| anyhow::anyhow!("Failed to read OOD evaluations: {e}"))?; + Ok(SerializableCommitment { + merkle_root, + out_of_domain_points, + out_of_domain_evals, + }) +} + +pub fn extract_commitments( + arthur: &mut VerifierState<'_, TranscriptSponge>, + configs: &SPARKWHIRConfigs, +) -> Result { + let vals = extract_single_commitment(arthur, &configs.num_terms_1batched) + .context("while extracting vals commitment")?; + let rs_ws = extract_single_commitment(arthur, &configs.num_terms_4batched) + .context("while extracting rs_ws commitment")?; + let final_row_ts = extract_single_commitment(arthur, &configs.row) + .context("while extracting final_row_ts commitment")?; + let final_col_ts = extract_single_commitment(arthur, &configs.col) + .context("while extracting final_col_ts commitment")?; + Ok(SparkCommitments { + vals, + rs_ws, + final_row_ts, + final_col_ts, + }) +} diff --git a/tooling/cli/src/cmd/prove.rs b/tooling/cli/src/cmd/prove.rs index d494f82c8..dec3cd5b0 100644 --- a/tooling/cli/src/cmd/prove.rs +++ b/tooling/cli/src/cmd/prove.rs @@ -1,13 +1,16 @@ use { - super::Command, - anyhow::{Context, Result}, + super::{ + spark_protocol::{self, SparkRequest, SparkResponse}, + Command, + }, + anyhow::{bail, Context, Result}, argh::FromArgs, provekit_common::{ file::{read, write}, Prover, }, provekit_prover::Prove, - std::path::PathBuf, + std::{os::unix::net::UnixStream, path::PathBuf}, tracing::{info, instrument}, }; #[cfg(test)] @@ -38,6 +41,22 @@ pub struct Args { default = "PathBuf::from(\"./proof.np\")" )] proof_path: PathBuf, + + /// unix socket path of a running serve instance (enables SPARK proving) + #[argh(option)] + socket: Option, + + /// circuit name on the server (required with --socket) + #[argh(option)] + circuit: Option, + + /// output path for SPARK proof (default: spark_proof.sp) + #[argh( + option, + long = "spark-out", + default = "PathBuf::from(\"spark_proof.sp\")" + )] + spark_proof_path: PathBuf, } impl Command for Args { @@ -66,6 +85,36 @@ impl Command for Args { .context("While verifying Noir proof")?; } + // If a socket is provided, send the proof to the SPARK server + if let Some(socket) = &self.socket { + let circuit = self + .circuit + .as_ref() + .context("--circuit is required when --socket is provided")?; + + info!("Connecting to SPARK server at {socket:?}"); + let mut stream = + UnixStream::connect(socket).with_context(|| format!("connecting to {socket:?}"))?; + + let request = SparkRequest { + circuit: circuit.clone(), + noir_proof: self.proof_path.clone(), + output: self.spark_proof_path.clone(), + }; + + spark_protocol::write_message(&mut stream, &request)?; + let response: SparkResponse = spark_protocol::read_message(&mut stream)?; + + if response.ok { + info!("SPARK proof written to {:?}", self.spark_proof_path); + } else { + bail!( + "SPARK server error: {}", + response.error.unwrap_or_else(|| "unknown".to_string()) + ); + } + } + Ok(()) } } diff --git a/tooling/cli/src/cmd/serve.rs b/tooling/cli/src/cmd/serve.rs new file mode 100644 index 000000000..7863013ad --- /dev/null +++ b/tooling/cli/src/cmd/serve.rs @@ -0,0 +1,219 @@ +use { + super::{ + prepare::{self, Compiler, SPARKCommitterScheme}, + spark_protocol::{self, SparkRequest, SparkResponse}, + Command, + }, + anyhow::{Context, Result}, + argh::FromArgs, + provekit_common::{ + file::{read, write}, + HashConfig, NoirProofScheme, Prover, TranscriptSponge, Verifier, + }, + provekit_r1cs_compiler::{MavrosCompiler, NoirCompiler}, + provekit_spark::{SPARKProver as _, SPARKProverScheme, SparkPreparedData}, + std::{ + collections::HashMap, + os::unix::net::UnixListener, + path::{Path, PathBuf}, + str::FromStr, + }, + tracing::{info, instrument}, + whir::transcript::{codecs::Empty, DomainSeparator, ProverState, VerifierState}, +}; + +/// Prepare circuits and serve SPARK proofs on a Unix socket +#[derive(FromArgs, PartialEq, Debug)] +#[argh(subcommand, name = "serve")] +pub struct Args { + /// unix socket path to listen on + #[argh(option)] + socket: PathBuf, + + /// circuit to prepare, format: name:path/to/program.json (repeatable) + #[argh(option)] + circuit: Vec, + + /// compiler backend: "noir" (default) or "mavros" + #[argh(option, long = "compiler", default = "Compiler::Noir")] + compiler: Compiler, + + /// path to R1CS file (required for mavros compiler) + #[argh(option, long = "r1cs")] + r1cs_path: Option, + + /// hash algorithm for Merkle commitments (skyscraper, sha256, keccak, + /// blake3) + #[argh(option, long = "hash", default = "String::from(\"skyscraper\")")] + hash: String, + + /// output directory for .pkp and .pkv files (default: current dir) + #[argh(option, long = "output-dir", default = "PathBuf::from(\".\")")] + output_dir: PathBuf, +} + +impl Command for Args { + #[instrument(skip_all)] + fn run(&self) -> Result<()> { + let hash_config = HashConfig::from_str(&self.hash).map_err(|e| anyhow::anyhow!("{}", e))?; + + provekit_common::register_ntt(); + + let mut circuits: HashMap = HashMap::new(); + + for spec in &self.circuit { + let (name, path) = spec + .split_once(':') + .with_context(|| format!("invalid circuit spec '{spec}', expected name:path"))?; + + info!("Preparing circuit '{name}' from {path:?}"); + let (spark_data, scheme) = prepare_circuit( + Path::new(path), + &self.compiler, + self.r1cs_path.as_deref(), + hash_config, + )?; + + // Write .pkp and .pkv so provekit-prover can load them + let pkp_path = self.output_dir.join(format!("{name}.pkp")); + let pkv_path = self.output_dir.join(format!("{name}.pkv")); + + let prover = Prover::from_noir_proof_scheme(scheme.clone()); + let verifier = Verifier::from_noir_proof_scheme(scheme); + write(&prover, &pkp_path).with_context(|| format!("writing prover for '{name}'"))?; + write(&verifier, &pkv_path) + .with_context(|| format!("writing verifier for '{name}'"))?; + info!("Wrote {pkp_path:?} and {pkv_path:?}"); + + circuits.insert(name.to_string(), spark_data); + info!("Circuit '{name}' ready"); + } + + // Clean up stale socket file + let _ = std::fs::remove_file(&self.socket); + + let listener = UnixListener::bind(&self.socket) + .with_context(|| format!("binding Unix socket at {:?}", self.socket))?; + + info!( + "Server ready on {:?} with {} circuit(s)", + self.socket, + circuits.len() + ); + + for stream in listener.incoming() { + let mut stream = stream.context("accepting connection")?; + + let request: SparkRequest = spark_protocol::read_message(&mut stream)?; + let response = match handle_prove(&circuits, &request) { + Ok(()) => SparkResponse { + ok: true, + error: None, + }, + Err(e) => SparkResponse { + ok: false, + error: Some(format!("{e:?}")), + }, + }; + spark_protocol::write_message(&mut stream, &response)?; + } + + Ok(()) + } +} + +#[instrument(skip_all)] +fn prepare_circuit( + program_path: &Path, + compiler: &Compiler, + r1cs_path: Option<&Path>, + hash_config: HashConfig, +) -> Result<(SparkPreparedData, NoirProofScheme)> { + let scheme = match compiler { + Compiler::Noir => NoirCompiler::from_file(program_path, hash_config) + .context("while compiling Noir program")?, + Compiler::Mavros => { + let r1cs_path = r1cs_path.context("--r1cs is required for mavros compiler")?; + MavrosCompiler::compile(program_path, r1cs_path, hash_config) + .context("while compiling with Mavros")? + } + }; + + let whir_r1cs_scheme = match &scheme { + NoirProofScheme::Noir(s) => s.whir_for_witness.clone(), + NoirProofScheme::Mavros(s) => s.whir_for_witness.clone(), + }; + + let spark_r1cs = match &scheme { + NoirProofScheme::Noir(noir) => prepare::build_spark_r1cs_noir( + &noir.r1cs, + whir_r1cs_scheme.m_0, + whir_r1cs_scheme.m, + whir_r1cs_scheme.w1_size, + whir_r1cs_scheme.num_challenges, + )?, + NoirProofScheme::Mavros(_) => { + let r1cs_path = r1cs_path.context("--r1cs is required for mavros compiler")?; + prepare::build_spark_r1cs_mavros( + r1cs_path, + whir_r1cs_scheme.m_0, + whir_r1cs_scheme.m, + whir_r1cs_scheme.w1_size, + whir_r1cs_scheme.num_challenges, + )? + } + }; + + let num_rows = spark_r1cs.timestamps.final_row.len(); + let num_cols = spark_r1cs.timestamps.final_col.len(); + let num_nz_vals = spark_r1cs.coo.val.len(); + + let spark_committer_scheme = SPARKCommitterScheme::new(num_rows, num_cols, num_nz_vals); + let ds = DomainSeparator::protocol(&spark_committer_scheme.whir_configs).instance(&Empty); + let mut merlin = ProverState::new(&ds, TranscriptSponge::default()); + let witnesses = spark_committer_scheme.commit(&mut merlin, &spark_r1cs); + + let proof = merlin.proof(); + let mut arthur = VerifierState::new(&ds, &proof, TranscriptSponge::default()); + let commitments = + prepare::extract_commitments(&mut arthur, &spark_committer_scheme.whir_configs)?; + + let spark_data = SparkPreparedData { + matrix: spark_r1cs, + witnesses, + commitments, + }; + + Ok((spark_data, scheme)) +} + +#[instrument(skip_all, fields(circuit = %request.circuit))] +fn handle_prove( + circuits: &HashMap, + request: &SparkRequest, +) -> Result<()> { + let spark_data = circuits + .get(&request.circuit) + .with_context(|| format!("unknown circuit '{}'", request.circuit))?; + + info!("Loading NoirProof from {:?}", request.noir_proof); + let noir_proof: provekit_common::NoirProof = + read(&request.noir_proof).context("reading NoirProof")?; + let spark_query = noir_proof.r1cs_spark_query; + + let num_constraints = spark_data.matrix.timestamps.final_row.len(); + let num_witnesses = spark_data.matrix.timestamps.final_col.len(); + let num_nonzero = spark_data.matrix.coo.val.len(); + + info!("Proving ({num_constraints} constraints, {num_witnesses} witnesses)"); + let scheme = SPARKProverScheme::new(num_constraints, num_witnesses, num_nonzero); + let proof = scheme + .prove(spark_data, &spark_query) + .context("generating SPARK proof")?; + + info!("Writing proof to {:?}", request.output); + write(&proof, &request.output).context("writing spark proof")?; + + info!("Done"); + Ok(()) +} diff --git a/tooling/cli/src/cmd/spark_protocol.rs b/tooling/cli/src/cmd/spark_protocol.rs new file mode 100644 index 000000000..7dd4bbea7 --- /dev/null +++ b/tooling/cli/src/cmd/spark_protocol.rs @@ -0,0 +1,47 @@ +use { + anyhow::{Context, Result}, + serde::{de::DeserializeOwned, Deserialize, Serialize}, + std::{ + io::{Read, Write}, + path::PathBuf, + }, +}; + +#[derive(Serialize, Deserialize)] +pub struct SparkRequest { + pub circuit: String, + pub noir_proof: PathBuf, + pub output: PathBuf, +} + +#[derive(Serialize, Deserialize)] +pub struct SparkResponse { + pub ok: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +pub fn write_message(stream: &mut impl Write, msg: &impl Serialize) -> Result<()> { + let bytes = serde_json::to_vec(msg).context("serializing message")?; + stream + .write_all(&(bytes.len() as u32).to_le_bytes()) + .context("writing message length")?; + stream.write_all(&bytes).context("writing message body")?; + stream.flush().context("flushing stream")?; + Ok(()) +} + +pub fn read_message(stream: &mut impl Read) -> Result { + let mut len_buf = [0u8; 4]; + stream + .read_exact(&mut len_buf) + .context("reading message length")?; + let len = u32::from_le_bytes(len_buf) as usize; + + let mut buf = vec![0u8; len]; + stream + .read_exact(&mut buf) + .context("reading message body")?; + + serde_json::from_slice(&buf).context("parsing message JSON") +} diff --git a/tooling/cli/src/cmd/verify.rs b/tooling/cli/src/cmd/verify.rs index 213cc53df..a24b57f8f 100644 --- a/tooling/cli/src/cmd/verify.rs +++ b/tooling/cli/src/cmd/verify.rs @@ -3,9 +3,10 @@ use { anyhow::{Context, Result}, argh::FromArgs, provekit_common::{file::read, NoirProof, Verifier}, + provekit_spark::{SPARKProof, SPARKVerifier, SPARKVerifierScheme}, provekit_verifier::Verify, std::path::PathBuf, - tracing::instrument, + tracing::{info, instrument}, }; /// Verify a Noir proof @@ -19,6 +20,10 @@ pub struct Args { /// path to the proof file #[argh(positional)] proof_path: PathBuf, + + /// path to the SPARK proof file (optional) + #[argh(option, long = "spark-proof")] + spark_proof_path: Option, } impl Command for Args { @@ -37,6 +42,19 @@ impl Command for Args { .verify(&proof) .context("While verifying Noir proof")?; + // Verify the SPARK proof if provided + if let Some(spark_proof_path) = &self.spark_proof_path { + info!("Verifying SPARK proof from {spark_proof_path:?}"); + let spark_proof: SPARKProof = + read(spark_proof_path).context("while reading SPARK proof")?; + let spark_statement = proof.r1cs_spark_query; + let scheme = SPARKVerifierScheme::from_proof(&spark_proof); + scheme + .verify(spark_proof, &spark_statement) + .context("While verifying SPARK proof")?; + info!("SPARK proof verified successfully"); + } + Ok(()) } }