/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;

public class FunctionProgramBlock
extends ProgramBlock
implements Types.FunctionBlock {
    public String _functionName;
    public String _namespace;
    protected ArrayList<ProgramBlock> _childBlocks = new ArrayList();
    protected ArrayList<DataIdentifier> _inputParams = new ArrayList();
    protected ArrayList<DataIdentifier> _outputParams;
    private boolean _recompileOnce = false;
    private boolean _nondeterministic = false;

    public FunctionProgramBlock(Program prog, List<DataIdentifier> inputParams, List<DataIdentifier> outputParams) {
        super(prog);
        for (DataIdentifier id : inputParams) {
            this._inputParams.add(new DataIdentifier(id));
        }
        this._outputParams = new ArrayList();
        for (DataIdentifier id : outputParams) {
            this._outputParams.add(new DataIdentifier(id));
        }
    }

    public DataIdentifier getInputParam(String name) {
        return this._inputParams.stream().filter(d -> d.getName().equals(name)).findFirst().orElse(null);
    }

    public List<String> getInputParamNames() {
        return this._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
    }

    public List<String> getOutputParamNames() {
        return this._outputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
    }

    public ArrayList<DataIdentifier> getInputParams() {
        return this._inputParams;
    }

    public ArrayList<DataIdentifier> getOutputParams() {
        return this._outputParams;
    }

    public void addProgramBlock(ProgramBlock childBlock) {
        this._childBlocks.add(childBlock);
    }

    public void setChildBlocks(ArrayList<ProgramBlock> pbs) {
        this._childBlocks = pbs;
    }

    @Override
    public ArrayList<ProgramBlock> getChildBlocks() {
        return this._childBlocks;
    }

    @Override
    public boolean isNested() {
        return true;
    }

    @Override
    public void execute(ExecutionContext ec) {
        try {
            if (ConfigurationManager.isDynamicRecompilation() && this.isRecompileOnce()) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
                LocalVariableMap tmp = (LocalVariableMap)ec.getVariables().clone();
                boolean codegen = ConfigurationManager.isCodegenEnabled();
                boolean singlenode = DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE;
                Recompiler.ResetType reset = codegen || singlenode ? Recompiler.ResetType.RESET_KNOWN_DIMS : Recompiler.ResetType.RESET;
                Recompiler.recompileProgramBlockHierarchy(this._childBlocks, tmp, this._tid, false, reset);
                if (DMLScript.STATISTICS) {
                    long t1 = System.nanoTime();
                    Statistics.incrementFunRecompileTime(t1 - t0);
                    Statistics.incrementFunRecompiles();
                }
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException("Error recompiling function body.", ex);
        }
        try {
            for (int i = 0; i < this._childBlocks.size(); ++i) {
                this._childBlocks.get(i).execute(ec);
            }
        }
        catch (DMLScriptException e) {
            throw e;
        }
        catch (Exception e) {
            throw new DMLRuntimeException(this.printBlockErrorLocation() + "Error evaluating function program block", e);
        }
        this.checkOutputParameters(ec.getVariables());
    }

    protected void checkOutputParameters(LocalVariableMap vars) {
        for (DataIdentifier diOut : this._outputParams) {
            String varName = diOut.getName();
            Data dat = vars.get(varName);
            if (dat == null) {
                LOG.error((Object)("Function output " + varName + " is missing."));
                continue;
            }
            if (dat.getDataType() != diOut.getDataType()) {
                LOG.warn((Object)("Function output " + varName + " has wrong data type: " + dat.getDataType() + "."));
                continue;
            }
            if (diOut.getValueType() == Types.ValueType.UNKNOWN || dat.getValueType() == diOut.getValueType()) continue;
            LOG.warn((Object)("Function output " + varName + " has wrong value type: " + dat.getValueType() + "."));
        }
    }

    public void setRecompileOnce(boolean flag) {
        this._recompileOnce = flag;
    }

    public boolean isRecompileOnce() {
        return this._recompileOnce;
    }

    public void setNondeterministic(boolean flag) {
        this._nondeterministic = flag;
    }

    public boolean isNondeterministic() {
        return this._nondeterministic;
    }

    @Override
    public Types.FunctionBlock cloneFunctionBlock() {
        return ProgramConverter.createDeepCopyFunctionProgramBlock(this, new HashSet<String>(), new HashSet<String>());
    }

    @Override
    public String printBlockErrorLocation() {
        return "ERROR: Runtime error in function program block generated from function statement block between lines " + this._beginLine + " and " + this._endLine + " -- ";
    }
}

