/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import org.apache.sysds.common.Opcodes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class AggregateBinaryCPInstruction
extends BinaryCPInstruction {
    public final boolean transposeLeft;
    public final boolean transposeRight;

    private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(CPInstruction.CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
        this.transposeLeft = false;
        this.transposeRight = false;
    }

    private AggregateBinaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, boolean transposeLeft, boolean transposeRight) {
        super(CPInstruction.CPType.AggregateBinary, op, in1, in2, out, opcode, istr);
        this.transposeLeft = transposeLeft;
        this.transposeRight = transposeRight;
    }

    public static AggregateBinaryCPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase(Opcodes.MMULT.toString())) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        int numFields = InstructionUtils.checkNumFields(parts, 4, 6);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        int k = Integer.parseInt(parts[4]);
        AggregateBinaryOperator op = InstructionUtils.getMatMultOperator(k);
        if (numFields == 6) {
            boolean lt = Boolean.parseBoolean(parts[5]);
            boolean rt = Boolean.parseBoolean(parts[6]);
            return new AggregateBinaryCPInstruction(op, in1, in2, out, opcode, str, lt, rt);
        }
        return new AggregateBinaryCPInstruction(op, in1, in2, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixBlock matBlock1 = ec.getMatrixInput(this.input1.getName());
        MatrixBlock matBlock2 = ec.getMatrixInput(this.input2.getName());
        boolean comp1 = matBlock1 instanceof CompressedMatrixBlock;
        boolean comp2 = matBlock2 instanceof CompressedMatrixBlock;
        if (comp1 || comp2) {
            this.processCompressedAggregateBinary(ec, matBlock1, matBlock2, comp1, comp2);
        } else if (this.transposeLeft || this.transposeRight) {
            this.processTransposedFusedAggregateBinary(ec, matBlock1, matBlock2);
        } else {
            this.processNormal(ec, matBlock1, matBlock2);
        }
    }

    private void processNormal(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2) {
        AggregateBinaryOperator ab_op = (AggregateBinaryOperator)this._optr;
        MatrixBlock ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this.input2.getName());
        ec.setMatrixOutput(this.output.getName(), ret);
    }

    private void processTransposedFusedAggregateBinary(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2) {
        AggregateBinaryOperator ab_op = (AggregateBinaryOperator)this._optr;
        if (this.transposeLeft) {
            matBlock1 = LibMatrixReorg.transpose(matBlock1, ab_op.getNumThreads());
            ec.releaseMatrixInput(this.input1.getName());
        }
        if (this.transposeRight) {
            matBlock2 = LibMatrixReorg.transpose(matBlock2, ab_op.getNumThreads());
            ec.releaseMatrixInput(this.input2.getName());
        }
        MatrixBlock ret = matBlock1.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op);
        if (!this.transposeLeft) {
            ec.releaseMatrixInput(this.input1.getName());
        }
        if (!this.transposeRight) {
            ec.releaseMatrixInput(this.input2.getName());
        }
        ec.setMatrixOutput(this.output.getName(), ret);
    }

    private void processCompressedAggregateBinary(ExecutionContext ec, MatrixBlock matBlock1, MatrixBlock matBlock2, boolean c1, boolean c2) {
        MatrixBlock ret;
        AggregateBinaryOperator ab_op = (AggregateBinaryOperator)this._optr;
        if (c1) {
            CompressedMatrixBlock main = (CompressedMatrixBlock)matBlock1;
            ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, this.transposeLeft, this.transposeRight);
        } else {
            CompressedMatrixBlock main = (CompressedMatrixBlock)matBlock2;
            ret = main.aggregateBinaryOperations(matBlock1, matBlock2, new MatrixBlock(), ab_op, this.transposeLeft, this.transposeRight);
        }
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this.input2.getName());
        ec.setMatrixOutput(this.output.getName(), ret);
    }
}

