package org.jpmml.evaluator.support_vector_machine;

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.dmg.pmml.Array;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.HasValue;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.RealSparseArray;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.support_vector_machine.Coefficient;
import org.dmg.pmml.support_vector_machine.Coefficients;
import org.dmg.pmml.support_vector_machine.Kernel;
import org.dmg.pmml.support_vector_machine.SupportVector;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.support_vector_machine.VectorDictionary;
import org.dmg.pmml.support_vector_machine.VectorFields;
import org.dmg.pmml.support_vector_machine.VectorInstance;
import org.jpmml.evaluator.ArrayUtil;
import org.jpmml.evaluator.CacheUtil;
import org.jpmml.evaluator.Classification;
import org.jpmml.evaluator.EvaluationContext;
import org.jpmml.evaluator.EvaluationException;
import org.jpmml.evaluator.ExpressionUtil;
import org.jpmml.evaluator.FieldValue;
import org.jpmml.evaluator.FieldValues;
import org.jpmml.evaluator.InvalidAttributeException;
import org.jpmml.evaluator.InvalidElementException;
import org.jpmml.evaluator.InvalidElementListException;
import org.jpmml.evaluator.MisplacedAttributeException;
import org.jpmml.evaluator.MisplacedElementException;
import org.jpmml.evaluator.MissingAttributeException;
import org.jpmml.evaluator.MissingElementException;
import org.jpmml.evaluator.MissingValueException;
import org.jpmml.evaluator.ModelEvaluationContext;
import org.jpmml.evaluator.ModelEvaluator;
import org.jpmml.evaluator.OutputUtil;
import org.jpmml.evaluator.PMMLAttributes;
import org.jpmml.evaluator.PMMLElements;
import org.jpmml.evaluator.PMMLException;
import org.jpmml.evaluator.PMMLUtil;
import org.jpmml.evaluator.SparseArrayUtil;
import org.jpmml.evaluator.TargetUtil;
import org.jpmml.evaluator.UnsupportedAttributeException;
import org.jpmml.evaluator.Value;
import org.jpmml.evaluator.ValueFactory;
import org.jpmml.evaluator.ValueMap;
import org.jpmml.evaluator.XPathUtil;
import org.jpmml.model.ReflectionUtil;

