/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.cocode.PlanningCoCoder;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorFactory;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.utils.DblArrayIntListHashMap;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.DMLCompressionStatistics;

public class CompressedMatrixBlockFactory {
    private static final Log LOG = LogFactory.getLog((String)CompressedMatrixBlockFactory.class.getName());
    private Timing time = new Timing(true);
    private double lastPhase;
    private CompressionStatistics _stats = new CompressionStatistics();
    private MatrixBlock mb;
    private int k;
    private CompressionSettings compSettings;
    private CompressedMatrixBlock res;
    private int phase = 0;
    private CompressedSizeInfo coCodeColGroups;

    private CompressedMatrixBlockFactory(MatrixBlock mb, int k, CompressionSettings compSettings) {
        this.mb = mb;
        this.k = k;
        this.compSettings = compSettings;
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb) {
        return CompressedMatrixBlockFactory.compress(mb, 1, new CompressionSettingsBuilder().create());
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, CompressionSettings customSettings) {
        return CompressedMatrixBlockFactory.compress(mb, 1, customSettings);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k) {
        return CompressedMatrixBlockFactory.compress(mb, k, new CompressionSettingsBuilder().create());
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock mb, int k, CompressionSettings compSettings) {
        CompressedMatrixBlockFactory cmbf = new CompressedMatrixBlockFactory(mb, k, compSettings);
        return cmbf.compressMatrix();
    }

    public static CompressedMatrixBlock createConstant(int numRows, int numCols, double value) {
        CompressedMatrixBlock block = new CompressedMatrixBlock(numRows, numCols);
        AColGroup cg = ColGroupFactory.genColGroupConst(numRows, numCols, value);
        block.allocateColGroup(cg);
        block.recomputeNonZeros();
        return block;
    }

    private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {
        if (this.mb instanceof CompressedMatrixBlock) {
            LOG.info((Object)"MatrixBlock already compressed");
            return new ImmutablePair((Object)this.mb, null);
        }
        this._stats.denseSize = MatrixBlock.estimateSizeInMemory(this.mb.getNumRows(), this.mb.getNumColumns(), 1.0);
        this._stats.originalSize = this.mb.getInMemorySize();
        this.res = new CompressedMatrixBlock(this.mb);
        this.classifyPhase();
        if (this.coCodeColGroups == null) {
            return this.abortCompression();
        }
        this.transposePhase();
        this.compressPhase();
        this.sharePhase();
        this.cleanupPhase();
        if (this.res == null) {
            return this.abortCompression();
        }
        this.res.recomputeNonZeros();
        return new ImmutablePair((Object)this.res, (Object)this._stats);
    }

    private void classifyPhase() {
        CompressedSizeEstimator sizeEstimator = CompressedSizeEstimatorFactory.getSizeEstimator(this.mb, this.compSettings);
        CompressedSizeInfo sizeInfos = sizeEstimator.computeCompressedSizeInfos(this.k);
        this._stats.estimatedSizeCols = sizeInfos.memoryEstimate();
        this.logPhase();
        if (this._stats.estimatedSizeCols < this._stats.originalSize || this.compSettings.columnPartitioner == PlanningCoCoder.PartitionerType.COST_MATRIX_MULT) {
            this.coCodePhase(sizeEstimator, sizeInfos, this.mb.getNumRows());
        } else {
            LOG.info((Object)("Estimated Size of singleColGroups: " + this._stats.estimatedSizeCols));
            LOG.info((Object)("Original size                    : " + this._stats.originalSize));
        }
    }

    private void coCodePhase(CompressedSizeEstimator sizeEstimator, CompressedSizeInfo sizeInfos, int numRows) {
        this.coCodeColGroups = PlanningCoCoder.findCoCodesByPartitioning(sizeEstimator, sizeInfos, numRows, this.k, this.compSettings);
        this._stats.estimatedSizeCoCoded = this.coCodeColGroups.memoryEstimate();
        this.logPhase();
    }

