/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shardingsphere.shardingjdbc.executor;

import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Predicate;
import com.google.common.collect.Collections2;
import com.google.common.collect.Iterators;
import com.google.common.collect.Lists;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import org.apache.shardingsphere.core.constant.ConnectionMode;
import org.apache.shardingsphere.core.execute.ShardingExecuteGroup;
import org.apache.shardingsphere.core.execute.StatementExecuteUnit;
import org.apache.shardingsphere.core.execute.sql.execute.SQLExecuteCallback;
import org.apache.shardingsphere.core.execute.sql.execute.threadlocal.ExecutorExceptionHandler;
import org.apache.shardingsphere.core.execute.sql.prepare.SQLExecutePrepareCallback;
import org.apache.shardingsphere.core.route.BatchRouteUnit;
import org.apache.shardingsphere.core.route.RouteUnit;
import org.apache.shardingsphere.core.route.SQLRouteResult;
import org.apache.shardingsphere.shardingjdbc.executor.AbstractStatementExecutor;
import org.apache.shardingsphere.shardingjdbc.jdbc.core.connection.ShardingConnection;

public final class BatchPreparedStatementExecutor
extends AbstractStatementExecutor {
    private final Collection<BatchRouteUnit> routeUnits = new LinkedList<BatchRouteUnit>();
    private final boolean returnGeneratedKeys;
    private int batchCount;

    public BatchPreparedStatementExecutor(int resultSetType, int resultSetConcurrency, int resultSetHoldability, boolean returnGeneratedKeys, ShardingConnection shardingConnection) {
        super(resultSetType, resultSetConcurrency, resultSetHoldability, shardingConnection);
        this.returnGeneratedKeys = returnGeneratedKeys;
    }

    public void init(SQLRouteResult routeResult) throws SQLException {
        this.setSqlStatement(routeResult.getSqlStatement());
        this.getExecuteGroups().addAll(this.obtainExecuteGroups(this.routeUnits));
    }

    private Collection<ShardingExecuteGroup<StatementExecuteUnit>> obtainExecuteGroups(Collection<BatchRouteUnit> routeUnits) throws SQLException {
        return this.getSqlExecutePrepareTemplate().getExecuteUnitGroups((Collection)Lists.transform(new ArrayList<BatchRouteUnit>(routeUnits), (Function)new Function<BatchRouteUnit, RouteUnit>(){

            public RouteUnit apply(BatchRouteUnit input) {
                return input.getRouteUnit();
            }
        }), new SQLExecutePrepareCallback(){

            public List<Connection> getConnections(ConnectionMode connectionMode, String dataSourceName, int connectionSize) throws SQLException {
                return BatchPreparedStatementExecutor.super.getConnection().getConnections(connectionMode, dataSourceName, connectionSize);
            }

            public StatementExecuteUnit createStatementExecuteUnit(Connection connection, RouteUnit routeUnit, ConnectionMode connectionMode) throws SQLException {
                return new StatementExecuteUnit(routeUnit, (Statement)BatchPreparedStatementExecutor.this.createPreparedStatement(connection, routeUnit.getSqlUnit().getSql()), connectionMode);
            }
        });
    }

    private PreparedStatement createPreparedStatement(Connection connection, String sql) throws SQLException {
        return this.returnGeneratedKeys ? connection.prepareStatement(sql, 1) : connection.prepareStatement(sql, this.getResultSetType(), this.getResultSetConcurrency(), this.getResultSetHoldability());
    }

    public void addBatchForRouteUnits(SQLRouteResult routeResult) {
        this.handleOldRouteUnits(this.createBatchRouteUnits(routeResult.getRouteUnits()));
        this.handleNewRouteUnits(this.createBatchRouteUnits(routeResult.getRouteUnits()));
        ++this.batchCount;
    }

    private Collection<BatchRouteUnit> createBatchRouteUnits(Collection<RouteUnit> routeUnits) {
        LinkedList<BatchRouteUnit> result = new LinkedList<BatchRouteUnit>();
        for (RouteUnit each : routeUnits) {
            result.add(new BatchRouteUnit(each));
        }
        return result;
    }

    private void handleOldRouteUnits(Collection<BatchRouteUnit> newRouteUnits) {
        for (final BatchRouteUnit each : newRouteUnits) {
            Optional batchRouteUnitOptional = Iterators.tryFind(this.routeUnits.iterator(), (Predicate)new Predicate<BatchRouteUnit>(){

                public boolean apply(BatchRouteUnit input) {
                    return input.equals((Object)each);
                }
            });
            if (!batchRouteUnitOptional.isPresent()) continue;
            this.reviseBatchRouteUnit((BatchRouteUnit)batchRouteUnitOptional.get(), each);
        }
    }

    private void reviseBatchRouteUnit(BatchRouteUnit oldBatchRouteUnit, BatchRouteUnit newBatchRouteUnit) {
        oldBatchRouteUnit.getRouteUnit().getSqlUnit().getParameters().addAll(newBatchRouteUnit.getRouteUnit().getSqlUnit().getParameters());
        oldBatchRouteUnit.mapAddBatchCount(this.batchCount);
    }

    private void handleNewRouteUnits(Collection<BatchRouteUnit> newRouteUnits) {
        newRouteUnits.removeAll(this.routeUnits);
        for (BatchRouteUnit each : newRouteUnits) {
            each.mapAddBatchCount(this.batchCount);
        }
        this.routeUnits.addAll(newRouteUnits);
    }

    public int[] executeBatch() throws SQLException {
        boolean isExceptionThrown = ExecutorExceptionHandler.isExceptionThrown();
        SQLExecuteCallback<int[]> callback = new SQLExecuteCallback<int[]>(this.getDatabaseType(), isExceptionThrown){

            protected int[] executeSQL(RouteUnit routeUnit, Statement statement, ConnectionMode connectionMode) throws SQLException {
                return statement.executeBatch();
            }
        };
        List<int[]> results = this.executeCallback(callback);
        if (this.isAccumulate()) {
            return this.accumulate(results);
        }
        return results.get(0);
    }

    private int[] accumulate(List<int[]> results) {
        int[] result = new int[this.batchCount];
        int count = 0;
        for (ShardingExecuteGroup<StatementExecuteUnit> each : this.getExecuteGroups()) {
            for (StatementExecuteUnit eachUnit : each.getInputs()) {
                Map jdbcAndActualAddBatchCallTimesMap = null;
                for (BatchRouteUnit batchRouteUnit : this.routeUnits) {
                    if (!batchRouteUnit.getRouteUnit().equals((Object)eachUnit.getRouteUnit())) continue;
                    jdbcAndActualAddBatchCallTimesMap = batchRouteUnit.getJdbcAndActualAddBatchCallTimesMap();
                    break;
                }
                for (Map.Entry entry : jdbcAndActualAddBatchCallTimesMap.entrySet()) {
                    int value = null == results.get(count) ? 0 : results.get(count)[(Integer)entry.getValue()];
                    int n = (Integer)entry.getKey();
                    result[n] = result[n] + value;
                }
                ++count;
            }
        }
        return result;
    }

    @Override
    public List<Statement> getStatements() {
        LinkedList<Statement> result = new LinkedList<Statement>();
        for (ShardingExecuteGroup<StatementExecuteUnit> each : this.getExecuteGroups()) {
            result.addAll(Lists.transform((List)each.getInputs(), (Function)new Function<StatementExecuteUnit, Statement>(){

                public Statement apply(StatementExecuteUnit input) {
                    return input.getStatement();
                }
            }));
        }
        return result;
    }

    public List<List<Object>> getParameterSet(Statement statement) {
        List<List<Object>> result = new LinkedList<List<Object>>();
        for (ShardingExecuteGroup<StatementExecuteUnit> each : this.getExecuteGroups()) {
            Optional<StatementExecuteUnit> target = this.getStatementExecuteUnit(statement, each);
            if (!target.isPresent()) continue;
            result = this.getParameterSets((StatementExecuteUnit)target.get());
            break;
        }
        return result;
    }

    private Optional<StatementExecuteUnit> getStatementExecuteUnit(final Statement statement, ShardingExecuteGroup<StatementExecuteUnit> executeGroup) {
        return Iterators.tryFind(executeGroup.getInputs().iterator(), (Predicate)new Predicate<StatementExecuteUnit>(){

            public boolean apply(StatementExecuteUnit input) {
                return input.getStatement().equals(statement);
            }
        });
    }

    private List<List<Object>> getParameterSets(final StatementExecuteUnit executeUnit) {
        List result = ((BatchRouteUnit)Collections2.filter(this.routeUnits, (Predicate)new Predicate<BatchRouteUnit>(){

            public boolean apply(BatchRouteUnit input) {
                return input.getRouteUnit().equals((Object)executeUnit.getRouteUnit());
            }
        }).iterator().next()).getParameterSets();
        return result;
    }

    @Override
    public void clear() throws SQLException {
        super.clear();
        this.batchCount = 0;
        this.routeUnits.clear();
    }

    public boolean isReturnGeneratedKeys() {
        return this.returnGeneratedKeys;
    }
}

