/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.fedplanner;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FederatedMemoTable;
import org.apache.sysds.hops.fedplanner.FederatedMemoTablePrinter;
import org.apache.sysds.hops.fedplanner.FederatedPlanCostEstimator;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

public class FederatedPlanCostEnumerator {
    public static FederatedMemoTable.FedPlan enumerateFederatedPlanCost(Hop rootHop, boolean printTree) {
        FederatedMemoTable memoTable = new FederatedMemoTable();
        FederatedPlanCostEnumerator.enumerateFederatedPlanCost(rootHop, memoTable);
        FederatedMemoTable.FedPlan optimalPlan = FederatedPlanCostEnumerator.getMinCostRootFedPlan(rootHop.getHopID(), memoTable);
        double additionalTotalCost = FederatedPlanCostEnumerator.detectAndResolveConflictFedPlan(optimalPlan, memoTable);
        if (printTree) {
            FederatedMemoTablePrinter.printFedPlanTree(optimalPlan, memoTable, additionalTotalCost);
        }
        return optimalPlan;
    }

    private static void enumerateFederatedPlanCost(Hop hop, FederatedMemoTable memoTable) {
        int numInputs = hop.getInput().size();
        for (Hop inputHop : hop.getInput()) {
            if (memoTable.contains(inputHop.getHopID(), FEDInstruction.FederatedOutput.FOUT) || memoTable.contains(inputHop.getHopID(), FEDInstruction.FederatedOutput.LOUT)) continue;
            FederatedPlanCostEnumerator.enumerateFederatedPlanCost(inputHop, memoTable);
        }
        for (int i = 0; i < 1 << numInputs; ++i) {
            ArrayList<Pair<Long, FEDInstruction.FederatedOutput>> planChilds = new ArrayList<Pair<Long, FEDInstruction.FederatedOutput>>();
            for (int j = 0; j < numInputs; ++j) {
                Hop inputHop = hop.getInput().get(j);
                FEDInstruction.FederatedOutput childType = (i & 1 << j) != 0 ? FEDInstruction.FederatedOutput.FOUT : FEDInstruction.FederatedOutput.LOUT;
                planChilds.add((Pair<Long, FEDInstruction.FederatedOutput>)Pair.of((Object)inputHop.getHopID(), (Object)((Object)childType)));
            }
            FederatedMemoTable.FedPlan fOutPlan = memoTable.addFedPlan(hop, FEDInstruction.FederatedOutput.FOUT, planChilds);
            FederatedPlanCostEstimator.computeFederatedPlanCost(fOutPlan, memoTable);
            FederatedMemoTable.FedPlan lOutPlan = memoTable.addFedPlan(hop, FEDInstruction.FederatedOutput.LOUT, planChilds);
            FederatedPlanCostEstimator.computeFederatedPlanCost(lOutPlan, memoTable);
        }
        memoTable.pruneFedPlan(hop.getHopID(), FEDInstruction.FederatedOutput.LOUT);
        memoTable.pruneFedPlan(hop.getHopID(), FEDInstruction.FederatedOutput.FOUT);
    }

    private static FederatedMemoTable.FedPlan getMinCostRootFedPlan(long HopID, FederatedMemoTable memoTable) {
        FederatedMemoTable.FedPlanVariants fOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FEDInstruction.FederatedOutput.FOUT);
        FederatedMemoTable.FedPlanVariants lOutFedPlanVariants = memoTable.getFedPlanVariants(HopID, FEDInstruction.FederatedOutput.LOUT);
        FederatedMemoTable.FedPlan minFOutFedPlan = fOutFedPlanVariants._fedPlanVariants.stream().min(Comparator.comparingDouble(FederatedMemoTable.FedPlan::getTotalCost)).orElse(null);
        FederatedMemoTable.FedPlan minlOutFedPlan = lOutFedPlanVariants._fedPlanVariants.stream().min(Comparator.comparingDouble(FederatedMemoTable.FedPlan::getTotalCost)).orElse(null);
        if (Objects.requireNonNull(minFOutFedPlan).getTotalCost() < Objects.requireNonNull(minlOutFedPlan).getTotalCost()) {
            return minFOutFedPlan;
        }
        return minlOutFedPlan;
    }

    private static double detectAndResolveConflictFedPlan(FederatedMemoTable.FedPlan rootPlan, FederatedMemoTable memoTable) {
        HashMap<Long, ImmutablePair> conflictCheckMap = new HashMap<Long, ImmutablePair>();
        LinkedHashMap<Long, List<FederatedMemoTable.FedPlan>> conflictLinkedMap = new LinkedHashMap<Long, List<FederatedMemoTable.FedPlan>>();
        LinkedHashMap<Object, Object> bfsLinkedMap = new LinkedHashMap<FederatedMemoTable.FedPlan, Boolean>();
        bfsLinkedMap.put(rootPlan, true);
        double[] cumulativeAdditionalCost = new double[]{0.0};
        while (!bfsLinkedMap.isEmpty()) {
            while (!bfsLinkedMap.isEmpty()) {
                FederatedMemoTable.FedPlan currentPlan = (FederatedMemoTable.FedPlan)bfsLinkedMap.keySet().iterator().next();
                bfsLinkedMap.remove(currentPlan);
                for (Pair<Long, FEDInstruction.FederatedOutput> childPlanPair : currentPlan.getChildFedPlans()) {
                    FederatedMemoTable.FedPlan childFedPlan = memoTable.getFedPlanAfterPrune(childPlanPair);
                    if (conflictCheckMap.containsKey(childPlanPair.getLeft())) {
                        Pair conflictChildPlanPair = (Pair)conflictCheckMap.get(childPlanPair.getLeft());
                        ((List)conflictChildPlanPair.getRight()).add(currentPlan);
                        if (conflictChildPlanPair.getLeft() == childPlanPair.getRight() || conflictLinkedMap.containsKey(childPlanPair.getLeft())) continue;
                        conflictLinkedMap.put((Long)childPlanPair.getLeft(), (List)conflictChildPlanPair.getRight());
                        bfsLinkedMap.remove(childFedPlan);
                        continue;
                    }
                    ArrayList<FederatedMemoTable.FedPlan> parentFedPlanList = new ArrayList<FederatedMemoTable.FedPlan>();
                    parentFedPlanList.add(currentPlan);
                    conflictCheckMap.put((Long)childPlanPair.getLeft(), new ImmutablePair((Object)((FEDInstruction.FederatedOutput)((Object)childPlanPair.getRight())), parentFedPlanList));
                    bfsLinkedMap.put(childFedPlan, true);
                }
            }
            bfsLinkedMap = FederatedPlanCostEstimator.resolveConflictFedPlan(memoTable, conflictLinkedMap, cumulativeAdditionalCost);
            conflictLinkedMap.clear();
        }
        return cumulativeAdditionalCost[0];
    }
}

