package com.fshows.msfpay.utils;

import msfpay.cfca.sadk.algorithm.common.*;
import msfpay.cfca.sadk.algorithm.util.RSAAndItsCloseSymAlgUtil;
import msfpay.cfca.sadk.algorithm.util.SM2AndItsCloseSymAlgUtil;
import msfpay.cfca.sadk.asn1.parser.ASN1Parser;
import msfpay.cfca.sadk.lib.crypto.Session;
import msfpay.cfca.sadk.lib.crypto.jni.JNISoftLib;
import msfpay.cfca.sadk.org.bouncycastle.asn1.*;
import msfpay.cfca.sadk.org.bouncycastle.asn1.cms.*;
import msfpay.cfca.sadk.org.bouncycastle.asn1.x500.X500Name;
import msfpay.cfca.sadk.org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import msfpay.cfca.sadk.org.bouncycastle.cms.CMSEnvelopedData;
import msfpay.cfca.sadk.util.Base64;
import msfpay.cfca.sadk.util.CertUtil;
import msfpay.cfca.sadk.x509.certificate.X509Cert;

import java.math.BigInteger;
import java.security.PrivateKey;

public class SM2EnvelopeUtil {
    private static byte[] IV_16 = new byte[]{50, 51, 52, 53, 54, 55, 56, 57, 56, 55, 54, 53, 52, 51, 50, 49};

    public static final byte[] envelopeMessage(byte[] sourceData, String symmetricAlgorithm, X509Cert[] receiverCerts, Session session) throws Exception {
        byte[] none64 = envelopMessage_None64(sourceData, symmetricAlgorithm, receiverCerts, session);
        return Base64.encode(none64);
    }

    private static byte[] envelopMessage_None64(byte[] sourceData, String symmetricAlgorithm, X509Cert[] receiverCerts, Session session) throws Exception {
        byte[] key = SM2AndItsCloseSymAlgUtil.generateSecretKey();
        IV_16 = SM2AndItsCloseSymAlgUtil.generateIV();
        ASN1EncodableVector recipientInfos = new ASN1EncodableVector();

        for (X509Cert receiverCert : receiverCerts) {
            recipientInfos.add(toRecipientInfoOfIssuerAndSerialNumber(receiverCert, key, session));
        }

        Mechanism contentEncryptionAlg;
        if (symmetricAlgorithm.contains("CBC")) {
            CBCParam cbc = new CBCParam(IV_16);
            contentEncryptionAlg = new Mechanism(symmetricAlgorithm, cbc);
        } else {
            contentEncryptionAlg = new Mechanism(symmetricAlgorithm);
        }

        boolean useJNI = session instanceof JNISoftLib;
        byte[] encryptedData = SM2AndItsCloseSymAlgUtil.crypto(useJNI, true, key, sourceData, contentEncryptionAlg);
        ASN1OctetString encryptedOctet = new BEROctetString(encryptedData);
        ASN1ObjectIdentifier tOID = (ASN1ObjectIdentifier) PKCS7EnvelopedData.MECH_OID.get(symmetricAlgorithm);
        AlgorithmIdentifier algId = getAlgorithmIdentifier(contentEncryptionAlg, tOID);
        EncryptedContentInfo encryptedContentInfo = new EncryptedContentInfo(PKCSObjectIdentifiers.sm2Data, algId, encryptedOctet);
        EnvelopedData envData = new EnvelopedData(null, new DERSet(recipientInfos), encryptedContentInfo, ASN1Set.getInstance(null));
        ContentInfo contentInfo = new ContentInfo(PKCSObjectIdentifiers.sm2EnvelopedData, envData);
        return ASN1Parser.parseDERObj2Bytes((new CMSEnvelopedData(contentInfo)).toASN1Structure());
    }

