package org.jpmml.evaluator;

import com.google.common.base.Predicate;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Ordering;
import com.google.common.collect.Sets;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.BaseCumHazardTables;
import org.dmg.pmml.BaselineCell;
import org.dmg.pmml.BaselineStratum;
import org.dmg.pmml.Categories;
import org.dmg.pmml.Category;
import org.dmg.pmml.CumulativeLinkFunctionType;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.GeneralRegressionModel;
import org.dmg.pmml.LinkFunctionType;
import org.dmg.pmml.Matrix;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PCell;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PPCell;
import org.dmg.pmml.Parameter;
import org.dmg.pmml.ParameterCell;
import org.dmg.pmml.ParameterList;
import org.dmg.pmml.Predictor;
import org.dmg.pmml.PredictorList;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator.class */
public class GeneralRegressionModelEvaluator extends ModelEvaluator<GeneralRegressionModel> {
    private static final LoadingCache<GeneralRegressionModel, BiMap<String, Parameter>> parameterCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<GeneralRegressionModel, BiMap<String, Parameter>>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.4
        public BiMap<String, Parameter> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf(GeneralRegressionModelEvaluator.parseParameterRegistry(generalRegressionModel.getParameterList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, BiMap<FieldName, Predictor>> factorCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<GeneralRegressionModel, BiMap<FieldName, Predictor>>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.5
        public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf(GeneralRegressionModelEvaluator.parsePredictorRegistry(generalRegressionModel.getFactorList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, BiMap<FieldName, Predictor>> covariateCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<GeneralRegressionModel, BiMap<FieldName, Predictor>>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.6
        public BiMap<FieldName, Predictor> load(GeneralRegressionModel generalRegressionModel) {
            return ImmutableBiMap.copyOf(GeneralRegressionModelEvaluator.parsePredictorRegistry(generalRegressionModel.getCovariateList()));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, Map<String, Map<String, Row>>> ppMatrixCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<GeneralRegressionModel, Map<String, Map<String, Row>>>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.7
        public Map<String, Map<String, Row>> load(GeneralRegressionModel generalRegressionModel) {
            return Collections.unmodifiableMap(GeneralRegressionModelEvaluator.parsePPMatrix(generalRegressionModel));
        }
    });
    private static final LoadingCache<GeneralRegressionModel, Map<String, List<PCell>>> paramMatrixCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<GeneralRegressionModel, Map<String, List<PCell>>>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.8
        public Map<String, List<PCell>> load(GeneralRegressionModel generalRegressionModel) {
            return Collections.unmodifiableMap(GeneralRegressionModelEvaluator.parseParamMatrix(generalRegressionModel));
        }
    });

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.jpmml.evaluator.GeneralRegressionModelEvaluator$12, reason: invalid class name */
    /* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator$12.class */
    public static /* synthetic */ class AnonymousClass12 {
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$MiningFunctionType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$OpType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$LinkFunctionType;
        static final /* synthetic */ int[] $SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType = new int[CumulativeLinkFunctionType.values().length];

        static {
            try {
                $SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType[CumulativeLinkFunctionType.LOGIT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType[CumulativeLinkFunctionType.PROBIT.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType[CumulativeLinkFunctionType.CLOGLOG.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType[CumulativeLinkFunctionType.LOGLOG.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType[CumulativeLinkFunctionType.CAUCHIT.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            $SwitchMap$org$dmg$pmml$LinkFunctionType = new int[LinkFunctionType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.CLOGLOG.ordinal()] = 1;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.IDENTITY.ordinal()] = 2;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.LOG.ordinal()] = 3;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.LOGC.ordinal()] = 4;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.LOGIT.ordinal()] = 5;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.LOGLOG.ordinal()] = 6;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.ODDSPOWER.ordinal()] = 7;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.POWER.ordinal()] = 8;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$dmg$pmml$LinkFunctionType[LinkFunctionType.PROBIT.ordinal()] = 9;
            } catch (NoSuchFieldError e14) {
            }
            $SwitchMap$org$dmg$pmml$OpType = new int[OpType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CONTINUOUS.ordinal()] = 1;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.CATEGORICAL.ordinal()] = 2;
            } catch (NoSuchFieldError e16) {
            }
            try {
                $SwitchMap$org$dmg$pmml$OpType[OpType.ORDINAL.ordinal()] = 3;
            } catch (NoSuchFieldError e17) {
            }
            $SwitchMap$org$dmg$pmml$MiningFunctionType = new int[MiningFunctionType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$MiningFunctionType[MiningFunctionType.REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e18) {
            }
            try {
                $SwitchMap$org$dmg$pmml$MiningFunctionType[MiningFunctionType.CLASSIFICATION.ordinal()] = 2;
            } catch (NoSuchFieldError e19) {
            }
            $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType = new int[GeneralRegressionModel.ModelType.values().length];
            try {
                $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[GeneralRegressionModel.ModelType.COX_REGRESSION.ordinal()] = 1;
            } catch (NoSuchFieldError e20) {
            }
            try {
                $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[GeneralRegressionModel.ModelType.REGRESSION.ordinal()] = 2;
            } catch (NoSuchFieldError e21) {
            }
            try {
                $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[GeneralRegressionModel.ModelType.GENERAL_LINEAR.ordinal()] = 3;
            } catch (NoSuchFieldError e22) {
            }
            try {
                $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[GeneralRegressionModel.ModelType.GENERALIZED_LINEAR.ordinal()] = 4;
            } catch (NoSuchFieldError e23) {
            }
            try {
                $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[GeneralRegressionModel.ModelType.MULTINOMIAL_LOGISTIC.ordinal()] = 5;
            } catch (NoSuchFieldError e24) {
            }
            try {
                $SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[GeneralRegressionModel.ModelType.ORDINAL_MULTINOMIAL.ordinal()] = 6;
            } catch (NoSuchFieldError e25) {
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator$Row.class */
    public static class Row {
        private List<FactorHandler> factorHandlers;
        private List<CovariateHandler> covariateHandlers;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator$Row$ContrastMatrixHandler.class */
        public class ContrastMatrixHandler extends FactorHandler {
            private Matrix matrix;
            private List<String> categories;

            private ContrastMatrixHandler(PPCell pPCell, Matrix matrix, List<String> list) {
                super(pPCell);
                this.matrix = null;
                this.categories = null;
                setMatrix(matrix);
                setCategories(list);
            }

            @Override // org.jpmml.evaluator.GeneralRegressionModelEvaluator.Row.FactorHandler, org.jpmml.evaluator.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public Double evaluate(FieldValue fieldValue) {
                Matrix matrix = getMatrix();
                int index = getIndex(fieldValue);
                int index2 = getIndex(getCategory());
                if (index < 0 || index2 < 0) {
                    throw new EvaluationException();
                }
                Number elementAt = MatrixUtil.getElementAt(matrix, index + 1, index2 + 1);
                if (elementAt == null) {
                    throw new EvaluationException();
                }
                return Double.valueOf(elementAt.doubleValue());
            }

            public int getIndex(FieldValue fieldValue) {
                List<String> categories = getCategories();
                for (int i = 0; i < categories.size(); i++) {
                    if (fieldValue.equalsString(categories.get(i))) {
                        return i;
                    }
                }
                return -1;
            }

            public int getIndex(String str) {
                return getCategories().indexOf(str);
            }

            public Matrix getMatrix() {
                return this.matrix;
            }

            private void setMatrix(Matrix matrix) {
                this.matrix = matrix;
            }

            public List<String> getCategories() {
                return this.categories;
            }

            private void setCategories(List<String> list) {
                this.categories = list;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator$Row$CovariateHandler.class */
        public class CovariateHandler extends PredictorHandler {
            private CovariateHandler(PPCell pPCell) {
                super(pPCell);
            }

            @Override // org.jpmml.evaluator.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public Double evaluate(FieldValue fieldValue) {
                return Double.valueOf(Math.pow(fieldValue.asNumber().doubleValue(), getMultiplicity().doubleValue()));
            }

            public Double getMultiplicity() {
                return Double.valueOf(getPPCell().getValue());
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator$Row$FactorHandler.class */
        public class FactorHandler extends PredictorHandler {
            private FactorHandler(PPCell pPCell) {
                super(pPCell);
            }

            @Override // org.jpmml.evaluator.GeneralRegressionModelEvaluator.Row.PredictorHandler
            public Double evaluate(FieldValue fieldValue) {
                return Double.valueOf(fieldValue.equalsString(getCategory()) ? 1.0d : 0.0d);
            }

            public String getCategory() {
                return getPPCell().getValue();
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:org/jpmml/evaluator/GeneralRegressionModelEvaluator$Row$PredictorHandler.class */
        public abstract class PredictorHandler {
            private PPCell ppCell;

            private PredictorHandler(PPCell pPCell) {
                this.ppCell = null;
                setPPCell(pPCell);
            }

            public abstract Double evaluate(FieldValue fieldValue);

            public FieldName getPredictorName() {
                return getPPCell().getPredictorName();
            }

            public PPCell getPPCell() {
                return this.ppCell;
            }

            private void setPPCell(PPCell pPCell) {
                this.ppCell = pPCell;
            }
        }

        private Row() {
            this.factorHandlers = Lists.newArrayList();
            this.covariateHandlers = Lists.newArrayList();
        }

        public Double evaluate(Map<FieldName, FieldValue> map) {
            List<FactorHandler> factorHandlers = getFactorHandlers();
            List<CovariateHandler> covariateHandlers = getCovariateHandlers();
            if (factorHandlers.isEmpty() && covariateHandlers.isEmpty()) {
                return Double.valueOf(1.0d);
            }
            Double computeProduct = computeProduct(factorHandlers, map);
            Double computeProduct2 = computeProduct(covariateHandlers, map);
            if (covariateHandlers.isEmpty()) {
                return computeProduct;
            }
            if (factorHandlers.isEmpty()) {
                return computeProduct2;
            }
            if (computeProduct == null || computeProduct2 == null) {
                return null;
            }
            return Double.valueOf(computeProduct.doubleValue() * computeProduct2.doubleValue());
        }

        public void addFactor(PPCell pPCell, Predictor predictor) {
            List<FactorHandler> factorHandlers = getFactorHandlers();
            Matrix matrix = predictor.getMatrix();
            if (matrix == null) {
                factorHandlers.add(new FactorHandler(pPCell));
                return;
            }
            Categories categories = predictor.getCategories();
            if (categories == null) {
                throw new UnsupportedFeatureException(predictor);
            }
            factorHandlers.add(new ContrastMatrixHandler(pPCell, matrix, Lists.transform(categories.getCategories(), new com.google.common.base.Function<Category, String>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.Row.1
                public String apply(Category category) {
                    return category.getValue();
                }
            })));
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addCovariate(PPCell pPCell) {
            getCovariateHandlers().add(new CovariateHandler(pPCell));
        }

        public List<FactorHandler> getFactorHandlers() {
            return this.factorHandlers;
        }

        public List<CovariateHandler> getCovariateHandlers() {
            return this.covariateHandlers;
        }

        private static Double computeProduct(List<? extends PredictorHandler> list, Map<FieldName, FieldValue> map) {
            Double d = null;
            for (PredictorHandler predictorHandler : list) {
                FieldValue fieldValue = map.get(predictorHandler.getPredictorName());
                if (fieldValue == null) {
                    return null;
                }
                d = d == null ? predictorHandler.evaluate(fieldValue) : Double.valueOf(d.doubleValue() * predictorHandler.evaluate(fieldValue).doubleValue());
            }
            return d;
        }
    }

    public GeneralRegressionModelEvaluator(PMML pmml) {
        this(pmml, find(pmml.getModels(), GeneralRegressionModel.class));
    }

    public GeneralRegressionModelEvaluator(PMML pmml, GeneralRegressionModel generalRegressionModel) {
        super(pmml, generalRegressionModel);
    }

    public String getSummary() {
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[getModel().getModelType().ordinal()]) {
            case 1:
                return "Cox regression";
            default:
                return "General regression";
        }
    }

    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ? extends Number> evaluateClassification;
        GeneralRegressionModel model = getModel();
        if (!model.isScorable()) {
            throw new InvalidResultException(model);
        }
        MiningFunctionType functionName = model.getFunctionName();
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$MiningFunctionType[functionName.ordinal()]) {
            case 1:
                evaluateClassification = evaluateRegression(modelEvaluationContext);
                break;
            case 2:
                evaluateClassification = evaluateClassification(modelEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(model, functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
    }

    private Map<FieldName, ? extends Number> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[getModel().getModelType().ordinal()]) {
            case 1:
                return evaluateCoxRegression(modelEvaluationContext);
            default:
                return evaluateGeneralRegression(modelEvaluationContext);
        }
    }

    private Map<FieldName, ? extends Number> evaluateCoxRegression(ModelEvaluationContext modelEvaluationContext) {
        List baselineCells;
        Double maxTime;
        GeneralRegressionModel model = getModel();
        BaseCumHazardTables baseCumHazardTables = model.getBaseCumHazardTables();
        if (baseCumHazardTables == null) {
            throw new InvalidFeatureException(model);
        }
        FieldName targetField = getTargetField();
        FieldName baselineStrataVariable = model.getBaselineStrataVariable();
        if (baselineStrataVariable != null) {
            BaselineStratum baselineStratum = getBaselineStratum(baseCumHazardTables, getVariable(baselineStrataVariable, modelEvaluationContext));
            if (baselineStratum == null) {
                return null;
            }
            baselineCells = baselineStratum.getBaselineCells();
            maxTime = Double.valueOf(baselineStratum.getMaxTime());
        } else {
            baselineCells = baseCumHazardTables.getBaselineCells();
            maxTime = baseCumHazardTables.getMaxTime();
            if (maxTime == null) {
                throw new InvalidFeatureException(baseCumHazardTables);
            }
        }
        Ordering from = Ordering.from(new Comparator<BaselineCell>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.1
            @Override // java.util.Comparator
            public int compare(BaselineCell baselineCell, BaselineCell baselineCell2) {
                return Double.compare(baselineCell.getTime(), baselineCell2.getTime());
            }
        });
        model.getStartTimeVariable();
        FieldName endTimeVariable = model.getEndTimeVariable();
        if (endTimeVariable == null) {
            throw new InvalidFeatureException(model);
        }
        Double valueOf = Double.valueOf(((BaselineCell) from.min(baselineCells)).getTime());
        final FieldValue variable = getVariable(endTimeVariable, modelEvaluationContext);
        if (variable.compareToValue(FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, valueOf)) < 0) {
            return Collections.singletonMap(targetField, Double.valueOf(1.0d));
        }
        if (variable.compareToValue(FieldValueUtil.create(DataType.DOUBLE, OpType.CONTINUOUS, maxTime)) > 0) {
            return null;
        }
        double cumHazard = ((BaselineCell) from.max(Iterables.filter(baselineCells, new Predicate<BaselineCell>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.2
            private double time;

            {
                this.time = variable.asNumber().doubleValue();
            }

            public boolean apply(BaselineCell baselineCell) {
                return baselineCell.getTime() <= this.time;
            }
        }))).getCumHazard();
        Double computeDotProduct = computeDotProduct(modelEvaluationContext);
        Double computeReferencePoint = computeReferencePoint();
        if (computeDotProduct == null || computeReferencePoint == null) {
            return null;
        }
        return Collections.singletonMap(targetField, Double.valueOf(Math.exp((-1.0d) * Double.valueOf(cumHazard * Math.exp(computeDotProduct.doubleValue() - computeReferencePoint.doubleValue())).doubleValue())));
    }

    private Map<FieldName, ? extends Number> evaluateGeneralRegression(ModelEvaluationContext modelEvaluationContext) {
        GeneralRegressionModel model = getModel();
        Double computeDotProduct = computeDotProduct(modelEvaluationContext);
        if (computeDotProduct == null) {
            return TargetUtil.evaluateRegressionDefault(modelEvaluationContext);
        }
        GeneralRegressionModel.ModelType modelType = model.getModelType();
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
            case 2:
            case 3:
                break;
            case 4:
                computeDotProduct = computeLink(computeDotProduct, modelEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(model, modelType);
        }
        return TargetUtil.evaluateRegression(computeDotProduct, modelEvaluationContext);
    }

    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        Double valueOf;
        Map<String, Row> map;
        Iterable<PCell> concat;
        GeneralRegressionModel model = getModel();
        FieldName targetField = getTargetField();
        DataField dataField = getDataField(targetField);
        OpType opType = dataField.getOpType();
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$OpType[opType.ordinal()]) {
            case 1:
                throw new InvalidFeatureException(dataField);
            case 2:
            case 3:
                List<String> targetCategories = ArgumentUtil.getTargetCategories(dataField);
                if (targetCategories.size() > 0 && targetCategories.size() < 2) {
                    throw new InvalidFeatureException(dataField);
                }
                Map<FieldName, FieldValue> arguments = getArguments(modelEvaluationContext);
                Map<String, Map<String, Row>> pPMatrixMap = getPPMatrixMap();
                final Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap();
                String targetReferenceCategory = model.getTargetReferenceCategory();
                GeneralRegressionModel.ModelType modelType = model.getModelType();
                switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
                    case 4:
                    case 5:
                        if (targetReferenceCategory == null) {
                            LinkedHashSet newLinkedHashSet = Sets.newLinkedHashSet(Iterables.filter(targetCategories, new Predicate<String>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.3
                                public boolean apply(String str) {
                                    return !paramMatrixMap.containsKey(str);
                                }
                            }));
                            if (newLinkedHashSet.size() != 1) {
                                throw new InvalidFeatureException(model.getParamMatrix());
                            }
                            targetReferenceCategory = (String) Iterables.getOnlyElement(newLinkedHashSet);
                            break;
                        }
                        break;
                    case 6:
                        break;
                    default:
                        throw new UnsupportedFeatureException(model, modelType);
                }
                if (targetReferenceCategory != null) {
                    targetCategories = Lists.newArrayList(targetCategories);
                    if (targetCategories.remove(targetReferenceCategory)) {
                        targetCategories.add(targetReferenceCategory);
                    }
                }
                ProbabilityClassificationMap probabilityClassificationMap = new ProbabilityClassificationMap();
                Double d = null;
                for (int i = 0; i < targetCategories.size(); i++) {
                    String str = targetCategories.get(i);
                    if (i < targetCategories.size() - 1) {
                        if (pPMatrixMap.isEmpty()) {
                            map = Collections.emptyMap();
                        } else {
                            map = pPMatrixMap.get(str);
                            if (map == null) {
                                map = pPMatrixMap.get(null);
                            }
                            if (map == null) {
                                throw new InvalidFeatureException(model.getPPMatrix());
                            }
                        }
                        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
                            case 4:
                            case 5:
                                concat = paramMatrixMap.get(str);
                                if (concat == null && targetCategories.size() == 2) {
                                    concat = paramMatrixMap.get(null);
                                }
                                if (concat == null) {
                                    throw new InvalidFeatureException(model.getParamMatrix());
                                }
                                break;
                            case 6:
                                List<PCell> list = paramMatrixMap.get(str);
                                if (list == null || list.size() != 1) {
                                    throw new InvalidFeatureException(model.getParamMatrix());
                                }
                                List<PCell> list2 = paramMatrixMap.get(null);
                                if (list2 == null) {
                                    throw new InvalidFeatureException(model.getParamMatrix());
                                }
                                concat = Iterables.concat(list, list2);
                                break;
                            default:
                                throw new UnsupportedFeatureException(model, modelType);
                        }
                        Double computeDotProduct = computeDotProduct(concat, map, arguments);
                        if (computeDotProduct == null) {
                            return TargetUtil.evaluateClassificationDefault(modelEvaluationContext);
                        }
                        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
                            case 4:
                                valueOf = computeLink(computeDotProduct, modelEvaluationContext);
                                break;
                            case 5:
                                valueOf = Double.valueOf(Math.exp(computeDotProduct.doubleValue()));
                                break;
                            case 6:
                                valueOf = computeCumulativeLink(computeDotProduct, modelEvaluationContext);
                                break;
                            default:
                                throw new UnsupportedFeatureException(model, modelType);
                        }
                    } else {
                        Double valueOf2 = Double.valueOf(0.0d);
                        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
                            case 4:
                                valueOf = computeLink(valueOf2, modelEvaluationContext);
                                break;
                            case 5:
                                valueOf = Double.valueOf(Math.exp(valueOf2.doubleValue()));
                                break;
                            case 6:
                                valueOf = Double.valueOf(1.0d);
                                break;
                            default:
                                throw new UnsupportedFeatureException(model, modelType);
                        }
                    }
                    switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
                        case 4:
                        case 5:
                            probabilityClassificationMap.put(str, valueOf);
                            break;
                        case 6:
                            if (d == null) {
                                probabilityClassificationMap.put(str, valueOf);
                                break;
                            } else {
                                probabilityClassificationMap.put(str, Double.valueOf(valueOf.doubleValue() - d.doubleValue()));
                                break;
                            }
                        default:
                            throw new UnsupportedFeatureException(model, modelType);
                    }
                    d = valueOf;
                }
                switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$GeneralRegressionModel$ModelType[modelType.ordinal()]) {
                    case 4:
                    case 5:
                        probabilityClassificationMap.normalizeValues();
                        break;
                    case 6:
                        break;
                    default:
                        throw new UnsupportedFeatureException(model, modelType);
                }
                return TargetUtil.evaluateClassification((Map<FieldName, ? extends ClassificationMap<?>>) Collections.singletonMap(targetField, probabilityClassificationMap), modelEvaluationContext);
            default:
                throw new UnsupportedFeatureException(dataField, opType);
        }
    }

    private Double computeDotProduct(EvaluationContext evaluationContext) {
        Map<String, Row> map;
        GeneralRegressionModel model = getModel();
        Map<FieldName, FieldValue> arguments = getArguments(evaluationContext);
        Map<String, Map<String, Row>> pPMatrixMap = getPPMatrixMap();
        if (pPMatrixMap.isEmpty()) {
            map = Collections.emptyMap();
        } else {
            map = pPMatrixMap.get(null);
            if (map == null) {
                throw new InvalidFeatureException(model.getPPMatrix());
            }
        }
        Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap();
        if (paramMatrixMap.size() == 1 && paramMatrixMap.containsKey(null)) {
            return computeDotProduct(paramMatrixMap.get(null), map, arguments);
        }
        throw new InvalidFeatureException(model.getParamMatrix());
    }

    private Double computeDotProduct(Iterable<PCell> iterable, Map<String, Row> map, Map<FieldName, FieldValue> map2) {
        double beta;
        Double d = null;
        for (PCell pCell : iterable) {
            Row row = map.get(pCell.getParameterName());
            if (row != null) {
                Double evaluate = row.evaluate(map2);
                if (evaluate == null) {
                    return null;
                }
                beta = evaluate.doubleValue() * pCell.getBeta();
            } else {
                beta = 1.0d * pCell.getBeta();
            }
            d = Double.valueOf(d != null ? d.doubleValue() + beta : beta);
        }
        return d;
    }

    private Double computeReferencePoint() {
        GeneralRegressionModel model = getModel();
        BiMap<String, Parameter> parameterRegistry = getParameterRegistry();
        Map<String, List<PCell>> paramMatrixMap = getParamMatrixMap();
        if (paramMatrixMap.size() != 1 || !paramMatrixMap.containsKey(null)) {
            throw new InvalidFeatureException(model.getParamMatrix());
        }
        Double d = null;
        for (PCell pCell : paramMatrixMap.get(null)) {
            Parameter parameter = (Parameter) parameterRegistry.get(pCell.getParameterName());
            if (parameter == null) {
                return null;
            }
            double referencePoint = parameter.getReferencePoint() * pCell.getBeta();
            d = Double.valueOf(d != null ? d.doubleValue() + referencePoint : referencePoint);
        }
        return d;
    }

    private Double computeLink(Double d, EvaluationContext evaluationContext) {
        GeneralRegressionModel model = getModel();
        LinkFunctionType linkFunction = model.getLinkFunction();
        if (linkFunction == null) {
            throw new InvalidFeatureException(model);
        }
        Double offset = getOffset(model, evaluationContext);
        Integer trials = getTrials(model, evaluationContext);
        Double linkParameter = model.getLinkParameter();
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$LinkFunctionType[linkFunction.ordinal()]) {
            case 1:
                return Double.valueOf((1.0d - Math.exp(-Math.exp(d.doubleValue() + offset.doubleValue()))) * trials.intValue());
            case 2:
                return Double.valueOf((d.doubleValue() + offset.doubleValue()) * trials.intValue());
            case 3:
                return Double.valueOf(Math.exp(d.doubleValue() + offset.doubleValue()) * trials.intValue());
            case 4:
                return Double.valueOf((1.0d - Math.exp(d.doubleValue() + offset.doubleValue())) * trials.intValue());
            case 5:
                return Double.valueOf((1.0d / (1.0d + Math.exp(-(d.doubleValue() + offset.doubleValue())))) * trials.intValue());
            case 6:
                return Double.valueOf(Math.exp(-Math.exp(-(d.doubleValue() + offset.doubleValue()))) * trials.intValue());
            case 7:
                if (linkParameter == null) {
                    throw new InvalidFeatureException(model);
                }
                return (linkParameter.doubleValue() < 0.0d || linkParameter.doubleValue() > 0.0d) ? Double.valueOf((1.0d / (1.0d + Math.pow(1.0d + (linkParameter.doubleValue() * (d.doubleValue() + offset.doubleValue())), -(1.0d / linkParameter.doubleValue())))) * trials.intValue()) : Double.valueOf((1.0d / (1.0d + Math.exp(-(d.doubleValue() + offset.doubleValue())))) * trials.intValue());
            case 8:
                if (linkParameter == null) {
                    throw new InvalidFeatureException(model);
                }
                return (linkParameter.doubleValue() < 0.0d || linkParameter.doubleValue() > 0.0d) ? Double.valueOf(Math.pow(d.doubleValue() + offset.doubleValue(), 1.0d / linkParameter.doubleValue()) * trials.intValue()) : Double.valueOf(Math.exp(d.doubleValue() + offset.doubleValue()) * trials.intValue());
            case 9:
                return Double.valueOf(NormalDistributionUtil.cumulativeProbability(d.doubleValue() + offset.doubleValue()) * trials.intValue());
            default:
                throw new UnsupportedFeatureException(model, linkFunction);
        }
    }

    private Double computeCumulativeLink(Double d, EvaluationContext evaluationContext) {
        GeneralRegressionModel model = getModel();
        CumulativeLinkFunctionType cumulativeLink = model.getCumulativeLink();
        if (cumulativeLink == null) {
            throw new InvalidFeatureException(model);
        }
        Double offset = getOffset(model, evaluationContext);
        switch (AnonymousClass12.$SwitchMap$org$dmg$pmml$CumulativeLinkFunctionType[cumulativeLink.ordinal()]) {
            case 1:
                return Double.valueOf(1.0d / (1.0d + Math.exp(-(d.doubleValue() + offset.doubleValue()))));
            case 2:
                return Double.valueOf(NormalDistributionUtil.cumulativeProbability(d.doubleValue() + offset.doubleValue()));
            case 3:
                return Double.valueOf(1.0d - Math.exp(-Math.exp(d.doubleValue() + offset.doubleValue())));
            case 4:
                return Double.valueOf(Math.exp(-Math.exp(-(d.doubleValue() + offset.doubleValue()))));
            case 5:
                return Double.valueOf(0.5d + (0.3183098861837907d * Math.atan(d.doubleValue() + offset.doubleValue())));
            default:
                throw new UnsupportedFeatureException(model, cumulativeLink);
        }
    }

    private Map<FieldName, FieldValue> getArguments(EvaluationContext evaluationContext) {
        BiMap<FieldName, Predictor> factorRegistry = getFactorRegistry();
        BiMap<FieldName, Predictor> covariateRegistry = getCovariateRegistry();
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        Iterator it = Iterables.concat(factorRegistry.values(), covariateRegistry.values()).iterator();
        while (it.hasNext()) {
            FieldName name = ((Predictor) it.next()).getName();
            newLinkedHashMap.put(name, ExpressionUtil.evaluate(name, evaluationContext));
        }
        return newLinkedHashMap;
    }

    public BiMap<String, Parameter> getParameterRegistry() {
        return (BiMap) getValue(parameterCache);
    }

    public BiMap<FieldName, Predictor> getFactorRegistry() {
        return (BiMap) getValue(factorCache);
    }

    public BiMap<FieldName, Predictor> getCovariateRegistry() {
        return (BiMap) getValue(covariateCache);
    }

    private Map<String, Map<String, Row>> getPPMatrixMap() {
        return (Map) getValue(ppMatrixCache);
    }

    private Map<String, List<PCell>> getParamMatrixMap() {
        return (Map) getValue(paramMatrixCache);
    }

    private static Double getOffset(GeneralRegressionModel generalRegressionModel, EvaluationContext evaluationContext) {
        FieldName offsetVariable = generalRegressionModel.getOffsetVariable();
        if (offsetVariable != null) {
            return Double.valueOf(getVariable(offsetVariable, evaluationContext).asNumber().doubleValue());
        }
        Double offsetValue = generalRegressionModel.getOffsetValue();
        return offsetValue != null ? offsetValue : Double.valueOf(0.0d);
    }

    private static Integer getTrials(GeneralRegressionModel generalRegressionModel, EvaluationContext evaluationContext) {
        FieldName trialsVariable = generalRegressionModel.getTrialsVariable();
        if (trialsVariable != null) {
            return getVariable(trialsVariable, evaluationContext).asInteger();
        }
        Integer trialsValue = generalRegressionModel.getTrialsValue();
        if (trialsValue != null) {
            return trialsValue;
        }
        return 1;
    }

    private static FieldValue getVariable(FieldName fieldName, EvaluationContext evaluationContext) {
        FieldValue field = evaluationContext.getField(fieldName);
        if (field == null) {
            throw new MissingFieldException(fieldName);
        }
        return field;
    }

    private static BaselineStratum getBaselineStratum(BaseCumHazardTables baseCumHazardTables, FieldValue fieldValue) {
        for (BaselineStratum baselineStratum : baseCumHazardTables.getBaselineStrata()) {
            if (fieldValue.equalsString(baselineStratum.getValue())) {
                return baselineStratum;
            }
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BiMap<String, Parameter> parseParameterRegistry(ParameterList parameterList) {
        HashBiMap create = HashBiMap.create();
        for (Parameter parameter : parameterList.getParameters()) {
            create.put(parameter.getName(), parameter);
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static BiMap<FieldName, Predictor> parsePredictorRegistry(PredictorList predictorList) {
        HashBiMap create = HashBiMap.create();
        if (predictorList == null) {
            return create;
        }
        for (Predictor predictor : predictorList.getPredictors()) {
            create.put(predictor.getName(), predictor);
        }
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, Map<String, Row>> parsePPMatrix(final GeneralRegressionModel generalRegressionModel) {
        com.google.common.base.Function<List<PPCell>, Row> function = new com.google.common.base.Function<List<PPCell>, Row>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.9
            private BiMap<FieldName, Predictor> factors;
            private BiMap<FieldName, Predictor> covariates;

            {
                this.factors = (BiMap) CacheUtil.getValue(generalRegressionModel, GeneralRegressionModelEvaluator.factorCache);
                this.covariates = (BiMap) CacheUtil.getValue(generalRegressionModel, GeneralRegressionModelEvaluator.covariateCache);
            }

            public Row apply(List<PPCell> list) {
                Row row = new Row();
                for (PPCell pPCell : list) {
                    FieldName predictorName = pPCell.getPredictorName();
                    Predictor predictor = (Predictor) this.factors.get(predictorName);
                    if (predictor != null) {
                        row.addFactor(pPCell, predictor);
                    } else {
                        if (((Predictor) this.covariates.get(predictorName)) == null) {
                            throw new InvalidFeatureException(pPCell);
                        }
                        row.addCovariate(pPCell);
                    }
                }
                return row;
            }
        };
        ListMultimap groupByTargetCategory = groupByTargetCategory(generalRegressionModel.getPPMatrix().getPPCells());
        LinkedHashMap newLinkedHashMap = Maps.newLinkedHashMap();
        for (Map.Entry entry : asMap(groupByTargetCategory).entrySet()) {
            LinkedHashMap newLinkedHashMap2 = Maps.newLinkedHashMap();
            for (Map.Entry entry2 : asMap(groupByParameterName((List) entry.getValue())).entrySet()) {
                newLinkedHashMap2.put(entry2.getKey(), function.apply(entry2.getValue()));
            }
            newLinkedHashMap.put(entry.getKey(), newLinkedHashMap2);
        }
        return newLinkedHashMap;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Map<String, List<PCell>> parseParamMatrix(GeneralRegressionModel generalRegressionModel) {
        return asMap(groupByTargetCategory(generalRegressionModel.getParamMatrix().getPCells()));
    }

    private static <C extends ParameterCell> Map<String, List<C>> asMap(ListMultimap<String, C> listMultimap) {
        return listMultimap.asMap();
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupByParameterName(List<C> list) {
        return groupCells(list, new com.google.common.base.Function<C, String>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.10
            /* JADX WARN: Incorrect types in method signature: (TC;)Ljava/lang/String; */
            public String apply(ParameterCell parameterCell) {
                return parameterCell.getParameterName();
            }
        });
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupByTargetCategory(List<C> list) {
        return groupCells(list, new com.google.common.base.Function<C, String>() { // from class: org.jpmml.evaluator.GeneralRegressionModelEvaluator.11
            /* JADX WARN: Incorrect types in method signature: (TC;)Ljava/lang/String; */
            public String apply(ParameterCell parameterCell) {
                return parameterCell.getTargetCategory();
            }
        });
    }

    private static <C extends ParameterCell> ListMultimap<String, C> groupCells(List<C> list, com.google.common.base.Function<C, String> function) {
        ArrayListMultimap create = ArrayListMultimap.create();
        for (C c : list) {
            create.put(function.apply(c), c);
        }
        return create;
    }
}
