#include <stdio.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <sha512.cuh>
#include <ed25519.cuh>
#include <edsign.cuh>
#include <string.cuh>
#include <keymanip.cuh>
__device__ unsigned d_high = 0x10;
__device__ int parameters(const char* arg) noexcept {
    if ((cstring_find(arg, "--altitude") == 0 && cstring_length(arg) == 10) || (cstring_find(arg, "-a") == 0 && cstring_length(arg) == 2)) {
        return 777;
    }
    int space_index = cstring_find(arg, " ");
    if (space_index == -1) return 0;
    const int substr_start = space_index + 1;
    char sub_arg[256];
    extract_substring(arg, substr_start, sub_arg, 256);
    if (cstring_find(arg, "--altitude") != -1 || cstring_find(arg, "-a") != -1) {
        unsigned tmp_high;
        if (cstring_to_ull(sub_arg, &tmp_high) != 0) return 1;
        d_high = tmp_high;
    }
    return 0;
}
__global__ void args(char** argv, int argc, int* result) {
    int err = 0;
    for (int x = 1; x < argc; x++) {
        int res = parameters(argv[x]);
        if (res == 777) {
            if (++x >= argc) {
                err = 776;
                break;
            }
            char combined[512];
            concat(argv[x - 1], argv[x], combined, 512);
            if (parameters(combined) != 0) {
                err = res;
                break;
            }
        }
    }
    result[0] = err;
}
__device__ __forceinline__ unsigned char zeroCounter(unsigned int x) noexcept {
    return x ? static_cast<unsigned char>(__clz(x)) : 32;
}
__device__ __forceinline__ unsigned char getZeros(const unsigned char* v) noexcept {
    unsigned char leadZeros = 0;
#pragma unroll
    for (int i = 0; i < 32; i += 4) {
        unsigned word = (static_cast<unsigned>(v[i]) << 24) | (static_cast<unsigned>(v[i + 1]) << 16) | (static_cast<unsigned>(v[i + 2]) << 8) | (static_cast<unsigned>(v[i + 3]));
        if (word == 0)
            leadZeros += 32;
        else {
            leadZeros += zeroCounter(word);
            break;
        }
    }
    return leadZeros;
}
__global__ void initRand(curandState* rs) {
    int id = blockIdx.x * blockDim.x + threadIdx.x;
    curand_init(clock64() + id * 7919ULL, id, 0, &rs[id]);
#pragma unroll 10
    for (int i = 0; i < 10; i++) {
        curand(&rs[id]);
    }
}
__device__ __forceinline__ void rmbytes(unsigned char* buf, curandState* state) {
#pragma unroll 32
    for (unsigned long i = 0; i < 32; i++) {
        buf[i] = curand(state) & 0xFF;
    }
}
__global__ void KeyGen(curandState* randStates) {
    curandState localState = randStates[blockIdx.x * blockDim.x + threadIdx.x];
    while (true) {
        KeysBox32 keys;
        Key32 seed;
        rmbytes(seed, &localState);
        ed25519_keygen(keys.PrivateKey, keys.PublicKey, seed);
        if (unsigned zeros = getZeros(keys.PublicKey); zeros > atomicMax((unsigned*)&d_high, zeros)) {
            Addr16 raw;
            Key32 inv;
            invertKey(keys.PublicKey, inv);
            getRawAddress(zeros, inv, raw);
            printf("\nIPv6:\t%s\nPK:\t%s\nSK:\t%s\n", getAddr(raw).data, ktos(keys.PublicKey).data, ktos(keys.PrivateKey).data);
        }
    }
}
int main(int argc, char* argv[]) {
    const int thPerBlock = 256;
    int* d_result, mBpSM, h_high;
    char** d_argv;
    cudaDeviceProp prop;
    curandState* rst;
    cudaMalloc((void**)&d_result, sizeof(int));
    cudaMalloc((void**)&d_argv, argc * sizeof(char*));
    for (int i = 0; i < argc; i++) {
        unsigned long len = strlen(argv[i]) + 1;
        char* d_str;
        cudaMalloc((void**)&d_str, len);
        cudaMemcpy(d_str, argv[i], len, cudaMemcpyHostToDevice);
        cudaMemcpy(&d_argv[i], &d_str, sizeof(char*), cudaMemcpyHostToDevice);
    }
    args<<<1, 1>>>(d_argv, argc, d_result);
    cudaDeviceSynchronize();
    cudaMemcpyFromSymbol(&h_high, d_high, sizeof(unsigned));
    cudaGetDeviceProperties(&prop, 0);
    cudaOccupancyMaxActiveBlocksPerMultiprocessor(&mBpSM, KeyGen, thPerBlock, 0);
    const int totalTh = mBpSM * prop.multiProcessorCount * thPerBlock;
    printf("High addrs: 2%02x+\nSMs: %d\nMaxBlocksPerSM: %d\nTotalTh: %d\nBlocksThreads: %d:%d\n", h_high, prop.multiProcessorCount, mBpSM, totalTh, totalTh / thPerBlock, thPerBlock);
    cudaMalloc(&rst, totalTh * sizeof(curandState));
    initRand<<<totalTh / thPerBlock, thPerBlock>>>(rst);
    cudaDeviceSynchronize();
    KeyGen<<<totalTh / thPerBlock, thPerBlock>>>(rst);
    cudaDeviceSynchronize();
    cudaFree(rst);
    return 0;
}