package com.fshows.util.fnefpay.gm;

import shaded.org.bouncycastle.asn1.ASN1Encodable;
import shaded.org.bouncycastle.asn1.ASN1Integer;
import shaded.org.bouncycastle.asn1.ASN1ObjectIdentifier;
import shaded.org.bouncycastle.asn1.ASN1Sequence;
import shaded.org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers;
import shaded.org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import shaded.org.bouncycastle.asn1.x509.Certificate;
import shaded.org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
import shaded.org.bouncycastle.asn1.x9.X9ObjectIdentifiers;
import shaded.org.bouncycastle.crypto.params.ECDomainParameters;
import shaded.org.bouncycastle.crypto.params.ECPublicKeyParameters;
import shaded.org.bouncycastle.math.ec.ECPoint;
import shaded.org.bouncycastle.util.encoders.Base64;
import shaded.org.bouncycastle.util.encoders.Hex;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.PublicKey;
import java.util.Arrays;

/**
 * SM2证书类
 */
public class X509Cert_SM2 {

    private static final byte[] headBytes = "-----BEGIN CERTIFICATE-----"
            .getBytes();

    private static final int headLength = headBytes.length;

    private static final byte[] endBytes = "-----END CERTIFICATE-----"
            .getBytes();

    private static final int endLength = endBytes.length;

    private final Certificate cert;

    private PublicKey publicKey;

    public X509Cert_SM2(byte[] certData) throws Exception {
        this(certFrom(certData), 0);
    }

    public X509Cert_SM2(String certFilePath) throws Exception {
        this(certFrom(certFilePath), 0);
    }

    private static final Certificate certFrom(String certFilePath)
            throws Exception {
        if (certFilePath == null)
            throw new Exception("null not allowed for parameters@certFilePath");
        byte[] certData;
        try {
            certData = X509Cert_SM2.read(certFilePath);

        } catch (IOException e) {
            throw e;
        }

        return certFrom(certData);
    }

    X509Cert_SM2(Certificate certificate, int certType) throws Exception {
        if (certificate == null) {
            throw new Exception("null not allowed for parameters@certificate");
        }

        this.cert = certificate;
    }

    private static final Certificate certFrom(byte[] certData) throws Exception {
        Certificate cert = null;
        try {
            byte[] certBytes = filterPEMText(certData);

            if (!isDERSequence(certBytes) && !isBERSequence(certBytes)) {
                certBytes = Base64.decode(certBytes);
            }

            ASN1Sequence seq = ASN1Sequence.getInstance(certBytes);

            cert = Certificate.getInstance(seq);
        } catch (Exception ex) {
            throw new Exception(ex);
        }

        return cert;
    }

    public static final byte[] filterPEMText(byte[] certData) {
        byte[] certHead = new byte[headLength];
        byte[] certEnd = new byte[endLength];

        System.arraycopy(certData, 0, certHead, 0, headLength);
        boolean hasHead = Arrays.equals(certHead, headBytes);

        if (hasHead) {
            certData = deleteCRLF(certData);
        }

        int certDataLength = certData.length;

        System.arraycopy(certData, certDataLength - endLength, certEnd, 0,
                endLength);
        boolean hasEnd = Arrays.equals(certEnd, endBytes);

        int datStarter = 0;
        int datLength = 0;
        byte[] certBytes = null;
        if ((hasHead) && (hasEnd)) {
            datStarter = headLength;
            datLength = certDataLength - headLength - endLength;
        } else if ((!hasHead) && (hasEnd)) {
            datStarter = 0;
            datLength = certDataLength - endLength;
        } else if ((hasHead) && (!hasEnd)) {
            datStarter = headLength;
            datLength = certDataLength - headLength;
        } else {
            certBytes = certData;
        }

        if (certBytes == null) {
            certBytes = new byte[datLength];
            System.arraycopy(certData, datStarter, certBytes, 0,
                    certBytes.length);
        }

        return certBytes;
    }

