package com.fshows.ark.spring.boot.starter.core.sensitive.encrypt.interceptor;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import com.fshows.ark.spring.boot.starter.annotation.encrypt.LifecircleEncrypt;
import com.fshows.ark.spring.boot.starter.core.sensitive.encrypt.DefaultFieldEncryptExecutor;
import com.fshows.ark.spring.boot.starter.core.sensitive.model.SensitiveTargetMethodEncryptFieldModel;
import com.fshows.ark.spring.boot.starter.util.LogUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.session.defaults.DefaultSqlSession;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.sql.PreparedStatement;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;

/**
 * 敏感字段加密拦截器
 *
 * @author zhaoxumin
 * @version EncryptInterceptor.java, v 0.1 2024-05-29 21:55 zhaoxumin
 */
@Component
@Intercepts({
        @Signature(type = ParameterHandler.class, method = "setParameters", args = {PreparedStatement.class})
})
@Slf4j
public class EncryptInterceptor implements Interceptor {


    private static final Map<String, Map<String, SensitiveTargetMethodEncryptFieldModel>> ENCRYPT_FIELD_METHOD_CACHE = new ConcurrentHashMap<>();

    private final static Map<String, SensitiveTargetMethodEncryptFieldModel> DEFAULT_TARGET_METHOD_ENCRYPT_FIELD_MODEL_MAP = new HashMap<>();

