package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataGenOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.DnnOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.NaryOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeCell;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeTernary;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.matrix.data.Pair;

/* loaded from: input_file:org/apache/sysds/hops/codegen/template/TemplateCell.class */
public class TemplateCell extends TemplateBase {
    private static final Types.AggOp[] SUPPORTED_AGG = {Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX};

    /* loaded from: input_file:org/apache/sysds/hops/codegen/template/TemplateCell$HopInputComparator.class */
    public static class HopInputComparator implements Comparator<Hop> {
        private final Hop _driver;

        public HopInputComparator() {
            this(null);
        }

        public HopInputComparator(Hop hop) {
            this._driver = hop;
        }

        @Override // java.util.Comparator
        public int compare(Hop hop, Hop hop2) {
            long length = hop.isScalar() ? Long.MIN_VALUE : hop.dimsKnown() ? hop.getLength() : Long.MAX_VALUE;
            long length2 = hop2.isScalar() ? Long.MIN_VALUE : hop2.dimsKnown() ? hop2.getLength() : Long.MAX_VALUE;
            if (length > length2 || hop == this._driver) {
                return -1;
            }
            if (length < length2 || hop2 == this._driver) {
                return 1;
            }
            return (hop.isScalar() && hop2.isScalar()) ? Long.compare(hop.getHopID(), hop2.getHopID()) : (hop.dimsKnown(true) && hop2.dimsKnown(true) && hop.getNnz() != hop2.getNnz() && (HopRewriteUtils.isSparse(hop, 1.0d) || HopRewriteUtils.isSparse(hop2, 1.0d))) ? Long.compare(hop.getNnz(), hop2.getNnz()) : Long.compare(hop.getHopID(), hop2.getHopID());
        }
    }

    public TemplateCell() {
        super(TemplateBase.TemplateType.CELL);
    }

    public TemplateCell(TemplateBase.CloseType closeType) {
        super(TemplateBase.TemplateType.CELL, closeType);
    }