    private void transposePhase() {
        boolean sparse = this.mb.isInSparseFormat();
        this.transposeHeuristics();
        this.mb = this.compSettings.transposed ? LibMatrixReorg.transpose(this.mb, new MatrixBlock(this.mb.getNumColumns(), this.mb.getNumRows(), sparse), this.k) : new MatrixBlock(this.mb.getNumRows(), this.mb.getNumColumns(), sparse).copyShallow(this.mb);
        this.logPhase();
    }

    private void transposeHeuristics() {
        switch (this.compSettings.transposeInput) {
            case "true": {
                this.compSettings.transposed = true;
                break;
            }
            case "false": {
                this.compSettings.transposed = false;
                break;
            }
            default: {
                if (this.mb.isInSparseFormat()) {
                    boolean isAboveRowNumbers = this.mb.getNumRows() > 500000;
                    boolean isAboveThreadToColumnRatio = this.coCodeColGroups.getNumberColGroups() > this.mb.getNumColumns() / 2;
                    this.compSettings.transposed = isAboveRowNumbers || isAboveThreadToColumnRatio;
                    break;
                }
                this.compSettings.transposed = false;
            }
        }
    }

    private void compressPhase() {
        this.res.allocateColGroupList(ColGroupFactory.compressColGroups(this.mb, this.coCodeColGroups, this.compSettings, this.k));
        this._stats.compressedInitialSize = this.res.getInMemorySize();
        this.logPhase();
    }

    private void sharePhase() {
        ArrayList<AColGroup> e = new ArrayList<AColGroup>();
        ArrayList<AColGroup> c = new ArrayList<AColGroup>();
        ArrayList<AColGroup> o = new ArrayList<AColGroup>();
        for (AColGroup g : this.res.getColGroups()) {
            if (g instanceof ColGroupEmpty) {
                e.add(g);
                continue;
            }
            if (g instanceof ColGroupConst) {
                c.add(g);
                continue;
            }
            o.add(g);
        }
        if (!e.isEmpty()) {
            o.add(CompressedMatrixBlockFactory.combineEmpty(e));
        }
        if (!c.isEmpty()) {
            o.add(CompressedMatrixBlockFactory.combineConst(c));
        }
        this.res.allocateColGroupList(o);
        this.logPhase();
    }

    private static AColGroup combineEmpty(List<AColGroup> e) {
        return new ColGroupEmpty(CompressedMatrixBlockFactory.combineColIndexes(e), e.get(0).getNumRows());
    }

    private static AColGroup combineConst(List<AColGroup> c) {
        int[] resCols = CompressedMatrixBlockFactory.combineColIndexes(c);
        double[] values = new double[resCols.length];
        block0: for (int i = 0; i < resCols.length; ++i) {
            for (AColGroup g : c) {
                ColGroupConst cg = (ColGroupConst)g;
                int[] cols = cg.getColIndices();
                int index = Arrays.binarySearch(cols, resCols[i]);
                if (index < 0) continue;
                values[i] = cg.getDictionary().getValue(index);
                continue block0;
            }
        }
        Dictionary dict = new Dictionary(values);
        return new ColGroupConst(resCols, c.get(0).getNumRows(), dict);
    }

    private static int[] combineColIndexes(List<AColGroup> gs) {
        int numCols = 0;
        for (AColGroup g : gs) {
            numCols += g.getNumCols();
        }
        int[] resCols = new int[numCols];
        int index = 0;
        for (AColGroup g : gs) {
            for (int c : g.getColIndices()) {
                resCols[index++] = c;
            }
        }
        Arrays.sort(resCols);
        return resCols;
    }

    private void cleanupPhase() {
        this.res.cleanupBlock(true, true);
        this._stats.size = this.res.estimateCompressedSizeInMemory();
        double ratio = this._stats.getRatio();
        if (ratio < 1.0 && this.compSettings.columnPartitioner != PlanningCoCoder.PartitionerType.COST_MATRIX_MULT) {
            LOG.info((Object)("--dense size:        " + this._stats.denseSize));
            LOG.info((Object)("--original size:     " + this._stats.originalSize));
            LOG.info((Object)("--compressed size:   " + this._stats.size));
            LOG.info((Object)("--compression ratio: " + ratio));
            LOG.info((Object)"Abort block compression because compression ratio is less than 1.");
            this.res = null;
            return;
        }
        this.mb.cleanupBlock(true, true);
        this._stats.setColGroupsCounts(this.res.getColGroups());
        this.logPhase();
    }

