/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLCompressionException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixAgg;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibCompAgg {
    private static final Log LOG = LogFactory.getLog((String)CLALibCompAgg.class.getName());
    private static final long MIN_PAR_AGG_THRESHOLD = 8192L;
    private static ThreadLocal<MatrixBlock> memPool = new ThreadLocal<MatrixBlock>(){

        @Override
        protected MatrixBlock initialValue() {
            return null;
        }
    };

    public static MatrixBlock aggregateUnary(CompressedMatrixBlock inputMatrix, MatrixBlock outputMatrix, AggregateUnaryOperator op, int blen, MatrixIndexes indexesIn, boolean inCP) {
        op = CLALibCompAgg.replaceKahnOperations(op);
        if (inputMatrix.getColGroups() != null) {
            CLALibCompAgg.fillStart(outputMatrix, op);
            if (inputMatrix.isOverlapping() && (op.aggOp.increOp.fn instanceof KahanPlusSq || op.aggOp.increOp.fn instanceof Builtin && (((Builtin)op.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MIN || ((Builtin)op.aggOp.increOp.fn).getBuiltinCode() == Builtin.BuiltinCode.MAX))) {
                CLALibCompAgg.aggregateUnaryOverlapping(inputMatrix, outputMatrix, op, indexesIn, inCP);
            } else {
                CLALibCompAgg.aggregateUnaryNormalCompressedMatrixBlock(inputMatrix, outputMatrix, op, blen, indexesIn, inCP);
            }
        }
        outputMatrix.recomputeNonZeros();
        return outputMatrix;
    }

    private static AggregateUnaryOperator replaceKahnOperations(AggregateUnaryOperator op) {
        if (op.aggOp.increOp.fn instanceof KahanPlus) {
            return new AggregateUnaryOperator(new AggregateOperator(0.0, Plus.getPlusFnObject()), op.indexFn, op.getNumThreads());
        }
        return op;
    }

    private static void aggregateUnaryNormalCompressedMatrixBlock(CompressedMatrixBlock m, MatrixBlock o, AggregateUnaryOperator op, int blen, MatrixIndexes indexesIn, boolean inCP) {
        AggregateUnaryOperator opm;
        int k = op.getNumThreads();
        AggregateUnaryOperator aggregateUnaryOperator = opm = op.aggOp.increOp.fn instanceof Mean ? new AggregateUnaryOperator(new AggregateOperator(0.0, Plus.getPlusFnObject()), op.indexFn) : op;
        if (CLALibCompAgg.isValidForParallelProcessing(m, op)) {
            CLALibCompAgg.aggregateInParallel(m, o, opm, k);
        } else {
            CLALibCompAgg.aggregateUnaryOperations(opm, m.getColGroups(), o.getDenseBlockValues(), 0, m.getNumRows(), m.getNumColumns());
        }
        CLALibCompAgg.postProcessAggregate(m, o, op);
    }

    private static boolean isValidForParallelProcessing(CompressedMatrixBlock m1, AggregateUnaryOperator op) {
        return op.getNumThreads() > 1 && m1.getExactSizeOnDisk() > 8192L;
    }

    /*
     * WARNING - void declaration
     */
    private static void aggregateInParallel(CompressedMatrixBlock m1, MatrixBlock ret, AggregateUnaryOperator op, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<UnaryAggregateTask> tasks = new ArrayList<UnaryAggregateTask>();
        try {
            if (op.indexFn instanceof ReduceCol) {
                void var8_14;
                ret.allocateDenseBlock();
                int blkz = 65535;
                int blklen = Math.max((int)Math.ceil((double)m1.getNumRows() / (double)(k * 2)), 65535);
                boolean bl = false;
                while (var8_14 * blklen < m1.getNumRows()) {
                    tasks.add(new UnaryAggregateTask(m1.getColGroups(), ret, (int)(var8_14 * blklen), Math.min((int)((var8_14 + true) * blklen), m1.getNumRows()), op, m1.getNumColumns()));
                    ++var8_14;
                }
            } else {
                List<List<AColGroup>> grpParts = CLALibCompAgg.createTaskPartition(m1.getColGroups(), k);
                for (List<AColGroup> list : grpParts) {
                    tasks.add(new UnaryAggregateTask(list, ret, 0, m1.getNumRows(), op, m1.getNumColumns(), m1.isOverlapping()));
                }
            }
            List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
            pool.shutdown();
            if (op.indexFn instanceof ReduceAll) {
                if (op.aggOp.increOp.fn instanceof Builtin) {
                    CLALibCompAgg.aggregateResults(ret, futures, op);
                } else {
                    CLALibCompAgg.sumResults(ret, futures);
                }
            } else if (op.indexFn instanceof ReduceRow && m1.isOverlapping()) {
                if (op.aggOp.increOp.fn instanceof Builtin) {
                    CLALibCompAgg.aggregateResultVectors(ret, futures, op);
                } else {
                    CLALibCompAgg.sumResultVectors(ret, futures);
                }
            } else {
                for (Future future : futures) {
                    future.get();
                }
            }
        }
        catch (InterruptedException | ExecutionException e) {
            LOG.error((Object)"Aggregate In parallel failed.");
            throw new DMLRuntimeException(e);
        }
    }

    private static void sumResults(MatrixBlock ret, List<Future<MatrixBlock>> futures) throws InterruptedException, ExecutionException {
        double val = ret.quickGetValue(0, 0);
        for (Future<MatrixBlock> rtask : futures) {
            double tmp = rtask.get().quickGetValue(0, 0);
            val += tmp;
        }
        ret.quickSetValue(0, 0, val);
    }

    private static void sumResultVectors(MatrixBlock ret, List<Future<MatrixBlock>> futures) throws InterruptedException, ExecutionException {
        double[] retVals = ret.getDenseBlockValues();
        for (Future<MatrixBlock> rtask : futures) {
            double[] taskResult = rtask.get().getDenseBlockValues();
            for (int i = 0; i < retVals.length; ++i) {
                int n = i;
                retVals[n] = retVals[n] + taskResult[i];
            }
        }
        ret.setNonZeros(ret.getNumColumns());
    }

    private static void aggregateResults(MatrixBlock ret, List<Future<MatrixBlock>> futures, AggregateUnaryOperator op) throws InterruptedException, ExecutionException {
        double val = ret.quickGetValue(0, 0);
        for (Future<MatrixBlock> rtask : futures) {
            double tmp = rtask.get().quickGetValue(0, 0);
            val = op.aggOp.increOp.fn.execute(val, tmp);
        }
        ret.quickSetValue(0, 0, val);
    }

    private static void aggregateResultVectors(MatrixBlock ret, List<Future<MatrixBlock>> futures, AggregateUnaryOperator op) throws InterruptedException, ExecutionException {
        double[] retVals = ret.getDenseBlockValues();
        for (Future<MatrixBlock> rtask : futures) {
            double[] taskResult = rtask.get().getDenseBlockValues();
            for (int i = 0; i < retVals.length; ++i) {
                retVals[i] = op.aggOp.increOp.fn.execute(retVals[i], taskResult[i]);
            }
        }
        ret.setNonZeros(ret.getNumColumns());
    }

    private static void divideByNumberOfCellsForMean(CompressedMatrixBlock m1, MatrixBlock ret, IndexFunction idxFn) {
        if (idxFn instanceof ReduceAll) {
            CLALibCompAgg.divideByNumberOfCellsForMeanAll(m1, ret);
        } else if (idxFn instanceof ReduceCol) {
            CLALibCompAgg.divideByNumberOfCellsForMeanRows(m1, ret);
        } else if (idxFn instanceof ReduceRow) {
            CLALibCompAgg.divideByNumberOfCellsForMeanCols(m1, ret);
        }
    }

    private static void divideByNumberOfCellsForMeanRows(CompressedMatrixBlock m1, MatrixBlock ret) {
        double[] values = ret.getDenseBlockValues();
        for (int i = 0; i < m1.getNumRows(); ++i) {
            values[i] = values[i] / (double)m1.getNumColumns();
        }
    }

    private static void divideByNumberOfCellsForMeanCols(CompressedMatrixBlock m1, MatrixBlock ret) {
        double[] values = ret.getDenseBlockValues();
        for (int i = 0; i < m1.getNumColumns(); ++i) {
            values[i] = values[i] / (double)m1.getNumRows();
        }
    }

    private static void divideByNumberOfCellsForMeanAll(CompressedMatrixBlock m1, MatrixBlock ret) {
        ret.quickSetValue(0, 0, ret.quickGetValue(0, 0) / (double)(m1.getNumColumns() * m1.getNumRows()));
    }

    private static void postProcessAggregate(CompressedMatrixBlock m1, MatrixBlock ret, AggregateUnaryOperator op) {
        if (op.aggOp.increOp.fn instanceof Mean) {
            CLALibCompAgg.divideByNumberOfCellsForMean(m1, ret, op.indexFn);
        }
    }

    private static void aggregateUnaryOverlapping(CompressedMatrixBlock m1, MatrixBlock ret, AggregateUnaryOperator op, MatrixIndexes indexesIn, boolean inCP) {
        try {
            List<Future<MatrixBlock>> rtasks = CLALibCompAgg.generateUnaryAggregateOverlappingFutures(m1, ret, op);
            CLALibCompAgg.reduceOverlappingFutures(rtasks, ret, op);
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static void reduceOverlappingFutures(List<Future<MatrixBlock>> rtasks, MatrixBlock ret, AggregateUnaryOperator op) throws InterruptedException, ExecutionException {
        if (CLALibCompAgg.isReduceAll(ret, op.indexFn)) {
            CLALibCompAgg.reduceAllOverlappingFutures(rtasks, ret, op);
        } else if (op.indexFn instanceof ReduceRow) {
            CLALibCompAgg.reduceColOverlappingFutures(rtasks, ret, op);
        } else {
            CLALibCompAgg.reduceRowOverlappingFutures(rtasks, ret, op);
        }
    }

    private static void reduceColOverlappingFutures(List<Future<MatrixBlock>> rtasks, MatrixBlock ret, AggregateUnaryOperator op) throws InterruptedException, ExecutionException {
        for (Future<MatrixBlock> rtask : rtasks) {
            LibMatrixBincell.bincellOpInPlace(ret, rtask.get(), op.aggOp.increOp.fn instanceof KahanFunction ? new BinaryOperator(Plus.getPlusFnObject()) : op.aggOp.increOp);
        }
    }

    private static void reduceRowOverlappingFutures(List<Future<MatrixBlock>> rtasks, MatrixBlock ret, AggregateUnaryOperator op) throws InterruptedException, ExecutionException {
        for (Future<MatrixBlock> rtask : rtasks) {
            rtask.get();
        }
    }

    private static boolean isReduceAll(MatrixBlock ret, IndexFunction idxFn) {
        return idxFn instanceof ReduceAll || ret.getNumColumns() == 1 && ret.getNumRows() == 1;
    }

    private static void reduceAllOverlappingFutures(List<Future<MatrixBlock>> rtasks, MatrixBlock ret, AggregateUnaryOperator op) throws InterruptedException, ExecutionException {
        if (op.aggOp.increOp.fn instanceof KahanFunction) {
            KahanObject kbuff = new KahanObject(ret.quickGetValue(0, 0), 0.0);
            KahanPlus kplus = KahanPlus.getKahanPlusFnObject();
            for (Future<MatrixBlock> rtask : rtasks) {
                double tmp = rtask.get().quickGetValue(0, 0);
                kplus.execute2(kbuff, tmp);
            }
            ret.quickSetValue(0, 0, kbuff._sum);
        } else {
            double val = ret.quickGetValue(0, 0);
            for (Future<MatrixBlock> rtask : rtasks) {
                double tmp = rtask.get().quickGetValue(0, 0);
                val = op.aggOp.increOp.fn.execute(val, tmp);
            }
            ret.quickSetValue(0, 0, val);
        }
    }

    private static List<Future<MatrixBlock>> generateUnaryAggregateOverlappingFutures(CompressedMatrixBlock m1, MatrixBlock ret, AggregateUnaryOperator op) throws InterruptedException {
        ExecutorService pool = CommonThreadPool.get(op.getNumThreads());
        ArrayList<UnaryAggregateOverlappingTask> tasks = new ArrayList<UnaryAggregateOverlappingTask>();
        int blklen = Math.min(m1.getNumRows() / op.getNumThreads(), 65535);
        int i = 0;
        while (i * blklen < m1.getNumRows()) {
            tasks.add(new UnaryAggregateOverlappingTask(m1, ret, i * blklen, Math.min((i + 1) * blklen, m1.getNumRows()), op));
            ++i;
        }
        List<Future<MatrixBlock>> futures = pool.invokeAll(tasks);
        pool.shutdown();
        return futures;
    }

    private static List<List<AColGroup>> createTaskPartition(List<AColGroup> colGroups, int k) {
        int numTasks = Math.min(k, colGroups.size());
        ArrayList<List<AColGroup>> grpParts = new ArrayList<List<AColGroup>>();
        for (int i = 0; i < numTasks; ++i) {
            grpParts.add(new ArrayList());
        }
        int pos = 0;
        for (AColGroup grp : colGroups) {
            List g = (List)grpParts.get(pos);
            g.add(grp);
            pos = (pos + 1) % numTasks;
        }
        return grpParts;
    }

    private static void aggregateUnaryOperations(AggregateUnaryOperator op, List<AColGroup> groups, double[] ret, int rl, int ru, int numColumns) {
        if (op.indexFn instanceof ReduceCol && op.aggOp.increOp.fn instanceof Builtin) {
            CLALibCompAgg.aggregateUnaryBuiltinRowOperation(op, groups, ret, rl, ru, numColumns);
        } else {
            CLALibCompAgg.aggregateUnaryNormalOperation(op, groups, ret, rl, ru, numColumns);
        }
    }

    private static void aggregateUnaryNormalOperation(AggregateUnaryOperator op, List<AColGroup> groups, double[] ret, int rl, int ru, int numColumns) {
        for (AColGroup grp : groups) {
            grp.unaryAggregateOperations(op, ret, rl, ru);
        }
    }

    private static void aggregateUnaryBuiltinRowOperation(AggregateUnaryOperator op, List<AColGroup> groups, double[] ret, int rl, int ru, int numColumns) {
        boolean isDense = true;
        for (AColGroup g : groups) {
            isDense &= g.isDense();
        }
        if (isDense) {
            for (AColGroup grp : groups) {
                grp.unaryAggregateOperations(op, ret, rl, ru);
            }
        } else {
            int[] rnnz = new int[ru - rl];
            int numberDenseColumns = 0;
            for (AColGroup grp : groups) {
                grp.unaryAggregateOperations(op, ret, rl, ru);
                if (grp.isDense()) {
                    numberDenseColumns += grp.getNumCols();
                    continue;
                }
                grp.countNonZerosPerRow(rnnz, rl, ru);
            }
            for (int row = rl; row < ru; ++row) {
                if (rnnz[row - rl] + numberDenseColumns >= numColumns) continue;
                ret[row] = op.aggOp.increOp.fn.execute(ret[row], 0.0);
            }
        }
    }

    private static void fillStart(MatrixBlock ret, AggregateUnaryOperator op) {
        if (op.aggOp.increOp.fn instanceof Builtin) {
            Double val = null;
            switch (((Builtin)op.aggOp.increOp.fn).getBuiltinCode()) {
                case MAX: {
                    val = Double.NEGATIVE_INFINITY;
                    break;
                }
                case MIN: {
                    val = Double.POSITIVE_INFINITY;
                    break;
                }
            }
            if (val != null) {
                ret.getDenseBlock().set(val);
            }
        }
    }

    private static class UnaryAggregateOverlappingTask
    implements Callable<MatrixBlock> {
        private final CompressedMatrixBlock _m1;
        private final int _rl;
        private final int _ru;
        private final MatrixBlock _ret;
        private final AggregateUnaryOperator _op;

        protected UnaryAggregateOverlappingTask(CompressedMatrixBlock m1, MatrixBlock ret, int rl, int ru, AggregateUnaryOperator op) {
            this._m1 = m1;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
            this._ret = ret;
        }

        private MatrixBlock getTmp() {
            MatrixBlock tmp = (MatrixBlock)memPool.get();
            if (tmp == null) {
                memPool.set(new MatrixBlock(this._ru - this._rl, this._m1.getNumColumns(), false, -1L).allocateBlock());
                tmp = (MatrixBlock)memPool.get();
            } else {
                tmp.reset(this._ru - this._rl, this._m1.getNumColumns(), false, -1L);
            }
            return tmp;
        }

        private MatrixBlock decompressToTemp() {
            MatrixBlock tmp = this.getTmp();
            for (AColGroup g : this._m1.getColGroups()) {
                g.decompressToBlockUnSafe(tmp, this._rl, this._ru, 0);
            }
            tmp.setNonZeros(this._rl + this._ru);
            return tmp;
        }

        @Override
        public MatrixBlock call() {
            MatrixBlock tmp = this.decompressToTemp();
            MatrixBlock outputBlock = tmp.prepareAggregateUnaryOutput(this._op, null, Math.max(tmp.getNumColumns(), tmp.getNumRows()));
            LibMatrixAgg.aggregateUnaryMatrix(tmp, outputBlock, this._op);
            outputBlock.dropLastRowsOrColumns(this._op.aggOp.correction);
            if (this._op.indexFn instanceof ReduceCol) {
                if (outputBlock.isEmpty()) {
                    return null;
                }
                if (outputBlock.isInSparseFormat()) {
                    throw new DMLCompressionException("Not implemented sparse and not something that should ever happen because we dont use sparse for column matrices");
                }
                double[] retValues = this._ret.getDenseBlockValues();
                int currentIndex = this._rl * this._ret.getNumColumns();
                double[] outputBlockValues = outputBlock.getDenseBlockValues();
                System.arraycopy(outputBlockValues, 0, retValues, currentIndex, outputBlockValues.length);
                return null;
            }
            return outputBlock;
        }
    }

    private static class UnaryAggregateTask
    implements Callable<MatrixBlock> {
        private final List<AColGroup> _groups;
        private final int _rl;
        private final int _ru;
        private final MatrixBlock _ret;
        private final int _numColumns;
        private final AggregateUnaryOperator _op;

        protected UnaryAggregateTask(List<AColGroup> groups, MatrixBlock ret, int rl, int ru, AggregateUnaryOperator op, int numColumns) {
            this._groups = groups;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
            this._numColumns = numColumns;
            if (this._op.indexFn instanceof ReduceAll) {
                this._ret = new MatrixBlock(1, 1, false);
                this._ret.allocateDenseBlock();
                if (this._op.aggOp.increOp.fn instanceof Builtin) {
                    System.arraycopy(ret.getDenseBlockValues(), 0, this._ret.getDenseBlockValues(), 0, ret.getNumRows() * ret.getNumColumns());
                }
            } else {
                this._ret = ret;
            }
        }

        protected UnaryAggregateTask(List<AColGroup> groups, MatrixBlock ret, int rl, int ru, AggregateUnaryOperator op, int numColumns, boolean overlapping) {
            this._groups = groups;
            this._op = op;
            this._rl = rl;
            this._ru = ru;
            this._numColumns = numColumns;
            if (this._op.indexFn instanceof ReduceAll || this._op.indexFn instanceof ReduceRow && overlapping) {
                this._ret = new MatrixBlock(ret.getNumRows(), ret.getNumColumns(), false);
                this._ret.allocateDenseBlock();
                if (this._op.aggOp.increOp.fn instanceof Builtin) {
                    System.arraycopy(ret.getDenseBlockValues(), 0, this._ret.getDenseBlockValues(), 0, ret.getNumRows() * ret.getNumColumns());
                }
            } else {
                this._ret = ret;
            }
        }

        @Override
        public MatrixBlock call() {
            CLALibCompAgg.aggregateUnaryOperations(this._op, this._groups, this._ret.getDenseBlockValues(), this._rl, this._ru, this._numColumns);
            return this._ret;
        }
    }
}

