Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
keccak.cpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Nishat], commit: }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#include "keccak.hpp"
15namespace bb::stdlib {
16
17using namespace bb::plookup;
18
35template <typename Builder>
36template <size_t lane_index>
38{
39 // left_bits = the number of bits that wrap around 11^{KECCAK_LANE_SIZE} (left_bits)
40 constexpr size_t left_bits = ROTATIONS[lane_index];
41
42 // right_bits = the number of bits that don't wrap
43 constexpr size_t right_bits = KECCAK_LANE_SIZE - ROTATIONS[lane_index];
44
45 // Matches the maximum bits per slice (Rho<>::MAXIMUM_MULTITABLE_BITS) used by KECCAK_RHO multitables
46 constexpr size_t max_bits_per_table = plookup::keccak_tables::Rho<>::MAXIMUM_MULTITABLE_BITS;
47
48 // compute the number of lookups required for our left and right bit slices
49 constexpr size_t num_left_tables = left_bits / max_bits_per_table + (left_bits % max_bits_per_table > 0 ? 1 : 0);
50 constexpr size_t num_right_tables = right_bits / max_bits_per_table + (right_bits % max_bits_per_table > 0 ? 1 : 0);
51
52 // get the numerical value of the left and right bit slices
53 // (lookup table input values derived from left / right)
54 uint256_t input = limb.get_value();
55 constexpr uint256_t slice_divisor = BASE.pow(right_bits);
56 const auto [left, right] = input.divmod(slice_divisor);
57
58 // compute the normalized values for the left and right bit slices
59 // (lookup table output values derived from left_normalised / right_normalized)
60 uint256_t left_normalized = normalize_sparse(left);
61 uint256_t right_normalized = normalize_sparse(right);
62
103
104 // compute plookup witness values for a given slice
105 // (same lambda can be used to compute witnesses for left and right slices)
106 auto compute_lookup_witnesses_for_limb = [&]<size_t limb_bits, size_t num_lookups>(uint256_t& normalized) {
107 // (use a constexpr loop to make some pow and div operations compile-time)
108 bb::constexpr_for<0, num_lookups, 1>([&]<size_t i> {
109 constexpr size_t num_bits_processed = i * max_bits_per_table;
110
111 // How many bits can this slice contain?
112 // We want to implicitly range-constrain `normalized < 11^{limb_bits}`,
113 // which means potentially using a lookup table that is not of size 11^{max_bits_per_table}
114 // for the most-significant slice
115 constexpr size_t bit_slice = (num_bits_processed + max_bits_per_table > limb_bits)
116 ? limb_bits % max_bits_per_table
117 : max_bits_per_table;
118
119 // current column values are tracked via 'input' and 'normalized'
120 lookup[ColumnIdx::C1].push_back(input);
121 lookup[ColumnIdx::C2].push_back(normalized);
122
123 constexpr uint64_t divisor = numeric::pow64(static_cast<uint64_t>(BASE), bit_slice);
124 constexpr uint64_t msb_divisor = divisor / static_cast<uint64_t>(BASE);
125
126 // compute the value of the most significant bit of this slice and store in C3
127 const auto [normalized_quotient, normalized_slice] = normalized.divmod(divisor);
128
129 // 256-bit divisions are expensive! cast to u64s when we don't need the extra bits
130 const uint64_t normalized_msb = (static_cast<uint64_t>(normalized_slice) / msb_divisor);
131 lookup[ColumnIdx::C3].push_back(normalized_msb);
132
133 // We need to provide a key/value object for this lookup in order for the Builder
134 // to compute the plookup sorted list commitment
135 const auto [input_quotient, input_slice] = input.divmod(divisor);
136 lookup.lookup_entries.push_back(
137 { { static_cast<uint64_t>(input_slice), 0 }, { normalized_slice, normalized_msb } });
138
139 // reduce the input and output by 11^{bit_slice}
140 input = input_quotient;
141 normalized = normalized_quotient;
142 });
143 };
144
145 // template lambda syntax is a little funky.
146 // Need to explicitly write `.template operator()` (instead of just `()`).
147 // Otherwise compiler cannot distinguish between `>` symbol referring to closing the template parameter list,
148 // OR `>` being a greater-than operator :/
149 compute_lookup_witnesses_for_limb.template operator()<right_bits, num_right_tables>(right_normalized);
150 compute_lookup_witnesses_for_limb.template operator()<left_bits, num_left_tables>(left_normalized);
151
152 // Call builder method to create plookup constraints.
153 // The MultiTable table index can be derived from `lane_idx`
154 // Each lane_idx has a different rotation amount, which changes sizes of left/right slices
155 // and therefore the selector constants required (i.e. the Q1, Q2, Q3 values in the earlier example)
156 const auto accumulator_witnesses = limb.context->create_gates_from_plookup_accumulators(
157 (plookup::MultiTableId)((size_t)KECCAK_NORMALIZE_AND_ROTATE + lane_index), lookup, limb.get_witness_index());
158
159 // extract the most significant bit of the normalized output from the final lookup entry in column C3
161 accumulator_witnesses[ColumnIdx::C3][num_left_tables + num_right_tables - 1]);
162
163 // Extract the witness that maps to the normalized right slice
164 const field_t<Builder> right_output =
165 field_t<Builder>::from_witness_index(limb.get_context(), accumulator_witnesses[ColumnIdx::C2][0]);
166
167 if (num_left_tables == 0) {
168 // if the left slice size is 0 bits (i.e. no rotation), return `right_output`
169 return right_output;
170 } else {
171 // Extract the normalized left slice
173 limb.get_context(), accumulator_witnesses[ColumnIdx::C2][num_right_tables]);
174
175 // Stitch the right/left slices together to create our rotated output
176 constexpr uint256_t shift = BASE.pow(ROTATIONS[lane_index]);
177 return (left_output + right_output * shift);
178 }
179}
180
197template <typename Builder> void keccak<Builder>::compute_twisted_state(keccak_state& internal)
198{
199 for (size_t i = 0; i < NUM_KECCAK_LANES; ++i) {
200 internal.twisted_state[i] = ((internal.state[i] * 11) + internal.state_msb[i]).normalize();
201 }
202}
203
251template <typename Builder> void keccak<Builder>::theta(keccak_state& internal)
252{
255
256 auto& state = internal.state;
257 const auto& twisted_state = internal.twisted_state;
258 for (size_t i = 0; i < 5; ++i) {
259
268 C[i] = field_ct::accumulate({ twisted_state[i],
269 twisted_state[5 + i],
270 twisted_state[10 + i],
271 twisted_state[15 + i],
272 twisted_state[20 + i] });
273 }
274
279 for (size_t i = 0; i < 5; ++i) {
280 const auto non_shifted_equivalent = (C[(i + 4) % 5]);
281 const auto shifted_equivalent = C[(i + 1) % 5] * BASE;
282 D[i] = (non_shifted_equivalent + shifted_equivalent);
283 }
284
301 static constexpr uint256_t divisor = BASE.pow(KECCAK_LANE_SIZE);
302 static constexpr uint256_t multiplicand = BASE.pow(KECCAK_LANE_SIZE + 1);
303 for (size_t i = 0; i < 5; ++i) {
304 uint256_t D_native = D[i].get_value();
305 const auto [D_quotient, lo_native] = D_native.divmod(BASE);
306 const uint256_t hi_native = D_quotient / divisor;
307 const uint256_t mid_native = D_quotient - hi_native * divisor;
308
309 field_ct hi(witness_ct(internal.context, hi_native));
310 field_ct mid(witness_ct(internal.context, mid_native));
311 field_ct lo(witness_ct(internal.context, lo_native));
312
313 // assert equal should cost 1 gate (multipliers are all constants)
314 D[i].assert_equal((hi * multiplicand).add_two(mid * 11, lo));
315 internal.context->create_small_range_constraint(hi.get_witness_index(), static_cast<uint64_t>(BASE));
316 internal.context->create_small_range_constraint(lo.get_witness_index(), static_cast<uint64_t>(BASE));
317
318 // If number of bits in KECCAK_THETA_OUTPUT table does NOT cleanly divide KECCAK_LANE_SIZE=64,
319 // we need an additional range constraint to ensure that mid < 11^64
320 static_assert(KECCAK_LANE_SIZE % plookup::keccak_tables::Theta::TABLE_BITS == 0,
321 "KECCAK_THETA_OUTPUT TABLE_BITS must divide KECCAK_LANE_SIZE.");
323 }
324
325 // compute state[j * 5 + i] XOR D[i] in base-11 representation
326 for (size_t i = 0; i < 5; ++i) {
327 for (size_t j = 0; j < 5; ++j) {
328 state[j * 5 + i] = state[j * 5 + i] + D[i];
329 }
330 }
331}
332
359template <typename Builder> void keccak<Builder>::rho(keccak_state& internal)
360{
361 constexpr_for<0, NUM_KECCAK_LANES, 1>(
362 [&]<size_t i>() { internal.state[i] = normalize_and_rotate<i>(internal.state[i], internal.state_msb[i]); });
363}
364
374template <typename Builder> void keccak<Builder>::pi(keccak_state& internal)
375{
377
378 for (size_t j = 0; j < 5; ++j) {
379 for (size_t i = 0; i < 5; ++i) {
380 B[j * 5 + i] = internal.state[j * 5 + i];
381 }
382 }
383
384 for (size_t y = 0; y < 5; ++y) {
385 for (size_t x = 0; x < 5; ++x) {
386 size_t u = (0 * x + 1 * y) % 5;
387 size_t v = (2 * x + 3 * y) % 5;
388
389 internal.state[v * 5 + u] = B[5 * y + x];
390 }
391 }
392}
393
410template <typename Builder> void keccak<Builder>::chi(keccak_state& internal)
411{
412 // (cost = 12 * 25 = 300?)
413 auto& state = internal.state;
414
415 for (size_t y = 0; y < 5; ++y) {
416 std::array<field_ct, 5> lane_outputs;
417 for (size_t x = 0; x < 5; ++x) {
418 const auto A = state[y * 5 + x];
419 const auto B = state[y * 5 + ((x + 1) % 5)];
420 const auto C = state[y * 5 + ((x + 2) % 5)];
421
422 // vv should cost 1 gate
423 lane_outputs[x] = (A + A + CHI_OFFSET).add_two(-B, C);
424 }
425 for (size_t x = 0; x < 5; ++x) {
426 // Normalize lane outputs and assign to internal.state
427 auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_CHI_OUTPUT, lane_outputs[x]);
428 internal.state[y * 5 + x] = accumulators[ColumnIdx::C2][0];
429 internal.state_msb[y * 5 + x] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
430 }
431 }
432}
433
443template <typename Builder> void keccak<Builder>::iota(keccak_state& internal, size_t round)
444{
445 const field_ct xor_result = internal.state[0] + SPARSE_RC[round];
446
447 // normalize lane value so that we don't overflow our base11 modulus boundary in the next round
448 internal.state[0] = normalize_and_rotate<0>(xor_result, internal.state_msb[0]);
449
450 // No need to add constraints to compute twisted repr if this is the last round
451 if (round != NUM_KECCAK_ROUNDS - 1) {
452 compute_twisted_state(internal);
453 }
454}
455
456template <typename Builder> void keccak<Builder>::keccakf1600(keccak_state& internal)
457{
458 for (size_t i = 0; i < NUM_KECCAK_ROUNDS; ++i) {
459 theta(internal);
460 rho(internal);
461 pi(internal);
462 chi(internal);
463 iota(internal, i);
464 }
465}
466
467// Returns the keccak f1600 permutation of the input state
468// We first convert the state into 'extended' representation, along with the 'twisted' state
469// and then we call keccakf1600() with this keccak 'internal state'
470// Finally, we convert back the state from the extented representation
471template <typename Builder>
473 std::array<field_t<Builder>, NUM_KECCAK_LANES> state, Builder* ctx)
474{
475 // populate keccak_state, convert our KECCAK_LANE_SIZE-bit lanes into an extended base-11 representation
476 keccak_state internal;
477 internal.context = ctx;
478 for (size_t i = 0; i < state.size(); ++i) {
479 const auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_FORMAT_INPUT, state[i]);
480 internal.state[i] = accumulators[ColumnIdx::C2][0];
481 internal.state_msb[i] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
482 }
483 compute_twisted_state(internal);
484 keccakf1600(internal);
485 // we convert back to the normal lanes
486 return extended_2_normal(internal);
487}
488
489// Convert the 'extended' representation of the internal Keccak state into the usual array of KECCAK_LANE_SIZE bit lanes
490template <typename Builder>
492 keccak_state& internal)
493{
494 std::array<field_t<Builder>, NUM_KECCAK_LANES> conversion;
495
496 // Each hash limb represents a little-endian integer.
497 for (size_t i = 0; i < internal.state.size(); ++i) {
499 conversion[i] = output_limb;
500 }
501
502 return conversion;
503}
504
506template class keccak<bb::MegaCircuitBuilder>;
507
508} // namespace bb::stdlib
constexpr uint256_t pow(const uint256_t &exponent) const
constexpr std::pair< uint256_t, uint256_t > divmod(const uint256_t &b) const
Container for lookup accumulator values and table reads.
Definition types.hpp:357
std::vector< BasicTable::LookupEntry > lookup_entries
Definition types.hpp:363
Generate the plookup tables used for the RHO round of the Keccak hash algorithm.
static constexpr size_t TABLE_BITS
static field_t from_witness_index(Builder *ctx, uint32_t witness_index)
Definition field.cpp:63
static field_t accumulate(const std::vector< field_t > &input)
Efficiently compute the sum of vector entries. Using big_add_gate we reduce the number of gates neede...
Definition field.cpp:1168
Builder * context
Definition field.hpp:57
Builder * get_context() const
Definition field.hpp:420
bb::fr get_value() const
Given a := *this, compute its value given by a.v * a.mul + a.add.
Definition field.cpp:829
uint32_t get_witness_index() const
Get the witness index of the current field element.
Definition field.hpp:507
static void rho(keccak_state &state)
RHO round.
Definition keccak.cpp:359
static void pi(keccak_state &state)
PI.
Definition keccak.cpp:374
static void theta(keccak_state &state)
THETA round.
Definition keccak.cpp:251
static void compute_twisted_state(keccak_state &internal)
Compute twisted representation of hash lane.
Definition keccak.cpp:197
static void chi(keccak_state &state)
CHI.
Definition keccak.cpp:410
static field_t< Builder > normalize_and_rotate(const field_ct &limb, field_ct &msb)
Normalize a base-11 limb and left-rotate by keccak::ROTATIONS[lane_index] bits. This method also extr...
Definition keccak.cpp:37
static std::array< field_ct, NUM_KECCAK_LANES > permutation_opcode(std::array< field_ct, NUM_KECCAK_LANES > state, Builder *context)
Definition keccak.cpp:472
static std::array< field_ct, NUM_KECCAK_LANES > extended_2_normal(keccak_state &internal)
Definition keccak.cpp:491
static void keccakf1600(keccak_state &state)
Definition keccak.cpp:456
static void iota(keccak_state &state, size_t round)
IOTA.
Definition keccak.cpp:443
static plookup::ReadData< field_pt > get_lookup_accumulators(const plookup::MultiTableId id, const field_pt &key_a, const field_pt &key_b=0, const bool is_2_to_1_lookup=false)
Definition plookup.cpp:19
static field_pt read_from_1_to_2_table(const plookup::MultiTableId id, const field_pt &key_a)
Definition plookup.cpp:89
bb::avm2::Column C
bn254::witness_ct witness_ct
constexpr uint64_t pow64(const uint64_t input, const uint64_t exponent)
Definition pow.hpp:13
@ KECCAK_FORMAT_INPUT
Definition types.hpp:128
@ KECCAK_FORMAT_OUTPUT
Definition types.hpp:129
@ KECCAK_NORMALIZE_AND_ROTATE
Definition types.hpp:130
@ KECCAK_CHI_OUTPUT
Definition types.hpp:127
@ KECCAK_THETA_OUTPUT
Definition types.hpp:126
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::array< field_ct, NUM_KECCAK_LANES > state
Definition keccak.hpp:148
std::array< field_ct, NUM_KECCAK_LANES > twisted_state
Definition keccak.hpp:150
std::array< field_ct, NUM_KECCAK_LANES > state_msb
Definition keccak.hpp:149