yggm/libs/fe.cu
2025-03-17 05:32:00 +05:00

409 lines
15 KiB
Plaintext

#include <fe.cuh>
void __host__ __device__ fe_1(fe __restrict__ h) {
h[0] = 1;
#pragma unroll 9
for (int i = 1; i < 10; i++) h[i] = 0;
}
void __host__ __device__ fe_add(fe __restrict__ h, const fe& __restrict__ f, const fe& __restrict__ g) {
#pragma unroll 10
for (int i = 0; i < 10; i++) h[i] = f[i] + g[i];
}
void __host__ __device__ fe_cmov(fe f, const fe& g, unsigned int b) {
#pragma unroll 10
for (int i = 0; i < 10; i++) f[i] ^= -static_cast<int>(b) & (f[i] ^ g[i]);
}
void __host__ __device__ fe_copy(fe __restrict__ h, const fe& __restrict__ f) {
#pragma unroll 10
for (int i = 0; i < 10; i++) h[i] = f[i];
}
void fe_invert(fe __restrict__ out, const fe& __restrict__ z) {
fe t0;
fe t1;
fe t2;
fe t3;
int i;
fe_sq(t0, z);
#pragma unroll
for (i = 1; i < 1; ++i) {
fe_sq(t0, t0);
}
fe_sq(t1, t0);
#pragma unroll
for (i = 1; i < 2; ++i) {
fe_sq(t1, t1);
}
fe_mul(t1, z, t1);
fe_mul(t0, t0, t1);
fe_sq(t2, t0);
#pragma unroll
for (i = 1; i < 1; ++i) {
fe_sq(t2, t2);
}
fe_mul(t1, t1, t2);
fe_sq(t2, t1);
#pragma unroll
for (i = 1; i < 5; ++i) {
fe_sq(t2, t2);
}
fe_mul(t1, t2, t1);
fe_sq(t2, t1);
#pragma unroll
for (i = 1; i < 10; ++i) {
fe_sq(t2, t2);
}
fe_mul(t2, t2, t1);
fe_sq(t3, t2);
#pragma unroll
for (i = 1; i < 20; ++i) {
fe_sq(t3, t3);
}
fe_mul(t2, t3, t2);
fe_sq(t2, t2);
#pragma unroll
for (i = 1; i < 10; ++i) {
fe_sq(t2, t2);
}
fe_mul(t1, t2, t1);
fe_sq(t2, t1);
#pragma unroll
for (i = 1; i < 50; ++i) {
fe_sq(t2, t2);
}
fe_mul(t2, t2, t1);
fe_sq(t3, t2);
#pragma unroll
for (i = 1; i < 100; ++i) {
fe_sq(t3, t3);
}
fe_mul(t2, t3, t2);
fe_sq(t2, t2);
#pragma unroll
for (i = 1; i < 50; ++i) {
fe_sq(t2, t2);
}
fe_mul(t1, t2, t1);
fe_sq(t1, t1);
#pragma unroll
for (i = 1; i < 5; ++i) {
fe_sq(t1, t1);
}
fe_mul(out, t1, t0);
}
int __host__ __device__ fe_isnegative(const fe& __restrict__ f) {
unsigned char s[32];
fe_tobytes(s, f);
return s[0] & 1;
}
__device__ __host__ void fe_mul(fe __restrict__ h, const fe& __restrict__ f, const fe& __restrict__ g) {
long f0 = f[0], f1 = f[1], f2 = f[2], f3 = f[3], f4 = f[4], f5 = f[5], f6 = f[6], f7 = f[7], f8 = f[8], f9 = f[9];
long g0 = g[0], g1 = g[1], g2 = g[2], g3 = g[3], g4 = g[4], g5 = g[5], g6 = g[6], g7 = g[7], g8 = g[8], g9 = g[9];
long f0g0 = f0 * g0, f0g1 = f0 * g1, f0g2 = f0 * g2, f0g3 = f0 * g3, f0g4 = f0 * g4, f0g5 = f0 * g5, f0g6 = f0 * g6, f0g7 = f0 * g7, f0g8 = f0 * g8, f0g9 = f0 * g9;
long f1g0 = f1 * g0, f1g1_2 = f1 * g1 << 1, f1g2 = f1 * g2, f1g3_2 = f1 * g3 << 1, f1g4 = f1 * g4, f1g5_2 = f1 * g5 << 1, f1g6 = f1 * g6, f1g7_2 = f1 * g7 << 1, f1g8 = f1 * g8, f1g9_38 = f1 * (19 * g9) << 1;
long f2g0 = f2 * g0, f2g1 = f2 * g1, f2g2 = f2 * g2, f2g3 = f2 * g3, f2g4 = f2 * g4, f2g5 = f2 * g5, f2g6 = f2 * g6, f2g7 = f2 * g7, f2g8_19 = f2 * (19 * g8), f2g9_19 = f2 * (19 * g9);
long f3g0 = f3 * g0, f3g1_2 = f3 * g1 << 1, f3g2 = f3 * g2, f3g3_2 = f3 * g3 << 1, f3g4 = f3 * g4, f3g5_2 = f3 * g5 << 1, f3g6 = f3 * g6, f3g7_38 = f3 * (19 * g7) << 1, f3g8_19 = f3 * (19 * g8), f3g9_38 = f3 * (19 * g9) << 1;
long f4g0 = f4 * g0, f4g1 = f4 * g1, f4g2 = f4 * g2, f4g3 = f4 * g3, f4g4 = f4 * g4, f4g5 = f4 * g5, f4g6_19 = f4 * (19 * g6), f4g7_19 = f4 * (19 * g7), f4g8_19 = f4 * (19 * g8), f4g9_19 = f4 * (19 * g9);
long f5g0 = f5 * g0, f5g1_2 = f5 * g1 << 1, f5g2 = f5 * g2, f5g3_2 = f5 * g3 << 1, f5g4 = f5 * g4, f5g5_38 = f5 * (19 * g5) << 1, f5g6_19 = f5 * (19 * g6), f5g7_38 = f5 * (19 * g7) << 1, f5g8_19 = f5 * (19 * g8), f5g9_38 = f5 * (19 * g9) << 1;
long f6g0 = f6 * g0, f6g1 = f6 * g1, f6g2 = f6 * g2, f6g3 = f6 * g3, f6g4_19 = f6 * (19 * g4), f6g5_19 = f6 * (19 * g5), f6g6_19 = f6 * (19 * g6), f6g7_19 = f6 * (19 * g7), f6g8_19 = f6 * (19 * g8), f6g9_19 = f6 * (19 * g9);
long f7g0 = f7 * g0, f7g1_2 = f7 * g1 << 1, f7g2 = f7 * g2, f7g3_38 = f7 * (19 * g3) << 1, f7g4_19 = f7 * (19 * g4), f7g5_38 = f7 * (19 * g5) << 1, f7g6_19 = f7 * (19 * g6), f7g7_38 = f7 * (19 * g7) << 1, f7g8_19 = f7 * (19 * g8), f7g9_38 = f7 * (19 * g9) << 1;
long f8g0 = f8 * g0, f8g1 = f8 * g1, f8g2_19 = f8 * (19 * g2), f8g3_19 = f8 * (19 * g3), f8g4_19 = f8 * (19 * g4), f8g5_19 = f8 * (19 * g5), f8g6_19 = f8 * (19 * g6), f8g7_19 = f8 * (19 * g7), f8g8_19 = f8 * (19 * g8), f8g9_19 = f8 * (19 * g9);
long f9g0 = f9 * g0, f9g1_38 = f9 * (19 * g1) << 1, f9g2_19 = f9 * (19 * g2), f9g3_38 = f9 * (19 * g3) << 1, f9g4_19 = f9 * (19 * g4), f9g5_38 = f9 * (19 * g5) << 1, f9g6_19 = f9 * (19 * g6), f9g7_38 = f9 * (19 * g7) << 1, f9g8_19 = f9 * (19 * g8), f9g9_38 = f9 * (19 * g9) << 1;
long h0 = f0g0 + f1g9_38 + f2g8_19 + f3g7_38 + f4g6_19 + f5g5_38 + f6g4_19 + f7g3_38 + f8g2_19 + f9g1_38;
long h1 = f0g1 + f1g0 + f2g9_19 + f3g8_19 + f4g7_19 + f5g6_19 + f6g5_19 + f7g4_19 + f8g3_19 + f9g2_19;
long h2 = f0g2 + f1g1_2 + f2g0 + f3g9_38 + f4g8_19 + f5g7_38 + f6g6_19 + f7g5_38 + f8g4_19 + f9g3_38;
long h3 = f0g3 + f1g2 + f2g1 + f3g0 + f4g9_19 + f5g8_19 + f6g7_19 + f7g6_19 + f8g5_19 + f9g4_19;
long h4 = f0g4 + f1g3_2 + f2g2 + f3g1_2 + f4g0 + f5g9_38 + f6g8_19 + f7g7_38 + f8g6_19 + f9g5_38;
long h5 = f0g5 + f1g4 + f2g3 + f3g2 + f4g1 + f5g0 + f6g9_19 + f7g8_19 + f8g7_19 + f9g6_19;
long h6 = f0g6 + f1g5_2 + f2g4 + f3g3_2 + f4g2 + f5g1_2 + f6g0 + f7g9_38 + f8g8_19 + f9g7_38;
long h7 = f0g7 + f1g6 + f2g5 + f3g4 + f4g3 + f5g2 + f6g1 + f7g0 + f8g9_19 + f9g8_19;
long h8 = f0g8 + f1g7_2 + f2g6 + f3g5_2 + f4g4 + f5g3_2 + f6g2 + f7g1_2 + f8g0 + f9g9_38;
long h9 = f0g9 + f1g8 + f2g7 + f3g6 + f4g5 + f5g4 + f6g3 + f7g2 + f8g1 + f9g0;
long carry = (h0 + (1L << 25)) >> 26; h1 += carry; h0 -= carry << 26;
carry = (h4 + (1L << 25)) >> 26; h5 += carry; h4 -= carry << 26;
carry = (h1 + (1L << 24)) >> 25; h2 += carry; h1 -= carry << 25;
carry = (h5 + (1L << 24)) >> 25; h6 += carry; h5 -= carry << 25;
carry = (h2 + (1L << 25)) >> 26; h3 += carry; h2 -= carry << 26;
carry = (h6 + (1L << 25)) >> 26; h7 += carry; h6 -= carry << 26;
carry = (h3 + (1L << 24)) >> 25; h4 += carry; h3 -= carry << 25;
carry = (h7 + (1L << 24)) >> 25; h8 += carry; h7 -= carry << 25;
carry = (h4 + (1L << 25)) >> 26; h5 += carry; h4 -= carry << 26;
carry = (h8 + (1L << 25)) >> 26; h9 += carry; h8 -= carry << 26;
carry = (h9 + (1L << 24)) >> 25; h0 += carry * 19; h9 -= carry << 25;
carry = (h0 + (1L << 25)) >> 26; h1 += carry; h0 -= carry << 26;
h[0] = static_cast<int>(h0);
h[1] = static_cast<int>(h1);
h[2] = static_cast<int>(h2);
h[3] = static_cast<int>(h3);
h[4] = static_cast<int>(h4);
h[5] = static_cast<int>(h5);
h[6] = static_cast<int>(h6);
h[7] = static_cast<int>(h7);
h[8] = static_cast<int>(h8);
h[9] = static_cast<int>(h9);
}
void __device__ __host__ fe_neg(fe h, const fe& f) {
#pragma unroll 10
for (unsigned char x = 0; x < 10; x++) {
h[x] = -f[x];
}
}
void __device__ __host__ fe_sq(fe h, const fe& f) {
fe_mul(h, f, f);
}
void __host__ __device__ fe_sq2(fe h, const fe& f) {
int f0 = f[0];
int f1 = f[1];
int f2 = f[2];
int f3 = f[3];
int f4 = f[4];
int f5 = f[5];
int f6 = f[6];
int f7 = f[7];
int f8 = f[8];
int f9 = f[9];
int f0_2 = f0 << 1;
int f1_2 = f1 << 1;
int f2_2 = f2 << 1;
int f3_2 = f3 << 1;
int f4_2 = f4 << 1;
int f5_2 = f5 << 1;
int f6_2 = f6 << 1;
int f7_2 = f7 << 1;
int f5_38 = 38 * f5;
int f6_19 = 19 * f6;
int f7_38 = 38 * f7;
int f8_19 = 19 * f8;
int f9_38 = 38 * f9;
long f0f0 = f0 * static_cast<long>(f0);
long f0f1_2 = f0_2 * static_cast<long>(f1);
long f0f2_2 = f0_2 * static_cast<long>(f2);
long f0f3_2 = f0_2 * static_cast<long>(f3);
long f0f4_2 = f0_2 * static_cast<long>(f4);
long f0f5_2 = f0_2 * static_cast<long>(f5);
long f0f6_2 = f0_2 * static_cast<long>(f6);
long f0f7_2 = f0_2 * static_cast<long>(f7);
long f0f8_2 = f0_2 * static_cast<long>(f8);
long f0f9_2 = f0_2 * static_cast<long>(f9);
long f1f1_2 = f1_2 * static_cast<long>(f1);
long f1f2_2 = f1_2 * static_cast<long>(f2);
long f1f3_4 = f1_2 * static_cast<long>(f3_2);
long f1f4_2 = f1_2 * static_cast<long>(f4);
long f1f5_4 = f1_2 * static_cast<long>(f5_2);
long f1f6_2 = f1_2 * static_cast<long>(f6);
long f1f7_4 = f1_2 * static_cast<long>(f7_2);
long f1f8_2 = f1_2 * static_cast<long>(f8);
long f1f9_76 = f1_2 * static_cast<long>(f9_38);
long f2f2 = f2 * static_cast<long>(f2);
long f2f3_2 = f2_2 * static_cast<long>(f3);
long f2f4_2 = f2_2 * static_cast<long>(f4);
long f2f5_2 = f2_2 * static_cast<long>(f5);
long f2f6_2 = f2_2 * static_cast<long>(f6);
long f2f7_2 = f2_2 * static_cast<long>(f7);
long f2f8_38 = f2_2 * static_cast<long>(f8_19);
long f2f9_38 = f2 * static_cast<long>(f9_38);
long f3f3_2 = f3_2 * static_cast<long>(f3);
long f3f4_2 = f3_2 * static_cast<long>(f4);
long f3f5_4 = f3_2 * static_cast<long>(f5_2);
long f3f6_2 = f3_2 * static_cast<long>(f6);
long f3f7_76 = f3_2 * static_cast<long>(f7_38);
long f3f8_38 = f3_2 * static_cast<long>(f8_19);
long f3f9_76 = f3_2 * static_cast<long>(f9_38);
long f4f4 = f4 * static_cast<long>(f4);
long f4f5_2 = f4_2 * static_cast<long>(f5);
long f4f6_38 = f4_2 * static_cast<long>(f6_19);
long f4f7_38 = f4 * static_cast<long>(f7_38);
long f4f8_38 = f4_2 * static_cast<long>(f8_19);
long f4f9_38 = f4 * static_cast<long>(f9_38);
long f5f5_38 = f5 * static_cast<long>(f5_38);
long f5f6_38 = f5_2 * static_cast<long>(f6_19);
long f5f7_76 = f5_2 * static_cast<long>(f7_38);
long f5f8_38 = f5_2 * static_cast<long>(f8_19);
long f5f9_76 = f5_2 * static_cast<long>(f9_38);
long f6f6_19 = f6 * static_cast<long>(f6_19);
long f6f7_38 = f6 * static_cast<long>(f7_38);
long f6f8_38 = f6_2 * static_cast<long>(f8_19);
long f6f9_38 = f6 * static_cast<long>(f9_38);
long f7f7_38 = f7 * static_cast<long>(f7_38);
long f7f8_38 = f7_2 * static_cast<long>(f8_19);
long f7f9_76 = f7_2 * static_cast<long>(f9_38);
long f8f8_19 = f8 * static_cast<long>(f8_19);
long f8f9_38 = f8 * static_cast<long>(f9_38);
long f9f9_38 = f9 * static_cast<long>(f9_38);
long h0 = f0f0 + f1f9_76 + f2f8_38 + f3f7_76 + f4f6_38 + f5f5_38;
long h1 = f0f1_2 + f2f9_38 + f3f8_38 + f4f7_38 + f5f6_38;
long h2 = f0f2_2 + f1f1_2 + f3f9_76 + f4f8_38 + f5f7_76 + f6f6_19;
long h3 = f0f3_2 + f1f2_2 + f4f9_38 + f5f8_38 + f6f7_38;
long h4 = f0f4_2 + f1f3_4 + f2f2 + f5f9_76 + f6f8_38 + f7f7_38;
long h5 = f0f5_2 + f1f4_2 + f2f3_2 + f6f9_38 + f7f8_38;
long h6 = f0f6_2 + f1f5_4 + f2f4_2 + f3f3_2 + f7f9_76 + f8f8_19;
long h7 = f0f7_2 + f1f6_2 + f2f5_2 + f3f4_2 + f8f9_38;
long h8 = f0f8_2 + f1f7_4 + f2f6_2 + f3f5_4 + f4f4 + f9f9_38;
long h9 = f0f9_2 + f1f8_2 + f2f7_2 + f3f6_2 + f4f5_2;
long carry0;
long carry1;
long carry2;
long carry3;
long carry4;
long carry5;
long carry6;
long carry7;
long carry8;
long carry9;
h0 += h0;
h1 += h1;
h2 += h2;
h3 += h3;
h4 += h4;
h5 += h5;
h6 += h6;
h7 += h7;
h8 += h8;
h9 += h9;
carry0 = (h0 + static_cast<long>(1 << 25)) >> 26;
h1 += carry0;
h0 -= carry0 << 26;
carry4 = (h4 + static_cast<long>(1 << 25)) >> 26;
h5 += carry4;
h4 -= carry4 << 26;
carry1 = (h1 + static_cast<long>(1 << 24)) >> 25;
h2 += carry1;
h1 -= carry1 << 25;
carry5 = (h5 + static_cast<long>(1 << 24)) >> 25;
h6 += carry5;
h5 -= carry5 << 25;
carry2 = (h2 + static_cast<long>(1 << 25)) >> 26;
h3 += carry2;
h2 -= carry2 << 26;
carry6 = (h6 + static_cast<long>(1 << 25)) >> 26;
h7 += carry6;
h6 -= carry6 << 26;
carry3 = (h3 + static_cast<long>(1 << 24)) >> 25;
h4 += carry3;
h3 -= carry3 << 25;
carry7 = (h7 + static_cast<long>(1 << 24)) >> 25;
h8 += carry7;
h7 -= carry7 << 25;
carry4 = (h4 + static_cast<long>(1 << 25)) >> 26;
h5 += carry4;
h4 -= carry4 << 26;
carry8 = (h8 + static_cast<long>(1 << 25)) >> 26;
h9 += carry8;
h8 -= carry8 << 26;
carry9 = (h9 + static_cast<long>(1 << 24)) >> 25;
h0 += carry9 * 19;
h9 -= carry9 << 25;
carry0 = (h0 + static_cast<long>(1 << 25)) >> 26;
h1 += carry0;
h0 -= carry0 << 26;
h[0] = (int)h0;
h[1] = (int)h1;
h[2] = (int)h2;
h[3] = (int)h3;
h[4] = (int)h4;
h[5] = (int)h5;
h[6] = (int)h6;
h[7] = (int)h7;
h[8] = (int)h8;
h[9] = (int)h9;
}
void __device__ __host__ fe_sub(fe __restrict__ h, const fe& __restrict__ f, const fe& __restrict__ g) {
#pragma unroll 10
for (unsigned char x = 0; x < 10; x++) {
h[x] = f[x] - g[x];
}
}
void __host__ __device__ fe_tobytes(unsigned char* s, const fe& h) {
int h0 = h[0];
int h1 = h[1];
int h2 = h[2];
int h3 = h[3];
int h4 = h[4];
int h5 = h[5];
int h6 = h[6];
int h7 = h[7];
int h8 = h[8];
int h9 = h[9];
int q;
int carry0;
int carry1;
int carry2;
int carry3;
int carry4;
int carry5;
int carry6;
int carry7;
int carry8;
int carry9;
q = (19 * h9 + (1 << 24)) >> 25;
q = (h0 + q) >> 26;
q = (h1 + q) >> 25;
q = (h2 + q) >> 26;
q = (h3 + q) >> 25;
q = (h4 + q) >> 26;
q = (h5 + q) >> 25;
q = (h6 + q) >> 26;
q = (h7 + q) >> 25;
q = (h8 + q) >> 26;
q = (h9 + q) >> 25;
h0 += 19 * q;
carry0 = h0 >> 26;
h1 += carry0;
h0 -= carry0 << 26;
carry1 = h1 >> 25;
h2 += carry1;
h1 -= carry1 << 25;
carry2 = h2 >> 26;
h3 += carry2;
h2 -= carry2 << 26;
carry3 = h3 >> 25;
h4 += carry3;
h3 -= carry3 << 25;
carry4 = h4 >> 26;
h5 += carry4;
h4 -= carry4 << 26;
carry5 = h5 >> 25;
h6 += carry5;
h5 -= carry5 << 25;
carry6 = h6 >> 26;
h7 += carry6;
h6 -= carry6 << 26;
carry7 = h7 >> 25;
h8 += carry7;
h7 -= carry7 << 25;
carry8 = h8 >> 26;
h9 += carry8;
h8 -= carry8 << 26;
carry9 = h9 >> 25;
h9 -= carry9 << 25;
s[0] = (unsigned char)(h0 >> 0);
s[1] = (unsigned char)(h0 >> 8);
s[2] = (unsigned char)(h0 >> 16);
s[3] = (unsigned char)((h0 >> 24) | (h1 << 2));
s[4] = (unsigned char)(h1 >> 6);
s[5] = (unsigned char)(h1 >> 14);
s[6] = (unsigned char)((h1 >> 22) | (h2 << 3));
s[7] = (unsigned char)(h2 >> 5);
s[8] = (unsigned char)(h2 >> 13);
s[9] = (unsigned char)((h2 >> 21) | (h3 << 5));
s[10] = (unsigned char)(h3 >> 3);
s[11] = (unsigned char)(h3 >> 11);
s[12] = (unsigned char)((h3 >> 19) | (h4 << 6));
s[13] = (unsigned char)(h4 >> 2);
s[14] = (unsigned char)(h4 >> 10);
s[15] = (unsigned char)(h4 >> 18);
s[16] = (unsigned char)(h5 >> 0);
s[17] = (unsigned char)(h5 >> 8);
s[18] = (unsigned char)(h5 >> 16);
s[19] = (unsigned char)((h5 >> 24) | (h6 << 1));
s[20] = (unsigned char)(h6 >> 7);
s[21] = (unsigned char)(h6 >> 15);
s[22] = (unsigned char)((h6 >> 23) | (h7 << 3));
s[23] = (unsigned char)(h7 >> 5);
s[24] = (unsigned char)(h7 >> 13);
s[25] = (unsigned char)((h7 >> 21) | (h8 << 4));
s[26] = (unsigned char)(h8 >> 4);
s[27] = (unsigned char)(h8 >> 12);
s[28] = (unsigned char)((h8 >> 20) | (h9 << 6));
s[29] = (unsigned char)(h9 >> 2);
s[30] = (unsigned char)(h9 >> 10);
s[31] = (unsigned char)(h9 >> 18);
}