package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupOLE;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibScalar.class */
public final class CLALibScalar {
    private static final Log LOG = LogFactory.getLog(CLALibScalar.class.getName());
    private static final int MINIMUM_PARALLEL_SIZE = 8096;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibScalar$ScalarTask.class */
    public static class ScalarTask implements Callable<List<AColGroup>> {
        private final List<AColGroup> _colGroups;
        private final ScalarOperator _sop;

        protected ScalarTask(List<AColGroup> list, ScalarOperator scalarOperator) {
            this._colGroups = list;
            this._sop = scalarOperator;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public List<AColGroup> call() {
            ArrayList arrayList = new ArrayList();
            Iterator<AColGroup> it = this._colGroups.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().scalarOperation(this._sop));
            }
            return arrayList;
        }
    }

    private CLALibScalar() {
    }

    public static MatrixBlock scalarOperations(ScalarOperator scalarOperator, CompressedMatrixBlock compressedMatrixBlock, MatrixValue matrixValue) {
        if (isInvalidForCompressedOutput(compressedMatrixBlock, scalarOperator)) {
            LOG.warn("scalar overlapping not supported for op: " + scalarOperator.fn.getClass().getSimpleName());
            return compressedMatrixBlock.decompress(scalarOperator.getNumThreads()).scalarOperations(scalarOperator, matrixValue);
        }
        CompressedMatrixBlock compressedMatrixBlock2 = setupRet(compressedMatrixBlock, matrixValue);
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        if (!compressedMatrixBlock.isOverlapping() || (scalarOperator.fn instanceof Multiply) || (scalarOperator.fn instanceof Divide)) {
            int numThreads = scalarOperator.getNumThreads() > 1 ? scalarOperator.getNumThreads() : OptimizerUtils.getConstrainedNumThreads(-1);
            if (numThreads > 1) {
                parallelScalarOperations(scalarOperator, colGroups, compressedMatrixBlock2, numThreads);
            } else {
                ArrayList arrayList = new ArrayList();
                Iterator<AColGroup> it = colGroups.iterator();
                while (it.hasNext()) {
                    arrayList.add(it.next().scalarOperation(scalarOperator));
                }
                compressedMatrixBlock2.allocateColGroupList(arrayList);
            }
            compressedMatrixBlock2.setOverlapping(compressedMatrixBlock.isOverlapping());
        } else {
            double executeScalar = scalarOperator.executeScalar(DataExpression.DEFAULT_DELIM_FILL_VALUE);
            ColGroupConst constOverlap = executeScalar != DataExpression.DEFAULT_DELIM_FILL_VALUE ? constOverlap(compressedMatrixBlock, executeScalar) : null;
            compressedMatrixBlock2.allocateColGroupList((scalarOperator instanceof LeftScalarOperator) && (scalarOperator.fn instanceof Minus) ? copyGroupsAndMultMinus(compressedMatrixBlock, scalarOperator, constOverlap, compressedMatrixBlock2) : copyGroups(compressedMatrixBlock, scalarOperator, constOverlap, compressedMatrixBlock2));
            compressedMatrixBlock2.setOverlapping(true);
        }
        compressedMatrixBlock2.recomputeNonZeros();
        return compressedMatrixBlock2;
    }

    private static CompressedMatrixBlock setupRet(CompressedMatrixBlock compressedMatrixBlock, MatrixValue matrixValue) {
        CompressedMatrixBlock compressedMatrixBlock2;
        if (matrixValue == null || !(matrixValue instanceof CompressedMatrixBlock)) {
            compressedMatrixBlock2 = new CompressedMatrixBlock(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns());
        } else {
            compressedMatrixBlock2 = (CompressedMatrixBlock) matrixValue;
            compressedMatrixBlock2.setNumColumns(compressedMatrixBlock.getNumColumns());
            compressedMatrixBlock2.setNumRows(compressedMatrixBlock.getNumRows());
        }
        return compressedMatrixBlock2;
    }

    private static ColGroupConst constOverlap(CompressedMatrixBlock compressedMatrixBlock, double d) {
        return (ColGroupConst) ColGroupConst.create(compressedMatrixBlock.getNumColumns(), d);
    }

    private static List<AColGroup> copyGroups(CompressedMatrixBlock compressedMatrixBlock, ScalarOperator scalarOperator, ColGroupConst colGroupConst, CompressedMatrixBlock compressedMatrixBlock2) {
        double[] values = colGroupConst != null ? colGroupConst.getValues() : null;
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        ArrayList arrayList = new ArrayList(colGroups.size() + 1);
        for (AColGroup aColGroup : colGroups) {
            if (!(aColGroup instanceof ColGroupEmpty)) {
                if (aColGroup instanceof ColGroupConst) {
                    double[] values2 = ((ColGroupConst) aColGroup).getValues();
                    IColIndex colIndices = aColGroup.getColIndices();
                    if (values != null) {
                        for (int i = 0; i < colIndices.size(); i++) {
                            int i2 = colIndices.get(i);
                            values[i2] = values[i2] + values2[i];
                        }
                    }
                } else {
                    arrayList.add(aColGroup);
                }
            }
        }
        if (colGroupConst != null) {
            arrayList.add(colGroupConst);
        }
        return arrayList;
    }

    private static List<AColGroup> copyGroupsAndMultMinus(CompressedMatrixBlock compressedMatrixBlock, ScalarOperator scalarOperator, ColGroupConst colGroupConst, CompressedMatrixBlock compressedMatrixBlock2) {
        double[] values = colGroupConst.getValues();
        ArrayList arrayList = new ArrayList();
        for (AColGroup aColGroup : compressedMatrixBlock.getColGroups()) {
            if (!(aColGroup instanceof ColGroupEmpty)) {
                if (aColGroup instanceof ColGroupConst) {
                    double[] values2 = ((ColGroupConst) aColGroup).getValues();
                    IColIndex colIndices = aColGroup.getColIndices();
                    for (int i = 0; i < colIndices.size(); i++) {
                        int i2 = colIndices.get(i);
                        values[i2] = values[i2] - values2[i];
                    }
                } else {
                    arrayList.add(aColGroup.scalarOperation(new RightScalarOperator(Multiply.getMultiplyFnObject(), -1.0d)));
                }
            }
        }
        arrayList.add(colGroupConst);
        return arrayList;
    }

    private static boolean isInvalidForCompressedOutput(CompressedMatrixBlock compressedMatrixBlock, ScalarOperator scalarOperator) {
        return (!compressedMatrixBlock.isOverlapping() || (scalarOperator.fn instanceof Multiply) || ((scalarOperator.fn instanceof Divide) && (scalarOperator instanceof RightScalarOperator)) || (scalarOperator.fn instanceof Plus) || (scalarOperator.fn instanceof Minus)) ? false : true;
    }

    private static void parallelScalarOperations(ScalarOperator scalarOperator, List<AColGroup> list, CompressedMatrixBlock compressedMatrixBlock, int i) {
        if (list == null) {
            return;
        }
        ExecutorService executorService = CommonThreadPool.get(i);
        try {
            List invokeAll = executorService.invokeAll(partition(scalarOperator, list));
            ArrayList arrayList = new ArrayList();
            Iterator it = invokeAll.iterator();
            while (it.hasNext()) {
                arrayList.addAll((Collection) ((Future) it.next()).get());
            }
            compressedMatrixBlock.allocateColGroupList(arrayList);
            executorService.shutdown();
        } catch (InterruptedException | ExecutionException e) {
            executorService.shutdown();
            throw new DMLRuntimeException(e);
        }
    }

    private static List<ScalarTask> partition(ScalarOperator scalarOperator, List<AColGroup> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (AColGroup aColGroup : list) {
            if (aColGroup instanceof ColGroupUncompressed) {
                ArrayList arrayList3 = new ArrayList();
                arrayList3.add(aColGroup);
                arrayList.add(new ScalarTask(arrayList3, scalarOperator));
            } else if (aColGroup.getNumValues() * aColGroup.getNumCols() >= MINIMUM_PARALLEL_SIZE || (aColGroup instanceof ColGroupOLE)) {
                ArrayList arrayList4 = new ArrayList();
                arrayList4.add(aColGroup);
                arrayList.add(new ScalarTask(arrayList4, scalarOperator));
            } else {
                arrayList2.add(aColGroup);
            }
            if (arrayList2.size() > 10) {
                arrayList.add(new ScalarTask(arrayList2, scalarOperator));
                arrayList2 = new ArrayList();
            }
        }
        if (arrayList2.size() > 0) {
            arrayList.add(new ScalarTask(arrayList2, scalarOperator));
        }
        return arrayList;
    }
}
