package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.functionobjects.Equals;
import org.apache.sysds.runtime.functionobjects.GreaterThan;
import org.apache.sysds.runtime.functionobjects.GreaterThanEquals;
import org.apache.sysds.runtime.functionobjects.LessThan;
import org.apache.sysds.runtime.functionobjects.LessThanEquals;
import org.apache.sysds.runtime.functionobjects.NotEquals;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.LeftScalarOperator;
import org.apache.sysds.runtime.matrix.operators.RightScalarOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRelationalOp.class */
public class CLALibRelationalOp {
    private static ThreadLocal<MatrixBlock> memPool = new ThreadLocal<MatrixBlock>() { // from class: org.apache.sysds.runtime.compress.lib.CLALibRelationalOp.1
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public MatrixBlock initialValue() {
            return null;
        }
    };

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRelationalOp$MinMaxGroup.class */
    public static class MinMaxGroup implements Comparable<MinMaxGroup> {
        double min;
        double max;
        AColGroup g;
        double[] values;

        public MinMaxGroup(double d, double d2, AColGroup aColGroup) {
            this.min = d;
            this.max = d2;
            this.g = aColGroup;
            this.values = aColGroup.getValues();
        }

        @Override // java.lang.Comparable
        public int compareTo(MinMaxGroup minMaxGroup) {
            return Double.compare(this.max - this.min, minMaxGroup.max - minMaxGroup.min);
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("MMG: ");
            sb.append("[" + this.min + "," + this.max + "]");
            sb.append(" " + this.g.getClass().getSimpleName());
            return sb.toString();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibRelationalOp$RelationalTask.class */
    public static class RelationalTask implements Callable<Object> {
        private final MinMaxGroup[] _minMax;
        private final int _i;
        private final int _blkz;
        private final MatrixBlock _res;
        private final int _rows;
        private final int _cols;
        private final ScalarOperator _sop;

        protected RelationalTask(MinMaxGroup[] minMaxGroupArr, int i, int i2, MatrixBlock matrixBlock, int i3, int i4, ScalarOperator scalarOperator) {
            this._minMax = minMaxGroupArr;
            this._i = i;
            this._blkz = i2;
            this._res = matrixBlock;
            this._rows = i3;
            this._cols = i4;
            this._sop = scalarOperator;
        }

        @Override // java.util.concurrent.Callable
        public Object call() {
            MatrixBlock matrixBlock;
            if (((MatrixBlock) CLALibRelationalOp.memPool.get()) == null) {
                CLALibRelationalOp.memPool.set(new MatrixBlock(this._blkz, this._cols, false, -1L).allocateBlock());
                matrixBlock = (MatrixBlock) CLALibRelationalOp.memPool.get();
            } else {
                matrixBlock = (MatrixBlock) CLALibRelationalOp.memPool.get();
                matrixBlock.reset(this._blkz, this._cols, false, -1L);
            }
            for (MinMaxGroup minMaxGroup : this._minMax) {
                if (minMaxGroup.g.getNumberNonZeros() != 0) {
                    minMaxGroup.g.decompressToBlockUnSafe(matrixBlock, this._i * this._blkz, Math.min((this._i + 1) * this._blkz, minMaxGroup.g.getNumRows()), 0);
                }
            }
            int i = 0;
            int i2 = this._i * this._blkz;
            while (i < this._blkz && i < this._rows - (this._i * this._blkz)) {
                for (int i3 = 0; i3 < this._cols; i3++) {
                    this._res.appendValue(i2, i3, this._sop.executeScalar(matrixBlock.quickGetValue(i, i3)));
                }
                i++;
                i2++;
            }
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean isValidForRelationalOperation(ScalarOperator scalarOperator, CompressedMatrixBlock compressedMatrixBlock) {
        return compressedMatrixBlock.isOverlapping() && ((scalarOperator.fn instanceof LessThan) || (scalarOperator.fn instanceof LessThanEquals) || (scalarOperator.fn instanceof GreaterThan) || (scalarOperator.fn instanceof GreaterThanEquals) || (scalarOperator.fn instanceof Equals) || (scalarOperator.fn instanceof NotEquals));
    }

    public static MatrixBlock overlappingRelativeRelationalOperation(ScalarOperator scalarOperator, CompressedMatrixBlock compressedMatrixBlock) {
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        boolean z = (((scalarOperator.fn instanceof LessThan) || (scalarOperator.fn instanceof LessThanEquals)) && (scalarOperator instanceof LeftScalarOperator)) || ((scalarOperator instanceof RightScalarOperator) && ((scalarOperator.fn instanceof GreaterThan) || (scalarOperator.fn instanceof GreaterThanEquals)));
        double constant = scalarOperator.getConstant();
        MinMaxGroup[] minMaxGroupArr = new MinMaxGroup[colGroups.size()];
        double d = 0.0d;
        double d2 = 0.0d;
        int i = 0;
        for (AColGroup aColGroup : colGroups) {
            double min = aColGroup.getMin();
            double max = aColGroup.getMax();
            d2 += min;
            d += max;
            int i2 = i;
            i++;
            minMaxGroupArr[i2] = new MinMaxGroup(min, max, aColGroup);
        }
        return (constant < d2 || constant > d) ? scalarOperator.fn instanceof Equals ? makeConstZero(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns()) : scalarOperator.fn instanceof NotEquals ? makeConstOne(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns()) : z ? (constant < d2 || (((scalarOperator.fn instanceof LessThanEquals) || (scalarOperator.fn instanceof GreaterThan)) && constant <= d2)) ? makeConstOne(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns()) : makeConstZero(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns()) : (constant > d2 || (((scalarOperator.fn instanceof LessThanEquals) || (scalarOperator.fn instanceof GreaterThan)) && constant >= d2)) ? makeConstOne(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns()) : makeConstZero(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns()) : processNonConstant(scalarOperator, minMaxGroupArr, d2, d, compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns(), z);
    }

    private static MatrixBlock makeConstOne(int i, int i2) {
        ArrayList arrayList = new ArrayList();
        int[] iArr = new int[i2];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            iArr[i3] = i3;
        }
        double[] dArr = new double[i2];
        Arrays.fill(dArr, 1.0d);
        arrayList.add(new ColGroupConst(iArr, i, new Dictionary(dArr)));
        CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(i, i2);
        compressedMatrixBlock.allocateColGroupList(arrayList);
        compressedMatrixBlock.setNonZeros(i2 * i);
        compressedMatrixBlock.setOverlapping(false);
        return compressedMatrixBlock;
    }

    private static MatrixBlock makeConstZero(int i, int i2) {
        return new MatrixBlock(i, i2, true, 0L);
    }

    private static MatrixBlock processNonConstant(ScalarOperator scalarOperator, MinMaxGroup[] minMaxGroupArr, double d, double d2, int i, int i2, boolean z) {
        MatrixBlock allocateBlock = new MatrixBlock(i, i2, true, 0L).allocateBlock();
        int constrainedNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
        long j = 0;
        if (constrainedNumThreads == 1) {
            int i3 = CompressionSettings.BITMAP_BLOCK_SZ / i2;
            int i4 = i < i3 ? i : i3;
            MatrixBlock allocateBlock2 = new MatrixBlock(i4, i2, false, -1L).allocateBlock();
            for (int i5 = 0; i5 * i4 < i; i5++) {
                for (MinMaxGroup minMaxGroup : minMaxGroupArr) {
                    minMaxGroup.g.decompressToBlockUnSafe(allocateBlock2, i5 * i4, Math.min((i5 + 1) * i4, i), 0);
                }
                for (int i6 = 0; i6 < i4 && i6 < i - (i5 * i4); i6++) {
                    int i7 = i6 + (i5 * i4);
                    for (int i8 = 0; i8 < i2; i8++) {
                        allocateBlock.quickSetValue(i7, i8, scalarOperator.executeScalar(allocateBlock2.quickGetValue(i6, i8)));
                        if (allocateBlock.quickGetValue(i7, i8) != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                            j++;
                        }
                    }
                }
            }
            allocateBlock2.reset();
            allocateBlock.setNonZeros(j);
        } else {
            ExecutorService executorService = CommonThreadPool.get(constrainedNumThreads);
            ArrayList arrayList = new ArrayList();
            for (int i9 = 0; i9 * 32767 < i; i9++) {
                try {
                    arrayList.add(new RelationalTask(minMaxGroupArr, i9, 32767, allocateBlock, i, i2, scalarOperator));
                } catch (InterruptedException | ExecutionException e) {
                    e.printStackTrace();
                    throw new DMLRuntimeException(e);
                }
            }
            List invokeAll = executorService.invokeAll(arrayList);
            executorService.shutdown();
            Iterator it = invokeAll.iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
        }
        memPool.remove();
        return allocateBlock;
    }
}