    public static byte[] deleteCRLF(byte[] data) {
        ByteArrayInputStream bis = new ByteArrayInputStream(data);
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        byte tmp;
        while ((tmp = (byte) bis.read()) != -1) {
            if ((tmp != 10) && (tmp != 13)) {
                bos.write(tmp);
            }
        }
        return bos.toByteArray();
    }

    public Certificate getCert() {
        return cert;
    }

    public String getCertId() {
        return new BigInteger(cert.getSerialNumber().toString(), 10).toString(16);
    }

    public byte[] getPublicKey() {

        byte[] pubData = this.cert.getSubjectPublicKeyInfo().getPublicKeyData()
                .getBytes();

        byte[] pk = new byte[64];

        if (pubData.length == 64) {
            System.arraycopy(pubData, 0, pk, 0, 64);
        } else if (pubData.length == 65) {
            System.arraycopy(pubData, 1, pk, 0, 64);
        } else if (pubData.length == 66) {
            System.arraycopy(pubData, 2, pk, 0, 64);
        } else {
            return null;
        }

        return pk;
    }

    public byte[] getSourceData() throws IOException {
        return this.cert.getTBSCertificate().getEncoded("DER");
    }

    public byte[] getSignature() {
        return this.cert.getSignature().getBytes();
    }

    public String toString() {
        StringBuffer builder = new StringBuffer();
        if (this.cert != null)
            try {

                builder.append("\n SN: ");
                builder.append(this.cert.getSerialNumber().getValue()
                        .toString(16));
                builder.append("\n Issuer: ").append(this.cert.getIssuer());
                builder.append("\n Subject: ").append(this.cert.getSubject());
                builder.append("\n Validate: ");
                builder.append(this.cert.getStartDate());
                builder.append(", ");
                builder.append(this.cert.getEndDate());
                builder.append("\n SignatureAlgorithm: ");
                builder.append(this.cert.getSignatureAlgorithm().getAlgorithm()
                        .getId());
                builder.append("\n encoding: ");
                builder.append(Hex.toHexString(this.cert.getEncoded()));
                builder.append("\n signature: ");
                builder.append(Hex.toHexString(this.cert.getSignature()
                        .getBytes()));

            } catch (Exception e) {
                builder.append("dump cert detail failure: " + e.getMessage());
            }
        else {
            builder.append("\n none content");
        }
        return builder.toString();
    }

    public static final byte[] read(String filePath) throws IOException {
        if (filePath == null) {
            throw new IllegalArgumentException("Illegal Argument: filePath");
        }

        FileInputStream crls = new FileInputStream(filePath);
        try {
            byte[] out = new byte[crls.available()];
            byte[] buffer = new byte[65536];

            int offset = 0;
            int rLength;
            while ((rLength = crls.read(buffer, 0, buffer.length)) != -1) {
                System.arraycopy(buffer, 0, out, offset, rLength);
                offset += rLength;
            }
            byte[] arrayOfByte1 = out;
            return arrayOfByte1;
        } catch (IOException e) {
            throw e;
        } finally {
            crls.close();
        }
    }

    public static final boolean isBERSequence(byte[] encoding)
            throws Exception {
        if (encoding == null) {
            throw new Exception("encoding should not be null");
        }
        if (encoding.length < 4) {
            throw new Exception("encoding length less than 4");
        }
        if (encoding[0] != 48) {
            return false;
        }

        int offset = 1;
        int length = encoding[(offset++)] & 0xFF;
        if (length != 128) {
            return false;
        }
        return (encoding[(encoding.length - 1)] == 0)
                && (encoding[(encoding.length - 2)] == 0);
    }

