/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatementBlock;

public class IPAPassRemoveUnnecessaryCheckpoints
extends IPAPass {
    @Override
    public boolean isApplicable(FunctionCallGraph fgraph) {
        return OptimizerUtils.isSparkExecutionMode();
    }

    @Override
    public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) {
        IPAPassRemoveUnnecessaryCheckpoints.removeCheckpointBeforeUpdate(prog);
        IPAPassRemoveUnnecessaryCheckpoints.moveCheckpointAfterUpdate(prog);
        IPAPassRemoveUnnecessaryCheckpoints.removeCheckpointReadWrite(prog);
        return false;
    }

    private static void removeCheckpointBeforeUpdate(DMLProgram dmlp) {
        HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
        for (StatementBlock sb : dmlp.getStatementBlocks()) {
            Iterator cand2;
            HashSet cands = new HashSet(chkpointCand.keySet());
            for (Iterator cand2 : cands) {
                if (!sb.variablesRead().containsVariable((String)((Object)cand2)) || sb.variablesUpdated().containsVariable((String)((Object)cand2))) continue;
                boolean skipRemove = false;
                if (sb.getHops() != null) {
                    Hop.resetVisitStatus(sb.getHops());
                    skipRemove = true;
                    for (Hop root : sb.getHops()) {
                        skipRemove &= !HopRewriteUtils.rContainsRead(root, cand2, false);
                    }
                }
                if (skipRemove) continue;
                chkpointCand.remove(cand2);
            }
            HashSet cands2 = new HashSet(chkpointCand.keySet());
            if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
                cand2 = cands2.iterator();
                while (cand2.hasNext()) {
                    String cand3 = (String)cand2.next();
                    if (!sb.variablesUpdated().containsVariable(cand3)) continue;
                    chkpointCand.remove(cand3);
                }
            } else {
                cand2 = cands2.iterator();
                while (cand2.hasNext()) {
                    String cand4 = (String)cand2.next();
                    if (!sb.variablesUpdated().containsVariable(cand4) || sb.getHops() == null) continue;
                    Hop.resetVisitStatus(sb.getHops());
                    for (Hop root : sb.getHops()) {
                        if (!root.getName().equals(cand4) || HopRewriteUtils.rHasSimpleReadChain(root, cand4)) continue;
                        chkpointCand.remove(cand4);
                    }
                }
            }
            if (!HopRewriteUtils.isLastLevelStatementBlock(sb)) continue;
            List<Hop> tmp = IPAPassRemoveUnnecessaryCheckpoints.collectCheckpoints(sb.getHops());
            for (Hop chkpoint : tmp) {
                if (chkpointCand.containsKey(chkpoint.getName())) {
                    ((Hop)chkpointCand.get(chkpoint.getName())).setRequiresCheckpoint(false);
                }
                chkpointCand.put(chkpoint.getName(), chkpoint);
            }
        }
    }

    private static void moveCheckpointAfterUpdate(DMLProgram dmlp) {
        HashMap<String, Hop> chkpointCand = new HashMap<String, Hop>();
        for (StatementBlock sb : dmlp.getStatementBlocks()) {
            Iterator cand2;
            HashSet cands = new HashSet(chkpointCand.keySet());
            for (Iterator cand2 : cands) {
                if (!sb.variablesRead().containsVariable((String)((Object)cand2)) || sb.variablesUpdated().containsVariable((String)((Object)cand2))) continue;
                boolean skipRemove = false;
                if (sb.getHops() != null) {
                    Hop.resetVisitStatus(sb.getHops());
                    skipRemove = true;
                    for (Hop root : sb.getHops()) {
                        skipRemove &= !HopRewriteUtils.rContainsRead(root, cand2, false);
                    }
                }
                if (skipRemove) continue;
                chkpointCand.remove(cand2);
            }
            HashSet cands2 = new HashSet(chkpointCand.keySet());
            if (sb instanceof IfStatementBlock || sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
                cand2 = cands2.iterator();
                while (cand2.hasNext()) {
                    String cand3 = (String)cand2.next();
                    if (!sb.variablesUpdated().containsVariable(cand3)) continue;
                    chkpointCand.remove(cand3);
                }
            } else {
                cand2 = cands2.iterator();
                while (cand2.hasNext()) {
                    String cand4 = (String)cand2.next();
                    if (!sb.variablesUpdated().containsVariable(cand4) || sb.getHops() == null) continue;
                    Hop.resetVisitStatus(sb.getHops());
                    for (Hop root : sb.getHops()) {
                        if (!root.getName().equals(cand4)) continue;
                        if (HopRewriteUtils.rHasSimpleReadChain(root, cand4)) {
                            ((Hop)chkpointCand.get(cand4)).setRequiresCheckpoint(false);
                            root.getInput().get(0).setRequiresCheckpoint(true);
                            chkpointCand.put(cand4, root.getInput().get(0));
                            continue;
                        }
                        chkpointCand.remove(cand4);
                    }
                }
            }
            if (!HopRewriteUtils.isLastLevelStatementBlock(sb)) continue;
            List<Hop> tmp = IPAPassRemoveUnnecessaryCheckpoints.collectCheckpoints(sb.getHops());
            for (Hop chkpoint : tmp) {
                chkpointCand.put(chkpoint.getName(), chkpoint);
            }
        }
    }

    private static void removeCheckpointReadWrite(DMLProgram dmlp) {
        ArrayList<StatementBlock> sbs = dmlp.getStatementBlocks();
        if (!(sbs.size() != 1 || sbs.get(0) instanceof IfStatementBlock || sbs.get(0) instanceof WhileStatementBlock || sbs.get(0) instanceof ForStatementBlock || ((StatementBlock)sbs.get(0)).getHops() == null)) {
            Hop.resetVisitStatus(((StatementBlock)sbs.get(0)).getHops());
            for (Hop root : ((StatementBlock)sbs.get(0)).getHops()) {
                IPAPassRemoveUnnecessaryCheckpoints.rRemoveCheckpointReadWrite(root);
            }
        }
    }

    private static List<Hop> collectCheckpoints(List<Hop> roots) {
        ArrayList<Hop> ret = new ArrayList<Hop>();
        if (roots != null) {
            Hop.resetVisitStatus(roots);
            for (Hop root : roots) {
                IPAPassRemoveUnnecessaryCheckpoints.rCollectCheckpoints(root, ret);
            }
        }
        return ret;
    }

    private static void rCollectCheckpoints(Hop hop, List<Hop> checkpoints) {
        if (hop.isVisited()) {
            return;
        }
        if (hop.requiresCheckpoint() && hop.getParent().size() == 1 && hop.getParent().get(0) instanceof DataOp && ((DataOp)hop.getParent().get(0)).getOp() == Types.OpOpData.TRANSIENTWRITE) {
            checkpoints.add(hop);
        }
        for (Hop c : hop.getInput()) {
            IPAPassRemoveUnnecessaryCheckpoints.rCollectCheckpoints(c, checkpoints);
        }
        hop.setVisited();
    }

    public static void rRemoveCheckpointReadWrite(Hop hop) {
        if (hop.isVisited()) {
            return;
        }
        if (hop instanceof DataOp && ((DataOp)hop).getOp() == Types.OpOpData.PERSISTENTWRITE || hop instanceof AggUnaryOp) {
            Hop c0 = hop.getInput().get(0);
            if (c0.requiresCheckpoint() && c0.getParent().size() == 1 && c0 instanceof DataOp && ((DataOp)c0).getOp() == Types.OpOpData.PERSISTENTREAD) {
                c0.setRequiresCheckpoint(false);
            }
            if (c0 instanceof UnaryOp && c0.getParent().size() == 1 && (((UnaryOp)c0).getOp() == Types.OpOp1.CAST_AS_FRAME || ((UnaryOp)c0).getOp() == Types.OpOp1.CAST_AS_MATRIX) && c0.getInput().get(0).requiresCheckpoint() && c0.getInput().get(0).getParent().size() == 1 && c0.getInput().get(0) instanceof DataOp && ((DataOp)c0.getInput().get(0)).getOp() == Types.OpOpData.PERSISTENTREAD) {
                c0.getInput().get(0).setRequiresCheckpoint(false);
            }
        }
        for (Hop c : hop.getInput()) {
            IPAPassRemoveUnnecessaryCheckpoints.rRemoveCheckpointReadWrite(c);
        }
        hop.setVisited();
    }
}

