package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
import org.tensorflow.internal.c_api.TFJ_RuntimeLibrary;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Library;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.op.CustomGradient;
import org.tensorflow.op.RawCustomGradient;
import org.tensorflow.op.RawOpInputs;
import org.tensorflow.op.annotation.OpInputsMetadata;
import org.tensorflow.op.annotation.OpMetadata;
import org.tensorflow.proto.OpList;

/* loaded from: input_file:org/tensorflow/TensorFlow.class */
public final class TensorFlow {
    private static Set<String> statefulOps;
    private static final Set<TFJ_GradFuncAdapter> gradientFuncs;

    public static String version() {
        return tensorflow.TF_Version().getString();
    }

    public static OpList registeredOpList() {
        TF_Buffer TF_GetAllOpList = tensorflow.TF_GetAllOpList();
        try {
            try {
                OpList parseFrom = OpList.parseFrom(TF_GetAllOpList.dataAsByteBuffer());
                tensorflow.TF_DeleteBuffer(TF_GetAllOpList);
                return parseFrom;
            } catch (InvalidProtocolBufferException e) {
                throw new TensorFlowException("Cannot parse OpList protocol buffer", e);
            }
        } catch (Throwable th) {
            tensorflow.TF_DeleteBuffer(TF_GetAllOpList);
            throw th;
        }
    }

    public static synchronized boolean isOpStateful(String str) {
        if (statefulOps == null) {
            statefulOps = (Set) registeredOpList().getOpList().stream().filter(opDef -> {
                return opDef.getIsStateful();
            }).map(opDef2 -> {
                return opDef2.getName();
            }).collect(Collectors.toSet());
        }
        return statefulOps.contains(str);
    }

    public static OpList loadLibrary(String str) {
        try {
            TF_Library libraryLoad = libraryLoad(str);
            try {
                return libraryOpList(libraryLoad);
            } finally {
                libraryDelete(libraryLoad);
            }
        } catch (RuntimeException e) {
            throw new UnsatisfiedLinkError(e.getMessage());
        }
    }

    public static void registerFilesystemPlugin(String str) {
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            tensorflow.TF_RegisterFilesystemPlugin(str, newStatus);
            newStatus.throwExceptionIfNotOK();
            pointerScope.close();
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static TF_Library libraryLoad(String str) {
        PointerScope pointerScope = new PointerScope();
        try {
            TF_Status newStatus = TF_Status.newStatus();
            TF_Library TF_LoadLibrary = tensorflow.TF_LoadLibrary(str, newStatus);
            newStatus.throwExceptionIfNotOK();
            pointerScope.close();
            return TF_LoadLibrary;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static void libraryDelete(TF_Library tF_Library) {
        if (tF_Library == null || tF_Library.isNull()) {
            return;
        }
        tensorflow.TF_DeleteLibraryHandle(tF_Library);
    }

    private static OpList libraryOpList(TF_Library tF_Library) {
        try {
            return OpList.parseFrom(tensorflow.TF_GetOpList(tF_Library).dataAsByteBuffer());
        } catch (InvalidProtocolBufferException e) {
            throw new TensorFlowException("Cannot parse OpList protocol buffer", e);
        }
    }

    private TensorFlow() {
    }

    static synchronized boolean hasGradient(String str) {
        return tensorflow.TFJ_HasGradient(str);
    }

    public static synchronized boolean registerCustomGradient(String str, RawCustomGradient rawCustomGradient) {
        if (isWindowsOs()) {
            throw new UnsupportedOperationException("Custom gradient registration is not supported on Windows systems.");
        }
        if (hasGradient(str)) {
            return false;
        }
        TFJ_GradFuncAdapter adapter = RawCustomGradient.adapter(rawCustomGradient);
        if (!tensorflow.TFJ_RegisterCustomGradient(str, adapter)) {
            return false;
        }
        gradientFuncs.add(adapter);
        return true;
    }

    public static synchronized <T extends RawOpInputs<?>> boolean registerCustomGradient(Class<T> cls, CustomGradient<T> customGradient) {
        if (isWindowsOs()) {
            throw new UnsupportedOperationException("Custom gradient registration is not supported on Windows systems.");
        }
        OpInputsMetadata opInputsMetadata = (OpInputsMetadata) cls.getAnnotation(OpInputsMetadata.class);
        if (opInputsMetadata == null) {
            throw new IllegalArgumentException("Inputs Class " + cls + " does not have a OpInputsMetadata annotation.  Was it generated by tensorflow/java?  If it was, this is a bug.");
        }
        OpMetadata opMetadata = (OpMetadata) opInputsMetadata.outputsClass().getAnnotation(OpMetadata.class);
        if (opMetadata == null) {
            throw new IllegalArgumentException("Op Class " + opInputsMetadata.outputsClass() + " does not have a OpMetadata annotation.  Was it generated by tensorflow/java?  If it was, this is a bug.");
        }
        String opType = opMetadata.opType();
        if (hasGradient(opType)) {
            return false;
        }
        TFJ_GradFuncAdapter adapter = CustomGradient.adapter(customGradient, cls);
        if (!tensorflow.TFJ_RegisterCustomGradient(opType, adapter)) {
            return false;
        }
        gradientFuncs.add(adapter);
        return true;
    }

    private static boolean isWindowsOs() {
        return System.getProperty("os.name", "").toLowerCase(Locale.ENGLISH).startsWith("win");
    }

    static {
        try {
            TFJ_RuntimeLibrary.load();
            gradientFuncs = Collections.newSetFromMap(new IdentityHashMap());
        } catch (Exception e) {
            System.err.println("Failed to load TensorFlow native library");
            e.printStackTrace();
            throw e;
        }
    }
}
