package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.logging.Log;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.utils.Explain;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.class */
public class RewriteMatrixMultChainOptimization extends HopRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) {
        if (arrayList == null) {
            return null;
        }
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            rule_OptimizeMMChains(it.next(), programRewriteStatus);
        }
        return arrayList;
    }

    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop == null) {
            return null;
        }
        rule_OptimizeMMChains(hop, programRewriteStatus);
        return hop;
    }

    private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop.isVisited()) {
            return;
        }
        if (HopRewriteUtils.isMatrixMultiply(hop) && !((AggBinaryOp) hop).hasLeftPMInput() && !hop.isVisited()) {
            prepAndOptimizeMMChain(hop, programRewriteStatus);
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rule_OptimizeMMChains(it.next(), programRewriteStatus);
        }
        hop.setVisited();
    }

    private void prepAndOptimizeMMChain(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (LOG.isTraceEnabled()) {
            Log log = LOG;
            String simpleName = hop.getClass().getSimpleName();
            long hopID = hop.getHopID();
            hop.getName();
            log.trace("MM Chain Optimization for HOP: (" + simpleName + ", " + hopID + ", " + log + ")");
        }
        ArrayList<Hop> arrayList = new ArrayList<>();
        ArrayList<Hop> arrayList2 = new ArrayList<>();
        arrayList2.add(hop);
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        int i = 0;
        while (i < arrayList.size()) {
            boolean z = false;
            Hop hop2 = arrayList.get(i);
            if (HopRewriteUtils.isMatrixMultiply(hop2) && !((AggBinaryOp) hop).hasLeftPMInput() && !hop2.isVisited()) {
                z = hop2.getParent().size() <= 1 && inputCount(hop2.getParent().get(0), hop2) <= 1;
                if (!z) {
                    break;
                }
            }
            hop2.setVisited();
            if (z) {
                ArrayList<Hop> input = arrayList.get(i).getInput();
                if (input.size() != 2) {
                    throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
                }
                arrayList2.add(arrayList.get(i));
                arrayList.set(i, input.get(0));
                arrayList.add(i + 1, input.get(1));
            } else {
                i++;
            }
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Identified MM Chain: ");
            Iterator<Hop> it2 = arrayList.iterator();
            while (it2.hasNext()) {
                logTraceHop(it2.next(), 1);
            }
        }
        if (arrayList.size() == 2) {
            return;
        }
        optimizeMMChain(hop, arrayList, arrayList2, programRewriteStatus);
    }

    protected void optimizeMMChain(Hop hop, ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2, ProgramRewriteStatus programRewriteStatus) {
        double[] dArr = new double[arrayList.size() + 1];
        if (getDimsArray(hop, arrayList, dArr)) {
            clearLinksWithinChain(hop, arrayList2);
            int size = arrayList.size();
            int[][] mmChainDP = mmChainDP(dArr, arrayList.size());
            LOG.trace("Optimal MM Chain: ");
            mmChainRelinkHops(arrayList2.get(0), 0, size - 1, arrayList, arrayList2, new MutableInt(1), mmChainDP, 1);
        }
    }

    private static int[][] mmChainDP(double[] dArr, int i) {
        double[][] dArr2 = new double[i][i];
        int[][] iArr = new int[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            Arrays.fill(dArr2[i2], DataExpression.DEFAULT_DELIM_FILL_VALUE);
            Arrays.fill(iArr[i2], -1);
        }
        for (int i3 = 2; i3 <= i; i3++) {
            for (int i4 = 0; i4 < (i - i3) + 1; i4++) {
                int i5 = (i4 + i3) - 1;
                dArr2[i4][i5] = Double.MAX_VALUE;
                for (int i6 = i4; i6 <= i5 - 1; i6++) {
                    double d = dArr2[i4][i6] + dArr2[i6 + 1][i5] + (dArr[i4] * dArr[i6 + 1] * dArr[i5 + 1]);
                    if (d < dArr2[i4][i5]) {
                        dArr2[i4][i5] = d;
                        iArr[i4][i5] = i6;
                    }
                }
                if (LOG.isTraceEnabled()) {
                    Log log = LOG;
                    double d2 = dArr2[i4][i5];
                    int i7 = iArr[i4][i5] + 1;
                    log.trace("mmchainopt [i=" + (i4 + 1) + ",j=" + (i5 + 1) + "]: costs = " + d2 + ", split = " + log);
                }
            }
        }
        return iArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void mmChainRelinkHops(Hop hop, int i, int i2, ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2, MutableInt mutableInt, int[][] iArr, int i3) {
        if (i == i2) {
            logTraceHop(hop, i3);
            return;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace(Explain.getIdentation(i3) + "(");
        }
        if (i == iArr[i][i2]) {
            hop.getInput().add(arrayList.get(i));
            arrayList.get(i).getParent().add(hop);
        } else {
            int intValue = mutableInt.getValue().intValue();
            mutableInt.increment();
            hop.getInput().add(arrayList2.get(intValue));
            arrayList2.get(intValue).getParent().add(hop);
        }
        if (iArr[i][i2] + 1 == i2) {
            hop.getInput().add(arrayList.get(i2));
            arrayList.get(i2).getParent().add(hop);
        } else {
            int intValue2 = mutableInt.getValue().intValue();
            mutableInt.increment();
            hop.getInput().add(arrayList2.get(intValue2));
            arrayList2.get(intValue2).getParent().add(hop);
        }
        mmChainRelinkHops(hop.getInput().get(0), i, iArr[i][i2], arrayList, arrayList2, mutableInt, iArr, i3 + 1);
        mmChainRelinkHops(hop.getInput().get(1), iArr[i][i2] + 1, i2, arrayList, arrayList2, mutableInt, iArr, i3 + 1);
        hop.refreshSizeInformation();
        if (LOG.isTraceEnabled()) {
            LOG.trace(Explain.getIdentation(i3) + ")");
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void clearLinksWithinChain(Hop hop, ArrayList<Hop> arrayList) {
        for (int i = 0; i < arrayList.size(); i++) {
            Hop hop2 = arrayList.get(i);
            if (hop2.getInput().size() != 2 || (i != 0 && hop2.getParent().size() > 1)) {
                throw new HopsException(hop.printErrorLocation() + "Unexpected error while applying optimization on matrix-mult chain. \n");
            }
            Hop hop3 = hop2.getInput().get(0);
            Hop hop4 = hop2.getInput().get(1);
            hop2.getInput().clear();
            hop3.getParent().remove(hop2);
            hop4.getParent().remove(hop2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean getDimsArray(Hop hop, ArrayList<Hop> arrayList, double[] dArr) {
        boolean z = true;
        for (int i = 0; i < arrayList.size(); i++) {
            if (arrayList.get(i).getDim1() <= 0 || arrayList.get(i).getDim2() <= 0) {
                z = false;
            }
        }
        if (z) {
            for (int i2 = 0; i2 < arrayList.size(); i2++) {
                if (i2 == 0) {
                    dArr[i2] = arrayList.get(i2).getDim1();
                    if (dArr[i2] <= DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                        throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dArr[i2]);
                    }
                } else if (arrayList.get(i2 - 1).getDim2() != arrayList.get(i2).getDim1()) {
                    String printErrorLocation = hop.printErrorLocation();
                    long dim2 = arrayList.get(i2 - 1).getDim2();
                    arrayList.get(i2).getDim1();
                    HopsException hopsException = new HopsException(printErrorLocation + "Hops::optimizeMMChain() : Matrix Dimension Mismatch: " + dim2 + " != " + hopsException);
                    throw hopsException;
                }
                dArr[i2 + 1] = arrayList.get(i2).getDim2();
                if (dArr[i2 + 1] <= DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dArr[i2 + 1]);
                }
            }
        }
        return z;
    }

    private static int inputCount(Hop hop, Hop hop2) {
        return CollectionUtils.cardinality(hop2, hop.getInput());
    }

    private static void logTraceHop(Hop hop, int i) {
        if (LOG.isTraceEnabled()) {
            String identation = Explain.getIdentation(i);
            Log log = LOG;
            String name = hop.getName();
            String simpleName = hop.getClass().getSimpleName();
            long hopID = hop.getHopID();
            long dim1 = hop.getDim1();
            hop.getDim2();
            log.trace(identation + "Hop " + name + "(" + simpleName + ", " + hopID + ") " + log + "x" + dim1);
        }
    }
}
