package org.apache.sysds.runtime.codegen;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.codegen.SpoofCUDAOperator;
import org.apache.sysds.runtime.codegen.SpoofCellwise;
import org.apache.sysds.runtime.codegen.SpoofOperator;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofMultiAggregate.class */
public abstract class SpoofMultiAggregate extends SpoofOperator implements Serializable {
    private static final long serialVersionUID = -6164871955591089349L;
    private final SpoofCellwise.AggOp[] _aggOps;
    private final boolean _sparseSafe;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofMultiAggregate$ParAggTask.class */
    public class ParAggTask implements Callable<double[]> {
        private final MatrixBlock _a;
        private final SpoofOperator.SideInput[] _b;
        private final double[] _scalars;
        private final int _rlen;
        private final int _clen;
        private final boolean _safe;
        private final int _rl;
        private final int _ru;

        protected ParAggTask(MatrixBlock matrixBlock, SpoofOperator.SideInput[] sideInputArr, double[] dArr, int i, int i2, boolean z, int i3, int i4) {
            this._a = matrixBlock;
            this._b = sideInputArr;
            this._scalars = dArr;
            this._rlen = i;
            this._clen = i2;
            this._safe = z;
            this._rl = i3;
            this._ru = i4;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public double[] call() {
            double[] dArr = new double[SpoofMultiAggregate.this._aggOps.length];
            SpoofMultiAggregate.this.setInitialOutputValues(dArr);
            if (this._a.isInSparseFormat()) {
                SpoofMultiAggregate.this.executeSparse(this._a.getSparseBlock(), this._b, this._scalars, dArr, this._rlen, this._clen, this._safe, this._rl, this._ru, 0L);
            } else {
                SpoofMultiAggregate.this.executeDense(this._a.getDenseBlock(), this._b, this._scalars, dArr, this._rlen, this._clen, this._safe, this._rl, this._ru, 0L);
            }
            return dArr;
        }
    }

    public SpoofMultiAggregate(boolean z, SpoofCellwise.AggOp... aggOpArr) {
        this._sparseSafe = z;
        this._aggOps = aggOpArr;
    }

    public SpoofCellwise.AggOp[] getAggOps() {
        return this._aggOps;
    }

