/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import org.apache.commons.lang3.NotImplementedException;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.data.CorrMatrixBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixCountDistinct;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperator;
import org.apache.sysds.runtime.matrix.operators.CountDistinctOperatorTypes;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.Hash;
import scala.Tuple2;

public class AggregateUnarySketchSPInstruction
extends UnarySPInstruction {
    private AggBinaryOp.SparkAggType aggtype;
    private CountDistinctOperator op = (CountDistinctOperator)super.getOperator();

    protected AggregateUnarySketchSPInstruction(Operator op, CPOperand in, CPOperand out, AggBinaryOp.SparkAggType aggtype, String opcode, String instr) {
        super(SPInstruction.SPType.AggregateUnarySketch, op, in, out, opcode, instr);
        this.aggtype = aggtype;
    }

    public static AggregateUnarySketchSPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 3);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand out = new CPOperand(parts[2]);
        AggBinaryOp.SparkAggType aggtype = AggBinaryOp.SparkAggType.valueOf(parts[3]);
        CountDistinctOperator cdop = null;
        if (opcode.equals("uacd")) {
            cdop = new CountDistinctOperator(CountDistinctOperatorTypes.COUNT, Types.Direction.RowCol, ReduceAll.getReduceAllFnObject(), Hash.HashType.LinearHash);
        } else {
            if (opcode.equals("uacdr")) {
                throw new NotImplementedException("uacdr has not been implemented yet");
            }
            if (opcode.equals("uacdc")) {
                throw new NotImplementedException("uacdc has not been implemented yet");
            }
            if (opcode.equals("uacdap")) {
                cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.RowCol, ReduceAll.getReduceAllFnObject(), Hash.HashType.LinearHash);
            } else if (opcode.equals("uacdapr")) {
                cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.Row, ReduceCol.getReduceColFnObject(), Hash.HashType.LinearHash);
            } else if (opcode.equals("uacdapc")) {
                cdop = new CountDistinctOperator(CountDistinctOperatorTypes.KMV, Types.Direction.Col, ReduceRow.getReduceRowFnObject(), Hash.HashType.LinearHash);
            } else {
                throw new DMLException("Unrecognized opcode: " + opcode);
            }
        }
        return new AggregateUnarySketchSPInstruction(cdop, in1, out, aggtype, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        if (this.input1.getDataType() == Types.DataType.MATRIX) {
            this.processMatrixSketch(ec);
        } else {
            this.processTensorSketch(ec);
        }
    }

    private void processMatrixSketch(ExecutionContext ec) {
        JavaPairRDD<MatrixIndexes, MatrixBlock> in;
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        if (this.aggtype == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            JavaRDD<CorrMatrixBlock> out1 = out.map(new AggregateUnarySketchCreateFunction(this.op));
            CorrMatrixBlock out2 = out1.fold(new CorrMatrixBlock(new MatrixBlock()), new AggregateUnarySketchUnionAllFunction(this.op));
            MatrixBlock out3 = LibMatrixCountDistinct.countDistinctValuesFromSketch(this.op, out2);
            sec.setMatrixOutput(this.output.getName(), out3);
        } else {
            JavaPairRDD<MatrixIndexes, CorrMatrixBlock> out2;
            if (this.aggtype != AggBinaryOp.SparkAggType.NONE && this.aggtype != AggBinaryOp.SparkAggType.MULTI_BLOCK) {
                throw new DMLRuntimeException(String.format("Unsupported aggregation type: %s", new Object[]{this.aggtype}));
            }
            if (this.aggtype == AggBinaryOp.SparkAggType.NONE) {
                out2 = out.mapValues(new AggregateUnarySketchCreateCombinerFunction(this.op));
            } else {
                JavaPairRDD<MatrixIndexes, MatrixBlock> out1 = out.mapToPair(new RowColGroupingFunction(this.op));
                out2 = out1.combineByKey(new AggregateUnarySketchCreateCombinerFunction(this.op), new AggregateUnarySketchMergeValueFunction(this.op), new AggregateUnarySketchMergeCombinerFunction(this.op));
            }
            JavaPairRDD<MatrixIndexes, MatrixBlock> out3 = out2.mapValues(new CalculateAggregateSketchFunction(this.op));
            this.updateUnaryAggOutputDataCharacteristics(sec, this.op.indexFn);
            sec.setRDDHandleForVariable(this.output.getName(), out3);
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
        }
    }

    private void processTensorSketch(ExecutionContext ec) {
        throw new NotImplementedException("Aggregate sketch instruction for tensors has not been implemented yet.");
    }

    private static class CalculateAggregateSketchFunction
    implements Function<CorrMatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 7504873483231717138L;
        private CountDistinctOperator op;

        public CalculateAggregateSketchFunction(CountDistinctOperator op) {
            this.op = op;
        }

        public MatrixBlock call(CorrMatrixBlock arg0) throws Exception {
            return LibMatrixCountDistinct.countDistinctValuesFromSketch(this.op, arg0);
        }
    }

    private static class AggregateUnarySketchMergeCombinerFunction
    implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 172215143740379070L;
        private CountDistinctOperator op;

        public AggregateUnarySketchMergeCombinerFunction(CountDistinctOperator op) {
            this.op = op;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) throws Exception {
            return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1);
        }
    }

    private static class AggregateUnarySketchMergeValueFunction
    implements Function2<CorrMatrixBlock, MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = -7006864809860460549L;
        private CountDistinctOperator op;

        public AggregateUnarySketchMergeValueFunction(CountDistinctOperator op) {
            this.op = op;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, MatrixBlock arg1) throws Exception {
            CorrMatrixBlock arg1WithCorr = LibMatrixCountDistinct.createSketch(this.op, arg1);
            return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1WithCorr);
        }
    }

    private static class AggregateUnarySketchCreateCombinerFunction
    implements Function<MatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = 8997980606986435297L;
        private final CountDistinctOperator op;

        private AggregateUnarySketchCreateCombinerFunction(CountDistinctOperator op) {
            this.op = op;
        }

        public CorrMatrixBlock call(MatrixBlock arg0) throws Exception {
            return LibMatrixCountDistinct.createSketch(this.op, arg0);
        }
    }

    private static class RowColGroupingFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = -3456633769452405482L;
        private CountDistinctOperator _op;

        public RowColGroupingFunction(CountDistinctOperator op) {
            this._op = op;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes idxIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes idxOut = new MatrixIndexes();
            MatrixBlock blkOut = blkIn;
            this._op.indexFn.execute(idxIn, idxOut);
            return new Tuple2((Object)idxOut, (Object)blkOut);
        }
    }

    private static class AggregateUnarySketchUnionAllFunction
    implements Function2<CorrMatrixBlock, CorrMatrixBlock, CorrMatrixBlock> {
        private static final long serialVersionUID = -3799519241499062936L;
        private CountDistinctOperator op;

        public AggregateUnarySketchUnionAllFunction(CountDistinctOperator op) {
            this.op = op;
        }

        public CorrMatrixBlock call(CorrMatrixBlock arg0, CorrMatrixBlock arg1) throws Exception {
            if (arg0.getCorrection() == null && arg1.getCorrection() == null) {
                throw new DMLRuntimeException("Corrupt sketch: metadata is missing");
            }
            if (arg0.getValue().getNumRows() == 0 && arg0.getValue().getNumColumns() == 0 || arg0.getCorrection() == null) {
                arg0.set(arg1.getValue(), arg1.getCorrection());
                return arg0;
            }
            if (arg1.getValue().getNumRows() == 0 && arg1.getValue().getNumColumns() == 0 || arg1.getCorrection() == null) {
                return arg0;
            }
            return LibMatrixCountDistinct.unionSketch(this.op, arg0, arg1);
        }
    }

    private static class AggregateUnarySketchCreateFunction
    implements Function<Tuple2<MatrixIndexes, MatrixBlock>, CorrMatrixBlock> {
        private static final long serialVersionUID = 7295176181965491548L;
        private CountDistinctOperator op;

        public AggregateUnarySketchCreateFunction(CountDistinctOperator op) {
            this.op = op;
        }

        public CorrMatrixBlock call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes ixOut = new MatrixIndexes();
            this.op.indexFn.execute(ixIn, ixOut);
            return LibMatrixCountDistinct.createSketch(this.op, blkIn);
        }
    }
}

