/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.driver.executor.batch;

import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import lombok.Generated;
import org.apache.shardingsphere.driver.executor.batch.BatchExecutionUnit;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroup;
import org.apache.shardingsphere.infra.executor.kernel.model.ExecutionGroupContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.ConnectionMode;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.SQLExecutorExceptionHandler;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutionUnit;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutor;
import org.apache.shardingsphere.infra.executor.sql.execute.engine.driver.jdbc.JDBCExecutorCallback;
import org.apache.shardingsphere.infra.rule.ShardingSphereRule;
import org.apache.shardingsphere.infra.rule.identifier.type.DataNodeContainedRule;
import org.apache.shardingsphere.mode.metadata.MetaDataContexts;
import org.apache.shardingsphere.sql.parser.sql.common.statement.SQLStatement;

public final class BatchPreparedStatementExecutor {
    private final MetaDataContexts metaDataContexts;
    private final JDBCExecutor jdbcExecutor;
    private ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext;
    private final Collection<BatchExecutionUnit> batchExecutionUnits;
    private int batchCount;
    private final String schemaName;

    public BatchPreparedStatementExecutor(MetaDataContexts metaDataContexts, JDBCExecutor jdbcExecutor, String schemaName) {
        this.schemaName = schemaName;
        this.metaDataContexts = metaDataContexts;
        this.jdbcExecutor = jdbcExecutor;
        this.executionGroupContext = new ExecutionGroupContext(new LinkedList());
        this.batchExecutionUnits = new LinkedList<BatchExecutionUnit>();
    }

    public void init(ExecutionGroupContext<JDBCExecutionUnit> executionGroupContext) {
        this.executionGroupContext = executionGroupContext;
    }

    public void addBatchForExecutionUnits(Collection<ExecutionUnit> executionUnits) {
        Collection<BatchExecutionUnit> batchExecutionUnits = this.createBatchExecutionUnits(executionUnits);
        this.handleOldBatchExecutionUnits(batchExecutionUnits);
        this.handleNewBatchExecutionUnits(batchExecutionUnits);
        ++this.batchCount;
    }

    private Collection<BatchExecutionUnit> createBatchExecutionUnits(Collection<ExecutionUnit> executionUnits) {
        ArrayList<BatchExecutionUnit> result = new ArrayList<BatchExecutionUnit>(executionUnits.size());
        for (ExecutionUnit each : executionUnits) {
            BatchExecutionUnit batchExecutionUnit = new BatchExecutionUnit(each);
            result.add(batchExecutionUnit);
        }
        return result;
    }

    private void handleOldBatchExecutionUnits(Collection<BatchExecutionUnit> newExecutionUnits) {
        newExecutionUnits.forEach(this::reviseBatchExecutionUnits);
    }

    private void reviseBatchExecutionUnits(BatchExecutionUnit batchExecutionUnit) {
        for (BatchExecutionUnit each : this.batchExecutionUnits) {
            if (!each.equals(batchExecutionUnit)) continue;
            this.reviseBatchExecutionUnit(each, batchExecutionUnit);
        }
    }

    private void reviseBatchExecutionUnit(BatchExecutionUnit oldBatchExecutionUnit, BatchExecutionUnit newBatchExecutionUnit) {
        oldBatchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters().addAll(newBatchExecutionUnit.getExecutionUnit().getSqlUnit().getParameters());
        oldBatchExecutionUnit.mapAddBatchCount(this.batchCount);
    }

    private void handleNewBatchExecutionUnits(Collection<BatchExecutionUnit> newExecutionUnits) {
        newExecutionUnits.removeAll(this.batchExecutionUnits);
        for (BatchExecutionUnit each : newExecutionUnits) {
            each.mapAddBatchCount(this.batchCount);
        }
        this.batchExecutionUnits.addAll(newExecutionUnits);
    }

    public int[] executeBatch(SQLStatementContext<?> sqlStatementContext) throws SQLException {
        boolean isExceptionThrown = SQLExecutorExceptionHandler.isExceptionThrown();
        JDBCExecutorCallback<int[]> callback = new JDBCExecutorCallback<int[]>(this.metaDataContexts.getMetaData(this.schemaName).getResource().getDatabaseType(), sqlStatementContext.getSqlStatement(), isExceptionThrown){

            protected int[] executeSQL(String sql, Statement statement, ConnectionMode connectionMode) throws SQLException {
                return statement.executeBatch();
            }

            protected Optional<int[]> getSaneResult(SQLStatement sqlStatement) {
                return Optional.empty();
            }
        };
        List results = this.jdbcExecutor.execute(this.executionGroupContext, (JDBCExecutorCallback)callback);
        if (results.isEmpty()) {
            return new int[0];
        }
        return this.isNeedAccumulate(sqlStatementContext) ? this.accumulate(results) : (int[])results.get(0);
    }

