package org.apache.sysds.runtime.data;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/data/LibTensorAgg.class */
public class LibTensorAgg {

    /* loaded from: input_file:org/apache/sysds/runtime/data/LibTensorAgg$AggTask.class */
    private static abstract class AggTask implements Callable<Object> {
        private AggTask() {
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/data/LibTensorAgg$AggType.class */
    public enum AggType {
        SUM,
        INVALID
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/data/LibTensorAgg$PartialAggTask.class */
    public static class PartialAggTask extends AggTask {
        private BasicTensorBlock _in;
        private BasicTensorBlock _ret;
        private AggType _aggtype;
        private AggregateUnaryOperator _uaop;
        private int _rl;
        private int _ru;

        protected PartialAggTask(BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2, AggType aggType, AggregateUnaryOperator aggregateUnaryOperator, int i, int i2) {
            this._in = basicTensorBlock;
            this._ret = basicTensorBlock2;
            this._aggtype = aggType;
            this._uaop = aggregateUnaryOperator;
            this._rl = i;
            this._ru = i2;
        }

        @Override // java.util.concurrent.Callable
        public Object call() {
            this._ret = new BasicTensorBlock(this._ret._vt, new int[]{this._ret.getDim(0), this._ret.getDim(1)});
            this._ret.allocateDenseBlock();
            LibTensorAgg.aggregateUnaryTensorPartial(this._in, this._ret, this._aggtype, this._uaop.aggOp.increOp.fn, this._rl, this._ru);
            return null;
        }

        public BasicTensorBlock getResult() {
            return this._ret;
        }
    }

    public static boolean satisfiesMultiThreadingConstraints(BasicTensorBlock basicTensorBlock, int i) {
        return i > 1 && basicTensorBlock._vt != Types.ValueType.BOOLEAN;
    }

    public static void aggregateUnaryTensor(BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2, AggregateUnaryOperator aggregateUnaryOperator) {
        AggType aggType = getAggType(aggregateUnaryOperator);
        if (basicTensorBlock.isEmpty(false)) {
            aggregateUnaryTensorEmpty(basicTensorBlock, basicTensorBlock2, aggType);
            return;
        }
        int numThreads = aggregateUnaryOperator.getNumThreads();
        if (!satisfiesMultiThreadingConstraints(basicTensorBlock, numThreads)) {
            if (basicTensorBlock.isSparse()) {
                throw new NotImplementedException("Tensor aggregation not supported for sparse tensors.");
            }
            aggregateUnaryTensorPartial(basicTensorBlock, basicTensorBlock2, aggType, aggregateUnaryOperator.aggOp.increOp.fn, 0, basicTensorBlock.getDim(0));
            return;
        }
        try {
            ExecutorService executorService = CommonThreadPool.get(numThreads);
            ArrayList arrayList = new ArrayList();
            ArrayList<Integer> balancedBlockSizesDefault = UtilFunctions.getBalancedBlockSizesDefault(basicTensorBlock.getDim(0), numThreads, false);
            int i = 0;
            for (int i2 = 0; i2 < balancedBlockSizesDefault.size(); i2++) {
                arrayList.add(new PartialAggTask(basicTensorBlock, basicTensorBlock2, aggType, aggregateUnaryOperator, i, i + balancedBlockSizesDefault.get(i2).intValue()));
                i += balancedBlockSizesDefault.get(i2).intValue();
            }
            executorService.invokeAll(arrayList);
            executorService.shutdown();
            basicTensorBlock2.copy(((PartialAggTask) arrayList.get(0)).getResult());
            for (int i3 = 1; i3 < arrayList.size(); i3++) {
                aggregateFinalResult(aggregateUnaryOperator.aggOp, basicTensorBlock2, ((PartialAggTask) arrayList.get(i3)).getResult());
            }
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static void aggregateUnaryTensorEmpty(BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2, AggType aggType) {
        basicTensorBlock2.set(new int[]{0, 0}, Double.valueOf(aggType == AggType.SUM ? 0.0d : Double.NaN));
    }

    public static void aggregateBinaryTensor(BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2, AggregateOperator aggregateOperator) {
        if (basicTensorBlock.getLength() != basicTensorBlock2.getLength()) {
            throw new DMLRuntimeException("Binary tensor aggregation requires consistent numbers of cells (" + Arrays.toString(basicTensorBlock._dims) + ", " + Arrays.toString(basicTensorBlock2._dims) + ").");
        }
        if (aggregateOperator.existsCorrection()) {
            throw new DMLRuntimeException("Corrections not supported for tensors yet");
        }
        if (!(aggregateOperator.increOp.fn instanceof Plus)) {
            throw new DMLRuntimeException("Binary aggregation of this type not supported for tensors yet");
        }
        int[] iArr = new int[basicTensorBlock.getNumDims()];
        switch (basicTensorBlock.getValueType()) {
            case INT64:
                basicTensorBlock2.set(iArr, Long.valueOf(((Long) basicTensorBlock.get(iArr)).longValue() + ((Long) basicTensorBlock2.get(iArr)).longValue()));
                return;
            case INT32:
                basicTensorBlock2.set(iArr, Integer.valueOf(((Integer) basicTensorBlock.get(iArr)).intValue() + ((Integer) basicTensorBlock2.get(iArr)).intValue()));
                return;
            default:
                basicTensorBlock2.set(0, 0, basicTensorBlock.get(0, 0) + basicTensorBlock2.get(0, 0));
                return;
        }
    }

    private static AggType getAggType(AggregateUnaryOperator aggregateUnaryOperator) {
        return aggregateUnaryOperator.aggOp.increOp.fn instanceof Plus ? AggType.SUM : AggType.INVALID;
    }

    public static boolean isSupportedUnaryAggregateOperator(AggregateUnaryOperator aggregateUnaryOperator) {
        return getAggType(aggregateUnaryOperator) != AggType.INVALID;
    }

    private static void aggregateUnaryTensorPartial(BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2, AggType aggType, ValueFunction valueFunction, int i, int i2) {
        if (aggType == AggType.SUM) {
            sum(basicTensorBlock, basicTensorBlock2, (Plus) valueFunction, i, i2);
        }
    }

    private static void aggregateFinalResult(AggregateOperator aggregateOperator, BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2) {
        if (aggregateOperator.existsCorrection()) {
            throw new NotImplementedException();
        }
        basicTensorBlock.incrementalAggregate(aggregateOperator, basicTensorBlock2);
    }

    private static void sum(BasicTensorBlock basicTensorBlock, BasicTensorBlock basicTensorBlock2, Plus plus, int i, int i2) {
        if (basicTensorBlock.isSparse()) {
            throw new DMLRuntimeException("Sparse aggregation not implemented for Tensor");
        }
        switch (basicTensorBlock.getValueType()) {
            case INT64:
            case INT32:
            case UINT8:
                DenseBlock denseBlock = basicTensorBlock.getDenseBlock();
                long j = 0;
                int[] iArr = new int[denseBlock.numDims()];
                for (int i3 = i; i3 < i2; i3++) {
                    iArr[0] = i3;
                    for (int i4 = 0; i4 < denseBlock.getCumODims(0); i4++) {
                        iArr[iArr.length - 1] = i4;
                        j += denseBlock.getLong(iArr);
                    }
                }
                basicTensorBlock2.set(new int[basicTensorBlock2.getNumDims()], Long.valueOf(j));
                return;
            case BOOLEAN:
                basicTensorBlock2.set(0, 0, basicTensorBlock.getDenseBlock().countNonZeros());
                return;
            case STRING:
                throw new DMLRuntimeException("Sum over string tensor is not supported.");
            case FP64:
            case FP32:
                DenseBlock denseBlock2 = basicTensorBlock.getDenseBlock();
                double d = 0.0d;
                for (int i5 = i; i5 < i2; i5++) {
                    for (int i6 = 0; i6 < denseBlock2.getCumODims(0); i6++) {
                        d = plus.execute(d, denseBlock2.get(i5, i6));
                    }
                }
                basicTensorBlock2.set(0, 0, d);
                return;
            case UNKNOWN:
                throw new NotImplementedException();
            default:
                return;
        }
    }
}
