/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibScalar {
    private static final Log LOG = LogFactory.getLog((String)CLALibScalar.class.getName());
    private static final int MINIMUM_PARALLEL_SIZE = 8096;

    public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixBlock m1, MatrixValue result) {
        if (CLALibScalar.isInvalidForCompressedOutput(m1, sop)) {
            LOG.warn((Object)("scalar overlapping not supported for op: " + sop.fn));
            MatrixBlock m1d = m1.decompress(sop.getNumThreads());
            return m1d.scalarOperations(sop, result);
        }
        CompressedMatrixBlock ret = CLALibScalar.setupRet(m1, result);
        List<AColGroup> colGroups = m1.getColGroups();
        if (m1.isOverlapping() && !(sop.fn instanceof Multiply) && !(sop.fn instanceof Divide)) {
            double v0 = sop.executeScalar(0.0);
            ColGroupConst c = v0 != 0.0 ? CLALibScalar.constOverlap(m1, v0) : null;
            boolean isMinus = sop instanceof LeftScalarOperator && sop.fn instanceof Minus;
            List<AColGroup> newColGroups = isMinus ? CLALibScalar.copyGroupsAndMultMinus(m1, sop, c, ret) : CLALibScalar.copyGroups(m1, sop, c, ret);
            ret.allocateColGroupList(newColGroups);
            ret.setOverlapping(true);
        } else {
            int threadsAvailable;
            int n = threadsAvailable = sop.getNumThreads() > 1 ? sop.getNumThreads() : OptimizerUtils.getConstrainedNumThreads(-1);
            if (threadsAvailable > 1) {
                CLALibScalar.parallelScalarOperations(sop, colGroups, ret, threadsAvailable);
            } else {
                ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
                for (AColGroup grp : colGroups) {
                    newColGroups.add(grp.scalarOperation(sop));
                }
                ret.allocateColGroupList(newColGroups);
            }
            ret.setOverlapping(m1.isOverlapping());
        }
        ret.recomputeNonZeros();
        return ret;
    }

    private static CompressedMatrixBlock setupRet(CompressedMatrixBlock m1, MatrixValue result) {
        CompressedMatrixBlock ret;
        if (result == null || !(result instanceof CompressedMatrixBlock)) {
            ret = new CompressedMatrixBlock(m1.getNumRows(), m1.getNumColumns());
        } else {
            ret = (CompressedMatrixBlock)result;
            ret.setNumColumns(m1.getNumColumns());
            ret.setNumRows(m1.getNumRows());
        }
        return ret;
    }

    private static ColGroupConst constOverlap(CompressedMatrixBlock m1, double v) {
        return (ColGroupConst)ColGroupConst.create(m1.getNumColumns(), v);
    }

    private static List<AColGroup> copyGroups(CompressedMatrixBlock m1, ScalarOperator sop, ColGroupConst c, CompressedMatrixBlock ret) {
        double[] constV = c != null ? c.getValues() : null;
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
        for (AColGroup grp : m1.getColGroups()) {
            if (grp instanceof ColGroupEmpty) continue;
            if (grp instanceof ColGroupConst) {
                ColGroupConst g = (ColGroupConst)grp;
                double[] gv = g.getValues();
                int[] colIdx = grp.getColIndices();
                if (constV == null) continue;
                for (int i = 0; i < colIdx.length; ++i) {
                    int n = colIdx[i];
                    constV[n] = constV[n] + gv[i];
                }
                continue;
            }
            newColGroups.add(grp.copy());
        }
        if (c != null) {
            newColGroups.add(c);
        }
        return newColGroups;
    }

    private static List<AColGroup> copyGroupsAndMultMinus(CompressedMatrixBlock m1, ScalarOperator sop, ColGroupConst c, CompressedMatrixBlock ret) {
        double[] constV = c.getValues();
        ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
        for (AColGroup grp : m1.getColGroups()) {
            if (grp instanceof ColGroupEmpty) continue;
            if (grp instanceof ColGroupConst) {
                ColGroupConst g = (ColGroupConst)grp;
                double[] gv = g.getValues();
                int[] colIdx = grp.getColIndices();
                for (int i = 0; i < colIdx.length; ++i) {
                    int n = colIdx[i];
                    constV[n] = constV[n] - gv[i];
                }
                continue;
            }
            newColGroups.add(grp.scalarOperation(new RightScalarOperator(Multiply.getMultiplyFnObject(), -1.0)));
        }
        if (c != null) {
            newColGroups.add(c);
        }
        return newColGroups;
    }

    private static boolean isInvalidForCompressedOutput(CompressedMatrixBlock m1, ScalarOperator sop) {
        return m1.isOverlapping() && !(sop.fn instanceof Multiply) && (!(sop.fn instanceof Divide) || !(sop instanceof RightScalarOperator)) && !(sop.fn instanceof Plus) && !(sop.fn instanceof Minus);
    }

    private static void parallelScalarOperations(ScalarOperator sop, List<AColGroup> colGroups, CompressedMatrixBlock ret, int k) {
        if (colGroups == null) {
            return;
        }
        ExecutorService pool = CommonThreadPool.get(k);
        List<ScalarTask> tasks = CLALibScalar.partition(sop, colGroups);
        try {
            List rtasks = pool.invokeAll(tasks);
            ArrayList<AColGroup> newColGroups = new ArrayList<AColGroup>();
            for (Future f : rtasks) {
                newColGroups.addAll((Collection)f.get());
            }
            ret.allocateColGroupList(newColGroups);
        }
        catch (InterruptedException | ExecutionException e) {
            pool.shutdown();
            throw new DMLRuntimeException(e);
        }
        pool.shutdown();
    }

    private static List<ScalarTask> partition(ScalarOperator sop, List<AColGroup> colGroups) {
        ArrayList<ScalarTask> tasks = new ArrayList<ScalarTask>();
        ArrayList<AColGroup> small = new ArrayList<AColGroup>();
        for (AColGroup grp : colGroups) {
            if (grp instanceof ColGroupUncompressed) {
                ArrayList<AColGroup> uc = new ArrayList<AColGroup>();
                uc.add(grp);
                tasks.add(new ScalarTask(uc, sop));
            } else {
                int nv = grp.getNumValues() * grp.getColIndices().length;
                if (nv < 8096 && !(grp instanceof ColGroupOLE)) {
                    small.add(grp);
                } else {
                    ArrayList<AColGroup> large = new ArrayList<AColGroup>();
                    large.add(grp);
                    tasks.add(new ScalarTask(large, sop));
                }
            }
            if (small.size() <= 10) continue;
            tasks.add(new ScalarTask(small, sop));
            small = new ArrayList();
        }
        if (small.size() > 0) {
            tasks.add(new ScalarTask(small, sop));
        }
        return tasks;
    }

    private static class ScalarTask
    implements Callable<List<AColGroup>> {
        private final List<AColGroup> _colGroups;
        private final ScalarOperator _sop;

        protected ScalarTask(List<AColGroup> colGroups, ScalarOperator sop) {
            this._colGroups = colGroups;
            this._sop = sop;
        }

        @Override
        public List<AColGroup> call() {
            ArrayList<AColGroup> res = new ArrayList<AColGroup>();
            for (AColGroup x : this._colGroups) {
                res.add(x.scalarOperation(this._sop));
            }
            return res;
        }
    }
}

