package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.class */
public class RewriteIndexingVectorization extends HopRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) {
        if (arrayList == null) {
            return arrayList;
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rule_IndexingVectorization(it.next());
        }
        return arrayList;
    }

    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop == null) {
            return hop;
        }
        rule_IndexingVectorization(hop);
        return hop;
    }

    private void rule_IndexingVectorization(Hop hop) {
        if (hop.isVisited()) {
            return;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            rule_IndexingVectorization(vectorizeLeftIndexing(vectorizeRightLeftIndexingChains(hop.getInput().get(i))));
        }
        hop.setVisited();
    }

    private static Hop vectorizeRightLeftIndexingChains(Hop hop) {
        if (!(hop instanceof LeftIndexingOp) || !(hop.getInput().get(1) instanceof IndexingOp) || hop.getInput().get(1).getParent().size() != 1) {
            return hop;
        }
        LeftIndexingOp leftIndexingOp = (LeftIndexingOp) hop;
        IndexingOp indexingOp = (IndexingOp) hop.getInput().get(1);
        if ((!leftIndexingOp.isRowLowerEqualsUpper() && !leftIndexingOp.isColLowerEqualsUpper()) || leftIndexingOp.isRowLowerEqualsUpper() != indexingOp.isRowLowerEqualsUpper() || leftIndexingOp.isColLowerEqualsUpper() != indexingOp.isColLowerEqualsUpper()) {
            return hop;
        }
        boolean isRowLowerEqualsUpper = leftIndexingOp.isRowLowerEqualsUpper();
        if (!isRowLowerEqualsUpper ? HopRewriteUtils.isFullColumnIndexing(leftIndexingOp) : HopRewriteUtils.isFullRowIndexing(leftIndexingOp)) {
            if (!isRowLowerEqualsUpper ? HopRewriteUtils.isFullColumnIndexing(indexingOp) : HopRewriteUtils.isFullRowIndexing(indexingOp)) {
                ArrayList arrayList = new ArrayList();
                arrayList.add(leftIndexingOp);
                ArrayList arrayList2 = new ArrayList();
                arrayList2.add(indexingOp);
                LeftIndexingOp leftIndexingOp2 = leftIndexingOp;
                IndexingOp indexingOp2 = indexingOp;
                while (isConsecutiveLeftRightIndexing(leftIndexingOp2, indexingOp2, leftIndexingOp2.getInput().get(0)) && leftIndexingOp2.getInput().get(0).getParent().size() == 1 && leftIndexingOp2.getInput().get(0).getInput().get(1).getParent().size() == 1) {
                    leftIndexingOp2 = (LeftIndexingOp) leftIndexingOp2.getInput().get(0);
                    indexingOp2 = (IndexingOp) leftIndexingOp2.getInput().get(1);
                    arrayList.add(leftIndexingOp2);
                    arrayList2.add(indexingOp2);
                }
                if (arrayList.size() >= 2) {
                    IndexingOp indexingOp3 = (IndexingOp) arrayList2.get(arrayList2.size() - 1);
                    Hop hop2 = indexingOp3.getInput().get(1);
                    Hop createBinary = isRowLowerEqualsUpper ? HopRewriteUtils.createBinary(hop2, new LiteralOp(arrayList2.size() - 1), Types.OpOp2.PLUS) : indexingOp3.getInput().get(2);
                    Hop hop3 = indexingOp3.getInput().get(3);
                    IndexingOp createIndexingOp = HopRewriteUtils.createIndexingOp(indexingOp3.getInput().get(0), hop2, createBinary, hop3, isRowLowerEqualsUpper ? indexingOp3.getInput().get(4) : HopRewriteUtils.createBinary(hop3, new LiteralOp(arrayList2.size() - 1), Types.OpOp2.PLUS));
                    LeftIndexingOp leftIndexingOp3 = (LeftIndexingOp) arrayList.get(arrayList2.size() - 1);
                    Hop hop4 = leftIndexingOp3.getInput().get(2);
                    Hop createBinary2 = isRowLowerEqualsUpper ? HopRewriteUtils.createBinary(hop4, new LiteralOp(arrayList.size() - 1), Types.OpOp2.PLUS) : leftIndexingOp3.getInput().get(3);
                    Hop hop5 = leftIndexingOp3.getInput().get(4);
                    LeftIndexingOp createLeftIndexingOp = HopRewriteUtils.createLeftIndexingOp(leftIndexingOp3.getInput().get(0), createIndexingOp, hop4, createBinary2, hop5, isRowLowerEqualsUpper ? leftIndexingOp3.getInput().get(5) : HopRewriteUtils.createBinary(hop5, new LiteralOp(arrayList.size() - 1), Types.OpOp2.PLUS));
                    HopRewriteUtils.replaceChildReference(hop.getParent().get(0), hop, createLeftIndexingOp);
                    for (int i = 0; i < arrayList.size(); i++) {
                        HopRewriteUtils.removeAllChildReferences((Hop) arrayList.get(i));
                        HopRewriteUtils.removeAllChildReferences((Hop) arrayList2.get(i));
                    }
                    hop = createLeftIndexingOp;
                    LOG.debug("Applied vectorizeRightLeftIndexingChains (line " + hop.getBeginLine() + ")");
                }
                return hop;
            }
        }
        return hop;
    }

    private static boolean isConsecutiveLeftRightIndexing(LeftIndexingOp leftIndexingOp, IndexingOp indexingOp, Hop hop) {
        if (!(hop instanceof LeftIndexingOp) || !(hop.getInput().get(1) instanceof IndexingOp)) {
            return false;
        }
        boolean isRowLowerEqualsUpper = leftIndexingOp.isRowLowerEqualsUpper();
        LeftIndexingOp leftIndexingOp2 = (LeftIndexingOp) hop;
        IndexingOp indexingOp2 = (IndexingOp) hop.getInput().get(1);
        return (isRowLowerEqualsUpper ? HopRewriteUtils.isFullRowIndexing(leftIndexingOp2) && HopRewriteUtils.isFullRowIndexing(indexingOp2) : HopRewriteUtils.isFullColumnIndexing(leftIndexingOp2) && HopRewriteUtils.isFullColumnIndexing(indexingOp2)) && (indexingOp.getInput().get(0) == indexingOp2.getInput().get(0)) && (isRowLowerEqualsUpper ? HopRewriteUtils.isConsecutiveIndex(leftIndexingOp2.getInput().get(2), leftIndexingOp.getInput().get(2)) && HopRewriteUtils.isConsecutiveIndex(indexingOp2.getInput().get(1), indexingOp.getInput().get(1)) : HopRewriteUtils.isConsecutiveIndex(leftIndexingOp2.getInput().get(4), leftIndexingOp.getInput().get(4)) && HopRewriteUtils.isConsecutiveIndex(indexingOp2.getInput().get(3), indexingOp.getInput().get(3)));
    }

    private static void vectorizeRightIndexing(Hop hop) {
        if (hop instanceof IndexingOp) {
            IndexingOp indexingOp = (IndexingOp) hop;
            boolean isRowLowerEqualsUpper = indexingOp.isRowLowerEqualsUpper();
            boolean isColLowerEqualsUpper = indexingOp.isColLowerEqualsUpper();
            boolean z = false;
            if (isRowLowerEqualsUpper && isColLowerEqualsUpper) {
                Hop hop2 = indexingOp.getInput().get(0);
                ArrayList arrayList = new ArrayList();
                arrayList.add(indexingOp);
                Iterator<Hop> it = hop2.getParent().iterator();
                while (it.hasNext()) {
                    Hop next = it.next();
                    if (next != indexingOp && (next instanceof IndexingOp) && next.getInput().get(0) == hop2 && ((IndexingOp) next).isRowLowerEqualsUpper() && next.getInput().get(1) == indexingOp.getInput().get(1)) {
                        arrayList.add(next);
                    }
                }
                if (arrayList.size() > 1) {
                    IndexingOp indexingOp2 = new IndexingOp("tmp", hop2.getDataType(), hop2.getValueType(), hop2, indexingOp.getInput().get(1), indexingOp.getInput().get(1), new LiteralOp(1L), HopRewriteUtils.createValueHop(hop2, false), true, false);
                    HopRewriteUtils.setOutputParameters(indexingOp2, -1L, -1L, hop2.getBlocksize(), -1L);
                    indexingOp2.refreshSizeInformation();
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        Hop hop3 = (Hop) it2.next();
                        HopRewriteUtils.removeChildReference(hop3, hop2);
                        HopRewriteUtils.addChildReference(hop3, indexingOp2, 0);
                        HopRewriteUtils.removeChildReferenceByPos(hop3, hop3.getInput().get(1), 1);
                        HopRewriteUtils.addChildReference(hop3, new LiteralOp(1L), 1);
                        HopRewriteUtils.removeChildReferenceByPos(hop3, hop3.getInput().get(2), 2);
                        HopRewriteUtils.addChildReference(hop3, new LiteralOp(1L), 2);
                        hop3.refreshSizeInformation();
                    }
                    z = true;
                    LOG.debug("Applied vectorizeRightIndexingRow");
                }
            }
            if (isRowLowerEqualsUpper && isColLowerEqualsUpper && !z) {
                Hop hop4 = indexingOp.getInput().get(0);
                ArrayList arrayList2 = new ArrayList();
                arrayList2.add(indexingOp);
                Iterator<Hop> it3 = hop4.getParent().iterator();
                while (it3.hasNext()) {
                    Hop next2 = it3.next();
                    if (next2 != indexingOp && (next2 instanceof IndexingOp) && next2.getInput().get(0) == hop4 && ((IndexingOp) next2).isColLowerEqualsUpper() && next2.getInput().get(3) == indexingOp.getInput().get(3)) {
                        arrayList2.add(next2);
                    }
                }
                if (arrayList2.size() > 1) {
                    IndexingOp indexingOp3 = new IndexingOp("tmp", hop4.getDataType(), hop4.getValueType(), hop4, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop4, true), indexingOp.getInput().get(3), indexingOp.getInput().get(3), false, true);
                    HopRewriteUtils.setOutputParameters(indexingOp3, -1L, -1L, hop4.getBlocksize(), -1L);
                    indexingOp3.refreshSizeInformation();
                    Iterator it4 = arrayList2.iterator();
                    while (it4.hasNext()) {
                        Hop hop5 = (Hop) it4.next();
                        HopRewriteUtils.removeChildReference(hop5, hop4);
                        HopRewriteUtils.addChildReference(hop5, indexingOp3, 0);
                        HopRewriteUtils.replaceChildReference(hop5, hop5.getInput().get(3), new LiteralOp(1L), 3);
                        HopRewriteUtils.replaceChildReference(hop5, hop5.getInput().get(4), new LiteralOp(1L), 4);
                        hop5.refreshSizeInformation();
                    }
                    LOG.debug("Applied vectorizeRightIndexingCol");
                }
            }
        }
    }

    private static Hop vectorizeLeftIndexing(Hop hop) {
        LeftIndexingOp leftIndexingOp;
        LeftIndexingOp leftIndexingOp2;
        Hop hop2 = hop;
        if (hop instanceof LeftIndexingOp) {
            LeftIndexingOp leftIndexingOp3 = (LeftIndexingOp) hop;
            boolean isRowLowerEqualsUpper = leftIndexingOp3.isRowLowerEqualsUpper();
            boolean isColLowerEqualsUpper = leftIndexingOp3.isColLowerEqualsUpper();
            boolean z = false;
            if (isRowLowerEqualsUpper && isColLowerEqualsUpper) {
                ArrayList arrayList = new ArrayList();
                arrayList.add(leftIndexingOp3);
                LeftIndexingOp leftIndexingOp4 = leftIndexingOp3;
                while (true) {
                    leftIndexingOp2 = leftIndexingOp4;
                    if (!(leftIndexingOp2.getInput().get(0) instanceof LeftIndexingOp)) {
                        break;
                    }
                    LeftIndexingOp leftIndexingOp5 = (LeftIndexingOp) leftIndexingOp2.getInput().get(0);
                    if (leftIndexingOp5.getParent().size() > 1 || !leftIndexingOp5.isRowLowerEqualsUpper() || leftIndexingOp5.getInput().get(2) != leftIndexingOp3.getInput().get(2) || leftIndexingOp5.getInput().get(0).getDim2() <= 1) {
                        break;
                    }
                    arrayList.add(leftIndexingOp5);
                    leftIndexingOp4 = leftIndexingOp5;
                }
                if (arrayList.size() > 1) {
                    Hop hop3 = leftIndexingOp2.getInput().get(0);
                    Hop hop4 = leftIndexingOp3.getInput().get(2);
                    IndexingOp indexingOp = new IndexingOp("tmp1", hop3.getDataType(), hop3.getValueType(), hop3, hop4, hop4, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop3, false), true, false);
                    HopRewriteUtils.setOutputParameters(indexingOp, -1L, -1L, hop3.getBlocksize(), -1L);
                    indexingOp.refreshSizeInformation();
                    Iterator<Hop> it = indexingOp.getInput().iterator();
                    while (it.hasNext()) {
                        it.next().resetVisitStatus();
                    }
                    HopRewriteUtils.removeChildReference(leftIndexingOp2, hop3);
                    HopRewriteUtils.addChildReference(leftIndexingOp2, indexingOp, 0);
                    for (int size = arrayList.size() - 1; size >= 0; size--) {
                        Hop hop5 = (Hop) arrayList.get(size);
                        HopRewriteUtils.replaceChildReference(hop5, hop5.getInput().get(2), new LiteralOp(1L), 2);
                        HopRewriteUtils.replaceChildReference(hop5, hop5.getInput().get(3), new LiteralOp(1L), 3);
                        ((LeftIndexingOp) hop5).setRowLowerEqualsUpper(true);
                        hop5.refreshSizeInformation();
                    }
                    ArrayList arrayList2 = (ArrayList) leftIndexingOp3.getParent().clone();
                    ArrayList arrayList3 = new ArrayList();
                    Iterator it2 = arrayList2.iterator();
                    while (it2.hasNext()) {
                        Hop hop6 = (Hop) it2.next();
                        int childReferencePos = HopRewriteUtils.getChildReferencePos(hop6, leftIndexingOp3);
                        HopRewriteUtils.removeChildReferenceByPos(hop6, leftIndexingOp3, childReferencePos);
                        arrayList3.add(Integer.valueOf(childReferencePos));
                    }
                    LeftIndexingOp leftIndexingOp6 = new LeftIndexingOp("tmp2", hop3.getDataType(), hop3.getValueType(), hop3, leftIndexingOp3, hop4, hop4, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop3, false), true, false);
                    HopRewriteUtils.setOutputParameters(leftIndexingOp6, -1L, -1L, hop3.getBlocksize(), -1L);
                    leftIndexingOp6.refreshSizeInformation();
                    Iterator<Hop> it3 = leftIndexingOp6.getInput().iterator();
                    while (it3.hasNext()) {
                        it3.next().resetVisitStatus();
                    }
                    for (int i = 0; i < arrayList3.size(); i++) {
                        HopRewriteUtils.addChildReference((Hop) arrayList2.get(i), leftIndexingOp6, ((Integer) arrayList3.get(i)).intValue());
                    }
                    z = true;
                    hop2 = leftIndexingOp6;
                    LOG.debug("Applied vectorizeLeftIndexingRow for hop " + hop.getHopID());
                }
            }
            if (isRowLowerEqualsUpper && isColLowerEqualsUpper && !z) {
                ArrayList arrayList4 = new ArrayList();
                arrayList4.add(leftIndexingOp3);
                LeftIndexingOp leftIndexingOp7 = leftIndexingOp3;
                while (true) {
                    leftIndexingOp = leftIndexingOp7;
                    if (!(leftIndexingOp.getInput().get(0) instanceof LeftIndexingOp)) {
                        break;
                    }
                    LeftIndexingOp leftIndexingOp8 = (LeftIndexingOp) leftIndexingOp.getInput().get(0);
                    if (leftIndexingOp8.getParent().size() > 1 || !leftIndexingOp8.isColLowerEqualsUpper() || leftIndexingOp8.getInput().get(4) != leftIndexingOp3.getInput().get(4) || leftIndexingOp8.getInput().get(0).getDim1() <= 1) {
                        break;
                    }
                    arrayList4.add(leftIndexingOp8);
                    leftIndexingOp7 = leftIndexingOp8;
                }
                if (arrayList4.size() > 1) {
                    Hop hop7 = leftIndexingOp.getInput().get(0);
                    Hop hop8 = leftIndexingOp3.getInput().get(4);
                    IndexingOp indexingOp2 = new IndexingOp("tmp1", hop7.getDataType(), hop7.getValueType(), hop7, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop7, true), hop8, hop8, false, true);
                    HopRewriteUtils.setOutputParameters(indexingOp2, -1L, -1L, hop7.getBlocksize(), -1L);
                    indexingOp2.refreshSizeInformation();
                    Iterator<Hop> it4 = indexingOp2.getInput().iterator();
                    while (it4.hasNext()) {
                        it4.next().resetVisitStatus();
                    }
                    HopRewriteUtils.removeChildReference(leftIndexingOp, hop7);
                    HopRewriteUtils.addChildReference(leftIndexingOp, indexingOp2, 0);
                    for (int size2 = arrayList4.size() - 1; size2 >= 0; size2--) {
                        Hop hop9 = (Hop) arrayList4.get(size2);
                        HopRewriteUtils.replaceChildReference(hop9, hop9.getInput().get(4), new LiteralOp(1L), 4);
                        HopRewriteUtils.replaceChildReference(hop9, hop9.getInput().get(5), new LiteralOp(1L), 5);
                        ((LeftIndexingOp) hop9).setColLowerEqualsUpper(true);
                        hop9.refreshSizeInformation();
                    }
                    ArrayList arrayList5 = (ArrayList) leftIndexingOp3.getParent().clone();
                    ArrayList arrayList6 = new ArrayList();
                    Iterator it5 = arrayList5.iterator();
                    while (it5.hasNext()) {
                        Hop hop10 = (Hop) it5.next();
                        int childReferencePos2 = HopRewriteUtils.getChildReferencePos(hop10, leftIndexingOp3);
                        HopRewriteUtils.removeChildReferenceByPos(hop10, leftIndexingOp3, childReferencePos2);
                        arrayList6.add(Integer.valueOf(childReferencePos2));
                    }
                    LeftIndexingOp leftIndexingOp9 = new LeftIndexingOp("tmp2", hop7.getDataType(), hop7.getValueType(), hop7, leftIndexingOp3, new LiteralOp(1L), HopRewriteUtils.createValueHop(hop7, true), hop8, hop8, false, true);
                    HopRewriteUtils.setOutputParameters(leftIndexingOp9, -1L, -1L, hop7.getBlocksize(), -1L);
                    leftIndexingOp9.refreshSizeInformation();
                    Iterator<Hop> it6 = leftIndexingOp9.getInput().iterator();
                    while (it6.hasNext()) {
                        it6.next().resetVisitStatus();
                    }
                    for (int i2 = 0; i2 < arrayList6.size(); i2++) {
                        HopRewriteUtils.addChildReference((Hop) arrayList5.get(i2), leftIndexingOp9, ((Integer) arrayList6.get(i2)).intValue());
                    }
                    hop2 = leftIndexingOp9;
                    LOG.debug("Applied vectorizeLeftIndexingCol for hop " + hop.getHopID());
                }
            }
        }
        return hop2;
    }
}
