package com.fshows.util.fnefpay.gm;

import shaded.org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import shaded.org.bouncycastle.crypto.generators.ECKeyPairGenerator;
import shaded.org.bouncycastle.crypto.params.ECDomainParameters;
import shaded.org.bouncycastle.crypto.params.ECKeyGenerationParameters;
import shaded.org.bouncycastle.crypto.params.ECPrivateKeyParameters;
import shaded.org.bouncycastle.crypto.params.ECPublicKeyParameters;
import shaded.org.bouncycastle.crypto.util.Pack;
import shaded.org.bouncycastle.math.ec.ECCurve;
import shaded.org.bouncycastle.math.ec.ECFieldElement;
import shaded.org.bouncycastle.math.ec.ECPoint;

import java.math.BigInteger;
import java.security.SecureRandom;

/**
 * @description 根据GM/T 0003-2012编写的国密SM2算法
 * @date 2021-6-29 下午3:05:50
 */
@SuppressWarnings("deprecation")
public class SM2 {

    static final byte[] defaultUserId = {49, 50, 51, 52, 53, 54, 55, 56, 49, 50, 51, 52, 53, 54, 55, 56};
    public final BigInteger ecc_p;
    public final BigInteger ecc_a;
    public final BigInteger ecc_b;
    public final BigInteger ecc_n;
    public final BigInteger ecc_xG;
    public final BigInteger ecc_yG;

    public final ECCurve ecc_curve;
    public final ECPoint ecc_point_g;

    public final ECDomainParameters ecc_bc_spec;

    public final ECKeyPairGenerator ecc_key_pair_generator;

    private int _byteLen;

    /**
     * @return 算法名称
     */
    public String getName() {
        return "SM2_Fq";
    }

    /**
     * @return 椭圆曲线域的大小（比特长度）
     */
    public int getFieldSize() {
        return ecc_curve.getFieldSize();
    }

    /**
     * @return 密钥和参数的字节长度
     */
    public int getPointByteLength() {
        return _byteLen;
    }

    /**
     * 构造一个由GM/T 0003.5-2012标准定义的SM2椭圆曲线密码算法实例
     */
    public SM2() {
        BigInteger p = new BigInteger(
                "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF",
                16);
        BigInteger a = new BigInteger(
                "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFC",
                16);
        BigInteger b = new BigInteger(
                "28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93",
                16);
        BigInteger n = new BigInteger(
                "FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123",
                16);
//		BigInteger n = new BigInteger(
//				"8542D69E4C044F18E8B92435BF6FF7DD297720630485628D5AE74EE7C32E79B7",
//				16);

        BigInteger xG = new BigInteger(
                "32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7",
                16);
        BigInteger yG = new BigInteger(
                "BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0",
                16);

        ECFieldElement ecc_gx_fieldelement;
        ECFieldElement ecc_gy_fieldelement;

        ecc_p = p;
        ecc_a = a;
        ecc_b = b;
        ecc_n = n;
        ecc_xG = xG;
        ecc_yG = yG;

        ecc_gx_fieldelement = new ECFieldElement.Fp(ecc_p, ecc_xG);
        ecc_gy_fieldelement = new ECFieldElement.Fp(ecc_p, ecc_yG);

        ecc_curve = new ECCurve.Fp(ecc_p, ecc_a, ecc_b);
        ecc_point_g = new ECPoint.Fp(ecc_curve, ecc_gx_fieldelement,
                ecc_gy_fieldelement);

        ecc_bc_spec = new ECDomainParameters(ecc_curve, ecc_point_g, ecc_n);

        ECKeyGenerationParameters ecc_ecgenparam;
        ecc_ecgenparam = new ECKeyGenerationParameters(ecc_bc_spec,
                new SecureRandom());

        ecc_key_pair_generator = new ECKeyPairGenerator();
        ecc_key_pair_generator.init(ecc_ecgenparam);

        _byteLen = (int) Math.ceil(ecc_curve.getFieldSize() / 8.0d);

    }

