package org.apache.sysds.runtime.matrix.data;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixTercell.class */
public class LibMatrixTercell {
    private static final long PAR_NUMCELL_THRESHOLD = 8192;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixTercell$TercellTask.class */
    public static class TercellTask implements Callable<Long> {
        private final MatrixBlock _m1;
        private final MatrixBlock _m2;
        private final MatrixBlock _m3;
        private final boolean _s1;
        private final boolean _s2;
        private final boolean _s3;
        private final double _d1;
        private final double _d2;
        private final double _d3;
        private final MatrixBlock _ret;
        private final TernaryOperator _op;
        private final int _rl;
        private final int _ru;

        protected TercellTask(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MatrixBlock matrixBlock4, TernaryOperator ternaryOperator, boolean z, boolean z2, boolean z3, double d, double d2, double d3, int i, int i2) {
            this._m1 = matrixBlock;
            this._m2 = matrixBlock2;
            this._m3 = matrixBlock3;
            this._s1 = z;
            this._s2 = z2;
            this._s3 = z3;
            this._d1 = d;
            this._d2 = d2;
            this._d3 = d3;
            this._ret = matrixBlock4;
            this._op = ternaryOperator;
            this._rl = i;
            this._ru = i2;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() {
            return Long.valueOf(LibMatrixTercell.unsafeTernary(this._m1, this._m2, this._m3, this._ret, this._op, this._s1, this._s2, this._s3, this._d1, this._d2, this._d3, this._rl, this._ru));
        }
    }

    private LibMatrixTercell() {
    }

    public static void tercellOp(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MatrixBlock matrixBlock4, TernaryOperator ternaryOperator) {
        boolean z = matrixBlock.rlen == 1 && matrixBlock.clen == 1;
        boolean z2 = matrixBlock2.rlen == 1 && matrixBlock2.clen == 1;
        boolean z3 = matrixBlock3.rlen == 1 && matrixBlock3.clen == 1;
        double quickGetValue = z ? matrixBlock.quickGetValue(0, 0) : Double.NaN;
        double quickGetValue2 = z2 ? matrixBlock2.quickGetValue(0, 0) : Double.NaN;
        double quickGetValue3 = z3 ? matrixBlock3.quickGetValue(0, 0) : Double.NaN;
        matrixBlock4.allocateBlock();
        if (ternaryOperator.getNumThreads() <= 1 || matrixBlock4.getLength() <= PAR_NUMCELL_THRESHOLD) {
            matrixBlock4.setNonZeros(unsafeTernary(matrixBlock, matrixBlock2, matrixBlock3, matrixBlock4, ternaryOperator, z, z2, z3, quickGetValue, quickGetValue2, quickGetValue3, 0, matrixBlock4.rlen));
            return;
        }
        try {
            ExecutorService executorService = CommonThreadPool.get(ternaryOperator.getNumThreads());
            ArrayList arrayList = new ArrayList();
            ArrayList<Integer> balancedBlockSizesDefault = UtilFunctions.getBalancedBlockSizesDefault(matrixBlock4.rlen, ternaryOperator.getNumThreads(), false);
            int i = 0;
            for (int i2 = 0; i2 < balancedBlockSizesDefault.size(); i2++) {
                arrayList.add(new TercellTask(matrixBlock, matrixBlock2, matrixBlock3, matrixBlock4, ternaryOperator, z, z2, z3, quickGetValue, quickGetValue2, quickGetValue3, i, i + balancedBlockSizesDefault.get(i2).intValue()));
                i += balancedBlockSizesDefault.get(i2).intValue();
            }
            List invokeAll = executorService.invokeAll(arrayList);
            matrixBlock4.nonZeros = 0L;
            Iterator it = invokeAll.iterator();
            while (it.hasNext()) {
                matrixBlock4.nonZeros += ((Long) ((Future) it.next()).get()).longValue();
            }
            executorService.shutdown();
        } catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static long unsafeTernary(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MatrixBlock matrixBlock4, TernaryOperator ternaryOperator, boolean z, boolean z2, boolean z3, double d, double d2, double d3, int i, int i2) {
        int i3 = matrixBlock4.clen;
        long j = 0;
        for (int i4 = i; i4 < i2; i4++) {
            for (int i5 = 0; i5 < i3; i5++) {
                double execute = ternaryOperator.fn.execute(z ? d : matrixBlock.quickGetValue(i4, i5), z2 ? d2 : matrixBlock2.quickGetValue(i4, i5), z3 ? d3 : matrixBlock3.quickGetValue(i4, i5));
                j += execute != DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1L : 0L;
                matrixBlock4.appendValuePlain(i4, i5, execute);
            }
        }
        return j;
    }
}
