package org.nd4j.autodiff.validation;

import java.io.IOException;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.config.ND4JClassLoading;
import org.nd4j.common.function.Function;
import org.nd4j.common.primitives.Pair;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.api.ops.custom.BarnesHutGains;
import org.nd4j.linalg.api.ops.custom.BarnesHutSymmetrize;
import org.nd4j.linalg.api.ops.custom.BitCast;
import org.nd4j.linalg.api.ops.custom.SpTreeCell;
import org.nd4j.linalg.api.ops.custom.ToggleBits;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMax;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMin;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastRSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual;
import org.nd4j.linalg.api.ops.impl.grid.FreeGridOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex;
import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNormDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Col2Im;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv3DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.SConv2DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Upsampling2dDerivative;
import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp;
import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp;
import org.nd4j.linalg.api.ops.impl.meta.InvertedPredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PostulateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.PredicateMetaOp;
import org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp;
import org.nd4j.linalg.api.ops.impl.nlp.CbowRound;
import org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
import org.nd4j.linalg.api.ops.impl.reduce.MmulBp;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsInf;
import org.nd4j.linalg.api.ops.impl.reduce.bool.IsNaN;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.reduce3.EqualsWithEps;
import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative;
import org.nd4j.linalg.api.ops.impl.scalar.ScalarRemainder;
import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue;
import org.nd4j.linalg.api.ops.impl.shape.BroadcastDynamicShape;
import org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix;
import org.nd4j.linalg.api.ops.impl.shape.Eye;
import org.nd4j.linalg.api.ops.impl.shape.MergeSum;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.api.ops.impl.shape.ReductionShape;
import org.nd4j.linalg.api.ops.impl.shape.Shape;
import org.nd4j.linalg.api.ops.impl.shape.ShapeN;
import org.nd4j.linalg.api.ops.impl.shape.SizeAt;
import org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp;
import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.Histogram;
import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot;
import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform;
import org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InTopK;
import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalAnd;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalNot;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalOr;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogicalXor;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits;
import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp;
import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SwishDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanDerivative;
import org.nd4j.linalg.api.ops.persistence.RestoreV2;
import org.nd4j.linalg.api.ops.persistence.SaveV2;
import org.nd4j.linalg.api.ops.random.compat.RandomStandardNormal;
import org.nd4j.linalg.api.ops.random.custom.DistributionUniform;
import org.nd4j.linalg.api.ops.random.impl.AlphaDropOut;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution;
import org.nd4j.linalg.api.ops.random.impl.BinomialDistributionEx;
import org.nd4j.linalg.api.ops.random.impl.Choice;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.ops.random.impl.Linspace;
import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge;
import org.nd4j.linalg.api.ops.random.impl.Range;
import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution;
import org.nd4j.linalg.api.ops.random.impl.UniformDistribution;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.shade.guava.collect.ImmutableSet;
import org.nd4j.shade.guava.collect.UnmodifiableIterator;
import org.nd4j.shade.guava.reflect.ClassPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.OpDef;

/* loaded from: input_file:org/nd4j/autodiff/validation/OpValidation.class */
public class OpValidation {
    private static List<Class> allOps;
    private static List<Long> nonMappedLibnd4jOps;
    private static Map<Long, Pair<List<String>, CustomOpDescriptor>> dedupedCustomOps;
    private static int countTotalLibnd4jOps;
    private static final Logger log = LoggerFactory.getLogger(OpValidation.class);
    private static Map<Class, Integer> gradCheckCoverageCountPerClass = new LinkedHashMap();
    private static Map<Class, Integer> fwdPassCoverageCountPerClass = new LinkedHashMap();
    private static Map<Class, Integer> singleOpTestCountPerClass = new LinkedHashMap();
    private static Map<Class, Integer> opsWithTFMappingTFImportCounts = new LinkedHashMap();
    private static Map<String, Integer> tfMappedOpsImportTestCounts = new LinkedHashMap();

    public static String validate(TestCase testCase) {
        return validate(testCase, false);
    }

    public static String validate(TestCase testCase, boolean z) {
        try {
            return validateHelper(testCase);
        } catch (Throwable th) {
            if (!z) {
                throw th;
            }
            log.info("Exception encountered - returning as error message", th);
            return "EXCEPTION: " + th.getMessage();
        }
    }

