package org.apache.sysds.hops.rewrite;

import java.util.Arrays;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteForLoopVectorization.class */
public class RewriteForLoopVectorization extends StatementBlockRewriteRule {
    private static final Types.OpOp2[] MAP_SCALAR_AGGREGATE_SOURCE_OPS = {Types.OpOp2.PLUS, Types.OpOp2.MULT, Types.OpOp2.MIN, Types.OpOp2.MAX};
    private static final Types.AggOp[] MAP_SCALAR_AGGREGATE_TARGET_OPS = {Types.AggOp.SUM, Types.AggOp.PROD, Types.AggOp.MIN, Types.AggOp.MAX};

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (statementBlock instanceof ForStatementBlock) {
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
            Hop fromHops = forStatementBlock.getFromHops();
            Hop toHops = forStatementBlock.getToHops();
            Hop incrementHops = forStatementBlock.getIncrementHops();
            String name = forStatementBlock.getIterPredicate().getIterVar().getName();
            if (forStatement.getBody() != null && forStatement.getBody().size() == 1) {
                StatementBlock statementBlock2 = forStatement.getBody().get(0);
                if (!(statementBlock2 instanceof WhileStatementBlock) && !(statementBlock2 instanceof IfStatementBlock) && !(statementBlock2 instanceof ForStatementBlock)) {
                    statementBlock = vectorizeIndexedCopy(vectorizeElementwiseUnary(vectorizeElementwiseBinary(vectorizeScalarAggregate(statementBlock, statementBlock2, fromHops, toHops, incrementHops, name), statementBlock2, fromHops, toHops, incrementHops, name), statementBlock2, fromHops, toHops, incrementHops, name), statementBlock2, fromHops, toHops, incrementHops, name);
                }
            }
        }
        return Arrays.asList(statementBlock);
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }

    private static StatementBlock vectorizeScalarAggregate(StatementBlock statementBlock, StatementBlock statementBlock2, Hop hop, Hop hop2, Hop hop3, String str) {
        StatementBlock statementBlock3 = statementBlock;
        if (hop3 == null || !(hop3 instanceof LiteralOp) || ((LiteralOp) hop3).getDoubleValue() != 1.0d) {
            return statementBlock3;
        }
        boolean z = false;
        boolean z2 = false;
        boolean z3 = false;
        if (statementBlock2.getHops() != null && statementBlock2.getHops().size() == 1) {
            Hop hop4 = statementBlock2.getHops().get(0);
            if (hop4.getDataType() == Types.DataType.SCALAR && (hop4.getInput().get(0) instanceof BinaryOp)) {
                BinaryOp binaryOp = (BinaryOp) hop4.getInput().get(0);
                Hop hop5 = binaryOp.getInput().get(0);
                Hop hop6 = binaryOp.getInput().get(1);
                if (HopRewriteUtils.isValidOp(binaryOp.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && (hop5 instanceof DataOp) && hop5.getDataType() == Types.DataType.SCALAR && hop4.getName().equals(hop5.getName()) && (hop6 instanceof UnaryOp) && ((UnaryOp) hop6).getOp() == Types.OpOp1.CAST_AS_SCALAR && (hop6.getInput().get(0) instanceof IndexingOp)) {
                    IndexingOp indexingOp = (IndexingOp) hop6.getInput().get(0);
                    if (indexingOp.isRowLowerEqualsUpper() && (indexingOp.getInput().get(1) instanceof DataOp) && indexingOp.getInput().get(1).getName().equals(str)) {
                        z = true;
                        z3 = true;
                    } else if (indexingOp.isColLowerEqualsUpper() && (indexingOp.getInput().get(3) instanceof DataOp) && indexingOp.getInput().get(3).getName().equals(str)) {
                        z = true;
                        z3 = false;
                    }
                } else if (HopRewriteUtils.isValidOp(binaryOp.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS) && (hop6 instanceof DataOp) && hop6.getDataType() == Types.DataType.SCALAR && hop4.getName().equals(hop6.getName()) && (hop5 instanceof UnaryOp) && ((UnaryOp) hop5).getOp() == Types.OpOp1.CAST_AS_SCALAR && (hop5.getInput().get(0) instanceof IndexingOp)) {
                    IndexingOp indexingOp2 = (IndexingOp) hop5.getInput().get(0);
                    if (indexingOp2.isRowLowerEqualsUpper() && (indexingOp2.getInput().get(1) instanceof DataOp) && indexingOp2.getInput().get(1).getName().equals(str)) {
                        z2 = true;
                        z3 = true;
                    } else if (indexingOp2.isColLowerEqualsUpper() && (indexingOp2.getInput().get(3) instanceof DataOp) && indexingOp2.getInput().get(3).getName().equals(str)) {
                        z2 = true;
                        z3 = false;
                    }
                }
            }
        }
        if (z || z2) {
            BinaryOp binaryOp2 = (BinaryOp) statementBlock2.getHops().get(0).getInput().get(0);
            Hop hop7 = binaryOp2.getInput().get(z ? 1 : 0);
            Hop hop8 = hop7.getInput().get(0);
            AggUnaryOp createAggUnaryOp = HopRewriteUtils.createAggUnaryOp(hop8, MAP_SCALAR_AGGREGATE_TARGET_OPS[HopRewriteUtils.getValidOpPos(binaryOp2.getOp(), MAP_SCALAR_AGGREGATE_SOURCE_OPS)], Types.Direction.RowCol);
            HopRewriteUtils.removeChildReference(hop7, hop8);
            HopRewriteUtils.removeChildReference(binaryOp2, hop7);
            HopRewriteUtils.addChildReference(binaryOp2, createAggUnaryOp, z ? 1 : 0);
            int i = z3 ? 1 : 3;
            int i2 = z3 ? 2 : 4;
            HopRewriteUtils.replaceChildReference(hop8, hop8.getInput().get(i), hop, i);
            HopRewriteUtils.replaceChildReference(hop8, hop8.getInput().get(i2), hop2, i2);
            if (z3) {
                ((IndexingOp) hop8).setRowLowerEqualsUpper(false);
            } else {
                ((IndexingOp) hop8).setColLowerEqualsUpper(false);
            }
            hop8.refreshSizeInformation();
            statementBlock3 = statementBlock2;
            LOG.debug("Applied vectorizeScalarSumForLoop.");
        }
        return statementBlock3;
    }

    private static StatementBlock vectorizeElementwiseBinary(StatementBlock statementBlock, StatementBlock statementBlock2, Hop hop, Hop hop2, Hop hop3, String str) {
        StatementBlock statementBlock3 = statementBlock;
        if (!(hop3 instanceof LiteralOp) || ((LiteralOp) hop3).getDoubleValue() != 1.0d) {
            return statementBlock3;
        }
        boolean z = false;
        boolean z2 = false;
        if (statementBlock2.getHops() != null && statementBlock2.getHops().size() == 1) {
            Hop hop4 = statementBlock2.getHops().get(0);
            if (hop4.getDataType() == Types.DataType.MATRIX && (hop4.getInput().get(0) instanceof LeftIndexingOp)) {
                LeftIndexingOp leftIndexingOp = (LeftIndexingOp) hop4.getInput().get(0);
                Hop hop5 = leftIndexingOp.getInput().get(0);
                Hop hop6 = leftIndexingOp.getInput().get(1);
                if ((hop5 instanceof DataOp) && (hop6 instanceof BinaryOp) && (hop6.getInput().get(0) instanceof IndexingOp) && (hop6.getInput().get(1) instanceof IndexingOp) && (hop6.getInput().get(0).getInput().get(0) instanceof DataOp) && (hop6.getInput().get(1).getInput().get(0) instanceof DataOp)) {
                    IndexingOp indexingOp = (IndexingOp) hop6.getInput().get(0);
                    IndexingOp indexingOp2 = (IndexingOp) hop6.getInput().get(1);
                    if (leftIndexingOp.isRowLowerEqualsUpper() && indexingOp.isRowLowerEqualsUpper() && indexingOp2.isRowLowerEqualsUpper() && leftIndexingOp.getInput().get(2).getName().equals(str) && indexingOp.getInput().get(1).getName().equals(str) && indexingOp2.getInput().get(1).getName().equals(str)) {
                        z = true;
                        z2 = true;
                    }
                    if (leftIndexingOp.isColLowerEqualsUpper() && indexingOp.isColLowerEqualsUpper() && indexingOp2.isColLowerEqualsUpper() && leftIndexingOp.getInput().get(4).getName().equals(str) && indexingOp.getInput().get(3).getName().equals(str) && indexingOp2.getInput().get(3).getName().equals(str)) {
                        z = true;
                        z2 = false;
                    }
                }
            }
        }
        if (z) {
            LeftIndexingOp leftIndexingOp2 = (LeftIndexingOp) statementBlock2.getHops().get(0).getInput().get(0);
            BinaryOp binaryOp = (BinaryOp) leftIndexingOp2.getInput().get(1);
            IndexingOp indexingOp3 = (IndexingOp) binaryOp.getInput().get(0);
            IndexingOp indexingOp4 = (IndexingOp) binaryOp.getInput().get(1);
            int i = z2 ? 2 : 4;
            int i2 = z2 ? 3 : 5;
            HopRewriteUtils.replaceChildReference(leftIndexingOp2, leftIndexingOp2.getInput().get(i), hop, i);
            HopRewriteUtils.replaceChildReference(leftIndexingOp2, leftIndexingOp2.getInput().get(i2), hop2, i2);
            HopRewriteUtils.replaceChildReference(indexingOp3, indexingOp3.getInput().get(i - 1), hop, i - 1);
            HopRewriteUtils.replaceChildReference(indexingOp3, indexingOp3.getInput().get(i2 - 1), hop2, i2 - 1);
            HopRewriteUtils.replaceChildReference(indexingOp4, indexingOp4.getInput().get(i - 1), hop, i - 1);
            HopRewriteUtils.replaceChildReference(indexingOp4, indexingOp4.getInput().get(i2 - 1), hop2, i2 - 1);
            updateLeftAndRightIndexingSizes(z2, leftIndexingOp2, indexingOp3, indexingOp4);
            binaryOp.refreshSizeInformation();
            leftIndexingOp2.refreshSizeInformation();
            statementBlock3 = statementBlock2;
            LOG.debug("Applied vectorizeElementwiseBinaryForLoop.");
        }
        return statementBlock3;
    }

    private static StatementBlock vectorizeElementwiseUnary(StatementBlock statementBlock, StatementBlock statementBlock2, Hop hop, Hop hop2, Hop hop3, String str) {
        StatementBlock statementBlock3 = statementBlock;
        if (!(hop3 instanceof LiteralOp) || ((LiteralOp) hop3).getDoubleValue() != 1.0d) {
            return statementBlock3;
        }
        boolean z = false;
        boolean z2 = false;
        if (statementBlock2.getHops() != null && statementBlock2.getHops().size() == 1) {
            Hop hop4 = statementBlock2.getHops().get(0);
            if (hop4.getDataType() == Types.DataType.MATRIX && (hop4.getInput().get(0) instanceof LeftIndexingOp)) {
                LeftIndexingOp leftIndexingOp = (LeftIndexingOp) hop4.getInput().get(0);
                Hop hop5 = leftIndexingOp.getInput().get(0);
                Hop hop6 = leftIndexingOp.getInput().get(1);
                if ((hop5 instanceof DataOp) && (hop6 instanceof UnaryOp) && (hop6.getInput().get(0) instanceof IndexingOp) && (hop6.getInput().get(0).getInput().get(0) instanceof DataOp)) {
                    boolean[] checkLeftAndRightIndexing = checkLeftAndRightIndexing(leftIndexingOp, (IndexingOp) hop6.getInput().get(0), str);
                    z = checkLeftAndRightIndexing[0];
                    z2 = checkLeftAndRightIndexing[1];
                }
            }
        }
        if (z) {
            LeftIndexingOp leftIndexingOp2 = (LeftIndexingOp) statementBlock2.getHops().get(0).getInput().get(0);
            UnaryOp unaryOp = (UnaryOp) leftIndexingOp2.getInput().get(1);
            IndexingOp indexingOp = (IndexingOp) unaryOp.getInput().get(0);
            int i = z2 ? 2 : 4;
            int i2 = z2 ? 3 : 5;
            HopRewriteUtils.replaceChildReference(leftIndexingOp2, leftIndexingOp2.getInput().get(i), hop, i);
            HopRewriteUtils.replaceChildReference(leftIndexingOp2, leftIndexingOp2.getInput().get(i2), hop2, i2);
            HopRewriteUtils.replaceChildReference(indexingOp, indexingOp.getInput().get(i - 1), hop, i - 1);
            HopRewriteUtils.replaceChildReference(indexingOp, indexingOp.getInput().get(i2 - 1), hop2, i2 - 1);
            updateLeftAndRightIndexingSizes(z2, leftIndexingOp2, indexingOp);
            unaryOp.refreshSizeInformation();
            leftIndexingOp2.refreshSizeInformation();
            statementBlock3 = statementBlock2;
            LOG.debug("Applied vectorizeElementwiseUnaryForLoop.");
        }
        return statementBlock3;
    }

    private static StatementBlock vectorizeIndexedCopy(StatementBlock statementBlock, StatementBlock statementBlock2, Hop hop, Hop hop2, Hop hop3, String str) {
        StatementBlock statementBlock3 = statementBlock;
        if (!(hop3 instanceof LiteralOp) || ((LiteralOp) hop3).getDoubleValue() != 1.0d) {
            return statementBlock3;
        }
        boolean z = false;
        boolean z2 = false;
        if (statementBlock2.getHops() != null && statementBlock2.getHops().size() == 1) {
            Hop hop4 = statementBlock2.getHops().get(0);
            if (hop4.getDataType() == Types.DataType.MATRIX && (hop4.getInput().get(0) instanceof LeftIndexingOp)) {
                LeftIndexingOp leftIndexingOp = (LeftIndexingOp) hop4.getInput().get(0);
                Hop hop5 = leftIndexingOp.getInput().get(0);
                Hop hop6 = leftIndexingOp.getInput().get(1);
                if ((hop5 instanceof DataOp) && (hop6 instanceof IndexingOp) && (hop6.getInput().get(0) instanceof DataOp)) {
                    boolean[] checkLeftAndRightIndexing = checkLeftAndRightIndexing(leftIndexingOp, (IndexingOp) hop6, str);
                    z = checkLeftAndRightIndexing[0];
                    z2 = checkLeftAndRightIndexing[1];
                }
            }
        }
        if (z) {
            LeftIndexingOp leftIndexingOp2 = (LeftIndexingOp) statementBlock2.getHops().get(0).getInput().get(0);
            IndexingOp indexingOp = (IndexingOp) leftIndexingOp2.getInput().get(1);
            int i = z2 ? 2 : 4;
            int i2 = z2 ? 3 : 5;
            HopRewriteUtils.replaceChildReference(leftIndexingOp2, leftIndexingOp2.getInput().get(i), hop, i);
            HopRewriteUtils.replaceChildReference(leftIndexingOp2, leftIndexingOp2.getInput().get(i2), hop2, i2);
            HopRewriteUtils.replaceChildReference(indexingOp, indexingOp.getInput().get(i - 1), hop, i - 1);
            HopRewriteUtils.replaceChildReference(indexingOp, indexingOp.getInput().get(i2 - 1), hop2, i2 - 1);
            updateLeftAndRightIndexingSizes(z2, leftIndexingOp2, indexingOp);
            statementBlock3 = statementBlock2;
            LOG.debug("Applied vectorizeIndexedCopy.");
        }
        return statementBlock3;
    }

    private static boolean[] checkLeftAndRightIndexing(LeftIndexingOp leftIndexingOp, IndexingOp indexingOp, String str) {
        boolean[] zArr = new boolean[2];
        if (leftIndexingOp.isRowLowerEqualsUpper() && indexingOp.isRowLowerEqualsUpper() && leftIndexingOp.getInput().get(2).getName().equals(str) && indexingOp.getInput().get(1).getName().equals(str)) {
            zArr[0] = true;
            zArr[1] = true;
        }
        if (leftIndexingOp.isColLowerEqualsUpper() && indexingOp.isColLowerEqualsUpper() && leftIndexingOp.getInput().get(4).getName().equals(str) && indexingOp.getInput().get(3).getName().equals(str)) {
            zArr[0] = true;
            zArr[1] = false;
        }
        return zArr;
    }

    private static void updateLeftAndRightIndexingSizes(boolean z, LeftIndexingOp leftIndexingOp, IndexingOp... indexingOpArr) {
        if (z) {
            leftIndexingOp.setRowLowerEqualsUpper(false);
            for (IndexingOp indexingOp : indexingOpArr) {
                indexingOp.setRowLowerEqualsUpper(false);
            }
        } else {
            leftIndexingOp.setColLowerEqualsUpper(false);
            for (IndexingOp indexingOp2 : indexingOpArr) {
                indexingOp2.setColLowerEqualsUpper(false);
            }
        }
        for (IndexingOp indexingOp3 : indexingOpArr) {
            indexingOp3.refreshSizeInformation();
        }
        leftIndexingOp.refreshSizeInformation();
    }
}
