package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.DMLCompressionStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.class */
public class CLALibRightMultBy {
    private static final Log LOG = LogFactory.getLog(CLALibRightMultBy.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRightMultBy$RightMatrixMultTask.class */
    public static class RightMatrixMultTask implements Callable<AColGroup> {
        private final AColGroup _colGroup;
        private final MatrixBlock _b;

        protected RightMatrixMultTask(AColGroup aColGroup, MatrixBlock matrixBlock) {
            this._colGroup = aColGroup;
            this._b = matrixBlock;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public AColGroup call() {
            try {
                return this._colGroup.rightMultByMatrix(this._b);
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        }
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        return rightMultByMatrix(compressedMatrixBlock, matrixBlock, matrixBlock2, i, ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING));
    }

    public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, boolean z) {
        int numRows = compressedMatrixBlock.getNumRows();
        int numColumns = matrixBlock.getNumColumns();
        if (compressedMatrixBlock.isEmpty() || matrixBlock.isEmpty()) {
            LOG.trace("Empty right multiply");
            if (matrixBlock2 == null) {
                matrixBlock2 = new MatrixBlock(numRows, numColumns, 0L);
            } else {
                matrixBlock2.reset(numRows, numColumns, 0L);
            }
            return matrixBlock2;
        }
        if (matrixBlock instanceof CompressedMatrixBlock) {
            matrixBlock = ((CompressedMatrixBlock) matrixBlock).getUncompressed("Uncompressed right side of right MM", i);
        }
        if (!z) {
            LOG.trace("Overlapping output not allowed in call to Right MM");
            return RMM(compressedMatrixBlock, matrixBlock, i);
        }
        CompressedMatrixBlock RMMOverlapping = RMMOverlapping(compressedMatrixBlock, matrixBlock, i);
        if (RMMOverlapping.isEmpty()) {
            return RMMOverlapping;
        }
        if (RMMOverlapping.isOverlapping()) {
            RMMOverlapping.setNonZeros(numRows * numColumns);
        } else {
            RMMOverlapping.recomputeNonZeros();
        }
        return RMMOverlapping;
    }

    private static CompressedMatrixBlock RMMOverlapping(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i) {
        int numRows = compressedMatrixBlock.getNumRows();
        int numColumns = matrixBlock.getNumColumns();
        int numRows2 = matrixBlock.getNumRows();
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        ArrayList arrayList = new ArrayList();
        CompressedMatrixBlock compressedMatrixBlock2 = new CompressedMatrixBlock(numRows, numColumns);
        double[] dArr = CLALibUtils.shouldPreFilter(colGroups) ? new double[numRows2] : null;
        List<AColGroup> filterGroups = CLALibUtils.filterGroups(colGroups, dArr);
        if (colGroups == filterGroups) {
            dArr = null;
        }
        boolean RMMSingle = i == 1 ? RMMSingle(filterGroups, matrixBlock, arrayList) : RMMParallel(filterGroups, matrixBlock, arrayList, i);
        if (dArr != null) {
            MatrixBlock matrixBlock2 = new MatrixBlock(1, dArr.length, dArr);
            MatrixBlock matrixBlock3 = new MatrixBlock(1, matrixBlock.getNumColumns(), false);
            LibMatrixMult.matrixMult(matrixBlock2, matrixBlock, matrixBlock3);
            if (!matrixBlock3.isEmpty()) {
                addConstant(matrixBlock3, arrayList);
            }
        }
        compressedMatrixBlock2.allocateColGroupList(arrayList);
        if (arrayList.size() > 1) {
            compressedMatrixBlock2.setOverlapping(true);
        }
        addEmptyColumn(arrayList, numColumns, numRows, RMMSingle);
        return compressedMatrixBlock2;
    }

    private static void addConstant(MatrixBlock matrixBlock, List<AColGroup> list) {
        int numColumns = matrixBlock.getNumColumns();
        int i = -1;
        for (int i2 = 0; i2 < list.size(); i2++) {
            AColGroup aColGroup = list.get(i2);
            if ((aColGroup instanceof ColGroupDDC) && aColGroup.getNumCols() == numColumns && aColGroup.getNumValues() < Integer.MAX_VALUE) {
                i = i2;
            }
        }
        matrixBlock.sparseToDense();
        if (i == -1) {
            list.add(ColGroupConst.create(matrixBlock.getDenseBlockValues()));
            return;
        }
        AColGroup aColGroup2 = list.get(i);
        list.remove(i);
        list.add(aColGroup2.binaryRowOpRight(new BinaryOperator(Plus.getPlusFnObject(), 1), matrixBlock.getDenseBlockValues(), true));
    }

    private static MatrixBlock RMM(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i) {
        int numRows = compressedMatrixBlock.getNumRows();
        int numColumns = matrixBlock.getNumColumns();
        int numRows2 = matrixBlock.getNumRows();
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        ArrayList arrayList = new ArrayList();
        boolean shouldPreFilter = CLALibUtils.shouldPreFilter(colGroups);
        Future<MatrixBlock> allocateBlockAsync = new MatrixBlock(numRows, numColumns, false).allocateBlockAsync();
        double[] dArr = shouldPreFilter ? new double[numRows2] : null;
        List<AColGroup> filterGroups = CLALibUtils.filterGroups(colGroups, dArr);
        if (colGroups == filterGroups) {
            dArr = null;
        }
        if (i == 1) {
            RMMSingle(filterGroups, matrixBlock, arrayList);
        } else {
            RMMParallel(filterGroups, matrixBlock, arrayList, i);
        }
        if (dArr != null) {
            dArr = ((ColGroupConst) ColGroupConst.create(dArr).rightMultByMatrix(matrixBlock)).getValues();
        }
        Timing timing = new Timing(true);
        MatrixBlock matrixBlock2 = (MatrixBlock) asyncRet(allocateBlockAsync);
        CLALibDecompress.decompressDenseMultiThread(matrixBlock2, arrayList, dArr, DataExpression.DEFAULT_DELIM_FILL_VALUE, i);
        if (DMLScript.STATISTICS) {
            DMLCompressionStatistics.addDecompressTime(timing.stop(), i);
        }
        return matrixBlock2;
    }

    private static <T> T asyncRet(Future<T> future) {
        try {
            return future.get();
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static boolean RMMSingle(List<AColGroup> list, MatrixBlock matrixBlock, List<AColGroup> list2) {
        boolean z = false;
        Iterator<AColGroup> it = list.iterator();
        while (it.hasNext()) {
            AColGroup rightMultByMatrix = it.next().rightMultByMatrix(matrixBlock);
            if (rightMultByMatrix != null) {
                list2.add(rightMultByMatrix);
            } else {
                z = true;
            }
        }
        return z;
    }

    private static boolean RMMParallel(List<AColGroup> list, MatrixBlock matrixBlock, List<AColGroup> list2, int i) {
        ExecutorService executorService = CommonThreadPool.get(i);
        boolean z = false;
        try {
            ArrayList arrayList = new ArrayList(list.size());
            Iterator<AColGroup> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(new RightMatrixMultTask(it.next(), matrixBlock));
            }
            Iterator it2 = executorService.invokeAll(arrayList).iterator();
            while (it2.hasNext()) {
                AColGroup aColGroup = (AColGroup) ((Future) it2.next()).get();
                if (aColGroup != null) {
                    list2.add(aColGroup);
                } else {
                    z = true;
                }
            }
            executorService.shutdown();
            return z;
        } catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static void addEmptyColumn(List<AColGroup> list, int i, int i2, boolean z) {
        ColGroupEmpty findEmptyColumnsAndMakeEmptyColGroup;
        if (!z || (findEmptyColumnsAndMakeEmptyColGroup = findEmptyColumnsAndMakeEmptyColGroup(list, i, i2)) == null) {
            return;
        }
        list.add(findEmptyColumnsAndMakeEmptyColGroup);
    }

    private static ColGroupEmpty findEmptyColumnsAndMakeEmptyColGroup(List<AColGroup> list, int i, int i2) {
        HashSet hashSet = new HashSet(i);
        for (int i3 = 0; i3 < i; i3++) {
            hashSet.add(Integer.valueOf(i3));
        }
        Iterator<AColGroup> it = list.iterator();
        while (it.hasNext()) {
            for (int i4 : it.next().getColIndices()) {
                hashSet.remove(Integer.valueOf(i4));
            }
        }
        if (hashSet.size() != 0) {
            return new ColGroupEmpty(hashSet.stream().mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }
        return null;
    }
}