/* loaded from: input_file:org/jpmml/evaluator/support_vector_machine/SupportVectorMachineModelEvaluator.class */
public class SupportVectorMachineModelEvaluator extends ModelEvaluator<SupportVectorMachineModel> {
    private transient Map<String, double[]> vectorMap;
    private static final LoadingCache<SupportVectorMachineModel, Map<String, double[]>> vectorCache = CacheUtil.buildLoadingCache(new CacheLoader<SupportVectorMachineModel, Map<String, double[]>>() { // from class: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator.2
        public Map<String, double[]> load(SupportVectorMachineModel supportVectorMachineModel) {
            return ImmutableMap.copyOf(SupportVectorMachineModelEvaluator.parseVectorDictionary(supportVectorMachineModel));
        }
    });

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator$3, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/evaluator/support_vector_machine/SupportVectorMachineModelEvaluator$3.class */
    public static /* synthetic */ class AnonymousClass3 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MathContext;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunction;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod = new int[SupportVectorMachineModel.ClassificationMethod.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunction = new int[MiningFunction.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.ASSOCIATION_RULES.ordinal()] = 3;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.SEQUENCES.ordinal()] = 4;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.CLUSTERING.ordinal()] = 5;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.TIME_SERIES.ordinal()] = 6;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunction[MiningFunction.MIXED.ordinal()] = 7;
            } catch (NoSuchFieldError e9) {
            }
            $SwitchMap$org$dmg$pmml$MathContext = new int[MathContext.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MathContext[MathContext.DOUBLE.ordinal()] = 1;
            } catch (NoSuchFieldError e10) {
            }
            $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation = new int[SupportVectorMachineModel.Representation.values().length];
            try {
                $SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation[SupportVectorMachineModel.Representation.SUPPORT_VECTORS.ordinal()] = 1;
            } catch (NoSuchFieldError e11) {
            }
        }
    }

    public SupportVectorMachineModelEvaluator(PMML pmml) {
        this(pmml, PMMLUtil.findModel(pmml, SupportVectorMachineModel.class));
    }

    public SupportVectorMachineModelEvaluator(PMML pmml, SupportVectorMachineModel supportVectorMachineModel) {
        super(pmml, supportVectorMachineModel);
        this.vectorMap = null;
        boolean isMaxWins = supportVectorMachineModel.isMaxWins();
        if (isMaxWins) {
            throw new UnsupportedAttributeException(supportVectorMachineModel, PMMLAttributes.SUPPORTVECTORMACHINEMODEL_MAXWINS, Boolean.valueOf(isMaxWins));
        }
        SupportVectorMachineModel.Representation representation = supportVectorMachineModel.getRepresentation();
        switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$Representation[representation.ordinal()]) {
            case 1:
                VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
                if (vectorDictionary == null) {
                    throw new MissingElementException((PMMLObject) supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_VECTORDICTIONARY);
                }
                if (vectorDictionary.getVectorFields() == null) {
                    throw new MissingElementException((PMMLObject) vectorDictionary, PMMLElements.VECTORDICTIONARY_VECTORFIELDS);
                }
                if (!supportVectorMachineModel.hasSupportVectorMachines()) {
                    throw new MissingElementException((PMMLObject) supportVectorMachineModel, PMMLElements.SUPPORTVECTORMACHINEMODEL_SUPPORTVECTORMACHINES);
                }
                return;
            default:
                throw new UnsupportedAttributeException((PMMLObject) supportVectorMachineModel, (Enum<?>) representation);
        }
    }

    @Override // org.jpmml.evaluator.Evaluator
    public String getSummary() {
        return "Support vector machine";
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ?> evaluateClassification;
        SupportVectorMachineModel ensureScorableModel = ensureScorableModel();
        MathContext mathContext = ensureScorableModel.getMathContext();
        switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$MathContext[mathContext.ordinal()]) {
            case 1:
                ValueFactory<?> ensureValueFactory = ensureValueFactory();
                MiningFunction miningFunction = ensureScorableModel.getMiningFunction();
                switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$MiningFunction[miningFunction.ordinal()]) {
                    case 1:
                        evaluateClassification = evaluateRegression(ensureValueFactory, modelEvaluationContext);
                        break;
                    case 2:
                        evaluateClassification = evaluateClassification(ensureValueFactory, modelEvaluationContext);
                        break;
                    case 3:
                    case 4:
                    case 5:
                    case 6:
                    case 7:
                        throw new InvalidAttributeException((PMMLObject) ensureScorableModel, (Enum<?>) miningFunction);
                    default:
                        throw new UnsupportedAttributeException((PMMLObject) ensureScorableModel, (Enum<?>) miningFunction);
                }
                return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
            default:
                throw new UnsupportedAttributeException((PMMLObject) ensureScorableModel, (Enum<?>) mathContext);
        }
    }

    private Map<FieldName, ?> evaluateRegression(ValueFactory<Double> valueFactory, EvaluationContext evaluationContext) {
        List supportVectorMachines = getModel().getSupportVectorMachines();
        if (supportVectorMachines.size() != 1) {
            throw new InvalidElementListException(supportVectorMachines);
        }
        return TargetUtil.evaluateRegression(getTargetField(), evaluateSupportVectorMachine(valueFactory, (SupportVectorMachine) supportVectorMachines.get(0), createInput(evaluationContext)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v71, types: [org.jpmml.evaluator.ValueMap] */
    private Map<FieldName, ? extends Classification<Double>> evaluateClassification(final ValueFactory<Double> valueFactory, EvaluationContext evaluationContext) {
        VoteMap<String, Double> voteMap;
        Classification voteDistribution;
        String str;
        SupportVectorMachineModel model = getModel();
        List<SupportVectorMachine> supportVectorMachines = model.getSupportVectorMachines();
        String alternateBinaryTargetCategory = model.getAlternateBinaryTargetCategory();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = getClassificationMethod();
        switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()]) {
            case 1:
                voteMap = new ValueMap(2 * supportVectorMachines.size());
                break;
            case 2:
                voteMap = new VoteMap<String, Double>(2 * supportVectorMachines.size()) { // from class: org.jpmml.evaluator.support_vector_machine.SupportVectorMachineModelEvaluator.1
                    @Override // org.jpmml.evaluator.ValueMap
                    public ValueFactory<Double> getValueFactory() {
                        return valueFactory;
                    }
                };
                break;
            default:
                throw new UnsupportedAttributeException((PMMLObject) model, (Enum<?>) classificationMethod);
        }
        double[] createInput = createInput(evaluationContext);
        for (SupportVectorMachine supportVectorMachine : supportVectorMachines) {
            String targetCategory = supportVectorMachine.getTargetCategory();
            if (targetCategory == null) {
                throw new MissingAttributeException((PMMLObject) supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_TARGETCATEGORY);
            }
            String alternateTargetCategory = supportVectorMachine.getAlternateTargetCategory();
            Value<Double> evaluateSupportVectorMachine = evaluateSupportVectorMachine(valueFactory, supportVectorMachine, createInput);
            switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()]) {
                case 1:
                    if (alternateTargetCategory != null) {
                        throw new MisplacedAttributeException(supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY, alternateTargetCategory);
                    }
                    voteMap.put(targetCategory, evaluateSupportVectorMachine);
                    break;
                case 2:
                    if (alternateBinaryTargetCategory != null) {
                        if (alternateTargetCategory != null) {
                            throw new MisplacedAttributeException(supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY, alternateTargetCategory);
                        }
                        evaluateSupportVectorMachine.round2();
                        if (evaluateSupportVectorMachine.equals(1.0d)) {
                            str = targetCategory;
                        } else {
                            if (!evaluateSupportVectorMachine.equals(0.0d)) {
                                throw new EvaluationException("Expected " + PMMLException.formatValue(Double.valueOf(0.0d)) + " or " + PMMLException.formatValue(Double.valueOf(1.0d)) + ", got " + PMMLException.formatValue(evaluateSupportVectorMachine.getValue()));
                            }
                            str = alternateBinaryTargetCategory;
                        }
                    } else {
                        if (alternateTargetCategory == null) {
                            throw new MissingAttributeException((PMMLObject) supportVectorMachine, PMMLAttributes.SUPPORTVECTORMACHINE_ALTERNATETARGETCATEGORY);
                        }
                        Double threshold = supportVectorMachine.getThreshold();
                        if (threshold == null) {
                            threshold = model.getThreshold();
                        }
                        str = evaluateSupportVectorMachine.compareTo(threshold.doubleValue()) < 0 ? targetCategory : alternateTargetCategory;
                    }
                    voteMap.increment(str);
                    break;
            }
        }
        switch (AnonymousClass3.$SwitchMap$org$dmg$pmml$support_vector_machine$SupportVectorMachineModel$ClassificationMethod[classificationMethod.ordinal()]) {
            case 1:
                voteDistribution = new DistanceDistribution(voteMap);
                break;
            case 2:
                voteDistribution = new VoteDistribution(voteMap);
                break;
            default:
                throw new UnsupportedAttributeException((PMMLObject) model, (Enum<?>) classificationMethod);
        }
        return TargetUtil.evaluateClassification(getTargetField(), voteDistribution);
    }

    private Value<Double> evaluateSupportVectorMachine(ValueFactory<Double> valueFactory, SupportVectorMachine supportVectorMachine, double[] dArr) {
        SupportVectorMachineModel model = getModel();
        Value<Double> newValue = valueFactory.newValue();
        Kernel kernel = model.getKernel();
        if (kernel == null) {
            throw new MissingElementException(MissingElementException.formatMessage(XPathUtil.formatElement(model.getClass()) + "/<Kernel>"), (PMMLObject) supportVectorMachine);
        }
        Coefficients coefficients = supportVectorMachine.getCoefficients();
        Iterator it = coefficients.iterator();
        Iterator it2 = supportVectorMachine.getSupportVectors().iterator();
        Map<String, double[]> vectorMap = getVectorMap();
        while (it.hasNext() && it2.hasNext()) {
            Coefficient coefficient = (Coefficient) it.next();
            SupportVector supportVector = (SupportVector) it2.next();
            String vectorId = supportVector.getVectorId();
            if (vectorId == null) {
                throw new MissingAttributeException((PMMLObject) supportVector, PMMLAttributes.SUPPORTVECTOR_VECTORID);
            }
            double[] dArr2 = vectorMap.get(vectorId);
            if (dArr2 == null) {
                throw new InvalidAttributeException(supportVector, PMMLAttributes.SUPPORTVECTOR_VECTORID, vectorId);
            }
            newValue.add2(coefficient.getValue().doubleValue(), Double.valueOf(KernelUtil.evaluate(kernel, dArr, dArr2)));
        }
        if (it.hasNext() || it2.hasNext()) {
            throw new InvalidElementException((PMMLObject) supportVectorMachine);
        }
        double doubleValue = coefficients.getAbsoluteValue().doubleValue();
        if (doubleValue != 0.0d) {
            newValue.add2(doubleValue);
        }
        return newValue;
    }

    private SupportVectorMachineModel.ClassificationMethod getClassificationMethod() {
        SupportVectorMachineModel model = getModel();
        SupportVectorMachineModel.ClassificationMethod classificationMethod = (SupportVectorMachineModel.ClassificationMethod) ReflectionUtil.getFieldValue(PMMLAttributes.SUPPORTVECTORMACHINEMODEL_CLASSIFICATIONMETHOD, model);
        if (classificationMethod != null) {
            return classificationMethod;
        }
        List supportVectorMachines = model.getSupportVectorMachines();
        if (model.getAlternateBinaryTargetCategory() != null) {
            if (supportVectorMachines.size() != 1) {
                throw new InvalidElementException((PMMLObject) model);
            }
            SupportVectorMachine supportVectorMachine = (SupportVectorMachine) supportVectorMachines.get(0);
            if (supportVectorMachine.getTargetCategory() != null) {
                return SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE;
            }
            throw new InvalidElementException((PMMLObject) supportVectorMachine);
        }
        Iterator it = supportVectorMachines.iterator();
        if (!it.hasNext()) {
            throw new InvalidElementException((PMMLObject) model);
        }
        SupportVectorMachine supportVectorMachine2 = (SupportVectorMachine) it.next();
        String targetCategory = supportVectorMachine2.getTargetCategory();
        String alternateTargetCategory = supportVectorMachine2.getAlternateTargetCategory();
        if (targetCategory != null) {
            return alternateTargetCategory != null ? SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE : SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ALL;
        }
        throw new InvalidElementException((PMMLObject) supportVectorMachine2);
    }

    private double[] createInput(EvaluationContext evaluationContext) {
        VectorFields vectorFields = getModel().getVectorDictionary().getVectorFields();
        List content = vectorFields.getContent();
        double[] dArr = new double[content.size()];
        for (int i = 0; i < content.size(); i++) {
            CategoricalPredictor categoricalPredictor = (PMMLObject) content.get(i);
            if (categoricalPredictor instanceof FieldRef) {
                FieldRef fieldRef = (FieldRef) content.get(i);
                FieldName field = fieldRef.getField();
                FieldValue evaluate = ExpressionUtil.evaluate(fieldRef, evaluationContext);
                if (Objects.equals(FieldValues.MISSING_VALUE, evaluate)) {
                    throw new MissingValueException(field, vectorFields);
                }
                dArr[i] = evaluate.asNumber().doubleValue();
            } else {
                if (!(categoricalPredictor instanceof CategoricalPredictor)) {
                    throw new MisplacedElementException(categoricalPredictor);
                }
                CategoricalPredictor categoricalPredictor2 = categoricalPredictor;
                FieldName name = categoricalPredictor2.getName();
                if (name == null) {
                    throw new MissingAttributeException((PMMLObject) categoricalPredictor2, PMMLAttributes.CATEGORICALPREDICTOR_FIELD);
                }
                FieldValue evaluate2 = evaluationContext.evaluate(name);
                if (Objects.equals(FieldValues.MISSING_VALUE, evaluate2)) {
                    throw new MissingValueException(name, categoricalPredictor2);
                }
                double coefficient = categoricalPredictor2.getCoefficient();
                if (coefficient != 1.0d) {
                    throw new InvalidAttributeException(categoricalPredictor2, PMMLAttributes.CATEGORICALPREDICTOR_COEFFICIENT, Double.valueOf(coefficient));
                }
                dArr[i] = evaluate2.equals((HasValue<?>) categoricalPredictor2) ? 1.0d : 0.0d;
            }
        }
        return dArr;
    }

    private Map<String, double[]> getVectorMap() {
        if (this.vectorMap == null) {
            this.vectorMap = (Map) getValue(vectorCache);
        }
        return this.vectorMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, double[]> parseVectorDictionary(SupportVectorMachineModel supportVectorMachineModel) {
        List<? extends Number> asNumberList;
        VectorDictionary vectorDictionary = supportVectorMachineModel.getVectorDictionary();
        List content = vectorDictionary.getVectorFields().getContent();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (VectorInstance vectorInstance : vectorDictionary.getVectorInstances()) {
            String id = vectorInstance.getId();
            if (id == null) {
                throw new MissingAttributeException((PMMLObject) vectorInstance, PMMLAttributes.VECTORINSTANCE_ID);
            }
            Array array = vectorInstance.getArray();
            RealSparseArray realSparseArray = vectorInstance.getRealSparseArray();
            if (array != null && realSparseArray == null) {
                asNumberList = ArrayUtil.asNumberList(array);
            } else {
                if (array != null || realSparseArray == null) {
                    throw new InvalidElementException((PMMLObject) vectorInstance);
                }
                asNumberList = SparseArrayUtil.asNumberList(realSparseArray);
            }
            if (content.size() != asNumberList.size()) {
                throw new InvalidElementException((PMMLObject) vectorInstance);
            }
            linkedHashMap.put(id, Doubles.toArray(asNumberList));
        }
        return linkedHashMap;
    }
}
