package org.apache.sysds.runtime.matrix.data;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.utils.NativeHelper;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixNative.class */
public class LibMatrixNative {
    private static final Log LOG = LogFactory.getLog(LibMatrixNative.class.getName());
    private static ThreadLocal<FloatBuffer> inBuff = new ThreadLocal<>();
    private static ThreadLocal<FloatBuffer> biasBuff = new ThreadLocal<>();
    private static ThreadLocal<FloatBuffer> filterBuff = new ThreadLocal<>();
    private static ThreadLocal<FloatBuffer> outBuff = new ThreadLocal<>();

    public static boolean isMatMultMemoryBound(int i, int i2, int i3) {
        return (i == 1 || i2 == 1 || i3 == 1) && ((8 * ((long) i)) * ((long) i2) > 268435456 || (8 * ((long) i2)) * ((long) i3) > 268435456);
    }

    public static void matrixMult(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) {
        matrixMult(matrixBlock, matrixBlock2, matrixBlock3, i, true);
    }

    public static void matrixMult(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i, boolean z) {
        boolean dmmdd;
        int maxNumThreads = i <= 0 ? NativeHelper.getMaxNumThreads() : i;
        if (matrixBlock.isEmptyBlock(false) || matrixBlock2.isEmptyBlock(false)) {
            matrixBlock3.setNonZeros(0L);
            if (z) {
                matrixBlock3.examSparsity();
                return;
            }
            return;
        }
        if (NativeHelper.isNativeLibraryLoaded() && !isMatMultMemoryBound(matrixBlock.rlen, matrixBlock.clen, matrixBlock2.clen) && !matrixBlock.isInSparseFormat() && !matrixBlock2.isInSparseFormat() && matrixBlock.getDenseBlock().isContiguous() && matrixBlock2.getDenseBlock().isContiguous() && 8 * matrixBlock3.getLength() < OptimizerUtils.MAX_NUMCELLS_CP_DENSE) {
            matrixBlock3.sparse = false;
            matrixBlock3.allocateDenseBlock();
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (isSinglePrecision()) {
                dmmdd = NativeHelper.smmdd(toFloatBuffer(matrixBlock.getDenseBlockValues(), inBuff, true), toFloatBuffer(matrixBlock2.getDenseBlockValues(), filterBuff, true), toFloatBuffer(matrixBlock3.getDenseBlockValues(), outBuff, false), matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock2.getNumColumns(), maxNumThreads);
                fromFloatBuffer(outBuff.get(), matrixBlock3.getDenseBlockValues());
            } else {
                dmmdd = NativeHelper.dmmdd(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), matrixBlock3.getDenseBlockValues(), matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock2.getNumColumns(), maxNumThreads);
            }
            if (dmmdd) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeLibMatrixMultTime += System.nanoTime() - nanoTime;
                    Statistics.numNativeLibMatrixMultCalls.increment();
                }
                matrixBlock3.recomputeNonZeros();
                if (z) {
                    matrixBlock3.examSparsity();
                    return;
                }
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        if (maxNumThreads == 1) {
            LibMatrixMult.matrixMult(matrixBlock, matrixBlock2, matrixBlock3, !z);
        } else {
            LibMatrixMult.matrixMult(matrixBlock, matrixBlock2, matrixBlock3, maxNumThreads);
        }
    }

    public static void tsmm(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, boolean z, int i) {
        if (matrixBlock.isEmptyBlock(false)) {
            return;
        }
        if (NativeHelper.isNativeLibraryLoaded() && ((matrixBlock2.clen > 1 || matrixBlock2.getLength() == 1) && !matrixBlock.sparse && matrixBlock.getDenseBlock().isContiguous())) {
            matrixBlock2.sparse = false;
            matrixBlock2.allocateDenseBlock();
            if (NativeHelper.tsmm(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), matrixBlock.rlen, matrixBlock.clen, z, i)) {
                matrixBlock2.setNonZeros(matrixBlock2.clen == 1 ? matrixBlock2.recomputeNonZeros() : LibMatrixMult.copyUpperToLowerTriangle(matrixBlock2));
                matrixBlock2.examSparsity();
                return;
            } else {
                LOG.info("Falling back to java TSMM()");
                Statistics.incrementNativeFailuresCounter();
            }
        }
        if (i > 1) {
            LibMatrixMult.matrixMultTransposeSelf(matrixBlock, matrixBlock2, z, i);
        } else {
            LibMatrixMult.matrixMultTransposeSelf(matrixBlock, matrixBlock2, z);
        }
    }

    public static void conv2d(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        int dconv2dBiasAddDense;
        LibMatrixDNN.checkInputsConv2d(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
        dnnParameters.numThreads = dnnParameters.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : dnnParameters.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !matrixBlock.isInSparseFormat() && !matrixBlock2.isInSparseFormat()) {
            setNumThreads(dnnParameters);
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            if (dnnParameters.bias == null) {
                dconv2dBiasAddDense = NativeHelper.conv2dDense(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), matrixBlock3.getDenseBlockValues(), dnnParameters.N, dnnParameters.C, dnnParameters.H, dnnParameters.W, dnnParameters.K, dnnParameters.R, dnnParameters.S, dnnParameters.stride_h, dnnParameters.stride_w, dnnParameters.pad_h, dnnParameters.pad_w, dnnParameters.P, dnnParameters.Q, dnnParameters.numThreads);
            } else {
                if (dnnParameters.bias.isInSparseFormat()) {
                    dnnParameters.bias.sparseToDense();
                }
                if (!isSinglePrecision() || NativeHelper.getCurrentBLAS().equalsIgnoreCase("mkl")) {
                    dconv2dBiasAddDense = NativeHelper.dconv2dBiasAddDense(matrixBlock.getDenseBlockValues(), dnnParameters.bias.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), matrixBlock3.getDenseBlockValues(), dnnParameters.N, dnnParameters.C, dnnParameters.H, dnnParameters.W, dnnParameters.K, dnnParameters.R, dnnParameters.S, dnnParameters.stride_h, dnnParameters.stride_w, dnnParameters.pad_h, dnnParameters.pad_w, dnnParameters.P, dnnParameters.Q, dnnParameters.numThreads);
                } else {
                    dconv2dBiasAddDense = NativeHelper.sconv2dBiasAddDense(toFloatBuffer(matrixBlock.getDenseBlockValues(), inBuff, true), toFloatBuffer(dnnParameters.bias.getDenseBlockValues(), biasBuff, true), toFloatBuffer(matrixBlock2.getDenseBlockValues(), filterBuff, true), toFloatBuffer(matrixBlock3.getDenseBlockValues(), outBuff, false), dnnParameters.N, dnnParameters.C, dnnParameters.H, dnnParameters.W, dnnParameters.K, dnnParameters.R, dnnParameters.S, dnnParameters.stride_h, dnnParameters.stride_w, dnnParameters.pad_h, dnnParameters.pad_w, dnnParameters.P, dnnParameters.Q, dnnParameters.numThreads);
                    if (dconv2dBiasAddDense != -1) {
                        fromFloatBuffer(outBuff.get(), matrixBlock3.getDenseBlockValues());
                    }
                }
            }
            if (dconv2dBiasAddDense != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dTime += System.nanoTime() - nanoTime;
                    Statistics.numNativeConv2dCalls.increment();
                }
                matrixBlock3.setNonZeros(dconv2dBiasAddDense);
                return;
            }
            LOG.warn("Native conv2d call returned with error - falling back to java operator.");
            if (!isSinglePrecision() || dnnParameters.bias == null) {
                matrixBlock3.reset();
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2d(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
    }

    private static void setNumThreads(DnnParameters dnnParameters) {
        dnnParameters.numThreads = OptimizerUtils.getConstrainedNumThreads(dnnParameters.numThreads);
        if (!dnnParameters.isOutputThreadSafe() || dnnParameters.numThreads <= 1) {
            dnnParameters.numThreads = 1;
        }
    }

    public static void conv2dBackwardFilter(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        LibMatrixDNN.checkInputsConv2dBackwardFilter(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
        dnnParameters.numThreads = dnnParameters.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : dnnParameters.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !matrixBlock2.isInSparseFormat() && !matrixBlock.isInSparseFormat()) {
            setNumThreads(dnnParameters);
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int conv2dBackwardFilterDense = NativeHelper.conv2dBackwardFilterDense(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), matrixBlock3.getDenseBlockValues(), dnnParameters.N, dnnParameters.C, dnnParameters.H, dnnParameters.W, dnnParameters.K, dnnParameters.R, dnnParameters.S, dnnParameters.stride_h, dnnParameters.stride_w, dnnParameters.pad_h, dnnParameters.pad_w, dnnParameters.P, dnnParameters.Q, dnnParameters.numThreads);
            if (conv2dBackwardFilterDense != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dBwdFilterTime += System.nanoTime() - nanoTime;
                    Statistics.numNativeConv2dBwdFilterCalls.increment();
                }
                matrixBlock3.setNonZeros(conv2dBackwardFilterDense);
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardFilter(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
    }

    public static void conv2dBackwardData(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, DnnParameters dnnParameters) {
        LibMatrixDNN.checkInputsConv2dBackwardData(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
        dnnParameters.numThreads = dnnParameters.numThreads <= 0 ? NativeHelper.getMaxNumThreads() : dnnParameters.numThreads;
        if (NativeHelper.isNativeLibraryLoaded() && !matrixBlock2.isInSparseFormat() && !matrixBlock.isInSparseFormat()) {
            setNumThreads(dnnParameters);
            long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
            int conv2dBackwardDataDense = NativeHelper.conv2dBackwardDataDense(matrixBlock.getDenseBlockValues(), matrixBlock2.getDenseBlockValues(), matrixBlock3.getDenseBlockValues(), dnnParameters.N, dnnParameters.C, dnnParameters.H, dnnParameters.W, dnnParameters.K, dnnParameters.R, dnnParameters.S, dnnParameters.stride_h, dnnParameters.stride_w, dnnParameters.pad_h, dnnParameters.pad_w, dnnParameters.P, dnnParameters.Q, dnnParameters.numThreads);
            if (conv2dBackwardDataDense != -1) {
                if (DMLScript.STATISTICS) {
                    Statistics.nativeConv2dBwdDataTime += System.nanoTime() - nanoTime;
                    Statistics.numNativeConv2dBwdDataCalls.increment();
                }
                matrixBlock3.setNonZeros(conv2dBackwardDataDense);
                return;
            }
            Statistics.incrementNativeFailuresCounter();
        }
        LibMatrixDNN.conv2dBackwardData(matrixBlock, matrixBlock2, matrixBlock3, dnnParameters);
    }

    private static boolean isSinglePrecision() {
        return ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.FLOATING_POINT_PRECISION).equals("single");
    }

    private static FloatBuffer toFloatBuffer(double[] dArr, ThreadLocal<FloatBuffer> threadLocal, boolean z) {
        FloatBuffer floatBuffer = threadLocal.get();
        if (floatBuffer == null || floatBuffer.capacity() < dArr.length) {
            floatBuffer = ByteBuffer.allocateDirect(4 * dArr.length).order(ByteOrder.nativeOrder()).asFloatBuffer();
            threadLocal.set(floatBuffer);
        }
        FloatBuffer floatBuffer2 = floatBuffer;
        if (z) {
            IntStream.range(0, dArr.length).parallel().forEach(i -> {
                floatBuffer2.put(i, (float) dArr[i]);
            });
        }
        return floatBuffer2;
    }

    public static void fromFloatBuffer(FloatBuffer floatBuffer, double[] dArr) {
        Arrays.parallelSetAll(dArr, i -> {
            return floatBuffer.get(i);
        });
    }
}