    /**
     * 通过自行指定的参数构造一个SM2运算实例
     *
     * @param p  素数p
     * @param a  素数a
     * @param b  素数b
     * @param n  基点G的阶n
     * @param xG 基点G的x坐标
     * @param yG 基点G的y坐标
     */
    public SM2(BigInteger p, BigInteger a, BigInteger b, BigInteger n,
               BigInteger xG, BigInteger yG) {

        ECFieldElement ecc_gx_fieldelement;
        ECFieldElement ecc_gy_fieldelement;

        ecc_p = p;
        ecc_a = a;
        ecc_b = b;
        ecc_n = n;
        ecc_xG = xG;
        ecc_yG = yG;

        ecc_gx_fieldelement = new ECFieldElement.Fp(ecc_p, ecc_xG);
        ecc_gy_fieldelement = new ECFieldElement.Fp(ecc_p, ecc_yG);

        ecc_curve = new ECCurve.Fp(ecc_p, ecc_a, ecc_b);
        ecc_point_g = new ECPoint.Fp(ecc_curve, ecc_gx_fieldelement,
                ecc_gy_fieldelement);

        ecc_bc_spec = new ECDomainParameters(ecc_curve, ecc_point_g, ecc_n);

        ECKeyGenerationParameters ecc_ecgenparam;
        ecc_ecgenparam = new ECKeyGenerationParameters(ecc_bc_spec,
                new SecureRandom());

        ecc_key_pair_generator = new ECKeyPairGenerator();
        ecc_key_pair_generator.init(ecc_ecgenparam);

        _byteLen = (int) Math.ceil(ecc_curve.getFieldSize() / 8.0d);
    }

    /**
     * 通过自行指定的参数构造一个SM2运算实例
     *
     * @param p  素数p
     * @param a  素数a
     * @param b  素数b
     * @param n  基点G的阶n
     * @param xG 基点G的x坐标
     * @param yG 基点G的y坐标
     * @param h  余因子h
     */
    public SM2(BigInteger p, BigInteger a, BigInteger b, BigInteger n,
               BigInteger xG, BigInteger yG, BigInteger h) {

        ECFieldElement ecc_gx_fieldelement;
        ECFieldElement ecc_gy_fieldelement;

        ecc_p = p;
        ecc_a = a;
        ecc_b = b;
        ecc_n = n;
        ecc_xG = xG;
        ecc_yG = yG;

        ecc_gx_fieldelement = new ECFieldElement.Fp(ecc_p, ecc_xG);
        ecc_gy_fieldelement = new ECFieldElement.Fp(ecc_p, ecc_yG);

        ecc_curve = new ECCurve.Fp(ecc_p, ecc_a, ecc_b);
        ecc_point_g = new ECPoint.Fp(ecc_curve, ecc_gx_fieldelement,
                ecc_gy_fieldelement);

        ecc_bc_spec = new ECDomainParameters(ecc_curve, ecc_point_g, ecc_n, h);

        ECKeyGenerationParameters ecc_ecgenparam;
        ecc_ecgenparam = new ECKeyGenerationParameters(ecc_bc_spec,
                new SecureRandom());

        ecc_key_pair_generator = new ECKeyPairGenerator();
        ecc_key_pair_generator.init(ecc_ecgenparam);

        _byteLen = (int) Math.ceil(ecc_curve.getFieldSize() / 8.0d);
    }

    /**
     * 将一个大整数转成一个指定长度的字节数组（前补零）
     *
     * @param bi     待转换的大整数
     * @param length 指定字节数组长度
     * @return 转换后得到的字节数组
     */
    public static byte[] BigIntegerToByteArray(BigInteger bi, int length) {
        byte[] bibyte = bi.toByteArray();
        byte[] ubibyte;

        if (bibyte[0] == 0 && bibyte.length > length) {
            ubibyte = new byte[bibyte.length - 1];
            System.arraycopy(bibyte, 1, ubibyte, 0, ubibyte.length);
        } else
            ubibyte = bibyte;

        if (ubibyte.length >= length)
            return ubibyte;

        byte[] temp = new byte[length];
        System.arraycopy(bibyte, 0, temp, length - bibyte.length, bibyte.length);

        return temp;
    }

    /**
     * 将一个大整数转成一个指定长度的字节数组（前补零）
     *
     * @param bi 待转换的大整数
     * @return 转换后得到的字节数组
     */
    public byte[] BigIntegerToByteArray(BigInteger bi) {
        byte[] bibyte = bi.toByteArray();
        byte[] ubibyte;

        if (bibyte[0] == 0 && bibyte.length > _byteLen) {
            ubibyte = new byte[bibyte.length - 1];
            System.arraycopy(bibyte, 1, ubibyte, 0, ubibyte.length);
        } else
            ubibyte = bibyte;

        if (ubibyte.length >= _byteLen)
            return ubibyte;
        else {

            byte[] temp = new byte[_byteLen];
            System.arraycopy(bibyte, 0, temp, _byteLen - bibyte.length,
                    bibyte.length);
            return temp;
        }
    }

