/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class AggregateTernaryCPInstruction
extends ComputationCPInstruction {
    private static final Log LOG = LogFactory.getLog((String)AggregateTernaryCPInstruction.class.getName());

    private AggregateTernaryCPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr) {
        super(CPInstruction.CPType.AggregateTernary, op, in1, in2, in3, out, opcode, istr);
    }

    public static AggregateTernaryCPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase(Opcodes.TAKPM.toString()) || opcode.equalsIgnoreCase(Opcodes.TACKPM.toString())) {
            InstructionUtils.checkNumFields(parts, 5);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[4]);
            int numThreads = Integer.parseInt(parts[5]);
            AggregateTernaryOperator op = InstructionUtils.parseAggregateTernaryOperator(opcode, numThreads);
            return new AggregateTernaryCPInstruction(op, in1, in2, in3, out, opcode, str);
        }
        throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + opcode);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixBlock matBlock1 = ec.getMatrixInput(this.input1.getName());
        MatrixBlock matBlock2 = ec.getMatrixInput(this.input2.getName());
        MatrixBlock matBlock3 = this.input3.isLiteral() ? null : ec.getMatrixInput(this.input3.getName());
        AggregateTernaryOperator ab_op = (AggregateTernaryOperator)this._optr;
        AggregateTernaryCPInstruction.validateInput(matBlock1, matBlock2, matBlock3, ab_op);
        MatrixBlock ret = MatrixBlock.aggregateTernaryOperations(matBlock1, matBlock2, matBlock3, new MatrixBlock(), ab_op, true);
        ec.releaseMatrixInput(this.input1.getName());
        ec.releaseMatrixInput(this.input2.getName());
        if (!this.input3.isLiteral()) {
            ec.releaseMatrixInput(this.input3.getName());
        }
        if (this.output.getDataType().isScalar()) {
            ec.setScalarOutput(this.output.getName(), new DoubleObject(ret.get(0, 0)));
        } else {
            ec.setMatrixOutput(this.output.getName(), ret);
        }
    }

    private static void validateInput(MatrixBlock m1, MatrixBlock m2, MatrixBlock m3, AggregateTernaryOperator op) {
        int m3c;
        int m1r = m1.getNumRows();
        int m2r = m2.getNumRows();
        int m3r = m3 == null ? m2r : m3.getNumRows();
        int m1c = m1.getNumColumns();
        int m2c = m2.getNumColumns();
        int n = m3c = m3 == null ? m2c : m3.getNumColumns();
        if (m1r != m2r || m1c != m2c || m2r != m3r || m2c != m3c) {
            if (LOG.isTraceEnabled()) {
                LOG.trace((Object)("matBlock1:" + m1));
                LOG.trace((Object)("matBlock2:" + m2));
                LOG.trace((Object)("matBlock3:" + m3));
            }
            throw new DMLRuntimeException("Invalid dimensions for aggregate ternary (" + m1r + "x" + m1c + ", " + m2r + "x" + m2c + ", " + m3r + "x" + m3c + ").");
        }
        if (!(op.aggOp.increOp.fn instanceof KahanPlus) || !(op.binaryFn instanceof Multiply)) {
            throw new DMLRuntimeException("Unsupported operator for aggregate ternary operations.");
        }
    }
}

