package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeNary;
import org.apache.sysds.hops.codegen.cplan.CNodeRow;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/hops/codegen/template/TemplateUtils.class */
public class TemplateUtils {
    public static final TemplateBase[] TEMPLATES = {new TemplateRow(), new TemplateCell(), new TemplateOuterProduct()};

    public static boolean isVector(Hop hop) {
        return hop.getDataType() == Types.DataType.MATRIX && ((hop.getDim1() != 1 && hop.getDim2() == 1) || (hop.getDim1() == 1 && hop.getDim2() != 1));
    }

    public static boolean isColVector(Hop hop) {
        return hop.getDataType() == Types.DataType.MATRIX && hop.getDim1() != 1 && hop.getDim2() == 1;
    }

    public static boolean isColVector(CNode cNode) {
        return cNode.getDataType() == Types.DataType.MATRIX && cNode.getNumRows() != 1 && cNode.getNumCols() == 1;
    }

    public static boolean isRowVector(CNode cNode) {
        return cNode.getDataType() == Types.DataType.MATRIX && cNode.getNumRows() == 1 && cNode.getNumCols() != 1;
    }

    public static boolean isMatrix(CNode cNode) {
        return (cNode.getDataType() != Types.DataType.MATRIX || cNode.getNumRows() == 1 || cNode.getNumCols() == 1) ? false : true;
    }

    public static CNode wrapLookupIfNecessary(CNode cNode, Hop hop) {
        return wrapLookupIfNecessary(cNode, hop, false);
    }

    public static CNode wrapLookupIfNecessary(CNode cNode, Hop hop, boolean z) {
        CNode cNode2 = cNode;
        if (isColVector(cNode)) {
            cNode2 = new CNodeUnary(cNode, CNodeUnary.UnaryType.LOOKUP_R);
        } else if (isRowVector(cNode)) {
            cNode2 = new CNodeUnary(cNode, CNodeUnary.UnaryType.LOOKUP_C);
        } else if ((cNode instanceof CNodeData) && hop.getDataType().isMatrix()) {
            cNode2 = z ? cNode : new CNodeUnary(cNode, CNodeUnary.UnaryType.LOOKUP_RC);
        }
        return cNode2;
    }

    public static boolean isMatrix(Hop hop) {
        return (hop.getDataType() != Types.DataType.MATRIX || hop.getDim1() == 1 || hop.getDim2() == 1) ? false : true;
    }

    public static boolean isVectorOrScalar(Hop hop) {
        return hop.dimsKnown() && (hop.getDataType() == Types.DataType.SCALAR || isVector(hop));
    }

    public static boolean isBinaryMatrixRowVector(Hop hop) {
        if (!(hop instanceof BinaryOp)) {
            return false;
        }
        Hop hop2 = hop.getInput().get(0);
        Hop hop3 = hop.getInput().get(1);
        return hop2.dimsKnown() && hop3.dimsKnown() && hop2.getDataType().isMatrix() && hop3.getDataType().isMatrix() && hop2.getDim1() > hop3.getDim1();
    }

    public static boolean isBinaryMatrixColVector(Hop hop) {
        if (!(hop instanceof BinaryOp)) {
            return false;
        }
        Hop hop2 = hop.getInput().get(0);
        Hop hop3 = hop.getInput().get(1);
        return hop2.dimsKnown() && hop3.dimsKnown() && hop2.getDataType().isMatrix() && hop3.getDataType().isMatrix() && hop2.getDim2() > hop3.getDim2();
    }