    public static final boolean isDERSequence(byte[] encoding)
            throws Exception {
        if (encoding == null) {
            throw new Exception("encoding should not be null");
        }
        if (encoding.length < 2) {
            throw new Exception("encoding length less than 4");
        }
        if (encoding[0] != 48) {
            return false;
        }

        int offset = 1;
        int length = encoding[(offset++)] & 0xFF;
        if (length == 128) {
            return false;
        }

        if (length > 127) {
            int dLength = length & 0x7F;
            if (dLength > 4) {
                return false;
            }

            length = 0;
            int next = 0;
            for (int i = 0; i < dLength; i++) {
                next = encoding[(offset++)] & 0xFF;
                length = (length << 8) + next;
            }

            if (length < 0) {
                return false;
            }
        }
        return encoding.length == offset + length;
    }

    // 判断证书是否为SM2 证书
    public boolean isSm2CertType() {
        int certType = 0;
        ASN1ObjectIdentifier sm2_old = new ASN1ObjectIdentifier(
                "1.2.156.197.1.301");
        ASN1ObjectIdentifier ecPubKey = X9ObjectIdentifiers.id_ecPublicKey;
        ASN1ObjectIdentifier sm2PubKey = new ASN1ObjectIdentifier(
                "1.2.156.10197.1.301");
        ASN1ObjectIdentifier sm3WithSM2Encryption = new ASN1ObjectIdentifier(
                "1.2.156.10197.1.501");
        ASN1ObjectIdentifier sm3WithSM2Encryption_OLD = new ASN1ObjectIdentifier(
                "1.2.156.197.1.501");
        if (this.cert != null) {
            SubjectPublicKeyInfo subjectPublicKeyInfo = this.cert
                    .getSubjectPublicKeyInfo();
            AlgorithmIdentifier algorithm = subjectPublicKeyInfo.getAlgorithm();

            String keyTypeAlgorithmId = algorithm.getAlgorithm().getId();
            if (keyTypeAlgorithmId.equals(PKCSObjectIdentifiers.rsaEncryption
                    .getId())) {
                certType = 2;
            } else if (keyTypeAlgorithmId.equals(ecPubKey.getId())) {
                ASN1Encodable parameters = algorithm.getParameters();
                if ((parameters != null)
                        && ((parameters instanceof ASN1ObjectIdentifier))) {
                    ASN1ObjectIdentifier param = (ASN1ObjectIdentifier) parameters;
                    if (param.equals(sm2_old))
                        certType = 1;
                    else if (param.equals(sm2PubKey)) {
                        certType = 1;
                    }

                }

                if (certType != 1) {
                    String signAlgorithmId = cert.getSignatureAlgorithm()
                            .getAlgorithm().getId();
                    if (signAlgorithmId.equals(sm3WithSM2Encryption.getId()))
                        certType = 1;
                    else if (signAlgorithmId.equals(sm3WithSM2Encryption_OLD
                            .getId())) {
                        certType = 1;
                    }
                }
            }
        }

        return certType == 1;
    }

    // 获取SM2 证书的publicKey
    public PublicKey getPubKey() throws Exception {
        if (this.publicKey == null) {
            PublicKey publicKey = buildPublicKey(this.cert);
            this.publicKey = publicKey;
        }
        return this.publicKey;
    }

    public final PublicKey buildPublicKey(Certificate cert) throws Exception {
        PublicKey publicKey = null;

        SubjectPublicKeyInfo subjectPublicKeyInfo = cert
                .getSubjectPublicKeyInfo();

        byte[] pubData = subjectPublicKeyInfo.getPublicKeyData().getBytes();
        if ((pubData == null)
                || ((pubData.length != 64) && (pubData.length != 65))) {
            throw new Exception("证书不正确");
        }
        int starter = pubData.length == 65 ? 1 : 0;

        byte[] pubX = new byte[32];
        byte[] pubY = new byte[32];

        System.arraycopy(pubData, starter, pubX, 0, 32);
        System.arraycopy(pubData, starter + 32, pubY, 0, 32);
        publicKey = new SM2PublicKey(pubX, pubY);
        return publicKey;
    }

