package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.class */
public class CLALibLeftMultBy {
    private static final Log LOG = LogFactory.getLog(CLALibLeftMultBy.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy$LeftMatrixColGroupMultTask.class */
    public static class LeftMatrixColGroupMultTask implements Callable<MatrixBlock> {
        private final List<AColGroup> _groups;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;
        private final double[] _rowSums;

        protected LeftMatrixColGroupMultTask(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, double[] dArr) {
            this._groups = list;
            this._that = matrixBlock;
            this._ret = matrixBlock2;
            this._rl = i;
            this._ru = i2;
            this._rowSums = dArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public MatrixBlock call() {
            try {
                CLALibLeftMultBy.leftMultByMatrixPrimitive(this._groups, this._that, this._ret, this._rl, this._ru, this._rowSums);
                return this._ret;
            } catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        if (matrixBlock.isEmpty()) {
            return matrixBlock2;
        }
        MatrixBlock matrixBlock3 = new MatrixBlock(matrixBlock.getNumColumns(), matrixBlock.getNumRows(), false);
        LibMatrixReorg.transpose(matrixBlock, matrixBlock3, i);
        MatrixBlock leftMultByMatrix = leftMultByMatrix(compressedMatrixBlock, matrixBlock3, matrixBlock2, i);
        leftMultByMatrix.recomputeNonZeros();
        return leftMultByMatrix;
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock compressedMatrixBlock, CompressedMatrixBlock compressedMatrixBlock2, MatrixBlock matrixBlock, int i) {
        MatrixBlock prepareReturnMatrix = prepareReturnMatrix(compressedMatrixBlock, compressedMatrixBlock2, matrixBlock, true);
        leftMultByCompressedTransposedMatrix(compressedMatrixBlock, compressedMatrixBlock2, prepareReturnMatrix, i);
        prepareReturnMatrix.recomputeNonZeros();
        return prepareReturnMatrix;
    }

    public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        MatrixBlock prepareReturnMatrix = prepareReturnMatrix(compressedMatrixBlock, matrixBlock, matrixBlock2, false);
        if (matrixBlock.isEmpty()) {
            return prepareReturnMatrix;
        }
        LOG.trace("LeftMultByMatrix Execution");
        MatrixBlock leftMultByMatrix = leftMultByMatrix(compressedMatrixBlock.getColGroups(), matrixBlock, prepareReturnMatrix, i, compressedMatrixBlock.isOverlapping());
        leftMultByMatrix.recomputeNonZeros();
        return leftMultByMatrix;
    }

    public static void leftMultByTransposeSelf(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i) {
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        int numColumns = compressedMatrixBlock.getNumColumns();
        int numRows = compressedMatrixBlock.getNumRows();
        boolean containsSDCOrConst = CLALibUtils.containsSDCOrConst(colGroups);
        double[] dArr = containsSDCOrConst ? new double[numColumns] : null;
        List<AColGroup> filterGroups = CLALibUtils.filterGroups(colGroups, dArr);
        double[] colSum = getColSum(filterGroups, numColumns, numRows, containsSDCOrConst);
        tsmmColGroups(filterGroups, matrixBlock, numRows);
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        if (dArr != null) {
            outerProductUpperTriangle(dArr, colSum, denseBlockValues);
            for (int i2 = 0; i2 < colSum.length; i2++) {
                int i3 = i2;
                colSum[i3] = colSum[i3] + (dArr[i2] * numRows);
            }
            outerProductUpperTriangle(colSum, dArr, denseBlockValues);
        }
        matrixBlock.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(matrixBlock));
        matrixBlock.examSparsity();
    }

    private static MatrixBlock prepareReturnMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, boolean z) {
        int numColumns = z ? matrixBlock2.getNumColumns() : matrixBlock2.getNumRows();
        int numColumns2 = matrixBlock.getNumColumns();
        if (matrixBlock3 == null) {
            matrixBlock3 = new MatrixBlock(numColumns, numColumns2, false, numColumns * numColumns2);
        } else if (matrixBlock3.getNumColumns() != numColumns2 || matrixBlock3.getNumRows() != numColumns || !matrixBlock3.isAllocated()) {
            matrixBlock3.reset(numColumns, numColumns2, false, numColumns * numColumns2);
        }
        matrixBlock3.allocateDenseBlock();
        return matrixBlock3;
    }

