Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
keccak.cpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Nishat], commit: 5be53b6f75bac06d6d0132220044b28777021f0f }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#include "keccak.hpp"
15namespace bb::stdlib {
16
17using namespace bb::plookup;
18
35template <typename Builder>
36template <size_t lane_index>
38{
39 // left_bits = the number of bits that wrap around 11^{KECCAK_LANE_SIZE} (left_bits)
40 constexpr size_t left_bits = ROTATIONS[lane_index];
41
42 // right_bits = the number of bits that don't wrap
43 constexpr size_t right_bits = KECCAK_LANE_SIZE - ROTATIONS[lane_index];
44
45 // Matches the maximum bits per slice (Rho<>::MAXIMUM_MULTITABLE_BITS) used by KECCAK_RHO multitables
46 constexpr size_t max_bits_per_table = plookup::keccak_tables::Rho<>::MAXIMUM_MULTITABLE_BITS;
47
48 // compute the number of lookups required for our left and right bit slices
49 constexpr size_t num_left_tables = left_bits / max_bits_per_table + (left_bits % max_bits_per_table > 0 ? 1 : 0);
50 constexpr size_t num_right_tables = right_bits / max_bits_per_table + (right_bits % max_bits_per_table > 0 ? 1 : 0);
51
52 // get the numerical value of the left and right bit slices
53 // (lookup table input values derived from left / right)
54 uint256_t input = limb.get_value();
55 constexpr uint256_t slice_divisor = BASE.pow(right_bits);
56 const auto [left, right] = input.divmod(slice_divisor);
57
58 // compute the normalized values for the left and right bit slices
59 // (lookup table output values derived from left_normalised / right_normalized)
60 uint256_t left_normalized = normalize_sparse(left);
61 uint256_t right_normalized = normalize_sparse(right);
62
103
104 // compute plookup witness values for a given slice
105 // (same lambda can be used to compute witnesses for left and right slices)
106 auto compute_lookup_witnesses_for_limb = [&]<size_t limb_bits, size_t num_lookups>(uint256_t& normalized) {
107 // (use a constexpr loop to make some pow and div operations compile-time)
108 bb::constexpr_for<0, num_lookups, 1>([&]<size_t i> {
109 constexpr size_t num_bits_processed = i * max_bits_per_table;
110
111 // How many bits can this slice contain?
112 // We want to implicitly range-constrain `normalized < 11^{limb_bits}`,
113 // which means potentially using a lookup table that is not of size 11^{max_bits_per_table}
114 // for the most-significant slice
115 constexpr size_t bit_slice = (num_bits_processed + max_bits_per_table > limb_bits)
116 ? limb_bits % max_bits_per_table
117 : max_bits_per_table;
118
119 // current column values are tracked via 'input' and 'normalized'
120 lookup[ColumnIdx::C1].push_back(input);
121 lookup[ColumnIdx::C2].push_back(normalized);
122
123 constexpr uint64_t divisor = numeric::pow64(static_cast<uint64_t>(BASE), bit_slice);
124 constexpr uint64_t msb_divisor = divisor / static_cast<uint64_t>(BASE);
125
126 // compute the value of the most significant bit of this slice and store in C3
127 const auto [normalized_quotient, normalized_slice] = normalized.divmod(divisor);
128
129 // 256-bit divisions are expensive! cast to u64s when we don't need the extra bits
130 const uint64_t normalized_msb = (static_cast<uint64_t>(normalized_slice) / msb_divisor);
131 lookup[ColumnIdx::C3].push_back(normalized_msb);
132
133 // We need to provide a key/value object for this lookup in order for the Builder
134 // to compute the plookup sorted list commitment
135 const auto [input_quotient, input_slice] = input.divmod(divisor);
136 lookup.lookup_entries.push_back(
137 { { static_cast<uint64_t>(input_slice), 0 }, { normalized_slice, normalized_msb } });
138
139 // reduce the input and output by 11^{bit_slice}
140 input = input_quotient;
141 normalized = normalized_quotient;
142 });
143 };
144
145 // template lambda syntax is a little funky.
146 // Need to explicitly write `.template operator()` (instead of just `()`).
147 // Otherwise compiler cannot distinguish between `>` symbol referring to closing the template parameter list,
148 // OR `>` being a greater-than operator :/
149 compute_lookup_witnesses_for_limb.template operator()<right_bits, num_right_tables>(right_normalized);
150 compute_lookup_witnesses_for_limb.template operator()<left_bits, num_left_tables>(left_normalized);
151
152 // Call builder method to create plookup constraints.
153 // The MultiTable table index can be derived from `lane_idx`
154 // Each lane_idx has a different rotation amount, which changes sizes of left/right slices
155 // and therefore the selector constants required (i.e. the Q1, Q2, Q3 values in the earlier example)
156 const auto accumulator_witnesses = limb.context->create_gates_from_plookup_accumulators(
157 (plookup::MultiTableId)((size_t)KECCAK_NORMALIZE_AND_ROTATE + lane_index), lookup, limb.get_witness_index());
158
159 // extract the most significant bit of the normalized output from the final lookup entry in column C3
161 accumulator_witnesses[ColumnIdx::C3][num_left_tables + num_right_tables - 1]);
162
163 // Extract the witness that maps to the normalized right slice
164 const field_t<Builder> right_output =
165 field_t<Builder>::from_witness_index(limb.get_context(), accumulator_witnesses[ColumnIdx::C2][0]);
166
167 if (num_left_tables == 0) {
168 // if the left slice size is 0 bits (i.e. no rotation), return `right_output`
169 return right_output;
170 } else {
171 // Extract the normalized left slice
173 limb.get_context(), accumulator_witnesses[ColumnIdx::C2][num_right_tables]);
174
175 // Stitch the right/left slices together to create our rotated output
176 constexpr uint256_t shift = BASE.pow(ROTATIONS[lane_index]);
177 return (left_output + right_output * shift);
178 }
179}
180
197template <typename Builder> void keccak<Builder>::compute_twisted_state(keccak_state& internal)
198{
199 for (size_t i = 0; i < NUM_KECCAK_LANES; ++i) {
200 internal.twisted_state[i] = ((internal.state[i] * 11) + internal.state_msb[i]).normalize();
201 }
202}
203
251template <typename Builder> void keccak<Builder>::theta(keccak_state& internal)
252{
255
256 auto& state = internal.state;
257 const auto& twisted_state = internal.twisted_state;
258 for (size_t i = 0; i < 5; ++i) {
259
268 C[i] = field_ct::accumulate({ twisted_state[i],
269 twisted_state[5 + i],
270 twisted_state[10 + i],
271 twisted_state[15 + i],
272 twisted_state[20 + i] });
273 }
274
279 for (size_t i = 0; i < 5; ++i) {
280 const auto non_shifted_equivalent = (C[(i + 4) % 5]);
281 const auto shifted_equivalent = C[(i + 1) % 5] * BASE;
282 D[i] = (non_shifted_equivalent + shifted_equivalent);
283 }
284
301 static constexpr uint256_t divisor = BASE.pow(KECCAK_LANE_SIZE);
302 static constexpr uint256_t multiplicand = BASE.pow(KECCAK_LANE_SIZE + 1);
303 for (size_t i = 0; i < 5; ++i) {
304 uint256_t D_native = D[i].get_value();
305 const auto [D_quotient, lo_native] = D_native.divmod(BASE);
306 const uint256_t hi_native = D_quotient / divisor;
307 const uint256_t mid_native = D_quotient - hi_native * divisor;
308
309 field_ct hi(witness_ct(internal.context, hi_native));
310 field_ct mid(witness_ct(internal.context, mid_native));
311 field_ct lo(witness_ct(internal.context, lo_native));
312
313 // assert equal should cost 1 gate (multipliers are all constants)
314 D[i].assert_equal((hi * multiplicand).add_two(mid * 11, lo));
315 // Range-constrain hi and lo to valid base-11 digits [0, 10]. Using BASE - 1 because
316 // create_small_range_constraint(var, N) constrains var to [0, N] inclusive.
317 internal.context->create_small_range_constraint(hi.get_witness_index(), static_cast<uint64_t>(BASE - 1));
318 internal.context->create_small_range_constraint(lo.get_witness_index(), static_cast<uint64_t>(BASE - 1));
319
320 // If number of bits in KECCAK_THETA_OUTPUT table does NOT cleanly divide KECCAK_LANE_SIZE=64,
321 // we need an additional range constraint to ensure that mid < 11^64
322 static_assert(KECCAK_LANE_SIZE % plookup::keccak_tables::Theta::TABLE_BITS == 0,
323 "KECCAK_THETA_OUTPUT TABLE_BITS must divide KECCAK_LANE_SIZE.");
325 }
326
327 // compute state[j * 5 + i] XOR D[i] in base-11 representation
328 for (size_t i = 0; i < 5; ++i) {
329 for (size_t j = 0; j < 5; ++j) {
330 state[j * 5 + i] = state[j * 5 + i] + D[i];
331 }
332 }
333}
334
361template <typename Builder> void keccak<Builder>::rho(keccak_state& internal)
362{
363 constexpr_for<0, NUM_KECCAK_LANES, 1>(
364 [&]<size_t i>() { internal.state[i] = normalize_and_rotate<i>(internal.state[i], internal.state_msb[i]); });
365}
366
376template <typename Builder> void keccak<Builder>::pi(keccak_state& internal)
377{
379
380 for (size_t j = 0; j < 5; ++j) {
381 for (size_t i = 0; i < 5; ++i) {
382 B[j * 5 + i] = internal.state[j * 5 + i];
383 }
384 }
385
386 for (size_t y = 0; y < 5; ++y) {
387 for (size_t x = 0; x < 5; ++x) {
388 size_t u = (0 * x + 1 * y) % 5;
389 size_t v = (2 * x + 3 * y) % 5;
390
391 internal.state[v * 5 + u] = B[5 * y + x];
392 }
393 }
394}
395
412template <typename Builder> void keccak<Builder>::chi(keccak_state& internal)
413{
414 // (cost = 12 * 25 = 300?)
415 auto& state = internal.state;
416
417 for (size_t y = 0; y < 5; ++y) {
418 std::array<field_ct, 5> lane_outputs;
419 for (size_t x = 0; x < 5; ++x) {
420 const auto A = state[y * 5 + x];
421 const auto B = state[y * 5 + ((x + 1) % 5)];
422 const auto C = state[y * 5 + ((x + 2) % 5)];
423
424 // vv should cost 1 gate
425 lane_outputs[x] = (A + A + CHI_OFFSET).add_two(-B, C);
426 }
427 for (size_t x = 0; x < 5; ++x) {
428 // Normalize lane outputs and assign to internal.state
429 auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_CHI_OUTPUT, lane_outputs[x]);
430 internal.state[y * 5 + x] = accumulators[ColumnIdx::C2][0];
431 internal.state_msb[y * 5 + x] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
432 }
433 }
434}
435
445template <typename Builder> void keccak<Builder>::iota(keccak_state& internal, size_t round)
446{
447 const field_ct xor_result = internal.state[0] + SPARSE_RC[round];
448
449 // normalize lane value so that we don't overflow our base11 modulus boundary in the next round
450 internal.state[0] = normalize_and_rotate<0>(xor_result, internal.state_msb[0]);
451
452 // No need to add constraints to compute twisted repr if this is the last round
453 if (round != NUM_KECCAK_ROUNDS - 1) {
454 compute_twisted_state(internal);
455 }
456}
457
458template <typename Builder> void keccak<Builder>::keccakf1600(keccak_state& internal)
459{
460 for (size_t i = 0; i < NUM_KECCAK_ROUNDS; ++i) {
461 theta(internal);
462 rho(internal);
463 pi(internal);
464 chi(internal);
465 iota(internal, i);
466 }
467}
468
469// Returns the keccak f1600 permutation of the input state
470// We first convert the state into 'extended' representation, along with the 'twisted' state
471// and then we call keccakf1600() with this keccak 'internal state'
472// Finally, we convert back the state from the extented representation
473template <typename Builder>
475 std::array<field_t<Builder>, NUM_KECCAK_LANES> state, Builder* ctx)
476{
477 // populate keccak_state, convert our KECCAK_LANE_SIZE-bit lanes into an extended base-11 representation
478 keccak_state internal;
479 internal.context = ctx;
480 for (size_t i = 0; i < state.size(); ++i) {
481 const auto accumulators = plookup_read<Builder>::get_lookup_accumulators(KECCAK_FORMAT_INPUT, state[i]);
482 internal.state[i] = accumulators[ColumnIdx::C2][0];
483 internal.state_msb[i] = accumulators[ColumnIdx::C3][accumulators[ColumnIdx::C3].size() - 1];
484 }
485 compute_twisted_state(internal);
486 keccakf1600(internal);
487 // we convert back to the normal lanes
488 return extended_2_normal(internal);
489}
490
491// Convert the 'extended' representation of the internal Keccak state into the usual array of KECCAK_LANE_SIZE bit lanes
492template <typename Builder>
494 keccak_state& internal)
495{
496 std::array<field_t<Builder>, NUM_KECCAK_LANES> conversion;
497
498 // Each hash limb represents a little-endian integer.
499 for (size_t i = 0; i < internal.state.size(); ++i) {
501 conversion[i] = output_limb;
502 }
503
504 return conversion;
505}
506
508template class keccak<bb::MegaCircuitBuilder>;
509
510} // namespace bb::stdlib
constexpr uint256_t pow(const uint256_t &exponent) const
constexpr std::pair< uint256_t, uint256_t > divmod(const uint256_t &b) const
Container for lookup accumulator values and table reads.
Definition types.hpp:360
std::vector< BasicTable::LookupEntry > lookup_entries
Definition types.hpp:366
Generate the plookup tables used for the RHO round of the Keccak hash algorithm.
static constexpr size_t TABLE_BITS
static field_t from_witness_index(Builder *ctx, uint32_t witness_index)
Definition field.cpp:67
static field_t accumulate(const std::vector< field_t > &input)
Efficiently compute the sum of vector entries. Using big_add_gate we reduce the number of gates neede...
Definition field.cpp:1180
Builder * context
Definition field.hpp:57
Builder * get_context() const
Definition field.hpp:432
bb::fr get_value() const
Given a := *this, compute its value given by a.v * a.mul + a.add.
Definition field.cpp:838
uint32_t get_witness_index() const
Get the witness index of the current field element.
Definition field.hpp:519
static void rho(keccak_state &state)
RHO round.
Definition keccak.cpp:361
static void pi(keccak_state &state)
PI.
Definition keccak.cpp:376
static void theta(keccak_state &state)
THETA round.
Definition keccak.cpp:251
static void compute_twisted_state(keccak_state &internal)
Compute twisted representation of hash lane.
Definition keccak.cpp:197
static void chi(keccak_state &state)
CHI.
Definition keccak.cpp:412
static field_t< Builder > normalize_and_rotate(const field_ct &limb, field_ct &msb)
Normalize a base-11 limb and left-rotate by keccak::ROTATIONS[lane_index] bits. This method also extr...
Definition keccak.cpp:37
static std::array< field_ct, NUM_KECCAK_LANES > permutation_opcode(std::array< field_ct, NUM_KECCAK_LANES > state, Builder *context)
Definition keccak.cpp:474
static std::array< field_ct, NUM_KECCAK_LANES > extended_2_normal(keccak_state &internal)
Definition keccak.cpp:493
static void keccakf1600(keccak_state &state)
Definition keccak.cpp:458
static void iota(keccak_state &state, size_t round)
IOTA.
Definition keccak.cpp:445
static plookup::ReadData< field_pt > get_lookup_accumulators(const plookup::MultiTableId id, const field_pt &key_a, const field_pt &key_b=0, const bool is_2_to_1_lookup=false)
Definition plookup.cpp:19
static field_pt read_from_1_to_2_table(const plookup::MultiTableId id, const field_pt &key_a)
Definition plookup.cpp:93
bb::avm2::Column C
stdlib::witness_t< Builder > witness_ct
constexpr uint64_t pow64(const uint64_t input, const uint64_t exponent)
Definition pow.hpp:13
@ KECCAK_FORMAT_INPUT
Definition types.hpp:131
@ KECCAK_FORMAT_OUTPUT
Definition types.hpp:132
@ KECCAK_NORMALIZE_AND_ROTATE
Definition types.hpp:133
@ KECCAK_CHI_OUTPUT
Definition types.hpp:130
@ KECCAK_THETA_OUTPUT
Definition types.hpp:129
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::array< field_ct, NUM_KECCAK_LANES > state
Definition keccak.hpp:148
std::array< field_ct, NUM_KECCAK_LANES > twisted_state
Definition keccak.hpp:150
std::array< field_ct, NUM_KECCAK_LANES > state_msb
Definition keccak.hpp:149