Barretenberg
The ZK-SNARK library at the core of Aztec
Loading...
Searching...
No Matches
univariate_coefficient_basis.hpp
Go to the documentation of this file.
1// === AUDIT STATUS ===
2// internal: { status: Complete, auditors: [Nishat], commit: 94f596f8b3bbbc216f9ad7dc33253256141156b2 }
3// external_1: { status: not started, auditors: [], commit: }
4// external_2: { status: not started, auditors: [], commit: }
5// =====================
6
7#pragma once
11#include <span>
12
13namespace bb {
14
39template <class Fr, size_t domain_end, bool has_a0_plus_a1> class UnivariateCoefficientBasis {
40 public:
41 static constexpr size_t LENGTH = domain_end;
42 static_assert(LENGTH == 2 || LENGTH == 3);
43 using value_type = Fr; // used to get the type of the elements consistently with std::array
44
55 std::array<Fr, 3> coefficients;
56
58
60 requires(!has_a0_plus_a1)
61 {
62 coefficients[0] = other.coefficients[0];
63 coefficients[1] = other.coefficients[1];
64 if constexpr (domain_end == 3) {
65 coefficients[2] = other.coefficients[2];
66 }
67 }
68
74
75 template <size_t other_domain_end, bool other_has_a0_plus_a1 = true>
77 requires(domain_end > other_domain_end)
78 {
79 coefficients[0] = other.coefficients[0];
80 coefficients[1] = other.coefficients[1];
81 if constexpr (domain_end == 3) {
82 coefficients[2] = 0;
83 }
84 };
85
86 // operator== is deleted. If an equality is needed, care must be taken to define the semantics correctly. The
87 // semantic meaning of coefficients[2] depends on template parameters (unused for domain_end=2,
88 // has_a0_plus_a1=false; a0+a1 Karatsuba precomputation for has_a0_plus_a1=true; the x^2 coefficient for
89 // domain_end=3), so a defaulted comparison can produce false negatives between construction paths that represent
90 // the same polynomial.
91 bool operator==(const UnivariateCoefficientBasis& other) const = delete;
92
93 template <size_t other_domain_end, bool other_has_a0_plus_a1>
96 {
97 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
98 // the output object therefore must have `other_has_a0_plus_a1` set to false.
99 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
100 coefficients[0] += other.coefficients[0];
101 coefficients[1] += other.coefficients[1];
102 if constexpr (other_domain_end == 3 && domain_end == 3) {
103 coefficients[2] += other.coefficients[2];
104 }
105 return *this;
106 }
107
108 template <size_t other_domain_end, bool other_has_a0_plus_a1>
111 {
112 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
113 // the output object therefore must have `other_has_a0_plus_a1` set to false.
114 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
115 coefficients[0] -= other.coefficients[0];
116 coefficients[1] -= other.coefficients[1];
117 if constexpr (other_domain_end == 3 && domain_end == 3) {
118 coefficients[2] -= other.coefficients[2];
119 }
120 return *this;
121 }
122
123 template <bool other_has_a0_plus_a1>
126 requires(LENGTH == 2)
127 {
129 // result.coefficients[0] = a0 * a0;
130 // result.coefficients[1] = a1 * a1
131 result.coefficients[0] = coefficients[0] * other.coefficients[0];
132 result.coefficients[2] = coefficients[1] * other.coefficients[1];
133
134 // the reason we've been tracking this variable all this time.
135 // coefficients[1] = sum of X^2 and X coefficients
136 // (a0 + a1X) * (b0 + b1X) = a0b0 + (a0b1 + a1b0)X + a1b1XX
137 // coefficients[1] = a0b1 + a1b0 + a1b1
138 // which represented as (a0 + a1) * (b0 + b1) - a0b0
139 // if we have a1_plus_a0
140 if constexpr (has_a0_plus_a1 && other_has_a0_plus_a1) {
141 result.coefficients[1] = (coefficients[2] * other.coefficients[2] - result.coefficients[0]);
142 } else if constexpr (has_a0_plus_a1 && !other_has_a0_plus_a1) {
143 result.coefficients[1] =
144 coefficients[2] * (other.coefficients[0] + other.coefficients[1]) - result.coefficients[0];
145 } else if constexpr (!has_a0_plus_a1 && other_has_a0_plus_a1) {
146 result.coefficients[1] =
147 (coefficients[0] + coefficients[1]) * other.coefficients[2] - result.coefficients[0];
148 } else {
149 result.coefficients[1] =
150 (coefficients[0] + coefficients[1]) * (other.coefficients[0] + other.coefficients[1]) -
151 result.coefficients[0];
152 }
153 return result;
154 }
155
156 template <size_t other_domain_end, bool other_has_a0_plus_a1>
159 {
161 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
162 // the output object therefore must have `other_has_a0_plus_a1` set to false.
163 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
164 res.coefficients[0] += other.coefficients[0];
165 res.coefficients[1] += other.coefficients[1];
166 if constexpr (other_domain_end == 3 && domain_end == 3) {
167 res.coefficients[2] += other.coefficients[2];
168 }
169 return res;
170 }
171
172 template <size_t other_domain_end, bool other_has_a0_plus_a1>
175 {
177 // if both operands are degree-1, then we do not update coefficients[2], which represents `a1 + a0`
178 // the output object therefore must have `other_has_a0_plus_a1` set to false.
179 // i.e. the input also requires `other_has_a0_plus_a1`, otherwise use `operator+
180 res.coefficients[0] -= other.coefficients[0];
181 res.coefficients[1] -= other.coefficients[1];
182 if constexpr (other_domain_end == 3 && domain_end == 3) {
183 res.coefficients[2] -= other.coefficients[2];
184 }
185 return res;
186 }
187
189 {
191 res.coefficients[0] = -coefficients[0];
192 res.coefficients[1] = -coefficients[1];
193 if constexpr (domain_end == 3) {
194 res.coefficients[2] = -coefficients[2];
195 }
196
197 return res;
198 }
199
201 requires(LENGTH == 2)
202 {
204 result.coefficients[0] = coefficients[0].sqr();
205 result.coefficients[2] = coefficients[1].sqr();
206
207 // (a0 + a1.X)^2 = a0a0 + 2a0a1.X + a1a1.XX
208 // coefficients[0] = a0a0
209 // coefficients[1] = 2a0a1 + a1a1 = (a0 + a0 + a1).a1
210 // coefficients[2] = a1a1
211 // a0a0 a1a1 a0a1a1a0
212 if constexpr (has_a0_plus_a1) {
213 result.coefficients[1] = (coefficients[2] + coefficients[0]) * coefficients[1];
214 } else {
215 result.coefficients[1] = coefficients[0] * coefficients[1];
216 result.coefficients[1] += result.coefficients[1];
217 result.coefficients[1] += result.coefficients[2];
218 }
219 return result;
220 }
221
222 // Operations between Univariate and scalar
224 requires(!has_a0_plus_a1)
225 {
226 coefficients[0] += scalar;
227 return *this;
228 }
229
231 requires(!has_a0_plus_a1)
232 {
233 coefficients[0] -= scalar;
234 return *this;
235 }
237 requires(!has_a0_plus_a1)
238 {
239 coefficients[0] *= scalar;
240 coefficients[1] *= scalar;
241 if constexpr (domain_end == 3) {
242 coefficients[2] *= scalar;
243 }
244 return *this;
245 }
246
248 {
250 res += scalar;
251 return res;
252 }
253
255 {
257 res -= scalar;
258 return res;
259 }
260
262 {
264 res.coefficients[0] *= scalar;
265 res.coefficients[1] *= scalar;
266 if constexpr (domain_end == 3) {
267 res.coefficients[2] *= scalar;
268 }
269 return res;
270 }
271
272 // Output is immediately parsable as a list of integers by Python.
273 friend std::ostream& operator<<(std::ostream& os, const UnivariateCoefficientBasis& u)
274 {
275 os << "[";
276 os << u.coefficients[0] << "," << std::endl;
277 for (size_t i = 1; i < u.coefficients.size(); i++) {
278 os << " " << u.coefficients[i];
279 if (i + 1 < u.coefficients.size()) {
280 os << "," << std::endl;
281 } else {
282 os << "]";
283 };
284 }
285 return os;
286 }
287};
288
289template <typename B, class Fr, size_t domain_end, bool has_a0_plus_a1>
291{
292 using serialize::read;
293 read(it, univariate.coefficients);
294}
295
296template <typename B, class Fr, size_t domain_end, bool has_a0_plus_a1>
298{
299 using serialize::write;
300 write(it, univariate.coefficients);
301}
302
303} // namespace bb
304
305namespace std {
306template <typename T, size_t N, bool X>
307struct tuple_size<bb::UnivariateCoefficientBasis<T, N, X>> : std::integral_constant<std::size_t, N> {};
308
309} // namespace std
A view of a univariate, also used to truncate univariates.
friend std::ostream & operator<<(std::ostream &os, const UnivariateCoefficientBasis &u)
UnivariateCoefficientBasis & operator=(const UnivariateCoefficientBasis &other)=default
UnivariateCoefficientBasis< Fr, domain_end, false > operator-(const Fr &scalar) const
UnivariateCoefficientBasis(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other)
UnivariateCoefficientBasis< Fr, domain_end, false > & operator*=(const Fr &scalar)
UnivariateCoefficientBasis(UnivariateCoefficientBasis &&other) noexcept=default
UnivariateCoefficientBasis< Fr, 3, false > sqr() const
bool operator==(const UnivariateCoefficientBasis &other) const =delete
UnivariateCoefficientBasis(const UnivariateCoefficientBasis< Fr, domain_end, true > &other)
UnivariateCoefficientBasis & operator+=(const Fr &scalar)
UnivariateCoefficientBasis< Fr, domain_end, false > operator+(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other) const
std::array< Fr, 3 > coefficients
Storage for polynomial coefficients (always 3 elements for uniform layout).
UnivariateCoefficientBasis< Fr, domain_end, false > & operator+=(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other)
UnivariateCoefficientBasis< Fr, domain_end, false > operator-() const
UnivariateCoefficientBasis & operator-=(const Fr &scalar)
UnivariateCoefficientBasis & operator=(UnivariateCoefficientBasis &&other) noexcept=default
UnivariateCoefficientBasis(const UnivariateCoefficientBasis &other)=default
UnivariateCoefficientBasis< Fr, domain_end, false > & operator-=(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other)
UnivariateCoefficientBasis< Fr, 3, false > operator*(const UnivariateCoefficientBasis< Fr, domain_end, other_has_a0_plus_a1 > &other) const
UnivariateCoefficientBasis< Fr, domain_end, false > operator*(const Fr &scalar) const
UnivariateCoefficientBasis< Fr, domain_end, false > operator-(const UnivariateCoefficientBasis< Fr, other_domain_end, other_has_a0_plus_a1 > &other) const
UnivariateCoefficientBasis< Fr, domain_end, false > operator+(const Fr &scalar) const
Entry point for Barretenberg command-line interface.
Definition api.hpp:5
void read(B &it, field2< base_field, Params > &value)
void write(B &buf, field2< base_field, Params > const &value)
void read(auto &it, msgpack_concepts::HasMsgPack auto &obj)
Automatically derived read for any object that defines .msgpack() (implicitly defined by SERIALIZATIO...
void write(auto &buf, const msgpack_concepts::HasMsgPack auto &obj)
Automatically derived write for any object that defines .msgpack() (implicitly defined by SERIALIZATI...
STL namespace.
constexpr decltype(auto) get(::tuplet::tuple< T... > &&t) noexcept
Definition tuple.hpp:13
Curve::ScalarField Fr