    private static MatrixBlock leftMultByCompressedTransposedMatrix(CompressedMatrixBlock compressedMatrixBlock, CompressedMatrixBlock compressedMatrixBlock2, MatrixBlock matrixBlock, int i) {
        int numRows = compressedMatrixBlock.getNumRows();
        int numColumns = compressedMatrixBlock.getNumColumns();
        int numColumns2 = compressedMatrixBlock2.getNumColumns();
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        List<AColGroup> colGroups2 = compressedMatrixBlock2.getColGroups();
        boolean containsSDCOrConst = CLALibUtils.containsSDCOrConst(colGroups);
        double[] dArr = containsSDCOrConst ? new double[numColumns] : null;
        List<AColGroup> filterGroups = CLALibUtils.filterGroups(colGroups, dArr);
        boolean containsSDCOrConst2 = CLALibUtils.containsSDCOrConst(colGroups2);
        double[] dArr2 = containsSDCOrConst2 ? new double[numColumns2] : null;
        List<AColGroup> filterGroups2 = CLALibUtils.filterGroups(colGroups2, dArr2);
        double[] colSum = getColSum(filterGroups, numColumns, numRows, containsSDCOrConst2);
        double[] colSum2 = getColSum(filterGroups2, numColumns2, numRows, containsSDCOrConst);
        for (int i2 = 0; i2 < filterGroups.size(); i2++) {
            for (int i3 = 0; i3 < filterGroups2.size(); i3++) {
                filterGroups.get(i2).leftMultByAColGroup(filterGroups2.get(i3), matrixBlock);
            }
        }
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        if (containsSDCOrConst2) {
            outerProduct(dArr2, colSum, denseBlockValues);
        }
        if (containsSDCOrConst) {
            outerProduct(dArr, colSum2, denseBlockValues);
        }
        return matrixBlock;
    }

