package hex.genmodel.algos.targetencoder;

import hex.genmodel.MojoModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;

/* loaded from: input_file:hex/genmodel/algos/targetencoder/TargetEncoderMojoModel.class */
public class TargetEncoderMojoModel extends MojoModel {
    public EncodingMaps _targetEncodingMap;
    public Map<String, Integer> _teColumnNameToIdx;
    public Map<String, Integer> _teColumnNameToMissingValuesPresence;
    public boolean _withBlending;
    public double _inflectionPoint;
    public double _smoothing;
    public double _priorMean;
    private final boolean _imputationOfUnknownLevelsIsEnabled = true;

    /* loaded from: input_file:hex/genmodel/algos/targetencoder/TargetEncoderMojoModel$SortByKeyAssociatedIndex.class */
    public static class SortByKeyAssociatedIndex<K extends String, V> implements Comparator<Map.Entry<K, V>> {
        public Map<String, Integer> _teColumnNameToIdx;

        public SortByKeyAssociatedIndex(Map<String, Integer> map) {
            this._teColumnNameToIdx = map;
        }

        @Override // java.util.Comparator
        public int compare(Map.Entry<K, V> entry, Map.Entry<K, V> entry2) {
            return this._teColumnNameToIdx.get(entry.getKey()).compareTo(this._teColumnNameToIdx.get(entry2.getKey()));
        }
    }

    public TargetEncoderMojoModel(String[] strArr, String[][] strArr2, String str) {
        super(strArr, strArr2, str);
        this._teColumnNameToIdx = new HashMap();
        this._imputationOfUnknownLevelsIsEnabled = true;
    }

    public static double computeLambda(int i, double d, double d2) {
        return 1.0d / (1.0d + Math.exp((d - i) / d2));
    }

    public static double computeBlendedEncoding(double d, double d2, double d3) {
        return (d * d2) + ((1.0d - d) * d3);
    }

    @Override // hex.genmodel.GenModel
    public double[] score0(double[] dArr, double[] dArr2) {
        if (this._targetEncodingMap == null) {
            throw new IllegalStateException("Encoding map is missing.");
        }
        int i = 0;
        for (Map.Entry entry : sortByColumnIndex(this._targetEncodingMap.encodingMap()).entrySet()) {
            EncodingMap encodingMap = (EncodingMap) entry.getValue();
            String str = (String) entry.getKey();
            double d = dArr[this._teColumnNameToIdx.get(str).intValue()];
            if (!Double.isNaN(d)) {
                computeEncodings(dArr2, i, encodingMap, (int) d);
            } else if (this._teColumnNameToMissingValuesPresence.get(str).intValue() == 1) {
                computeEncodings(dArr2, i, encodingMap, encodingMap._encodingMap.size() - 1);
            } else {
                dArr2[i] = this._priorMean;
            }
            i++;
        }
        return dArr2;
    }

    private void computeEncodings(double[] dArr, int i, EncodingMap encodingMap, int i2) {
        int[] iArr = encodingMap._encodingMap.get(Integer.valueOf(i2));
        double d = iArr[0] / iArr[1];
        if (this._withBlending) {
            dArr[i] = computeBlendedEncoding(computeLambda(iArr[1], this._inflectionPoint, this._smoothing), d, this._priorMean);
        } else {
            dArr[i] = d;
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    <K, V> LinkedHashMap<K, V> sortByColumnIndex(Map<K, V> map) {
        ArrayList arrayList = new ArrayList(map.entrySet());
        Collections.sort(arrayList, new SortByKeyAssociatedIndex(this._teColumnNameToIdx));
        LinkedHashMap<K, V> linkedHashMap = (LinkedHashMap<K, V>) new LinkedHashMap();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Map.Entry entry = (Map.Entry) it.next();
            linkedHashMap.put(entry.getKey(), entry.getValue());
        }
        return linkedHashMap;
    }
}
