/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.lops;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Map;
import org.apache.sysds.common.Opcodes;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.FunctionCallCP;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MMCJ;
import org.apache.sysds.lops.MMRJ;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.lops.MMZip;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.parser.StatementBlock;

public class OperatorOrderingUtils {
    public static ArrayList<Lop> getLopList(StatementBlock sb) {
        ArrayList<Lop> lops = null;
        if (sb.getLops() != null && !sb.getLops().isEmpty()) {
            lops = new ArrayList<Lop>();
            for (Lop root : sb.getLops()) {
                OperatorOrderingUtils.addToLopList(lops, root);
            }
        }
        return lops;
    }

    public static boolean isLopRoot(Lop lop) {
        if (lop.getOutputs().isEmpty()) {
            return true;
        }
        return lop instanceof FunctionCallCP && ((FunctionCallCP)lop).getFnamespace().equalsIgnoreCase("_internal");
    }

    public static int collectSparkRoots(Lop root, Map<Long, Integer> sparkOpCount, HashSet<Lop> sparkRoots) {
        if (sparkOpCount.containsKey(root.getID())) {
            return sparkOpCount.get(root.getID());
        }
        int total = 0;
        for (Lop input : root.getInputs()) {
            total += OperatorOrderingUtils.collectSparkRoots(input, sparkOpCount, sparkRoots);
        }
        total = root.isExecSpark() ? total + 1 : total;
        sparkOpCount.put(root.getID(), total);
        if (OperatorOrderingUtils.isSparkTriggeringOp(root)) {
            sparkRoots.add(root);
        }
        return total;
    }

    public static int collectGPURoots(Lop root, Map<Long, Integer> gpuOpCount, HashSet<Lop> gpuRoots) {
        if (gpuOpCount.containsKey(root.getID())) {
            return gpuOpCount.get(root.getID());
        }
        int total = 0;
        for (Lop input : root.getInputs()) {
            total += OperatorOrderingUtils.collectSparkRoots(input, gpuOpCount, gpuRoots);
        }
        total = root.isExecGPU() ? total + 1 : total;
        gpuOpCount.put(root.getID(), total);
        if (OperatorOrderingUtils.isD2HCopyOp(root)) {
            gpuRoots.add(root);
        }
        return total;
    }

    public static boolean isPersistableSparkOp(Lop lop) {
        return lop.isExecSpark() && (lop instanceof MapMult || lop instanceof MMCJ || lop instanceof MMRJ || lop instanceof MMZip || lop instanceof WeightedDivMMR);
    }

    private static boolean isSparkTriggeringOp(Lop lop) {
        boolean rightSpLop = lop.isExecSpark() && (lop.getAggType() == AggBinaryOp.SparkAggType.SINGLE_BLOCK || lop.getDataType() == Types.DataType.SCALAR || lop instanceof MapMultChain || lop instanceof PickByCount || lop instanceof MMZip || lop instanceof CentralMoment || lop instanceof CoVariance || lop instanceof MMTSJ || lop.isAllOutputsCP());
        boolean isPrefetched = lop.getOutputs().size() == 1 && lop.getOutputs().get(0) instanceof UnaryCP && ((UnaryCP)lop.getOutputs().get(0)).getOpCode().equalsIgnoreCase(Opcodes.PREFETCH.toString());
        boolean col2Bc = OperatorOrderingUtils.isCollectForBroadcast(lop);
        boolean prefetch = lop instanceof UnaryCP && ((UnaryCP)lop).getOpCode().equalsIgnoreCase(Opcodes.PREFETCH.toString());
        return (rightSpLop || col2Bc || prefetch) && !isPrefetched;
    }

    private static boolean isD2HCopyOp(Lop lop) {
        boolean rightGpuLop = lop.isExecGPU() && lop.isAllOutputsCP();
        boolean isPrefetched = lop.isExecGPU() && lop.getOutputs().size() == 1 && lop.getOutputs().get(0) instanceof UnaryCP && ((UnaryCP)lop.getOutputs().get(0)).getOpCode().equalsIgnoreCase(Opcodes.PREFETCH.toString());
        boolean prefetch = lop instanceof UnaryCP && ((UnaryCP)lop).getOpCode().equalsIgnoreCase(Opcodes.PREFETCH.toString());
        return (rightGpuLop || prefetch) && !isPrefetched;
    }

    public static boolean isCollectForBroadcast(Lop lop) {
        boolean isSparkOp = lop.isExecSpark();
        boolean isBc = lop.getOutputs().stream().allMatch(out -> out.getBroadcastInput() == lop);
        return isSparkOp && isBc && lop.getDataType() == Types.DataType.MATRIX;
    }

    public static void markSharedSparkOps(HashSet<Lop> sparkRoots, Map<Long, Integer> operatorJobCount) {
        for (Lop root : sparkRoots) {
            OperatorOrderingUtils.collectSharedSparkOps(root, operatorJobCount);
            root.resetVisitStatus();
        }
    }

    private static void collectSharedSparkOps(Lop root, Map<Long, Integer> operatorJobCount) {
        if (root.isVisited()) {
            return;
        }
        for (Lop input : root.getInputs()) {
            if (root.getBroadcastInput() == input) continue;
            OperatorOrderingUtils.collectSharedSparkOps(input, operatorJobCount);
        }
        operatorJobCount.merge(root.getID(), 1, Integer::sum);
        root.setVisited();
    }

    private static boolean addNode(ArrayList<Lop> lops, Lop node) {
        if (lops.contains(node)) {
            return false;
        }
        lops.add(node);
        return true;
    }

    private static void addToLopList(ArrayList<Lop> lops, Lop lop) {
        if (OperatorOrderingUtils.addNode(lops, lop)) {
            for (Lop in : lop.getInputs()) {
                OperatorOrderingUtils.addToLopList(lops, in);
            }
        }
    }
}

