package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.class */
public class CLALibRightMultBy {
    private static final Log LOG = LogFactory.getLog(CLALibRightMultBy.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRightMultBy$RightMatrixMultTask.class */
    public static class RightMatrixMultTask implements Callable<AColGroup> {
        private final AColGroup _colGroup;
        private final MatrixBlock _b;

        protected RightMatrixMultTask(AColGroup aColGroup, MatrixBlock matrixBlock) {
            this._colGroup = aColGroup;
            this._b = matrixBlock;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public AColGroup call() {
            try {
                return this._colGroup.rightMultByMatrix(this._b);
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, boolean z) {
        MatrixBlock rightMultByMatrix = rightMultByMatrix(compressedMatrixBlock.getColGroups(), matrixBlock, matrixBlock2, i, compressedMatrixBlock.getMaxNumValues(), z);
        rightMultByMatrix.recomputeNonZeros();
        return rightMultByMatrix;
    }

    private static MatrixBlock rightMultByMatrix(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, Pair<Integer, int[]> pair, boolean z) {
        if (matrixBlock instanceof CompressedMatrixBlock) {
            LOG.warn("Decompression Right matrix");
        }
        MatrixBlock rightMultByMatrixOverlapping = rightMultByMatrixOverlapping(list, matrixBlock instanceof CompressedMatrixBlock ? ((CompressedMatrixBlock) matrixBlock).decompress(i) : matrixBlock, matrixBlock2, i, pair);
        if ((rightMultByMatrixOverlapping instanceof CompressedMatrixBlock) && !allowOverlappingOutput(list, z)) {
            return ((CompressedMatrixBlock) rightMultByMatrixOverlapping).decompress(i);
        }
        return rightMultByMatrixOverlapping;
    }

    private static boolean allowOverlappingOutput(List<AColGroup> list, boolean z) {
        if (z) {
            return true;
        }
        LOG.debug("Not Overlapping because it is not allowed");
        return false;
    }

    private static MatrixBlock rightMultByMatrixOverlapping(List<AColGroup> list, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, Pair<Integer, int[]> pair) {
        return rightMultByMatrixCompressed(list, matrixBlock, new CompressedMatrixBlock(list.get(0).getNumRows(), matrixBlock.getNumColumns()), i, pair);
    }

    private static MatrixBlock rightMultByMatrixCompressed(List<AColGroup> list, MatrixBlock matrixBlock, CompressedMatrixBlock compressedMatrixBlock, int i, Pair<Integer, int[]> pair) {
        ColGroupEmpty findEmptyColumnsAndMakeEmptyColGroup;
        ArrayList arrayList = new ArrayList();
        boolean z = false;
        if (i == 1) {
            Iterator<AColGroup> it = list.iterator();
            while (it.hasNext()) {
                AColGroup rightMultByMatrix = it.next().rightMultByMatrix(matrixBlock);
                if (rightMultByMatrix != null) {
                    arrayList.add(rightMultByMatrix);
                } else {
                    z = true;
                }
            }
        } else {
            ExecutorService executorService = CommonThreadPool.get(i);
            try {
                ArrayList arrayList2 = new ArrayList(list.size());
                Iterator<AColGroup> it2 = list.iterator();
                while (it2.hasNext()) {
                    arrayList2.add(new RightMatrixMultTask(it2.next(), matrixBlock));
                }
                Iterator it3 = executorService.invokeAll(arrayList2).iterator();
                while (it3.hasNext()) {
                    AColGroup aColGroup = (AColGroup) ((Future) it3.next()).get();
                    if (aColGroup != null) {
                        arrayList.add(aColGroup);
                    } else {
                        z = true;
                    }
                }
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException(e);
            }
        }
        compressedMatrixBlock.allocateColGroupList(arrayList);
        if (arrayList.size() > 1) {
            compressedMatrixBlock.setOverlapping(true);
        }
        if (z && (findEmptyColumnsAndMakeEmptyColGroup = findEmptyColumnsAndMakeEmptyColGroup(arrayList, compressedMatrixBlock.getNumColumns(), compressedMatrixBlock.getNumRows())) != null) {
            arrayList.add(findEmptyColumnsAndMakeEmptyColGroup);
        }
        return compressedMatrixBlock;
    }

    private static ColGroupEmpty findEmptyColumnsAndMakeEmptyColGroup(List<AColGroup> list, int i, int i2) {
        HashSet hashSet = new HashSet(i);
        for (int i3 = 0; i3 < i; i3++) {
            hashSet.add(Integer.valueOf(i3));
        }
        Iterator<AColGroup> it = list.iterator();
        while (it.hasNext()) {
            for (int i4 : it.next().getColIndices()) {
                hashSet.remove(Integer.valueOf(i4));
            }
        }
        if (hashSet.size() != 0) {
            return new ColGroupEmpty(hashSet.stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray(), i2);
        }
        return null;
    }
}
