Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
msm_builder.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Raju], commit: 2a49eb6 }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#pragma once
8
9#include <cstddef>
10
15
16namespace bb {
17
19 public:
22 using Element = typename CycleGroup::element;
23 using AffineElement = typename CycleGroup::affine_element;
25
26 static constexpr size_t ADDITIONS_PER_ROW = bb::eccvm::ADDITIONS_PER_ROW;
27 static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR = bb::eccvm::NUM_WNAF_DIGITS_PER_SCALAR;
28
29 struct alignas(64) MSMRow {
30 uint32_t pc = 0; // decreasing point-counter, over all half-length (128 bit) scalar muls used to compute
31 // the required MSMs. however, this value is _constant_ on a given MSM and more precisely
32 // refers to the number of half-length scalar muls completed up until we have started
33 // the current MSM.
34 uint32_t msm_size = 0; // the number of points in (a.k.a. the length of) the MSM in whose computation
35 // this VM row participates
36 uint32_t msm_count = 0; // number of multiplications processed so far (not including this row) in current MSM
37 // round (a.k.a. wNAF digit slot). this specifically refers to the number of wNAF-digit
38 // * point scalar products we have looked up and accumulated.
39 uint32_t msm_round = 0; // current "round" of MSM, in {0, ..., 32 = `NUM_WNAF_DIGITS_PER_SCALAR`}. With the
40 // Straus algorithm, we proceed wNAF digit by wNAF digit, from left to right. (final
41 // round deals with the `skew` bit.)
42 bool msm_transition = false; // is 1 if the current row *starts* the processing of a different MSM, else 0.
43 bool q_add = false;
44 bool q_double = false;
45 bool q_skew = false;
46
47 // Each row in the MSM portion of the ECCVM can handle (up to) 4 point-additions.
48 // For each row in the VM we represent the point addition data via a size-4 array of
49 // AddState objects.
50 struct AddState {
51 bool add = false; // are we adding a point at this location in the VM?
52 // e.g if the MSM is of size-2 then the 3rd and 4th AddState objects will have this set
53 // to `false`.
54 int slice = 0; // wNAF slice value. This has values in {0, ..., 15} and corresponds to an odd number in the
55 // range {-15, -13, ..., 15} via the monotonic bijection.
56 AffineElement point{ 0, 0 }; // point being added into the accumulator. (This is of the form nP,
57 // where n is in {-15, -13, ..., 15}.)
58 FF lambda = 0; // when adding `point` into the accumulator via Affine point addition, the value of `lambda`
59 // (i.e., the slope of the line). (we need this as a witness in the circuit.)
60 FF collision_inverse = 0; // `collision_inverse` is used to validate we are not hitting point addition edge
61 // case exceptions, i.e., we want the VM proof to fail if we're doing a point
62 // addition where (x1 == x2). to do this, we simply provide an inverse to x1 - x2.
63 };
64 std::array<AddState, 4> add_state{ AddState{ false, 0, { 0, 0 }, 0, 0 },
65 AddState{ false, 0, { 0, 0 }, 0, 0 },
66 AddState{ false, 0, { 0, 0 }, 0, 0 },
67 AddState{ false, 0, { 0, 0 }, 0, 0 } };
68 // The accumulator here is, in general, the result of four EC additions: A + Q_1 + Q_2 + Q_3 + Q_4.
69 // We do not explicitly store the intermediate values A + Q_1, A + Q_1 + Q_2, and A + Q_1 + Q_2 + Q_3, although
70 // these values are implicitly used in the values of `AddState.lambda` and `AddState.collision_inverse`.
71
72 FF accumulator_x = 0; // `(accumulator_x, accumulator_y)` is the accumulator to which I potentially want to add
73 // the points in `add_state`.
74 FF accumulator_y = 0; // `(accumulator_x, accumulator_y)` is the accumulator to which I potentially want to add
75 // the points in `add_state`.
76 };
77
91 const std::vector<MSM>& msms, const uint32_t total_number_of_muls, const size_t num_msm_rows)
92 {
93 // To perform a scalar multiplication of a point P by a scalar x, we precompute a table of points
94 // -15P, -13P, ..., -3P, -P, P, 3P, ..., 15P
95 // When we perform a scalar multiplication, we decompose x into base-16 wNAF digits then look these precomputed
96 // values up with digit-by-digit. As we are performing lookups with the log-derivative argument, we have to
97 // record read counts. We record read counts in a table with the following structure:
98 // 1st write column = positive wNAF digits
99 // 2nd write column = negative wNAF digits
100 // the row number is a function of pc and wnaf digit:
101 // point_idx = total_number_of_muls - pc
102 // row = point_idx * rows_per_point_table + (some function of the slice value)
103 //
104 // Illustration:
105 // Block Structure:
106 // | 0 | 1 |
107 // | - | - |
108 // 1 | # | # | -1
109 // 3 | # | # | -3
110 // 5 | # | # | -5
111 // 7 | # | # | -7
112 // 9 | # | # | -9
113 // 11 | # | # | -11
114 // 13 | # | # | -13
115 // 15 | # | # | -15
116 //
117 // Table structure:
118 // | Block_{0} | <-- pc = total_number_of_muls
119 // | Block_{1} | <-- pc = total_number_of_muls-(num muls in msm 0)
120 // | ... | ...
121 // | Block_{total_number_of_muls-1} | <-- pc = num muls in last msm
122
123 const size_t num_rows_in_read_counts_table =
124 static_cast<size_t>(total_number_of_muls) *
125 (eccvm::POINT_TABLE_SIZE >> 1); // `POINT_TABLE_SIZE` is 2ʷ, where in our case w = 4. As noted above, with
126 // respect to *read counts*, we are record looking up the positive and
127 // negative odd multiples of [P] in two separate columns, each of size 2ʷ⁻¹.
128 std::array<std::vector<size_t>, 2> point_table_read_counts;
129 point_table_read_counts[0].reserve(num_rows_in_read_counts_table);
130 point_table_read_counts[1].reserve(num_rows_in_read_counts_table);
131 for (size_t i = 0; i < num_rows_in_read_counts_table; ++i) {
132 point_table_read_counts[0].emplace_back(0);
133 point_table_read_counts[1].emplace_back(0);
134 }
135
136 const auto update_read_count = [&point_table_read_counts](const size_t point_idx, const int slice) {
146 const size_t row_index_offset = point_idx * 8;
147 if (slice < 0) {
148 // negative table: T[0] = -15P, T[1] = -13P, ..., T[7] = -P
149 const auto table_index = static_cast<size_t>((slice + 15) / 2);
150 point_table_read_counts[1][row_index_offset + table_index]++;
151 } else {
152 // positive table: T[0] = 15P, T[1] = 13P, ..., T[7] = P
153 const auto table_index = static_cast<size_t>((15 - slice) / 2);
154 point_table_read_counts[0][row_index_offset + table_index]++;
155 }
156 };
157
158 // compute which row index each multiscalar multiplication will start at.
159 std::vector<size_t> msm_row_counts;
160 msm_row_counts.reserve(msms.size() + 1);
161 msm_row_counts.push_back(1);
162 // compute the point counter (i.e. the index among all single scalar muls) that each multiscalar
163 // multiplication will start at.
164 std::vector<size_t> pc_values;
165 pc_values.reserve(msms.size() + 1);
166 pc_values.push_back(total_number_of_muls);
167 for (const auto& msm : msms) {
168 const size_t num_rows_required = EccvmRowTracker::num_eccvm_msm_rows(msm.size());
169 msm_row_counts.push_back(msm_row_counts.back() + num_rows_required);
170 pc_values.push_back(pc_values.back() - msm.size());
171 }
172 BB_ASSERT_EQ(pc_values.back(), 0U);
173
174 // compute the MSM rows
175
176 std::vector<MSMRow> msm_rows(num_msm_rows);
177 // start with empty row (shiftable polynomials must have 0 as first coefficient)
178 msm_rows[0] = (MSMRow{});
179 // compute "read counts" so that we can determine the number of times entries in our log-derivative lookup
180 // tables are called.
181 // Note: this part is single-threaded. The amount of compute is low, however, so this is likely not a big
182 // concern.
183 for (size_t msm_idx = 0; msm_idx < msms.size(); ++msm_idx) {
184 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
185 auto pc = static_cast<uint32_t>(pc_values[msm_idx]);
186 const auto& msm = msms[msm_idx];
187 const size_t msm_size = msm.size();
188 const size_t num_rows_per_digit =
189 (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0);
190
191 for (size_t relative_row_idx = 0; relative_row_idx < num_rows_per_digit; ++relative_row_idx) {
192 const size_t num_points_in_row = (relative_row_idx + 1) * ADDITIONS_PER_ROW > msm_size
193 ? (msm_size % ADDITIONS_PER_ROW)
195 const size_t offset = relative_row_idx * ADDITIONS_PER_ROW;
196 for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW; ++relative_point_idx) {
197 const size_t point_idx = offset + relative_point_idx;
198 const bool add = num_points_in_row > relative_point_idx;
199 if (add) {
200 int slice = msm[point_idx].wnaf_digits[digit_idx];
201 // pc starts at total_number_of_muls and decreses non-uniformly to 0
202 update_read_count((total_number_of_muls - pc) + point_idx, slice);
203 }
204 }
205 }
206
207 // update the log-derivative read count for the lookup associated with WNAF skew
208 if (digit_idx == NUM_WNAF_DIGITS_PER_SCALAR - 1) {
209 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
210 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
211 ? (msm_size % ADDITIONS_PER_ROW)
213 const size_t offset = row_idx * ADDITIONS_PER_ROW;
214 for (size_t relative_point_idx = 0; relative_point_idx < ADDITIONS_PER_ROW;
215 ++relative_point_idx) {
216 bool add = num_points_in_row > relative_point_idx;
217 const size_t point_idx = offset + relative_point_idx;
218 if (add) {
219 // `pc` starts at total_number_of_muls and decreases non-uniformly to 0.
220 // -15 maps to the 1st point in the lookup table (array element 0)
221 // -1 maps to the point in the lookup table that corresponds to the negation of the
222 // original input point (i.e. the point we need to add into the accumulator if wnaf_skew
223 // is positive)
224 int slice = msm[point_idx].wnaf_skew ? -1 : -15;
225 update_read_count((total_number_of_muls - pc) + point_idx, slice);
226 }
227 }
228 }
229 }
230 }
231 }
232
233 // The execution trace data for the MSM columns requires knowledge of intermediate values from *affine* point
234 // addition. The naive solution to compute this data requires 2 field inversions per in-circuit group addition
235 // evaluation. This is bad! To avoid this, we split the witness computation algorithm into 3 steps.
236 // Step 1: compute the execution trace group operations in *projective* coordinates. (these will be stored in
237 // `p1_trace`, `p2_trace`, and `p3_trace`)
238 // Step 2: use batch inversion trick to convert all points into affine coordinates
239 // Step 3: populate the full execution trace, including the intermediate values from affine group
240 // operations
241 // This section sets up the data structures we need to store all intermediate ECC operations in projective form
242
243 const size_t num_point_adds_and_doubles =
244 (num_msm_rows - 2) * 4; // `num_msm_rows - 2` is the actual number of rows in the table required to compute
245 // the MSM; the msm table itself has a dummy row at the beginning and an extra row
246 // with the `x` and `y` coordinates of the accumulator at the end. (In general, the
247 // output of the accumulator from the computation at row `i` is present on row
248 // `i+1`. We multiply by 4 because each "row" of the VM processes 4 point-additions
249 // (and the fact that w = 4 means we must interleave with 4 doublings). This
250 // "corresponds" to the fact that `MSMROW.add_state` has 4 entries.
251 const size_t num_accumulators = num_msm_rows - 1; // for every row after the first row, we have an accumulator.
252 // In what follows, either p1 + p2 = p3, or p1.dbl() = p3
253 // We create 1 vector to store the entire point trace. We split into multiple containers using std::span
254 // (we want 1 vector object to more efficiently batch-normalize points)
255 static constexpr size_t NUM_POINTS_IN_ADDITION_RELATION = 3;
256 const size_t num_points_to_normalize =
257 (num_point_adds_and_doubles * NUM_POINTS_IN_ADDITION_RELATION) + num_accumulators;
258 std::vector<Element> points_to_normalize(num_points_to_normalize);
259 std::span<Element> p1_trace(&points_to_normalize[0], num_point_adds_and_doubles);
260 std::span<Element> p2_trace(&points_to_normalize[num_point_adds_and_doubles], num_point_adds_and_doubles);
261 std::span<Element> p3_trace(&points_to_normalize[num_point_adds_and_doubles * 2], num_point_adds_and_doubles);
262 // `is_double_or_add` records whether an entry in the p1/p2/p3 trace represents a point addition or
263 // doubling. if it is `true`, then we are doubling (i.e., the condition is that `p3 = p1.dbl()`), else we are
264 // adding (i.e., the condition is that `p3 = p1 + p2`).
265 std::vector<bool> is_double_or_add(num_point_adds_and_doubles);
266 // accumulator_trace tracks the value of the ECCVM accumulator for each row
267 std::span<Element> accumulator_trace(&points_to_normalize[num_point_adds_and_doubles * 3], num_accumulators);
268
269 // we start the accumulator at the offset generator point
270 constexpr auto offset_generator = get_precomputed_generators<g1, "ECCVM_OFFSET_GENERATOR", 1>()[0];
271 accumulator_trace[0] = offset_generator;
272
273 // TODO(https://github.com/AztecProtocol/barretenberg/issues/973): Reinstate multitreading?
274 // populate point trace, and the components of the MSM execution trace that do not relate to affine point
275 // operations
276 for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
277 Element accumulator = offset_generator; // for every MSM, we start with the same `offset_generator`
278 const auto& msm = msms[msm_idx]; // which MSM we are processing. This is of type `std::vector<ScalarMul>`.
279 size_t msm_row_index = msm_row_counts[msm_idx]; // the row where the given MSM starts
280 const size_t msm_size = msm.size();
281 const size_t num_rows_per_digit =
282 (msm_size / ADDITIONS_PER_ROW) +
283 (msm_size % ADDITIONS_PER_ROW !=
284 0); // the Straus algorithm proceeds by incrementing through the digit-slots and doing
285 // computations *across* the `ScalarMul`s that make up our MSM. Each digit-slot therefore
286 // contributes the *ceiling* of `msm_size`/`ADDITIONS_PER_ROW`.
287 size_t trace_index =
288 (msm_row_counts[msm_idx] - 1) * 4; // tracks the index in the traces of `p1`, `p2`, `p3`, and
289 // `accumulator_trace` that we are filling out
290
291 // for each digit-slot (`digit_idx`), and then for each row of the VM (which does `ADDITIONS_PER_ROW` point
292 // additions), we either enter in/process (`ADDITIONS_PER_ROW`) `AddState` objects, and then if necessary
293 // (i.e., if not at the last wNAF digit), process the four doublings.
294 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
295 const auto pc = static_cast<uint32_t>(pc_values[msm_idx]); // pc that our msm starts at
296
297 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
298 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
299 ? (msm_size % ADDITIONS_PER_ROW)
301 auto& row = msm_rows[msm_row_index]; // actual `MSMRow` we will fill out in the body of this loop
302 const size_t offset = row_idx * ADDITIONS_PER_ROW;
303 row.msm_transition = (digit_idx == 0) && (row_idx == 0);
304 // each iteration of this loop process/enters in one of the `AddState` objects in `row.add_state`.
305 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
306 auto& add_state = row.add_state[point_idx];
307 add_state.add = num_points_in_row > point_idx;
308 int slice = add_state.add ? msm[offset + point_idx].wnaf_digits[digit_idx] : 0;
309 // In the MSM columns in the ECCVM circuit, we can add up to 4 points per row.
310 // if `row.add_state[point_idx].add = 1`, this indicates that we want to add the
311 // `point_idx`'th point in the MSM columns into the MSM accumulator `add_state.slice` = A
312 // 4-bit WNAF slice of the scalar multiplier associated with the point we are adding (the
313 // specific slice chosen depends on the value of msm_round) (WNAF = our version of
314 // windowed-non-adjacent-form. Value range is `-15, -13,..., 15`)
315 // If `add_state.add = 1`, we want `add_state.slice` to be the *compressed*
316 // form of the WNAF slice value. (compressed = no gaps in the value range. i.e. -15,
317 // -13, ..., 15 maps to 0, ... , 15)
318 add_state.slice = add_state.add ? (slice + 15) / 2 : 0;
319 add_state.point =
320 add_state.add
321 ? msm[offset + point_idx].precomputed_table[static_cast<size_t>(add_state.slice)]
322 : AffineElement{ 0, 0 };
323
324 Element p1(accumulator);
325 Element p2(add_state.point);
326 accumulator = add_state.add ? (accumulator + add_state.point) : Element(p1);
327 p1_trace[trace_index] = p1;
328 p2_trace[trace_index] = p2;
329 p3_trace[trace_index] = accumulator;
330 is_double_or_add[trace_index] = false;
331 trace_index++;
332 }
333 // Now, `row.add_state` has been fully processed and we fill in the rest of the members of `row`.
334 accumulator_trace[msm_row_index] = accumulator;
335 row.q_add = true;
336 row.q_double = false;
337 row.q_skew = false;
338 row.msm_round = static_cast<uint32_t>(digit_idx);
339 row.msm_size = static_cast<uint32_t>(msm_size);
340 row.msm_count = static_cast<uint32_t>(offset);
341 row.pc = pc;
342 msm_row_index++;
343 }
344 // after processing each digit-slot, we now take care of doubling (as long as we are not at the last
345 // digit). We add an `MSMRow`, `row`, whose four `AddState` objects in `row.add_state`
346 // are null, but we also populate `p1_trace`, `p2_trace`, `p3_trace`, and `is_double_or_add` for four
347 // indices, corresponding to the w=4 doubling operations we need to perform. This embodies the numerical
348 // "coincidence" that `ADDITIONS_PER_ROW == NUM_WNAF_DIGIT_BITS`
349 if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) {
350 auto& row = msm_rows[msm_row_index];
351 row.msm_transition = false;
352 row.msm_round = static_cast<uint32_t>(digit_idx + 1);
353 row.msm_size = static_cast<uint32_t>(msm_size);
354 row.msm_count = static_cast<uint32_t>(0);
355 row.q_add = false;
356 row.q_double = true;
357 row.q_skew = false;
358 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
359 auto& add_state = row.add_state[point_idx];
360 add_state.add = false;
361 add_state.slice = 0;
362 add_state.point = { 0, 0 };
363 add_state.collision_inverse = 0;
364
365 p1_trace[trace_index] = accumulator;
366 p2_trace[trace_index] = accumulator; // dummy
367 accumulator = accumulator.dbl();
368 p3_trace[trace_index] = accumulator;
369 is_double_or_add[trace_index] = true;
370 trace_index++;
371 }
372 accumulator_trace[msm_row_index] = accumulator;
373 msm_row_index++;
374 } else // process `wnaf_skew`, i.e., the skew digit.
375 {
376 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
377 auto& row = msm_rows[msm_row_index];
378
379 const size_t num_points_in_row = (row_idx + 1) * ADDITIONS_PER_ROW > msm_size
380 ? msm_size % ADDITIONS_PER_ROW
382 const size_t offset = row_idx * ADDITIONS_PER_ROW;
383 row.msm_transition = false;
384 Element acc_expected = accumulator;
385 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
386 auto& add_state = row.add_state[point_idx];
387 add_state.add = num_points_in_row > point_idx;
388 add_state.slice = add_state.add ? msm[offset + point_idx].wnaf_skew ? 7 : 0 : 0;
389
390 add_state.point =
391 add_state.add
392 ? msm[offset + point_idx].precomputed_table[static_cast<size_t>(add_state.slice)]
394 0, 0
395 }; // if the skew_bit is on, `slice == 7`. Then `precomputed_table[7] == -[P]`, as
396 // required for the skew logic.
397 bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false;
398 auto p1 = accumulator;
399 accumulator = add_predicate ? accumulator + add_state.point : accumulator;
400 p1_trace[trace_index] = p1;
401 p2_trace[trace_index] = add_state.point;
402 p3_trace[trace_index] = accumulator;
403 is_double_or_add[trace_index] = false;
404 trace_index++;
405 }
406 row.q_add = false;
407 row.q_double = false;
408 row.q_skew = true;
409 row.msm_round = static_cast<uint32_t>(digit_idx + 1);
410 row.msm_size = static_cast<uint32_t>(msm_size);
411 row.msm_count = static_cast<uint32_t>(offset);
412 row.pc = pc;
413 accumulator_trace[msm_row_index] = accumulator;
414 msm_row_index++;
415 }
416 }
417 }
418 }
419
420 // Normalize the points in the point trace
421 parallel_for_range(points_to_normalize.size(), [&](size_t start, size_t end) {
422 Element::batch_normalize(&points_to_normalize[start], end - start);
423 });
424
425 // inverse_trace is used to compute the value of the `collision_inverse` column in the ECCVM.
426 std::vector<FF> inverse_trace(num_point_adds_and_doubles);
427 if (num_point_adds_and_doubles > 0) {
428 parallel_for_range(num_point_adds_and_doubles, [&](size_t start, size_t end) {
429 for (size_t operation_idx = start; operation_idx < end; ++operation_idx) {
430 if (is_double_or_add[operation_idx]) {
431 inverse_trace[operation_idx] = (p1_trace[operation_idx].y + p1_trace[operation_idx].y);
432 } else {
433 inverse_trace[operation_idx] = (p2_trace[operation_idx].x - p1_trace[operation_idx].x);
434 }
435 }
436 FF::batch_invert(&inverse_trace[start], end - start);
437 });
438 }
439
440 // complete the computation of the ECCVM execution trace, by adding the affine intermediate point data
441 // i.e. row.accumulator_x, row.accumulator_y, row.add_state[0...3].collision_inverse,
442 // row.add_state[0...3].lambda
443 for (size_t msm_idx = 0; msm_idx < msms.size(); msm_idx++) {
444 const auto& msm = msms[msm_idx];
445 size_t trace_index = ((msm_row_counts[msm_idx] - 1) * ADDITIONS_PER_ROW);
446 size_t msm_row_index = msm_row_counts[msm_idx];
447 // 1st MSM row will have accumulator equal to the previous MSM output (or point at infinity for first MSM)
448 size_t accumulator_index = msm_row_counts[msm_idx] - 1;
449 const size_t msm_size = msm.size();
450 const size_t num_rows_per_digit =
451 (msm_size / ADDITIONS_PER_ROW) + ((msm_size % ADDITIONS_PER_ROW != 0) ? 1 : 0);
452
453 for (size_t digit_idx = 0; digit_idx < NUM_WNAF_DIGITS_PER_SCALAR; ++digit_idx) {
454 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
455 auto& row = msm_rows[msm_row_index];
456 // note that we do not store the "intermediate accumulators" that are implicit *within* a row (i.e.,
457 // within a given `add_state` object). This is the reason why accumulator_index only increments once
458 // per `row_idx`.
459 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
460 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
461 row.accumulator_x = normalized_accumulator.x;
462 row.accumulator_y = normalized_accumulator.y;
463 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
464 auto& add_state = row.add_state[point_idx];
465
466 const auto& inverse = inverse_trace[trace_index];
467 const auto& p1 = p1_trace[trace_index];
468 const auto& p2 = p2_trace[trace_index];
469 add_state.collision_inverse = add_state.add ? inverse : 0;
470 add_state.lambda = add_state.add ? (p2.y - p1.y) * inverse : 0;
471 trace_index++;
472 }
473 accumulator_index++;
474 msm_row_index++;
475 }
476
477 // if digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1 we have to fill out our doubling row (which in fact
478 // amounts to 4 doublings)
479 if (digit_idx < NUM_WNAF_DIGITS_PER_SCALAR - 1) {
480 MSMRow& row = msm_rows[msm_row_index];
481 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
482 const FF& acc_x = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.x;
483 const FF& acc_y = normalized_accumulator.is_point_at_infinity() ? 0 : normalized_accumulator.y;
484 row.accumulator_x = acc_x;
485 row.accumulator_y = acc_y;
486 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
487 auto& add_state = row.add_state[point_idx];
488 add_state.collision_inverse = 0; // no notion of "different x values" for a point doubling
489 const FF& dx = p1_trace[trace_index].x;
490 const FF& inverse = inverse_trace[trace_index]; // here, 2y
491 add_state.lambda = ((dx + dx + dx) * dx) * inverse;
492 trace_index++;
493 }
494 accumulator_index++;
495 msm_row_index++;
496 } else // this row corresponds to performing point additions to handle WNAF skew
497 // i.e. iterate over all the points in the MSM - if for a given point, `wnaf_skew == 1`,
498 // subtract the original point from the accumulator. if `digit_idx == NUM_WNAF_DIGITS_PER_SCALAR
499 // - 1` we have finished executing our double-and-add algorithm.
500 {
501 for (size_t row_idx = 0; row_idx < num_rows_per_digit; ++row_idx) {
502 MSMRow& row = msm_rows[msm_row_index];
503 const Element& normalized_accumulator = accumulator_trace[accumulator_index];
504 BB_ASSERT_EQ(normalized_accumulator.is_point_at_infinity(), 0);
505 const size_t offset = row_idx * ADDITIONS_PER_ROW;
506 row.accumulator_x = normalized_accumulator.x;
507 row.accumulator_y = normalized_accumulator.y;
508 for (size_t point_idx = 0; point_idx < ADDITIONS_PER_ROW; ++point_idx) {
509 auto& add_state = row.add_state[point_idx];
510 bool add_predicate = add_state.add ? msm[offset + point_idx].wnaf_skew : false;
511
512 const auto& inverse = inverse_trace[trace_index];
513 const auto& p1 = p1_trace[trace_index];
514 const auto& p2 = p2_trace[trace_index];
515 add_state.collision_inverse = add_predicate ? inverse : 0;
516 add_state.lambda = add_predicate ? (p2.y - p1.y) * inverse : 0;
517 trace_index++;
518 }
519 accumulator_index++;
520 msm_row_index++;
521 }
522 }
523 }
524 }
525
526 // populate the final row in the MSM execution trace.
527 // we always require 1 extra row at the end of the trace, because the x and y coordinates of the accumulator for
528 // row `i` are present at row `i+1`
529 Element final_accumulator(accumulator_trace.back());
530 MSMRow& final_row = msm_rows.back();
531 final_row.pc = static_cast<uint32_t>(pc_values.back());
532 final_row.msm_transition = true;
533 final_row.accumulator_x = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.x;
534 final_row.accumulator_y = final_accumulator.is_point_at_infinity() ? 0 : final_accumulator.y;
535 final_row.msm_size = 0;
536 final_row.msm_count = 0;
537 final_row.q_add = false;
538 final_row.q_double = false;
539 final_row.q_skew = false;
540 final_row.add_state = { typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
541 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
542 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 },
543 typename MSMRow::AddState{ false, 0, AffineElement{ 0, 0 }, 0, 0 } };
544
545 return { msm_rows, point_table_read_counts };
546 }
547};
548} // namespace bb
#define BB_ASSERT_EQ(actual, expected,...)
Definition assert.hpp:83
static constexpr size_t ADDITIONS_PER_ROW
static constexpr size_t NUM_WNAF_DIGITS_PER_SCALAR
static std::tuple< std::vector< MSMRow >, std::array< std::vector< size_t >, 2 > > compute_rows(const std::vector< MSM > &msms, const uint32_t total_number_of_muls, const size_t num_msm_rows)
Computes the row values for the Straus MSM columns of the ECCVM.
curve::BN254::Group CycleGroup
typename CycleGroup::affine_element AffineElement
bb::eccvm::MSM< CycleGroup > MSM
typename CycleGroup::element Element
static uint32_t num_eccvm_msm_rows(const size_t msm_size)
Get the number of rows in the 'msm' column section of the ECCVM associated with a single multiscalar ...
typename bb::g1 Group
Definition bn254.hpp:20
ssize_t offset
Definition engine.cpp:62
std::vector< ScalarMul< CycleGroup > > MSM
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
group< fq, fr, Bn254G1Params > g1
Definition g1.hpp:34
C slice(C const &container, size_t start)
Definition container.hpp:9
constexpr std::span< const typename Group::affine_element > get_precomputed_generators()
void parallel_for_range(size_t num_points, const std::function< void(size_t, size_t)> &func, size_t no_multhreading_if_less_or_equal)
Split a loop into several loops running in parallel.
Definition thread.cpp:141
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
std::array< AddState, 4 > add_state
static void batch_invert(C &coeffs) noexcept
Batch invert a collection of field elements using Montgomery's trick.