package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.utils.LinearAlgebraUtils;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy.class */
public class CLALibLeftMultBy {
    private static final Log LOG = LogFactory.getLog(CLALibLeftMultBy.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy$LeftMatrixColGroupMultTask.class */
    public static class LeftMatrixColGroupMultTask implements Callable<Object> {
        private final AColGroup _group;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;
        private final Pair<Integer, int[]> _v;

        protected LeftMatrixColGroupMultTask(AColGroup aColGroup, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, Pair<Integer, int[]> pair) {
            this._group = aColGroup;
            this._that = matrixBlock;
            this._ret = matrixBlock2;
            this._rl = i;
            this._ru = i2;
            this._v = pair;
        }

        @Override // java.util.concurrent.Callable
        public Object call() {
            try {
                ColGroupValue.setupThreadLocalMemory(((Integer) this._v.getLeft()).intValue() * (this._ru - this._rl));
                this._group.leftMultByMatrix(this._that, this._ret, this._rl, this._ru);
                return null;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy$LeftMatrixMatrixMultTask.class */
    public static class LeftMatrixMatrixMultTask implements Callable<Object> {
        private final List<AColGroup> _group;
        private final MatrixBlock _that;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;
        private final Pair<Integer, int[]> _v;

        protected LeftMatrixMatrixMultTask(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, Pair<Integer, int[]> pair) {
            this._group = list;
            this._that = matrixBlock;
            this._ret = matrixBlock2;
            this._rl = i;
            this._ru = i2;
            this._v = pair;
        }

        @Override // java.util.concurrent.Callable
        public Object call() {
            try {
                ColGroupValue.setupThreadLocalMemory(((Integer) this._v.getLeft()).intValue());
                for (int i = 0; i < this._group.size(); i++) {
                    this._group.get(i).leftMultByMatrix(this._that, this._ret, this._rl, this._ru);
                }
                return null;
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibLeftMultBy$LeftMultByCompressedTransposedMatrixTask.class */
    public static class LeftMultByCompressedTransposedMatrixTask implements Callable<Object> {
        private final List<AColGroup> _groups;
        private final AColGroup _left;
        private final MatrixBlock _ret;
        private final int _start;
        private final int _end;

        protected LeftMultByCompressedTransposedMatrixTask(List<AColGroup> list, AColGroup aColGroup, MatrixBlock matrixBlock, int i, int i2) {
            this._groups = list;
            this._left = aColGroup;
            this._ret = matrixBlock;
            this._start = i;
            this._end = i2;
        }

        protected LeftMultByCompressedTransposedMatrixTask(List<AColGroup> list, AColGroup aColGroup, MatrixBlock matrixBlock) {
            this._groups = list;
            this._left = aColGroup;
            this._ret = matrixBlock;
            this._start = 0;
            this._end = list.size();
        }

        @Override // java.util.concurrent.Callable
        public Object call() {
            try {
                CLALibLeftMultBy.leftMultByCompressedTransposedMatrix(this._left, this._groups, this._ret, this._start, this._end);
                return null;
            } catch (Exception e) {
                e.printStackTrace();
                throw new DMLRuntimeException(e);
            }
        }
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        if (matrixBlock.isEmpty()) {
            return matrixBlock2;
        }
        MatrixBlock matrixBlock3 = new MatrixBlock(matrixBlock.getNumColumns(), matrixBlock.getNumRows(), false);
        LibMatrixReorg.transpose(matrixBlock, matrixBlock3);
        MatrixBlock leftMultByMatrix = leftMultByMatrix(compressedMatrixBlock, matrixBlock3, matrixBlock2, i);
        leftMultByMatrix.recomputeNonZeros();
        return leftMultByMatrix;
    }

    public static MatrixBlock leftMultByMatrixTransposed(CompressedMatrixBlock compressedMatrixBlock, CompressedMatrixBlock compressedMatrixBlock2, MatrixBlock matrixBlock, int i) {
        prepareReturnMatrix(compressedMatrixBlock, compressedMatrixBlock2, matrixBlock, true);
        leftMultByCompressedTransposedMatrix(compressedMatrixBlock.getColGroups(), compressedMatrixBlock2, matrixBlock, i, compressedMatrixBlock.getNumColumns(), compressedMatrixBlock.getMaxNumValues(), compressedMatrixBlock.isOverlapping());
        matrixBlock.recomputeNonZeros();
        return matrixBlock;
    }

    public static MatrixBlock leftMultByMatrix(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        prepareReturnMatrix(compressedMatrixBlock, matrixBlock, matrixBlock2, false);
        if (matrixBlock.isEmpty()) {
            return matrixBlock2;
        }
        MatrixBlock leftMultByMatrix = leftMultByMatrix(compressedMatrixBlock.getColGroups(), matrixBlock, matrixBlock2, i, compressedMatrixBlock.getNumColumns(), compressedMatrixBlock.getMaxNumValues(), compressedMatrixBlock.isOverlapping());
        leftMultByMatrix.recomputeNonZeros();
        return leftMultByMatrix;
    }

    private static MatrixBlock prepareReturnMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, boolean z) {
        int numColumns = z ? matrixBlock2.getNumColumns() : matrixBlock2.getNumRows();
        int numColumns2 = matrixBlock.getNumColumns();
        if (matrixBlock3 == null) {
            matrixBlock3 = new MatrixBlock(numColumns, numColumns2, false, numColumns * numColumns2);
        } else if (matrixBlock3.getNumColumns() != numColumns2 || matrixBlock3.getNumRows() != numColumns || !matrixBlock3.isAllocated()) {
            matrixBlock3.reset(numColumns, numColumns2, false, numColumns * numColumns2);
        }
        return matrixBlock3;
    }

    public static void leftMultByTransposeSelf(List<AColGroup> list, MatrixBlock matrixBlock, int i, int i2, Pair<Integer, int[]> pair, boolean z) {
        matrixBlock.allocateDenseBlock();
        if (z) {
            LOG.warn("Inefficient TSMM with overlapping matrix could be implemented multi-threaded but is not yet.");
            leftMultByCompressedTransposedMatrix(list, list, matrixBlock);
        } else if (i <= 1) {
            for (int i3 = 0; i3 < list.size(); i3++) {
                leftMultByCompressedTransposedMatrix(list.get(i3), list, matrixBlock, i3, list.size());
            }
        } else {
            try {
                ExecutorService executorService = CommonThreadPool.get(i);
                ArrayList arrayList = new ArrayList();
                for (int i4 = 0; i4 < list.size(); i4++) {
                    arrayList.add(new LeftMultByCompressedTransposedMatrixTask(list, list.get(i4), matrixBlock, i4, list.size()));
                }
                Iterator it = executorService.invokeAll(arrayList).iterator();
                while (it.hasNext()) {
                    ((Future) it.next()).get();
                }
                executorService.shutdown();
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        copyToUpperTriangle(matrixBlock.getDenseBlockValues(), i2);
        matrixBlock.setNonZeros(LinearAlgebraUtils.copyUpperToLowerTriangle(matrixBlock));
        matrixBlock.examSparsity();
    }

    private static void copyToUpperTriangle(double[] dArr, int i) {
        int i2 = 0;
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i2 >= i) {
                return;
            }
            int i5 = i2;
            int i6 = i2 * i;
            while (true) {
                int i7 = i6;
                if (i5 < i) {
                    if (dArr[i4 + i5] == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                        dArr[i4 + i5] = dArr[i2 + i7];
                    }
                    i5++;
                    i6 = i7 + i;
                }
            }
            i2++;
            i3 = i4 + i;
        }
    }

    private static MatrixBlock leftMultByCompressedTransposedMatrix(List<AColGroup> list, CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i, int i2, Pair<Integer, int[]> pair, boolean z) {
        matrixBlock.allocateDenseBlock();
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        if (i <= 1 || z || compressedMatrixBlock.isOverlapping()) {
            if (z || compressedMatrixBlock.isOverlapping()) {
                LOG.warn("Inefficient Compressed multiplication with overlapping matrix could be implemented multi-threaded but is not yet.");
            }
            leftMultByCompressedTransposedMatrix(list, colGroups, matrixBlock);
        } else {
            try {
                ExecutorService executorService = CommonThreadPool.get(i);
                ArrayList arrayList = new ArrayList();
                for (int i3 = 0; i3 < colGroups.size(); i3++) {
                    arrayList.add(new LeftMultByCompressedTransposedMatrixTask(list, colGroups.get(i3), matrixBlock));
                }
                Iterator it = executorService.invokeAll(arrayList).iterator();
                while (it.hasNext()) {
                    ((Future) it.next()).get();
                }
                executorService.shutdown();
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        matrixBlock.recomputeNonZeros();
        return matrixBlock;
    }

    private static void leftMultByCompressedTransposedMatrix(List<AColGroup> list, List<AColGroup> list2, MatrixBlock matrixBlock) {
        Iterator<AColGroup> it = list2.iterator();
        while (it.hasNext()) {
            leftMultByCompressedTransposedMatrix(it.next(), list, matrixBlock, 0, list.size());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void leftMultByCompressedTransposedMatrix(AColGroup aColGroup, List<AColGroup> list, MatrixBlock matrixBlock, int i, int i2) {
        while (i < i2) {
            AColGroup aColGroup2 = list.get(i);
            if (aColGroup2 != aColGroup) {
                aColGroup2.leftMultByAColGroup(aColGroup, matrixBlock);
            } else {
                aColGroup2.tsmm(matrixBlock.getDenseBlockValues(), matrixBlock.getNumColumns());
            }
            i++;
        }
    }

    private static MatrixBlock leftMultByMatrix(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, Pair<Integer, int[]> pair, boolean z) {
        if (matrixBlock.isEmpty()) {
            matrixBlock2.setNonZeros(0L);
            return matrixBlock2;
        }
        matrixBlock2.allocateDenseBlock();
        if (i == 1) {
            for (int i3 = 0; i3 < list.size(); i3++) {
                list.get(i3).leftMultByMatrix(matrixBlock, matrixBlock2);
            }
        } else {
            try {
                ExecutorService executorService = CommonThreadPool.get(i);
                ArrayList arrayList = new ArrayList();
                if (z) {
                    for (int i4 = 0; i4 < matrixBlock.getNumRows(); i4++) {
                        arrayList.add(new LeftMatrixMatrixMultTask(list, matrixBlock, matrixBlock2, i4, Math.min(i4 + 1, matrixBlock.getNumRows()), pair));
                    }
                } else {
                    for (AColGroup aColGroup : list) {
                        for (int i5 = 0; i5 < matrixBlock.getNumRows(); i5++) {
                            arrayList.add(new LeftMatrixColGroupMultTask(aColGroup, matrixBlock, matrixBlock2, i5, Math.min(i5 + 1, matrixBlock.getNumRows()), pair));
                        }
                    }
                }
                List invokeAll = executorService.invokeAll(arrayList);
                executorService.shutdown();
                Iterator it = invokeAll.iterator();
                while (it.hasNext()) {
                    ((Future) it.next()).get();
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        matrixBlock2.recomputeNonZeros();
        return matrixBlock2;
    }
}