    public boolean isSparseSafe() {
        return this._sparseSafe;
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public String getSpoofType() {
        return "MA" + getClass().getName().split("\\.")[1];
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public SpoofCUDAOperator createCUDAInstrcution(Integer num, SpoofCUDAOperator.PrecisionProxy precisionProxy) {
        return null;
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock) {
        return execute(arrayList, arrayList2, matrixBlock, 1, 0L);
    }

    @Override // org.apache.sysds.runtime.codegen.SpoofOperator
    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock, int i) {
        return execute(arrayList, arrayList2, matrixBlock, i, 0L);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public MatrixBlock execute(ArrayList<MatrixBlock> arrayList, ArrayList<ScalarObject> arrayList2, MatrixBlock matrixBlock, int i, long j) {
        if (arrayList == null || arrayList.size() < 1) {
            throw new RuntimeException("Invalid input arguments.");
        }
        if ((isSparseSafe() ? getTotalInputNnz(arrayList) : getTotalInputSize(arrayList)) < 1048576) {
            i = 1;
        }
        matrixBlock.reset(1, this._aggOps.length, false);
        matrixBlock.allocateDenseBlock();
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        setInitialOutputValues(denseBlockValues);
        SpoofOperator.SideInput[] prepInputMatrices = prepInputMatrices(arrayList);
        double[] prepInputScalars = prepInputScalars(arrayList2);
        int numRows = arrayList.get(0).getNumRows();
        int numColumns = arrayList.get(0).getNumColumns();
        boolean isSparseSafe = isSparseSafe();
        if (i > 1) {
            try {
                ExecutorService executorService = CommonThreadPool.get(i);
                ArrayList arrayList3 = new ArrayList();
                int roundToNext = UtilFunctions.roundToNext(Math.min(8 * i, numRows / 32), i);
                int ceil = (int) Math.ceil(numRows / roundToNext);
                int i2 = 0;
                while (true) {
                    if (!(i2 < roundToNext) || !(i2 * ceil < numRows)) {
                        break;
                    }
                    arrayList3.add(new ParAggTask(arrayList.get(0), prepInputMatrices, prepInputScalars, numRows, numColumns, isSparseSafe, i2 * ceil, Math.min((i2 + 1) * ceil, numRows)));
                    i2++;
                }
                List invokeAll = executorService.invokeAll(arrayList3);
                executorService.shutdown();
                ArrayList arrayList4 = new ArrayList();
                Iterator it = invokeAll.iterator();
                while (it.hasNext()) {
                    arrayList4.add(((Future) it.next()).get());
                }
                aggregatePartialResults(denseBlockValues, arrayList4);
            } catch (Exception e) {
                throw new DMLRuntimeException(e);
            }
        } else if (arrayList.get(0).isInSparseFormat()) {
            executeSparse(arrayList.get(0).getSparseBlock(), prepInputMatrices, prepInputScalars, denseBlockValues, numRows, numColumns, isSparseSafe, 0, numRows, j);
        } else {
            executeDense(arrayList.get(0).getDenseBlock(), prepInputMatrices, prepInputScalars, denseBlockValues, numRows, numColumns, isSparseSafe, 0, numRows, j);
        }
        matrixBlock.recomputeNonZeros();
        matrixBlock.examSparsity();
        return matrixBlock;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void executeDense(DenseBlock denseBlock, SpoofOperator.SideInput[] sideInputArr, double[] dArr, double[] dArr2, int i, int i2, boolean z, int i3, int i4, long j) {
        SpoofOperator.SideInput[] createSparseSideInputs = createSparseSideInputs(sideInputArr);
        if (denseBlock == null && !z) {
            for (int i5 = i3; i5 < i4; i5++) {
                for (int i6 = 0; i6 < i2; i6++) {
                    genexec(DataExpression.DEFAULT_DELIM_FILL_VALUE, createSparseSideInputs, dArr, dArr2, i, i2, j + i5, i5, i6);
                }
            }
            return;
        }
        if (denseBlock != null) {
            for (int i7 = i3; i7 < i4; i7++) {
                double[] values = denseBlock.values(i7);
                int pos = denseBlock.pos(i7);
                for (int i8 = 0; i8 < i2; i8++) {
                    genexec(values[pos + i8], createSparseSideInputs, dArr, dArr2, i, i2, j + i7, i7, i8);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void executeSparse(SparseBlock sparseBlock, SpoofOperator.SideInput[] sideInputArr, double[] dArr, double[] dArr2, int i, int i2, boolean z, int i3, int i4, long j) {
        if (sparseBlock == null && z) {
            return;
        }
        SpoofOperator.SideInput[] createSparseSideInputs = createSparseSideInputs(sideInputArr);
        for (int i5 = i3; i5 < i4; i5++) {
            int i6 = -1;
            if (sparseBlock != null && !sparseBlock.isEmpty(i5)) {
                int pos = sparseBlock.pos(i5);
                int size = sparseBlock.size(i5);
                int[] indexes = sparseBlock.indexes(i5);
                double[] values = sparseBlock.values(i5);
                for (int i7 = pos; i7 < pos + size; i7++) {
                    if (!z) {
                        for (int i8 = i6 + 1; i8 < indexes[i7]; i8++) {
                            genexec(DataExpression.DEFAULT_DELIM_FILL_VALUE, createSparseSideInputs, dArr, dArr2, i, i2, j + i5, i5, i8);
                        }
                    }
                    i6 = indexes[i7];
                    genexec(values[i7], createSparseSideInputs, dArr, dArr2, i, i2, j + i5, i5, i6);
                }
            }
            if (!z) {
                for (int i9 = i6 + 1; i9 < i2; i9++) {
                    genexec(DataExpression.DEFAULT_DELIM_FILL_VALUE, createSparseSideInputs, dArr, dArr2, i, i2, j + i5, i5, i9);
                }
            }
        }
    }

    protected final void genexec(double d, SpoofOperator.SideInput[] sideInputArr, double[] dArr, double[] dArr2, int i, int i2, int i3, int i4) {
        genexec(d, sideInputArr, dArr, dArr2, i, i2, i3, i3, i4);
    }

    protected abstract void genexec(double d, SpoofOperator.SideInput[] sideInputArr, double[] dArr, double[] dArr2, int i, int i2, long j, int i3, int i4);

    /* JADX INFO: Access modifiers changed from: private */
    public void setInitialOutputValues(double[] dArr) {
        for (int i = 0; i < this._aggOps.length; i++) {
            dArr[i] = getInitialValue(this._aggOps[i]);
        }
    }

    public static double getInitialValue(SpoofCellwise.AggOp aggOp) {
        switch (aggOp) {
            case SUM:
            case SUM_SQ:
                return DataExpression.DEFAULT_DELIM_FILL_VALUE;
            case MIN:
                return Double.POSITIVE_INFINITY;
            case MAX:
                return Double.NEGATIVE_INFINITY;
            default:
                return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
    }

    private void aggregatePartialResults(double[] dArr, ArrayList<double[]> arrayList) {
        ValueFunction[] aggFunctions = getAggFunctions(this._aggOps);
        for (int i = 0; i < this._aggOps.length; i++) {
            if (aggFunctions[i] instanceof KahanFunction) {
                KahanObject kahanObject = new KahanObject(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
                KahanPlus kahanPlusFnObject = KahanPlus.getKahanPlusFnObject();
                Iterator<double[]> it = arrayList.iterator();
                while (it.hasNext()) {
                    kahanPlusFnObject.execute2(kahanObject, it.next()[i]);
                }
                dArr[i] = kahanObject._sum;
            } else {
                Iterator<double[]> it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    dArr[i] = aggFunctions[i].execute(dArr[i], it2.next()[i]);
                }
            }
        }
    }

    public static void aggregatePartialResults(SpoofCellwise.AggOp[] aggOpArr, MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        ValueFunction[] aggFunctions = getAggFunctions(aggOpArr);
        for (int i = 0; i < aggOpArr.length; i++) {
            if (aggFunctions[i] instanceof KahanFunction) {
                KahanObject kahanObject = new KahanObject(matrixBlock.quickGetValue(0, i), DataExpression.DEFAULT_DELIM_FILL_VALUE);
                KahanPlus.getKahanPlusFnObject().execute2(kahanObject, matrixBlock2.quickGetValue(0, i));
                matrixBlock.quickSetValue(0, i, kahanObject._sum);
            } else {
                matrixBlock.quickSetValue(0, i, aggFunctions[i].execute(matrixBlock.quickGetValue(0, i), matrixBlock2.quickGetValue(0, i)));
            }
        }
    }

    public static ValueFunction[] getAggFunctions(SpoofCellwise.AggOp[] aggOpArr) {
        ValueFunction[] valueFunctionArr = new ValueFunction[aggOpArr.length];
        for (int i = 0; i < aggOpArr.length; i++) {
            switch (aggOpArr[i]) {
                case SUM:
                    valueFunctionArr[i] = KahanPlus.getKahanPlusFnObject();
                    break;
                case SUM_SQ:
                    valueFunctionArr[i] = KahanPlusSq.getKahanPlusSqFnObject();
                    break;
                case MIN:
                    valueFunctionArr[i] = Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN);
                    break;
                case MAX:
                    valueFunctionArr[i] = Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX);
                    break;
                default:
                    throw new RuntimeException("Unsupported aggregation type: " + aggOpArr[i].name());
            }
        }
        return valueFunctionArr;
    }
}
