package org.apache.sysds.utils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.FloatBuffer;
import java.util.Iterator;
import java.util.Vector;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.SystemUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.io.IOUtilFunctions;

/* loaded from: input_file:org/apache/sysds/utils/NativeHelper.class */
public class NativeHelper {
    private static String blasType;
    public static NativeBlasState CURRENT_NATIVE_BLAS_STATE = NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS;
    private static int maxNumThreads = -1;
    private static boolean setMaxNumThreads = false;
    private static final Log LOG = LogFactory.getLog(NativeHelper.class.getName());

    /* loaded from: input_file:org/apache/sysds/utils/NativeHelper$NativeBlasState.class */
    public enum NativeBlasState {
        NOT_ATTEMPTED_LOADING_NATIVE_BLAS,
        SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE,
        SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE,
        ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY
    }

    public static String getCurrentBLAS() {
        return blasType != null ? blasType : "";
    }

    public static boolean isNativeLibraryLoaded() {
        if (!isBLASLoaded()) {
            DMLConfig dMLConfig = ConfigurationManager.getDMLConfig();
            performLoading(dMLConfig == null ? "none" : dMLConfig.getTextValue(DMLConfig.NATIVE_BLAS_DIR).trim(), dMLConfig == null ? "auto" : dMLConfig.getTextValue(DMLConfig.NATIVE_BLAS).trim().toLowerCase());
        }
        if (maxNumThreads == -1) {
            maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
        }
        if (CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE && !setMaxNumThreads && maxNumThreads != -1) {
            setMaxNumThreads(maxNumThreads);
            setMaxNumThreads = true;
        }
        return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
    }

    public static void initialize(String str, String str2) {
        if (isBLASLoaded() && isSupportedBLAS(str2) && !blasType.equalsIgnoreCase(str2)) {
            throw new DMLRuntimeException("Cannot replace previously loaded blas \"" + blasType + "\" with \"" + str2 + "\".");
        }
        if (isBLASLoaded() && str2.equalsIgnoreCase("none")) {
            CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
            return;
        }
        if (isBLASLoaded() && str2.equalsIgnoreCase(blasType)) {
            CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
        } else {
            if (isBLASLoaded() || !isSupportedBLAS(str2)) {
                return;
            }
            performLoading(str, str2);
        }
    }

    private static boolean isSupportedBLAS(String str) {
        return str.equalsIgnoreCase("auto") || str.equalsIgnoreCase("mkl") || str.equalsIgnoreCase("openblas");
    }

    private static boolean isSupportedArchitecture() {
        if (SystemUtils.OS_ARCH.equals("x86_64") || SystemUtils.OS_ARCH.equals("amd64")) {
            return true;
        }
        LOG.info("Unsupported architecture for native BLAS:" + SystemUtils.OS_ARCH);
        return false;
    }

    private static boolean isSupportedOS() {
        if (SystemUtils.IS_OS_LINUX || SystemUtils.IS_OS_WINDOWS) {
            return true;
        }
        LOG.info("Unsupported architecture for native BLAS:" + SystemUtils.OS_ARCH);
        return false;
    }

    private static boolean isBLASLoaded() {
        return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
    }

    private static boolean shouldReload(String str) {
        return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS || ((str != null && !str.equalsIgnoreCase("none")) && !isBLASLoaded());
    }

    private static void performLoading(String str, String str2) {
        if (str != null && str.equalsIgnoreCase("none")) {
            str = null;
        }
        if (!shouldReload(str) || !isSupportedBLAS(str2) || !isSupportedArchitecture() || !isSupportedOS()) {
            if (!LOG.isDebugEnabled() || isSupportedBLAS(str2)) {
                return;
            }
            LOG.debug("Using internal Java BLAS as native BLAS support instead of the configuration 'sysds.native.blas'=" + str2 + ".");
            return;
        }
        long nanoTime = System.nanoTime();
        synchronized (NativeHelper.class) {
            if (shouldReload(str)) {
                CURRENT_NATIVE_BLAS_STATE = NativeBlasState.ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY;
                String[] strArr = {str2};
                if (str2.equalsIgnoreCase("auto")) {
                    strArr = new String[]{"mkl", "openblas"};
                }
                if (checkAndLoadBLAS(str, strArr)) {
                    String str3 = "libsystemds_" + blasType + (SystemUtils.IS_OS_WINDOWS ? "-Windows-AMD64.dll" : "-Linux-x86_64.so");
                    if (loadLibraryHelperFromResource(str3) || loadBLAS(str, str3, "Loading native helper with customLibPath.")) {
                        LOG.info("Using native blas: " + blasType + getNativeBLASPath());
                        CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
                    }
                }
            }
        }
        double nanoTime2 = (System.nanoTime() - nanoTime) * 1.0E-6d;
        if (nanoTime2 > 1000.0d) {
            LOG.warn("Time to load native blas: " + nanoTime2 + " milliseconds.");
        }
    }