    private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
        LOG.warn((Object)("Compression aborted at phase: " + this.phase));
        if (this.compSettings.transposed) {
            LibMatrixReorg.transposeInPlace(this.mb, this.k);
        }
        return new ImmutablePair((Object)this.mb, (Object)this._stats);
    }

    private void logPhase() {
        this.setNextTimePhase(this.time.stop());
        DMLCompressionStatistics.addCompressionTime(this.getLastTimePhase(), this.phase);
        if (LOG.isDebugEnabled()) {
            switch (this.phase) {
                case 0: {
                    LOG.debug((Object)("--compression phase " + this.phase + " Classify  : " + this.getLastTimePhase()));
                    LOG.debug((Object)("--Individual Columns Estimated Compression: " + this._stats.estimatedSizeCols));
                    break;
                }
                case 1: {
                    LOG.debug((Object)("--compression phase " + this.phase + " Grouping  : " + this.getLastTimePhase()));
                    LOG.debug((Object)("Grouping using: " + (Object)((Object)this.compSettings.columnPartitioner)));
                    LOG.debug((Object)("--Cocoded Columns estimated Compression:" + this._stats.estimatedSizeCoCoded));
                    break;
                }
                case 2: {
                    LOG.debug((Object)("--compression phase " + this.phase + " Transpose : " + this.getLastTimePhase()));
                    LOG.debug((Object)("Did transpose: " + this.compSettings.transposed));
                    break;
                }
                case 3: {
                    LOG.debug((Object)("--compression phase " + this.phase + " Compress  : " + this.getLastTimePhase()));
                    LOG.debug((Object)("--compression Hash collisions:" + DblArrayIntListHashMap.hashMissCount));
                    DblArrayIntListHashMap.hashMissCount = 0;
                    LOG.debug((Object)("--compressed initial actual size:" + this._stats.compressedInitialSize));
                    break;
                }
                case 4: {
                    LOG.debug((Object)("--compression phase " + this.phase + " Share     : " + this.getLastTimePhase()));
                    break;
                }
                case 5: {
                    LOG.debug((Object)("--num col groups: " + this.res.getColGroups().size()));
                    LOG.debug((Object)("--compression phase " + this.phase + " Cleanup   : " + this.getLastTimePhase()));
                    LOG.debug((Object)("--col groups types " + this._stats.getGroupsTypesString()));
                    LOG.debug((Object)("--col groups sizes " + this._stats.getGroupsSizesString()));
                    LOG.debug((Object)("--dense size:        " + this._stats.denseSize));
                    LOG.debug((Object)("--original size:     " + this._stats.originalSize));
                    LOG.debug((Object)("--compressed size:   " + this._stats.size));
                    LOG.debug((Object)("--compression ratio: " + this._stats.getRatio()));
                    int[] lengths = new int[this.res.getColGroups().size()];
                    int i = 0;
                    for (AColGroup colGroup : this.res.getColGroups()) {
                        lengths[i++] = colGroup.getNumValues();
                    }
                    LOG.debug((Object)("--compressed colGroup dictionary sizes: " + Arrays.toString(lengths)));
                    if (!LOG.isTraceEnabled()) break;
                    for (AColGroup colGroup : this.res.getColGroups()) {
                        LOG.trace((Object)("--colGroups type       : " + colGroup.getClass().getSimpleName() + " size: " + colGroup.estimateInMemorySize() + (colGroup instanceof ColGroupValue ? "  numValues :" + ((ColGroupValue)colGroup).getNumValues() : "") + "  colIndexes : " + Arrays.toString(colGroup.getColIndices())));
                    }
                    break;
                }
            }
        }
        ++this.phase;
    }

    public void setNextTimePhase(double time) {
        this.lastPhase = time;
    }

    public double getLastTimePhase() {
        return this.lastPhase;
    }
}