    @Autowired
    private DefaultFieldEncryptExecutor defaultFieldEncryptExecutor;

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (invocation.getTarget() instanceof ParameterHandler) {
            ParameterHandler parameterHandler = (ParameterHandler) invocation.getTarget();
            Object mappedStatement = ReflectUtil.getFieldValue(parameterHandler, "mappedStatement");
            if (mappedStatement instanceof MappedStatement) {
                Field parameterField = parameterHandler.getClass().getDeclaredField("parameterObject");
                parameterField.setAccessible(true);
                Object parameterObject = parameterField.get(parameterHandler);
                if (parameterObject != null) {
                    // 原生态参数. primitive
                    if (parameterObject instanceof MapperMethod.ParamMap) {
                        String mapperFullName = ((MappedStatement) mappedStatement).getId();
                        // 获取mapper全路径名 以及 方法名
                        int lastPoint = mapperFullName.lastIndexOf(".");
                        String className = mapperFullName.substring(0, lastPoint);
                        String methodName = mapperFullName.substring(lastPoint + 1);

                        String methodNamePath = className + "#" + methodName;
                        if (!ENCRYPT_FIELD_METHOD_CACHE.containsKey(methodNamePath)) {
                            Map<String, SensitiveTargetMethodEncryptFieldModel> targetMethodEncryptFieldMap = getTargetMethodEncryptFieldMap(className, methodName);
                            ENCRYPT_FIELD_METHOD_CACHE.put(methodNamePath, targetMethodEncryptFieldMap);
                        }

                        Map<String, SensitiveTargetMethodEncryptFieldModel> targetMethodEncryptFieldMap = ENCRYPT_FIELD_METHOD_CACHE.get(methodNamePath);
                        if (DEFAULT_TARGET_METHOD_ENCRYPT_FIELD_MODEL_MAP != targetMethodEncryptFieldMap) {
                            @SuppressWarnings("unchecked")
                            MapperMethod.ParamMap<Object> paramMap = (MapperMethod.ParamMap<Object>) parameterObject;
                            for (Map.Entry<String, Object> entry : paramMap.entrySet()) {
                                SensitiveTargetMethodEncryptFieldModel targetMethodEncryptFieldModel = targetMethodEncryptFieldMap.get(entry.getKey());
                                if (targetMethodEncryptFieldModel != null) {
                                    LifecircleEncrypt.EncryptContentTypeEnum encryptContentTypeEnum = LifecircleEncrypt.EncryptContentTypeEnum.getByValue(targetMethodEncryptFieldModel.getEncryptContentType());

                                    String encryptValue;
                                    if (LifecircleEncrypt.EncryptContentTypeEnum.SEARCH_KEYWORD.equals(encryptContentTypeEnum)) {
                                        encryptValue = defaultFieldEncryptExecutor.encryptSearchKeywords(entry.getValue().toString());
                                    } else {
//                                        encryptValue = defaultFieldEncryptExecutor.encrypt(entry.getValue().toString(), targetMethodEncryptFieldModel.getSecretName()); todo zxm
                                        encryptValue = defaultFieldEncryptExecutor.encrypt(entry.getValue().toString());
                                    }
                                    entry.setValue(encryptValue);
                                }
                            }
                        }
                    } else if (parameterObject instanceof DefaultSqlSession.StrictMap) {
                        // DOList作为参数. objectList
                        @SuppressWarnings("unchecked")
                        DefaultSqlSession.StrictMap<Object> strictMap = (DefaultSqlSession.StrictMap<Object>) parameterObject;
                        if (strictMap.containsKey("collection")) {
                            Object collectionObject = strictMap.get("collection");
                            if (collectionObject instanceof List) {
                                for (Object collectionNext : (List<?>) collectionObject) {
                                    resetObjectEncryptField(collectionNext);
                                }
                            }
                        }
                        if (strictMap.containsKey("list")) {
                            Object listObject = strictMap.get("list");
                            if (listObject instanceof List) {
                                for (Object listNext : (List<?>) listObject) {
                                    resetObjectEncryptField(listNext);
                                }
                            }
                        }
                        if (strictMap.containsKey("array")) {
                            Object arrayObject = strictMap.get("array");
                            if (arrayObject.getClass().isArray()) {
                                Object[] array = (Object[]) arrayObject;
                                for (Object arrayNext : array) {
                                    resetObjectEncryptField(arrayNext);
                                }
                            }
                        }
                    } else {
                        // DO作为参数. object
                        resetObjectEncryptField(parameterObject);
                    }
                }
            }
        }
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }

    /**
     * 获取DOMapper目标方法的加密字段model
     *
     * @param className
     * @param methodName
     * @return
     */
    private Map<String, SensitiveTargetMethodEncryptFieldModel> getTargetMethodEncryptFieldMap(String className, String methodName) {
        Map<String, SensitiveTargetMethodEncryptFieldModel> targetMethodEncryptFieldModelMap = null;
        try {
            Class<?> aClass = Class.forName(className);
            for (Method method : aClass.getDeclaredMethods()) {
                if (StrUtil.equals(method.getName(), methodName)) {
                    final Class<?>[] parameterTypes = method.getParameterTypes();
                    for (int paramIndex = 0; paramIndex < parameterTypes.length; paramIndex++) {
                        Class<?> parameterType = parameterTypes[paramIndex];
                        if (String.class.isAssignableFrom(parameterType)) {
                            String name = null;
                            String secretName = null;
                            String encryptContentType = null;
                            for (Annotation annotation : method.getParameterAnnotations()[paramIndex]) {
                                if (annotation instanceof Param) {
                                    name = ((Param) annotation).value();
                                }
                                if (annotation instanceof LifecircleEncrypt) {
                                    secretName = ((LifecircleEncrypt) annotation).value();
                                    encryptContentType = ((LifecircleEncrypt) annotation).encryptContentType().getValue();
                                }
                                if (StrUtil.isNotBlank(name) && StrUtil.isNotBlank(secretName)) {
                                    targetMethodEncryptFieldModelMap = new HashMap<>();
                                    targetMethodEncryptFieldModelMap.put(name,
                                            SensitiveTargetMethodEncryptFieldModel.builder()
                                                    .paramName(name).secretName(secretName).encryptContentType(encryptContentType)
                                                    .build()
                                    );
                                }
                            }
                        }
                    }
                }
            }
        } catch (ClassNotFoundException e) {
            // ignore
            LogUtil.error(log, "ark-spring-boot-starter >> getTargetMethodEncryptFieldMap >> 获取DOMapper目标方法的加密字段map发生异常 >> className={}, methodName={}", e, className, methodName);
        }
        if (CollUtil.isNotEmpty(targetMethodEncryptFieldModelMap)) {
            return targetMethodEncryptFieldModelMap;
        }
        return DEFAULT_TARGET_METHOD_ENCRYPT_FIELD_MODEL_MAP;
    }

    /**
     * 重置对象的加密字段value
     *
     * @param object
     * @throws IllegalAccessException
     */
    private void resetObjectEncryptField(Object object) throws IllegalAccessException {
        if (object != null) {
            Field[] declaredFields = object.getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                LifecircleEncrypt encryptField = field.getAnnotation(LifecircleEncrypt.class);
                if (ObjectUtil.isNotNull(encryptField)) {
                    field.setAccessible(true);
                    Object fieldValue = field.get(object);
                    if (fieldValue instanceof String) {
                        String encryptValue;
                        if (LifecircleEncrypt.EncryptContentTypeEnum.SEARCH_KEYWORD.equals(encryptField.encryptContentType())) {
                            encryptValue = defaultFieldEncryptExecutor.encryptSearchKeywords(fieldValue.toString());
                        } else {
//                            encryptValue = defaultFieldEncryptExecutor.encrypt(fieldValue.toString(), encryptField.secretName());
                            encryptValue = defaultFieldEncryptExecutor.encrypt(fieldValue.toString());
                        }
                        field.set(object, encryptValue);
                    }
                }
            }
        }
    }
}
