/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops;

import org.apache.sysml.api.DMLScript;
import org.apache.sysml.hops.AggUnaryOp;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.HopsException;
import org.apache.sysml.hops.MemoTable;
import org.apache.sysml.hops.MultiThreadedHop;
import org.apache.sysml.hops.OptimizerUtils;
import org.apache.sysml.hops.ReorgOp;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.lops.Aggregate;
import org.apache.sysml.lops.Binary;
import org.apache.sysml.lops.DataPartition;
import org.apache.sysml.lops.Group;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.LopProperties;
import org.apache.sysml.lops.MMCJ;
import org.apache.sysml.lops.MMRJ;
import org.apache.sysml.lops.MMTSJ;
import org.apache.sysml.lops.MMZip;
import org.apache.sysml.lops.MapMult;
import org.apache.sysml.lops.MapMultChain;
import org.apache.sysml.lops.PMMJ;
import org.apache.sysml.lops.PMapMult;
import org.apache.sysml.lops.PartialAggregate;
import org.apache.sysml.lops.Transform;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.controlprogram.ParForProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.mapred.MMCJMRReducerWithAggregator;
import org.apache.sysml.runtime.util.UtilFunctions;

public class AggBinaryOp
extends MultiThreadedHop {
    public static final double MAPMULT_MEM_MULTIPLIER = 1.0;
    public static MMultMethod FORCED_MMULT_METHOD = null;
    private Hop.OpOp2 innerOp;
    private Hop.AggOp outerOp;
    private MMultMethod _method = null;
    private boolean _hasLeftPMInput = false;

    private AggBinaryOp() {
    }

    public AggBinaryOp(String l, Expression.DataType dt, Expression.ValueType vt, Hop.OpOp2 innOp, Hop.AggOp outOp, Hop in1, Hop in2) {
        super(l, dt, vt);
        this.innerOp = innOp;
        this.outerOp = outOp;
        this.getInput().add(0, in1);
        this.getInput().add(1, in2);
        in1.getParent().add(this);
        in2.getParent().add(this);
        this.refreshSizeInformation();
    }

    @Override
    public void checkArity() {
        HopsException.check(this._input.size() == 2, this, "should have arity 2 but has arity %d", this._input.size());
    }

    public void setHasLeftPMInput(boolean flag) {
        this._hasLeftPMInput = flag;
    }

    public boolean hasLeftPMInput() {
        return this._hasLeftPMInput;
    }

    public MMultMethod getMMultMethod() {
        return this._method;
    }

    @Override
    public boolean isGPUEnabled() {
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        Hop input1 = this.getInput().get(0);
        Hop input2 = this.getInput().get(1);
        MMTSJ.MMTSJType mmtsj = this.checkTransposeSelf();
        MapMultChain.ChainType chain = this.checkMapMultChain();
        this._method = AggBinaryOp.optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), input2.getDim1(), input2.getDim2(), mmtsj, chain, this._hasLeftPMInput);
        switch (this._method) {
            case TSMM: {
                return false;
            }
            case MAPMM_CHAIN: {
                return false;
            }
            case PMM: {
                return false;
            }
            case MM: {
                return true;
            }
        }
        throw new RuntimeException("Unsupported method:" + (Object)((Object)this._method));
    }

    @Override
    public Lop constructLops() {
        block30: {
            block28: {
                MapMultChain.ChainType chain;
                MMTSJ.MMTSJType mmtsj;
                LopProperties.ExecType et;
                Hop input2;
                Hop input1;
                block29: {
                    if (this.getLops() != null) {
                        return this.getLops();
                    }
                    if (!this.isMatrixMultiply()) break block28;
                    input1 = this.getInput().get(0);
                    input2 = this.getInput().get(1);
                    et = this.optFindExecType();
                    mmtsj = this.checkTransposeSelf();
                    chain = this.checkMapMultChain();
                    if (et != LopProperties.ExecType.CP && et != LopProperties.ExecType.GPU) break block29;
                    this._method = AggBinaryOp.optFindMMultMethodCP(input1.getDim1(), input1.getDim2(), input2.getDim1(), input2.getDim2(), mmtsj, chain, this._hasLeftPMInput);
                    switch (this._method) {
                        case TSMM: {
                            this.constructCPLopsTSMM(mmtsj, et);
                            break block30;
                        }
                        case MAPMM_CHAIN: {
                            this.constructCPLopsMMChain(chain);
                            break block30;
                        }
                        case PMM: {
                            this.constructCPLopsPMM();
                            break block30;
                        }
                        case MM: {
                            this.constructCPLopsMM(et);
                            break block30;
                        }
                        default: {
                            throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + (Object)((Object)this._method) + ") while constructing CP lops.");
                        }
                    }
                }
                if (et == LopProperties.ExecType.SPARK) {
                    boolean tmmRewrite = HopRewriteUtils.isTransposeOperation(input1);
                    this._method = this.optFindMMultMethodSpark(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, this._hasLeftPMInput, tmmRewrite);
                    switch (this._method) {
                        case TSMM: 
                        case TSMM2: {
                            this.constructSparkLopsTSMM(mmtsj, this._method == MMultMethod.TSMM2);
                            break;
                        }
                        case MAPMM_L: 
                        case MAPMM_R: {
                            this.constructSparkLopsMapMM(this._method);
                            break;
                        }
                        case MAPMM_CHAIN: {
                            this.constructSparkLopsMapMMChain(chain);
                            break;
                        }
                        case PMAPMM: {
                            this.constructSparkLopsPMapMM();
                            break;
                        }
                        case CPMM: {
                            this.constructSparkLopsCPMM();
                            break;
                        }
                        case RMM: {
                            this.constructSparkLopsRMM();
                            break;
                        }
                        case PMM: {
                            this.constructSparkLopsPMM();
                            break;
                        }
                        case ZIPMM: {
                            this.constructSparkLopsZIPMM();
                            break;
                        }
                        default: {
                            throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + (Object)((Object)this._method) + ") while constructing SPARK lops.");
                        }
                    }
                } else if (et == LopProperties.ExecType.MR) {
                    this._method = AggBinaryOp.optFindMMultMethodMR(input1.getDim1(), input1.getDim2(), input1.getRowsInBlock(), input1.getColsInBlock(), input1.getNnz(), input2.getDim1(), input2.getDim2(), input2.getRowsInBlock(), input2.getColsInBlock(), input2.getNnz(), mmtsj, chain, this._hasLeftPMInput);
                    switch (this._method) {
                        case MAPMM_L: 
                        case MAPMM_R: {
                            this.constructMRLopsMapMM(this._method);
                            break;
                        }
                        case MAPMM_CHAIN: {
                            this.constructMRLopsMapMMChain(chain);
                            break;
                        }
                        case CPMM: {
                            this.constructMRLopsCPMM();
                            break;
                        }
                        case RMM: {
                            this.constructMRLopsRMM();
                            break;
                        }
                        case TSMM: {
                            this.constructMRLopsTSMM(mmtsj);
                            break;
                        }
                        case PMM: {
                            this.constructMRLopsPMM();
                            break;
                        }
                        default: {
                            throw new HopsException(this.printErrorLocation() + "Invalid Matrix Mult Method (" + (Object)((Object)this._method) + ") while constructing MR lops.");
                        }
                    }
                }
                break block30;
            }
            throw new HopsException(this.printErrorLocation() + "Invalid operation in AggBinary Hop, aggBin(" + (Object)((Object)this.innerOp) + "," + (Object)((Object)this.outerOp) + ") while constructing lops.");
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    @Override
    public String getOpString() {
        String s = "ba(" + (String)HopsAgg2String.get((Object)this.outerOp) + (String)HopsOpOp2String.get((Object)this.innerOp) + ")";
        return s;
    }

    @Override
    public void computeMemEstimate(MemoTable memo) {
        super.computeMemEstimate(memo);
        MMTSJ.MMTSJType mmtsj = this.checkTransposeSelf();
        if (mmtsj.isLeft() && this.getInput().get(1).dimsKnown() && this.getInput().get(1).getDim2() > 1L) {
            this._memEstimate -= this.getInput().get((int)0)._outputMemEstimate;
        }
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = 1.0;
        double ret = OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
        return ret;
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        double ret = 0.0;
        if (this.isGPUEnabled()) {
            boolean in2Sparse;
            Hop in1 = (Hop)this._input.get(0);
            Hop in2 = (Hop)this._input.get(1);
            double in1Sparsity = OptimizerUtils.getSparsity(in1.getDim1(), in1.getDim2(), in1.getNnz());
            double in2Sparsity = OptimizerUtils.getSparsity(in2.getDim1(), in2.getDim2(), in2.getNnz());
            boolean in1Sparse = in1Sparsity < 0.4;
            boolean bl = in2Sparse = in2Sparsity < 0.4;
            if (in1Sparse && !in2Sparse) {
                ret += (double)OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, 1.0);
            }
        }
        if (dim2 >= 2L) {
            ret += (double)MatrixBlock.estimateSizeSparseInMemory(dim1, dim2, 0.4 - UtilFunctions.DOUBLE_EPS);
        }
        return ret;
    }

    @Override
    protected long[] inferOutputCharacteristics(MemoTable memo) {
        long[] ret = null;
        MatrixCharacteristics[] mc = memo.getAllInputStats(this.getInput());
        if (mc[0].rowsKnown() && mc[1].colsKnown()) {
            ret = new long[3];
            ret[0] = mc[0].getRows();
            ret[1] = mc[1].getCols();
            double sp1 = mc[0].getNonZeros() > 0L ? OptimizerUtils.getSparsity(mc[0].getRows(), mc[0].getCols(), mc[0].getNonZeros()) : 1.0;
            double sp2 = mc[1].getNonZeros() > 0L ? OptimizerUtils.getSparsity(mc[1].getRows(), mc[1].getCols(), mc[1].getNonZeros()) : 1.0;
            ret[2] = (long)((double)(ret[0] * ret[1]) * OptimizerUtils.getMatMultSparsity(sp1, sp2, ret[0], mc[0].getCols(), ret[1], true));
        }
        return ret;
    }

    public boolean isMatrixMultiply() {
        return this.innerOp == Hop.OpOp2.MULT && this.outerOp == Hop.AggOp.SUM;
    }

    private boolean isOuterProduct() {
        if (this.getInput().get(0).isVector() && this.getInput().get(1).isVector()) {
            return this.getInput().get(0).getDim1() == 1L && this.getInput().get(0).getDim1() > 1L && this.getInput().get(1).getDim1() > 1L && this.getInput().get(1).getDim2() == 1L;
        }
        return false;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    @Override
    protected LopProperties.ExecType optFindExecType() {
        LopProperties.ExecType REMOTE;
        this.checkAndSetForcedPlatform();
        LopProperties.ExecType execType = REMOTE = OptimizerUtils.isSparkExecutionMode() ? LopProperties.ExecType.SPARK : LopProperties.ExecType.MR;
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() && this.getInput().get(1).areDimsBelowThreshold() || this.getInput().get(0).isVector() && this.getInput().get(1).isVector() && !this.isOuterProduct() ? LopProperties.ExecType.CP : REMOTE);
            if (this._etype == LopProperties.ExecType.CP && this.checkMapMultChain() != MapMultChain.ChainType.NONE && OptimizerUtils.getLocalMemBudget() < this.getInput().get(0).getInput().get(0).getOutputMemEstimate()) {
                this._etype = REMOTE;
            }
            this.checkAndSetInvalidCPDimsAndSize();
        }
        if (this._etype == LopProperties.ExecType.CP && this._etypeForced != LopProperties.ExecType.CP && (this.isApplicableForTransitiveSparkExecType(true) || this.isApplicableForTransitiveSparkExecType(false))) {
            this._etype = LopProperties.ExecType.SPARK;
        }
        this.setRequiresRecompileIfNecessary();
        return this._etype;
    }

    private boolean isApplicableForTransitiveSparkExecType(boolean left) {
        int index = left ? 0 : 1;
        return (!(this.getInput().get(index) instanceof DataOp) || !((DataOp)this.getInput().get(index)).requiresCheckpoint()) && (!HopRewriteUtils.isTransposeOperation(this.getInput().get(index)) || left && !this.isLeftTransposeRewriteApplicable(true, false)) && this.getInput().get(index).getParent().size() == 1 && !this.getInput().get(index).areDimsBelowThreshold() && this.getInput().get(index).optFindExecType() == LopProperties.ExecType.SPARK && this.getInput().get(index).getOutputMemEstimate() > this.getOutputMemEstimate();
    }

    public MMTSJ.MMTSJType checkTransposeSelf() {
        MMTSJ.MMTSJType ret = MMTSJ.MMTSJType.NONE;
        Hop in1 = this.getInput().get(0);
        Hop in2 = this.getInput().get(1);
        if (HopRewriteUtils.isTransposeOperation(in1) && in1.getInput().get(0) == in2) {
            ret = MMTSJ.MMTSJType.LEFT;
        }
        if (HopRewriteUtils.isTransposeOperation(in2) && in2.getInput().get(0) == in1) {
            ret = MMTSJ.MMTSJType.RIGHT;
        }
        return ret;
    }

    public MapMultChain.ChainType checkMapMultChain() {
        MapMultChain.ChainType chainType = MapMultChain.ChainType.NONE;
        Hop in1 = this.getInput().get(0);
        Hop in2 = this.getInput().get(1);
        if (HopRewriteUtils.isTransposeOperation(in1)) {
            Hop in3;
            Hop X = in1.getInput().get(0);
            if (in2 instanceof BinaryOp && ((BinaryOp)in2).getOp() == Hop.OpOp2.MULT) {
                Hop in4;
                Hop in3b = in2.getInput().get(1);
                if (in3b instanceof AggBinaryOp && X == (in4 = in3b.getInput().get(0))) {
                    chainType = MapMultChain.ChainType.XtwXv;
                }
            } else if (in2 instanceof BinaryOp && ((BinaryOp)in2).getOp() == Hop.OpOp2.MINUS) {
                Hop in4;
                Hop in3a = in2.getInput().get(0);
                Hop in3b = in2.getInput().get(1);
                if (in3a instanceof AggBinaryOp && in3b.getDataType() == Expression.DataType.MATRIX && X == (in4 = in3a.getInput().get(0))) {
                    chainType = MapMultChain.ChainType.XtXvy;
                }
            } else if (in2 instanceof AggBinaryOp && X == (in3 = in2.getInput().get(0))) {
                chainType = MapMultChain.ChainType.XtXv;
            }
        }
        return chainType;
    }

    private void constructCPLopsTSMM(MMTSJ.MMTSJType mmtsj, LopProperties.ExecType et) {
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        MMTSJ matmultCP = new MMTSJ(this.getInput().get(mmtsj.isLeft() ? 1 : 0).constructLops(), this.getDataType(), this.getValueType(), et, mmtsj, false, k);
        matmultCP.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(matmultCP);
        this.setLops(matmultCP);
    }

    private void constructCPLopsMMChain(MapMultChain.ChainType chain) {
        MapMultChain mapmmchain = null;
        if (chain == MapMultChain.ChainType.XtXv) {
            Hop hX = this.getInput().get(0).getInput().get(0);
            Hop hv = this.getInput().get(1).getInput().get(1);
            mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        } else {
            int wix = chain == MapMultChain.ChainType.XtwXv ? 0 : 1;
            int vix = chain == MapMultChain.ChainType.XtwXv ? 1 : 0;
            Hop hX = this.getInput().get(0).getInput().get(0);
            Hop hw = this.getInput().get(1).getInput().get(wix);
            Hop hv = this.getInput().get(1).getInput().get(vix).getInput().get(1);
            mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        }
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        mapmmchain.setNumThreads(k);
        this.setOutputDimensions(mapmmchain);
        this.setLineNumbers(mapmmchain);
        this.setLops(mapmmchain);
    }

    private void constructCPLopsPMM() {
        Hop pmInput = this.getInput().get(0);
        Hop rightInput = this.getInput().get(1);
        Hop nrow = HopRewriteUtils.createValueHop(pmInput, true);
        nrow.setOutputBlocksizes(0, 0);
        nrow.setForcedExecType(LopProperties.ExecType.CP);
        HopRewriteUtils.copyLineNumbers(this, nrow);
        Lop lnrow = nrow.constructLops();
        PMMJ pmm = new PMMJ(pmInput.constructLops(), rightInput.constructLops(), lnrow, this.getDataType(), this.getValueType(), false, false, LopProperties.ExecType.CP);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        pmm.setNumThreads(k);
        pmm.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(pmm);
        this.setLops(pmm);
        HopRewriteUtils.removeChildReference(pmInput, nrow);
    }

    private void constructCPLopsMM(LopProperties.ExecType et) {
        Lop matmultCP = null;
        if (et == LopProperties.ExecType.GPU) {
            Hop h1 = this.getInput().get(0);
            Hop h2 = this.getInput().get(1);
            boolean leftTrans = false;
            boolean rightTrans = false;
            Lop left = !leftTrans ? h1.constructLops() : h1.getInput().get(0).constructLops();
            Lop right = !rightTrans ? h2.constructLops() : h2.getInput().get(0).constructLops();
            matmultCP = new Binary(left, right, Binary.OperationTypes.MATMULT, this.getDataType(), this.getValueType(), et, leftTrans, rightTrans);
            this.setOutputDimensions(matmultCP);
        } else {
            if (this.isLeftTransposeRewriteApplicable(true, false)) {
                matmultCP = this.constructCPLopsMMWithLeftTransposeRewrite();
            } else {
                int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
                matmultCP = new Binary(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), Binary.OperationTypes.MATMULT, this.getDataType(), this.getValueType(), et, k);
            }
            this.setOutputDimensions(matmultCP);
        }
        this.setLineNumbers(matmultCP);
        this.setLops(matmultCP);
    }

    private Lop constructCPLopsMMWithLeftTransposeRewrite() {
        Hop X = this.getInput().get(0).getInput().get(0);
        Hop Y = this.getInput().get(1);
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        Lop lY = Y.constructLops();
        Lop tY = lY instanceof Transform && ((Transform)lY).getOperationType() == Transform.OperationTypes.Transpose ? lY.getInputs().get(0) : new Transform(lY, Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP, k);
        tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), this.getRowsInBlock(), this.getColsInBlock(), Y.getNnz());
        this.setLineNumbers(tY);
        Binary mult = new Binary(tY, X.constructLops(), Binary.OperationTypes.MATMULT, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP, k);
        mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(mult);
        Transform out = new Transform((Lop)mult, Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP, k);
        return out;
    }

    private void constructSparkLopsTSMM(MMTSJ.MMTSJType mmtsj, boolean multiPass) {
        Hop input = this.getInput().get(mmtsj.isLeft() ? 1 : 0);
        MMTSJ tsmm = new MMTSJ(input.constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.SPARK, mmtsj, multiPass);
        this.setOutputDimensions(tsmm);
        this.setLineNumbers(tsmm);
        this.setLops(tsmm);
    }

    private void constructSparkLopsMapMM(MMultMethod method) {
        Lop mapmult = null;
        if (this.isLeftTransposeRewriteApplicable(false, false)) {
            mapmult = this.constructSparkLopsMapMMWithLeftTransposeRewrite();
        } else {
            boolean needAgg = this.requiresAggregation(method);
            SparkAggType aggtype = this.getSparkMMAggregationType(needAgg);
            this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
            mapmult = new MapMult(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getDataType(), this.getValueType(), method == MMultMethod.MAPMM_R, false, this._outputEmptyBlocks, aggtype);
        }
        this.setOutputDimensions(mapmult);
        this.setLineNumbers(mapmult);
        this.setLops(mapmult);
    }

    private Lop constructSparkLopsMapMMWithLeftTransposeRewrite() {
        Hop X = this.getInput().get(0).getInput().get(0);
        Hop Y = this.getInput().get(1);
        Transform tY = new Transform(Y.constructLops(), Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), this.getRowsInBlock(), this.getColsInBlock(), Y.getNnz());
        this.setLineNumbers(tY);
        boolean needAgg = this.requiresAggregation(MMultMethod.MAPMM_R);
        SparkAggType aggtype = this.getSparkMMAggregationType(needAgg);
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        MapMult mult = new MapMult(tY, X.constructLops(), this.getDataType(), this.getValueType(), false, false, this._outputEmptyBlocks, aggtype);
        mult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(mult);
        Transform out = new Transform(mult, Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        return out;
    }

    private void constructSparkLopsMapMMChain(MapMultChain.ChainType chain) {
        MapMultChain mapmmchain = null;
        if (chain == MapMultChain.ChainType.XtXv) {
            Hop hX = this.getInput().get(0).getInput().get(0);
            Hop hv = this.getInput().get(1).getInput().get(1);
            mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.SPARK);
        } else {
            int wix = chain == MapMultChain.ChainType.XtwXv ? 0 : 1;
            int vix = chain == MapMultChain.ChainType.XtwXv ? 1 : 0;
            Hop hX = this.getInput().get(0).getInput().get(0);
            Hop hw = this.getInput().get(1).getInput().get(wix);
            Hop hv = this.getInput().get(1).getInput().get(vix).getInput().get(1);
            mapmmchain = new MapMultChain(hX.constructLops(), hv.constructLops(), hw.constructLops(), chain, this.getDataType(), this.getValueType(), LopProperties.ExecType.SPARK);
        }
        this.setOutputDimensions(mapmmchain);
        this.setLineNumbers(mapmmchain);
        this.setLops(mapmmchain);
    }

    private void constructSparkLopsPMapMM() {
        PMapMult pmapmult = new PMapMult(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getDataType(), this.getValueType());
        this.setOutputDimensions(pmapmult);
        this.setLineNumbers(pmapmult);
        this.setLops(pmapmult);
    }

    private void constructSparkLopsCPMM() {
        if (this.isLeftTransposeRewriteApplicable(false, false)) {
            this.setLops(this.constructSparkLopsCPMMWithLeftTransposeRewrite());
        } else {
            SparkAggType aggtype = this.getSparkMMAggregationType(true);
            this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
            MMCJ cpmm = new MMCJ(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getDataType(), this.getValueType(), this._outputEmptyBlocks, aggtype, LopProperties.ExecType.SPARK);
            this.setOutputDimensions(cpmm);
            this.setLineNumbers(cpmm);
            this.setLops(cpmm);
        }
    }

    private Lop constructSparkLopsCPMMWithLeftTransposeRewrite() {
        SparkAggType aggtype = this.getSparkMMAggregationType(true);
        Hop X = this.getInput().get(0).getInput().get(0);
        Hop Y = this.getInput().get(1);
        Transform tY = new Transform(Y.constructLops(), Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), this.getRowsInBlock(), this.getColsInBlock(), Y.getNnz());
        this.setLineNumbers(tY);
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        MMCJ mmcj = new MMCJ(tY, X.constructLops(), this.getDataType(), this.getValueType(), this._outputEmptyBlocks, aggtype, LopProperties.ExecType.SPARK);
        mmcj.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(mmcj);
        Transform out = new Transform(mmcj, Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        out.getOutputParameters().setDimensions(X.getDim2(), Y.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        return out;
    }

    private void constructSparkLopsRMM() {
        MMRJ rmm = new MMRJ(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.SPARK);
        this.setOutputDimensions(rmm);
        this.setLineNumbers(rmm);
        this.setLops(rmm);
    }

    private void constructSparkLopsPMM() {
        LopProperties.ExecType etVect;
        Hop pmInput = this.getInput().get(0);
        Hop rightInput = this.getInput().get(1);
        Lop lpmInput = pmInput.constructLops();
        Hop nrow = null;
        double mestPM = OptimizerUtils.estimateSize(pmInput.getDim1(), 1L);
        LopProperties.ExecType execType = etVect = mestPM > OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.MR : LopProperties.ExecType.CP;
        if (pmInput.getDim2() != 1L) {
            ReorgOp transpose = HopRewriteUtils.createTranspose(pmInput);
            transpose.setForcedExecType(LopProperties.ExecType.SPARK);
            AggUnaryOp agg1 = HopRewriteUtils.createAggUnaryOp(transpose, Hop.AggOp.MAXINDEX, Hop.Direction.Row);
            agg1.setForcedExecType(LopProperties.ExecType.SPARK);
            AggUnaryOp agg2 = HopRewriteUtils.createAggUnaryOp(transpose, Hop.AggOp.MAX, Hop.Direction.Row);
            agg2.setForcedExecType(LopProperties.ExecType.SPARK);
            BinaryOp mult = HopRewriteUtils.createBinary(agg1, agg2, Hop.OpOp2.MULT);
            mult.setForcedExecType(LopProperties.ExecType.SPARK);
            nrow = HopRewriteUtils.createValueHop(pmInput, true);
            nrow.setOutputBlocksizes(0, 0);
            nrow.setForcedExecType(LopProperties.ExecType.CP);
            HopRewriteUtils.copyLineNumbers(this, nrow);
            lpmInput = mult.constructLops();
            HopRewriteUtils.removeChildReference(pmInput, transpose);
        } else {
            nrow = HopRewriteUtils.createAggUnaryOp(pmInput, Hop.AggOp.MAX, Hop.Direction.RowCol);
            nrow.setOutputBlocksizes(0, 0);
            nrow.setForcedExecType(etVect);
            HopRewriteUtils.copyLineNumbers(this, nrow);
        }
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        PMMJ pmm = new PMMJ(lpmInput, rightInput.constructLops(), nrow.constructLops(), this.getDataType(), this.getValueType(), false, this._outputEmptyBlocks, LopProperties.ExecType.SPARK);
        this.setOutputDimensions(pmm);
        this.setLineNumbers(pmm);
        this.setLops(pmm);
        HopRewriteUtils.removeChildReference(pmInput, nrow);
    }

    private void constructSparkLopsZIPMM() {
        Hop left = this.getInput().get(0).getInput().get(0);
        Hop right = this.getInput().get(1);
        boolean tRewrite = left.getDim1() * left.getDim2() >= right.getDim1() * right.getDim2();
        MMZip zipmm = new MMZip(left.constructLops(), right.constructLops(), this.getDataType(), this.getValueType(), tRewrite, LopProperties.ExecType.SPARK);
        this.setOutputDimensions(zipmm);
        this.setLineNumbers(zipmm);
        this.setLops(zipmm);
    }

    private void constructMRLopsMapMM(MMultMethod method) {
        if (method == MMultMethod.MAPMM_R && this.isLeftTransposeRewriteApplicable(false, true)) {
            this.setLops(this.constructMRLopsMapMMWithLeftTransposeRewrite());
        } else {
            boolean needAgg = this.requiresAggregation(method);
            boolean needPart = this.requiresPartitioning(method, false);
            this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
            Lop leftInput = this.getInput().get(0).constructLops();
            Lop rightInput = this.getInput().get(1).constructLops();
            if (needPart) {
                LopProperties.ExecType etPart;
                Hop input;
                if (method == MMultMethod.MAPMM_L) {
                    input = this.getInput().get(0);
                    etPart = (double)OptimizerUtils.estimateSizeExactSparsity(input.getDim1(), input.getDim2(), OptimizerUtils.getSparsity(input.getDim1(), input.getDim2(), input.getNnz())) < OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.CP : LopProperties.ExecType.MR;
                    leftInput = new DataPartition(input.constructLops(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, etPart, ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N);
                    leftInput.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), input.getNnz());
                    this.setLineNumbers(leftInput);
                } else {
                    input = this.getInput().get(1);
                    etPart = (double)OptimizerUtils.estimateSizeExactSparsity(input.getDim1(), input.getDim2(), OptimizerUtils.getSparsity(input.getDim1(), input.getDim2(), input.getNnz())) < OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.CP : LopProperties.ExecType.MR;
                    rightInput = new DataPartition(input.constructLops(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, etPart, ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N);
                    rightInput.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), input.getNnz());
                    this.setLineNumbers(rightInput);
                }
            }
            MapMult mapmult = new MapMult(leftInput, rightInput, this.getDataType(), this.getValueType(), method == MMultMethod.MAPMM_R, needPart, this._outputEmptyBlocks);
            mapmult.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
            this.setLineNumbers(mapmult);
            if (needAgg) {
                Group grp = new Group(mapmult, Group.OperationTypes.Sort, this.getDataType(), this.getValueType());
                Aggregate agg1 = new Aggregate(grp, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
                grp.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
                agg1.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
                this.setLineNumbers(agg1);
                agg1.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
                this.setLops(agg1);
            } else {
                this.setLops(mapmult);
            }
        }
    }

    private Lop constructMRLopsMapMMWithLeftTransposeRewrite() {
        Hop X = this.getInput().get(0).getInput().get(0);
        Hop Y = this.getInput().get(1);
        Transform tY = new Transform(Y.constructLops(), Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), this.getRowsInBlock(), this.getColsInBlock(), Y.getNnz());
        this.setLineNumbers(tY);
        boolean needAgg = X.getDim1() <= 0L || X.getDim1() > (long)X.getRowsInBlock();
        boolean needPart = this.requiresPartitioning(MMultMethod.MAPMM_R, true);
        Lop dcinput = null;
        if (needPart) {
            LopProperties.ExecType etPart = (double)OptimizerUtils.estimateSizeExactSparsity(Y.getDim2(), Y.getDim1(), OptimizerUtils.getSparsity(Y.getDim2(), Y.getDim1(), Y.getNnz())) < OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.CP : LopProperties.ExecType.MR;
            dcinput = new DataPartition(tY, Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, etPart, ParForProgramBlock.PDataPartitionFormat.COLUMN_BLOCK_WISE_N);
            dcinput.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), this.getRowsInBlock(), this.getColsInBlock(), Y.getNnz());
            this.setLineNumbers(dcinput);
        } else {
            dcinput = tY;
        }
        MapMult mapmult = new MapMult(dcinput, X.constructLops(), this.getDataType(), this.getValueType(), false, needPart, false);
        mapmult.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(mapmult);
        Lop mult = null;
        if (needAgg) {
            Group grp = new Group(mapmult, Group.OperationTypes.Sort, this.getDataType(), this.getValueType());
            grp.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
            this.setLineNumbers(grp);
            Aggregate agg1 = new Aggregate(grp, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
            agg1.getOutputParameters().setDimensions(Y.getDim2(), X.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
            this.setLineNumbers(agg1);
            agg1.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
            mult = agg1;
        } else {
            mult = mapmult;
        }
        Transform out = new Transform(mult, Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        out.getOutputParameters().setDimensions(X.getDim2(), Y.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        return out;
    }

    private void constructMRLopsMapMMChain(MapMultChain.ChainType chainType) {
        MapMultChain mapmult = null;
        if (chainType == MapMultChain.ChainType.XtXv) {
            Hop hX = this.getInput().get(0).getInput().get(0);
            Hop hv = this.getInput().get(1).getInput().get(1);
            mapmult = new MapMultChain(hX.constructLops(), hv.constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
            mapmult.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
            this.setLineNumbers(mapmult);
        } else {
            int wix = chainType == MapMultChain.ChainType.XtwXv ? 0 : 1;
            int vix = chainType == MapMultChain.ChainType.XtwXv ? 1 : 0;
            Hop hX = this.getInput().get(0).getInput().get(0);
            Hop hw = this.getInput().get(1).getInput().get(wix);
            Hop hv = this.getInput().get(1).getInput().get(vix).getInput().get(1);
            double mestW = OptimizerUtils.estimateSize(hw.getDim1(), hw.getDim2());
            boolean needPart = !hw.dimsKnown() || hw.getDim1() * hw.getDim2() > 4000000L;
            Lop X = hX.constructLops();
            Lop v = hv.constructLops();
            Lop w = null;
            if (needPart) {
                w = new DataPartition(hw.constructLops(), Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, mestW > OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.MR : LopProperties.ExecType.CP, ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N);
                w.getOutputParameters().setDimensions(hw.getDim1(), hw.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), hw.getNnz());
                this.setLineNumbers(w);
            } else {
                w = hw.constructLops();
            }
            mapmult = new MapMultChain(X, v, w, chainType, this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
            mapmult.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
            this.setLineNumbers(mapmult);
        }
        Group grp = new Group(mapmult, Group.OperationTypes.Sort, this.getDataType(), this.getValueType());
        grp.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        Aggregate agg1 = new Aggregate(grp, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
        agg1.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        agg1.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        this.setLineNumbers(agg1);
        this.setLops(agg1);
    }

    private void constructMRLopsCPMM() {
        if (this.isLeftTransposeRewriteApplicable(false, false)) {
            this.setLops(this.constructMRLopsCPMMWithLeftTransposeRewrite());
        } else {
            Hop X = this.getInput().get(0);
            Hop Y = this.getInput().get(1);
            MMCJ.MMCJType type = this.getMMCJAggregationType(X, Y);
            MMCJ mmcj = new MMCJ(X.constructLops(), Y.constructLops(), this.getDataType(), this.getValueType(), type, LopProperties.ExecType.MR);
            this.setOutputDimensions(mmcj);
            this.setLineNumbers(mmcj);
            Group grp = new Group(mmcj, Group.OperationTypes.Sort, this.getDataType(), this.getValueType());
            this.setOutputDimensions(grp);
            this.setLineNumbers(grp);
            Aggregate agg1 = new Aggregate(grp, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
            this.setOutputDimensions(agg1);
            this.setLineNumbers(agg1);
            agg1.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
            this.setLops(agg1);
        }
    }

    private Lop constructMRLopsCPMMWithLeftTransposeRewrite() {
        Hop X = this.getInput().get(0).getInput().get(0);
        Hop Y = this.getInput().get(1);
        Transform tY = new Transform(Y.constructLops(), Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        tY.getOutputParameters().setDimensions(Y.getDim2(), Y.getDim1(), this.getRowsInBlock(), this.getColsInBlock(), Y.getNnz());
        this.setLineNumbers(tY);
        MMCJ.MMCJType type = this.getMMCJAggregationType(X, Y);
        MMCJ mmcj = new MMCJ(tY, X.constructLops(), this.getDataType(), this.getValueType(), type, LopProperties.ExecType.MR);
        this.setOutputDimensions(mmcj);
        this.setLineNumbers(mmcj);
        Group grp = new Group(mmcj, Group.OperationTypes.Sort, this.getDataType(), this.getValueType());
        this.setOutputDimensions(grp);
        this.setLineNumbers(grp);
        Aggregate agg1 = new Aggregate(grp, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
        this.setOutputDimensions(agg1);
        this.setLineNumbers(agg1);
        agg1.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        Transform out = new Transform(agg1, Transform.OperationTypes.Transpose, this.getDataType(), this.getValueType(), LopProperties.ExecType.CP);
        out.getOutputParameters().setDimensions(X.getDim2(), Y.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        return out;
    }

    private void constructMRLopsRMM() {
        MMRJ rmm = new MMRJ(this.getInput().get(0).constructLops(), this.getInput().get(1).constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
        this.setOutputDimensions(rmm);
        this.setLineNumbers(rmm);
        this.setLops(rmm);
    }

    private void constructMRLopsTSMM(MMTSJ.MMTSJType mmtsj) {
        Hop input = this.getInput().get(mmtsj.isLeft() ? 1 : 0);
        MMTSJ tsmm = new MMTSJ(input.constructLops(), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR, mmtsj);
        tsmm.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(tsmm);
        Aggregate agg1 = new Aggregate(tsmm, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
        agg1.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        agg1.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        this.setLineNumbers(agg1);
        this.setLops(agg1);
    }

    private void constructMRLopsPMM() {
        boolean needPart;
        LopProperties.ExecType etVect;
        Hop pmInput = this.getInput().get(0);
        Hop rightInput = this.getInput().get(1);
        Lop lpmInput = pmInput.constructLops();
        Hop nrow = null;
        double mestPM = OptimizerUtils.estimateSize(pmInput.getDim1(), 1L);
        LopProperties.ExecType execType = etVect = mestPM > OptimizerUtils.getLocalMemBudget() ? LopProperties.ExecType.MR : LopProperties.ExecType.CP;
        if (pmInput.getDim2() != 1L) {
            ReorgOp transpose = HopRewriteUtils.createTranspose(pmInput);
            transpose.setForcedExecType(LopProperties.ExecType.MR);
            AggUnaryOp agg1 = HopRewriteUtils.createAggUnaryOp(transpose, Hop.AggOp.MAXINDEX, Hop.Direction.Row);
            agg1.setForcedExecType(LopProperties.ExecType.MR);
            AggUnaryOp agg2 = HopRewriteUtils.createAggUnaryOp(transpose, Hop.AggOp.MAX, Hop.Direction.Row);
            agg2.setForcedExecType(LopProperties.ExecType.MR);
            BinaryOp mult = HopRewriteUtils.createBinary(agg1, agg2, Hop.OpOp2.MULT);
            mult.setForcedExecType(LopProperties.ExecType.MR);
            nrow = HopRewriteUtils.createValueHop(pmInput, true);
            nrow.setOutputBlocksizes(0, 0);
            nrow.setForcedExecType(LopProperties.ExecType.CP);
            HopRewriteUtils.copyLineNumbers(this, nrow);
            lpmInput = mult.constructLops();
            HopRewriteUtils.removeChildReference(pmInput, transpose);
        } else {
            nrow = HopRewriteUtils.createAggUnaryOp(pmInput, Hop.AggOp.MAX, Hop.Direction.RowCol);
            nrow.setOutputBlocksizes(0, 0);
            nrow.setForcedExecType(etVect);
            HopRewriteUtils.copyLineNumbers(this, nrow);
        }
        boolean bl = needPart = !pmInput.dimsKnown() || pmInput.getDim1() > 4000000L;
        if (needPart) {
            lpmInput = new DataPartition(lpmInput, Expression.DataType.MATRIX, Expression.ValueType.DOUBLE, etVect, ParForProgramBlock.PDataPartitionFormat.ROW_BLOCK_WISE_N);
            lpmInput.getOutputParameters().setDimensions(pmInput.getDim1(), 1L, this.getRowsInBlock(), this.getColsInBlock(), pmInput.getDim1());
            this.setLineNumbers(lpmInput);
        }
        this._outputEmptyBlocks = !OptimizerUtils.allowsToFilterEmptyBlockOutputs(this);
        PMMJ pmm = new PMMJ(lpmInput, rightInput.constructLops(), nrow.constructLops(), this.getDataType(), this.getValueType(), needPart, this._outputEmptyBlocks, LopProperties.ExecType.MR);
        pmm.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        this.setLineNumbers(pmm);
        Aggregate aggregate = new Aggregate(pmm, (Aggregate.OperationTypes)((Object)HopsAgg2Lops.get((Object)this.outerOp)), this.getDataType(), this.getValueType(), LopProperties.ExecType.MR);
        aggregate.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getRowsInBlock(), this.getColsInBlock(), this.getNnz());
        aggregate.setupCorrectionLocation(PartialAggregate.CorrectionLocationType.NONE);
        this.setLineNumbers(aggregate);
        this.setLops(aggregate);
        HopRewriteUtils.removeChildReference(pmInput, nrow);
    }

    private boolean isLeftTransposeRewriteApplicable(boolean CP, boolean checkMemMR) {
        if (DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.HADOOP || DMLScript.rtplatform == DMLScript.RUNTIME_PLATFORM.SPARK) {
            return false;
        }
        boolean ret = false;
        Hop h1 = this.getInput().get(0);
        Hop h2 = this.getInput().get(1);
        if (CP) {
            if (HopRewriteUtils.isTransposeOperation(h1)) {
                long m = h1.getDim1();
                long cd = h1.getDim2();
                long n = h2.getDim2();
                ret = m > 0L && cd > 0L && n > 0L;
                double memX = h1.getInput().get(0).getOutputMemEstimate();
                double memtv = OptimizerUtils.estimateSizeExactSparsity(n, cd, 1.0);
                double memtXv = OptimizerUtils.estimateSizeExactSparsity(n, m, 1.0);
                double newMemEstimate = memtv + memX + memtXv;
                ret &= newMemEstimate < OptimizerUtils.getLocalMemBudget();
                if (ret &= m * cd > cd * n + m * n && (double)(2L * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0)) < OptimizerUtils.getLocalMemBudget() && (double)(2L * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0)) < OptimizerUtils.getLocalMemBudget()) {
                    this._memEstimate = newMemEstimate;
                }
            }
        } else if (HopRewriteUtils.isTransposeOperation(h1)) {
            long m = h1.getDim1();
            long cd = h1.getDim2();
            long n = h2.getDim2();
            if (m > 0L && cd > 0L && n > 0L && m * cd > cd * n + m * n && (double)(2L * OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0)) < OptimizerUtils.getLocalMemBudget() && (double)(2L * OptimizerUtils.estimateSizeExactSparsity(m, n, 1.0)) < OptimizerUtils.getLocalMemBudget() && (!checkMemMR || (double)OptimizerUtils.estimateSizeExactSparsity(cd, n, 1.0) < OptimizerUtils.getRemoteMemBudgetMap(true))) {
                ret = true;
            }
        }
        return ret;
    }

    private MMCJ.MMCJType getMMCJAggregationType(Hop X, Hop Y) {
        double sizeX = OptimizerUtils.estimateSize(X.getDim1(), Math.min(X.getDim2(), (long)X.getColsInBlock()));
        double sizeY = OptimizerUtils.estimateSize(Math.min(Y.getDim1(), (long)Y.getRowsInBlock()), Y.getDim2());
        return this.dimsKnown() && (double)(2L * OptimizerUtils.estimateSize(this.getDim1(), this.getDim2())) > OptimizerUtils.getRemoteMemBudgetReduce() && (sizeX < (double)MMCJMRReducerWithAggregator.MIN_CACHE_SIZE || sizeY < (double)MMCJMRReducerWithAggregator.MIN_CACHE_SIZE) ? MMCJ.MMCJType.NO_AGG : MMCJ.MMCJType.AGG;
    }

    private SparkAggType getSparkMMAggregationType(boolean agg) {
        if (!agg) {
            return SparkAggType.NONE;
        }
        if (this.dimsKnown() && this.getDim1() <= (long)this.getRowsInBlock() && this.getDim2() <= (long)this.getColsInBlock()) {
            return SparkAggType.SINGLE_BLOCK;
        }
        return SparkAggType.MULTI_BLOCK;
    }

    private boolean requiresAggregation(MMultMethod method) {
        boolean ret = true;
        if (method == MMultMethod.MAPMM_R && this.getInput().get(0).getDim2() >= 0L && this.getInput().get(0).getDim2() <= (long)this.getInput().get(0).getColsInBlock()) {
            ret = false;
        }
        if (method == MMultMethod.MAPMM_L && this.getInput().get(1).getDim1() >= 0L && this.getInput().get(1).getDim1() <= (long)this.getInput().get(1).getRowsInBlock()) {
            ret = false;
        }
        return ret;
    }

    private boolean requiresPartitioning(MMultMethod method, boolean rewrite) {
        boolean ret = true;
        Hop input1 = this.getInput().get(0);
        Hop input2 = this.getInput().get(1);
        if (method == MMultMethod.MAPMM_R && input2.dimsKnown()) {
            boolean bl = ret = input2.getDim1() * input2.getDim2() > 4000000L;
        }
        if (method == MMultMethod.MAPMM_L && input1.dimsKnown()) {
            ret = input1.getDim1() * input1.getDim2() > 4000000L;
        }
        return ret;
    }

    public static double getMapmmMemEstimate(long m1_rows, long m1_cols, long m1_rpb, long m1_cpb, long m1_nnz, long m2_rows, long m2_cols, long m2_rpb, long m2_cpb, long m2_nnz, int cachedInputIndex, boolean pmm) {
        double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz);
        double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz);
        double m1BlockSize = OptimizerUtils.estimateSize(Math.min(m1_rows, m1_rpb), Math.min(m1_cols, m1_cpb));
        double m2BlockSize = OptimizerUtils.estimateSize(Math.min(m2_rows, m2_rpb), Math.min(m2_cols, m2_cpb));
        double m3m1OutSize = OptimizerUtils.estimateSize(Math.min(m1_rows, m1_rpb), m2_cols);
        double m3m2OutSize = OptimizerUtils.estimateSize(m1_rows, Math.min(m2_cols, m2_cpb));
        double footprint = 0.0;
        footprint = pmm ? m1SizeP + 3.0 * m2BlockSize : (cachedInputIndex == 1 ? m1SizeP + m2BlockSize + m3m2OutSize : m1BlockSize + m2SizeP + m3m1OutSize);
        return footprint;
    }

    private static MMultMethod optFindMMultMethodMR(long m1_rows, long m1_cols, long m1_rpb, long m1_cpb, long m1_nnz, long m2_rows, long m2_cols, long m2_rpb, long m2_cpb, long m2_nnz, MMTSJ.MMTSJType mmtsj, MapMultChain.ChainType chainType, boolean leftPMInput) {
        double memBudget = 1.0 * OptimizerUtils.getRemoteMemBudgetMap(true);
        if (FORCED_MMULT_METHOD != null) {
            return FORCED_MMULT_METHOD;
        }
        if (mmtsj == MMTSJ.MMTSJType.LEFT && m2_cols >= 0L && m2_cols <= m2_cpb || mmtsj == MMTSJ.MMTSJType.RIGHT && m1_rows >= 0L && m1_rows <= m1_rpb) {
            return MMultMethod.TSMM;
        }
        if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && chainType != MapMultChain.ChainType.NONE && m1_rows >= 0L && m1_rows <= m1_rpb && m2_cols == 1L) {
            if (chainType == MapMultChain.ChainType.XtXv && m1_rows >= 0L && m2_cols >= 0L && (double)OptimizerUtils.estimateSize(m1_rows, m2_cols) < memBudget) {
                return MMultMethod.MAPMM_CHAIN;
            }
            if ((chainType == MapMultChain.ChainType.XtwXv || chainType == MapMultChain.ChainType.XtXvy) && m1_rows >= 0L && m2_cols >= 0L && m1_cols >= 0L && (double)(OptimizerUtils.estimateSize(m1_rows, m2_cols) + OptimizerUtils.estimateSize(m1_cols, m2_cols)) < memBudget) {
                return MMultMethod.MAPMM_CHAIN;
            }
        }
        double footprintPM1 = AggBinaryOp.getMapmmMemEstimate(m1_rows, 1L, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, true);
        double footprintPM2 = AggBinaryOp.getMapmmMemEstimate(m2_rows, 1L, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, true);
        if ((footprintPM1 < memBudget && m1_rows >= 0L || footprintPM2 < memBudget && m2_rows >= 0L) && leftPMInput) {
            return MMultMethod.PMM;
        }
        double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz);
        double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz);
        double footprint1 = AggBinaryOp.getMapmmMemEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, false);
        double footprint2 = AggBinaryOp.getMapmmMemEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 2, false);
        if (footprint1 < memBudget && m1_rows >= 0L && m1_cols >= 0L || footprint2 < memBudget && m2_rows >= 0L && m2_cols >= 0L) {
            if (m1SizeP < m2SizeP && m1_rows >= 0L && m1_cols >= 0L) {
                return MMultMethod.MAPMM_L;
            }
            return MMultMethod.MAPMM_R;
        }
        if (m1_rows == -1L || m1_cols == -1L || m2_rows == -1L || m2_cols == -1L) {
            return MMultMethod.CPMM;
        }
        double rmm_costs = AggBinaryOp.getRMMCostEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m2_rows, m2_cols, m2_rpb, m2_cpb);
        double cpmm_costs = AggBinaryOp.getCPMMCostEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m2_rows, m2_cols, m2_rpb, m2_cpb);
        if (cpmm_costs < rmm_costs) {
            return MMultMethod.CPMM;
        }
        return MMultMethod.RMM;
    }

    private static MMultMethod optFindMMultMethodCP(long m1_rows, long m1_cols, long m2_rows, long m2_cols, MMTSJ.MMTSJType mmtsj, MapMultChain.ChainType chainType, boolean leftPM) {
        if (mmtsj != MMTSJ.MMTSJType.NONE) {
            return MMultMethod.TSMM;
        }
        if (chainType != MapMultChain.ChainType.NONE && OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && m2_cols == 1L) {
            return MMultMethod.MAPMM_CHAIN;
        }
        if (leftPM && m1_cols == 1L && m2_rows != 1L) {
            return MMultMethod.PMM;
        }
        return MMultMethod.MM;
    }

    private MMultMethod optFindMMultMethodSpark(long m1_rows, long m1_cols, long m1_rpb, long m1_cpb, long m1_nnz, long m2_rows, long m2_cols, long m2_rpb, long m2_cpb, long m2_nnz, MMTSJ.MMTSJType mmtsj, MapMultChain.ChainType chainType, boolean leftPMInput, boolean tmmRewrite) {
        double memBudgetExec = 1.0 * SparkExecutionContext.getBroadcastMemoryBudget();
        double memBudgetLocal = OptimizerUtils.getLocalMemBudget();
        this._spBroadcastMemEstimate = 0.0;
        if (FORCED_MMULT_METHOD != null) {
            return FORCED_MMULT_METHOD;
        }
        if (mmtsj == MMTSJ.MMTSJType.LEFT && m2_cols >= 0L && m2_cols <= m2_cpb || mmtsj == MMTSJ.MMTSJType.RIGHT && m1_rows >= 0L && m1_rows <= m1_rpb) {
            return MMultMethod.TSMM;
        }
        if (OptimizerUtils.ALLOW_SUM_PRODUCT_REWRITES && chainType != MapMultChain.ChainType.NONE && m1_rows >= 0L && m1_rows <= m1_rpb && m2_cols == 1L) {
            if (chainType == MapMultChain.ChainType.XtXv && m1_rows >= 0L && m2_cols >= 0L && (double)OptimizerUtils.estimateSize(m1_rows, m2_cols) < memBudgetExec) {
                return MMultMethod.MAPMM_CHAIN;
            }
            if ((chainType == MapMultChain.ChainType.XtwXv || chainType == MapMultChain.ChainType.XtXvy) && m1_rows >= 0L && m2_cols >= 0L && m1_cols >= 0L && (double)(OptimizerUtils.estimateSize(m1_rows, m2_cols) + OptimizerUtils.estimateSize(m1_cols, m2_cols)) < memBudgetExec && (double)(2L * (OptimizerUtils.estimateSize(m1_rows, m2_cols) + OptimizerUtils.estimateSize(m1_cols, m2_cols))) < memBudgetLocal) {
                this._spBroadcastMemEstimate = 2L * (OptimizerUtils.estimateSize(m1_rows, m2_cols) + OptimizerUtils.estimateSize(m1_cols, m2_cols));
                return MMultMethod.MAPMM_CHAIN;
            }
        }
        double footprintPM1 = AggBinaryOp.getMapmmMemEstimate(m1_rows, 1L, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, true);
        double footprintPM2 = AggBinaryOp.getMapmmMemEstimate(m2_rows, 1L, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, true);
        if ((footprintPM1 < memBudgetExec && m1_rows >= 0L || footprintPM2 < memBudgetExec && m2_rows >= 0L) && (double)(2L * OptimizerUtils.estimateSize(m1_rows, 1L)) < memBudgetLocal && leftPMInput) {
            this._spBroadcastMemEstimate = 2L * OptimizerUtils.estimateSize(m1_rows, 1L);
            return MMultMethod.PMM;
        }
        double m1Size = OptimizerUtils.estimateSizeExactSparsity(m1_rows, m1_cols, m1_nnz);
        double m2Size = OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols, m2_nnz);
        double m1SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz);
        double m2SizeP = OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz);
        double footprint1 = AggBinaryOp.getMapmmMemEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 1, false);
        double footprint2 = AggBinaryOp.getMapmmMemEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m1_nnz, m2_rows, m2_cols, m2_rpb, m2_cpb, m2_nnz, 2, false);
        if (footprint1 < memBudgetExec && m1Size + m1SizeP < memBudgetLocal && m1_rows >= 0L && m1_cols >= 0L || footprint2 < memBudgetExec && m2Size + m2SizeP < memBudgetLocal && m2_rows >= 0L && m2_cols >= 0L) {
            double em1Size = this.getInput().get(0).getOutputMemEstimate();
            double em2Size = this.getInput().get(1).getOutputMemEstimate();
            if ((m1SizeP < m2SizeP || m1SizeP == m2SizeP && em1Size < em2Size) && m1_rows >= 0L && m1_cols >= 0L && OptimizerUtils.isValidCPDimensions(m1_rows, m1_cols)) {
                this._spBroadcastMemEstimate = m1Size + m1SizeP;
                return MMultMethod.MAPMM_L;
            }
            if (OptimizerUtils.isValidCPDimensions(m2_rows, m2_cols)) {
                this._spBroadcastMemEstimate = m2Size + m2SizeP;
                return MMultMethod.MAPMM_R;
            }
        }
        if (mmtsj != MMTSJ.MMTSJType.NONE && m1_rows >= 0L && m1_cols >= 0L && m2_rows >= 0L && m2_cols >= 0L) {
            double mSizeP;
            double mSize = mmtsj == MMTSJ.MMTSJType.LEFT ? (double)OptimizerUtils.estimateSizeExactSparsity(m2_rows, m2_cols - m2_cpb, 1.0) : (double)OptimizerUtils.estimateSizeExactSparsity(m1_rows - m1_rpb, m1_cols, 1.0);
            double d = mSizeP = mmtsj == MMTSJ.MMTSJType.LEFT ? (double)OptimizerUtils.estimatePartitionedSizeExactSparsity(m2_rows, m2_cols - m2_cpb, m2_rpb, m2_cpb, 1.0) : (double)OptimizerUtils.estimatePartitionedSizeExactSparsity(m1_rows - m1_rpb, m1_cols, m1_rpb, m1_cpb, 1.0);
            if (mSizeP < memBudgetExec && mSize + mSizeP < memBudgetLocal && (mmtsj == MMTSJ.MMTSJType.LEFT ? m2_cols <= 2L * m2_cpb : m1_rows <= 2L * m1_rpb) && mSizeP < 2.147483648E9) {
                return MMultMethod.TSMM2;
            }
        }
        if (m1_rows == -1L || m1_cols == -1L || m2_rows == -1L || m2_cols == -1L) {
            return MMultMethod.CPMM;
        }
        if (tmmRewrite && m1_rows >= 0L && m1_rows <= m1_rpb && m2_cols >= 0L && m2_cols <= m2_cpb) {
            return MMultMethod.ZIPMM;
        }
        double rmm_costs = AggBinaryOp.getRMMCostEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m2_rows, m2_cols, m2_rpb, m2_cpb);
        double cpmm_costs = AggBinaryOp.getCPMMCostEstimate(m1_rows, m1_cols, m1_rpb, m1_cpb, m2_rows, m2_cols, m2_rpb, m2_cpb);
        if (cpmm_costs < rmm_costs) {
            return MMultMethod.CPMM;
        }
        return MMultMethod.RMM;
    }

    private static double getRMMCostEstimate(long m1_rows, long m1_cols, long m1_rpb, long m1_cpb, long m2_rows, long m2_cols, long m2_rpb, long m2_cpb) {
        long m1_nrb = (long)Math.ceil((double)m1_rows / (double)m1_rpb);
        long m2_ncb = (long)Math.ceil((double)m2_cols / (double)m2_cpb);
        double m1_size = m1_rows * m1_cols;
        double m2_size = m2_rows * m2_cols;
        double result_size = m1_rows * m2_cols;
        int numReducersRMM = OptimizerUtils.getNumReducers(true);
        double rmm_shuffle = (double)m2_ncb * m1_size + (double)m1_nrb * m2_size;
        double rmm_io = m1_size + m2_size + result_size;
        double rmm_nred = Math.min(m1_nrb * m2_ncb, (long)numReducersRMM);
        double rmm_costs = (rmm_shuffle + rmm_io) / rmm_nred;
        return rmm_costs;
    }

    private static double getCPMMCostEstimate(long m1_rows, long m1_cols, long m1_rpb, long m1_cpb, long m2_rows, long m2_cols, long m2_rpb, long m2_cpb) {
        long m1_nrb = (long)Math.ceil((double)m1_rows / (double)m1_rpb);
        long m1_ncb = (long)Math.ceil((double)m1_cols / (double)m1_cpb);
        long m2_ncb = (long)Math.ceil((double)m2_cols / (double)m2_cpb);
        double m1_size = m1_rows * m1_cols;
        double m2_size = m2_rows * m2_cols;
        double result_size = m1_rows * m2_cols;
        int numReducersCPMM = OptimizerUtils.getNumReducers(false);
        double cpmm_shuffle1 = m1_size + m2_size;
        double cpmm_nred1 = Math.min(m1_ncb, (long)numReducersCPMM);
        double cpmm_io1 = m1_size + m2_size + cpmm_nred1 * result_size;
        double cpmm_shuffle2 = cpmm_nred1 * result_size;
        double cpmm_io2 = cpmm_nred1 * result_size + result_size;
        double cpmm_nred2 = Math.min(m1_nrb * m2_ncb, (long)numReducersCPMM);
        double cpmm_costs = (cpmm_shuffle1 + cpmm_io1) / cpmm_nred1 + (cpmm_shuffle2 + cpmm_io2) / cpmm_nred2;
        return cpmm_costs;
    }

    @Override
    public void refreshSizeInformation() {
        Hop input1 = this.getInput().get(0);
        Hop input2 = this.getInput().get(1);
        if (this.isMatrixMultiply()) {
            this.setDim1(input1.getDim1());
            this.setDim2(input2.getDim2());
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        AggBinaryOp ret = new AggBinaryOp();
        ret.clone(this, false);
        ret.innerOp = this.innerOp;
        ret.outerOp = this.outerOp;
        ret._hasLeftPMInput = this._hasLeftPMInput;
        ret._maxNumThreads = this._maxNumThreads;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        if (!(that instanceof AggBinaryOp)) {
            return false;
        }
        AggBinaryOp that2 = (AggBinaryOp)that;
        return this.innerOp == that2.innerOp && this.outerOp == that2.outerOp && this.getInput().get(0) == that2.getInput().get(0) && this.getInput().get(1) == that2.getInput().get(1) && this._hasLeftPMInput == that2._hasLeftPMInput && this._maxNumThreads == that2._maxNumThreads;
    }

    public static enum SparkAggType {
        NONE,
        SINGLE_BLOCK,
        MULTI_BLOCK;

    }

    public static enum MMultMethod {
        CPMM,
        RMM,
        MAPMM_L,
        MAPMM_R,
        MAPMM_CHAIN,
        PMAPMM,
        PMM,
        TSMM,
        TSMM2,
        ZIPMM,
        MM;

    }
}

