/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.controlprogram.paramserv;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysml.api.DMLScript;
import org.apache.sysml.parser.DataIdentifier;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.parser.Statement;
import org.apache.sysml.runtime.DMLRuntimeException;
import org.apache.sysml.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysml.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysml.runtime.instructions.cp.ListObject;
import org.apache.sysml.utils.Statistics;

public abstract class ParamServer {
    protected static final Log LOG = LogFactory.getLog(ParamServer.class.getName());
    protected static final boolean ACCRUE_BSP_GRADIENTS = true;
    protected Map<Integer, BlockingQueue<ListObject>> _modelMap;
    private ListObject _model;
    protected ExecutionContext _ec;
    private Statement.PSUpdateType _updateType;
    private FunctionCallCPInstruction _inst;
    private String _outputName;
    private boolean[] _finishedStates;
    private ListObject _accGradients = null;

    protected ParamServer() {
    }

    protected ParamServer(ListObject model, String aggFunc, Statement.PSUpdateType updateType, ExecutionContext ec, int workerNum) {
        this._modelMap = new HashMap<Integer, BlockingQueue<ListObject>>(workerNum);
        IntStream.range(0, workerNum).forEach(i -> this._modelMap.put(i, new ArrayBlockingQueue(1)));
        this._model = model;
        this._ec = ec;
        this._updateType = updateType;
        this._finishedStates = new boolean[workerNum];
        this.setupAggFunc(this._ec, aggFunc);
        this.broadcastModel(true);
    }

    protected void setupAggFunc(ExecutionContext ec, String aggFunc) {
        String[] cfn = ParamservUtils.getCompleteFuncName(aggFunc, "_ps_");
        String ns = cfn[0];
        String fname = cfn[1];
        FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname);
        ArrayList<DataIdentifier> inputs = func.getInputParams();
        ArrayList<DataIdentifier> outputs = func.getOutputParams();
        if (outputs.size() != 1) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the updated model.", aggFunc));
        }
        if (outputs.get(0).getDataType() != Expression.DataType.LIST) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", aggFunc));
        }
        this._outputName = outputs.get(0).getName();
        CPOperand[] boundInputs = (CPOperand[])inputs.stream().map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())).toArray(CPOperand[]::new);
        ArrayList inputNames = inputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
        ArrayList outputNames = outputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
        this._inst = new FunctionCallCPInstruction(ns, fname, boundInputs, inputNames, func.getInputParamNames(), outputNames, "aggregate function");
    }

    public abstract void push(int var1, ListObject var2);

    public abstract ListObject pull(int var1);

    public ListObject getResult() {
        return this._model;
    }

    protected synchronized void updateGlobalModel(int workerID, ListObject gradients) {
        try {
            if (LOG.isDebugEnabled()) {
                LOG.debug(String.format("Successfully pulled the gradients [size:%d kb] of worker_%d.", gradients.getDataSize() / 1024L, workerID));
            }
            switch (this._updateType) {
                case BSP: {
                    this.setFinishedState(workerID);
                    this._accGradients = ParamservUtils.accrueGradients(this._accGradients, gradients, true);
                    if (this.allFinished()) {
                        this.updateGlobalModel(this._accGradients);
                        this._accGradients = null;
                        this.resetFinishedStates();
                        this.broadcastModel(true);
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("Global parameter is broadcasted successfully.");
                        }
                    }
                    break;
                }
                case ASP: {
                    this.updateGlobalModel(gradients);
                    this.broadcastModel(workerID);
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Unsupported update: " + this._updateType.name());
                }
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Aggregation service failed: ", e);
        }
    }

    private void updateGlobalModel(ListObject gradients) {
        Timing tAgg = DMLScript.STATISTICS ? new Timing(true) : null;
        this._model = this.updateLocalModel(this._ec, gradients, this._model);
        if (DMLScript.STATISTICS) {
            Statistics.accPSAggregationTime((long)tAgg.stop());
        }
    }

    protected ListObject updateLocalModel(ExecutionContext ec, ListObject gradients, ListObject model) {
        ec.setVariable("gradients", gradients);
        ec.setVariable("model", model);
        this._inst.processInstruction(ec);
        ListObject newModel = ec.getListObject(this._outputName);
        ParamservUtils.cleanupListObject(ec, "model", newModel.getStatus());
        ParamservUtils.cleanupListObject(ec, "gradients");
        return newModel;
    }

    private boolean allFinished() {
        return !ArrayUtils.contains(this._finishedStates, false);
    }

    private void resetFinishedStates() {
        Arrays.fill(this._finishedStates, false);
    }

    private void setFinishedState(int workerID) {
        this._finishedStates[workerID] = true;
    }

    private void broadcastModel(boolean par) {
        IntStream stream = IntStream.range(0, this._modelMap.size());
        (par ? stream.parallel() : stream).forEach(workerID -> {
            try {
                this.broadcastModel(workerID);
            }
            catch (InterruptedException e) {
                throw new DMLRuntimeException("Paramserv func: some error occurred when broadcasting model", e);
            }
        });
    }

    private void broadcastModel(int workerID) throws InterruptedException {
        Timing tBroad = DMLScript.STATISTICS ? new Timing(true) : null;
        this._modelMap.get(workerID).put(ParamservUtils.copyList(this._model, false));
        if (DMLScript.STATISTICS) {
            Statistics.accPSModelBroadcastTime((long)tBroad.stop());
        }
    }
}

