/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import java.io.Serializable;
import java.util.Iterator;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import scala.Tuple2;

public class QuaternarySPInstruction
extends ComputationSPInstruction {
    private CPOperand _input4 = null;
    private boolean _cacheU = false;
    private boolean _cacheV = false;

    private QuaternarySPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, boolean cacheU, boolean cacheV, String opcode, String str) {
        super(SPInstruction.SPType.Quaternary, op, in1, in2, in3, out, opcode, str);
        this._input4 = in4;
        this._cacheU = cacheU;
        this._cacheV = cacheV;
    }

    public static QuaternarySPInstruction parseInstruction(String str) {
        boolean cacheV;
        int addInput4;
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!InstructionUtils.isDistQuaternaryOpcode(opcode)) {
            throw new DMLRuntimeException("Quaternary.parseInstruction():: Unknown opcode " + opcode);
        }
        if ("mapwsloss".equalsIgnoreCase(opcode) || "redwsloss".equalsIgnoreCase(opcode)) {
            boolean isRed = "redwsloss".equalsIgnoreCase(opcode);
            if (isRed) {
                InstructionUtils.checkNumFields(parts, 8);
            } else {
                InstructionUtils.checkNumFields(parts, 6);
            }
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            WeightedSquaredLoss.WeightsType wtype = WeightedSquaredLoss.WeightsType.valueOf(parts[6]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV2 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternarySPInstruction(new QuaternaryOperator(wtype), in1, in2, in3, in4, out, cacheU, cacheV2, opcode, str);
        }
        if ("mapwumm".equalsIgnoreCase(opcode) || "redwumm".equalsIgnoreCase(opcode)) {
            boolean isRed = "redwumm".equalsIgnoreCase(opcode);
            if (isRed) {
                InstructionUtils.checkNumFields(parts, 8);
            } else {
                InstructionUtils.checkNumFields(parts, 6);
            }
            String uopcode = parts[1];
            CPOperand in1 = new CPOperand(parts[2]);
            CPOperand in2 = new CPOperand(parts[3]);
            CPOperand in3 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            WeightedUnaryMM.WUMMType wtype = WeightedUnaryMM.WUMMType.valueOf(parts[6]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV3 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            return new QuaternarySPInstruction(new QuaternaryOperator(wtype, uopcode), in1, in2, in3, null, out, cacheU, cacheV3, opcode, str);
        }
        if ("mapwdivmm".equalsIgnoreCase(opcode) || "redwdivmm".equalsIgnoreCase(opcode)) {
            boolean isRed = opcode.startsWith("red");
            if (isRed) {
                InstructionUtils.checkNumFields(parts, 8);
            } else {
                InstructionUtils.checkNumFields(parts, 6);
            }
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            boolean cacheU = isRed ? Boolean.parseBoolean(parts[7]) : true;
            boolean cacheV4 = isRed ? Boolean.parseBoolean(parts[8]) : true;
            WeightedDivMM.WDivMMType wt = WeightedDivMM.WDivMMType.valueOf(parts[6]);
            QuaternaryOperator qop = wt.hasScalar() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt);
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV4, opcode, str);
        }
        boolean isRed = opcode.startsWith("red");
        int n = addInput4 = opcode.endsWith("wcemm") ? 1 : 0;
        if (isRed) {
            InstructionUtils.checkNumFields(parts, 7 + addInput4);
        } else {
            InstructionUtils.checkNumFields(parts, 5 + addInput4);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand out = new CPOperand(parts[4 + addInput4]);
        boolean cacheU = isRed ? Boolean.parseBoolean(parts[6 + addInput4]) : true;
        boolean bl = cacheV = isRed ? Boolean.parseBoolean(parts[7 + addInput4]) : true;
        if (opcode.endsWith("wsigmoid")) {
            return new QuaternarySPInstruction(new QuaternaryOperator(WeightedSigmoid.WSigmoidType.valueOf(parts[5])), in1, in2, in3, null, out, cacheU, cacheV, opcode, str);
        }
        if (opcode.endsWith("wcemm")) {
            CPOperand in4 = new CPOperand(parts[4]);
            WeightedCrossEntropy.WCeMMType wt = WeightedCrossEntropy.WCeMMType.valueOf(parts[6]);
            QuaternaryOperator qop = wt.hasFourInputs() ? new QuaternaryOperator(wt, Double.parseDouble(in4.getName())) : new QuaternaryOperator(wt);
            return new QuaternarySPInstruction(qop, in1, in2, in3, in4, out, cacheU, cacheV, opcode, str);
        }
        return null;
    }

    /*
     * Exception decompiling
     */
    @Override
    public void processInstruction(ExecutionContext ec) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * java.lang.UnsupportedOperationException
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.NewAnonymousArray.getDimSize(NewAnonymousArray.java:142)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.isNewArrayLambda(LambdaRewriter.java:455)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteDynamicExpression(LambdaRewriter.java:409)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteDynamicExpression(LambdaRewriter.java:167)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteExpression(LambdaRewriter.java:105)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.rewriters.ExpressionRewriterHelper.applyForwards(ExpressionRewriterHelper.java:12)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.AbstractMemberFunctionInvokation.applyExpressionRewriterToArgs(AbstractMemberFunctionInvokation.java:101)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.AbstractMemberFunctionInvokation.applyExpressionRewriter(AbstractMemberFunctionInvokation.java:88)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteExpression(LambdaRewriter.java:103)
         *     at org.benf.cfr.reader.bytecode.analysis.parse.expression.TernaryExpression.applyExpressionRewriter(TernaryExpression.java:106)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewriteExpression(LambdaRewriter.java:103)
         *     at org.benf.cfr.reader.bytecode.analysis.structured.statement.StructuredAssignment.rewriteExpressions(StructuredAssignment.java:146)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.op4rewriters.LambdaRewriter.rewrite(LambdaRewriter.java:88)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.rewriteLambdas(Op04StructuredStatement.java:1137)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:912)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    private void updateOutputDataCharacteristics(SparkExecutionContext sec, QuaternaryOperator qop) {
        DataCharacteristics mcIn1 = sec.getDataCharacteristics(this.input1.getName());
        DataCharacteristics mcIn2 = sec.getDataCharacteristics(this.input2.getName());
        DataCharacteristics mcIn3 = sec.getDataCharacteristics(this.input3.getName());
        DataCharacteristics mcOut = sec.getDataCharacteristics(this.output.getName());
        if (qop.wtype2 != null || qop.wtype5 != null) {
            mcOut.set(mcIn1.getRows(), mcIn1.getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize());
        } else if (qop.wtype3 != null) {
            long rank = qop.wtype3.isLeft() ? mcIn3.getCols() : mcIn2.getCols();
            MatrixCharacteristics mcTmp = qop.wtype3.computeOutputCharacteristics(mcIn1.getRows(), mcIn1.getCols(), rank);
            mcOut.set(((DataCharacteristics)mcTmp).getRows(), ((DataCharacteristics)mcTmp).getCols(), mcIn1.getBlocksize(), mcIn1.getBlocksize());
        }
    }

    public CPOperand getInput4() {
        return this._input4;
    }

    private static /* synthetic */ MatrixBlock[] lambda$processInstruction$4447ded8$2(MatrixBlock[] mb) throws Exception {
        return (MatrixBlock[])ArrayUtils.add((Object[])mb, null);
    }

    private static /* synthetic */ MatrixBlock[] lambda$processInstruction$4447ded8$1(MatrixBlock[] mb) throws Exception {
        return (MatrixBlock[])ArrayUtils.add((Object[])mb, null);
    }

    private static class Unpack
    implements PairFunction<Tuple2<Long, Tuple2<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixBlock>>, MatrixIndexes, MatrixBlock[]> {
        private static final long serialVersionUID = 3812660351709830714L;

        private Unpack() {
        }

        public Tuple2<MatrixIndexes, MatrixBlock[]> call(Tuple2<Long, Tuple2<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixBlock>> arg) throws Exception {
            return new Tuple2((Object)((MatrixIndexes)((Tuple2)((Tuple2)arg._2())._1())._1()), (Object)((MatrixBlock[])ArrayUtils.addAll((Object[])((MatrixBlock[])((Tuple2)((Tuple2)arg._2())._1())._2()), (Object[])new MatrixBlock[]{(MatrixBlock)((Tuple2)arg._2())._2()})));
        }
    }

    private static class ToArray
    implements PairFunction<Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>>, MatrixIndexes, MatrixBlock[]> {
        private static final long serialVersionUID = -4856316007590144978L;

        private ToArray() {
        }

        public Tuple2<MatrixIndexes, MatrixBlock[]> call(Tuple2<MatrixIndexes, Tuple2<MatrixBlock, MatrixBlock>> arg) throws Exception {
            return new Tuple2((Object)((MatrixIndexes)arg._1()), (Object)new MatrixBlock[]{(MatrixBlock)((Tuple2)arg._2())._1(), (MatrixBlock)((Tuple2)arg._2())._2()});
        }
    }

    private static class ExtractIndexWith
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, Long, Tuple2<MatrixIndexes, MatrixBlock[]>> {
        private static final long serialVersionUID = -966212318512764461L;
        private final boolean _row;

        public ExtractIndexWith(boolean row) {
            this._row = row;
        }

        public Tuple2<Long, Tuple2<MatrixIndexes, MatrixBlock[]>> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg) throws Exception {
            return new Tuple2((Object)(this._row ? ((MatrixIndexes)arg._1()).getRowIndex() : ((MatrixIndexes)arg._1()).getColumnIndex()), arg);
        }
    }

    private static class ExtractIndex
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, Long, MatrixBlock> {
        private static final long serialVersionUID = -6542246824481788376L;
        private final boolean _row;

        public ExtractIndex(boolean row) {
            this._row = row;
        }

        public Tuple2<Long, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception {
            return new Tuple2((Object)(this._row ? ((MatrixIndexes)arg._1()).getRowIndex() : ((MatrixIndexes)arg._1()).getColumnIndex()), (Object)((MatrixBlock)arg._2()));
        }
    }

    private static class RDDQuaternaryFunction2
    extends RDDQuaternaryBaseFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock[]>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 7493974462943080693L;

        public RDDQuaternaryFunction2(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            super(qop, bcU, bcV);
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock[]> arg0) {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock[] blks = (MatrixBlock[])arg0._2();
            MatrixBlock mbU = this._pmU != null ? (MatrixBlock)this._pmU.getBlock((int)ixIn.getRowIndex(), 1) : blks[2];
            MatrixBlock mbV = this._pmV != null ? (MatrixBlock)this._pmV.getBlock((int)ixIn.getColumnIndex(), 1) : blks[3];
            MatrixBlock mbW = this._qop.hasFourInputs() ? blks[1] : null;
            MatrixBlock blkOut = blks[0].quaternaryOperations(this._qop, mbU, mbV, mbW, new MatrixBlock());
            MatrixIndexes ixOut = this.createOutputIndexes(ixIn);
            return new Tuple2((Object)ixOut, (Object)blkOut);
        }
    }

    private static class RDDQuaternaryFunction1
    extends RDDQuaternaryBaseFunction
    implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, MatrixBlock>>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -8209188316939435099L;

        public RDDQuaternaryFunction1(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            super(qop, bcU, bcV);
        }

        public LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> call(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> arg) {
            return new RDDQuaternaryPartitionIterator(arg);
        }

        private class RDDQuaternaryPartitionIterator
        extends LazyIterableIterator<Tuple2<MatrixIndexes, MatrixBlock>> {
            public RDDQuaternaryPartitionIterator(Iterator<Tuple2<MatrixIndexes, MatrixBlock>> in) {
                super(in);
            }

            @Override
            protected Tuple2<MatrixIndexes, MatrixBlock> computeNext(Tuple2<MatrixIndexes, MatrixBlock> arg) {
                MatrixIndexes ixIn = (MatrixIndexes)arg._1();
                MatrixBlock blkIn = (MatrixBlock)arg._2();
                MatrixBlock mbU = (MatrixBlock)RDDQuaternaryFunction1.this._pmU.getBlock((int)ixIn.getRowIndex(), 1);
                MatrixBlock mbV = (MatrixBlock)RDDQuaternaryFunction1.this._pmV.getBlock((int)ixIn.getColumnIndex(), 1);
                MatrixBlock blkOut = blkIn.quaternaryOperations(RDDQuaternaryFunction1.this._qop, mbU, mbV, null, new MatrixBlock());
                MatrixIndexes ixOut = RDDQuaternaryFunction1.this.createOutputIndexes(ixIn);
                return new Tuple2((Object)ixOut, (Object)blkOut);
            }
        }
    }

    private static abstract class RDDQuaternaryBaseFunction
    implements Serializable {
        private static final long serialVersionUID = -3175397651350954930L;
        protected QuaternaryOperator _qop = null;
        protected PartitionedBroadcast<MatrixBlock> _pmU = null;
        protected PartitionedBroadcast<MatrixBlock> _pmV = null;

        public RDDQuaternaryBaseFunction(QuaternaryOperator qop, PartitionedBroadcast<MatrixBlock> bcU, PartitionedBroadcast<MatrixBlock> bcV) {
            this._qop = qop;
            this._pmU = bcU;
            this._pmV = bcV;
        }

        protected MatrixIndexes createOutputIndexes(MatrixIndexes in) {
            if (this._qop.wtype3 != null && !this._qop.wtype3.isBasic()) {
                boolean left = this._qop.wtype3.isLeft();
                return new MatrixIndexes(left ? in.getColumnIndex() : in.getRowIndex(), 1L);
            }
            return in;
        }
    }
}

