#include <cuda_runtime.h>
#include <curand_kernel.h>
#include <sstream>
#include <iostream>
#include <sha512.cuh>
#include <ed25519.cuh>
#include <edsign.cuh>
#include <string.cuh>
using Address = unsigned char[16];
using Key = unsigned char[32];
struct KeysBox {
    Key PublicKey;
    Key PrivateKey;
};
__device__ unsigned high = 0x10;
__device__ int parameters(const char* arg) {
    int space_index = cstring_find(arg, " ");
    if (space_index != -1) {
        int substr_start = space_index + 1;
        int arg_len = cstring_length(arg);
        int substr_len = arg_len - substr_start + 1;
        char sub_arg[256];
        if (substr_len > 256) substr_len = 256;
        extract_substring(arg, substr_start, sub_arg, substr_len);
        if (cstring_find(arg, "--altitude") != -1 || cstring_find(arg, "-a") != -1) {
            unsigned tmp_high;
            int ret = cstring_to_ull(sub_arg, &tmp_high);
            if (ret != 0) return 1;
            high = tmp_high;
            return 0;
        }
    }
    if ((cstring_find(arg, "--altitude") == 0 && cstring_length(arg) == 10) || (cstring_find(arg, "-a") == 0 && cstring_length(arg) == 2)) {
        return 777;
    }
    return 0;
}
__global__ void args(char** argv, int argc, int* result) {
    int err = 0;
    for (int x = 1; x < argc; x++) {
        int res = parameters(argv[x]);
        if (res == 777) {
            if (++x >= argc) {
                err = 776;
                break;
            }
            char combined[512];
            concat(argv[x - 1], argv[x], combined, 512);
            if (parameters(combined) != 0) {
                err = res;
                break;
            }
        }
    }
    result[0] = err;
}
struct ds64 {
    char data[65];
};
struct ds46 {
    char data[46];
};
__device__ ds64 ktos(const unsigned char* key) noexcept {
    ds64 str;
    const char* hexDigits = "0123456789abcdef";
#pragma unroll
    for (unsigned char i = 0; i < 32; i++) {
        str.data[2 * i] = hexDigits[key[i] >> 4];
        str.data[2 * i + 1] = hexDigits[key[i] & 0x0F];
    }
    str.data[64] = '\0';
    return str;
}
__device__ ds46 getAddr(const unsigned char rawAddr[16]) noexcept {
    ds46 addrStr;
    const char* hexDigits = "0123456789abcdef";
    unsigned pos = 0;
#pragma unroll
    for (unsigned char group = 0; group < 8; group++) {
        int idx = group * 2;
        addrStr.data[pos++] = hexDigits[rawAddr[idx] >> 4];
        addrStr.data[pos++] = hexDigits[rawAddr[idx] & 0x0F];
        addrStr.data[pos++] = hexDigits[rawAddr[idx + 1] >> 4];
        addrStr.data[pos++] = hexDigits[rawAddr[idx + 1] & 0x0F];
        if (group < 7) { addrStr.data[pos++] = ':'; }
    }
    addrStr.data[pos] = '\0';
    return addrStr;
}
__device__ __forceinline__ void getRawAddress(int lErase, Key& InvertedPublicKey, Address& rawAddr) noexcept {
    lErase++;
    const int bitsToShift = lErase & 7;
    const int start = lErase >> 3;
    if (bitsToShift) {
    #pragma unroll
        for (int i = start; i < start + 15; i++) {
            InvertedPublicKey[i] = static_cast<unsigned char>((InvertedPublicKey[i] << bitsToShift) | (InvertedPublicKey[i + 1] >> (8 - bitsToShift)));
        }
    }
    rawAddr[0] = 0x02;
    rawAddr[1] = static_cast<unsigned char>(lErase - 1);
    memcpy(&rawAddr[2], &InvertedPublicKey[start], 14);
}
__device__ __forceinline__ unsigned char zeroCounter(unsigned int x) {
    return x ? static_cast<unsigned char>(__clz(x)) : 32;
}
__device__ __forceinline__ unsigned char getZeros(const unsigned char* v) {
    unsigned char leadZeros = 0;
#pragma unroll
    for (int i = 0; i < 32; i += 4) {
        unsigned word = (static_cast<unsigned>(v[i]) << 24) | (static_cast<unsigned>(v[i + 1]) << 16) | (static_cast<unsigned>(v[i + 2]) << 8) | (static_cast<unsigned>(v[i + 3]));
        if (word == 0)
            leadZeros += 32;
        else {
            leadZeros += zeroCounter(word);
            break;
        }
    }
    return leadZeros;
}
__global__ void initRand(curandState* randStates) {
    int id = blockIdx.x * blockDim.x + threadIdx.x;
    curand_init(static_cast<unsigned long long>(clock64()) + id, id, 0, &randStates[id]);
}
__device__ __forceinline__ unsigned long long xorshift128plus(unsigned long long* state) noexcept {
    unsigned long long x = state[0];
    const unsigned long long y = state[1];
    state[0] = y;
    x ^= x << 23;
    x ^= x >> 17;
    x ^= y ^ (y >> 26);
    state[1] = x;
    return x + y;
}
__device__ __forceinline__ void rmbytes(unsigned char* buf, unsigned long size, unsigned long long* state) {
#pragma unroll
    for (unsigned long i = 0; i < size; i++) {
        buf[i] = static_cast<unsigned char>(xorshift128plus(state) & 0xFF);
    }
}
__device__ __forceinline__ void invertKey(const unsigned char* key, unsigned char* inverted) {
#pragma unroll
    for (unsigned char i = 0; i < 32; i++) inverted[i] = key[i] ^ 0xFF;
}
__global__ void KeyGen(curandState* randStates) {
    curandState localState = randStates[blockIdx.x * blockDim.x + threadIdx.x];
    unsigned long long xorshiftState[2];
    xorshiftState[0] = curand(&localState);
    xorshiftState[1] = curand(&localState);
    Key seed;
    while (true) {
        rmbytes(seed, sizeof(seed), xorshiftState);
        KeysBox keys;
        ed25519_keygen(keys.PrivateKey, keys.PublicKey, seed);
        unsigned zeros = getZeros(keys.PublicKey);
        if (zeros > atomicMax((unsigned*)&high, zeros)) {
            Address raw;
            Key inv;
            invertKey(keys.PublicKey, inv);
            getRawAddress(zeros, inv, raw);
            printf("\nIPv6:\t%s\nPK:\t%s\nSK:\t%s\nFK:\t%s%s\n", getAddr(raw).data, ktos(keys.PublicKey).data, ktos(keys.PrivateKey).data, ktos(keys.PrivateKey).data, ktos(keys.PublicKey).data);
        }
    }
}
int main(int argc, char* argv[]) {
    int* d_result;
    cudaMalloc((void**)&d_result, sizeof(int));
    char** d_argv;
    cudaMalloc((void**)&d_argv, argc * sizeof(char*));
    for (int i = 0; i < argc; i++) {
        unsigned long len = strlen(argv[i]) + 1;
        char* d_str;
        cudaMalloc((void**)&d_str, len);
        cudaMemcpy(d_str, argv[i], len, cudaMemcpyHostToDevice);
        cudaMemcpy(&d_argv[i], &d_str, sizeof(char*), cudaMemcpyHostToDevice);
    }
    args<<<1, 1 >>>(d_argv, argc, d_result);
    unsigned h_high;
    cudaMemcpyFromSymbol(&h_high, high, sizeof(unsigned));
    printf("High addresses (2%02x+)\n", h_high);
    const int threadsPerBlock = 256;
    cudaDeviceProp prop;
    cudaGetDeviceProperties(&prop, 0);
    int mBpSM;
    cudaOccupancyMaxActiveBlocksPerMultiprocessor(&mBpSM, KeyGen, threadsPerBlock, 0);
    int SMs = prop.multiProcessorCount;
    int maxBlocks = mBpSM * SMs;
    const int totalThreads = maxBlocks * threadsPerBlock;
    printf("SMs: %d\n", SMs);
    printf("maxBlocks: %d\n", maxBlocks);
    printf("totalThreads: %d\n", totalThreads);
    printf("MaxBlocksPerSM: %d\n", mBpSM);
    printf("BlocksThreads: %d:%d\n", totalThreads / threadsPerBlock, threadsPerBlock);
    curandState* rst;
    cudaMalloc(&rst, totalThreads * sizeof(curandState));
    initRand<<<totalThreads / threadsPerBlock, threadsPerBlock>>>(rst);
    cudaDeviceSynchronize();
    KeyGen<<<totalThreads / threadsPerBlock, threadsPerBlock>>>(rst);
    cudaDeviceSynchronize();
    cudaFree(rst);
    return 0;
}