    // 证书校验
    public boolean verify(PublicKey publicKey) throws Exception {

        boolean verifyResult = false;
        byte[] sourceData = getSourceData();
        byte[] signature = getSignature();
        // byte[] defaultUserId = { 49, 50, 51, 52, 53, 54, 55, 56, 49, 50, 51,
        // 52, 53, 54, 55, 56 };
        byte[] userId = null;
        // SM2Signature signer = new SM2Signature();
        if ((sourceData == null) || (signature == null)) {
            return false;
        }
        // signer.initVerify(pubKey);
        SM2PublicKey sm2PubKey = (SM2PublicKey) publicKey;
        byte[] zvalue = sm2PubKey.calcZ(userId);

        // signer.update 实际调用 了digest.update
        SM3Digest digest = new SM3Digest();
        digest.update(zvalue, 0, zvalue.length);

        // return signer.verify(signature, sourceData);
        byte[] out = new byte[32];
        byte[] r = new byte[32];
        byte[] s = new byte[32];
        SM2_Result sm2Ret = new SM2_Result();
        if (signature.length == 64) {
            System.arraycopy(signature, 0, r, 0, 32);
            System.arraycopy(signature, 32, s, 0, 32);
        } else if (signature.length > 64) {
            ASN1Sequence sequence = ASN1Sequence.getInstance(signature);
            ASN1Integer R = (ASN1Integer) sequence.getObjectAt(0);
            ASN1Integer S = (ASN1Integer) sequence.getObjectAt(1);
            r = SM2PublicKey.asUnsignedNByteArray(R.getPositiveValue(), 32);
            s = SM2PublicKey.asUnsignedNByteArray(S.getPositiveValue(), 32);
        } else {
            return false;
        }

        digest.update(sourceData, 0, sourceData.length);
        digest.doFinal(out, 0);
        sm2Ret.r = new BigInteger(1, r);
        sm2Ret.s = new BigInteger(1, s);

        // return this.sm2.verify(out, sm2PubKey.getQ(),
        // sm2Ret);->BCSoftSM2.verify
        ECPoint userkey = sm2PubKey.getQ();
        if ((out == null) || (out.length != 32) || (userkey == null)
                || (r == null) || (s == null)) {
            verifyResult = false;
        } else {

            ECPublicKeyParameters key;
            key = new ECPublicKeyParameters(userkey,
                    SM2Params.sm2DomainParameters);
            verifyResult = verifySignature(out, sm2Ret.r, sm2Ret.s, key);
        }
        return verifyResult;

    }

    public static boolean verifySignature(byte[] message, BigInteger r,
                                          BigInteger s, ECPublicKeyParameters key) {
        if (message == null) {
            throw new SecurityException("null not allowed for message");
        }
        if ((r == null) || (s == null)) {
            throw new SecurityException("null not allowed for r/s");
        }

        if (key == null) {
            throw new SecurityException("not Initialization");
        }

        if (!(key instanceof ECPublicKeyParameters)) {
            throw new SecurityException("key not ECPublicKeyParameters");
        }

        ECDomainParameters ec = key.getParameters();
        BigInteger n = ec.getN();

        if ((r.compareTo(BigInteger.ONE) <= 0) || (r.compareTo(n) > 0)) {
            return false;
        }

        if ((s.compareTo(BigInteger.ONE) <= 0) || (s.compareTo(n) > 0)) {
            return false;
        }

        BigInteger e = calculateE(n, message);

        BigInteger t = r.add(s).mod(n);

        if (t.equals(BigInteger.ZERO)) {
            return false;
        }

        ECPoint P = ((ECPublicKeyParameters) key).getQ();

        ECPoint point = ec.getG().multiply(s).add(P.multiply(t));

        if (point.isInfinity()) {
            return false;
        }

        BigInteger v = e.add(point.normalize().getXCoord().toBigInteger()).mod(
                n);

        return v.equals(r);
    }

    protected static BigInteger calculateE(BigInteger n, byte[] message) {
        BigInteger e = new BigInteger(1, message);

        int messageBitLength = message.length * 8;
        int log2n = n.bitLength();
        if (log2n < messageBitLength) {
            e = e.shiftRight(messageBitLength - log2n);
        }
        return e;
    }
}
