package org.jpmml.evaluator.neural_network;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasField;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.neural_network.NeuralInput;
import org.dmg.pmml.neural_network.NeuralInputs;
import org.dmg.pmml.neural_network.NeuralLayer;
import org.dmg.pmml.neural_network.NeuralNetwork;
import org.dmg.pmml.neural_network.NeuralOutput;
import org.dmg.pmml.neural_network.NeuralOutputs;
import org.dmg.pmml.neural_network.Neuron;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EntityUtil;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.HasEntityRegistry;
import org.jpmml.evaluator.InvalidFeatureException;
import org.jpmml.evaluator.InvalidResultException;
import org.jpmml.evaluator.MissingFieldException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.NormalizationUtil;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.TargetField;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedFeatureException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;

/* loaded from: input_file:org/jpmml/evaluator/neural_network/NeuralNetworkEvaluator.class */
public class NeuralNetworkEvaluator extends ModelEvaluator<NeuralNetwork> implements HasEntityRegistry<Entity> {
    private transient Map<FieldName, List<NeuralOutput>> neuralOutputMap;
    private transient BiMap<String, Entity> entityRegistry;
    private static final LoadingCache<NeuralNetwork, BiMap<String, Entity>> entityCache = CacheUtil.buildLoadingCache(new CacheLoader<NeuralNetwork, BiMap<String, Entity>>() { // from class: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator.1
        public BiMap<String, Entity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.Builder builder = new ImmutableBiMap.Builder();
            AtomicInteger atomicInteger = new AtomicInteger(1);
            Iterator it = neuralNetwork.getNeuralInputs().iterator();
            while (it.hasNext()) {
                builder = EntityUtil.put((NeuralInput) it.next(), atomicInteger, builder);
            }
            Iterator it2 = neuralNetwork.getNeuralLayers().iterator();
            while (it2.hasNext()) {
                List neurons = ((NeuralLayer) it2.next()).getNeurons();
                for (int i = 0; i < neurons.size(); i++) {
                    builder = EntityUtil.put((Neuron) neurons.get(i), atomicInteger, builder);
                }
            }
            return builder.build();
        }
    });

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/evaluator/neural_network/NeuralNetworkEvaluator$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MathContext;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod = new int[NeuralNetwork.NormalizationMethod.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[NeuralNetwork.NormalizationMethod.NONE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[NeuralNetwork.NormalizationMethod.SIMPLEMAX.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$NormalizationMethod[NeuralNetwork.NormalizationMethod.SOFTMAX.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction = new int[NeuralNetwork.ActivationFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.THRESHOLD.ordinal()] = 1;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.LOGISTIC.ordinal()] = 2;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.TANH.ordinal()] = 3;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.IDENTITY.ordinal()] = 4;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.EXPONENTIAL.ordinal()] = 5;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.RECIPROCAL.ordinal()] = 6;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.SQUARE.ordinal()] = 7;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.GAUSS.ordinal()] = 8;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.SINE.ordinal()] = 9;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.COSINE.ordinal()] = 10;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.ELLIOTT.ordinal()] = 11;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.ARCTAN.ordinal()] = 12;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$org$dmg$pmml$neural_network$NeuralNetwork$ActivationFunction[NeuralNetwork.ActivationFunction.RECTIFIER.ordinal()] = 13;
            } catch (NoSuchFieldError e16) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e17) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e18) {
            }
            $SwitchMap$org$dmg$pmml$MathContext = new int[MathContext.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MathContext[MathContext.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e19) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MathContext[MathContext.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e20) {
            }
        }
    }

    public NeuralNetworkEvaluator(PMML pmml) {
        this(pmml, selectModel(pmml, NeuralNetwork.class));
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
        this.neuralOutputMap = null;
        this.entityRegistry = null;
        NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
        if (neuralInputs == null) {
            throw new InvalidFeatureException((PMMLObject) neuralNetwork);
        }
        if (!neuralInputs.hasNeuralInputs()) {
            throw new InvalidFeatureException((PMMLObject) neuralInputs);
        }
        if (!neuralNetwork.hasNeuralLayers()) {
            throw new InvalidFeatureException((PMMLObject) neuralNetwork);
        }
        NeuralOutputs neuralOutputs = neuralNetwork.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException((PMMLObject) neuralNetwork);
        }
        if (!neuralOutputs.hasNeuralOutputs()) {
            throw new InvalidFeatureException((PMMLObject) neuralOutputs);
        }
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Neural network";
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Entity> getEntityRegistry() {
        if (this.entityRegistry == null) {
            this.entityRegistry = (BiMap) getValue(entityCache);
        }
        return this.entityRegistry;
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateClassification;
        NeuralNetwork model = getModel();
        if (!model.isScorable()) {
            throw new InvalidResultException(model);
        }
        MathContext mathContext = model.getMathContext();
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MathContext[mathContext.ordinal()]) {
            case 1:
            case 2:
                ValueFactory<?> valueFactory = getValueFactory();
                MiningFunction miningFunction = model.getMiningFunction();
                switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                    case 1:
                        evaluateClassification = evaluateRegression(valueFactory, modelEvaluationContext);
                        break;
                    case 2:
                        evaluateClassification = evaluateClassification(valueFactory, modelEvaluationContext);
                        break;
                    default:
                        throw new UnsupportedFeatureException((PMMLObject) model, (Enum<?>) miningFunction);
                }
                return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
            default:
                throw new UnsupportedFeatureException((PMMLObject) model, (Enum<?>) mathContext);
        }
    }

    private <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        NeuralNetwork model = getModel();
        List<TargetField> targetFields = getTargetFields();
        ValueMap<String, V> evaluateRaw = evaluateRaw(valueFactory, evaluationContext);
        if (evaluateRaw == null) {
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateRegressionDefault(valueFactory, targetFields.get(0));
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            Iterator<TargetField> it = targetFields.iterator();
            while (it.hasNext()) {
                linkedHashMap.putAll(TargetUtil.evaluateRegressionDefault(valueFactory, it.next()));
            }
            return linkedHashMap;
        }
        Map<FieldName, List<NeuralOutput>> neuralOutputMap = getNeuralOutputMap();
        LinkedHashMap linkedHashMap2 = null;
        for (TargetField targetField : targetFields) {
            List<NeuralOutput> list = neuralOutputMap.get(targetField.getName());
            if (list == null || list.size() != 1) {
                throw new InvalidFeatureException((PMMLObject) model);
            }
            NeuralOutput neuralOutput = list.get(0);
            Value value = evaluateRaw.get(neuralOutput.getOutputNeuron());
            if (value == null) {
                throw new InvalidFeatureException((PMMLObject) neuralOutput);
            }
            Value<V> copy2 = value.copy2();
            NormContinuous outputExpression = getOutputExpression(neuralOutput);
            if (!(outputExpression instanceof FieldRef)) {
                if (!(outputExpression instanceof NormContinuous)) {
                    throw new UnsupportedFeatureException((PMMLObject) outputExpression);
                }
                NormalizationUtil.denormalize(outputExpression, copy2);
            }
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateRegression(targetField, copy2);
            }
            if (linkedHashMap2 == null) {
                linkedHashMap2 = new LinkedHashMap();
            }
            linkedHashMap2.putAll(TargetUtil.evaluateRegression(targetField, copy2));
        }
        return linkedHashMap2;
    }

    private <V extends Number> Map<FieldName, ? extends Classification<V>> evaluateClassification(ValueFactory<V> valueFactory, EvaluationContext evaluationContext) {
        NeuralNetwork model = getModel();
        List<TargetField> targetFields = getTargetFields();
        ValueMap<String, V> evaluateRaw = evaluateRaw(valueFactory, evaluationContext);
        if (evaluateRaw == null) {
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateClassificationDefault(valueFactory, targetFields.get(0));
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            Iterator<TargetField> it = targetFields.iterator();
            while (it.hasNext()) {
                linkedHashMap.putAll(TargetUtil.evaluateClassificationDefault(valueFactory, it.next()));
            }
            return linkedHashMap;
        }
        Map<FieldName, List<NeuralOutput>> neuralOutputMap = getNeuralOutputMap();
        BiMap<String, Entity> entityRegistry = getEntityRegistry();
        LinkedHashMap linkedHashMap2 = null;
        for (TargetField targetField : targetFields) {
            List<NeuralOutput> list = neuralOutputMap.get(targetField.getName());
            if (list == null) {
                throw new InvalidFeatureException((PMMLObject) model);
            }
            NeuronProbabilityDistribution neuronProbabilityDistribution = new NeuronProbabilityDistribution(new ValueMap(2 * list.size()), entityRegistry);
            for (NeuralOutput neuralOutput : list) {
                String outputNeuron = neuralOutput.getOutputNeuron();
                Entity entity = (Entity) entityRegistry.get(outputNeuron);
                if (entity == null) {
                    throw new InvalidFeatureException((PMMLObject) neuralOutput);
                }
                Value<V> value = (Value) evaluateRaw.get(outputNeuron);
                if (value == null) {
                    throw new InvalidFeatureException((PMMLObject) neuralOutput);
                }
                NormDiscrete outputExpression = getOutputExpression(neuralOutput);
                if (!(outputExpression instanceof NormDiscrete)) {
                    throw new UnsupportedFeatureException((PMMLObject) outputExpression);
                }
                NormDiscrete normDiscrete = outputExpression;
                String value2 = normDiscrete.getValue();
                if (value2 == null) {
                    throw new InvalidFeatureException((PMMLObject) normDiscrete);
                }
                neuronProbabilityDistribution.put(entity, value2, value);
            }
            if (targetFields.size() == 1) {
                return TargetUtil.evaluateClassification(targetField, neuronProbabilityDistribution);
            }
            if (linkedHashMap2 == null) {
                linkedHashMap2 = new LinkedHashMap();
            }
            linkedHashMap2.putAll(TargetUtil.evaluateClassification(targetField, neuronProbabilityDistribution));
        }
        return linkedHashMap2;
    }

    private Expression getOutputExpression(NeuralOutput neuralOutput) {
        DerivedField derivedField = neuralOutput.getDerivedField();
        if (derivedField == null) {
            throw new InvalidFeatureException((PMMLObject) neuralOutput);
        }
        FieldRef expression = derivedField.getExpression();
        if (expression == null) {
            throw new InvalidFeatureException((PMMLObject) derivedField);
        }
        if (!(expression instanceof FieldRef)) {
            return expression;
        }
        FieldRef fieldRef = expression;
        FieldName field = fieldRef.getField();
        DerivedField resolveField = resolveField(field);
        if (resolveField == null) {
            throw new MissingFieldException(field, fieldRef);
        }
        if (resolveField instanceof DataField) {
            return expression;
        }
        if (!(resolveField instanceof DerivedField)) {
            throw new InvalidFeatureException((PMMLObject) fieldRef);
        }
        DerivedField derivedField2 = resolveField;
        Expression expression2 = derivedField2.getExpression();
        if (expression2 == null) {
            throw new InvalidFeatureException((PMMLObject) derivedField2);
        }
        return expression2;
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:28:0x00fb. Please report as an issue. */
    /* JADX WARN: Removed duplicated region for block: B:37:0x0170  */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private <V extends java.lang.Number> org.jpmml.evaluator.ValueMap<java.lang.String, V> evaluateRaw(org.jpmml.evaluator.ValueFactory<V> r6, org.jpmml.evaluator.EvaluationContext r7) {
        /*
            Method dump skipped, instructions count: 721
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.jpmml.evaluator.neural_network.NeuralNetworkEvaluator.evaluateRaw(org.jpmml.evaluator.ValueFactory, org.jpmml.evaluator.EvaluationContext):org.jpmml.evaluator.ValueMap");
    }

    private Map<FieldName, List<NeuralOutput>> getNeuralOutputMap() {
        if (this.neuralOutputMap == null) {
            this.neuralOutputMap = parseNeuralOutputs();
        }
        return this.neuralOutputMap;
    }

    private Map<FieldName, List<NeuralOutput>> parseNeuralOutputs() {
        NeuralOutputs neuralOutputs = getModel().getNeuralOutputs();
        if (neuralOutputs == null) {
            return Collections.emptyMap();
        }
        ArrayListMultimap create = ArrayListMultimap.create();
        Iterator it = neuralOutputs.iterator();
        while (it.hasNext()) {
            NeuralOutput neuralOutput = (NeuralOutput) it.next();
            HasField outputExpression = getOutputExpression(neuralOutput);
            if (!(outputExpression instanceof HasField)) {
                throw new UnsupportedFeatureException((PMMLObject) outputExpression);
            }
            FieldName field = outputExpression.getField();
            if (field == null) {
                throw new InvalidFeatureException((PMMLObject) outputExpression);
            }
            create.put(field, neuralOutput);
        }
        return create.asMap();
    }
}
