/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.hops.rewrite;

import java.util.ArrayList;
import org.apache.sysml.conf.ConfigurationManager;
import org.apache.sysml.hops.BinaryOp;
import org.apache.sysml.hops.DataOp;
import org.apache.sysml.hops.Hop;
import org.apache.sysml.hops.LiteralOp;
import org.apache.sysml.hops.UnaryOp;
import org.apache.sysml.hops.recompile.Recompiler;
import org.apache.sysml.hops.rewrite.HopRewriteRule;
import org.apache.sysml.hops.rewrite.HopRewriteUtils;
import org.apache.sysml.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysml.lops.Lop;
import org.apache.sysml.lops.compile.Dag;
import org.apache.sysml.parser.Expression;
import org.apache.sysml.runtime.controlprogram.Program;
import org.apache.sysml.runtime.controlprogram.ProgramBlock;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysml.runtime.instructions.Instruction;
import org.apache.sysml.runtime.instructions.cp.ScalarObject;
import org.apache.sysml.runtime.instructions.cp.ScalarObjectFactory;

public class RewriteConstantFolding
extends HopRewriteRule {
    private static final String TMP_VARNAME = "__cf_tmp";
    private ProgramBlock _tmpPB = null;
    private ExecutionContext _tmpEC = null;

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return null;
        }
        for (int i = 0; i < roots.size(); ++i) {
            Hop h = roots.get(i);
            roots.set(i, this.rule_ConstantFolding(h));
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return null;
        }
        return this.rule_ConstantFolding(root);
    }

    private Hop rule_ConstantFolding(Hop hop) {
        return this.rConstantFoldingExpression(hop);
    }

    private Hop rConstantFoldingExpression(Hop root) {
        if (root.isVisited()) {
            return root;
        }
        for (int i = 0; i < root.getInput().size(); ++i) {
            Hop h = root.getInput().get(i);
            this.rConstantFoldingExpression(h);
        }
        LiteralOp literal = null;
        if (root.getDataType() == Expression.DataType.SCALAR && (RewriteConstantFolding.isApplicableBinaryOp(root) || RewriteConstantFolding.isApplicableUnaryOp(root))) {
            try {
                literal = this.evalScalarOperation(root);
            }
            catch (Exception ex) {
                LOG.error("Failed to execute constant folding instructions. No abort.", ex);
            }
        } else if (RewriteConstantFolding.isApplicableFalseConjunctivePredicate(root)) {
            literal = new LiteralOp(false);
        } else if (RewriteConstantFolding.isApplicableTrueDisjunctivePredicate(root)) {
            literal = new LiteralOp(true);
        }
        if (literal != null) {
            if (!root.getParent().isEmpty()) {
                ArrayList<Hop> parents = new ArrayList<Hop>(root.getParent());
                for (Hop parent : parents) {
                    HopRewriteUtils.replaceChildReference(parent, root, literal);
                }
            } else {
                root = literal;
            }
        }
        root.setVisited();
        return root;
    }

    private LiteralOp evalScalarOperation(Hop bop) {
        DataOp tmpWrite = new DataOp(TMP_VARNAME, bop.getDataType(), bop.getValueType(), bop, Hop.DataOpTypes.TRANSIENTWRITE, TMP_VARNAME);
        Dag<Lop> dag = new Dag<Lop>();
        Recompiler.rClearLops(tmpWrite);
        Lop lops = tmpWrite.constructLops();
        lops.addToDag(dag);
        ArrayList<Instruction> inst = dag.getJobs(null, ConfigurationManager.getDMLConfig());
        ExecutionContext ec = this.getExecutionContext();
        ProgramBlock pb = this.getProgramBlock();
        pb.setInstructions(inst);
        pb.execute(ec);
        ScalarObject so = (ScalarObject)ec.getVariable(TMP_VARNAME);
        LiteralOp literal = ScalarObjectFactory.createLiteralOp(so);
        tmpWrite.getInput().clear();
        bop.getParent().remove(tmpWrite);
        pb.setInstructions(null);
        ec.getVariables().removeAll();
        HopRewriteUtils.setOutputParametersForScalar(literal);
        return literal;
    }

    private ProgramBlock getProgramBlock() {
        if (this._tmpPB == null) {
            this._tmpPB = new ProgramBlock(new Program());
        }
        return this._tmpPB;
    }

    private ExecutionContext getExecutionContext() {
        if (this._tmpEC == null) {
            this._tmpEC = ExecutionContextFactory.createContext();
        }
        return this._tmpEC;
    }

    private static boolean isApplicableBinaryOp(Hop hop) {
        ArrayList<Hop> in = hop.getInput();
        return hop instanceof BinaryOp && in.get(0) instanceof LiteralOp && in.get(1) instanceof LiteralOp && ((BinaryOp)hop).getOp() != Hop.OpOp2.CBIND && ((BinaryOp)hop).getOp() != Hop.OpOp2.RBIND;
    }

    private static boolean isApplicableUnaryOp(Hop hop) {
        ArrayList<Hop> in = hop.getInput();
        return hop instanceof UnaryOp && in.get(0) instanceof LiteralOp && ((UnaryOp)hop).getOp() != Hop.OpOp1.EXISTS && ((UnaryOp)hop).getOp() != Hop.OpOp1.PRINT && ((UnaryOp)hop).getOp() != Hop.OpOp1.ASSERT && ((UnaryOp)hop).getOp() != Hop.OpOp1.STOP && hop.getDataType() == Expression.DataType.SCALAR;
    }

    private static boolean isApplicableFalseConjunctivePredicate(Hop hop) {
        ArrayList<Hop> in = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Hop.OpOp2.AND) && hop.getDataType().isScalar() && (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue() || in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue());
    }

    private static boolean isApplicableTrueDisjunctivePredicate(Hop hop) {
        ArrayList<Hop> in = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Hop.OpOp2.OR) && hop.getDataType().isScalar() && (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue() || in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue());
    }
}

