Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
uint128_impl.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Luke], commit: }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#ifdef __i386__
8#pragma once
9#include "../bitop/get_msb.hpp"
10#include "./uint128.hpp"
13namespace bb::numeric {
14
15constexpr std::pair<uint32_t, uint32_t> uint128_t::mul_wide(const uint32_t a, const uint32_t b)
16{
17 const uint32_t a_lo = a & 0xffffULL;
18 const uint32_t a_hi = a >> 16ULL;
19 const uint32_t b_lo = b & 0xffffULL;
20 const uint32_t b_hi = b >> 16ULL;
21
22 const uint32_t lo_lo = a_lo * b_lo;
23 const uint32_t hi_lo = a_hi * b_lo;
24 const uint32_t lo_hi = a_lo * b_hi;
25 const uint32_t hi_hi = a_hi * b_hi;
26
27 const uint32_t cross = (lo_lo >> 16) + (hi_lo & 0xffffULL) + lo_hi;
28
29 return { (cross << 16ULL) | (lo_lo & 0xffffULL), (hi_lo >> 16ULL) + (cross >> 16ULL) + hi_hi };
30}
31
32// compute a + b + carry, returning the carry
33constexpr std::pair<uint32_t, uint32_t> uint128_t::addc(const uint32_t a, const uint32_t b, const uint32_t carry_in)
34{
35 const uint32_t sum = a + b;
36 const auto carry_temp = static_cast<uint32_t>(sum < a);
37 const uint32_t r = sum + carry_in;
38 const uint32_t carry_out = carry_temp + static_cast<unsigned int>(r < carry_in);
39 return { r, carry_out };
40}
41
42constexpr uint32_t uint128_t::addc_discard_hi(const uint32_t a, const uint32_t b, const uint32_t carry_in)
43{
44 return a + b + carry_in;
45}
46
47constexpr std::pair<uint32_t, uint32_t> uint128_t::sbb(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
48{
49 const uint32_t t_1 = a - (borrow_in >> 31ULL);
50 const auto borrow_temp_1 = static_cast<uint32_t>(t_1 > a);
51 const uint32_t t_2 = t_1 - b;
52 const auto borrow_temp_2 = static_cast<uint32_t>(t_2 > t_1);
53
54 return { t_2, 0ULL - (borrow_temp_1 | borrow_temp_2) };
55}
56
57constexpr uint32_t uint128_t::sbb_discard_hi(const uint32_t a, const uint32_t b, const uint32_t borrow_in)
58{
59 return a - b - (borrow_in >> 31ULL);
60}
61
62// {r, carry_out} = a + carry_in + b * c
63constexpr std::pair<uint32_t, uint32_t> uint128_t::mac(const uint32_t a,
64 const uint32_t b,
65 const uint32_t c,
66 const uint32_t carry_in)
67{
68 std::pair<uint32_t, uint32_t> result = mul_wide(b, c);
69 result.first += a;
70 const auto overflow_c = static_cast<uint32_t>(result.first < a);
71 result.first += carry_in;
72 const auto overflow_carry = static_cast<uint32_t>(result.first < carry_in);
73 result.second += (overflow_c + overflow_carry);
74 return result;
75}
76
77constexpr uint32_t uint128_t::mac_discard_hi(const uint32_t a,
78 const uint32_t b,
79 const uint32_t c,
80 const uint32_t carry_in)
81{
82 return (b * c + a + carry_in);
83}
84
85constexpr std::pair<uint128_t, uint128_t> uint128_t::divmod(const uint128_t& b) const
86{
87 if (b == 0) {
88 throw_or_abort("uint128_t::divmod: divisor must be nonzero");
89 }
90 if (*this == 0) {
91 return { 0, 0 };
92 }
93 if (b == 1) {
94 return { *this, 0 };
95 }
96 if (*this == b) {
97 return { 1, 0 };
98 }
99 if (b > *this) {
100 return { 0, *this };
101 }
102
103 uint128_t quotient = 0;
104 uint128_t remainder = *this;
105
106 uint64_t bit_difference = get_msb() - b.get_msb();
107
108 uint128_t divisor = b << bit_difference;
109 uint128_t accumulator = uint128_t(1) << bit_difference;
110
111 // if the divisor is bigger than the remainder, a and b have the same bit length
112 if (divisor > remainder) {
113 divisor >>= 1;
114 accumulator >>= 1;
115 }
116
117 // while the remainder is bigger than our original divisor, we can subtract multiples of b from the remainder,
118 // and add to the quotient
119 while (remainder >= b) {
120
121 // we've shunted 'divisor' up to have the same bit length as our remainder.
122 // If remainder >= divisor, then a is at least '1 << bit_difference' multiples of b
123 if (remainder >= divisor) {
124 remainder -= divisor;
125 // we can use OR here instead of +, as
126 // accumulator is always a nice power of two
127 quotient |= accumulator;
128 }
129 divisor >>= 1;
130 accumulator >>= 1;
131 }
132
133 return { quotient, remainder };
134}
135
136constexpr std::pair<uint128_t, uint128_t> uint128_t::mul_extended(const uint128_t& other) const
137{
138 const auto [r0, t0] = mul_wide(data[0], other.data[0]);
139 const auto [q0, t1] = mac(t0, data[0], other.data[1], 0);
140 const auto [q1, t2] = mac(t1, data[0], other.data[2], 0);
141 const auto [q2, z0] = mac(t2, data[0], other.data[3], 0);
142
143 const auto [r1, t3] = mac(q0, data[1], other.data[0], 0);
144 const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
145 const auto [q4, t5] = mac(q2, data[1], other.data[2], t4);
146 const auto [q5, z1] = mac(z0, data[1], other.data[3], t5);
147
148 const auto [r2, t6] = mac(q3, data[2], other.data[0], 0);
149 const auto [q6, t7] = mac(q4, data[2], other.data[1], t6);
150 const auto [q7, t8] = mac(q5, data[2], other.data[2], t7);
151 const auto [q8, z2] = mac(z1, data[2], other.data[3], t8);
152
153 const auto [r3, t9] = mac(q6, data[3], other.data[0], 0);
154 const auto [r4, t10] = mac(q7, data[3], other.data[1], t9);
155 const auto [r5, t11] = mac(q8, data[3], other.data[2], t10);
156 const auto [r6, r7] = mac(z2, data[3], other.data[3], t11);
157
158 uint128_t lo(r0, r1, r2, r3);
159 uint128_t hi(r4, r5, r6, r7);
160 return { lo, hi };
161}
162
168constexpr uint128_t uint128_t::slice(const uint64_t start, const uint64_t end) const
169{
170 // Plain assert is used here because BB_ASSERT_DEBUG defines a std::ostringstream, which is
171 // a non-literal type and therefore disallowed in the body of a constexpr function before C++23.
172 assert(start <= end);
173 const uint64_t range = end - start;
174 const uint128_t mask = (range == 128) ? -uint128_t(1) : (uint128_t(1) << range) - 1;
175 return ((*this) >> start) & mask;
176}
177
178constexpr uint128_t uint128_t::pow(const uint128_t& exponent) const
179{
180 uint128_t accumulator{ data[0], data[1], data[2], data[3] };
181 uint128_t to_mul{ data[0], data[1], data[2], data[3] };
182 const uint64_t maximum_set_bit = exponent.get_msb();
183
184 for (int i = static_cast<int>(maximum_set_bit) - 1; i >= 0; --i) {
185 accumulator *= accumulator;
186 if (exponent.get_bit(static_cast<uint64_t>(i))) {
187 accumulator *= to_mul;
188 }
189 }
190 if (exponent == uint128_t(0)) {
191 accumulator = uint128_t(1);
192 } else if (*this == uint128_t(0)) {
193 accumulator = uint128_t(0);
194 }
195 return accumulator;
196}
197
198constexpr bool uint128_t::get_bit(const uint64_t bit_index) const
199{
200 BB_ASSERT(bit_index < 128);
201 if (bit_index > 127) {
202 return false;
203 }
204 const auto idx = static_cast<size_t>(bit_index >> 5);
205 const size_t shift = bit_index & 31;
206 return static_cast<bool>((data[idx] >> shift) & 1);
207}
208
209constexpr uint64_t uint128_t::get_msb() const
210{
211 uint64_t idx = numeric::get_msb64(data[3]);
212 idx = (idx == 0 && data[3] == 0) ? numeric::get_msb64(data[2]) : idx + 32;
213 idx = (idx == 0 && data[2] == 0) ? numeric::get_msb64(data[1]) : idx + 32;
214 idx = (idx == 0 && data[1] == 0) ? numeric::get_msb64(data[0]) : idx + 32;
215 return idx;
216}
217
218constexpr uint128_t uint128_t::operator+(const uint128_t& other) const
219{
220 const auto [r0, t0] = addc(data[0], other.data[0], 0);
221 const auto [r1, t1] = addc(data[1], other.data[1], t0);
222 const auto [r2, t2] = addc(data[2], other.data[2], t1);
223 const auto r3 = addc_discard_hi(data[3], other.data[3], t2);
224 return { r0, r1, r2, r3 };
225};
226
227constexpr uint128_t uint128_t::operator-(const uint128_t& other) const
228{
229
230 const auto [r0, t0] = sbb(data[0], other.data[0], 0);
231 const auto [r1, t1] = sbb(data[1], other.data[1], t0);
232 const auto [r2, t2] = sbb(data[2], other.data[2], t1);
233 const auto r3 = sbb_discard_hi(data[3], other.data[3], t2);
234 return { r0, r1, r2, r3 };
235}
236
237constexpr uint128_t uint128_t::operator-() const
238{
239 return uint128_t(0) - *this;
240}
241
242constexpr uint128_t uint128_t::operator*(const uint128_t& other) const
243{
244 const auto [r0, t0] = mac(0, data[0], other.data[0], 0ULL);
245 const auto [q0, t1] = mac(0, data[0], other.data[1], t0);
246 const auto [q1, t2] = mac(0, data[0], other.data[2], t1);
247 const auto q2 = mac_discard_hi(0, data[0], other.data[3], t2);
248
249 const auto [r1, t3] = mac(q0, data[1], other.data[0], 0ULL);
250 const auto [q3, t4] = mac(q1, data[1], other.data[1], t3);
251 const auto q4 = mac_discard_hi(q2, data[1], other.data[2], t4);
252
253 const auto [r2, t5] = mac(q3, data[2], other.data[0], 0ULL);
254 const auto q5 = mac_discard_hi(q4, data[2], other.data[1], t5);
255
256 const auto r3 = mac_discard_hi(q5, data[3], other.data[0], 0ULL);
257
258 return { r0, r1, r2, r3 };
259}
260
261constexpr uint128_t uint128_t::operator/(const uint128_t& other) const
262{
263 return divmod(other).first;
264}
265
266constexpr uint128_t uint128_t::operator%(const uint128_t& other) const
267{
268 return divmod(other).second;
269}
270
271constexpr uint128_t uint128_t::operator&(const uint128_t& other) const
272{
273 return { data[0] & other.data[0], data[1] & other.data[1], data[2] & other.data[2], data[3] & other.data[3] };
274}
275
276constexpr uint128_t uint128_t::operator^(const uint128_t& other) const
277{
278 return { data[0] ^ other.data[0], data[1] ^ other.data[1], data[2] ^ other.data[2], data[3] ^ other.data[3] };
279}
280
281constexpr uint128_t uint128_t::operator|(const uint128_t& other) const
282{
283 return { data[0] | other.data[0], data[1] | other.data[1], data[2] | other.data[2], data[3] | other.data[3] };
284}
285
286constexpr uint128_t uint128_t::operator~() const
287{
288 return { ~data[0], ~data[1], ~data[2], ~data[3] };
289}
290
291constexpr bool uint128_t::operator==(const uint128_t& other) const
292{
293 return data[0] == other.data[0] && data[1] == other.data[1] && data[2] == other.data[2] && data[3] == other.data[3];
294}
295
296constexpr bool uint128_t::operator!=(const uint128_t& other) const
297{
298 return !(*this == other);
299}
300
301constexpr bool uint128_t::operator!() const
302{
303 return *this == uint128_t(0ULL);
304}
305
306constexpr bool uint128_t::operator>(const uint128_t& other) const
307{
308 bool t0 = data[3] > other.data[3];
309 bool t1 = data[3] == other.data[3] && data[2] > other.data[2];
310 bool t2 = data[3] == other.data[3] && data[2] == other.data[2] && data[1] > other.data[1];
311 bool t3 =
312 data[3] == other.data[3] && data[2] == other.data[2] && data[1] == other.data[1] && data[0] > other.data[0];
313 return t0 || t1 || t2 || t3;
314}
315
316constexpr bool uint128_t::operator>=(const uint128_t& other) const
317{
318 return (*this > other) || (*this == other);
319}
320
321constexpr bool uint128_t::operator<(const uint128_t& other) const
322{
323 return other > *this;
324}
325
326constexpr bool uint128_t::operator<=(const uint128_t& other) const
327{
328 return (*this < other) || (*this == other);
329}
330
331constexpr uint128_t uint128_t::operator>>(const uint128_t& other) const
332{
333 uint32_t total_shift = other.data[0];
334
335 if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
336 return 0;
337 }
338
339 if (total_shift == 0) {
340 return *this;
341 }
342
343 uint32_t num_shifted_limbs = total_shift >> 5ULL;
344 uint32_t limb_shift = total_shift & 31ULL;
345
346 std::array<uint32_t, 4> shifted_limbs = { 0, 0, 0, 0 };
347
348 if (limb_shift == 0) {
349 shifted_limbs[0] = data[0];
350 shifted_limbs[1] = data[1];
351 shifted_limbs[2] = data[2];
352 shifted_limbs[3] = data[3];
353 } else {
354 uint32_t remainder_shift = 32ULL - limb_shift;
355
356 shifted_limbs[3] = data[3] >> limb_shift;
357
358 uint32_t remainder = (data[3]) << remainder_shift;
359
360 shifted_limbs[2] = (data[2] >> limb_shift) + remainder;
361
362 remainder = (data[2]) << remainder_shift;
363
364 shifted_limbs[1] = (data[1] >> limb_shift) + remainder;
365
366 remainder = (data[1]) << remainder_shift;
367
368 shifted_limbs[0] = (data[0] >> limb_shift) + remainder;
369 }
370 uint128_t result(0);
371
372 for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
373 result.data[i] = shifted_limbs[static_cast<size_t>(i + num_shifted_limbs)];
374 }
375
376 return result;
377}
378
379constexpr uint128_t uint128_t::operator<<(const uint128_t& other) const
380{
381 uint32_t total_shift = other.data[0];
382
383 if (total_shift >= 128 || (other.data[1] != 0U) || (other.data[2] != 0U) || (other.data[3] != 0U)) {
384 return 0;
385 }
386
387 if (total_shift == 0) {
388 return *this;
389 }
390 uint32_t num_shifted_limbs = total_shift >> 5ULL;
391 uint32_t limb_shift = total_shift & 31ULL;
392
393 std::array<uint32_t, 4> shifted_limbs{ 0, 0, 0, 0 };
394
395 if (limb_shift == 0) {
396 shifted_limbs[0] = data[0];
397 shifted_limbs[1] = data[1];
398 shifted_limbs[2] = data[2];
399 shifted_limbs[3] = data[3];
400 } else {
401 uint32_t remainder_shift = 32ULL - limb_shift;
402
403 shifted_limbs[0] = data[0] << limb_shift;
404
405 uint32_t remainder = data[0] >> remainder_shift;
406
407 shifted_limbs[1] = (data[1] << limb_shift) + remainder;
408
409 remainder = data[1] >> remainder_shift;
410
411 shifted_limbs[2] = (data[2] << limb_shift) + remainder;
412
413 remainder = data[2] >> remainder_shift;
414
415 shifted_limbs[3] = (data[3] << limb_shift) + remainder;
416 }
417 uint128_t result(0);
418
419 for (size_t i = 0; i < 4 - num_shifted_limbs; ++i) {
420 result.data[static_cast<size_t>(i + num_shifted_limbs)] = shifted_limbs[i];
421 }
422
423 return result;
424}
425
426} // namespace bb::numeric
427#endif
#define BB_ASSERT(expression,...)
Definition assert.hpp:70
const std::vector< MemoryValue > data
FF a
FF b
constexpr uint64_t get_msb64(const uint64_t in)
Definition get_msb.hpp:33
constexpr T get_msb(const T in)
Definition get_msb.hpp:50
Inner sum(Cont< Inner, Args... > const &in)
Definition container.hpp:70
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
unsigned __int128 uint128_t
Definition serialize.hpp:45
void throw_or_abort(std::string const &err)