/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.Arrays;
import java.util.Collections;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.And;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;

public class CtableFEDInstruction
extends ComputationFEDInstruction {
    private final CPOperand _outDim1;
    private final CPOperand _outDim2;

    private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, boolean dim2Literal, boolean isExpand, boolean ignoreZeros, String opcode, String istr) {
        super(FEDInstruction.FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
        this._outDim1 = new CPOperand(outputDim1, Types.ValueType.FP64, Types.DataType.SCALAR, dim1Literal);
        this._outDim2 = new CPOperand(outputDim2, Types.ValueType.FP64, Types.DataType.SCALAR, dim2Literal);
    }

    public static CtableFEDInstruction parseInstruction(String inst) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(inst);
        InstructionUtils.checkNumFields(parts, 7);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ctable")) {
            throw new DMLRuntimeException("Unexpected opcode in CtableFEDInstruction: " + inst);
        }
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        String[] dim1Fields = parts[4].split("\u00b7");
        String[] dim2Fields = parts[5].split("\u00b7");
        CPOperand out = new CPOperand(parts[6]);
        boolean ignoreZeros = Boolean.parseBoolean(parts[7]);
        return new CtableFEDInstruction(in1, in2, in3, out, dim1Fields[0], Boolean.parseBoolean(dim1Fields[1]), dim2Fields[0], Boolean.parseBoolean(dim2Fields[1]), false, ignoreZeros, opcode, inst);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        long dim1;
        boolean reversedWeights;
        MatrixObject mo1 = ec.getMatrixObject(this.input1);
        MatrixObject mo2 = ec.getMatrixObject(this.input2);
        boolean reversed = false;
        if (!mo1.isFederated() && mo2.isFederated()) {
            mo1 = ec.getMatrixObject(this.input2);
            mo2 = ec.getMatrixObject(this.input1);
            reversed = true;
        }
        Long[] dims1 = this.getOutputDimension(mo1, this.input1, this._outDim1, mo1.getFedMapping().getFederatedRanges());
        Long[] dims2 = this.getOutputDimension(mo2, this.input2, this._outDim2, mo1.getFedMapping().getFederatedRanges());
        CacheableData mo3 = this.input3 != null && this.input3.isMatrix() ? ec.getMatrixObject(this.input3) : null;
        boolean bl = reversedWeights = mo3 != null && mo3.isFederated() && !mo1.isFederated() && !mo2.isFederated();
        if (reversedWeights) {
            mo3 = mo1;
            mo1 = ec.getMatrixObject(this.input3);
        }
        boolean fedOutput = (dim1 = Collections.max(Arrays.asList(dims1), Long::compare).longValue()) % (long)mo1.getFedMapping().getSize() == 0L && (long)dims1.length == Arrays.stream(dims1).distinct().count();
        this.processRequest(ec, mo1, mo2, (MatrixObject)mo3, reversed, reversedWeights, fedOutput, dims1, dims2);
    }

    private void processRequest(ExecutionContext ec, MatrixObject mo1, MatrixObject mo2, MatrixObject mo3, boolean reversed, boolean reversedWeights, boolean fedOutput, Long[] dims1, Long[] dims2) {
        Future<FederatedResponse>[] ffr;
        FederatedRequest[] fr4;
        FederatedRequest fr2;
        FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false);
        if (mo3 == null) {
            fr2 = !reversed ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID()}) : FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{fr1[0].getID(), mo1.getFedMapping().getID()});
            FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
            fr4 = mo1.getFedMapping().cleanup(this.getTID(), fr1[0].getID());
            ffr = mo1.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[]{fr2, fr3, fr4});
        } else {
            fr4 = mo1.getFedMapping().broadcastSliced(mo3, false);
            fr2 = !reversed && !reversedWeights ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()}) : (reversed && !reversedWeights ? FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fr1[0].getID(), mo1.getFedMapping().getID(), fr4[0].getID()}) : FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{fr1[0].getID(), fr4[0].getID(), mo1.getFedMapping().getID()}));
            FederatedRequest fr3 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
            FederatedRequest fr5 = mo1.getFedMapping().cleanup(this.getTID(), fr1[0].getID(), fr4[0].getID());
            ffr = mo1.getFedMapping().execute(this.getTID(), true, fr1, fr4, new FederatedRequest[]{fr2, fr3, fr5});
        }
        if (fedOutput && this.isFedOutput(ffr, dims1)) {
            MatrixObject out = ec.getMatrixObject(this.output);
            FederationMap newFedMap = CtableFEDInstruction.modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
            CtableFEDInstruction.setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
        } else {
            ec.setMatrixOutput(this.output.getName(), CtableFEDInstruction.aggResult(ffr));
        }
    }

    boolean isFedOutput(Future<FederatedResponse>[] ffr, Long[] dims1) {
        boolean fedOutput = true;
        long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / (long)ffr.length;
        try {
            MatrixBlock prev = (MatrixBlock)ffr[0].get().getData()[0];
            for (int i = 1; i < ffr.length && fedOutput; ++i) {
                MatrixBlock curr = (MatrixBlock)ffr[i].get().getData()[0];
                MatrixBlock sliced = curr.slice((int)((long)curr.getNumRows() - fedSize), curr.getNumRows() - 1);
                if (curr.getNumRows() == (i + 1) * prev.getNumRows() && curr.getNonZeros() <= prev.getLength() && curr.getNumRows() - sliced.getNumRows() == i * prev.getNumRows() && curr.getNonZeros() - sliced.getNonZeros() == 0L) continue;
                MatrixBlock prevExtend = new MatrixBlock(curr.getNumRows(), curr.getNumColumns(), 0.0);
                prevExtend.copy(0, prev.getNumRows() - 1, 0, prev.getNumColumns() - 1, prev, true);
                MatrixBlock intersect = curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), prevExtend);
                if (intersect.getNonZeros() != 0L) {
                    fedOutput = false;
                }
                prev = sliced;
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return fedOutput;
    }

    private static void setFedOutput(MatrixObject mo1, MatrixObject out, FederationMap fedMap, Long[] dims1, long outId) {
        long fedSize = Collections.max(Arrays.asList(dims1), Long::compare) / (long)dims1.length;
        long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
        long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
        out.getDataCharacteristics().set(d1, d2, (int)mo1.getBlocksize(), mo1.getNnz());
        out.setFedMapping(fedMap.copyWithNewID(outId));
        long varID = FederationUtils.getNextFedDataID();
        out.getFedMapping().mapParallel(varID, (range, data) -> {
            try {
                FederatedResponse response = data.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new SliceOutput(data.getVarID(), fedSize))).get();
                if (!response.isSuccessful()) {
                    response.throwExceptionFromResponse();
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
            return null;
        });
    }

    private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
        MatrixBlock resultBlock = new MatrixBlock(1, 1, 0L);
        int dim1 = 0;
        int dim2 = 0;
        for (int i = 0; i < ffr.length; ++i) {
            try {
                MatrixBlock mb = (MatrixBlock)ffr[i].get().getData()[0];
                dim1 = mb.getNumRows() > dim1 ? mb.getNumRows() : dim1;
                dim2 = mb.getNumColumns() > dim2 ? mb.getNumColumns() : dim2;
                MatrixBlock prev = new MatrixBlock(dim1, dim2, 0.0);
                prev.copy(0, resultBlock.getNumRows() - 1, 0, resultBlock.getNumColumns() - 1, resultBlock, true);
                MatrixBlock next = new MatrixBlock(dim1, dim2, 0.0);
                next.copy(0, mb.getNumRows() - 1, 0, mb.getNumColumns() - 1, mb, true);
                BinaryOperator plus = InstructionUtils.parseBinaryOperator("+");
                resultBlock = prev.binaryOperationsInPlace(plus, next);
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        return resultBlock;
    }

    private static FederationMap modifyFedRanges(FederationMap fedMap, Long[] dims1, Long[] dims2) {
        IntStream.range(0, fedMap.getFederatedRanges().length).forEach(i -> {
            fedMap.getFederatedRanges()[i].setBeginDim(0, i == 0 ? 0L : fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
            fedMap.getFederatedRanges()[i].setEndDim(0, dims1[i]);
            fedMap.getFederatedRanges()[i].setBeginDim(1, i == 0 ? 0L : fedMap.getFederatedRanges()[i - 1].getBeginDims()[1]);
            fedMap.getFederatedRanges()[i].setEndDim(1, dims2[i]);
        });
        return fedMap;
    }

    private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, CPOperand outOp, FederatedRange[] federatedRanges) {
        Long[] fedDims = new Long[federatedRanges.length];
        if (!in.isFederated()) {
            MatrixBlock mb = (MatrixBlock)in.acquireReadAndRelease();
            IntStream.range(0, federatedRanges.length).forEach(i -> {
                MatrixBlock sliced = mb.slice(federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] - 1);
                fedDims[i] = (long)sliced.max();
            });
            return fedDims;
        }
        String maxInstString = this.constructMaxInstString(inOp.getName(), outOp.getName());
        FederationMap map = in.getFedMapping();
        FederatedRequest fr1 = FederationUtils.callInstruction(maxInstString, outOp, new CPOperand[]{inOp}, new long[]{in.getFedMapping().getID()});
        FederatedRequest fr2 = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr1.getID());
        FederatedRequest fr3 = map.cleanup(this.getTID(), fr1.getID());
        Future<FederatedResponse>[] tmp = map.execute(this.getTID(), fr1, fr2, fr3);
        return CtableFEDInstruction.computeOutputDims(tmp);
    }

    private static Long[] computeOutputDims(Future<FederatedResponse>[] tmp) {
        Long[] fedDims = new Long[tmp.length];
        for (int i = 0; i < tmp.length; ++i) {
            try {
                fedDims[i] = ((ScalarObject)tmp[i].get().getData()[0]).getLongValue();
                continue;
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        return fedDims;
    }

    private String constructMaxInstString(String in, String out) {
        String maxInstrString = this.instString.replace("ctable", "uamax");
        String[] instParts = maxInstrString.split("\u00b0");
        CharSequence[] maxInstParts = new String[]{instParts[0], instParts[1], InstructionUtils.concatOperandParts(in, Types.DataType.MATRIX.name(), Types.ValueType.FP64.name()), InstructionUtils.concatOperandParts(out, Types.DataType.SCALAR.name(), Types.ValueType.FP64.name()), "16"};
        return String.join((CharSequence)"\u00b0", maxInstParts);
    }

    private static class SliceOutput
    extends FederatedUDF {
        private static final long serialVersionUID = -2808597461054603816L;
        private final long _fedSize;

        protected SliceOutput(long input, long fedSize) {
            super(new long[]{input});
            this._fedSize = fedSize;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject mo = (MatrixObject)data[0];
            MatrixBlock mb = (MatrixBlock)mo.acquireReadAndRelease();
            MatrixBlock sliced = mb.slice((int)((long)mb.getNumRows() - this._fedSize), mb.getNumRows() - 1);
            mo.acquireModify(sliced);
            mo.release();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[0]);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

