/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.BinaryMatrixScalarSPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class BinaryMatrixScalarFEDInstruction
extends BinaryFEDInstruction {
    protected BinaryMatrixScalarFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, FEDInstruction.FederatedOutput fedOut) {
        super(FEDInstruction.FEDType.Binary, op, in1, in2, out, opcode, istr, fedOut);
    }

    public static BinaryMatrixScalarFEDInstruction parseInstruction(BinaryMatrixScalarCPInstruction instr) {
        return new BinaryMatrixScalarFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output, instr.getOpcode(), instr.getInstructionString(), FEDInstruction.FederatedOutput.NONE);
    }

    public static BinaryMatrixScalarFEDInstruction parseInstruction(BinaryMatrixScalarSPInstruction instr) {
        String instrStr = BinaryMatrixScalarFEDInstruction.rewriteSparkInstructionToCP(instr.getInstructionString());
        String opcode = InstructionUtils.getInstructionPartsWithValueType(instrStr)[0];
        return new BinaryMatrixScalarFEDInstruction(instr.getOperator(), instr.input1, instr.input2, instr.output, opcode, instrStr, FEDInstruction.FederatedOutput.NONE);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        CPOperand matrix = this.input1.isMatrix() ? this.input1 : this.input2;
        CPOperand scalar = this.input2.isScalar() ? this.input2 : this.input1;
        MatrixObject mo = ec.getMatrixObject(matrix);
        FederatedRequest fr1 = !scalar.isLiteral() ? mo.getFedMapping().broadcast(ec.getScalarInput(scalar)) : null;
        FederatedRequest fr2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{matrix, fr1 != null ? scalar : null}, new long[]{mo.getFedMapping().getID(), fr1 != null ? fr1.getID() : -1L}, true);
        Future<FederatedResponse>[] ffr = null;
        if (fr1 != null) {
            FederatedRequest fr3 = mo.getFedMapping().cleanup(this.getTID(), fr1.getID());
            ffr = mo.getFedMapping().execute(this.getTID(), true, fr1, fr2, fr3);
        } else {
            ffr = mo.getFedMapping().execute(this.getTID(), true, fr2);
        }
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(mo.getDataCharacteristics()).setNonZeros(FederationUtils.sumNonZeros(ffr));
        out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
    }
}