    private static void tsmmColGroups(List<AColGroup> list, MatrixBlock matrixBlock, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            AColGroup aColGroup = list.get(i2);
            aColGroup.tsmm(matrixBlock, i);
            for (int i3 = i2 + 1; i3 < list.size(); i3++) {
                aColGroup.tsmmAColGroup(list.get(i3), matrixBlock);
            }
        }
    }

    private static void outerProductUpperTriangle(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            int length = dArr2.length * i;
            double d = dArr[i];
            for (int i2 = i; i2 < dArr2.length; i2++) {
                int i3 = length + i2;
                dArr3[i3] = dArr3[i3] + (d * dArr2[i2]);
            }
        }
    }

    private static MatrixBlock leftMultByMatrix(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, boolean z) {
        double[] denseBlockValues;
        if (matrixBlock.isEmpty()) {
            matrixBlock2.setNonZeros(0L);
            return matrixBlock2;
        }
        int numColumns = matrixBlock2.getNumColumns();
        boolean containsSDCOrConst = CLALibUtils.containsSDCOrConst(list);
        int numRows = matrixBlock.getNumRows();
        double[] dArr = containsSDCOrConst ? new double[numColumns] : null;
        List<AColGroup> filterGroups = CLALibUtils.filterGroups(list, dArr);
        if (list == filterGroups) {
            dArr = null;
        }
        if (filterGroups.isEmpty()) {
            denseBlockValues = dArr != null ? matrixBlock.rowSum(i).getDenseBlockValues() : null;
        } else if (i == 1) {
            denseBlockValues = leftMultByMatrixPrimitive(filterGroups, matrixBlock, matrixBlock2, 0, numRows, containsSDCOrConst ? new double[numRows] : null);
        } else {
            denseBlockValues = leftMultByMatrixParallel(filterGroups, matrixBlock, matrixBlock2, containsSDCOrConst, z, i);
        }
        if (denseBlockValues != null && dArr != null) {
            matrixBlock2.sparseToDense();
            outerProduct(denseBlockValues, dArr, matrixBlock2.getDenseBlockValues());
        }
        matrixBlock2.recomputeNonZeros();
        return matrixBlock2;
    }

    private static double[] leftMultByMatrixParallel(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, boolean z, boolean z2, int i) {
        LOG.debug("Parallel left matrix multiplication");
        try {
            ExecutorService executorService = CommonThreadPool.get(i);
            ArrayList arrayList = new ArrayList();
            int numRows = matrixBlock.getNumRows();
            int min = numRows <= i ? 1 : Math.min(Math.max((numRows / i) * 2, 1), 8);
            double[] dArr = z ? new double[numRows] : null;
            int max = Math.max(i / (numRows / min), 1);
            if (max == 1) {
                for (int i2 = 0; i2 < numRows; i2 += min) {
                    arrayList.add(new LeftMatrixColGroupMultTask(list, matrixBlock, matrixBlock2, i2, Math.min(i2 + min, numRows), dArr));
                }
                Iterator it = executorService.invokeAll(arrayList).iterator();
                while (it.hasNext()) {
                    ((Future) it.next()).get();
                }
            } else {
                List<List<AColGroup>> split = split(list, max);
                boolean z3 = z2 && list.size() > 1;
                for (int i3 = 0; i3 < numRows; i3 += min) {
                    int i4 = i3;
                    int min2 = Math.min(i3 + min, numRows);
                    for (int i5 = 0; i5 < split.size(); i5++) {
                        List<AColGroup> list2 = split.get(i5);
                        MatrixBlock matrixBlock3 = z3 ? new MatrixBlock(numRows, matrixBlock2.getNumColumns(), false) : matrixBlock2;
                        if (matrixBlock3.getDenseBlock() == null) {
                            matrixBlock3.allocateDenseBlock();
                        }
                        if (i5 == 0) {
                            arrayList.add(new LeftMatrixColGroupMultTask(list2, matrixBlock, matrixBlock3, i4, min2, dArr));
                        } else {
                            arrayList.add(new LeftMatrixColGroupMultTask(list2, matrixBlock, matrixBlock3, i4, min2, null));
                        }
                    }
                }
                if (z3) {
                    BinaryOperator binaryOperator = new BinaryOperator(Plus.getPlusFnObject());
                    Iterator it2 = executorService.invokeAll(arrayList).iterator();
                    while (it2.hasNext()) {
                        matrixBlock2.binaryOperationsInPlace(binaryOperator, (MatrixValue) ((Future) it2.next()).get());
                    }
                } else {
                    Iterator it3 = executorService.invokeAll(arrayList).iterator();
                    while (it3.hasNext()) {
                        ((Future) it3.next()).get();
                    }
                }
            }
            executorService.shutdown();
            return dArr;
        } catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static List<List<AColGroup>> split(List<AColGroup> list, int i) {
        Collections.sort(list, Comparator.comparing((v0) -> {
            return v0.getNumValues();
        }).reversed());
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new ArrayList());
        }
        for (int i3 = 0; i3 < list.size(); i3++) {
            ((List) arrayList.get(i3 % i)).add(list.get(i3));
        }
        return arrayList;
    }

    private static void outerProduct(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            int length = dArr2.length * i;
            double d = dArr[i];
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                int i3 = length + i2;
                dArr3[i3] = dArr3[i3] + (d * dArr2[i2]);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double[] leftMultByMatrixPrimitive(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, double[] dArr) {
        if (matrixBlock.isInSparseFormat()) {
            leftMultByMatrixPrimitiveSparse(list, matrixBlock, matrixBlock2, i, i2, dArr);
        } else {
            leftMultByMatrixPrimitiveDense(list, matrixBlock, matrixBlock2, i, i2, dArr);
        }
        matrixBlock2.setNonZeros(matrixBlock2.getNumRows() * matrixBlock2.getNumColumns());
        return dArr;
    }

    private static void leftMultByMatrixPrimitiveSparse(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, double[] dArr) {
        for (int i3 = i; i3 < i2; i3++) {
            for (int i4 = 0; i4 < list.size(); i4++) {
                list.get(i4).leftMultByMatrix(matrixBlock, matrixBlock2, i3, i3 + 1);
            }
            if (dArr != null) {
                SparseBlock sparseBlock = matrixBlock.getSparseBlock();
                if (!sparseBlock.isEmpty(i3)) {
                    int pos = sparseBlock.pos(i3);
                    int size = sparseBlock.size(i3) + pos;
                    double[] values = sparseBlock.values(i3);
                    for (int i5 = pos; i5 < size; i5++) {
                        int i6 = i3;
                        dArr[i6] = dArr[i6] + values[i5];
                    }
                }
            }
        }
    }

    private static void leftMultByMatrixPrimitiveDense(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, double[] dArr) {
        int numColumns = matrixBlock2.getNumColumns();
        List<ColGroupValue> preFilterAndMultiply = preFilterAndMultiply(list, matrixBlock, matrixBlock2, i, i2);
        int i3 = preFilterAndMultiply.size() % 16 < 4 ? 20 : 16;
        MatrixBlock[] populatePreAggregate = populatePreAggregate(i3);
        MatrixBlock matrixBlock3 = new MatrixBlock(1, numColumns, false);
        int numColumns2 = matrixBlock.getNumColumns();
        int i4 = 0;
        while (true) {
            int i5 = i4;
            if (i5 >= preFilterAndMultiply.size()) {
                break;
            }
            int min = Math.min(i5 + i3, preFilterAndMultiply.size());
            for (int i6 = i5; i6 < min && i6 < preFilterAndMultiply.size(); i6++) {
                populatePreAggregate[i6 % i3].reset(1, preFilterAndMultiply.get(i6).getNumValues(), false);
            }
            for (int i7 = i; i7 < i2; i7++) {
                int min2 = Math.min(i7 + 1, i2);
                int i8 = 0;
                while (true) {
                    int i9 = i8;
                    if (i9 >= numColumns2) {
                        break;
                    }
                    int min3 = Math.min(i9 + 32000, numColumns2);
                    for (int i10 = i5; i10 < min && i10 < preFilterAndMultiply.size(); i10++) {
                        preFilterAndMultiply.get(i10).preAggregateDense(matrixBlock, populatePreAggregate[i10 % i3], i7, min2, i9, min3);
                    }
                    if (dArr != null) {
                        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
                        for (int i11 = i7; i11 < min2; i11++) {
                            int i12 = i11 * numColumns2;
                            for (int i13 = i12 + i9; i13 < i12 + min3; i13++) {
                                int i14 = i11;
                                dArr[i14] = dArr[i14] + denseBlockValues[i13];
                            }
                        }
                    }
                    i8 = i9 + 32000;
                }
                for (int i15 = i5; i15 < min && i15 < preFilterAndMultiply.size(); i15++) {
                    ColGroupValue colGroupValue = preFilterAndMultiply.get(i15);
                    MatrixBlock matrixBlock4 = populatePreAggregate[i15 % i3];
                    matrixBlock4.recomputeNonZeros();
                    matrixBlock3.reset(1, colGroupValue.getNumCols(), false);
                    colGroupValue.addMatrixToResult(colGroupValue.leftMultByPreAggregateMatrix(matrixBlock4, matrixBlock3), matrixBlock2, i7, Math.min(i7 + 1, i2));
                    matrixBlock4.reset();
                }
            }
            i4 = i5 + i3;
        }
        if (preFilterAndMultiply.size() != 0 || dArr == null) {
            return;
        }
        double[] denseBlockValues2 = matrixBlock.getDenseBlockValues();
        for (int i16 = i; i16 < i2; i16++) {
            int i17 = i16 * numColumns2;
            int i18 = i17 + numColumns2;
            for (int i19 = i17; i19 < i18; i19++) {
                int i20 = i16;
                dArr[i20] = dArr[i20] + denseBlockValues2[i19];
            }
        }
    }

    private static MatrixBlock[] populatePreAggregate(int i) {
        MatrixBlock[] matrixBlockArr = new MatrixBlock[i];
        for (int i2 = 0; i2 < i; i2++) {
            MatrixBlock matrixBlock = new MatrixBlock(1, 1, false);
            matrixBlock.allocateDenseBlock();
            matrixBlockArr[i2] = matrixBlock;
        }
        return matrixBlockArr;
    }

    private static List<ColGroupValue> preFilterAndMultiply(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2) {
        ArrayList arrayList = new ArrayList(list.size());
        for (int i3 = 0; i3 < list.size(); i3++) {
            AColGroup aColGroup = list.get(i3);
            if (aColGroup instanceof ColGroupValue) {
                arrayList.add((ColGroupValue) aColGroup);
            } else {
                aColGroup.leftMultByMatrix(matrixBlock, matrixBlock2, i, i2);
            }
        }
        Collections.sort(arrayList, Comparator.comparing((v0) -> {
            return v0.getNumValues();
        }).reversed());
        return arrayList;
    }

    private static double[] getColSum(List<AColGroup> list, int i, int i2, boolean z) {
        if (z) {
            return AColGroup.colSum(list, new double[i], i2);
        }
        return null;
    }
}
