Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
field_impl_generic.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Completed, auditors: [Raju], commit: }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#pragma once
8
9#include <array>
10#include <cstdint>
11
12#include "./field_impl.hpp"
14
15namespace bb {
16
17// NOLINTBEGIN(readability-implicit-bool-conversion)
18template <class T>
19constexpr std::pair<uint64_t, uint64_t> field<T>::mul_wide([[maybe_unused]] uint64_t a,
20 [[maybe_unused]] uint64_t b) noexcept
21{
22#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
23 const uint128_t res = (static_cast<uint128_t>(a) * static_cast<uint128_t>(b));
24 return { static_cast<uint64_t>(res), static_cast<uint64_t>(res >> 64) };
25#else
26 static_assert(false, "mul_wide is not implemented for WASM");
27 return { 0, 0 };
28#endif
29}
33template <class T>
34constexpr uint64_t field<T>::mac([[maybe_unused]] const uint64_t a,
35 [[maybe_unused]] const uint64_t b,
36 [[maybe_unused]] const uint64_t c,
37 [[maybe_unused]] const uint64_t carry_in,
38 [[maybe_unused]] uint64_t& carry_out) noexcept
39{
40#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
41 const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c)) +
42 static_cast<uint128_t>(carry_in);
43 carry_out = static_cast<uint64_t>(res >> 64);
44 return static_cast<uint64_t>(res);
45#else
46 static_assert(false, "mac is not implemented for WASM");
47 return 0;
48#endif
49}
50
55template <class T>
56constexpr void field<T>::mac([[maybe_unused]] const uint64_t a,
57 [[maybe_unused]] const uint64_t b,
58 [[maybe_unused]] const uint64_t c,
59 [[maybe_unused]] const uint64_t carry_in,
60 [[maybe_unused]] uint64_t& out,
61 [[maybe_unused]] uint64_t& carry_out) noexcept
62{
63#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
64 const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c)) +
65 static_cast<uint128_t>(carry_in);
66 out = static_cast<uint64_t>(res);
67 carry_out = static_cast<uint64_t>(res >> 64);
68#else
69 static_assert(false, "mac is not implemented for WASM");
70#endif
71}
72
73template <class T>
74constexpr uint64_t field<T>::mac_mini([[maybe_unused]] const uint64_t a,
75 [[maybe_unused]] const uint64_t b,
76 [[maybe_unused]] const uint64_t c,
77 [[maybe_unused]] uint64_t& carry_out) noexcept
78{
79#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
80 const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
81 carry_out = static_cast<uint64_t>(res >> 64);
82 return static_cast<uint64_t>(res);
83#else
84 static_assert(false, "mac is not implemented for WASM");
85 return 0;
86#endif
87}
88
89template <class T>
90constexpr void field<T>::mac_mini([[maybe_unused]] const uint64_t a,
91 [[maybe_unused]] const uint64_t b,
92 [[maybe_unused]] const uint64_t c,
93 [[maybe_unused]] uint64_t& out,
94 [[maybe_unused]] uint64_t& carry_out) noexcept
95{
96#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
97 const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
98 out = static_cast<uint64_t>(res);
99 carry_out = static_cast<uint64_t>(res >> 64);
100#else
101 static_assert(false, "mac_mini is not implemented for WASM");
102#endif
103}
104
105template <class T>
106constexpr uint64_t field<T>::mac_discard_lo([[maybe_unused]] const uint64_t a,
107 [[maybe_unused]] const uint64_t b,
108 [[maybe_unused]] const uint64_t c) noexcept
109{
110#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
111 const uint128_t res = static_cast<uint128_t>(a) + (static_cast<uint128_t>(b) * static_cast<uint128_t>(c));
112 return static_cast<uint64_t>(res >> 64);
113#else
114 static_assert(false, "mac_discord_lo is not implemented for WASM");
115 return 0;
116#endif
117}
124template <class T>
125constexpr uint64_t field<T>::addc(const uint64_t a,
126 const uint64_t b,
127 const uint64_t carry_in,
128 uint64_t& carry_out) noexcept
129{
130#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
131 uint128_t res = static_cast<uint128_t>(a) + static_cast<uint128_t>(b) + static_cast<uint128_t>(carry_in);
132 carry_out = static_cast<uint64_t>(res >> 64);
133 return static_cast<uint64_t>(res);
134#else
135 uint64_t r = a + b;
136 const uint64_t carry_temp = r < a; // carry_temp == 1 iff a + b overflows (without the carry_in bit)
137 r += carry_in;
138 carry_out = carry_temp +
139 (r < carry_in); // (r < carry_in) iff a + b == 2^64 - 1 and carry_in == 1, which means that (r >= a)
140 return r;
141#endif
142}
151template <class T>
152constexpr uint64_t field<T>::sbb(const uint64_t a,
153 const uint64_t b,
154 const uint64_t borrow_in,
155 uint64_t& borrow_out) noexcept
156{
157#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
158 uint128_t res = static_cast<uint128_t>(a) - (static_cast<uint128_t>(b) + static_cast<uint128_t>(borrow_in >> 63));
159 borrow_out = static_cast<uint64_t>(
160 res >> 64); // consider the set of negative outputs of [0, 2^64 - 1] - [0, 2^64]; then the highest-order 64 bits
161 // are either all 0 or all 1. hence `borrow_out` is in {0, 2^64 - 1}.
162 return static_cast<uint64_t>(res);
163#else
164 uint64_t t_1 = a - (borrow_in >> 63ULL);
165 uint64_t borrow_temp_1 = t_1 > a; // 0 iff a == 0 and borrow_in is non-zero (i.e., 2^64 - 1).
166 uint64_t t_2 = t_1 - b;
167 uint64_t borrow_temp_2 = t_2 > t_1; // 0 iff b > t_1
168 borrow_out = 0ULL - (borrow_temp_1 | borrow_temp_2); // underflow if either staged underflowed.
169 return t_2;
170#endif
171}
181template <class T>
182constexpr uint64_t field<T>::square_accumulate([[maybe_unused]] const uint64_t a,
183 [[maybe_unused]] const uint64_t b,
184 [[maybe_unused]] const uint64_t c,
185 [[maybe_unused]] const uint64_t carry_in_lo,
186 [[maybe_unused]] const uint64_t carry_in_hi,
187 [[maybe_unused]] uint64_t& carry_lo,
188 [[maybe_unused]] uint64_t& carry_hi) noexcept
189{
190#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
191 const uint128_t product = static_cast<uint128_t>(b) * static_cast<uint128_t>(c);
192 const auto r0 = static_cast<uint64_t>(product); // uint64_t(b * c)
193 const auto r1 = static_cast<uint64_t>(product >> 64);
194 uint64_t out = r0 + r0;
195 carry_lo = (out < r0); // 1 iff r_0 + r_0 overflows. (r_0 = uint_64t(b * c))
196 out += a; // uint_64t(a + (2 * b * c))
197 carry_lo += (out < a); // + 1 if a + uint_64t(2 * b * c) overflows
198 out += carry_in_lo; // uint_64t(a + (2 * b * c) + carry_in_lo)
199 carry_lo += (out < carry_in_lo); // + 1 if uint_64t(a + (2 * b * c)) + carry_in_lo overflows.
200 carry_lo += r1; // + r_1 (r_1 == "high order bits of b * c")
201 carry_hi = (carry_lo < r1); // 1 if adding r_1 to carry_lo causes overflow
202 carry_lo += r1; // + r_1 (we do this twice because of 2 * (b * c))
203 carry_hi += (carry_lo < r1); // + 1 if adding r_1 causes overflow
204 carry_lo += carry_in_hi; // finally add in the input "upper bits" contribution carry_in_hi
205 carry_hi += (carry_lo < carry_in_hi); // + 1 if this caused an overflow
206 return out;
207#else
208 static_assert(false, "square_accumulate is not implemented for WASM");
209 return 0;
210#endif
211}
223template <class T> constexpr field<T> field<T>::reduce() const noexcept
224{
225 if constexpr (modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) {
226 uint256_t val{ data[0], data[1], data[2], data[3] };
227 if (val >= modulus) {
228 val -= modulus;
229 }
230 return { val.data[0], val.data[1], val.data[2], val.data[3] };
231 }
232 // not_modulus == 2^256 - modulus
233 // do limb-based add-and-carry with `not_modulus`. this yields a _constant-time_ algorithm.
234 uint64_t t0 = data[0] + not_modulus.data[0];
235 uint64_t c = t0 < data[0];
236 auto t1 = addc(data[1], not_modulus.data[1], c, c);
237 auto t2 = addc(data[2], not_modulus.data[2], c, c);
238 auto t3 = addc(data[3], not_modulus.data[3], c, c);
239 // c != 0 iff val >= modulus.
240 const uint64_t selection_mask = 0ULL - c; // 0xffffffff if we have overflowed.
241 const uint64_t selection_mask_inverse = ~selection_mask;
242 // if c == 0, then the original element is already reduced. if we overflow, we want to return the element whose
243 // limbs are {t0, t1, t2, t3}.
244 return {
245 (data[0] & selection_mask_inverse) | (t0 & selection_mask),
246 (data[1] & selection_mask_inverse) | (t1 & selection_mask),
247 (data[2] & selection_mask_inverse) | (t2 & selection_mask),
248 (data[3] & selection_mask_inverse) | (t3 & selection_mask),
249 };
250}
251
255
256// Both `add` and `sub` use constexpr branching to distinguish the cases: modulus has <= 254 bits (fields associated to
257// BN-254) and modulus has 256 bits. The former has the so-called "coarse" optimization: we allow the inputs to be in
258// the range [0, 2p) and the outputs will similarly only be constrained to [0, 2p)
259
260template <class T> constexpr field<T> field<T>::add(const field& other) const noexcept
261{
262 if constexpr (modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) {
263 uint64_t r0 = data[0] + other.data[0];
264 uint64_t c = r0 < data[0];
265 auto r1 = addc(data[1], other.data[1], c, c);
266 auto r2 = addc(data[2], other.data[2], c, c);
267 auto r3 = addc(data[3], other.data[3], c, c);
268 if (c) {
269 uint64_t b = 0;
270 r0 = sbb(r0, modulus.data[0], b, b);
271 r1 = sbb(r1, modulus.data[1], b, b);
272 r2 = sbb(r2, modulus.data[2], b, b);
273 r3 = sbb(r3, modulus.data[3], b, b);
274 // Since both values are in [0, 2^256), the result is in [0, 2^257-2). Subtracting one p might not
275 // be enough. We need to ensure that we've underflown the 0 and that might require subtracting an additional
276 // p. This can only happen if at least one of the two arguments has uint256_t-element (derived from limbs)
277 // LARGER than p (i.e., non-reduced).
278 if (!b) {
279 b = 0;
280 r0 = sbb(r0, modulus.data[0], b, b);
281 r1 = sbb(r1, modulus.data[1], b, b);
282 r2 = sbb(r2, modulus.data[2], b, b);
283 r3 = sbb(r3, modulus.data[3], b, b);
284 }
285 }
286 // if c != 0, i.e., if there was no carry, we do no additional processing. Note that this means that the output
287 // might be larger than p, even if the original self and other were in the range [0, p). This is witnessed in
288 // the test AddYieldsLimbsBiggerThanModulus.
289 return { r0, r1, r2, r3 };
290 } else {
291 uint64_t r0 = data[0] + other.data[0];
292 uint64_t c = r0 < data[0];
293 auto r1 = addc(data[1], other.data[1], c, c);
294 auto r2 = addc(data[2], other.data[2], c, c);
295 uint64_t r3 = data[3] + other.data[3] +
296 c; // in the small modulus branch so this will satisfy the right size bounds: both self
297 // and other are in the range [0, 2p), which means their sum is in [0, 4p-1).
298
299 uint64_t t0 = r0 + twice_not_modulus.data[0];
300 c = t0 < twice_not_modulus.data[0];
301 uint64_t t1 = addc(r1, twice_not_modulus.data[1], c, c);
302 uint64_t t2 = addc(r2, twice_not_modulus.data[2], c, c);
303 uint64_t t3 = addc(r3, twice_not_modulus.data[3], c, c);
304 // c == 1 iff self + other >= 2 * p.
305 // if c == 0, then return the r_i (naive sum still in coarse form), if c == 1, return the t_i.
306 const uint64_t selection_mask = 0ULL - c;
307 const uint64_t selection_mask_inverse = ~selection_mask;
308
309 field result{
310 (r0 & selection_mask_inverse) | (t0 & selection_mask),
311 (r1 & selection_mask_inverse) | (t1 & selection_mask),
312 (r2 & selection_mask_inverse) | (t2 & selection_mask),
313 (r3 & selection_mask_inverse) | (t3 & selection_mask),
314 };
316 result.assert_coarse_form();
317 }
318 return result;
319 }
320}
321
322template <class T> constexpr field<T> field<T>::subtract(const field& other) const noexcept
323{
324 uint64_t borrow = 0;
325 uint64_t r0 = sbb(data[0], other.data[0], borrow, borrow);
326 uint64_t r1 = sbb(data[1], other.data[1], borrow, borrow);
327 uint64_t r2 = sbb(data[2], other.data[2], borrow, borrow);
328 uint64_t r3 = sbb(data[3], other.data[3], borrow, borrow);
329
330 // recall that borrow is in the size-2 set {0, 2^64 - 1}.
331 if constexpr (modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) {
332 // add the modulus if borrow != 0, i.e., if other > self as uint256_t.
333 r0 += (modulus.data[0] & borrow);
334 uint64_t carry = r0 < (modulus.data[0] & borrow);
335 r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
336 r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
337 r3 = addc(r3, modulus.data[3] & borrow, carry, carry);
338 // The value being subtracted is in [0, 2^256); it is possible that adding one copy of
339 // p still leaves us with a negative number. To check if we might need to add another copy of p, we check if
340 // `carry == 0`; this means that (if we are "in the borrow branch"), the addition did not 2^256-overflow, which
341 // means we are still negative. If we not in the borrow branch (i.e., if `borrow == 0`), `carry == 0` and we add
342 // nothing using the
343 // `& borrow` trick for the `addc` argument.
344 if (!carry) {
345 r0 += (modulus.data[0] & borrow);
346 uint64_t carry = r0 < (modulus.data[0] & borrow);
347 r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
348 r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
349 r3 = addc(r3, (modulus.data[3] & borrow), carry, carry);
350 }
351 return { r0, r1, r2, r3 };
352 }
353 // Recall that in this constexpr branch, we use _coarse representation_, meaning the underlying limbs of both self
354 // and other yield uint256_t are in [0, 2p) . If there is a borrow, then it is possible that adding one copy of p
355 // is insufficient to make the result positive (and adding two copies both preserves the residue mod p and keeps us
356 // in the coarse-range).
357 r0 += (twice_modulus.data[0] & borrow);
358 uint64_t carry = r0 < (twice_modulus.data[0] & borrow);
359 r1 = addc(r1, twice_modulus.data[1] & borrow, carry, carry);
360 r2 = addc(r2, twice_modulus.data[2] & borrow, carry, carry);
361 r3 += (twice_modulus.data[3] & borrow) + carry;
362
363 field result{ r0, r1, r2, r3 };
365 result.assert_coarse_form();
366 }
367 return result;
368}
369
378template <class T> constexpr field<T> field<T>::montgomery_mul_big(const field& other) const noexcept
379{
380 // only applicable for big moduli
381 static_assert(modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD);
382
383#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
384 uint64_t c = 0;
385 uint64_t t0 = 0;
386 uint64_t t1 = 0;
387 uint64_t t2 = 0;
388 uint64_t t3 = 0;
389 uint64_t t4 = 0;
390 uint64_t t5 = 0;
391 uint64_t k = 0;
392
393 // Montgomery multiplication main loop: iterates 4 times, once per limb of self.data.
394 // We compute self * other in Montgomery form by maintaining a 5-limb running accumulator (t0-t4, with t5 for
395 // overflow). In each iteration:
396 // 1. Accumulate one limb of self multiplied by all limbs of other into (t0, t1, t2, t3, t4, t5)
397 // 2. "Zero out" the lowest limb t0 by computing k = t0 * r_inv (mod 2^64), then adding k * modulus
398 // This shifts the accumulator right by one limb position (t1->t0, t2->t1, etc.)
399 // The value of k is chosen so that (t0 + k * modulus[0]) ≡ 0 (mod 2^64), meaning the shifting of the accumulator
400 // amounts to dividing by 2^64.
401 //
402 // After 4 iterations, we've accumulated the full product and divided by R = 2^256,
403 // leaving the Montgomery form result in (t0, t1, t2, t3, t4).
404 for (const auto& element : data) {
405 c = 0;
406 // element = self.data[j]
407 // ti <- ti + self.data[j] * other.data[i] + carry_in, for i = 0..3.
408 // c is the carry_in for the computation; the carry-out is then written to c at every ste at every step..
409 mac(t0, element, other.data[0], c, t0, c);
410 mac(t1, element, other.data[1], c, t1, c);
411 mac(t2, element, other.data[2], c, t2, c);
412 mac(t3, element, other.data[3], c, t3, c);
413 // t4 += c, with carry-out written to t5.
414 // t5 is in {0, 1}.
415 t4 = addc(t4, c, 0, t5);
416
417 // add a multiple of the modulus, so that the result is divisible by 2^64, and then divide. these processes are
418 // done "simultaneously".
419 k = t0 * T::r_inv;
420 // the uint128_t t0 + (t0 * r_inv) * modulus[0] is divisible by 2^64. set c to be the high 64-bits of this
421 // number.
422 c = mac_discard_lo(t0, k, modulus.data[0]);
423 mac(t1, k, modulus.data[1], c, t0, c);
424 mac(t2, k, modulus.data[2], c, t1, c);
425 mac(t3, k, modulus.data[3], c, t2, c);
426 t3 = addc(c, t4, 0, c); // c is now in {0, 1}
427 t4 = t5 + c;
428 }
429 // The result is now contains in the 64*5-bit number with limbs {t0, t1, t2, t3, t4}. In fact, this number has at
430 // most 257 bits because t4 is in {0, 1}. Proof: we have just computed (aR * bR + \sum_i k_i p)/(2^256), where each
431 // k_i is less than 2^{64i} * (2^64 - 1) for i = 0..3. The numerator is therefore upper-bounded by (2^256 - 1)^2 +
432 // (2^256 - 1) * p, hence the whole quantity is bounded by 2^256 + p - 1. Therefore, t4 is in {0, 1}, and we must do
433 // at most one subtraction to get in range.
434
435 // constant-time "conditional reduction" that computes the following without branches:
436 // `result = (value >= modulus) ? value - modulus : value`
437 uint64_t borrow = 0;
438 uint64_t r0 = sbb(t0, modulus.data[0], borrow, borrow);
439 uint64_t r1 = sbb(t1, modulus.data[1], borrow, borrow);
440 uint64_t r2 = sbb(t2, modulus.data[2], borrow, borrow);
441 uint64_t r3 = sbb(t3, modulus.data[3], borrow, borrow);
442 // if t4 == 1, then from the above upper bound of 2^256 + p - 1, it follows that borrow != 0, i.e., borrow == 2^64
443 // - 1. if t4 == 0, both options for borrow are possible.
444 borrow = borrow ^ (0ULL - t4); // borrow is set to 0 if (t4 == 1 and hence borrow == 2^64 - 1) OR if (borrow == 0
445 // AND t4 == 1). borrow is set to 2^64 - 1 if (t4 == 0 AND borrow == 2^64 - 1)
446 r0 += (modulus.data[0] & borrow);
447 uint64_t carry = r0 < (modulus.data[0] & borrow);
448 r1 = addc(r1, modulus.data[1] & borrow, carry, carry);
449 r2 = addc(r2, modulus.data[2] & borrow, carry, carry);
450 r3 += (modulus.data[3] & borrow) + carry;
451 return { r0, r1, r2, r3 };
452#else
453
454 // Convert 4 64-bit limbs to 9 29-bit limbs
455 auto left = wasm_convert(data);
456 auto right = wasm_convert(other.data);
457 constexpr uint64_t mask = 0x1fffffff;
458 uint64_t temp_0 = 0;
459 uint64_t temp_1 = 0;
460 uint64_t temp_2 = 0;
461 uint64_t temp_3 = 0;
462 uint64_t temp_4 = 0;
463 uint64_t temp_5 = 0;
464 uint64_t temp_6 = 0;
465 uint64_t temp_7 = 0;
466 uint64_t temp_8 = 0;
467 uint64_t temp_9 = 0;
468 uint64_t temp_10 = 0;
469 uint64_t temp_11 = 0;
470 uint64_t temp_12 = 0;
471 uint64_t temp_13 = 0;
472 uint64_t temp_14 = 0;
473 uint64_t temp_15 = 0;
474 uint64_t temp_16 = 0;
475 uint64_t temp_17 = 0;
476 // Compute left[0] * right and replace with a representative modulo p that zeros out the lowest
477 // 29 bits. In other words, after first reduction: temp_1..temp_8 hold the partial Montgomery product after
478 // processing left[0]. temp_0 has been "consumed" (its information propagated via carry to temp_1).
479 // Multiply-add 0th limb of the left argument by all 9 limbs of the right arguemnt
480 wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
481 // Instantly Montgomery reduce
482 wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
483 // Continue for other limbs
484 wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
485 wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
486 wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
487 wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
488 wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
489 wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
490 wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
491 wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
492 wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
493 wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
494 wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
495 wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
496 wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
497 wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
498 wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
499 wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
500 // MontgomeryMul(left, right) := (left * right) / R mod p.
501 // Then, after the add/reduce sequence, we have the following: MontgomeryMul(left, right) ≡ \sum_{i=0}^8 temp_{i+9}
502 // * 2^{29 * i} mod p. In particular, the information we want is stored in {t_9, ..., t_16}. However, these t_i are
503 // not yet 29 bits.
504 //
505 // Moreover, we claim that the value \sum_{i=0}^8 temp_{i+9} is less than than p + 2^{512-261} = p +
506 // 2^{251}. The reasoning is again generic: we have computed aR * bR + k_{0, 1, .., 8}p. Each aR and bR are, by
507 // assumption, 256 bits, and each k is 29 bits: k_0 is at most 2^29 - 1, k_1 is at most 2^58 - 2^29, etc.
508 // Telescoping, this means that the sum is upper-bounded by 2^512 + (2^261 - 1) * p. As we are taking the "high"
509 // part, we are simply trying to upper-bound this sum divided by 2^261. In particular, this shows that we have to do
510 // at most one subtraction to make the result 256 bits.
511 //
512 // After all multiplications and additions, convert relaxed form to strict (i.e., force all limbs to be
513 // 29 bits)
514 temp_10 += temp_9 >> WASM_LIMB_BITS;
515 temp_9 &= mask;
516 temp_11 += temp_10 >> WASM_LIMB_BITS;
517 temp_10 &= mask;
518 temp_12 += temp_11 >> WASM_LIMB_BITS;
519 temp_11 &= mask;
520 temp_13 += temp_12 >> WASM_LIMB_BITS;
521 temp_12 &= mask;
522 temp_14 += temp_13 >> WASM_LIMB_BITS;
523 temp_13 &= mask;
524 temp_15 += temp_14 >> WASM_LIMB_BITS;
525 temp_14 &= mask;
526 temp_16 += temp_15 >> WASM_LIMB_BITS;
527 temp_15 &= mask;
528 temp_17 += temp_16 >> WASM_LIMB_BITS;
529 temp_16 &= mask;
530
531 uint64_t r_temp_0;
532 uint64_t r_temp_1;
533 uint64_t r_temp_2;
534 uint64_t r_temp_3;
535 uint64_t r_temp_4;
536 uint64_t r_temp_5;
537 uint64_t r_temp_6;
538 uint64_t r_temp_7;
539 uint64_t r_temp_8;
540
541 r_temp_0 = temp_9 - wasm_modulus[0];
542 r_temp_1 = temp_10 - wasm_modulus[1] - ((r_temp_0) >> 63);
543 r_temp_2 = temp_11 - wasm_modulus[2] - ((r_temp_1) >> 63);
544 r_temp_3 = temp_12 - wasm_modulus[3] - ((r_temp_2) >> 63);
545 r_temp_4 = temp_13 - wasm_modulus[4] - ((r_temp_3) >> 63);
546 r_temp_5 = temp_14 - wasm_modulus[5] - ((r_temp_4) >> 63);
547 r_temp_6 = temp_15 - wasm_modulus[6] - ((r_temp_5) >> 63);
548 r_temp_7 = temp_16 - wasm_modulus[7] - ((r_temp_6) >> 63);
549 r_temp_8 = temp_17 - wasm_modulus[8] - ((r_temp_7) >> 63);
550
551 // Depending on whether the subtraction underflowed, choose original value or the result of subtraction
552 uint64_t new_mask = 0 - (r_temp_8 >> 63);
553 uint64_t inverse_mask = (~new_mask) & mask;
554 temp_9 = (temp_9 & new_mask) | (r_temp_0 & inverse_mask);
555 temp_10 = (temp_10 & new_mask) | (r_temp_1 & inverse_mask);
556 temp_11 = (temp_11 & new_mask) | (r_temp_2 & inverse_mask);
557 temp_12 = (temp_12 & new_mask) | (r_temp_3 & inverse_mask);
558 temp_13 = (temp_13 & new_mask) | (r_temp_4 & inverse_mask);
559 temp_14 = (temp_14 & new_mask) | (r_temp_5 & inverse_mask);
560 temp_15 = (temp_15 & new_mask) | (r_temp_6 & inverse_mask);
561 temp_16 = (temp_16 & new_mask) | (r_temp_7 & inverse_mask);
562 temp_17 = (temp_17 & new_mask) | (r_temp_8 & inverse_mask);
563
564 // Convert back to 4 64-bit limbs
565 return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
566 (temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
567 (temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
568 (temp_15 >> 18) | (temp_16 << 11) | (temp_17 << 40) };
569
570#endif
571}
572
573#if defined(__wasm__) || !defined(__SIZEOF_INT128__)
574
580template <class T>
581constexpr void field<T>::wasm_madd(uint64_t& left_limb,
582 const std::array<uint64_t, WASM_NUM_LIMBS>& right_limbs,
583 uint64_t& result_0,
584 uint64_t& result_1,
585 uint64_t& result_2,
586 uint64_t& result_3,
587 uint64_t& result_4,
588 uint64_t& result_5,
589 uint64_t& result_6,
590 uint64_t& result_7,
591 uint64_t& result_8)
592{
593 result_0 += left_limb * right_limbs[0];
594 result_1 += left_limb * right_limbs[1];
595 result_2 += left_limb * right_limbs[2];
596 result_3 += left_limb * right_limbs[3];
597 result_4 += left_limb * right_limbs[4];
598 result_5 += left_limb * right_limbs[5];
599 result_6 += left_limb * right_limbs[6];
600 result_7 += left_limb * right_limbs[7];
601 result_8 += left_limb * right_limbs[8];
602}
603
624template <class T>
625constexpr void field<T>::wasm_reduce(uint64_t& result_0,
626 uint64_t& result_1,
627 uint64_t& result_2,
628 uint64_t& result_3,
629 uint64_t& result_4,
630 uint64_t& result_5,
631 uint64_t& result_6,
632 uint64_t& result_7,
633 uint64_t& result_8)
634{
635 constexpr uint64_t mask = 0x1fffffff;
636 constexpr uint64_t r_inv = T::r_inv & mask; // -(modulus ^ { -1 }) modulo 2 ^ WASM_LIMB_BITS
637 uint64_t k = (result_0 * r_inv) & mask;
638 result_0 += k * wasm_modulus[0];
639 result_1 += k * wasm_modulus[1] + (result_0 >> WASM_LIMB_BITS);
640 result_2 += k * wasm_modulus[2];
641 result_3 += k * wasm_modulus[3];
642 result_4 += k * wasm_modulus[4];
643 result_5 += k * wasm_modulus[5];
644 result_6 += k * wasm_modulus[6];
645 result_7 += k * wasm_modulus[7];
646 result_8 += k * wasm_modulus[8];
647}
648
669template <class T>
670constexpr void field<T>::wasm_reduce_yuval(uint64_t& result_0,
671 uint64_t& result_1,
672 uint64_t& result_2,
673 uint64_t& result_3,
674 uint64_t& result_4,
675 uint64_t& result_5,
676 uint64_t& result_6,
677 uint64_t& result_7,
678 uint64_t& result_8,
679 uint64_t& result_9)
680{
681 constexpr uint64_t mask = 0x1fffffff;
682 const uint64_t result_0_masked = result_0 & mask;
683 result_1 += result_0_masked * wasm_r_inv[0] + (result_0 >> WASM_LIMB_BITS);
684 result_2 += result_0_masked * wasm_r_inv[1];
685 result_3 += result_0_masked * wasm_r_inv[2];
686 result_4 += result_0_masked * wasm_r_inv[3];
687 result_5 += result_0_masked * wasm_r_inv[4];
688 result_6 += result_0_masked * wasm_r_inv[5];
689 result_7 += result_0_masked * wasm_r_inv[6];
690 result_8 += result_0_masked * wasm_r_inv[7];
691 result_9 += result_0_masked * wasm_r_inv[8];
692}
697template <class T> constexpr std::array<uint64_t, WASM_NUM_LIMBS> field<T>::wasm_convert(const uint64_t* data)
698{
699 return { data[0] & 0x1fffffff,
700 (data[0] >> WASM_LIMB_BITS) & 0x1fffffff,
701 ((data[0] >> 58) & 0x3f) | ((data[1] & 0x7fffff) << 6),
702 (data[1] >> 23) & 0x1fffffff,
703 ((data[1] >> 52) & 0xfff) | ((data[2] & 0x1ffff) << 12),
704 (data[2] >> 17) & 0x1fffffff,
705 ((data[2] >> 46) & 0x3ffff) | ((data[3] & 0x7ff) << 18),
706 (data[3] >> 11) & 0x1fffffff,
707 (data[3] >> 40) & 0x1fffffff };
708}
709#endif
710template <class T> constexpr field<T> field<T>::montgomery_mul(const field& other) const noexcept
711{
712 if constexpr (modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) {
713 return montgomery_mul_big(other);
714 }
715#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
716 // process first limb of self, data[0]
717 auto [t0, c] = mul_wide(data[0], other.data[0]);
718 uint64_t k = t0 * T::r_inv;
719 uint64_t a = mac_discard_lo(t0, k, modulus.data[0]);
720
721 uint64_t t1 = mac_mini(a, data[0], other.data[1], a);
722 mac(t1, k, modulus.data[1], c, t0, c);
723 uint64_t t2 = mac_mini(a, data[0], other.data[2], a);
724 mac(t2, k, modulus.data[2], c, t1, c);
725 uint64_t t3 = mac_mini(a, data[0], other.data[3], a);
726 mac(t3, k, modulus.data[3], c, t2, c);
727 t3 = c + a;
728 // process second limb of self, data[1]
729 mac_mini(t0, data[1], other.data[0], t0, a);
730 k = t0 * T::r_inv;
731 c = mac_discard_lo(t0, k, modulus.data[0]);
732 mac(t1, data[1], other.data[1], a, t1, a);
733 mac(t1, k, modulus.data[1], c, t0, c);
734 mac(t2, data[1], other.data[2], a, t2, a);
735 mac(t2, k, modulus.data[2], c, t1, c);
736 mac(t3, data[1], other.data[3], a, t3, a);
737 mac(t3, k, modulus.data[3], c, t2, c);
738 t3 = c + a;
739 // process third limb of self, data[2]
740 mac_mini(t0, data[2], other.data[0], t0, a);
741 k = t0 * T::r_inv;
742 c = mac_discard_lo(t0, k, modulus.data[0]);
743 mac(t1, data[2], other.data[1], a, t1, a);
744 mac(t1, k, modulus.data[1], c, t0, c);
745 mac(t2, data[2], other.data[2], a, t2, a);
746 mac(t2, k, modulus.data[2], c, t1, c);
747 mac(t3, data[2], other.data[3], a, t3, a);
748 mac(t3, k, modulus.data[3], c, t2, c);
749 t3 = c + a;
750 // process fourth limb of self, data[3]
751 mac_mini(t0, data[3], other.data[0], t0, a);
752 k = t0 * T::r_inv;
753 c = mac_discard_lo(t0, k, modulus.data[0]);
754 mac(t1, data[3], other.data[1], a, t1, a);
755 mac(t1, k, modulus.data[1], c, t0, c);
756 mac(t2, data[3], other.data[2], a, t2, a);
757 mac(t2, k, modulus.data[2], c, t1, c);
758 mac(t3, data[3], other.data[3], a, t3, a);
759 mac(t3, k, modulus.data[3], c, t2, c);
760 t3 = c + a;
761 {
762 field result{ t0, t1, t2, t3 };
764 result.assert_coarse_form();
765 }
766 return result;
767 }
768#else
769
770 // Convert 4 64-bit limbs to 9 29-bit ones
771 auto left = wasm_convert(data);
772 auto right = wasm_convert(other.data);
773 constexpr uint64_t mask = 0x1fffffff;
774
775 // Karatsuba multiplication: split 9 limbs into 5 (lo) + 4 (hi).
776 // P_lo = left[0..4] * right[0..4] (25 muls)
777 // P_hi = left[5..8] * right[5..8] (16 muls)
778 // P_cross = (left_lo + left_hi) * (right_lo + right_hi) (25 muls)
779 // P_mid = P_cross - P_lo - P_hi
780 // Total: 66 muls vs 81 for schoolbook 9x9.
781
782 // P_lo = left[0..4] * right[0..4] — 5x5 schoolbook
783 uint64_t pl0 = left[0] * right[0];
784 uint64_t pl1 = left[0] * right[1] + left[1] * right[0];
785 uint64_t pl2 = left[0] * right[2] + left[1] * right[1] + left[2] * right[0];
786 uint64_t pl3 = left[0] * right[3] + left[1] * right[2] + left[2] * right[1] + left[3] * right[0];
787 uint64_t pl4 =
788 left[0] * right[4] + left[1] * right[3] + left[2] * right[2] + left[3] * right[1] + left[4] * right[0];
789 uint64_t pl5 = left[1] * right[4] + left[2] * right[3] + left[3] * right[2] + left[4] * right[1];
790 uint64_t pl6 = left[2] * right[4] + left[3] * right[3] + left[4] * right[2];
791 uint64_t pl7 = left[3] * right[4] + left[4] * right[3];
792 uint64_t pl8 = left[4] * right[4];
793
794 // P_hi = left[5..8] * right[5..8] — 4x4 schoolbook
795 uint64_t ph0 = left[5] * right[5];
796 uint64_t ph1 = left[5] * right[6] + left[6] * right[5];
797 uint64_t ph2 = left[5] * right[7] + left[6] * right[6] + left[7] * right[5];
798 uint64_t ph3 = left[5] * right[8] + left[6] * right[7] + left[7] * right[6] + left[8] * right[5];
799 uint64_t ph4 = left[6] * right[8] + left[7] * right[7] + left[8] * right[6];
800 uint64_t ph5 = left[7] * right[8] + left[8] * right[7];
801 uint64_t ph6 = left[8] * right[8];
802
803 // Sums for the cross product (left_lo + left_hi, right_lo + right_hi)
804 uint64_t sl0 = left[0] + left[5];
805 uint64_t sl1 = left[1] + left[6];
806 uint64_t sl2 = left[2] + left[7];
807 uint64_t sl3 = left[3] + left[8];
808 uint64_t sl4 = left[4];
809 uint64_t sr0 = right[0] + right[5];
810 uint64_t sr1 = right[1] + right[6];
811 uint64_t sr2 = right[2] + right[7];
812 uint64_t sr3 = right[3] + right[8];
813 uint64_t sr4 = right[4];
814
815 // P_cross = sum_left * sum_right — 5x5 schoolbook
816 uint64_t pc0 = sl0 * sr0;
817 uint64_t pc1 = sl0 * sr1 + sl1 * sr0;
818 uint64_t pc2 = sl0 * sr2 + sl1 * sr1 + sl2 * sr0;
819 uint64_t pc3 = sl0 * sr3 + sl1 * sr2 + sl2 * sr1 + sl3 * sr0;
820 uint64_t pc4 = sl0 * sr4 + sl1 * sr3 + sl2 * sr2 + sl3 * sr1 + sl4 * sr0;
821 uint64_t pc5 = sl1 * sr4 + sl2 * sr3 + sl3 * sr2 + sl4 * sr1;
822 uint64_t pc6 = sl2 * sr4 + sl3 * sr3 + sl4 * sr2;
823 uint64_t pc7 = sl3 * sr4 + sl4 * sr3;
824 uint64_t pc8 = sl4 * sr4;
825
826 // Combine: temp[k] = P_lo[k] + P_mid[k-5] + P_hi[k-10]
827 // where P_mid = P_cross - P_lo - P_hi
828 uint64_t temp_0 = pl0;
829 uint64_t temp_1 = pl1;
830 uint64_t temp_2 = pl2;
831 uint64_t temp_3 = pl3;
832 uint64_t temp_4 = pl4;
833 uint64_t temp_5 = pl5 + (pc0 - pl0 - ph0);
834 uint64_t temp_6 = pl6 + (pc1 - pl1 - ph1);
835 uint64_t temp_7 = pl7 + (pc2 - pl2 - ph2);
836 uint64_t temp_8 = pl8 + (pc3 - pl3 - ph3);
837 uint64_t temp_9 = pc4 - pl4 - ph4;
838 uint64_t temp_10 = (pc5 - pl5 - ph5) + ph0;
839 uint64_t temp_11 = (pc6 - pl6 - ph6) + ph1;
840 uint64_t temp_12 = (pc7 - pl7) + ph2;
841 uint64_t temp_13 = (pc8 - pl8) + ph3;
842 uint64_t temp_14 = ph4;
843 uint64_t temp_15 = ph5;
844 uint64_t temp_16 = ph6;
845
846 // At this point, the value aR * bR is contained in \sum_{i=0}^16 temp_{i}*2^{29*i}. Note that this value is no
847 // greater than 4p^2 as aR and bR are both less than 2p.
848 wasm_reduce_yuval(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
849 wasm_reduce_yuval(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
850 wasm_reduce_yuval(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
851 wasm_reduce_yuval(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
852 wasm_reduce_yuval(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
853 wasm_reduce_yuval(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
854 wasm_reduce_yuval(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
855 wasm_reduce_yuval(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
856
857 // The first 8 limbs are reduced using Yuval's method, the last one is reduced using the regular method
858 // The reason for this is that Yuval's method produces a 10-limb representation of the reduced limb, which is then
859 // added to the higher limbs. If we do this for the last limb we reduce, we'll get a 10-limb representation instead
860 // of a 9-limb one, so we'll have to reduce it again in some other way.
861 wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
862 // We must now reason about the current value of \sum_{i=0}^8 temp_{i+8} from the original assumptions.
863 // Following the algorithm, this is aR * bR + k_{0, 1, ..., 7}*r_inv_wasm + k_8p. Here, k_0 < 2^29-1, k_1 < 2^58 -
864 // 2^29, and so on, until k_8 < 2^261 - 2^232. Moreover, r_inv_wasm < p. (From the definition, it is the value of
865 // 2^{-29} mod p, and our choice of limb-representation is smaller than p. In fact, it is empirically smaller than
866 // p/2 for Fq and Fr.)
867 //
868 // Therefore, this whole sum is bounded by 4p^2 + (2^261 - 1)*p. Dividing by 2^261 and taking
869 // the integral part (corresponding to taking the top half of the limbs), and noting that 4p^2 / 2^261 << 1, we
870 // conclude that the result is in [0, p]. In particular, this implies that we are safely in [0, 2p), as desired.
871 //
872 // Note that the above analysis is soft, and it is overwhelmingly likely that the result is in [0, p). However, the
873 // only guarantee we require is that it is in [0, 2p), as with 254-bit fields we work with the coarse
874 // representation.
875
876 // Convert result to unrelaxed form (all limbs are 29 bits)
877 temp_10 += temp_9 >> WASM_LIMB_BITS;
878 temp_9 &= mask;
879 temp_11 += temp_10 >> WASM_LIMB_BITS;
880 temp_10 &= mask;
881 temp_12 += temp_11 >> WASM_LIMB_BITS;
882 temp_11 &= mask;
883 temp_13 += temp_12 >> WASM_LIMB_BITS;
884 temp_12 &= mask;
885 temp_14 += temp_13 >> WASM_LIMB_BITS;
886 temp_13 &= mask;
887 temp_15 += temp_14 >> WASM_LIMB_BITS;
888 temp_14 &= mask;
889 temp_16 += temp_15 >> WASM_LIMB_BITS;
890 temp_15 &= mask;
891
892 // Convert back to 4 64-bit limbs form
893 return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
894 (temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
895 (temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
896 (temp_15 >> 18) | (temp_16 << 11) };
897#endif
898}
905template <class T> constexpr field<T> field<T>::montgomery_square() const noexcept
906{
907 if constexpr (modulus.data[3] >= MODULUS_TOP_LIMB_LARGE_THRESHOLD) {
908 return montgomery_mul_big(*this);
909 }
910#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
911 uint64_t carry_hi = 0;
912
913 auto [t0, carry_lo] = mul_wide(data[0], data[0]);
914 uint64_t t1 = square_accumulate(0, data[1], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
915 uint64_t t2 = square_accumulate(0, data[2], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
916 uint64_t t3 = square_accumulate(0, data[3], data[0], carry_lo, carry_hi, carry_lo, carry_hi);
917
918 uint64_t round_carry = carry_lo;
919 uint64_t k = t0 * T::r_inv;
920 carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
921 mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
922 mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
923 mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
924 t3 = carry_lo + round_carry;
925
926 t1 = mac_mini(t1, data[1], data[1], carry_lo);
927 carry_hi = 0;
928 t2 = square_accumulate(t2, data[2], data[1], carry_lo, carry_hi, carry_lo, carry_hi);
929 t3 = square_accumulate(t3, data[3], data[1], carry_lo, carry_hi, carry_lo, carry_hi);
930 round_carry = carry_lo;
931 k = t0 * T::r_inv;
932 carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
933 mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
934 mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
935 mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
936 t3 = carry_lo + round_carry;
937
938 t2 = mac_mini(t2, data[2], data[2], carry_lo);
939 carry_hi = 0;
940 t3 = square_accumulate(t3, data[3], data[2], carry_lo, carry_hi, carry_lo, carry_hi);
941 round_carry = carry_lo;
942 k = t0 * T::r_inv;
943 carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
944 mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
945 mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
946 mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
947 t3 = carry_lo + round_carry;
948
949 t3 = mac_mini(t3, data[3], data[3], carry_lo);
950 k = t0 * T::r_inv;
951 round_carry = carry_lo;
952 carry_lo = mac_discard_lo(t0, k, modulus.data[0]);
953 mac(t1, k, modulus.data[1], carry_lo, t0, carry_lo);
954 mac(t2, k, modulus.data[2], carry_lo, t1, carry_lo);
955 mac(t3, k, modulus.data[3], carry_lo, t2, carry_lo);
956 t3 = carry_lo + round_carry;
957 {
958 field result{ t0, t1, t2, t3 };
960 result.assert_coarse_form();
961 }
962 return result;
963 }
964#else
965 // Convert from 4 64-bit limbs to 9 29-bit ones
966 auto left = wasm_convert(data);
967 constexpr uint64_t mask = 0x1fffffff;
968 uint64_t temp_0 = 0;
969 uint64_t temp_1 = 0;
970 uint64_t temp_2 = 0;
971 uint64_t temp_3 = 0;
972 uint64_t temp_4 = 0;
973 uint64_t temp_5 = 0;
974 uint64_t temp_6 = 0;
975 uint64_t temp_7 = 0;
976 uint64_t temp_8 = 0;
977 uint64_t temp_9 = 0;
978 uint64_t temp_10 = 0;
979 uint64_t temp_11 = 0;
980 uint64_t temp_12 = 0;
981 uint64_t temp_13 = 0;
982 uint64_t temp_14 = 0;
983 uint64_t temp_15 = 0;
984 uint64_t temp_16 = 0;
985 uint64_t acc;
986 // Perform multiplications, but accumulated results for limb k=i+j so that we can double them at the same time
987 temp_0 += left[0] * left[0];
988 acc = 0;
989 acc += left[0] * left[1];
990 temp_1 += (acc << 1);
991 acc = 0;
992 acc += left[0] * left[2];
993 temp_2 += left[1] * left[1];
994 temp_2 += (acc << 1);
995 acc = 0;
996 acc += left[0] * left[3];
997 acc += left[1] * left[2];
998 temp_3 += (acc << 1);
999 acc = 0;
1000 acc += left[0] * left[4];
1001 acc += left[1] * left[3];
1002 temp_4 += left[2] * left[2];
1003 temp_4 += (acc << 1);
1004 acc = 0;
1005 acc += left[0] * left[5];
1006 acc += left[1] * left[4];
1007 acc += left[2] * left[3];
1008 temp_5 += (acc << 1);
1009 acc = 0;
1010 acc += left[0] * left[6];
1011 acc += left[1] * left[5];
1012 acc += left[2] * left[4];
1013 temp_6 += left[3] * left[3];
1014 temp_6 += (acc << 1);
1015 acc = 0;
1016 acc += left[0] * left[7];
1017 acc += left[1] * left[6];
1018 acc += left[2] * left[5];
1019 acc += left[3] * left[4];
1020 temp_7 += (acc << 1);
1021 acc = 0;
1022 acc += left[0] * left[8];
1023 acc += left[1] * left[7];
1024 acc += left[2] * left[6];
1025 acc += left[3] * left[5];
1026 temp_8 += left[4] * left[4];
1027 temp_8 += (acc << 1);
1028 acc = 0;
1029 acc += left[1] * left[8];
1030 acc += left[2] * left[7];
1031 acc += left[3] * left[6];
1032 acc += left[4] * left[5];
1033 temp_9 += (acc << 1);
1034 acc = 0;
1035 acc += left[2] * left[8];
1036 acc += left[3] * left[7];
1037 acc += left[4] * left[6];
1038 temp_10 += left[5] * left[5];
1039 temp_10 += (acc << 1);
1040 acc = 0;
1041 acc += left[3] * left[8];
1042 acc += left[4] * left[7];
1043 acc += left[5] * left[6];
1044 temp_11 += (acc << 1);
1045 acc = 0;
1046 acc += left[4] * left[8];
1047 acc += left[5] * left[7];
1048 temp_12 += left[6] * left[6];
1049 temp_12 += (acc << 1);
1050 acc = 0;
1051 acc += left[5] * left[8];
1052 acc += left[6] * left[7];
1053 temp_13 += (acc << 1);
1054 acc = 0;
1055 acc += left[6] * left[8];
1056 temp_14 += left[7] * left[7];
1057 temp_14 += (acc << 1);
1058 acc = 0;
1059 acc += left[7] * left[8];
1060 temp_15 += (acc << 1);
1061 temp_16 += left[8] * left[8];
1062
1063 // Perform reductions
1064
1065 wasm_reduce_yuval(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
1066 wasm_reduce_yuval(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
1067 wasm_reduce_yuval(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
1068 wasm_reduce_yuval(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
1069 wasm_reduce_yuval(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
1070 wasm_reduce_yuval(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
1071 wasm_reduce_yuval(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
1072 wasm_reduce_yuval(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
1073
1074 // In case there is some unforseen edge case encountered in wasm multiplications, we can quickly restore previous
1075 // functionality. Comment all "wasm_reduce_yuval" and uncomment the following:
1076
1077 // wasm_reduce(temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
1078 // wasm_reduce(temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
1079 // wasm_reduce(temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
1080 // wasm_reduce(temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
1081 // wasm_reduce(temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
1082 // wasm_reduce(temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
1083 // wasm_reduce(temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
1084 // wasm_reduce(temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
1085
1086 // The first 8 limbs are reduced using Yuval's method, the last one is reduced using the regular method
1087 // The reason for this is that Yuval's method produces a 10-limb representation of the reduced limb, which is then
1088 // added to the higher limbs. If we do this for the last limb we reduce, we'll get a 10-limb representation instead
1089 // of a 9-limb one, so we'll have to reduce it again in some other way.
1090 wasm_reduce(temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
1091
1092 // Convert to unrelaxed 29-bit form
1093 temp_10 += temp_9 >> WASM_LIMB_BITS;
1094 temp_9 &= mask;
1095 temp_11 += temp_10 >> WASM_LIMB_BITS;
1096 temp_10 &= mask;
1097 temp_12 += temp_11 >> WASM_LIMB_BITS;
1098 temp_11 &= mask;
1099 temp_13 += temp_12 >> WASM_LIMB_BITS;
1100 temp_12 &= mask;
1101 temp_14 += temp_13 >> WASM_LIMB_BITS;
1102 temp_13 &= mask;
1103 temp_15 += temp_14 >> WASM_LIMB_BITS;
1104 temp_14 &= mask;
1105 temp_16 += temp_15 >> WASM_LIMB_BITS;
1106 temp_15 &= mask;
1107 // Convert to 4 64-bit form
1108 return { (temp_9 << 0) | (temp_10 << 29) | (temp_11 << 58),
1109 (temp_11 >> 6) | (temp_12 << 23) | (temp_13 << 52),
1110 (temp_13 >> 12) | (temp_14 << 17) | (temp_15 << 46),
1111 (temp_15 >> 18) | (temp_16 << 11) };
1112#endif
1113}
1114
1115template <class T> constexpr struct field<T>::wide_array field<T>::mul_512(const field& other) const noexcept
1116{
1117#if defined(__SIZEOF_INT128__) && !defined(__wasm__)
1118 uint64_t carry_2 = 0;
1119 auto [r0, carry] = mul_wide(data[0], other.data[0]);
1120 uint64_t r1 = mac_mini(carry, data[0], other.data[1], carry);
1121 uint64_t r2 = mac_mini(carry, data[0], other.data[2], carry);
1122 uint64_t r3 = mac_mini(carry, data[0], other.data[3], carry_2);
1123
1124 r1 = mac_mini(r1, data[1], other.data[0], carry);
1125 r2 = mac(r2, data[1], other.data[1], carry, carry);
1126 r3 = mac(r3, data[1], other.data[2], carry, carry);
1127 uint64_t r4 = mac(carry_2, data[1], other.data[3], carry, carry_2);
1128
1129 r2 = mac_mini(r2, data[2], other.data[0], carry);
1130 r3 = mac(r3, data[2], other.data[1], carry, carry);
1131 r4 = mac(r4, data[2], other.data[2], carry, carry);
1132 uint64_t r5 = mac(carry_2, data[2], other.data[3], carry, carry_2);
1133
1134 r3 = mac_mini(r3, data[3], other.data[0], carry);
1135 r4 = mac(r4, data[3], other.data[1], carry, carry);
1136 r5 = mac(r5, data[3], other.data[2], carry, carry);
1137 uint64_t r6 = mac(carry_2, data[3], other.data[3], carry, carry_2);
1138
1139 return { r0, r1, r2, r3, r4, r5, r6, carry_2 };
1140#else
1141 // Convert from 4 64-bit limbs to 9 29-bit limbs
1142 auto left = wasm_convert(data);
1143 auto right = wasm_convert(other.data);
1144 constexpr uint64_t mask = 0x1fffffff;
1145 uint64_t temp_0 = 0;
1146 uint64_t temp_1 = 0;
1147 uint64_t temp_2 = 0;
1148 uint64_t temp_3 = 0;
1149 uint64_t temp_4 = 0;
1150 uint64_t temp_5 = 0;
1151 uint64_t temp_6 = 0;
1152 uint64_t temp_7 = 0;
1153 uint64_t temp_8 = 0;
1154 uint64_t temp_9 = 0;
1155 uint64_t temp_10 = 0;
1156 uint64_t temp_11 = 0;
1157 uint64_t temp_12 = 0;
1158 uint64_t temp_13 = 0;
1159 uint64_t temp_14 = 0;
1160 uint64_t temp_15 = 0;
1161 uint64_t temp_16 = 0;
1162
1163 // Multiply-add all limbs
1164 wasm_madd(left[0], right, temp_0, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8);
1165 wasm_madd(left[1], right, temp_1, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9);
1166 wasm_madd(left[2], right, temp_2, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10);
1167 wasm_madd(left[3], right, temp_3, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11);
1168 wasm_madd(left[4], right, temp_4, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12);
1169 wasm_madd(left[5], right, temp_5, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13);
1170 wasm_madd(left[6], right, temp_6, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14);
1171 wasm_madd(left[7], right, temp_7, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15);
1172 wasm_madd(left[8], right, temp_8, temp_9, temp_10, temp_11, temp_12, temp_13, temp_14, temp_15, temp_16);
1173
1174 // Convert to unrelaxed 29-bit form
1175 temp_1 += temp_0 >> WASM_LIMB_BITS;
1176 temp_0 &= mask;
1177 temp_2 += temp_1 >> WASM_LIMB_BITS;
1178 temp_1 &= mask;
1179 temp_3 += temp_2 >> WASM_LIMB_BITS;
1180 temp_2 &= mask;
1181 temp_4 += temp_3 >> WASM_LIMB_BITS;
1182 temp_3 &= mask;
1183 temp_5 += temp_4 >> WASM_LIMB_BITS;
1184 temp_4 &= mask;
1185 temp_6 += temp_5 >> WASM_LIMB_BITS;
1186 temp_5 &= mask;
1187 temp_7 += temp_6 >> WASM_LIMB_BITS;
1188 temp_6 &= mask;
1189 temp_8 += temp_7 >> WASM_LIMB_BITS;
1190 temp_7 &= mask;
1191 temp_9 += temp_8 >> WASM_LIMB_BITS;
1192 temp_8 &= mask;
1193 temp_10 += temp_9 >> WASM_LIMB_BITS;
1194 temp_9 &= mask;
1195 temp_11 += temp_10 >> WASM_LIMB_BITS;
1196 temp_10 &= mask;
1197 temp_12 += temp_11 >> WASM_LIMB_BITS;
1198 temp_11 &= mask;
1199 temp_13 += temp_12 >> WASM_LIMB_BITS;
1200 temp_12 &= mask;
1201 temp_14 += temp_13 >> WASM_LIMB_BITS;
1202 temp_13 &= mask;
1203 temp_15 += temp_14 >> WASM_LIMB_BITS;
1204 temp_14 &= mask;
1205 temp_16 += temp_15 >> WASM_LIMB_BITS;
1206 temp_15 &= mask;
1207
1208 // Convert to 8 64-bit limbs
1209 return { (temp_0 << 0) | (temp_1 << 29) | (temp_2 << 58),
1210 (temp_2 >> 6) | (temp_3 << 23) | (temp_4 << 52),
1211 (temp_4 >> 12) | (temp_5 << 17) | (temp_6 << 46),
1212 (temp_6 >> 18) | (temp_7 << 11) | (temp_8 << 40),
1213 (temp_8 >> 24) | (temp_9 << 5) | (temp_10 << 34) | (temp_11 << 63),
1214 (temp_11 >> 1) | (temp_12 << 28) | (temp_13 << 57),
1215 (temp_13 >> 7) | (temp_14 << 22) | (temp_15 << 51),
1216 (temp_15 >> 13) | (temp_16 << 16) };
1217#endif
1218}
1219
1220// NOLINTEND(readability-implicit-bool-conversion)
1221} // namespace bb
const std::vector< MemoryValue > data
FF a
FF b
#define WASM_LIMB_BITS
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
unsigned __int128 uint128_t
Definition serialize.hpp:45
General class for prime fields see Prime field documentation["field documentation"] for general imple...
void assert_coarse_form() const noexcept
static BB_INLINE constexpr std::array< uint64_t, WASM_NUM_LIMBS > wasm_convert(const uint64_t *data)
Convert 4 64-bit limbs into 9 29-bit limbs.
static BB_INLINE constexpr std::pair< uint64_t, uint64_t > mul_wide(uint64_t a, uint64_t b) noexcept
static BB_INLINE constexpr uint64_t mac_discard_lo(uint64_t a, uint64_t b, uint64_t c) noexcept
static BB_INLINE constexpr uint64_t sbb(uint64_t a, uint64_t b, uint64_t borrow_in, uint64_t &borrow_out) noexcept
unsigned 64-bit subtract-with-borrow that takes in borrow_in value in the size-2 set {0,...
BB_INLINE constexpr field subtract(const field &other) const noexcept
static BB_INLINE constexpr uint64_t mac(uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in, uint64_t &carry_out) noexcept
Compute uint128_t(a * b + c + carry_in), where the inputs are all uint64_t. Return the top 64 bits.
static BB_INLINE constexpr uint64_t addc(uint64_t a, uint64_t b, uint64_t carry_in, uint64_t &carry_out) noexcept
unsigned 64-bit add-with-carry that takes in a carry_in and a carry_out bit and rewrites the latter.
static BB_INLINE constexpr void wasm_reduce(uint64_t &result_0, uint64_t &result_1, uint64_t &result_2, uint64_t &result_3, uint64_t &result_4, uint64_t &result_5, uint64_t &result_6, uint64_t &result_7, uint64_t &result_8)
Perform 29-bit Montgomery reduction on 1 limb (result_0 should be zero modulo 2^29 after calling this...
BB_INLINE constexpr field montgomery_mul_big(const field &other) const noexcept
Mongtomery multiplication for moduli > 2²⁵⁴
static BB_INLINE constexpr void wasm_madd(uint64_t &left_limb, const std::array< uint64_t, WASM_NUM_LIMBS > &right_limbs, uint64_t &result_0, uint64_t &result_1, uint64_t &result_2, uint64_t &result_3, uint64_t &result_4, uint64_t &result_5, uint64_t &result_6, uint64_t &result_7, uint64_t &result_8)
Multiply left limb by a sequence of 9 limbs and accumulate into result variables.
static BB_INLINE constexpr uint64_t square_accumulate(uint64_t a, uint64_t b, uint64_t c, uint64_t carry_in_lo, uint64_t carry_in_hi, uint64_t &carry_lo, uint64_t &carry_hi) noexcept
Computes a + 2 * b * c + carry_in_lo + 2^64 * carry_in_hi, in the form of returning a uint64_t and mo...
static BB_INLINE constexpr void wasm_reduce_yuval(uint64_t &result_0, uint64_t &result_1, uint64_t &result_2, uint64_t &result_3, uint64_t &result_4, uint64_t &result_5, uint64_t &result_6, uint64_t &result_7, uint64_t &result_8, uint64_t &result_9)
Perform 29-bit Montgomery reduction on 1 limb using Yuval's method.
BB_INLINE constexpr field add(const field &other) const noexcept
BB_INLINE constexpr field montgomery_square() const noexcept
Squaring via a variant of the Montgomery algorithm, where we roughly take advantage of the repeated t...
BB_INLINE constexpr field montgomery_mul(const field &other) const noexcept
BB_INLINE constexpr field reduce() const noexcept
reduce once, i.e., if the value is bigger than the modulus, subtract off the modulus once.
static BB_INLINE constexpr uint64_t mac_mini(uint64_t a, uint64_t b, uint64_t c, uint64_t &out) noexcept