    /**
     * 根据坐标获得椭圆曲线上的点
     *
     * @param x x坐标
     * @param y y坐标
     * @return 点
     */
    public ECPoint GetPoint(BigInteger x, BigInteger y) {
        ECFieldElement ecc_gx_fieldelement;
        ECFieldElement ecc_gy_fieldelement;

        ecc_gx_fieldelement = new ECFieldElement.Fp(ecc_p, x);
        ecc_gy_fieldelement = new ECFieldElement.Fp(ecc_p, y);
        return new ECPoint.Fp(ecc_curve, ecc_gx_fieldelement,
                ecc_gy_fieldelement);
    }

    /**
     * 根据私钥d获取公钥P
     *
     * @param d 私钥d
     * @return 公钥P
     */
    public ECPoint GetPublicKey(BigInteger d) {
        return ecc_point_g.multiply(d);
    }

    /**
     * 获取IDA长度ENTL_A的字节数组
     *
     * @param ID_A 用户A标识的字节数组
     * @return ENTL_A的字节数组
     */
    private byte[] GetENTLA(byte[] ID_A) {

        int t = ID_A.length * 8;
        byte[] ENTL_A_t = Pack.intToBigEndian(t);
        byte[] ENTL_A = new byte[2];
        System.arraycopy(ENTL_A_t, 2, ENTL_A, 0, 2);

        return ENTL_A;
    }

    /**
     * 获取杂凑值ZA
     *
     * @param userId    用户标识
     * @param publicKey 公钥
     * @return ZA
     */
    public byte[] GetZA(byte[] userId, ECPoint publicKey) {
        SM3Digest sm3 = new SM3Digest();
        byte[] p = null;

        if (userId == null) {
            userId = defaultUserId;
        }

        // ENTLA
        p = GetENTLA(userId);
        sm3.update(p, 0, p.length);

        // userId
        sm3.update(userId, 0, userId.length);

        // a,b
        p = BigIntegerToByteArray(ecc_a);
        sm3.update(p, 0, p.length);
        p = BigIntegerToByteArray(ecc_b);
        sm3.update(p, 0, p.length);

        // xG,yG
        p = BigIntegerToByteArray(ecc_xG);
        sm3.update(p, 0, p.length);
        p = BigIntegerToByteArray(ecc_yG);
        sm3.update(p, 0, p.length);

        // xA,yA
        p = BigIntegerToByteArray(publicKey.getX().toBigInteger());
        sm3.update(p, 0, p.length);
        p = BigIntegerToByteArray(publicKey.getY().toBigInteger());
        // for (int i = p.length; i < _byteLen; i++)
        // sm3.update(zero);
        sm3.update(p, 0, p.length);

        // Z
        byte[] M = new byte[sm3.getDigestSize()];
        sm3.doFinal(M, 0);

        return M;
    }

    /**
     * 使用SM2算法进行签名
     *
     * @param userId     用户标识
     * @param M          待签名的消息
     * @param privateKey 用户的私钥
     * @param k          随机数k（在正常使用时传入null即可）
     * @return 签名结果
     */
    public SM2_Result Sign(byte[] userId, byte[] M, BigInteger privateKey,
                           BigInteger k) {
        SM2_Result sm2Ret = new SM2_Result();

        byte[] ZA = GetZA(userId, GetPublicKey(privateKey));
        byte[] m_Line = new byte[ZA.length + M.length];

        System.arraycopy(ZA, 0, m_Line, 0, ZA.length);
        System.arraycopy(M, 0, m_Line, ZA.length, M.length);

        SM3Digest sm3 = new SM3Digest();
        sm3.update(m_Line, 0, m_Line.length);

        // e
        byte[] ebyte = new byte[sm3.getDigestSize()];
        sm3.doFinal(ebyte, 0);
        BigInteger e = new BigInteger(1, ebyte);

        ECPoint kp = null;
        BigInteger r = null;
        BigInteger s = null;

        do {
            do {
                if (k == null || BigInteger.ZERO.equals(k)) {
                    AsymmetricCipherKeyPair keypair = ecc_key_pair_generator
                            .generateKeyPair();
                    ECPrivateKeyParameters ecpriv = (ECPrivateKeyParameters) keypair
                            .getPrivate();
                    ECPublicKeyParameters ecpub = (ECPublicKeyParameters) keypair
                            .getPublic();
                    k = ecpriv.getD();
                    kp = ecpub.getQ();
                } else {
                    kp = ecc_point_g.multiply(k);
                }

                // r
                r = e.add(kp.getX().toBigInteger());
                r = r.mod(ecc_n);
            } while (r.equals(BigInteger.ZERO) || r.add(k).equals(ecc_n));

            // 1/(1 + dA)
            BigInteger da_1 = privateKey.add(BigInteger.ONE);
            da_1 = da_1.modInverse(ecc_n);

            // s
            s = r.multiply(privateKey);
            s = k.subtract(s).mod(ecc_n);
            s = da_1.multiply(s).mod(ecc_n);
        } while (s.equals(BigInteger.ZERO));

        sm2Ret.r = r;
        sm2Ret.s = s;

        return sm2Ret;
    }