    public static byte[] openEvelopedMessage(byte[] base64EnvelopeMessage, PrivateKey privateKey, X509Cert recipientCert, Session session) throws Exception {
        if (session == null) {
            throw new IllegalArgumentException("session不能为空");
        }

        try {
            boolean isSM2Type = CertUtil.isSM2Cert(recipientCert);
            byte[] bEnvelop = Base64.decode(base64EnvelopeMessage);
            CMSEnvelopedData cmsEnData = new CMSEnvelopedData(bEnvelop);
            ContentInfo info = cmsEnData.toASN1Structure();
            EnvelopedData enData = EnvelopedData.getInstance(info.getContent());
            ASN1Set receivers = enData.getRecipientInfos();
            X500Name recipientIssuer = recipientCert.getIssuerX500Name();
            BigInteger recipientSN = recipientCert.getSerialNumber();
            byte[] subjectPubKeyID = recipientCert.getSubjectKeyIdentifier().getKeyIdentifier();

            if (receivers == null) {
                throw new Exception("接收者为空");
            }

            ASN1OctetString encryptKey = null;
            AlgorithmIdentifier algId = null;

            for (int i = 0; i < receivers.size(); i++) {
                RecipientInfo recip = RecipientInfo.getInstance(receivers.getObjectAt(i));
                if (recip.getInfo() instanceof KeyTransRecipientInfo) {
                    KeyTransRecipientInfo inf = KeyTransRecipientInfo.getInstance(recip.getInfo());
                    if (hasRecipent(inf, subjectPubKeyID, recipientIssuer, recipientSN)) {
                        encryptKey = inf.getEncryptedKey();
                        algId = inf.getKeyEncryptionAlgorithm();
                        break;
                    }
                }
            }

            if (encryptKey == null || algId == null) {
                throw new Exception("找不到接收者");
            }

            Mechanism contentEncryptionAlg = isSM2Type ? new Mechanism("SM2") : new Mechanism("RSA/ECB/PKCS1PADDING");
            byte[] symmetricKey = session.decrypt(contentEncryptionAlg, privateKey, encryptKey.getOctets());
            EncryptedContentInfo data = enData.getEncryptedContentInfo();
            ASN1OctetString os = data.getEncryptedContent();
            AlgorithmIdentifier symmetricAlgId = data.getContentEncryptionAlgorithm();
            String encryptionAlgStr = (String) PKCS7EnvelopedData.OID_MECH.get(symmetricAlgId.getAlgorithm());

            Mechanism mechanism = null;
            if (encryptionAlgStr.contains("CBC")) {
                DEROctetString doct = (DEROctetString) symmetricAlgId.getParameters();
                CBCParam sourceData = new CBCParam(doct.getOctets());
                if (encryptionAlgStr.equals("DESede/CBC/PKCS7Padding")) {
                    mechanism = new Mechanism("DESede/CBC/PKCS7Padding", sourceData);
                } else if (encryptionAlgStr.equals("SM4/CBC/PKCS7Padding")) {
                    mechanism = new Mechanism("SM4/CBC/PKCS7Padding", sourceData);
                }
            } else if (encryptionAlgStr.contains("ECB")) {
                if (encryptionAlgStr.equals("DESede/ECB/PKCS7Padding")) {
                    mechanism = new Mechanism("DESede/ECB/PKCS7Padding");
                } else if (encryptionAlgStr.equals("SM4/ECB/PKCS7Padding")) {
                    mechanism = new Mechanism("SM4/ECB/PKCS7Padding");
                }
            } else if (encryptionAlgStr.contains("RC4")) {
                mechanism = new Mechanism("RC4");
            }

            if (mechanism == null) {
                throw new Exception("不支持的加密算法: " + encryptionAlgStr);
            }

            boolean useJNI = session instanceof JNISoftLib;
            return isSM2Type ?
                    SM2AndItsCloseSymAlgUtil.crypto(useJNI, false, symmetricKey, os.getOctets(), mechanism) :
                    RSAAndItsCloseSymAlgUtil.crypto(useJNI, false, symmetricKey, os.getOctets(), mechanism);
        } catch (Exception e) {
            throw new Exception("解析消息数字信封失败", e);
        }
    }

    private static AlgorithmIdentifier getAlgorithmIdentifier(Mechanism contentEncryptionAlg, ASN1ObjectIdentifier tOID) throws Exception {
        AlgorithmIdentifier algorithmIdentifier = new AlgorithmIdentifier(tOID);
        if (contentEncryptionAlg.getMechanismType().contains("CBC")) {
            Object param = contentEncryptionAlg.getParam();
            if (param == null) {
                throw new Exception("P7信封CBC参数为空");
            }
            CBCParam cbcParam = (CBCParam) contentEncryptionAlg.getParam();
            DEROctetString doct = new DEROctetString(cbcParam.getIv());
            return new AlgorithmIdentifier(tOID, doct);
        }
        return algorithmIdentifier;
    }

    private static RecipientInfo toRecipientInfoOfIssuerAndSerialNumber(X509Cert cert, byte[] symmetricKey, Session session) throws Exception {
        byte[] encryptedKey = session instanceof JNISoftLib ?
                SM2AndItsCloseSymAlgUtil.sm2EncryptByJNI(true, cert.getPublicKey(), symmetricKey) :
                SM2AndItsCloseSymAlgUtil.sm2Encrypt(true, cert.getPublicKey(), symmetricKey);

        ASN1OctetString encKey = new DEROctetString(encryptedKey);
        X500Name recipientIssuer = cert.getIssuerX500Name();
        BigInteger recipientSN = cert.getSerialNumber();
        IssuerAndSerialNumber issu = new IssuerAndSerialNumber(recipientIssuer, recipientSN);

        AlgorithmIdentifier keyEncAlg = new AlgorithmIdentifier(PKCSObjectIdentifiers.SM2_pubKey_encrypt, DERNull.INSTANCE);
        KeyTransRecipientInfo ktr = new KeyTransRecipientInfo(RecipientIdentifier.getInstance(issu), keyEncAlg, encKey);

        return new RecipientInfo(ktr);
    }

    private static boolean hasRecipent(KeyTransRecipientInfo inf, byte[] subjectPubKeyID, X500Name recipientIssuer, BigInteger recipientSN) {
        RecipientIdentifier id = inf.getRecipientIdentifier();
        DEROctetString oct = new DEROctetString(subjectPubKeyID);
        IssuerAndSerialNumber issu = new IssuerAndSerialNumber(recipientIssuer, recipientSN);
        return id.getId().toASN1Primitive().asn1Equals(oct) || id.getId().toASN1Primitive().asn1Equals(issu.toASN1Primitive());
    }
}