package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibTSMM.class */
public final class CLALibTSMM {
    private static final Log LOG = LogFactory.getLog(CLALibTSMM.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibTSMM$TSMMColGroupTask.class */
    public static class TSMMColGroupTask implements Callable<MatrixBlock> {
        private final AColGroup _g;
        private final AColGroup _h;
        private final MatrixBlock _ret;

        protected TSMMColGroupTask(AColGroup aColGroup, AColGroup aColGroup2, MatrixBlock matrixBlock) {
            this._g = aColGroup;
            this._h = aColGroup2;
            this._ret = matrixBlock;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public MatrixBlock call() {
            try {
                this._g.tsmmAColGroup(this._h, this._ret);
                return this._ret;
            } catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibTSMM$TSMMTask.class */
    public static class TSMMTask implements Callable<MatrixBlock> {
        private final AColGroup _g;
        private final MatrixBlock _ret;
        private final int _nRows;

        protected TSMMTask(AColGroup aColGroup, MatrixBlock matrixBlock, int i) {
            this._g = aColGroup;
            this._ret = matrixBlock;
            this._nRows = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public MatrixBlock call() {
            try {
                this._g.tsmm(this._ret, this._nRows);
                return this._ret;
            } catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }

    private CLALibTSMM() {
    }

    public static void leftMultByTransposeSelf(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i) {
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        int numColumns = compressedMatrixBlock.getNumColumns();
        int numRows = compressedMatrixBlock.getNumRows();
        boolean shouldPreFilter = CLALibUtils.shouldPreFilter(colGroups);
        boolean isOverlapping = compressedMatrixBlock.isOverlapping();
        if (shouldPreFilter) {
            double[] dArr = new double[numColumns];
            List<AColGroup> filterGroups = CLALibUtils.filterGroups(colGroups, dArr);
            tsmmColGroups(filterGroups, matrixBlock, numRows, isOverlapping, i);
            addCorrectionLayer(filterGroups, matrixBlock, numRows, numColumns, dArr);
        } else {
            tsmmColGroups(colGroups, matrixBlock, numRows, isOverlapping, i);
        }
        matrixBlock.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(matrixBlock));
        matrixBlock.examSparsity();
    }

    private static void addCorrectionLayer(List<AColGroup> list, MatrixBlock matrixBlock, int i, int i2, double[] dArr) {
        addCorrectionLayer(dArr, CLALibUtils.getColSum(list, i2, i), i, matrixBlock.getDenseBlockValues());
    }

    public static void addCorrectionLayer(double[] dArr, double[] dArr2, int i, double[] dArr3) {
        outerProductUpperTriangle(dArr, dArr2, dArr3);
        outerProductUpperTriangleWithScaling(dArr2, dArr, i, dArr3);
    }

    private static void tsmmColGroups(List<AColGroup> list, MatrixBlock matrixBlock, int i, boolean z, int i2) {
        if (i2 <= 1) {
            tsmmColGroupsSingleThread(list, matrixBlock, i);
        } else if (z) {
            tsmmColGroupsMultiThreadOverlapping(list, matrixBlock, i, i2);
        } else {
            tsmmColGroupsMultiThread(list, matrixBlock, i, i2);
        }
    }

    private static void tsmmColGroupsSingleThread(List<AColGroup> list, MatrixBlock matrixBlock, int i) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            AColGroup aColGroup = list.get(i2);
            aColGroup.tsmm(matrixBlock, i);
            for (int i3 = i2 + 1; i3 < list.size(); i3++) {
                aColGroup.tsmmAColGroup(list.get(i3), matrixBlock);
            }
        }
    }

    private static void tsmmColGroupsMultiThreadOverlapping(List<AColGroup> list, MatrixBlock matrixBlock, int i, int i2) {
        LOG.warn("fallback to single threaded for now");
        tsmmColGroupsSingleThread(list, matrixBlock, i);
    }

    private static void tsmmColGroupsMultiThread(List<AColGroup> list, MatrixBlock matrixBlock, int i, int i2) {
        ExecutorService executorService = CommonThreadPool.get(i2);
        ArrayList arrayList = new ArrayList((list.size() * (1 + list.size())) / 2);
        for (int i3 = 0; i3 < list.size(); i3++) {
            AColGroup aColGroup = list.get(i3);
            arrayList.add(new TSMMTask(aColGroup, matrixBlock, i));
            for (int i4 = i3 + 1; i4 < list.size(); i4++) {
                arrayList.add(new TSMMColGroupTask(aColGroup, list.get(i4), matrixBlock));
            }
        }
        try {
            Iterator it = executorService.invokeAll(arrayList).iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
            executorService.shutdown();
        } catch (InterruptedException | ExecutionException e) {
            executorService.shutdown();
            throw new DMLRuntimeException(e);
        }
    }

    private static void outerProductUpperTriangle(double[] dArr, double[] dArr2, double[] dArr3) {
        for (int i = 0; i < dArr.length; i++) {
            int length = dArr2.length * i;
            double d = dArr[i];
            for (int i2 = i; i2 < dArr2.length; i2++) {
                int i3 = length + i2;
                dArr3[i3] = dArr3[i3] + (d * dArr2[i2]);
            }
        }
    }

    private static void outerProductUpperTriangleWithScaling(double[] dArr, double[] dArr2, int i, double[] dArr3) {
        for (int i2 = 0; i2 < dArr.length; i2++) {
            int length = dArr2.length * i2;
            double d = dArr[i2] + (dArr2[i2] * i);
            for (int i3 = i2; i3 < dArr2.length; i3++) {
                int i4 = length + i3;
                dArr3[i4] = dArr3[i4] + (d * dArr2[i3]);
            }
        }
    }
}