    /**
     * 使用SM2算法进行签名验证
     *
     * @param M_sq      待验证的消息
     * @param userId    用户标识
     * @param publicKey 用户公钥
     * @param r_sq      签名值r
     * @param s_sq      签名值s
     * @return 验签结果
     */
    public Boolean Verify(byte[] M_sq, byte[] userId, ECPoint publicKey,
                          BigInteger r_sq, BigInteger s_sq) {

        if (r_sq.compareTo(BigInteger.ONE) < 0 || r_sq.compareTo(ecc_n) > 0)
            return false;

        if (s_sq.compareTo(BigInteger.ONE) < 0 || s_sq.compareTo(ecc_n) > 0)
            return false;

        byte[] ZA = GetZA(userId, publicKey);
        byte[] m_Line_sq = new byte[ZA.length + M_sq.length];

        System.arraycopy(ZA, 0, m_Line_sq, 0, ZA.length);
        System.arraycopy(M_sq, 0, m_Line_sq, ZA.length, M_sq.length);

        SM3Digest sm3 = new SM3Digest();
        sm3.update(m_Line_sq, 0, m_Line_sq.length);

        // e'
        byte[] ebyte = new byte[sm3.getDigestSize()];
        sm3.doFinal(ebyte, 0);
        BigInteger e_sq = new BigInteger(1, ebyte);

        // t
        BigInteger t = r_sq.add(s_sq).mod(ecc_n);

        if (t.equals(BigInteger.ZERO))
            return false;

        // x1'y1'
        ECPoint p0_sq = ecc_point_g.multiply(s_sq);
        ECPoint p00_sq = publicKey.multiply(t);
        ECPoint p1_sq = p0_sq.add(p00_sq);

        // R
        BigInteger R = e_sq.add(p1_sq.getX().toBigInteger()).mod(ecc_n);

        return (r_sq.compareTo(R) == 0);
    }

    /**
     * 密钥派生函数
     *
     * @param Z    输入的字节数组
     * @param klen 要获得的密钥的比特长度
     * @return 密钥
     */
    public byte[] KDF(byte[] Z, int klen) {
        SM3Digest sm3_base = new SM3Digest();
        sm3_base.update(Z, 0, Z.length);
        int ct = 1;
        int v = sm3_base.getDigestSize() * 8;
        int klenv = (int) Math.ceil((double) klen / v);
        byte[] K = new byte[klen / 8];
        int key_off = 0;

        for (int i = 1; i <= klenv - 1; i++) {
            SM3Digest sm3_ha = new SM3Digest(sm3_base);
            sm3_ha.update((byte) (ct >> 24 & 0x00ff));
            sm3_ha.update((byte) (ct >> 16 & 0x00ff));
            sm3_ha.update((byte) (ct >> 8 & 0x00ff));
            sm3_ha.update((byte) (ct & 0x00ff));
            sm3_ha.doFinal(K, key_off);
            key_off += sm3_ha.getDigestSize();
            ct++;
        }

        SM3Digest sm3_ha_end = new SM3Digest(sm3_base);
        sm3_ha_end.update((byte) (ct >> 24 & 0x00ff));
        sm3_ha_end.update((byte) (ct >> 16 & 0x00ff));
        sm3_ha_end.update((byte) (ct >> 8 & 0x00ff));
        sm3_ha_end.update((byte) (ct & 0x00ff));
        int remainder = (klen % v) / 8;
        if (remainder == 0) {
            sm3_ha_end.doFinal(K, key_off);
        } else {
            byte[] tmp = new byte[sm3_ha_end.getDigestSize()];
            sm3_ha_end.doFinal(tmp, 0);
            System.arraycopy(tmp, 0, K, key_off, remainder);
        }

        return K;
    }

    /**
     * 将多个字节数组拼凑成一个字节数组
     *
     * @param args 源字节数组
     * @return 得到的字节数组
     */
    public byte[] CombineByteArray(byte[]... args) {
        if (args == null || args.length == 0)
            return null;

        int len = 0;
        for (int i = 0; i < args.length; i++) {
            if (args[i] == null)
                continue;
            len += args[i].length;
        }
        byte[] ret = new byte[len];

        int pos = 0;
        for (int i = 0; i < args.length; i++) {
            if (args[i] == null || args[i].length == 0)
                continue;
            System.arraycopy(args[i], 0, ret, pos, args[i].length);
            pos += args[i].length;
        }

        return ret;
    }

}