package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.Maps;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import org.dmg.pmml.ActivationFunctionType;
import org.dmg.pmml.Connection;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Entity;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.NeuralInput;
import org.dmg.pmml.NeuralInputs;
import org.dmg.pmml.NeuralLayer;
import org.dmg.pmml.NeuralNetwork;
import org.dmg.pmml.NeuralOutput;
import org.dmg.pmml.NeuralOutputs;
import org.dmg.pmml.Neuron;
import org.dmg.pmml.NnNormalizationMethodType;
import org.dmg.pmml.NormContinuous;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMML;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:org/jpmml/evaluator/NeuralNetworkEvaluator.class */
public class NeuralNetworkEvaluator extends ModelEvaluator<NeuralNetwork> implements HasEntityRegistry<Entity> {
    private static final LoadingCache<NeuralNetwork, BiMap<String, Entity>> entityCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<NeuralNetwork, BiMap<String, Entity>>() { // from class: org.jpmml.evaluator.NeuralNetworkEvaluator.1
        public BiMap<String, Entity> load(NeuralNetwork neuralNetwork) {
            ImmutableBiMap.Builder builder = new ImmutableBiMap.Builder();
            NeuralInputs neuralInputs = neuralNetwork.getNeuralInputs();
            if (neuralInputs == null) {
                throw new InvalidFeatureException(neuralNetwork);
            }
            Iterator it = neuralInputs.iterator();
            while (it.hasNext()) {
                builder = EntityUtil.put((NeuralInput) it.next(), builder);
            }
            Iterator it2 = neuralNetwork.getNeuralLayers().iterator();
            while (it2.hasNext()) {
                Iterator it3 = ((NeuralLayer) it2.next()).getNeurons().iterator();
                while (it3.hasNext()) {
                    builder = EntityUtil.put((Neuron) it3.next(), builder);
                }
            }
            return builder.build();
        }
    });

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.NeuralNetworkEvaluator$2, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/evaluator/NeuralNetworkEvaluator$2.class */
    public static /* synthetic */ class AnonymousClass2 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunctionType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$NnNormalizationMethodType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$ActivationFunctionType = new int[ActivationFunctionType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.THRESHOLD.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.LOGISTIC.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.TANH.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.IDENTITY.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.EXPONENTIAL.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.RECIPROCAL.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.SQUARE.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.GAUSS.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.SINE.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.COSINE.ordinal()] = 10;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.ELLIOTT.ordinal()] = 11;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$dmg$pmml$ActivationFunctionType[ActivationFunctionType.ARCTAN.ordinal()] = 12;
            } catch (NoSuchFieldError e12) {
            }
            $SwitchMap$org$dmg$pmml$NnNormalizationMethodType = new int[NnNormalizationMethodType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$NnNormalizationMethodType[NnNormalizationMethodType.NONE.ordinal()] = 1;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$dmg$pmml$NnNormalizationMethodType[NnNormalizationMethodType.SIMPLEMAX.ordinal()] = 2;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$org$dmg$pmml$NnNormalizationMethodType[NnNormalizationMethodType.SOFTMAX.ordinal()] = 3;
            } catch (NoSuchFieldError e15) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunctionType = new int[MiningFunctionType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunctionType[MiningFunctionType.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e16) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunctionType[MiningFunctionType.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e17) {
            }
        }
    }

    public NeuralNetworkEvaluator(PMML pmml) {
        this(pmml, find(pmml.getModels(), NeuralNetwork.class));
    }

    public NeuralNetworkEvaluator(PMML pmml, NeuralNetwork neuralNetwork) {
        super(pmml, neuralNetwork);
    }

    public String getSummary() {
        return "Neural network";
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Entity> getEntityRegistry() {
        return (BiMap) getValue(entityCache);
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateClassification;
        NeuralNetwork model = getModel();
        if (!model.isScorable()) {
            throw new InvalidResultException(model);
        }
        MiningFunctionType functionName = model.getFunctionName();
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$MiningFunctionType[functionName.ordinal()]) {
            case 1:
                evaluateClassification = evaluateRegression(modelEvaluationContext);
                break;
            case 2:
                evaluateClassification = evaluateClassification(modelEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(model, functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
    }

    private Map<FieldName, ?> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        NeuralNetwork model = getModel();
        Map<String, Double> evaluateRaw = evaluateRaw(modelEvaluationContext);
        if (evaluateRaw == null) {
            return TargetUtil.evaluateRegressionDefault(modelEvaluationContext);
        }
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        NeuralOutputs neuralOutputs = model.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException(model);
        }
        Iterator it = neuralOutputs.iterator();
        while (it.hasNext()) {
            NeuralOutput neuralOutput = (NeuralOutput) it.next();
            String outputNeuron = neuralOutput.getOutputNeuron();
            NormContinuous expression = getExpression(neuralOutput.getDerivedField(), modelEvaluationContext);
            if (expression instanceof FieldRef) {
                newLinkedHashMap.put(((FieldRef) expression).getField(), evaluateRaw.get(outputNeuron));
            } else {
                if (!(expression instanceof NormContinuous)) {
                    throw new UnsupportedFeatureException(expression);
                }
                NormContinuous normContinuous = expression;
                newLinkedHashMap.put(normContinuous.getField(), Double.valueOf(NormalizationUtil.denormalize(normContinuous, evaluateRaw.get(outputNeuron).doubleValue())));
            }
        }
        return TargetUtil.evaluateRegression(newLinkedHashMap, modelEvaluationContext);
    }

    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        NeuralNetwork model = getModel();
        BiMap<String, Entity> entityRegistry = getEntityRegistry();
        Map<String, Double> evaluateRaw = evaluateRaw(modelEvaluationContext);
        if (evaluateRaw == null) {
            return TargetUtil.evaluateClassificationDefault(modelEvaluationContext);
        }
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        NeuralOutputs neuralOutputs = model.getNeuralOutputs();
        if (neuralOutputs == null) {
            throw new InvalidFeatureException(model);
        }
        Iterator it = neuralOutputs.iterator();
        while (it.hasNext()) {
            NeuralOutput neuralOutput = (NeuralOutput) it.next();
            String outputNeuron = neuralOutput.getOutputNeuron();
            NormDiscrete expression = getExpression(neuralOutput.getDerivedField(), modelEvaluationContext);
            if (!(expression instanceof NormDiscrete)) {
                throw new UnsupportedFeatureException(expression);
            }
            NormDiscrete normDiscrete = expression;
            FieldName field = normDiscrete.getField();
            NeuronClassificationMap neuronClassificationMap = (NeuronClassificationMap) newLinkedHashMap.get(field);
            if (neuronClassificationMap == null) {
                neuronClassificationMap = new NeuronClassificationMap();
                newLinkedHashMap.put(field, neuronClassificationMap);
            }
            neuronClassificationMap.put((Entity) entityRegistry.get(outputNeuron), normDiscrete.getValue(), evaluateRaw.get(outputNeuron));
        }
        return TargetUtil.evaluateClassification(newLinkedHashMap, modelEvaluationContext);
    }

    private Expression getExpression(DerivedField derivedField, EvaluationContext evaluationContext) {
        FieldRef expression = derivedField.getExpression();
        if (!(expression instanceof FieldRef)) {
            return expression;
        }
        FieldRef fieldRef = expression;
        EvaluationContext.Result<DerivedField> resolveDerivedField = evaluationContext.resolveDerivedField(fieldRef.getField());
        return resolveDerivedField != null ? getExpression(resolveDerivedField.getElement(), evaluationContext) : fieldRef;
    }

    public Map<String, Double> evaluateRaw(EvaluationContext evaluationContext) {
        NeuralNetwork model = getModel();
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        NeuralInputs neuralInputs = model.getNeuralInputs();
        if (neuralInputs == null) {
            throw new InvalidFeatureException(model);
        }
        Iterator it = neuralInputs.iterator();
        while (it.hasNext()) {
            NeuralInput neuralInput = (NeuralInput) it.next();
            FieldValue evaluate = ExpressionUtil.evaluate(neuralInput.getDerivedField(), evaluationContext);
            if (evaluate == null) {
                return null;
            }
            newLinkedHashMap.put(neuralInput.getId(), Double.valueOf(evaluate.asNumber().doubleValue()));
        }
        for (NeuralLayer neuralLayer : model.getNeuralLayers()) {
            LinkedHashMap newLinkedHashMap2 = Maps.newLinkedHashMap();
            for (Neuron neuron : neuralLayer.getNeurons()) {
                double doubleValue = neuron.getBias().doubleValue();
                for (Connection connection : neuron.getConnections()) {
                    doubleValue += ((Double) newLinkedHashMap.get(connection.getFrom())).doubleValue() * connection.getWeight();
                }
                newLinkedHashMap2.put(neuron.getId(), Double.valueOf(activation(doubleValue, neuralLayer)));
            }
            normalizeNeuronOutputs(neuralLayer, newLinkedHashMap2);
            newLinkedHashMap.putAll(newLinkedHashMap2);
        }
        return newLinkedHashMap;
    }

    private void normalizeNeuronOutputs(NeuralLayer neuralLayer, Map<String, Double> map) {
        NeuralLayer neuralLayer2 = (NeuralNetwork) getModel();
        NeuralLayer neuralLayer3 = neuralLayer;
        NnNormalizationMethodType normalizationMethod = neuralLayer.getNormalizationMethod();
        if (normalizationMethod == null) {
            neuralLayer3 = neuralLayer2;
            normalizationMethod = neuralLayer2.getNormalizationMethod();
        }
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$NnNormalizationMethodType[normalizationMethod.ordinal()]) {
            case 1:
                return;
            case 2:
                ClassificationMap.normalize(map);
                return;
            case 3:
                ClassificationMap.normalizeSoftMax(map);
                return;
            default:
                throw new UnsupportedFeatureException(neuralLayer3, normalizationMethod);
        }
    }

    private double activation(double d, NeuralLayer neuralLayer) {
        NeuralLayer neuralLayer2 = (NeuralNetwork) getModel();
        NeuralLayer neuralLayer3 = neuralLayer;
        ActivationFunctionType activationFunction = neuralLayer.getActivationFunction();
        if (activationFunction == null) {
            neuralLayer3 = neuralLayer2;
            activationFunction = neuralLayer2.getActivationFunction();
        }
        if (activationFunction == null) {
            throw new InvalidFeatureException(neuralLayer3);
        }
        switch (AnonymousClass2.$SwitchMap$org$dmg$pmml$ActivationFunctionType[activationFunction.ordinal()]) {
            case 1:
                Double threshold = neuralLayer.getThreshold();
                if (threshold == null) {
                    threshold = Double.valueOf(neuralLayer2.getThreshold());
                }
                return d > threshold.doubleValue() ? 1.0d : 0.0d;
            case 2:
                return 1.0d / (1.0d + Math.exp(-d));
            case 3:
                return Math.tanh(d);
            case 4:
                return d;
            case 5:
                return Math.exp(d);
            case 6:
                return 1.0d / d;
            case 7:
                return d * d;
            case 8:
                return Math.exp(-(d * d));
            case 9:
                return Math.sin(d);
            case 10:
                return Math.cos(d);
            case 11:
                return d / (1.0d + Math.abs(d));
            case 12:
                return Math.atan(d);
            default:
                throw new UnsupportedFeatureException(neuralLayer3, activationFunction);
        }
    }
}
