package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;
import org.apache.sysds.hops.codegen.template.CPlanMemoTable;
import org.apache.sysds.hops.codegen.template.TemplateBase;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
import org.apache.sysds.runtime.matrix.data.Pair;

/* loaded from: input_file:org/apache/sysds/hops/codegen/template/TemplateOuterProduct.class */
public class TemplateOuterProduct extends TemplateBase {
    MMTSJ.MMTSJType mmtsj;

    public TemplateOuterProduct() {
        super(TemplateBase.TemplateType.OUTER);
        this.mmtsj = MMTSJ.MMTSJType.NONE;
    }

    public TemplateOuterProduct(TemplateBase.CloseType closeType) {
        super(TemplateBase.TemplateType.OUTER, closeType);
        this.mmtsj = MMTSJ.MMTSJType.NONE;
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public boolean open(Hop hop) {
        return (HopRewriteUtils.isOuterProductLikeMM(hop) || HopRewriteUtils.isOuterBinary(hop)) && hop.getDim1() > 256 && hop.getDim2() > 256;
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public boolean fuse(Hop hop, Hop hop2) {
        return !isClosed() && (((hop instanceof UnaryOp) && TemplateUtils.isOperationSupported(hop)) || (((hop instanceof BinaryOp) && TemplateUtils.isOperationSupported(hop) && (TemplateUtils.isBinaryMatrixColVector(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) || HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) || TemplateUtils.isBinaryMatrixRowVector(hop))) || ((HopRewriteUtils.isTransposeOperation(hop) && (hop2 instanceof AggBinaryOp) && !HopRewriteUtils.isOuterProductLikeMM(hop2)) || (((hop instanceof AggBinaryOp) && !HopRewriteUtils.isOuterProductLikeMM(hop) && TemplateUtils.containsOuterProduct(hop2, HopRewriteUtils.getOtherInput(hop, hop2))) || ((hop instanceof AggUnaryOp) && ((AggUnaryOp) hop).getDirection() == Types.Direction.RowCol)))));
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public boolean merge(Hop hop, Hop hop2) {
        return !isClosed() && (TemplateUtils.isBinaryMatrixRowVector(hop) || HopRewriteUtils.isBinaryMatrixScalarOperation(hop) || (HopRewriteUtils.isBinary(hop, Types.OpOp2.MULT) && HopRewriteUtils.isBinarySparseSafe(hop2) && !TemplateUtils.containsOuterProduct(hop2)));
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public TemplateBase.CloseType close(Hop hop) {
        return ((!(hop instanceof AggUnaryOp) || (!HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) && HopRewriteUtils.isBinarySparseSafe(hop.getInput().get(0)))) && !((hop instanceof AggBinaryOp) && (HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)) || HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(1)) || (!HopRewriteUtils.isOuterProductLikeMM(hop) && !HopRewriteUtils.isBinarySparseSafe(HopRewriteUtils.getLargestInput(hop)))))) ? ((hop instanceof AggUnaryOp) || !(!(hop instanceof AggBinaryOp) || HopRewriteUtils.isOuterProductLikeMM(hop) || HopRewriteUtils.isTransposeOperation(hop.getParent().get(0))) || (HopRewriteUtils.isTransposeOperation(hop) && (hop.getInput().get(0) instanceof AggBinaryOp) && !HopRewriteUtils.isOuterProductLikeMM(hop.getInput().get(0)))) ? TemplateBase.CloseType.CLOSED_VALID : (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) && HopRewriteUtils.isBinary(hop, Types.OpOp2.MULT, Types.OpOp2.DIV)) ? TemplateBase.CloseType.OPEN_VALID : TemplateBase.CloseType.OPEN_INVALID : TemplateBase.CloseType.CLOSED_INVALID;
    }

    @Override // org.apache.sysds.hops.codegen.template.TemplateBase
    public Pair<Hop[], CNodeTpl> constructCplan(Hop hop, CPlanMemoTable cPlanMemoTable, boolean z) {
        HashSet<Hop> hashSet = new HashSet<>();
        HashMap<String, Hop> hashMap = new HashMap<>();
        HashMap<Long, CNode> hashMap2 = new HashMap<>();
        hop.resetVisitStatus();
        rConstructCplan(hop, cPlanMemoTable, hashMap2, hashSet, hashMap, z);
        hop.resetVisitStatus();
        long hopID = hop.getHopID();
        if (hop instanceof BinaryOp) {
            hopID = TemplateUtils.skipConditionalInOuterProduct(hop, hashMap2, hashSet);
        }
        Hop hop2 = hashMap.get("_X");
        Hop hop3 = hashMap.get("_U");
        Hop hop4 = hashMap.get("_V");
        LinkedList linkedList = new LinkedList(hashSet);
        linkedList.remove(hop4);
        linkedList.remove(hop3);
        linkedList.remove(hop2);
        linkedList.addFirst(hop4);
        linkedList.addFirst(hop3);
        linkedList.addFirst(hop2);
        ArrayList arrayList = new ArrayList();
        Iterator it = linkedList.iterator();
        while (it.hasNext()) {
            Hop hop5 = (Hop) it.next();
            if (hop5 != null) {
                arrayList.add(hashMap2.get(Long.valueOf(hop5.getHopID())));
            }
        }
        CNodeOuterProduct cNodeOuterProduct = new CNodeOuterProduct(arrayList, hashMap2.get(Long.valueOf(hopID)), this.mmtsj);
        cNodeOuterProduct.setOutProdType(TemplateUtils.getOuterProductType(hop2, hop3, hop4, hop));
        cNodeOuterProduct.setTransposeOutput(!HopRewriteUtils.isTransposeOperation(hop) && cNodeOuterProduct.getOutProdType() == SpoofOuterProduct.OutProdType.LEFT_OUTER_PRODUCT);
        cNodeOuterProduct.setBeginLine(hop.getBeginLine());
        return new Pair<>(linkedList.toArray(new Hop[0]), cNodeOuterProduct);
    }

    private void rConstructCplan(Hop hop, CPlanMemoTable cPlanMemoTable, HashMap<Long, CNode> hashMap, HashSet<Hop> hashSet, HashMap<String, Hop> hashMap2, boolean z) {
        if (hashMap.containsKey(Long.valueOf(hop.getHopID()))) {
            return;
        }
        CPlanMemoTable.MemoTableEntry best = cPlanMemoTable.getBest(hop.getHopID(), TemplateBase.TemplateType.OUTER, TemplateBase.TemplateType.CELL);
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (best.isPlanRef(i)) {
                rConstructCplan(hop2, cPlanMemoTable, hashMap, hashSet, hashMap2, z);
            } else {
                hashMap.put(Long.valueOf(hop2.getHopID()), TemplateUtils.createCNodeData(hop2, z));
                hashSet.add(hop2);
            }
        }
        CNode cNode = null;
        if (hop instanceof UnaryOp) {
            cNode = new CNodeUnary(hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID())), CNodeUnary.UnaryType.valueOf(((UnaryOp) hop).getOp().name()));
        } else if (hop instanceof BinaryOp) {
            CNode cNode2 = hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID()));
            CNode cNode3 = hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID()));
            String name = ((BinaryOp) hop).getOp().name();
            if (HopRewriteUtils.isBinarySparseSafe(hop)) {
                if (TemplateUtils.isMatrix(hop.getInput().get(0)) && (cNode2 instanceof CNodeData)) {
                    hashMap2.put("_X", hop.getInput().get(0));
                }
                if (TemplateUtils.isMatrix(hop.getInput().get(1)) && (cNode3 instanceof CNodeData)) {
                    hashMap2.put("_X", hop.getInput().get(1));
                }
            }
            cNode = new CNodeBinary(TemplateUtils.wrapLookupIfNecessary(cNode2, hop.getInput().get(0)), TemplateUtils.wrapLookupIfNecessary(cNode3, hop.getInput().get(1)), CNodeBinary.BinType.valueOf(name));
        } else if (hop instanceof AggBinaryOp) {
            CNode cNode4 = hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID()));
            CNode cNode5 = hashMap.get(Long.valueOf(hop.getInput().get(1).getHopID()));
            CNode skipTranspose = TemplateUtils.skipTranspose(cNode4, hop.getInput().get(0), hashMap, z);
            CNode skipTranspose2 = TemplateUtils.skipTranspose(cNode5, hop.getInput().get(1), hashMap, z);
            if (HopRewriteUtils.isOuterProductLikeMM(hop)) {
                if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(0))) {
                    hashMap2.put("_U", hop.getInput().get(0).getInput().get(0));
                } else {
                    hashMap2.put("_U", hop.getInput().get(0));
                }
                if (HopRewriteUtils.isTransposeOperation(hop.getInput().get(1))) {
                    hashMap2.put("_V", hop.getInput().get(1).getInput().get(0));
                } else {
                    hashMap2.put("_V", hop.getInput().get(1));
                }
                this.mmtsj = ((AggBinaryOp) hop).checkTransposeSelf();
                cNode = new CNodeBinary(skipTranspose, skipTranspose2, CNodeBinary.BinType.DOT_PRODUCT);
            } else {
                cNode = skipTranspose.getDataType().isScalar() ? new CNodeBinary(skipTranspose2, skipTranspose, CNodeBinary.BinType.VECT_MULT_ADD) : new CNodeBinary(skipTranspose, skipTranspose2, CNodeBinary.BinType.VECT_MULT_ADD);
            }
        } else if (HopRewriteUtils.isTransposeOperation(hop)) {
            cNode = hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID()));
        } else if ((hop instanceof AggUnaryOp) && ((AggUnaryOp) hop).getOp() == Types.AggOp.SUM && ((AggUnaryOp) hop).getDirection() == Types.Direction.RowCol) {
            cNode = hashMap.get(Long.valueOf(hop.getInput().get(0).getHopID()));
        }
        hashMap.put(Long.valueOf(hop.getHopID()), cNode);
    }

    public static CPlanMemoTable.MemoTableEntry dropAlternativePlan(CPlanMemoTable cPlanMemoTable, CPlanMemoTable.MemoTableEntry memoTableEntry, CPlanMemoTable.MemoTableEntry memoTableEntry2) {
        if (memoTableEntry.countPlanRefs() != 1 || memoTableEntry2.countPlanRefs() != 1 || memoTableEntry.getPlanRefIndex() == memoTableEntry2.getPlanRefIndex()) {
            return null;
        }
        Hop hop = cPlanMemoTable._hopRefs.get(Long.valueOf(memoTableEntry.input(memoTableEntry.getPlanRefIndex())));
        Hop hop2 = cPlanMemoTable._hopRefs.get(Long.valueOf(memoTableEntry2.input(memoTableEntry2.getPlanRefIndex())));
        if (!cPlanMemoTable.contains(hop.getHopID(), TemplateBase.TemplateType.OUTER) || !cPlanMemoTable.contains(hop2.getHopID(), TemplateBase.TemplateType.OUTER)) {
            return null;
        }
        if (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) && HopRewriteUtils.isBinary(hop, Types.OpOp2.MULT, Types.OpOp2.DIV)) {
            return memoTableEntry;
        }
        if (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop2) && HopRewriteUtils.isBinary(hop2, Types.OpOp2.MULT, Types.OpOp2.DIV)) {
            return memoTableEntry2;
        }
        return null;
    }
}
