/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.gpu;

import java.util.ArrayList;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.lineage.LineageTraceable;
import org.apache.sysds.utils.GPUStatistics;

public class SpoofCUDAInstruction
extends GPUInstruction
implements LineageTraceable {
    private static final Log LOG = LogFactory.getLog((String)SpoofCUDAInstruction.class.getName());
    public static SpoofCUDAOperator.PrecisionProxy proxy = null;
    private final SpoofCUDAOperator _op;
    private final CPOperand[] _in;
    public final CPOperand _out;

    public static void resetFloatingPointPrecision() {
        if (DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("single")) {
            proxy = new SinglePrecision();
        } else if (DMLScript.FLOATING_POINT_PRECISION.equalsIgnoreCase("double")) {
            proxy = new DoublePrecision();
        } else {
            throw new DMLRuntimeException("Unsupported floating point precision: " + DMLScript.FLOATING_POINT_PRECISION);
        }
    }

    private SpoofCUDAInstruction(SpoofCUDAOperator op, CPOperand[] in, CPOperand out, String opcode, String istr) {
        super(null, opcode, istr);
        this._op = op;
        this._in = in;
        this._out = out;
        this.instString = istr;
        this.instOpcode = opcode;
    }

    public static SpoofCUDAInstruction parseInstruction(String str) {
        if (proxy == null) {
            throw new RuntimeException("SpoofCUDA Executor has not been initialized");
        }
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        ArrayList<CPOperand> inlist = new ArrayList<CPOperand>();
        Integer op_id = CodegenUtils.getCUDAopID(parts[2]);
        Class<?> cla = CodegenUtils.getClass(parts[2]);
        SpoofOperator fallback_java_op = CodegenUtils.createInstance(cla);
        SpoofCUDAOperator op = fallback_java_op.createCUDAInstrcution(op_id, proxy);
        String opcode = parts[0] + "CUDA" + fallback_java_op.getSpoofType();
        for (int i = 3; i < parts.length - 2; ++i) {
            inlist.add(new CPOperand(parts[i]));
        }
        CPOperand out = new CPOperand(parts[parts.length - 2]);
        return new SpoofCUDAInstruction(op, inlist.toArray(new CPOperand[0]), out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        ArrayList<MatrixObject> inputs = new ArrayList<MatrixObject>();
        ArrayList<ScalarObject> scalars = new ArrayList<ScalarObject>();
        for (CPOperand input : this._in) {
            if (input.getDataType() == Types.DataType.MATRIX) {
                inputs.add(ec.getMatrixInputForGPUInstruction(input.getName(), this.getExtendedOpcode()));
                continue;
            }
            if (input.getDataType() != Types.DataType.SCALAR) continue;
            scalars.add(ec.getScalarInput(input));
        }
        try {
            if (this._out.getDataType() == Types.DataType.MATRIX) {
                this._op.execute(ec, inputs, scalars, this._out.getName());
                ec.releaseMatrixOutputForGPUInstruction(this._out.getName());
            } else if (this._out.getDataType() == Types.DataType.SCALAR) {
                ScalarObject out = this._op.execute(ec, inputs, scalars);
                ec.setScalarOutput(this._out.getName(), out);
            }
            this._op.releaseScalarGPUMemory(ec);
        }
        catch (Exception ex) {
            LOG.error((Object)("SpoofCUDAInstruction: " + this._op.getName() + " operator failed to execute. Trying Java fallback.(ToDo)\n"));
            throw new DMLRuntimeException(ex);
        }
        for (CPOperand input : this._in) {
            if (input.getDataType() != Types.DataType.MATRIX) continue;
            ec.releaseMatrixInputForGPUInstruction(input.getName());
        }
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        return Pair.of((Object)this._out.getName(), (Object)new LineageItem(this.getOpcode(), LineageItemUtils.getLineage(ec, this._in)));
    }

    public static class DoublePrecision
    extends SpoofCUDAOperator.PrecisionProxy {
        @Override
        public int exec(ExecutionContext ec, SpoofCUDAOperator op, int opID, long[] in, long[] sides, long[] out, ArrayList<ScalarObject> scalarObjects, long grix) {
            if (!scalarObjects.isEmpty()) {
                op.setScalarPtr(this.transferScalars(ec, op, 8, scalarObjects));
            }
            long[] _metadata = new long[]{opID, grix, in.length, sides.length, out.length, scalarObjects.size()};
            return op.execute_dp(this.ctx, _metadata, in, sides, out, GPUObject.getPointerAddress(op.getScalarPtr()));
        }
    }

    public static class SinglePrecision
    extends SpoofCUDAOperator.PrecisionProxy {
        @Override
        public int exec(ExecutionContext ec, SpoofCUDAOperator op, int opID, long[] in, long[] sides, long[] out, ArrayList<ScalarObject> scalarObjects, long grix) {
            op.setScalarPtr(this.transferScalars(ec, op, 4, scalarObjects));
            long[] _metadata = new long[]{opID, grix, in.length, sides.length, out.length, scalarObjects.size()};
            return op.execute_sp(this.ctx, _metadata, in, sides, out, GPUObject.getPointerAddress(op.getScalarPtr()));
        }
    }
}

