/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.dml.DmlSyntacticValidator;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.BuiltinNaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.DataConverter;

public class EvalNaryCPInstruction
extends BuiltinNaryCPInstruction {
    public EvalNaryCPInstruction(Operator op, String opcode, String istr, CPOperand output, CPOperand ... inputs) {
        super(op, opcode, istr, output, inputs);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        String funcName = ec.getScalarInput(this.inputs[0]).getStringValue();
        if (funcName.contains("::")) {
            throw new DMLRuntimeException("Eval calls to '" + funcName + "', i.e., a function outside the default namespace, are not supported yet. Please call the function directly.");
        }
        CPOperand[] boundInputs = Arrays.copyOfRange(this.inputs, 1, this.inputs.length);
        ArrayList<String> boundOutputNames = new ArrayList<String>();
        boundOutputNames.add(this.output.getName());
        MatrixObject outputMO = new MatrixObject(ec.getMatrixObject(this.output.getName()));
        Types.DataType dt1 = boundInputs[0].getDataType().isList() ? Types.DataType.MATRIX : boundInputs[0].getDataType();
        String funcName2 = Builtins.getInternalFName(funcName, dt1);
        if (!ec.getProgram().containsFunctionProgramBlock(null, funcName)) {
            if (!ec.getProgram().containsFunctionProgramBlock(null, funcName2)) {
                EvalNaryCPInstruction.compileFunctionProgramBlock(funcName, dt1, ec.getProgram());
            }
            funcName = funcName2;
        }
        FunctionProgramBlock fpb = ec.getProgram().getFunctionProgramBlock(null, funcName, false);
        CPOperand[] boundInputs2 = null;
        if (boundInputs.length == 1 && boundInputs[0].getDataType().isList() && (fpb.getInputParams().size() != 1 || !fpb.getInputParams().get(0).getDataType().isList())) {
            ListObject lo = ec.getListObject(boundInputs[0]);
            EvalNaryCPInstruction.checkValidArguments(lo.getData(), lo.getNames(), fpb.getInputParamNames());
            if (lo.isNamedList()) {
                lo = EvalNaryCPInstruction.reorderNamedListForFunctionCall(lo, fpb.getInputParamNames());
            }
            boundInputs2 = new CPOperand[lo.getLength()];
            for (int i = 0; i < lo.getLength(); ++i) {
                Data in = lo.getData(i);
                String varName = Dag.getNextUniqueVarname(in.getDataType());
                ec.getVariables().put(varName, in);
                boundInputs2[i] = new CPOperand(varName, in);
            }
            boundInputs = boundInputs2;
        }
        FunctionCallCPInstruction fcpi = new FunctionCallCPInstruction(null, funcName, false, boundInputs, fpb.getInputParamNames(), boundOutputNames, "eval func");
        fcpi.processInstruction(ec);
        Data newOutput = ec.getVariable(this.output);
        if (!(newOutput instanceof MatrixObject)) {
            MatrixBlock mb = null;
            if (newOutput instanceof ScalarObject) {
                mb = new MatrixBlock(((ScalarObject)newOutput).getDoubleValue());
            } else if (newOutput instanceof FrameObject) {
                mb = DataConverter.convertToMatrixBlock((FrameBlock)((FrameObject)newOutput).acquireRead());
                ec.cleanupCacheableData((FrameObject)newOutput);
            }
            outputMO.acquireModify(mb);
            outputMO.release();
            ec.setVariable(this.output.getName(), outputMO);
        }
        if (boundInputs2 != null) {
            for (CPOperand op : boundInputs2) {
                VariableCPInstruction.processRmvarInstruction(ec, op.getName());
            }
        }
    }

    private static void compileFunctionProgramBlock(String name, Types.DataType dt, Program prog) {
        Map<String, FunctionStatementBlock> fsbs = DmlSyntacticValidator.loadAndParseBuiltinFunction(name, ".defaultNS");
        if (fsbs.isEmpty()) {
            throw new DMLRuntimeException("Failed to compile function '" + name + "'.");
        }
        DMLProgram dmlp = prog.getDMLProg() != null ? prog.getDMLProg() : fsbs.get(Builtins.getInternalFName(name, dt)).getDMLProg();
        for (Map.Entry<String, FunctionStatementBlock> fsb : fsbs.entrySet()) {
            if (!dmlp.getDefaultFunctionDictionary().containsFunction(fsb.getKey())) {
                dmlp.addFunctionStatementBlock(fsb.getKey(), fsb.getValue());
            }
            fsb.getValue().setDMLProg(dmlp);
        }
        DMLTranslator dmlt = new DMLTranslator(dmlp);
        ProgramRewriter rewriter = new ProgramRewriter(true, false);
        ProgramRewriter rewriter2 = new ProgramRewriter(false, true);
        for (FunctionStatementBlock functionStatementBlock : fsbs.values()) {
            dmlt.liveVariableAnalysisFunction(dmlp, functionStatementBlock);
            dmlt.validateFunction(dmlp, functionStatementBlock);
        }
        for (FunctionStatementBlock functionStatementBlock : fsbs.values()) {
            dmlt.constructHops(functionStatementBlock);
            rewriter.rewriteHopDAGsFunction(functionStatementBlock, false);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            rewriter.rewriteHopDAGsFunction(functionStatementBlock, true);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            rewriter2.rewriteHopDAGsFunction(functionStatementBlock, true);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            HopRewriteUtils.setUnoptimizedFunctionCalls(functionStatementBlock);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock);
            DMLTranslator.refreshMemEstimates(functionStatementBlock);
            dmlt.constructLops(functionStatementBlock);
        }
        for (Map.Entry entry : fsbs.entrySet()) {
            if (prog.containsFunctionProgramBlock(null, (String)entry.getKey(), false)) continue;
            FunctionProgramBlock fpb = (FunctionProgramBlock)dmlt.createRuntimeProgramBlock(prog, (StatementBlock)entry.getValue(), ConfigurationManager.getDMLConfig());
            prog.addFunctionProgramBlock(null, (String)entry.getKey(), fpb, true);
            prog.addFunctionProgramBlock(null, (String)entry.getKey(), fpb, false);
        }
    }

    private static void checkValidArguments(List<Data> loData, List<String> loNames, List<String> fArgNames) {
        int listSize;
        int n = listSize = loNames != null ? loNames.size() : loData.size();
        if (listSize != fArgNames.size()) {
            throw new DMLRuntimeException("Failed to expand list for function call (mismatching number of arguments: " + listSize + " vs. " + fArgNames.size() + ").");
        }
        if (loNames != null) {
            HashSet<String> probe = new HashSet<String>();
            for (String var : fArgNames) {
                probe.add(var);
            }
            for (String var : loNames) {
                if (probe.contains(var)) continue;
                throw new DMLRuntimeException("List argument named '" + var + "' not in function signature.");
            }
        }
    }

    private static ListObject reorderNamedListForFunctionCall(ListObject in, List<String> fArgNames) {
        ArrayList<Data> sortedData = new ArrayList<Data>();
        for (String name : fArgNames) {
            sortedData.add(in.getData(name));
        }
        return new ListObject(sortedData, new ArrayList<String>(fArgNames));
    }
}