    private boolean isNeedAccumulate(SQLStatementContext<?> sqlStatementContext) {
        for (ShardingSphereRule each : this.metaDataContexts.getMetaData(this.schemaName).getRuleMetaData().getRules()) {
            if (!(each instanceof DataNodeContainedRule) || !((DataNodeContainedRule)each).isNeedAccumulate(sqlStatementContext.getTablesContext().getTableNames())) continue;
            return true;
        }
        return false;
    }

    private int[] accumulate(List<int[]> results) {
        int[] result = new int[this.batchCount];
        int count = 0;
        for (ExecutionGroup each : this.executionGroupContext.getInputGroups()) {
            for (JDBCExecutionUnit eachUnit : each.getInputs()) {
                Map<Object, Object> jdbcAndActualAddBatchCallTimesMap = Collections.emptyMap();
                for (BatchExecutionUnit batchExecutionUnit : this.batchExecutionUnits) {
                    if (!this.isSameDataSourceAndSQL(batchExecutionUnit, eachUnit)) continue;
                    jdbcAndActualAddBatchCallTimesMap = batchExecutionUnit.getJdbcAndActualAddBatchCallTimesMap();
                    break;
                }
                for (Map.Entry entry : jdbcAndActualAddBatchCallTimesMap.entrySet()) {
                    int value = null == results.get(count) ? 0 : results.get(count)[(Integer)entry.getValue()];
                    int n = (Integer)entry.getKey();
                    result[n] = result[n] + value;
                }
                ++count;
            }
        }
        return result;
    }

    private boolean isSameDataSourceAndSQL(BatchExecutionUnit batchExecutionUnit, JDBCExecutionUnit jdbcExecutionUnit) {
        return batchExecutionUnit.getExecutionUnit().getDataSourceName().equals(jdbcExecutionUnit.getExecutionUnit().getDataSourceName()) && batchExecutionUnit.getExecutionUnit().getSqlUnit().getSql().equals(jdbcExecutionUnit.getExecutionUnit().getSqlUnit().getSql());
    }

    public List<Statement> getStatements() {
        LinkedList<Statement> result = new LinkedList<Statement>();
        for (ExecutionGroup eachGroup : this.executionGroupContext.getInputGroups()) {
            for (JDBCExecutionUnit eachUnit : eachGroup.getInputs()) {
                Statement storageResource = eachUnit.getStorageResource();
                result.add(storageResource);
            }
        }
        return result;
    }

    public List<List<Object>> getParameterSet(Statement statement) {
        for (ExecutionGroup each : this.executionGroupContext.getInputGroups()) {
            Optional<JDBCExecutionUnit> result = this.findJDBCExecutionUnit(statement, (ExecutionGroup<JDBCExecutionUnit>)each);
            if (!result.isPresent()) continue;
            return this.getParameterSets(result.get());
        }
        return Collections.emptyList();
    }

    private Optional<JDBCExecutionUnit> findJDBCExecutionUnit(Statement statement, ExecutionGroup<JDBCExecutionUnit> executionGroup) {
        for (JDBCExecutionUnit each : executionGroup.getInputs()) {
            if (!each.getStorageResource().equals(statement)) continue;
            return Optional.of(each);
        }
        return Optional.empty();
    }

    private List<List<Object>> getParameterSets(JDBCExecutionUnit executionUnit) {
        for (BatchExecutionUnit each : this.batchExecutionUnits) {
            if (!this.isSameDataSourceAndSQL(each, executionUnit)) continue;
            return each.getParameterSets();
        }
        throw new IllegalStateException();
    }

    public void clear() throws SQLException {
        this.getStatements().clear();
        this.executionGroupContext.getInputGroups().clear();
        this.batchCount = 0;
        this.batchExecutionUnits.clear();
    }

    @Generated
    public Collection<BatchExecutionUnit> getBatchExecutionUnits() {
        return this.batchExecutionUnits;
    }
}

