Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
sumcheck.test.cpp
Go to the documentation of this file.
1#include "sumcheck.hpp"
4
7#include <gtest/gtest.h>
8
9using namespace bb;
10
11namespace {
12
26template <typename Flavor> typename Flavor::ProverPolynomials create_satisfiable_trace(size_t circuit_size)
27{
28 using FF = typename Flavor::FF;
31
32 ProverPolynomials full_polynomials;
33
34 // Initialize precomputed polynomials (selectors)
35 for (auto& poly : full_polynomials.get_precomputed()) {
36 poly = Polynomial(circuit_size);
37 }
38
39 // Initialize witness polynomials as shiftable (start_index = 1) to allow shifting
40 for (auto& poly : full_polynomials.get_witness()) {
41 poly = Polynomial::shiftable(circuit_size);
42 }
43
44 // Initialize shifted polynomials (will be populated by set_shifted())
45 for (auto& poly : full_polynomials.get_shifted()) {
46 poly = Polynomial(circuit_size);
47 }
48
49 // Create a simple arithmetic circuit with a few gates
50 // Row 1: Addition gate: w_l + w_r = w_o (1 + 1 = 2)
51 if (circuit_size > 1) {
52 full_polynomials.w_l.at(1) = FF(1);
53 full_polynomials.w_r.at(1) = FF(1);
54 full_polynomials.w_o.at(1) = FF(2);
55 full_polynomials.q_l.at(1) = FF(1);
56 full_polynomials.q_r.at(1) = FF(1);
57 full_polynomials.q_o.at(1) = FF(-1);
58 full_polynomials.q_arith.at(1) = FF(1);
59 }
60
61 // Row 2: Multiplication gate: w_l * w_r = w_o (2 * 2 = 4)
62 if (circuit_size > 2) {
63 full_polynomials.w_l.at(2) = FF(2);
64 full_polynomials.w_r.at(2) = FF(2);
65 full_polynomials.w_o.at(2) = FF(4);
66 full_polynomials.q_m.at(2) = FF(1);
67 full_polynomials.q_o.at(2) = FF(-1);
68 full_polynomials.q_arith.at(2) = FF(1);
69 }
70
71 // For ZK flavors: add randomness to the last rows (which will be masked by row-disabling polynomial)
72 // These rows don't need to satisfy the relation because they're disabled
73 if constexpr (Flavor::HasZK) {
74 constexpr size_t NUM_DISABLED_ROWS = 3; // Matches the number of disabled rows in ZK sumcheck
75 if (circuit_size > NUM_DISABLED_ROWS) {
76 for (size_t i = circuit_size - NUM_DISABLED_ROWS; i < circuit_size; ++i) {
77 full_polynomials.w_l.at(i) = FF::random_element();
78 full_polynomials.w_r.at(i) = FF::random_element();
79 full_polynomials.w_o.at(i) = FF::random_element();
80 full_polynomials.w_4.at(i) = FF::random_element();
81 full_polynomials.w_test_1.at(i) = FF::random_element();
82 full_polynomials.w_test_2.at(i) = FF::random_element();
83 }
84 }
85 }
86
87 // Compute shifted polynomials using the set_shifted() method
88 full_polynomials.set_shifted();
89
90 return full_polynomials;
91}
92
93template <typename Flavor> class SumcheckTests : public ::testing::Test {
94 public:
95 using FF = typename Flavor::FF;
97 using ZKData = ZKSumcheckData<Flavor>;
98
99 const size_t NUM_POLYNOMIALS = Flavor::NUM_ALL_ENTITIES;
100 static void SetUpTestSuite() { bb::srs::init_file_crs_factory(bb::srs::bb_crs_path()); }
101
102 Polynomial<FF> random_poly(size_t size)
103 {
104 auto poly = bb::Polynomial<FF>(size);
105 for (auto& coeff : poly.coeffs()) {
106 coeff = FF::random_element();
107 }
108 return poly;
109 }
110
111 ProverPolynomials construct_ultra_full_polynomials(auto& input_polynomials)
112 {
113 ProverPolynomials full_polynomials;
114 for (auto [full_poly, input_poly] : zip_view(full_polynomials.get_all(), input_polynomials)) {
115 full_poly = input_poly.share();
116 }
117 return full_polynomials;
118 }
119
120 void test_polynomial_normalization()
121 {
122 // TODO(#225)(Cody): We should not use real constants like this in the tests, at least not in so many of them.
123 const size_t multivariate_d(3);
124 const size_t multivariate_n(1 << multivariate_d);
125
126 // Randomly construct the prover polynomials that are input to Sumcheck.
127 // Note: ProverPolynomials are defined as spans so the polynomials they point to need to exist in memory.
128 std::vector<bb::Polynomial<FF>> random_polynomials(NUM_POLYNOMIALS);
129 for (auto& poly : random_polynomials) {
130 poly = random_poly(multivariate_n);
131 }
132 auto full_polynomials = construct_ultra_full_polynomials(random_polynomials);
133
134 auto transcript = Flavor::Transcript::prover_init_empty();
135
136 FF alpha = transcript->template get_challenge<FF>("Sumcheck:alpha");
137
138 std::vector<FF> gate_challenges(multivariate_d);
139 for (size_t idx = 0; idx < multivariate_d; idx++) {
140 gate_challenges[idx] =
141 transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
142 }
143
144 SumcheckProver<Flavor> sumcheck(
145 multivariate_n, full_polynomials, transcript, alpha, gate_challenges, {}, multivariate_d);
146
147 auto output = sumcheck.prove();
148
149 FF u_0 = output.challenge[0];
150 FF u_1 = output.challenge[1];
151 FF u_2 = output.challenge[2];
152
153 /* sumcheck.prove() terminates with sumcheck.multivariates.folded_polynoimals as an array such that
154 * sumcheck.multivariates.folded_polynoimals[i][0] is the evaluatioin of the i'th multivariate at the vector of
155 challenges u_i. What does this mean?
156
157 Here we show that if the multivariate is F(X0, X1, X2) defined as above, then what we get is F(u0, u1, u2) and
158 not, say F(u2, u1, u0). This is in accordance with Adrian's thesis (cf page 9).
159 */
160
161 // Check the correctness of the multilinear evaluations produced by Sumcheck by directly evaluating
162 // the full polynomials at challenge u via the evaluate_mle() function
163 std::vector<FF> u_challenge = { u_0, u_1, u_2 };
164 for (auto [full_poly, claimed_eval] :
165 zip_view(full_polynomials.get_all(), output.claimed_evaluations.get_all())) {
166 Polynomial<FF> poly(full_poly);
167 auto v_expected = poly.evaluate_mle(u_challenge);
168 EXPECT_EQ(v_expected, claimed_eval);
169 }
170 }
171
172 void test_prover()
173 {
174 const size_t multivariate_d(2);
175 const size_t multivariate_n(1 << multivariate_d);
176
177 // Randomly construct the prover polynomials that are input to Sumcheck.
178 // Note: ProverPolynomials are defined as spans so the polynomials they point to need to exist in memory.
179 std::vector<Polynomial<FF>> random_polynomials(NUM_POLYNOMIALS);
180 for (auto& poly : random_polynomials) {
181 poly = random_poly(multivariate_n);
182 }
183 auto full_polynomials = construct_ultra_full_polynomials(random_polynomials);
184
185 auto transcript = Flavor::Transcript::prover_init_empty();
186
187 FF alpha = transcript->template get_challenge<FF>("Sumcheck:alpha");
188
189 std::vector<FF> gate_challenges(multivariate_d);
190 for (size_t idx = 0; idx < gate_challenges.size(); idx++) {
191 gate_challenges[idx] =
192 transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
193 }
194
195 SumcheckProver<Flavor> sumcheck(
196 multivariate_n, full_polynomials, transcript, alpha, gate_challenges, {}, CONST_PROOF_SIZE_LOG_N);
197
199
200 if constexpr (Flavor::HasZK) {
201 ZKData zk_sumcheck_data = ZKData(multivariate_d, transcript);
202 output = sumcheck.prove(zk_sumcheck_data);
203 } else {
204 output = sumcheck.prove();
205 }
206 FF u_0 = output.challenge[0];
207 FF u_1 = output.challenge[1];
208 std::vector<FF> expected_values;
209 for (auto& polynomial_ptr : full_polynomials.get_all()) {
210 auto& polynomial = polynomial_ptr;
211 // using knowledge of inputs here to derive the evaluation
212 FF expected_lo = polynomial[0] * (FF(1) - u_0) + polynomial[1] * u_0;
213 expected_lo *= (FF(1) - u_1);
214 FF expected_hi = polynomial[2] * (FF(1) - u_0) + polynomial[3] * u_0;
215 expected_hi *= u_1;
216 expected_values.emplace_back(expected_lo + expected_hi);
217 }
218
219 for (auto [eval, expected] : zip_view(output.claimed_evaluations.get_all(), expected_values)) {
220 eval = expected;
221 }
222 }
223
224 // TODO(#225): make the inputs to this test more interesting, e.g. non-trivial permutations
225 void test_prover_verifier_flow()
226 {
227 const size_t multivariate_d(3);
228 const size_t multivariate_n(1 << multivariate_d);
229
230 const size_t virtual_log_n = 6;
231
232 auto full_polynomials = create_satisfiable_trace<Flavor>(multivariate_n);
233
234 // SumcheckTestFlavor doesn't need complex relation parameters (no permutation, lookup, etc.)
235 RelationParameters<FF> relation_parameters{};
236 auto prover_transcript = Flavor::Transcript::prover_init_empty();
237 FF prover_alpha = prover_transcript->template get_challenge<FF>("Sumcheck:alpha");
238
239 std::vector<FF> prover_gate_challenges(virtual_log_n);
240 prover_gate_challenges =
241 prover_transcript->template get_dyadic_powers_of_challenge<FF>("Sumcheck:gate_challenge", virtual_log_n);
242
243 SumcheckProver<Flavor> sumcheck_prover(multivariate_n,
244 full_polynomials,
245 prover_transcript,
246 prover_alpha,
247 prover_gate_challenges,
248 relation_parameters,
249 virtual_log_n);
250
252 if constexpr (Flavor::HasZK) {
253 ZKData zk_sumcheck_data = ZKData(multivariate_d, prover_transcript);
254 output = sumcheck_prover.prove(zk_sumcheck_data);
255 } else {
256 output = sumcheck_prover.prove();
257 }
258
259 auto verifier_transcript = Flavor::Transcript::verifier_init_empty(prover_transcript);
260
261 FF verifier_alpha = verifier_transcript->template get_challenge<FF>("Sumcheck:alpha");
262
263 auto sumcheck_verifier = SumcheckVerifier<Flavor>(verifier_transcript, verifier_alpha, virtual_log_n);
264
265 std::vector<FF> verifier_gate_challenges(virtual_log_n);
266 verifier_gate_challenges =
267 verifier_transcript->template get_dyadic_powers_of_challenge<FF>("Sumcheck:gate_challenge", virtual_log_n);
268
269 std::vector<FF> padding_indicator_array(virtual_log_n, 1);
270 if constexpr (Flavor::HasZK) {
271 for (size_t idx = 0; idx < virtual_log_n; idx++) {
272 padding_indicator_array[idx] = (idx < multivariate_d) ? FF{ 1 } : FF{ 0 };
273 }
274 }
275
276 auto verifier_output =
277 sumcheck_verifier.verify(relation_parameters, verifier_gate_challenges, padding_indicator_array);
278
279 auto verified = verifier_output.verified;
280
281 EXPECT_EQ(verified, true);
282 };
283
284 void test_failure_prover_verifier_flow()
285 {
286 // Since the last 4 rows in ZK Flavors are disabled, we extend an invalid circuit of size 4 to size 8 by padding
287 // with 0.
288 const size_t multivariate_d(3);
289 const size_t multivariate_n(1 << multivariate_d);
290
291 // Start with a satisfiable trace, then break it
292 auto full_polynomials = create_satisfiable_trace<Flavor>(multivariate_n);
293
294 // Break the circuit by changing w_l[1] from 1 to 0
295 // This makes the arithmetic relation unsatisfied:
296 // q_arith[1] * (q_l[1] * w_l[1] + q_r[1] * w_r[1] + q_o[1] * w_o[1]) = 1 * (1 * 0 + 1 * 1 + (-1) * 2) = -1 ≠
297 // 0
298 full_polynomials.w_l.at(1) = FF(0);
299
300 // SumcheckTestFlavor doesn't need complex relation parameters
301 RelationParameters<FF> relation_parameters{};
302 auto prover_transcript = Flavor::Transcript::prover_init_empty();
303 FF prover_alpha = prover_transcript->template get_challenge<FF>("Sumcheck:alpha");
304
305 auto prover_gate_challenges =
306 prover_transcript->template get_dyadic_powers_of_challenge<FF>("Sumcheck:gate_challenge", multivariate_d);
307
308 SumcheckProver<Flavor> sumcheck_prover(multivariate_n,
309 full_polynomials,
310 prover_transcript,
311 prover_alpha,
312 prover_gate_challenges,
313 relation_parameters,
314 multivariate_d);
315
317 if constexpr (Flavor::HasZK) {
318 // construct libra masking polynomials and compute auxiliary data
319 ZKData zk_sumcheck_data = ZKData(multivariate_d, prover_transcript);
320 output = sumcheck_prover.prove(zk_sumcheck_data);
321 } else {
322 output = sumcheck_prover.prove();
323 }
324
325 auto verifier_transcript = Flavor::Transcript::verifier_init_empty(prover_transcript);
326
327 FF verifier_alpha = verifier_transcript->template get_challenge<FF>("Sumcheck:alpha");
328
329 SumcheckVerifier<Flavor> sumcheck_verifier(verifier_transcript, verifier_alpha, multivariate_d);
330
331 std::vector<FF> verifier_gate_challenges(multivariate_d);
332 for (size_t idx = 0; idx < multivariate_d; idx++) {
333 verifier_gate_challenges[idx] =
334 verifier_transcript->template get_challenge<FF>("Sumcheck:gate_challenge_" + std::to_string(idx));
335 }
336
337 std::vector<FF> padding_indicator_array(multivariate_d);
338 std::ranges::fill(padding_indicator_array, FF{ 1 });
339 auto verifier_output =
340 sumcheck_verifier.verify(relation_parameters, verifier_gate_challenges, padding_indicator_array);
341
342 auto verified = verifier_output.verified;
343
344 EXPECT_EQ(verified, false);
345 };
346};
347
348// Define the FlavorTypes using SumcheckTestFlavor variants
349// Note: Only testing short monomials since full barycentric adds complexity without testing sumcheck-specific logic
350// Note: Grumpkin sumcheck requires ZK mode for commitment-based protocol (used in ECCVM/IVC)
351using FlavorTypes = testing::Types<SumcheckTestFlavor, // BN254, non-ZK, short monomials
352 SumcheckTestFlavorZK, // BN254, ZK, short monomials
353 SumcheckTestFlavorGrumpkinZK>; // Grumpkin, ZK, short monomials
354
355TYPED_TEST_SUITE(SumcheckTests, FlavorTypes);
356
357TYPED_TEST(SumcheckTests, PolynomialNormalization)
358{
359 if constexpr (!TypeParam::HasZK) {
360 this->test_polynomial_normalization();
361 } else {
362 GTEST_SKIP() << "Skipping test for ZK-enabled flavors";
363 }
364}
365// Test the prover
366TYPED_TEST(SumcheckTests, Prover)
367{
368 this->test_prover();
369}
370// Tests the prover-verifier flow
371TYPED_TEST(SumcheckTests, ProverAndVerifierSimple)
372{
373 this->test_prover_verifier_flow();
374}
375// This tests is fed an invalid circuit and checks that the verifier would output false.
376TYPED_TEST(SumcheckTests, ProverAndVerifierSimpleFailure)
377{
378 this->test_failure_prover_verifier_flow();
379}
380
381} // namespace
A container for the prover polynomials.
static constexpr bool HasZK
typename Curve::ScalarField FF
static constexpr size_t NUM_ALL_ENTITIES
Structured polynomial class that represents the coefficients 'a' of a_0 + a_1 x .....
static Polynomial shiftable(size_t virtual_size)
Utility to create a shiftable polynomial of given virtual size.
The implementation of the sumcheck Prover for statements of the form for multilinear polynomials .
Definition sumcheck.hpp:289
SumcheckOutput< Flavor > prove()
Non-ZK version: Compute round univariate, place it in transcript, compute challenge,...
Definition sumcheck.hpp:387
A flexible, minimal test flavor for sumcheck testing.
Implementation of the sumcheck Verifier for statements of the form for multilinear polynomials .
Definition sumcheck.hpp:719
typename ECCVMFlavor::ProverPolynomials ProverPolynomials
testing::Types< MegaFlavor, UltraFlavor, UltraZKFlavor, UltraRollupFlavor > FlavorTypes
std::filesystem::path bb_crs_path()
void init_file_crs_factory(const std::filesystem::path &path)
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
TYPED_TEST_SUITE(CommitmentKeyTest, Curves)
SumcheckTestFlavor_< curve::BN254, true, true > SumcheckTestFlavorZK
Zero-knowledge variant.
TYPED_TEST(CommitmentKeyTest, CommitToZeroPoly)
SumcheckTestFlavor_< curve::BN254, false, true > SumcheckTestFlavor
Base test flavor (BN254, non-ZK, short monomials)
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::string to_string(bb::avm2::ValueTag tag)
Container for parameters used by the grand product (permutation, lookup) Honk relations.
Contains the evaluations of multilinear polynomials at the challenge point . These are computed by S...
This structure is created to contain various polynomials and constants required by ZK Sumcheck.
static field random_element(numeric::RNG *engine=nullptr) noexcept
Minimal test flavors for sumcheck testing without UltraFlavor dependencies.