    private static String validateHelper(TestCase testCase) {
        testCase.assertConfigValid();
        collectCoverageInformation(testCase);
        ByteBuffer byteBuffer = null;
        if (testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BEFORE_EXEC || testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BOTH) {
            byteBuffer = testCase.sameDiff().asFlatBuffers(true);
            Preconditions.checkNotNull(byteBuffer, "Serialization failed? Null output");
        }
        SameDiff sameDiff = testCase.sameDiff();
        List<Listener> listeners = sameDiff.getListeners();
        if (listeners.isEmpty()) {
            sameDiff.addListeners(new NonInplaceValidationListener());
        } else {
            boolean z = false;
            Iterator<Listener> it = listeners.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (it.next() instanceof NonInplaceValidationListener) {
                    z = true;
                    break;
                }
            }
            if (!z) {
                sameDiff.addListeners(new NonInplaceValidationListener());
            }
        }
        if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) {
            SameDiff sameDiff2 = testCase.sameDiff();
            try {
                Map<String, INDArray> output = sameDiff2.output(testCase.placeholderValues(), new ArrayList(testCase.fwdTestFns().keySet()));
                for (Map.Entry<String, Function<INDArray, String>> entry : testCase.fwdTestFns().entrySet()) {
                    SDVariable variable = sameDiff2.getVariable(entry.getKey());
                    if (variable == null) {
                        throw new IllegalStateException("Test case has expected result function defined for variable \"" + entry.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
                    }
                    INDArray iNDArray = output.get(variable.name());
                    if (iNDArray == null) {
                        throw new IllegalStateException("Null INDArray after forward pass for variable \"" + entry.getKey() + "\"");
                    }
                    try {
                        String str = (String) entry.getValue().apply(iNDArray);
                        if (str != null) {
                            return testCase.testNameErrMsg() + ": Variable " + entry.getKey() + " failed: " + str;
                        }
                    } catch (Throwable th) {
                        throw new IllegalStateException("Error checking forward pass for variable \"" + entry.getKey() + "\": exception was thrown by forward pass validation function", th);
                    }
                }
                if (testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BEFORE_EXEC || testCase.testFlatBufferSerialization() == TestCase.TestSerialization.BOTH) {
                    Preconditions.checkNotNull(testCase.sameDiff().asFlatBuffers(true), "Serialization failed? Null output");
                }
                if (byteBuffer != null) {
                    checkDeserializedEquality(sameDiff2, byteBuffer, testCase);
                }
            } catch (Exception e) {
                throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e);
            }
        }
        if (!testCase.gradientCheck()) {
            return null;
        }
        try {
            if (GradCheckUtil.checkGradients(testCase)) {
                return null;
            }
            return "Gradient check failed" + testCase.testNameErrMsg();
        } catch (Throwable th2) {
            th2.printStackTrace();
            throw new IllegalStateException("Exception encountered during gradient check" + testCase.testNameErrMsg(), th2);
        }
    }

    public static void checkDeserializedEquality(SameDiff sameDiff, ByteBuffer byteBuffer, TestCase testCase) {
        try {
            SameDiff fromFlatBuffers = SameDiff.fromFlatBuffers(byteBuffer);
            List<SDVariable> variables = sameDiff.variables();
            List<SDVariable> variables2 = fromFlatBuffers.variables();
            Preconditions.checkState(variables.size() == variables2.size(), "Number of variables differs: expected %s, got %s", variables.size(), variables2.size());
            for (int i = 0; i < variables.size(); i++) {
                SDVariable sDVariable = variables.get(i);
                SDVariable sDVariable2 = variables2.get(i);
                Preconditions.checkState(sDVariable.name().equals(sDVariable2.name()), "Names should be equal for variable %s: expected %s vs %s", Integer.valueOf(i), sDVariable.name(), sDVariable2.name());
            }
            Map<String, SameDiffOp> ops = sameDiff.getOps();
            Map<String, SameDiffOp> ops2 = fromFlatBuffers.getOps();
            Preconditions.checkState(ops.keySet().equals(ops2.keySet()), "Op names differs: %s vs. %s", ops.keySet(), ops2.keySet());
            for (String str : ops.keySet()) {
                SameDiffOp sameDiffOp = ops.get(str);
                SameDiffOp sameDiffOp2 = ops2.get(str);
                Preconditions.checkState(sameDiffOp.getName().equals(sameDiffOp2.getName()), "Names differ: %s vs %s", sameDiffOp.getName(), sameDiffOp2.getName());
                Preconditions.checkState((sameDiffOp.getInputsToOp() == null) == (sameDiffOp2.getInputsToOp() == null), "Inputs differ: %s vs. %s", sameDiffOp.getInputsToOp(), sameDiffOp2.getInputsToOp());
                Preconditions.checkState(sameDiffOp.getInputsToOp() == null || sameDiffOp.getInputsToOp().equals(sameDiffOp2.getInputsToOp()), "Inputs differ: %s vs. %s", sameDiffOp.getInputsToOp(), sameDiffOp2.getInputsToOp());
                Preconditions.checkState((sameDiffOp.getOutputsOfOp() == null) == (sameDiffOp2.getOutputsOfOp() == null), "Outputs differ: %s vs. %s", sameDiffOp.getOutputsOfOp(), sameDiffOp2.getOutputsOfOp());
                Preconditions.checkState(sameDiffOp.getOutputsOfOp() == null || sameDiffOp.getOutputsOfOp().equals(sameDiffOp2.getOutputsOfOp()), "Outputs differ: %s vs. %s", sameDiffOp.getOutputsOfOp(), sameDiffOp2.getOutputsOfOp());
                Preconditions.checkState((sameDiffOp.getControlDeps() == null) == (sameDiffOp2.getControlDeps() == null), "Control dependencies differ: %s vs. %s", sameDiffOp.getControlDeps(), sameDiffOp2.getControlDeps());
                Preconditions.checkState(sameDiffOp.getControlDeps() == null || sameDiffOp.getControlDeps().equals(sameDiffOp2.getControlDeps()), "Control dependencies differ: %s vs. %s", sameDiffOp.getControlDeps(), sameDiffOp2.getControlDeps());
                Preconditions.checkState((sameDiffOp.getVarControlDeps() == null) == (sameDiffOp2.getVarControlDeps() == null), "Op variable control dependencies differ: %s vs. %s", sameDiffOp.getVarControlDeps(), sameDiffOp2.getVarControlDeps());
                Preconditions.checkState(sameDiffOp.getVarControlDeps() == null || sameDiffOp.getVarControlDeps().equals(sameDiffOp2.getVarControlDeps()), "Op variable control dependencies differ: %s vs. %s", sameDiffOp.getControlDeps(), sameDiffOp2.getControlDeps());
                Preconditions.checkState((sameDiffOp.getControlDepFor() == null) == (sameDiffOp2.getControlDepFor() == null), "Op control dependencies for list differ: %s vs. %s", sameDiffOp.getControlDepFor(), sameDiffOp2.getControlDepFor());
                Preconditions.checkState(sameDiffOp.getControlDepFor() == null || sameDiffOp.getControlDepFor().equals(sameDiffOp2.getControlDepFor()), "Op variable control dependencies differ: %s vs. %s", sameDiffOp.getControlDepFor(), sameDiffOp2.getControlDepFor());
                Preconditions.checkState(sameDiffOp.getOp().getClass().equals(sameDiffOp2.getOp().getClass()), "Classes differ: %s v. %s", sameDiffOp.getOp().getClass(), sameDiffOp2.getOp().getClass());
            }
            HashSet hashSet = new HashSet();
            HashSet hashSet2 = new HashSet();
            for (Variable variable : sameDiff.getVariables().values()) {
                if (variable.getVariable().isPlaceHolder()) {
                    hashSet.add(variable.getName());
                }
            }
            for (Variable variable2 : fromFlatBuffers.getVariables().values()) {
                if (variable2.getVariable().isPlaceHolder()) {
                    hashSet2.add(variable2.getName());
                }
            }
            if (hashSet == null) {
                Preconditions.checkState(hashSet2 == null || hashSet2.size() == 0, "%s", hashSet2);
            } else {
                Preconditions.checkState(hashSet2 != null, "Placeholders after deserialization was null");
                Preconditions.checkState(hashSet.equals(hashSet2), "Before: %s, after deserialization: %s", hashSet, hashSet2);
            }
            Map<String, Variable> variables3 = sameDiff.getVariables();
            Map<String, Variable> variables4 = fromFlatBuffers.getVariables();
            Preconditions.checkState(variables3.keySet().equals(variables4.keySet()), "Variable keysets do not match: %s vs %s", variables3.keySet(), variables4.keySet());
            for (String str2 : variables3.keySet()) {
                Variable variable3 = variables3.get(str2);
                Variable variable4 = variables4.get(str2);
                Preconditions.checkState(variable3.getName().equals(variable4.getName()), "Variable names do not match: %s vs %s", variable4.getName(), variable3.getName());
                Preconditions.checkState(variable3.getVariable().getVariableType() == variable4.getVariable().getVariableType(), "Variable types do not match: %s - %s vs %s", str2, variable3.getVariable().getVariableType(), variable4.getVariable().getVariableType());
                equalConsideringNull(variable3.getInputsForOp(), variable4.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", str2, variable3.getInputsForOp(), variable4.getInputsForOp());
                Preconditions.checkState((variable3.getOutputOfOp() == null && variable4.getOutputOfOp() == null) || variable3.getOutputOfOp().equals(variable4.getOutputOfOp()), "%s - Output of op differ: %s vs. %s", str2, variable3.getOutputOfOp(), variable4.getOutputOfOp());
                equalConsideringNull(variable3.getControlDeps(), variable4.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", str2, variable3.getControlDeps(), variable4.getControlDeps());
                equalConsideringNull(variable3.getControlDepsForOp(), variable4.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", str2, variable3.getControlDepsForOp(), variable4.getControlDepsForOp());
                equalConsideringNull(variable3.getControlDepsForVar(), variable4.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", str2, variable3.getControlDepsForVar(), variable4.getControlDepsForVar());
            }
            List<String> lossVariables = sameDiff.getLossVariables();
            List<String> lossVariables2 = fromFlatBuffers.getLossVariables();
            if (lossVariables == null || lossVariables.isEmpty()) {
                Preconditions.checkState(lossVariables2 == null || lossVariables2.isEmpty(), "Loss variables ");
            } else {
                Preconditions.checkState(lossVariables.equals(lossVariables2), "Loss variables are not equal after deserialization: %s vs %s", lossVariables, lossVariables2);
            }
            if (testCase.fwdTestFns() == null || testCase.fwdTestFns().isEmpty()) {
                return;
            }
            Map<String, INDArray> outputAll = sameDiff.outputAll(testCase.placeholderValues());
            Map<String, INDArray> outputAll2 = fromFlatBuffers.outputAll(testCase.placeholderValues());
            Preconditions.checkState(outputAll.keySet().equals(outputAll2.keySet()), "Keysets for execution after deserialization does not match key set for original model");
            for (String str3 : outputAll.keySet()) {
                INDArray iNDArray = outputAll.get(str3);
                INDArray iNDArray2 = outputAll2.get(str3);
                Function<INDArray, String> function = testCase.fwdTestFns().get(str3);
                String str4 = null;
                if (function != null) {
                    str4 = (String) function.apply(iNDArray2);
                } else if (!iNDArray.equals(iNDArray2)) {
                    long longValue = iNDArray.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn((ReduceOp) new MatchCondition(iNDArray, Conditions.isNan(), new int[0])).getFinalResult().longValue() : -1L;
                    if (!iNDArray.dataType().isNumerical() || longValue <= 0 || !iNDArray.equalShapes(iNDArray2)) {
                        str4 = "INDArray equality failed";
                    } else if (longValue == Nd4j.getExecutioner().execAndReturn((ReduceOp) new MatchCondition(iNDArray2, Conditions.isNan(), new int[0])).getFinalResult().longValue()) {
                        NdIndexIterator ndIndexIterator = new NdIndexIterator(iNDArray.shape());
                        while (ndIndexIterator.hasNext()) {
                            long[] next = ndIndexIterator.next();
                            double d = iNDArray.getDouble(next);
                            double d2 = iNDArray2.getDouble(next);
                            if (Double.isNaN(d) != Double.isNaN(d2) || Double.isInfinite(d) != Double.isInfinite(d2) || Math.abs(d - d2) > 1.0E-5d) {
                                str4 = "INDArray equality failed";
                                break;
                            }
                        }
                    } else {
                        str4 = "INDArray equality failed";
                    }
                }
                Preconditions.checkState(str4 == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", str3, iNDArray, iNDArray2, iNDArray, iNDArray2, str4);
            }
        } catch (IOException e) {
            throw new RuntimeException("IOException deserializing from FlatBuffers", e);
        }
    }

    protected static void equalConsideringNull(List<String> list, List<String> list2, String str, Object... objArr) {
        boolean z = list == null || list.isEmpty();
        boolean z2 = list2 == null || list2.isEmpty();
        if (z && z2) {
            return;
        }
        Preconditions.checkState(list == null || list.equals(list2), str, objArr);
    }

    public static String validate(OpTestCase opTestCase) {
        int i;
        collectCoverageInformation(opTestCase);
        try {
            List<LongShapeDescriptor> calculateOutputShape = Nd4j.getExecutioner().calculateOutputShape(opTestCase.op());
            if (calculateOutputShape.size() != opTestCase.testFns().size()) {
                return "Expected number of output shapes and number of outputs differ. " + calculateOutputShape.size() + " output shapes, but OpTestCase specifies " + opTestCase.testFns().size() + " outputs expected";
            }
            while (i < calculateOutputShape.size()) {
                LongShapeDescriptor longShapeDescriptor = calculateOutputShape.get(i);
                LongShapeDescriptor longShapeDescriptor2 = opTestCase.expShapes().get(Integer.valueOf(i));
                i = (Objects.equals(longShapeDescriptor2.dataType(), longShapeDescriptor.dataType()) && Arrays.equals(longShapeDescriptor.getShape(), longShapeDescriptor2.getShape())) ? i + 1 : 0;
                return "Shape function check failed for output " + i + ": expected shape " + longShapeDescriptor2 + ", actual shape " + longShapeDescriptor;
            }
            try {
                Nd4j.getExecutioner().execAndReturn(opTestCase.op());
                for (int i2 = 0; i2 < opTestCase.testFns().size(); i2++) {
                    try {
                        String str = (String) opTestCase.testFns().get(Integer.valueOf(i2)).apply(opTestCase.op().outputArguments().get(i2));
                        if (str != null) {
                            return "Output " + i2 + " failed: " + str;
                        }
                    } catch (Throwable th) {
                        throw new IllegalStateException("Exception thrown during op output validation for output " + i2, th);
                    }
                }
                return null;
            } catch (Throwable th2) {
                throw new IllegalStateException("Error during op execution", th2);
            }
        } catch (Throwable th3) {
            throw new IllegalStateException("Error calculating output shapes during op validation", th3);
        }
    }

    private static void collectCoverageInformation(TestCase testCase) {
        SameDiff sameDiff = testCase.sameDiff();
        DifferentialFunction[] ops = sameDiff.ops();
        HashSet<Class> hashSet = new HashSet();
        for (DifferentialFunction differentialFunction : ops) {
            hashSet.add(differentialFunction.getClass());
        }
        for (Class cls : hashSet) {
            if (gradCheckCoverageCountPerClass.containsKey(cls)) {
                gradCheckCoverageCountPerClass.put(cls, Integer.valueOf(gradCheckCoverageCountPerClass.get(cls).intValue() + 1));
            } else {
                gradCheckCoverageCountPerClass.put(cls, 1);
            }
        }
        HashSet<Class> hashSet2 = null;
        if (testCase.fwdTestFns() != null) {
            Iterator<String> it = testCase.fwdTestFns().keySet().iterator();
            while (it.hasNext()) {
                DifferentialFunction variableOutputOp = sameDiff.getVariableOutputOp(it.next());
                if (variableOutputOp != null) {
                    if (hashSet2 == null) {
                        hashSet2 = new HashSet();
                    }
                    hashSet2.add(variableOutputOp.getClass());
                }
            }
        }
        if (hashSet2 != null) {
            for (Class cls2 : hashSet2) {
                if (fwdPassCoverageCountPerClass.containsKey(cls2)) {
                    fwdPassCoverageCountPerClass.put(cls2, Integer.valueOf(fwdPassCoverageCountPerClass.get(cls2).intValue() + 1));
                } else {
                    fwdPassCoverageCountPerClass.put(cls2, 1);
                }
            }
        }
    }

    private static void collectCoverageInformation(OpTestCase opTestCase) {
        if (singleOpTestCountPerClass.containsKey(opTestCase.op().getClass())) {
            singleOpTestCountPerClass.put(opTestCase.op().getClass(), Integer.valueOf(singleOpTestCountPerClass.get(opTestCase.op().getClass()).intValue() + 1));
        } else {
            singleOpTestCountPerClass.put(opTestCase.op().getClass(), 1);
        }
    }

    public static void collectTensorflowImportCoverage(SameDiff sameDiff) {
        Iterator<SameDiffOp> it = sameDiff.getOps().values().iterator();
        while (it.hasNext()) {
            DifferentialFunction op = it.next().getOp();
            try {
                String[] tensorflowNames = op.tensorflowNames();
                if (tensorflowNames != null && tensorflowNames.length > 0) {
                    Integer num = opsWithTFMappingTFImportCounts.get(op.getClass());
                    if (num == null) {
                        num = 0;
                    }
                    opsWithTFMappingTFImportCounts.put(op.getClass(), Integer.valueOf(num.intValue() + 1));
                    Integer num2 = fwdPassCoverageCountPerClass.get(op.getClass());
                    if (num2 == null) {
                        num2 = 0;
                    }
                    fwdPassCoverageCountPerClass.put(op.getClass(), Integer.valueOf(num2.intValue() + 1));
                    for (String str : tensorflowNames) {
                        Integer num3 = tfMappedOpsImportTestCounts.get(str);
                        if (num3 == null) {
                            num3 = 0;
                        }
                        tfMappedOpsImportTestCounts.put(str, Integer.valueOf(num3.intValue() + 1));
                    }
                }
            } catch (Throwable th) {
            }
        }
    }

    private static void initializeCoverage() {
        CustomOpDescriptor customOpDescriptor;
        try {
            ImmutableSet topLevelClassesRecursive = ClassPath.from(DifferentialFunctionClassHolder.class.getClassLoader()).getTopLevelClassesRecursive("org.nd4j.linalg.api.ops");
            Map<String, CustomOpDescriptor> customOperations = Nd4j.getExecutioner().getCustomOperations();
            dedupedCustomOps = new HashMap();
            for (Map.Entry<String, CustomOpDescriptor> entry : customOperations.entrySet()) {
                long hash = entry.getValue().getHash();
                if (!dedupedCustomOps.containsKey(Long.valueOf(hash))) {
                    dedupedCustomOps.put(Long.valueOf(hash), new Pair<>(new ArrayList(), entry.getValue()));
                }
                List list = (List) dedupedCustomOps.get(Long.valueOf(hash)).getFirst();
                if (!list.contains(entry.getKey())) {
                    list.add(entry.getKey());
                }
            }
            HashSet hashSet = new HashSet(dedupedCustomOps.keySet());
            allOps = new ArrayList(gradCheckCoverageCountPerClass.keySet());
            UnmodifiableIterator it = topLevelClassesRecursive.iterator();
            while (it.hasNext()) {
                Class loadClassByName = ND4JClassLoading.loadClassByName(((ClassPath.ClassInfo) it.next()).getName());
                Objects.requireNonNull(loadClassByName);
                if (!Modifier.isAbstract(loadClassByName.getModifiers()) && !loadClassByName.isInterface() && DifferentialFunction.class.isAssignableFrom(loadClassByName)) {
                    if (DifferentialFunction.class.isAssignableFrom(loadClassByName) && !loadClassByName.getSimpleName().contains("Old")) {
                        allOps.add(loadClassByName);
                    }
                    String str = null;
                    try {
                        str = ((DifferentialFunction) loadClassByName.newInstance()).opName();
                    } catch (Exception e) {
                        log.warn("Could not instantiate object of type {}", loadClassByName.getName(), e);
                    }
                    if (str != null && (customOpDescriptor = customOperations.get(str)) != null) {
                        hashSet.remove(Long.valueOf(customOpDescriptor.getHash()));
                    }
                }
            }
            countTotalLibnd4jOps = dedupedCustomOps.size();
            nonMappedLibnd4jOps = new ArrayList(hashSet);
            Collections.sort(nonMappedLibnd4jOps, new Comparator<Long>() { // from class: org.nd4j.autodiff.validation.OpValidation.1
                @Override // java.util.Comparator
                public int compare(Long l, Long l2) {
                    return ((String) ((List) ((Pair) OpValidation.dedupedCustomOps.get(l)).getKey()).get(0)).compareTo((String) ((List) ((Pair) OpValidation.dedupedCustomOps.get(l2)).getKey()).get(0));
                }
            });
            Collections.sort(allOps, new Comparator<Class>() { // from class: org.nd4j.autodiff.validation.OpValidation.2
                @Override // java.util.Comparator
                public int compare(Class cls, Class cls2) {
                    return cls.getName().compareTo(cls2.getName());
                }
            });
            for (Class cls : allOps) {
                gradCheckCoverageCountPerClass.put(cls, 0);
                fwdPassCoverageCountPerClass.put(cls, 0);
                singleOpTestCountPerClass.put(cls, 0);
            }
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    public static void logCoverageInformation(boolean z, boolean z2, boolean z3, boolean z4, boolean z5) {
        Set<Class> excludedFromGradientCheckCoverage = excludedFromGradientCheckCoverage();
        Set<Class> excludedFromAllTests = excludedFromAllTests();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        if (z) {
            log.info(" --- Adequately Tested Classes ---");
            for (Class cls : allOps) {
                if (!excludedFromAllTests.contains(cls)) {
                    int intValue = gradCheckCoverageCountPerClass.get(cls).intValue();
                    int intValue2 = fwdPassCoverageCountPerClass.get(cls).intValue() + singleOpTestCountPerClass.get(cls).intValue();
                    if (intValue > 0) {
                        i2++;
                    }
                    if (intValue2 > 0) {
                        i3++;
                    }
                    if (intValue2 > 0 && intValue > 0) {
                        i++;
                    }
                    boolean contains = excludedFromGradientCheckCoverage.contains(cls);
                    if (intValue2 > 0 && (intValue > 0 || contains)) {
                        if (contains) {
                            log.info("Forward: {} tests, GradCheck: <excluded> for op {}", String.format("%3d", Integer.valueOf(intValue2)), cls.getName());
                        } else {
                            log.info("Forward: {} tests, GradCheck: {} tests  for op {}", new Object[]{String.format("%3d", Integer.valueOf(intValue2)), String.format("%3d", Integer.valueOf(intValue)), cls.getName()});
                        }
                    }
                }
            }
        }
        if (z2) {
            log.info(" --- Classes NOT Tested Adequately ---");
            for (Class cls2 : allOps) {
                if (!excludedFromAllTests.contains(cls2)) {
                    int intValue3 = gradCheckCoverageCountPerClass.get(cls2).intValue();
                    int intValue4 = fwdPassCoverageCountPerClass.get(cls2).intValue() + singleOpTestCountPerClass.get(cls2).intValue();
                    boolean contains2 = excludedFromGradientCheckCoverage.contains(cls2);
                    if (intValue4 == 0 || (intValue3 == 0 && !contains2)) {
                        if (contains2) {
                            log.info("Forward: {} tests, GradCheck: <excluded> for op {}", String.format("%3d", Integer.valueOf(intValue4)), cls2.getName());
                        } else {
                            log.info("Forward: {} tests, GradCheck: {} tests  for op {}", new Object[]{String.format("%3d", Integer.valueOf(intValue4)), String.format("%3d", Integer.valueOf(intValue3)), cls2.getName()});
                        }
                    }
                }
            }
        }
        int i4 = 0;
        if (z3) {
            Set<String> excludeFromLibnd4jCustomOpMapping = excludeFromLibnd4jCustomOpMapping();
            log.info(" --- Libnd4j Ops Not Mapped ---");
            Iterator<Long> it = nonMappedLibnd4jOps.iterator();
            while (it.hasNext()) {
                long longValue = it.next().longValue();
                Pair<List<String>, CustomOpDescriptor> pair = dedupedCustomOps.get(Long.valueOf(longValue));
                boolean z6 = false;
                Iterator it2 = ((List) pair.getFirst()).iterator();
                while (true) {
                    if (!it2.hasNext()) {
                        break;
                    }
                    if (excludeFromLibnd4jCustomOpMapping.contains((String) it2.next())) {
                        z6 = true;
                        i4++;
                        break;
                    }
                }
                if (!z6) {
                    log.info("Not mapped libnd4j custom op: {} (hash: {})", pair.getFirst(), Long.valueOf(longValue));
                }
            }
        }
        Map<String, DifferentialFunction> tensorFlowNames = DifferentialFunctionClassHolder.getInstance().getTensorFlowNames();
        int size = tensorFlowNames.size();
        int i5 = 0;
        if (z4) {
            log.info(" --- Ops with TF Mapping but No TF Import Tests ---");
        }
        ArrayList<String> arrayList = new ArrayList(tensorFlowNames.keySet());
        Collections.sort(arrayList);
        Set<String> excludeFromTfImportCoverage = excludeFromTfImportCoverage();
        int i6 = 0;
        for (String str : arrayList) {
            Integer num = tfMappedOpsImportTestCounts.get(str);
            if (num != null && num.intValue() != 0) {
                i5++;
            } else if (excludeFromTfImportCoverage.contains(str)) {
                i6++;
            } else if (z4) {
                log.info("TF mapped op with no import tests: {}", str);
            }
        }
        if (z5) {
            log.info(" --- TF Ops Not Mapped for Import ---");
            try {
                Map<String, OpDef> opDescs = TensorflowDescriptorParser.opDescs();
                ArrayList arrayList2 = new ArrayList();
                for (String str2 : opDescs.keySet()) {
                    if (DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(str2) == null && !excludeFromTfImportCoverage.contains(str2)) {
                        arrayList2.add(str2);
                    }
                }
                Collections.sort(arrayList2);
                int ceil = (int) Math.ceil(arrayList2.size() / 10);
                for (int i7 = 0; i7 < ceil; i7++) {
                    log.info("TF ops not mapped for import: {}", arrayList2.subList(10 * i7, Math.min(10 * (i7 + 1), arrayList2.size())));
                }
            } catch (Throwable th) {
                throw new RuntimeException(th);
            }
        }
        int i8 = 0;
        Iterator<Class> it3 = allOps.iterator();
        while (it3.hasNext()) {
            if (!excludedFromAllTests.contains(it3.next())) {
                i8++;
            }
        }
        int i9 = 0;
        Iterator<Class> it4 = allOps.iterator();
        while (it4.hasNext()) {
            if (!isBackpropOp(it4.next())) {
                i9++;
            }
        }
        String format = String.format("%.2f", Double.valueOf((i3 / i8) * 100.0d));
        String format2 = String.format("%.2f", Double.valueOf((i2 / i9) * 100.0d));
        String format3 = String.format("%.2f", Double.valueOf((i / allOps.size()) * 100.0d));
        int countTotalTfOps = DifferentialFunctionClassHolder.getInstance().getCountTotalTfOps();
        int countTotalMappedOps = DifferentialFunctionClassHolder.getInstance().getCountTotalMappedOps();
        String format4 = String.format("%.2f", Double.valueOf(100.0d * (countTotalMappedOps / countTotalTfOps)));
        int size2 = countTotalLibnd4jOps - nonMappedLibnd4jOps.size();
        String format5 = String.format("%.2f", Double.valueOf(100.0d * (size2 / (countTotalLibnd4jOps - i4))));
        String format6 = String.format("%.2f", Double.valueOf((100.0d * i5) / (size - i6)));
        log.info("*****************************************************");
        log.info("Op Validation:                        {} of {} classes with adequate tests ({}% coverage)", new Object[]{Integer.valueOf(i), Integer.valueOf(i8), format3});
        log.info("Forward pass tests:                   {} of {} classes ({}% coverage)", new Object[]{Integer.valueOf(i3), Integer.valueOf(i8), format});
        log.info("Gradient check tests:                 {} of {} classes ({}% coverage)", new Object[]{Integer.valueOf(i2), Integer.valueOf(i9), format2});
        log.info("({} ops excluded from gradient check coverage)", Integer.valueOf(excludedFromGradientCheckCoverage.size()));
        log.info("({} ops excluded from fwd+gradient tests)", Integer.valueOf(excludedFromAllTests.size()));
        log.info("TF mapped ops:                        {} of {} ({}%)", new Object[]{Integer.valueOf(countTotalMappedOps), Integer.valueOf(countTotalTfOps), format4});
        log.info("SD ops with TF import mapping + test  {} of {} ({}%) - {} ignored for coverage", new Object[]{Integer.valueOf(i5), Integer.valueOf(size - i6), format6, Integer.valueOf(i6)});
        log.info("Libnd4j mapped ops:                   {} of {} ({}%) - {} excluded for coverage", new Object[]{Integer.valueOf(size2), Integer.valueOf(countTotalLibnd4jOps), format5, Integer.valueOf(i4)});
        log.info("*****************************************************");
    }

    private static boolean isBackpropOp(Class<?> cls) {
        String simpleName = cls.getSimpleName();
        return simpleName.contains("Bp") || simpleName.contains("Derivative") || simpleName.contains("Grad");
    }

    private static Set<Class> excludedFromAllTests() {
        return new HashSet(Arrays.asList(DynamicCustomOp.class, GradientBackwardsMarker.class, EqualsWithEps.class, FreeGridOp.class, MergeSum.class, ScalarRemainder.class, RestoreV2.class, SaveV2.class, ScalarSetValue.class, BinomialDistributionEx.class, BroadcastAMax.class, BroadcastAMin.class, BroadcastAddOp.class, BroadcastCopyOp.class, BroadcastDivOp.class, BroadcastEqualTo.class, BroadcastGreaterThan.class, BroadcastGreaterThanOrEqual.class, BroadcastLessThan.class, BroadcastLessThanOrEqual.class, BroadcastMax.class, BroadcastMin.class, BroadcastMulOp.class, BroadcastNotEqual.class, BroadcastRDivOp.class, BroadcastRSubOp.class, BroadcastSubOp.class, AddBpOp.class, DivBpOp.class, FloorDivBpOp.class, FloorModBpOp.class, MulBpOp.class, RDivBpOp.class, RSubBpOp.class, SquaredDifferenceBpOp.class, SubBpOp.class, CumProdBp.class, DotBp.class, SquaredNormBp.class, SoftmaxBp.class, CubeDerivative.class, GELUDerivative.class, PreciseGELUDerivative.class, HardSigmoidDerivative.class, HardTanhDerivative.class, LeakyReLUDerivative.class, LogSoftMaxDerivative.class, RationalTanhDerivative.class, RectifiedTanhDerivative.class, Relu6Derivative.class, PReluBp.class, SELUDerivative.class, SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, SoftSignDerivative.class, TanhDerivative.class, SwishDerivative.class, TanDerivative.class, TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class, PowDerivative.class, RectifiedLinearDerivative.class, CubeBp.class, EluBp.class, HardSigmoidBp.class, HardTanhBp.class, LeakyReLUBp.class, RationalTanhBp.class, RectifiedTanhBp.class, SeluBp.class, SoftPlusBp.class, SoftSignBp.class, ThresholdReluBp.class, ModBpOp.class, BiasAddGrad.class, ConcatBp.class, TileBp.class, BatchNormDerivative.class, Conv2DDerivative.class, Conv3DDerivative.class, DeConv2DDerivative.class, LocalResponseNormalizationDerivative.class, Pooling2DDerivative.class, Pooling3DDerivative.class, SConv2DDerivative.class, Upsampling2dDerivative.class, Im2colBp.class, SliceBp.class, StridedSliceBp.class, MmulBp.class, DotProductAttentionBp.class, MultiHeadDotProductAttentionBp.class, LayerNormBp.class, StandardizeBp.class, DynamicPartitionBp.class, AbsoluteDifferenceLossBp.class, CosineDistanceLossBp.class, HingeLossBp.class, HuberLossBp.class, LogLossBp.class, LogPoissonLossBp.class, MeanPairwiseSquaredErrorLossBp.class, MeanSquaredErrorLossBp.class, SigmoidCrossEntropyLossBp.class, SoftmaxCrossEntropyLossBp.class, SparseSoftmaxCrossEntropyLossWithLogitsBp.class, SegmentMaxBp.class, SegmentMeanBp.class, SegmentMinBp.class, SegmentProdBp.class, SegmentSumBp.class, UnsortedSegmentMaxBp.class, UnsortedSegmentMeanBp.class, UnsortedSegmentMinBp.class, UnsortedSegmentProdBp.class, UnsortedSegmentSqrtNBp.class, UnsortedSegmentSumBp.class, ExternalErrorsFunction.class, InvertedPredicateMetaOp.class, PostulateMetaOp.class, PredicateMetaOp.class, ReduceMetaOp.class, BarnesEdgeForces.class, BarnesHutGains.class, BarnesHutSymmetrize.class, SpTreeCell.class, CbowRound.class, SkipGramRound.class, HashCode.class, HashCode.class, BitCast.class, ToggleBits.class));
    }

    private static Set<Class> excludedFromGradientCheckCoverage() {
        return new HashSet(Arrays.asList(DynamicCustomOp.class, EqualsWithEps.class, ConfusionMatrix.class, Eye.class, OneHot.class, BinaryMinimalRelativeError.class, BinaryMinimalRelativeError.class, InvertPermutation.class, ConfusionMatrix.class, Linspace.class, Assert.class, Any.class, All.class, IsInf.class, org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf.class, IsNaN.class, org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class, BooleanNot.class, Not.class, MatchConditionTransform.class, InTopK.class, IsNonDecreasing.class, IsStrictlyIncreasing.class, IsNumericTensor.class, FirstIndex.class, LastIndex.class, ArgMax.class, ArgMin.class, Shape.class, ShapeN.class, SizeAt.class, BroadcastDynamicShape.class, ReductionShape.class, ShiftBits.class, RShiftBits.class, BitsHammingDistance.class, CyclicShiftBits.class, CyclicRShiftBits.class, RandomStandardNormal.class, DistributionUniform.class, AlphaDropOut.class, BernoulliDistribution.class, BinomialDistribution.class, BinomialDistributionEx.class, Choice.class, DropOut.class, DropOutInverted.class, GaussianDistribution.class, LogNormalDistribution.class, ProbablisticMerge.class, Range.class, TruncatedNormalDistribution.class, UniformDistribution.class, Col2Im.class, NormalizeMoments.class, CumProdBp.class, CumSumBp.class, DotBp.class, MaxBp.class, MeanBp.class, MinBp.class, Norm1Bp.class, Norm2Bp.class, NormMaxBp.class, ProdBp.class, StandardDeviationBp.class, SumBp.class, VarianceBp.class, LogicalAnd.class, LogicalNot.class, LogicalOr.class, LogicalXor.class, Histogram.class));
    }

    private static Set<String> excludeFromTfImportCoverage() {
        return new HashSet(Arrays.asList("Reverse", "LogSigmoid", "HardSigmoid", "SpaceToBatch", "BatchToSpace", "Pad", "TopK", "InTopK", "BatchMatrixDeterminant", "BatchMatrixDiagPart", "BatchMatrixDiag", "BatchMatrixBandPart", "BatchMatrixInverse", "BatchMatrixSetDiag", "BatchMatrixSolve", "BatchMatrixSolveLs", "BatchMatrixTriangularSolve", "BatchSelfAdjointEig", "BatchSelfAdjointEigV2", "BatchSvd", "ExperimentalBytesProducedStatsDataset", "ExperimentalCSVDataset", "ExperimentalDatasetCardinality", "ExperimentalDatasetToTFRecord", "ExperimentalDenseToSparseBatchDataset", "ExperimentalDirectedInterleaveDataset", "ExperimentalGroupByReducerDataset", "ExperimentalGroupByWindowDataset", "ExperimentalIdentityIndexedDataset", "ExperimentalIgnoreErrorsDataset", "ExperimentalIndexedDatasetGet", "ExperimentalIndexedDatasetMaterialize", "ExperimentalIteratorGetDevice", "ExperimentalLMDBDataset", "ExperimentalLatencyStatsDataset", "ExperimentalMapAndBatchDataset", "ExperimentalMapDataset", "ExperimentalMatchingFilesDataset", "ExperimentalMaterializedIndexDatasetHandle", "ExperimentalMaxIntraOpParallelismDataset", "ExperimentalNonSerializableDataset", "ExperimentalNumaMapAndBatchDataset", "ExperimentalParallelInterleaveDataset", "ExperimentalParseExampleDataset", "ExperimentalPrivateThreadPoolDataset", "ExperimentalRandomDataset", "ExperimentalScanDataset", "ExperimentalSetStatsAggregatorDataset", "ExperimentalSleepDataset", "ExperimentalSlidingWindowDataset", "ExperimentalSqlDataset", "ExperimentalStatsAggregatorHandle", "ExperimentalStatsAggregatorSummary", "ExperimentalThreadPoolDataset", "ExperimentalThreadPoolHandle", "ExperimentalUnbatchDataset", "ExperimentalUniqueDataset", "DebugIdentity", "NcclAllReduce", "NcclBroadcast", "NcclReduce", "PyFunc", "PyFuncStateless", "QuantizedAdd", "QuantizedAvgPool", "QuantizedBatchNormWithGlobalNormalization", "QuantizedBiasAdd", "QuantizedConcat", "QuantizedConv2D", "QuantizedInstanceNorm", "QuantizedMatMul", "QuantizedMaxPool", "QuantizedMul", "QuantizedRelu", "QuantizedRelu6", "QuantizedReluX", "QuantizedReshape", "QuantizedResizeBilinear", "HardTanh", "Swish", "RDiv", "DivScalar", "LogX", "RationalTanh", "absargmax", "absargmin", "entropy_shannon", "count_zero", "SaveV2", "LoadV2", "RestoreV2", "RandomCrop"));
    }

    private static Set<String> excludeFromLibnd4jCustomOpMapping() {
        HashSet hashSet = new HashSet();
        Collections.addAll(hashSet, "TestOp2i2o", "testop2i2o", "firas_sparse", "test_output_reshape", "test_scalar", "testcustom", "testreduction", "to_double", "to_float16", "to_float32", "to_int32", "to_int64", "to_uint32", "to_uint64");
        return hashSet;
    }

    static {
        initializeCoverage();
    }
}