    private static boolean checkAndLoadBLAS(String str, String[] strArr) {
        if (str != null && str.equalsIgnoreCase("none")) {
            str = null;
        }
        boolean z = false;
        int length = strArr.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            String str2 = strArr[i];
            if (str2.equalsIgnoreCase("mkl")) {
                z = loadBLAS(str, "mkl_rt", "");
            } else if (str2.equalsIgnoreCase("openblas")) {
                z = loadBLAS(str, "libopenblas", "");
                if (!z) {
                    z = loadBLAS(str, "openblas", "");
                }
            } else {
                LOG.warn("Not trying to load unknown blas type " + str2);
            }
            if (z) {
                blasType = str2;
                break;
            }
            i++;
        }
        return z;
    }

    private static String getNativeBLASPath() {
        String str = "";
        if (LOG.isDebugEnabled()) {
            try {
                Field declaredField = ClassLoader.class.getDeclaredField("loadedLibraryNames");
                declaredField.setAccessible(true);
                Vector vector = (Vector) declaredField.get(ClassLoader.getSystemClassLoader());
                LOG.debug("List of native libraries loaded:" + vector);
                Iterator it = vector.iterator();
                while (it.hasNext()) {
                    String str2 = (String) it.next();
                    if (str2.contains("mkl_rt") || str2.contains("libopenblas")) {
                        str = " from the path " + str2;
                        break;
                    }
                }
            } catch (IllegalAccessException | IllegalArgumentException | NoSuchFieldException | SecurityException e) {
                LOG.debug("Error while finding list of native libraries:" + e.getMessage());
            }
        }
        return str;
    }

    public static int getMaxNumThreads() {
        if (maxNumThreads == -1) {
            maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
        }
        return maxNumThreads;
    }

    public static boolean loadBLAS(String str, String str2, String str3) {
        if (str != null && !str.equalsIgnoreCase("none")) {
            String str4 = str + File.separator + System.mapLibraryName(str2);
            try {
                str4 = str4.replace("liblibsystemds", "libsystemds").replace(".dll.dll", ".dll").replace(".so.so", ".so");
                System.load(str4);
                LOG.info("Loaded the library:" + str4);
                return true;
            } catch (UnsatisfiedLinkError e) {
                LOG.warn("Unable to load " + str2 + " from " + str4 + ". Trying once more with System.loadLibrary(" + str2 + ") \n Message from exception was: " + e.getMessage());
            }
        }
        try {
            System.loadLibrary(str2);
            return true;
        } catch (UnsatisfiedLinkError e2) {
            LOG.debug("java.library.path: " + System.getProperty("java.library.path"));
            LOG.debug("Unable to load " + str2 + (str3 == null ? "" : " (" + str3 + ")") + " \n Message from exception was: " + e2.getMessage());
            return false;
        }
    }

    public static boolean loadLibraryHelperFromResource(String str) {
        LOG.info("Loading JNI shared library: " + str);
        try {
            InputStream resourceAsStream = NativeHelper.class.getResourceAsStream("/lib/" + str);
            try {
                if (resourceAsStream == null) {
                    LOG.error("No lib available in the jar:" + str);
                    if (resourceAsStream != null) {
                        resourceAsStream.close();
                    }
                    return false;
                }
                File createTempFile = File.createTempFile(str, "");
                createTempFile.deleteOnExit();
                FileOutputStream openOutputStream = FileUtils.openOutputStream(createTempFile);
                IOUtils.copy(resourceAsStream, openOutputStream);
                IOUtilFunctions.closeSilently(openOutputStream);
                System.load(createTempFile.getAbsolutePath());
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
                return true;
            } finally {
            }
        } catch (IOException | UnsatisfiedLinkError e) {
            LOG.error("Unable to load library " + str + " from resource:" + e.getMessage());
            return false;
        }
    }

    public static native long dmmdd(double[] dArr, double[] dArr2, double[] dArr3, int i, int i2, int i3, int i4);

    public static native long smmdd(FloatBuffer floatBuffer, FloatBuffer floatBuffer2, FloatBuffer floatBuffer3, int i, int i2, int i3, int i4);

    public static native long tsmm(double[] dArr, double[] dArr2, int i, int i2, boolean z, int i3);

    public static native long conv2dDense(double[] dArr, double[] dArr2, double[] dArr3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14);

    public static native long dconv2dBiasAddDense(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14);

    public static native long sconv2dBiasAddDense(FloatBuffer floatBuffer, FloatBuffer floatBuffer2, FloatBuffer floatBuffer3, FloatBuffer floatBuffer4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14);

    public static native long conv2dBackwardFilterDense(double[] dArr, double[] dArr2, double[] dArr3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14);

    public static native long conv2dBackwardDataDense(double[] dArr, double[] dArr2, double[] dArr3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14);

    public static native boolean conv2dBackwardFilterSparseDense(int i, int i2, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14, int i15, int i16);

    public static native boolean conv2dSparse(int i, int i2, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, int i14, int i15, int i16);

    private static native void setMaxNumThreads(int i);
}