    public static boolean hasMatrixInput(Hop hop) {
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            if (isMatrix(it.next())) {
                return true;
            }
        }
        return false;
    }

    public static boolean isOperationSupported(Hop hop) {
        if (hop instanceof UnaryOp) {
            return CNodeUnary.UnaryType.contains(((UnaryOp) hop).getOp().name());
        }
        if ((hop instanceof BinaryOp) && !((BinaryOp) hop).isOuter()) {
            return CNodeBinary.BinType.contains(((BinaryOp) hop).getOp().name());
        }
        if (hop instanceof TernaryOp) {
            return CNodeTernary.TernaryType.contains(((TernaryOp) hop).getOp().name());
        }
        if (hop instanceof ParameterizedBuiltinOp) {
            return CNodeTernary.TernaryType.contains(((ParameterizedBuiltinOp) hop).getOp().name());
        }
        return false;
    }

    public static TemplateBase createTemplate(TemplateBase.TemplateType templateType) {
        return createTemplate(templateType, TemplateBase.CloseType.OPEN_VALID);
    }

    public static TemplateBase createTemplate(TemplateBase.TemplateType templateType, TemplateBase.CloseType closeType) {
        TemplateBase templateBase = null;
        switch (templateType) {
            case CELL:
                templateBase = new TemplateCell(closeType);
                break;
            case ROW:
                templateBase = new TemplateRow(closeType);
                break;
            case MAGG:
                templateBase = new TemplateMultiAgg(closeType);
                break;
            case OUTER:
                templateBase = new TemplateOuterProduct(closeType);
                break;
        }
        return templateBase;
    }

    public static TemplateBase[] createCompatibleTemplates(TemplateBase.TemplateType templateType, TemplateBase.CloseType closeType) {
        TemplateBase[] templateBaseArr = null;
        switch (templateType) {
            case CELL:
                templateBaseArr = new TemplateBase[]{new TemplateCell(closeType), new TemplateRow(closeType)};
                break;
            case ROW:
                templateBaseArr = new TemplateBase[]{new TemplateRow(closeType)};
                break;
            case MAGG:
                templateBaseArr = new TemplateBase[]{new TemplateMultiAgg(closeType)};
                break;
            case OUTER:
                templateBaseArr = new TemplateBase[]{new TemplateOuterProduct(closeType)};
                break;
        }
        return templateBaseArr;
    }

    public static SpoofCellwise.CellType getCellType(Hop hop) {
        if (hop instanceof AggBinaryOp) {
            return SpoofCellwise.CellType.FULL_AGG;
        }
        if (hop instanceof AggUnaryOp) {
            switch (((AggUnaryOp) hop).getDirection()) {
                case Row:
                    return SpoofCellwise.CellType.ROW_AGG;
                case Col:
                    return SpoofCellwise.CellType.COL_AGG;
                case RowCol:
                    return SpoofCellwise.CellType.FULL_AGG;
            }
        }
        return SpoofCellwise.CellType.NO_AGG;
    }

    public static SpoofRowwise.RowType getRowType(Hop hop, Hop... hopArr) {
        Hop hop2 = hopArr[0];
        Hop hop3 = hopArr.length > 1 ? hopArr[1] : null;
        if ((hop2 != null && HopRewriteUtils.isEqualSize(hop, hop2)) || hop2 == null || !hop2.dimsKnown()) {
            return SpoofRowwise.RowType.NO_AGG;
        }
        if (((hop3 != null && hop.getDim1() == hop2.getDim1() && hop.getDim2() == hop3.getDim2()) || ((hop instanceof IndexingOp) && HopRewriteUtils.isColumnRangeIndexing((IndexingOp) hop))) && (!(hop instanceof AggBinaryOp) || !HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop2))) {
            return SpoofRowwise.RowType.NO_AGG_B1;
        }
        if (hop.getDim1() == hop2.getDim1() && hop.getDim2() == 1 && (!(hop instanceof AggBinaryOp) || !HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop2))) {
            return SpoofRowwise.RowType.ROW_AGG;
        }
        if ((hop instanceof AggUnaryOp) && ((AggUnaryOp) hop).getDirection() == Types.Direction.RowCol) {
            return SpoofRowwise.RowType.FULL_AGG;
        }
        if (hop.getDim1() == hop2.getDim2() && hop.getDim2() == 1) {
            return SpoofRowwise.RowType.COL_AGG_T;
        }
        if (hop.getDim1() == 1 && hop.getDim2() == hop2.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG;
        }
        if (hop3 != null && hop.getDim1() == hop2.getDim2() && hop.getDim2() == hop3.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_B1_T;
        }
        if (hop3 != null && hop.getDim1() == hop3.getDim2() && hop.getDim2() == hop2.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_B1;
        }
        if (hop3 != null && hop.getDim1() == 1 && hop3.getDim2() == hop.getDim2()) {
            return SpoofRowwise.RowType.COL_AGG_B1R;
        }
        if (hop2.getDim1() == hop.getDim1() && hop2.getDim2() != hop.getDim2()) {
            return SpoofRowwise.RowType.NO_AGG_CONST;
        }
        if (hop.getDim1() != 1 || hop2.getDim2() == hop.getDim2()) {
            throw new RuntimeException("Unknown row type for hop " + hop.getHopID() + ".");
        }
        return SpoofRowwise.RowType.COL_AGG_CONST;
    }

    public static Types.AggOp getAggOp(Hop hop) {
        if (hop instanceof AggUnaryOp) {
            return ((AggUnaryOp) hop).getOp();
        }
        if (hop instanceof AggBinaryOp) {
            return Types.AggOp.SUM;
        }
        return null;
    }

    public static SpoofOuterProduct.OutProdType getOuterProductType(Hop hop, Hop hop2, Hop hop3, Hop hop4) {
        if (hop4.getDataType() == Types.DataType.SCALAR) {
            return SpoofOuterProduct.OutProdType.AGG_OUTER_PRODUCT;
        }
        if (((hop4 instanceof AggBinaryOp) && (hop4.getInput().get(0) == hop2 || (HopRewriteUtils.isTransposeOperation(hop4.getInput().get(0)) && hop4.getInput().get(0).getInput().get(0) == hop2))) || HopRewriteUtils.isTransposeOperation(hop4)) {
            return SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT;
        }
        if ((hop4 instanceof AggBinaryOp) && (hop4.getInput().get(1) == hop3 || (HopRewriteUtils.isTransposeOperation(hop4.getInput().get(1)) && hop4.getInput().get(1).getInput().get(0) == hop3))) {
            return SpoofOuterProduct.OutProdType.RIGHT_OUTER_PRODUCT;
        }
        if ((hop4 instanceof BinaryOp) && HopRewriteUtils.isEqualSize(hop4.getInput().get(0), hop4.getInput().get(1))) {
            return SpoofOuterProduct.OutProdType.CELLWISE_OUTER_PRODUCT;
        }
        throw new RuntimeException("Undefined outer product type for hop " + hop4.getHopID());
    }

    public static CNodeData getLiteral(CNode cNode) {
        return ((CNodeData) cNode).isLiteral() ? (CNodeData) cNode : createCNodeData(new LiteralOp(cNode.getVarname()), true);
    }

    public static boolean isLiteral(CNode cNode) {
        return (cNode instanceof CNodeData) && ((CNodeData) cNode).isLiteral();
    }

    public static boolean isLiteral(CNode cNode, String str) {
        return isLiteral(cNode) && ((CNodeData) cNode).getVarname().equals(str);
    }

    public static boolean isLookup(CNode cNode, boolean z) {
        return isUnary(cNode, CNodeUnary.UnaryType.LOOKUP_C, CNodeUnary.UnaryType.LOOKUP_RC) || (z && isUnary(cNode, CNodeUnary.UnaryType.LOOKUP_R)) || (z && isTernary(cNode, CNodeTernary.TernaryType.LOOKUP_RC1));
    }

    public static boolean isUnary(CNode cNode, CNodeUnary.UnaryType... unaryTypeArr) {
        return (cNode instanceof CNodeUnary) && ArrayUtils.contains(unaryTypeArr, ((CNodeUnary) cNode).getType());
    }

    public static boolean isUnaryRowAgg(CNode cNode) {
        return isUnary(cNode, CNodeUnary.UnaryType.ROW_MAXS, CNodeUnary.UnaryType.ROW_SUMS);
    }

    public static boolean isBinary(CNode cNode, CNodeBinary.BinType... binTypeArr) {
        return (cNode instanceof CNodeBinary) && ArrayUtils.contains(binTypeArr, ((CNodeBinary) cNode).getType());
    }

    public static boolean rIsSparseSafeOnly(CNode cNode, CNodeBinary.BinType... binTypeArr) {
        if (!isBinary(cNode, binTypeArr) && !(cNode instanceof CNodeData)) {
            if (!(cNode instanceof CNodeUnary)) {
                return false;
            }
            if (!((CNodeUnary) cNode).getType().isScalarLookup() && !((CNodeUnary) cNode).getType().isSparseSafeScalar() && ((CNodeUnary) cNode).getType() != CNodeUnary.UnaryType.POW2 && ((CNodeUnary) cNode).getType() != CNodeUnary.UnaryType.MULT2) {
                return false;
            }
        }
        boolean z = true;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            z &= rIsSparseSafeOnly(it.next(), binTypeArr);
        }
        return z;
    }

    public static boolean rContainsInput(CNode cNode, long j) {
        boolean z = false;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            z |= rContainsInput(it.next(), j);
        }
        if (cNode instanceof CNodeData) {
            z |= ((CNodeData) cNode).getHopID() == j;
        }
        return z;
    }

    public static boolean isTernary(CNode cNode, CNodeTernary.TernaryType... ternaryTypeArr) {
        return (cNode instanceof CNodeTernary) && ArrayUtils.contains(ternaryTypeArr, ((CNodeTernary) cNode).getType());
    }

    public static CNodeData createCNodeData(Hop hop, boolean z) {
        CNodeData cNodeData = new CNodeData(hop);
        cNodeData.setLiteral((hop instanceof LiteralOp) && (z || UtilFunctions.isIntegerNumber(((LiteralOp) hop).getStringValue())));
        return cNodeData;
    }

    public static CNode skipTranspose(CNode cNode, Hop hop, HashMap<Long, CNode> hashMap, boolean z) {
        if (!HopRewriteUtils.isTransposeOperation(hop)) {
            return cNode;
        }
        CNode cNode2 = hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID()));
        if (cNode2 == null) {
            cNode2 = createCNodeData(hop.getInput().get(0), z);
            hashMap.put(Long.valueOf(hop.getInput().get(0).getHopID()), cNode2);
        }
        hashMap.put(Long.valueOf(hop.getHopID()), cNode2);
        return cNode2;
    }

    public static LinkedList<Long> findRemovableConditionalPatternInOuterProduct(Hop hop) {
        LinkedList<Long> linkedList = new LinkedList<>();
        if (((BinaryOp) hop).getOp() == Types.OpOp2.MULT) {
            if ((hop.getInput().get(0) instanceof BinaryOp) && ((BinaryOp) hop.getInput().get(0)).getOp() == Types.OpOp2.NOTEQUAL) {
                linkedList.add(Long.valueOf(hop.getHopID()));
                linkedList.add(Long.valueOf(hop.getInput().get(0).getHopID()));
                linkedList.add(Long.valueOf(hop.getInput().get(0).getInput().get(0).getHopID()));
                linkedList.add(Long.valueOf(hop.getInput().get(0).getInput().get(1).getHopID()));
            } else if ((hop.getInput().get(1) instanceof BinaryOp) && ((BinaryOp) hop.getInput().get(1)).getOp() == Types.OpOp2.NOTEQUAL) {
                linkedList.add(Long.valueOf(hop.getHopID()));
                linkedList.add(Long.valueOf(hop.getInput().get(1).getHopID()));
                linkedList.add(Long.valueOf(hop.getInput().get(1).getInput().get(0).getHopID()));
                linkedList.add(Long.valueOf(hop.getInput().get(1).getInput().get(1).getHopID()));
            }
        }
        return linkedList;
    }

    public static long skipConditionalInOuterProduct(Hop hop, HashMap<Long, CNode> hashMap, HashSet<Hop> hashSet) {
        LinkedList<Long> findRemovableConditionalPatternInOuterProduct = findRemovableConditionalPatternInOuterProduct(hop);
        if (findRemovableConditionalPatternInOuterProduct.isEmpty()) {
            return hop.getHopID();
        }
        Iterator<Long> it = findRemovableConditionalPatternInOuterProduct.iterator();
        while (it.hasNext()) {
            long longValue = it.next().longValue();
            boolean z = false;
            Iterator<Hop> it2 = hashSet.iterator();
            while (it2.hasNext()) {
                z = it2.next().getHopID() == longValue;
                if (z) {
                    break;
                }
            }
            if (!z) {
                hashMap.remove(Long.valueOf(longValue));
            }
        }
        return hashMap.containsKey(Long.valueOf(hop.getInput().get(0).getHopID())) ? hop.getInput().get(0).getHopID() : hop.getInput().get(1).getHopID();
    }

    public static boolean hasTransposeParentUnderOuterProduct(Hop hop) {
        Iterator<Hop> it = hop.getParent().iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (HopRewriteUtils.isTransposeOperation(next)) {
                Iterator<Hop> it2 = next.getParent().iterator();
                while (it2.hasNext()) {
                    if (HopRewriteUtils.isOuterProductLikeMM(it2.next())) {
                        return true;
                    }
                }
            }
        }
        return false;
    }

    public static boolean hasSingleOperation(CNodeTpl cNodeTpl) {
        CNode output = cNodeTpl.getOutput();
        return (((output instanceof CNodeUnary) && !isUnary(output, CNodeUnary.UnaryType.EXP, CNodeUnary.UnaryType.LOG, CNodeUnary.UnaryType.ROW_COUNTNNZS)) || (((output instanceof CNodeBinary) && !isBinary(output, CNodeBinary.BinType.VECT_OUTERMULT_ADD) && isBinary(output, CNodeBinary.BinType.ROWMAXS_VECTMULT)) || ((output instanceof CNodeTernary) && ((CNodeTernary) output).getType() == CNodeTernary.TernaryType.IFELSE))) && hasOnlyDataNodeOrLookupInputs(output);
    }

    public static boolean isValidSingleOperation(Hop hop) {
        return HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) || HopRewriteUtils.isUnary(hop, Types.OpOp1.EXP, Types.OpOp1.LOG) || HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT);
    }

    public static boolean hasNoOperation(CNodeTpl cNodeTpl) {
        return (cNodeTpl.getOutput() instanceof CNodeData) || isLookup(cNodeTpl.getOutput(), true);
    }

    public static boolean hasOnlyDataNodeOrLookupInputs(CNode cNode) {
        boolean z = true;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            CNode next = it.next();
            z &= (next instanceof CNodeData) || ((next instanceof CNodeUnary) && (((CNodeUnary) next).getType() == CNodeUnary.UnaryType.LOOKUP0 || ((CNodeUnary) next).getType() == CNodeUnary.UnaryType.LOOKUP_R || ((CNodeUnary) next).getType() == CNodeUnary.UnaryType.LOOKUP_RC));
        }
        return z;
    }

    public static int determineMinVectorIntermediates(CNode cNode, CNode cNode2) {
        cNode.resetVisitStatus();
        int i = -1;
        switch (SpoofCompiler.REG_ALLOC_POLICY) {
            case HEURISTIC:
                boolean isUnaryOperatorPipeline = isUnaryOperatorPipeline(cNode);
                cNode.resetVisitStatus();
                i = isUnaryOperatorPipeline ? getMaxVectorIntermediates(cNode) : countVectorIntermediates(cNode);
                break;
            case EXACT_DYNAMIC_BUFF:
                Map<Long, Set<Long>> allParents = getAllParents(cNode);
                cNode.resetVisitStatus();
                i = getMaxLiveVectorIntermediates(cNode, cNode2, allParents, new HashSet());
                break;
            case EXACT_STATIC_BUFF:
                boolean isUnaryOperatorPipeline2 = isUnaryOperatorPipeline(cNode);
                cNode.resetVisitStatus();
                i = isUnaryOperatorPipeline2 ? getMaxVectorIntermediates(cNode) : countVectorIntermediates(cNode);
                Map<Long, Set<Long>> allParents2 = getAllParents(cNode);
                HashMap hashMap = new HashMap();
                HashSet hashSet = new HashSet();
                while (i > 0 && isValidNumVectorIntermediates(cNode, cNode2, allParents2, hashMap, hashSet, i - 1)) {
                    i--;
                }
        }
        cNode.resetVisitStatus();
        return i;
    }

    public static boolean isUnaryOperatorPipeline(CNode cNode) {
        if (cNode.isVisited()) {
            return (((cNode instanceof CNodeBinary) && ((CNodeBinary) cNode).getType().isVectorPrimitive()) || ((cNode instanceof CNodeTernary) && ((CNodeTernary) cNode).getType().isVectorPrimitive()) || ((cNode instanceof CNodeNary) && ((CNodeNary) cNode).getType().isVectorPrimitive())) ? false : true;
        }
        boolean z = true;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            z &= isUnaryOperatorPipeline(it.next());
        }
        cNode.setVisited();
        return z;
    }

    public static int getMaxVectorIntermediates(CNode cNode) {
        if (cNode.isVisited()) {
            return 0;
        }
        int i = 0;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            i = Math.max(i, getMaxVectorIntermediates(it.next()));
        }
        int max = Math.max(Math.max(Math.max(i, ((cNode instanceof CNodeTernary) && ((CNodeTernary) cNode).getType().isVectorPrimitive()) ? 1 : 0), cNode instanceof CNodeBinary ? ((CNodeBinary) cNode).getType().isVectorVectorPrimitive() ? 3 : ((CNodeBinary) cNode).getType().isVectorScalarPrimitive() ? 2 : ((CNodeBinary) cNode).getType().isVectorMatrixPrimitive() ? 1 : 0 : 0), ((cNode instanceof CNodeUnary) && ((CNodeUnary) cNode).getType().isVectorScalarPrimitive()) ? 2 : 0);
        cNode.setVisited();
        return max;
    }

    public static int countVectorIntermediates(CNode cNode) {
        if (cNode.isVisited()) {
            return 0;
        }
        cNode.setVisited();
        int i = 0;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            i += countVectorIntermediates(it.next());
        }
        int i2 = ((cNode instanceof CNodeBinary) && ((CNodeBinary) cNode).getType().isVectorPrimitive() && !((CNodeBinary) cNode).getType().name().endsWith("_ADD")) ? 1 : 0;
        return i + i2 + (((cNode instanceof CNodeUnary) && ((CNodeUnary) cNode).getType().isVectorScalarPrimitive()) ? 1 : 0) + (((cNode instanceof CNodeTernary) && ((CNodeTernary) cNode).getType().isVectorPrimitive()) ? 1 : 0) + (((cNode instanceof CNodeNary) && ((CNodeNary) cNode).getType().isVectorPrimitive()) ? 1 : 0);
    }

    public static int getMaxLiveVectorIntermediates(CNode cNode, CNode cNode2, Map<Long, Set<Long>> map, Set<Pair<Long, Long>> set) {
        if (cNode.isVisited()) {
            return -1;
        }
        int i = -1;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            i = Math.max(i, getMaxLiveVectorIntermediates(it.next(), cNode2, map, set));
        }
        if (!cNode.getDataType().isScalar() && map.containsKey(Long.valueOf(cNode.getID())) && cNode != cNode2) {
            Iterator<Long> it2 = map.get(Long.valueOf(cNode.getID())).iterator();
            while (it2.hasNext()) {
                set.add(Pair.of(it2.next(), Long.valueOf(cNode.getID())));
            }
        }
        int max = Math.max(i, (int) set.stream().map(pair -> {
            return (Long) pair.getValue();
        }).distinct().count());
        Iterator<CNode> it3 = cNode.getInput().iterator();
        while (it3.hasNext()) {
            set.remove(Pair.of(Long.valueOf(cNode.getID()), Long.valueOf(it3.next().getID())));
        }
        cNode.setVisited();
        return max;
    }

    public static boolean isValidNumVectorIntermediates(CNode cNode, CNode cNode2, Map<Long, Set<Long>> map, Map<Long, Pair<Long, MutableInt>> map2, Set<Long> set, int i) {
        if (i <= 1) {
            return false;
        }
        IDSequence iDSequence = new IDSequence(true, i - 1);
        map2.clear();
        set.clear();
        cNode.resetVisitStatus();
        return rIsValidNumVectorIntermediates(cNode, cNode2, map, map2, set, iDSequence);
    }

    public static boolean rIsValidNumVectorIntermediates(CNode cNode, CNode cNode2, Map<Long, Set<Long>> map, Map<Long, Pair<Long, MutableInt>> map2, Set<Long> set, IDSequence iDSequence) {
        if (cNode.isVisited()) {
            return true;
        }
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            if (!rIsValidNumVectorIntermediates(it.next(), cNode2, map, map2, set, iDSequence)) {
                return false;
            }
        }
        if (!cNode.getDataType().isScalar() && map.containsKey(Long.valueOf(cNode.getID())) && cNode != cNode2) {
            long nextID = iDSequence.getNextID();
            if (set.contains(Long.valueOf(nextID))) {
                return false;
            }
            map2.put(Long.valueOf(cNode.getID()), Pair.of(Long.valueOf(nextID), new MutableInt(map.get(Long.valueOf(cNode.getID())).size())));
            set.add(Long.valueOf(nextID));
        }
        Iterator<CNode> it2 = cNode.getInput().iterator();
        while (it2.hasNext()) {
            CNode next = it2.next();
            Pair<Long, MutableInt> pair = map2.get(Long.valueOf(next.getID()));
            if (pair != null) {
                ((MutableInt) pair.getValue()).decrement();
                if (((MutableInt) pair.getValue()).intValue() <= 0) {
                    map2.remove(Long.valueOf(next.getID()));
                    set.remove(pair.getKey());
                }
            }
        }
        cNode.setVisited();
        return true;
    }

    public static Map<Long, Set<Long>> getAllParents(CNode cNode) {
        HashMap hashMap = new HashMap();
        getAllParents(cNode, hashMap);
        return hashMap;
    }

    public static void getAllParents(CNode cNode, Map<Long, Set<Long>> map) {
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            CNode next = it.next();
            if (!map.containsKey(Long.valueOf(next.getID()))) {
                map.put(Long.valueOf(next.getID()), new HashSet());
            }
            map.get(Long.valueOf(next.getID())).add(Long.valueOf(cNode.getID()));
            getAllParents(next, map);
        }
    }

    public static boolean isType(TemplateBase.TemplateType templateType, TemplateBase.TemplateType... templateTypeArr) {
        return ArrayUtils.contains(templateTypeArr, templateType);
    }

    public static boolean hasCommonRowTemplateMatrixInput(Hop hop, Hop hop2, CPlanMemoTable cPlanMemoTable) {
        return !cPlanMemoTable.contains(hop2.getHopID(), TemplateBase.TemplateType.ROW) || getRowTemplateMatrixInput(hop, cPlanMemoTable) == getRowTemplateMatrixInput(hop2, cPlanMemoTable);
    }

    public static long getRowTemplateMatrixInput(Hop hop, CPlanMemoTable cPlanMemoTable) {
        CPlanMemoTable.MemoTableEntry best = cPlanMemoTable.getBest(hop.getHopID(), TemplateBase.TemplateType.ROW);
        long j = -1;
        for (int i = 0; j < 0 && i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (best.isPlanRef(i) && cPlanMemoTable.contains(hop2.getHopID(), TemplateBase.TemplateType.ROW)) {
                j = getRowTemplateMatrixInput(hop2, cPlanMemoTable);
            } else if (!best.isPlanRef(i) && isMatrix(hop2)) {
                j = hop2.getHopID();
            }
        }
        return j;
    }

    public static boolean containsBinary(CNode cNode, CNodeBinary.BinType binType) {
        cNode.resetVisitStatus();
        boolean rContainsBinary = rContainsBinary(cNode, binType);
        cNode.resetVisitStatus();
        return rContainsBinary;
    }

    public static boolean rContainsBinary(CNode cNode, CNodeBinary.BinType binType) {
        if (cNode.isVisited()) {
            return false;
        }
        boolean z = false;
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            z |= rContainsBinary(it.next(), binType);
        }
        boolean isBinary = z | isBinary(cNode, binType);
        cNode.setVisited();
        return isBinary;
    }

    public static boolean containsOuterProduct(Hop hop) {
        hop.resetVisitStatus();
        boolean rContainsOuterProduct = rContainsOuterProduct(hop);
        hop.resetVisitStatus();
        return rContainsOuterProduct;
    }

    public static boolean containsOuterProduct(Hop hop, Hop hop2) {
        hop.resetVisitStatus();
        boolean rContainsOuterProduct = rContainsOuterProduct(hop, hop2);
        hop.resetVisitStatus();
        return rContainsOuterProduct;
    }

    private static boolean rContainsOuterProduct(Hop hop) {
        if (hop.isVisited()) {
            return false;
        }
        boolean isOuterProductLikeMM = false | HopRewriteUtils.isOuterProductLikeMM(hop);
        for (int i = 0; i < hop.getInput().size() && !isOuterProductLikeMM; i++) {
            isOuterProductLikeMM |= rContainsOuterProduct(hop.getInput().get(i));
        }
        hop.setVisited();
        return isOuterProductLikeMM;
    }

    private static boolean rContainsOuterProduct(Hop hop, Hop hop2) {
        if (hop.isVisited()) {
            return false;
        }
        boolean z = false | (HopRewriteUtils.isOuterProductLikeMM(hop) && checkContainment(hop.getInput(), hop2, true));
        for (int i = 0; i < hop.getInput().size() && !z; i++) {
            z |= rContainsOuterProduct(hop.getInput().get(i), hop2);
        }
        hop.setVisited();
        return z;
    }

    private static boolean checkContainment(ArrayList<Hop> arrayList, Hop hop, boolean z) {
        if (!z) {
            return arrayList.contains(hop);
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            if (HopRewriteUtils.isTransposeOfItself(it.next(), hop)) {
                return true;
            }
        }
        return false;
    }

    public static void rFlipVectorLookups(CNode cNode) {
        if (isUnary(cNode, CNodeUnary.UnaryType.LOOKUP_C)) {
            ((CNodeUnary) cNode).setType(CNodeUnary.UnaryType.LOOKUP_R);
        } else if (isUnary(cNode, CNodeUnary.UnaryType.LOOKUP_R)) {
            ((CNodeUnary) cNode).setType(CNodeUnary.UnaryType.LOOKUP_C);
        }
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            rFlipVectorLookups(it.next());
        }
    }

    public static boolean containsFusedRowVecAgg(CNodeTpl cNodeTpl) {
        if (!(cNodeTpl instanceof CNodeRow)) {
            return false;
        }
        if (isBinary(cNodeTpl.getOutput(), CNodeBinary.BinType.ROWMAXS_VECTMULT)) {
            return true;
        }
        Iterator<CNode> it = cNodeTpl.getOutput().getInput().iterator();
        while (it.hasNext()) {
            if (isBinary(it.next(), CNodeBinary.BinType.ROWMAXS_VECTMULT)) {
                return true;
            }
        }
        return false;
    }
}
