package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.commons.logging.Log;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.estim.EstimatorMatrixHistogram;
import org.apache.sysds.hops.estim.MMNode;
import org.apache.sysds.hops.estim.SparsityEstimator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.class */
public class RewriteMatrixMultChainOptimizationSparse extends RewriteMatrixMultChainOptimization {
    @Override // org.apache.sysds.hops.rewrite.RewriteMatrixMultChainOptimization
    protected void optimizeMMChain(Hop hop, ArrayList<Hop> arrayList, ArrayList<Hop> arrayList2, ProgramRewriteStatus programRewriteStatus) {
        double[] dArr = new double[arrayList.size() + 1];
        boolean dimsArray = getDimsArray(hop, arrayList, dArr);
        MMNode[] mMNodeArr = new MMNode[arrayList.size() + 1];
        boolean inputMatrices = getInputMatrices(hop, arrayList, mMNodeArr, programRewriteStatus);
        if (dimsArray && inputMatrices) {
            clearLinksWithinChain(hop, arrayList2);
            int size = arrayList.size();
            int[][] mmChainDPSparse = mmChainDPSparse(dArr, mMNodeArr, arrayList.size());
            LOG.trace("Optimal MM Chain: ");
            mmChainRelinkHops(arrayList2.get(0), 0, size - 1, arrayList, arrayList2, new MutableInt(1), mmChainDPSparse, 1);
        }
    }

    private static int[][] mmChainDPSparse(double[] dArr, MMNode[] mMNodeArr, int i) {
        double[][] dArr2 = new double[i][i];
        MMNode[][] mMNodeArr2 = new MMNode[i][i];
        int[][] iArr = new int[i][i];
        for (int i2 = 0; i2 < i; i2++) {
            Arrays.fill(dArr2[i2], DataExpression.DEFAULT_DELIM_FILL_VALUE);
            Arrays.fill(iArr[i2], -1);
            mMNodeArr2[i2][i2] = mMNodeArr[i2];
        }
        EstimatorMatrixHistogram estimatorMatrixHistogram = new EstimatorMatrixHistogram(true);
        for (int i3 = 2; i3 <= i; i3++) {
            for (int i4 = 0; i4 < (i - i3) + 1; i4++) {
                int i5 = (i4 + i3) - 1;
                dArr2[i4][i5] = Double.MAX_VALUE;
                for (int i6 = i4; i6 <= i5 - 1; i6++) {
                    MMNode mMNode = new MMNode(mMNodeArr2[i4][i6], mMNodeArr2[i6 + 1][i5], SparsityEstimator.OpCode.MM);
                    estimatorMatrixHistogram.estim(mMNode, false);
                    double dotProduct = dArr2[i4][i6] + dArr2[i6 + 1][i5] + dotProduct(((EstimatorMatrixHistogram.MatrixHistogram) mMNodeArr2[i4][i6].getSynopsis()).getColCounts(), ((EstimatorMatrixHistogram.MatrixHistogram) mMNodeArr2[i6 + 1][i5].getSynopsis()).getRowCounts());
                    if (dotProduct < dArr2[i4][i5]) {
                        dArr2[i4][i5] = dotProduct;
                        mMNodeArr2[i4][i5] = mMNode;
                        iArr[i4][i5] = i6;
                    }
                }
                if (LOG.isTraceEnabled()) {
                    Log log = LOG;
                    double d = dArr2[i4][i5];
                    int i7 = iArr[i4][i5] + 1;
                    log.trace("mmchainopt [i=" + (i4 + 1) + ",j=" + (i5 + 1) + "]: costs = " + d + ", split = " + log);
                }
            }
        }
        return iArr;
    }

    private static boolean getInputMatrices(Hop hop, ArrayList<Hop> arrayList, MMNode[] mMNodeArr, ProgramRewriteStatus programRewriteStatus) {
        boolean z = true;
        LocalVariableMap variables = programRewriteStatus.getVariables();
        for (int i = 0; i < arrayList.size(); i++) {
            z &= HopRewriteUtils.isData(arrayList.get(0), Types.OpOpData.TRANSIENTREAD);
            if (!z) {
                break;
            }
            mMNodeArr[i] = new MMNode(getMatrix(arrayList.get(i).getName(), variables));
        }
        return z;
    }

    private static MatrixBlock getMatrix(String str, LocalVariableMap localVariableMap) {
        Data data = localVariableMap.get(str);
        if (data instanceof MatrixObject) {
            return ((MatrixObject) data).acquireReadAndRelease();
        }
        throw new HopsException("Input '" + str + "' not a matrix: " + data.getDataType());
    }

    private static double dotProduct(int[] iArr, int[] iArr2) {
        long j = 0;
        for (int i = 0; i < iArr.length; i++) {
            j += iArr[i] * iArr2[i];
        }
        return j;
    }
}