    public TemplateCell(TemplateBase.TemplateType templateType, TemplateBase.CloseType closeType) {
        super(templateType, closeType);
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public boolean open(Hop hop) {
        return (hop.dimsKnown() && isValidOperation(hop) && !(hop.getDim1() == 1 && hop.getDim2() == 1)) || ((hop instanceof IndexingOp) && hop.getInput().get(0).getDim2() >= 0 && (((IndexingOp) hop).isColLowerEqualsUpper() || hop.getDim2() == 1)) || ((HopRewriteUtils.isDataGenOpWithLiteralInputs(hop, Types.OpOpDG.SEQ) && HopRewriteUtils.hasOnlyUnaryBinaryParents(hop, true)) || ((HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix()) || (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown())));
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public boolean fuse(Hop hop, Hop hop2) {
        return (!isClosed() && (isValidOperation(hop) || HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) || ((HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1() == 1 && hop.getDim2() == 1 && HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) || (HopRewriteUtils.isTransposeOperation(hop) && hop.getDim1() == 1 && hop.getDim2() > 1)))) || (HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix()) || (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown());
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public boolean merge(Hop hop, Hop hop2) {
        return (!isClosed() && (isValidOperation(hop) || ((hop instanceof AggBinaryOp) && hop.getInput().indexOf(hop2) == 0 && HopRewriteUtils.isTransposeOperation(hop2)))) || (HopRewriteUtils.isDataGenOpWithLiteralInputs(hop2, Types.OpOpDG.SEQ) && HopRewriteUtils.hasOnlyUnaryBinaryParents(hop2, false)) || ((HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.isMatrix()) || (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT) && hop.getInput().get(0).dimsKnown() && hop.getInput().get(1).dimsKnown()));
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public TemplateBase.CloseType close(Hop hop) {
        return (HopRewriteUtils.isAggUnaryOp(hop, SUPPORTED_AGG) || (HopRewriteUtils.isMatrixMultiply(hop) && hop.getDim1() == 1 && hop.getDim2() == 1)) ? TemplateBase.CloseType.CLOSED_VALID : ((hop instanceof AggUnaryOp) || (hop instanceof AggBinaryOp)) ? TemplateBase.CloseType.CLOSED_INVALID : TemplateBase.CloseType.OPEN_VALID;
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable cPlanMemoTable, boolean z) {
        HashSet<Hop> hashSet = new HashSet<>();
        HashMap<Long, CNode> hashMap = new HashMap<>();
        hop.resetVisitStatus();
        rConstructCplan(hop, cPlanMemoTable, hashMap, hashSet, z);
        hop.resetVisitStatus();
        Hop[] hopArr = (Hop[]) hashSet.stream().filter(hop2 -> {
            return (hop2.getDataType().isScalar() && ((CNode) hashMap.get(Long.valueOf(hop2.getHopID()))).isLiteral()) ? false : true;
        }).sorted(new HopInputComparator()).toArray(i -> {
            return new Hop[i];
        });
        ArrayList arrayList = new ArrayList();
        for (Hop hop3 : hopArr) {
            arrayList.add(hashMap.get(Long.valueOf(hop3.getHopID())));
        }
        if (arrayList.stream().allMatch(cNode -> {
            return cNode.getDataType().isScalar();
        })) {
            return null;
        }
        CNodeCell cNodeCell = new CNodeCell(arrayList, hashMap.get(Long.valueOf(hop.getHopID())));
        cNodeCell.setCellType(TemplateUtils.getCellType(hop));
        cNodeCell.setAggOp(TemplateUtils.getAggOp(hop));
        cNodeCell.setSparseSafe(isSparseSafe(Arrays.asList(hop), hopArr[0], Arrays.asList(cNodeCell.getOutput()), Arrays.asList(cNodeCell.getAggOp()), false));
        cNodeCell.setContainsSeq(rContainsSeq(cNodeCell.getOutput(), new HashSet<>()));
        cNodeCell.setRequiresCastDtm(hop instanceof AggBinaryOp);
        cNodeCell.setBeginLine(hop.getBeginLine());
        return new Pair<>(hopArr, cNodeCell);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void rConstructCplan(Hop hop, CPlanMemoTable cPlanMemoTable, HashMap<Long, CNode> hashMap, HashSet<Hop> hashSet, boolean z) {
        if (hashMap.containsKey(Long.valueOf(hop.getHopID()))) {
            return;
        }
        CPlanMemoTable.MemoTableEntry best = cPlanMemoTable.getBest(hop.getHopID(), TemplateBase.TemplateType.CELL);
        if (best != null && best.type.isIn(TemplateBase.TemplateType.ROW, TemplateBase.TemplateType.OUTER)) {
            hashMap.put(Long.valueOf(hop.getHopID()), TemplateUtils.createCNodeData(hop, z));
            hashSet.add(hop);
            return;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (best != null && best.isPlanRef(i) && !(hop2 instanceof DataOp) && (best.type != TemplateBase.TemplateType.MAGG || cPlanMemoTable.contains(hop2.getHopID(), TemplateBase.TemplateType.CELL))) {
                rConstructCplan(hop2, cPlanMemoTable, hashMap, hashSet, z);
            } else if (best == null || !((best.type == TemplateBase.TemplateType.MAGG || best.type == TemplateBase.TemplateType.CELL) && HopRewriteUtils.isMatrixMultiply(hop) && i == 0)) {
                hashMap.put(Long.valueOf(hop2.getHopID()), TemplateUtils.createCNodeData(hop2, z));
                hashSet.add(hop2);
            } else if (hop2.getInput().get(0) instanceof DataOp) {
                hashMap.put(Long.valueOf(hop2.getInput().get(0).getHopID()), TemplateUtils.createCNodeData(hop2.getInput().get(0), z));
                hashSet.add(hop2.getInput().get(0));
            } else {
                rConstructCplan(hop2.getInput().get(0), cPlanMemoTable, hashMap, hashSet, z);
            }
        }
        CNode cNode = null;
        if (hop instanceof UnaryOp) {
            cNode = new CNodeUnary(TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), hop.getInput().get(0)), CNodeUnary.UnaryType.valueOf(((UnaryOp) hop).getOp().name()));
        } else if (hop instanceof BinaryOp) {
            cNode = new CNodeBinary(TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), hop.getInput().get(0)), TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID())), hop.getInput().get(1)), CNodeBinary.BinType.valueOf(((BinaryOp) hop).getOp().name()));
        } else if (hop instanceof TernaryOp) {
            cNode = new CNodeTernary(TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), hop.getInput().get(0)), TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID())), hop.getInput().get(1)), TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(2).getHopID())), hop.getInput().get(2)), CNodeTernary.TernaryType.valueOf(((TernaryOp) hop).getOp().name()));
        } else if (HopRewriteUtils.isDnn(hop, Types.OpOpDnn.BIASADD, Types.OpOpDnn.BIASMULT)) {
            cNode = new CNodeTernary(TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), hop.getInput().get(0)), hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID())), TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2() / hop.getInput().get(1).getDim1()), true), CNodeTernary.TernaryType.valueOf(((DnnOp) hop).getOp().name()));
        } else if (HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS)) {
            String name = ((NaryOp) hop).getOp().name();
            CNode[] cNodeArr = (CNode[]) hop.getInput().stream().map(hop3 -> {
                return TemplateUtils.wrapLookupIfNecessary((CNode) hashMap.get(Long.valueOf(hop3.getHopID())), hop3);
            }).toArray(i2 -> {
                return new CNode[i2];
            });
            cNode = new CNodeBinary(cNodeArr[0], cNodeArr[1], CNodeBinary.BinType.valueOf(name));
            for (int i3 = 2; i3 < hop.getInput().size(); i3++) {
                cNode = new CNodeBinary(cNode, cNodeArr[i3], CNodeBinary.BinType.valueOf(name));
            }
        } else if (hop instanceof ParameterizedBuiltinOp) {
            CNode wrapLookupIfNecessary = TemplateUtils.wrapLookupIfNecessary(hashMap.get(Long.valueOf(((ParameterizedBuiltinOp) hop).getTargetHop().getHopID())), hop.getInput().get(0));
            CNode cNode2 = hashMap.get(Long.valueOf(((ParameterizedBuiltinOp) hop).getParameterHop("pattern").getHopID()));
            cNode = new CNodeTernary(wrapLookupIfNecessary, cNode2, hashMap.get(Long.valueOf(((ParameterizedBuiltinOp) hop).getParameterHop("replacement").getHopID())), (cNode2.isLiteral() && cNode2.getVarname().equals("Double.NaN")) ? CNodeTernary.TernaryType.REPLACE_NAN : CNodeTernary.TernaryType.REPLACE);
        } else if (hop instanceof IndexingOp) {
            cNode = new CNodeTernary(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), TemplateUtils.createCNodeData(new LiteralOp(hop.getInput().get(0).getDim2()), true), TemplateUtils.createCNodeData(hop.getInput().get(4), true), CNodeTernary.TernaryType.LOOKUP_RC1);
        } else if (HopRewriteUtils.isDataGenOp(hop, Types.OpOpDG.SEQ)) {
            CNodeData literal = TemplateUtils.getLiteral(hashMap.get(Long.valueOf(((DataGenOp) hop).getParam(Statement.SEQ_FROM).getHopID())));
            CNodeData literal2 = TemplateUtils.getLiteral(hashMap.get(Long.valueOf(((DataGenOp) hop).getParam(Statement.SEQ_TO).getHopID())));
            CNodeData literal3 = TemplateUtils.getLiteral(hashMap.get(Long.valueOf(((DataGenOp) hop).getParam(Statement.SEQ_INCR).getHopID())));
            if (Double.parseDouble(literal.getVarname()) > Double.parseDouble(literal2.getVarname()) && Double.parseDouble(literal3.getVarname()) > DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                literal3 = TemplateUtils.createCNodeData(new LiteralOp("-" + literal3.getVarname()), true);
            }
            cNode = new CNodeBinary(literal, literal3, CNodeBinary.BinType.SEQ_RIX);
        } else if (HopRewriteUtils.isTransposeOperation(hop)) {
            cNode = TemplateUtils.skipTranspose(hashMap.get(Long.valueOf(hop.getHopID())), hop, hashMap, z);
            if (!HopRewriteUtils.containsOp(hop.getParent(), AggBinaryOp.class)) {
                TemplateUtils.rFlipVectorLookups(cNode);
            }
            if ((cNode instanceof CNodeData) && !hashSet.contains(hop.getInput().get(0))) {
                hashSet.add(hop.getInput().get(0));
            }
        } else if (hop instanceof AggUnaryOp) {
            cNode = hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID()));
        } else if (hop instanceof AggBinaryOp) {
            if (HopRewriteUtils.isTransposeOfItself(hop.getInput().get(0), hop.getInput().get(1))) {
                CNode cNode3 = hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID()));
                if (TemplateUtils.isColVector(cNode3)) {
                    cNode3 = new CNodeUnary(cNode3, CNodeUnary.UnaryType.LOOKUP_R);
                }
                cNode = new CNodeUnary(cNode3, CNodeUnary.UnaryType.POW2);
            } else {
                CNode skipTranspose = TemplateUtils.skipTranspose(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), hop.getInput().get(0), hashMap, z);
                if ((skipTranspose instanceof CNodeData) && !hashSet.contains(hop.getInput().get(0).getInput().get(0))) {
                    hashSet.add(hop.getInput().get(0).getInput().get(0));
                }
                if (TemplateUtils.isColVector(skipTranspose)) {
                    skipTranspose = new CNodeUnary(skipTranspose, CNodeUnary.UnaryType.LOOKUP_R);
                }
                CNode cNode4 = hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID()));
                if (TemplateUtils.isColVector(cNode4)) {
                    cNode4 = new CNodeUnary(cNode4, CNodeUnary.UnaryType.LOOKUP_R);
                }
                cNode = new CNodeBinary(skipTranspose, cNode4, CNodeBinary.BinType.MULT);
            }
        }
        if (cNode != null) {
            hashMap.put(Long.valueOf(hop.getHopID()), cNode);
            return;
        }
        long hopID = hop.getHopID();
        hop.getOpString();
        HopsException hopsException = new HopsException(hopID + " " + hopsException);
        throw hopsException;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean isValidOperation(Hop hop) {
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        if ((hop instanceof BinaryOp) && hop.getDataType().isMatrix() && !((BinaryOp) hop).isOuter()) {
            Hop hop2 = hop.getInput().get(0);
            Hop hop3 = hop.getInput().get(1);
            z = hop2.getDataType().isScalar() || hop3.getDataType().isScalar();
            z2 = hop.dimsKnown() && ((hop2.getDataType().isMatrix() && TemplateUtils.isVectorOrScalar(hop3)) || (hop3.getDataType().isMatrix() && TemplateUtils.isVectorOrScalar(hop2)));
            z3 = hop.dimsKnown() && HopRewriteUtils.isEqualSize(hop2, hop3) && hop2.getDataType().isMatrix() && hop3.getDataType().isMatrix();
        }
        boolean z4 = false;
        boolean z5 = false;
        boolean z6 = HopRewriteUtils.isTernary(hop, Types.OpOp3.IFELSE) && hop.getDataType().isMatrix();
        if ((hop instanceof TernaryOp) && hop.getInput().size() == 3 && hop.dimsKnown() && HopRewriteUtils.checkInputDataTypes(hop, Types.DataType.MATRIX, Types.DataType.SCALAR, Types.DataType.MATRIX)) {
            Hop hop4 = hop.getInput().get(0);
            Hop hop5 = hop.getInput().get(2);
            z4 = TemplateUtils.isVector(hop4) && TemplateUtils.isVector(hop5);
            z5 = (!HopRewriteUtils.isEqualSize(hop4, hop5) || HopRewriteUtils.isSparse(hop4) || HopRewriteUtils.isSparse(hop5)) ? false : true;
        }
        return hop.getDataType() == Types.DataType.MATRIX && TemplateUtils.isOperationSupported(hop) && ((hop instanceof UnaryOp) || z || z2 || z3 || z4 || z5 || z6 || ((hop instanceof ParameterizedBuiltinOp) && ((ParameterizedBuiltinOp) hop).getOp() == Types.ParamBuiltinOp.REPLACE));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isSparseSafe(List<Hop> list, Hop hop, List<CNode> list2, List<Types.AggOp> list3, boolean z) {
        boolean z2 = true;
        for (int i = 0; i < list2.size() && z2; i++) {
            Hop hop2 = ((list.get(i) instanceof AggUnaryOp) || (list.get(i) instanceof AggBinaryOp)) ? list.get(i).getInput().get(0) : list.get(i);
            z2 &= (HopRewriteUtils.isBinarySparseSafe(hop2) && hop2.getInput().contains(hop)) || (HopRewriteUtils.isBinary(hop2, Types.OpOp2.DIV) && hop2.getInput().get(0) == hop) || (TemplateUtils.rIsSparseSafeOnly(list2.get(i), CNodeBinary.BinType.MULT) && TemplateUtils.rContainsInput(list2.get(i), hop.getHopID()));
            if (z) {
                z2 &= list3.get(i) == Types.AggOp.SUM || list3.get(i) == Types.AggOp.SUM_SQ;
            }
        }
        return z2;
    }

    protected boolean rContainsSeq(CNode cNode, HashSet<Long> hashSet) {
        if (hashSet.contains(Long.valueOf(cNode.getID()))) {
            return false;
        }
        boolean isBinary = TemplateUtils.isBinary(cNode, CNodeBinary.BinType.SEQ_RIX);
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            isBinary |= rContainsSeq(it.next(), hashSet);
        }
        hashSet.add(Long.valueOf(cNode.getID()));
        return isBinary;
    }
}
