package org.streaminer.stream.classifier.tree;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.learner.Learner;
import org.streaminer.stream.learner.LearnerUtils;

/* loaded from: input_file:org/streaminer/stream/classifier/tree/HoeffdingTree.class */
public class HoeffdingTree implements Learner<Data, HoeffdingTreeModel> {
    private static final long serialVersionUID = -2578445642027682454L;
    public static final int DEFAULT_DECLARE_NUMERIC = 5;
    public static final double DEFAULT_DELTA = 1.0E-7d;
    protected HoeffdingTreeModel tree;
    protected Map<String, List<Serializable>> featureValuePairs;
    protected List<Serializable> labels;
    protected Map<HoeffdingTreeNode, NodeData> NodeDataPairs;
    protected double numerator;
    protected QualityCriterion qualityCriterion;
    protected double minQualityForSplit;
    protected boolean initialized;
    protected String labelAttribute;
    private static final Logger logger = LoggerFactory.getLogger(HoeffdingTree.class);
    public static final QualityCriterion INFORMATION_GAIN = QualityCriterion.INFO_GAIN;
    public static final QualityCriterion GINI_INDEX = QualityCriterion.GINI_INDEX;
    public static final Double[] DEFAULT_QUANTILES = {Double.valueOf(0.25d), Double.valueOf(0.5d), Double.valueOf(0.75d), Double.valueOf(1.0d)};
    public static final QualityCriterion DEFAULT_QUALITY_CRITERION = INFORMATION_GAIN;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/streaminer/stream/classifier/tree/HoeffdingTree$NodeData.class */
    public class NodeData implements Serializable {
        private static final long serialVersionUID = 1;
        private final List<String> remainingFeatures;
        private final int[][][] counter;
        private final int fewestValuesFeature;
        String labelAttribute;
        private int exampleCount = 0;
        private int examplesTillNextLabelUpdate = 0;
        private Serializable lastMajorityClass = null;
        private boolean isUniLabel = true;
        private int uniLabelIndex = -1;

        /* JADX WARN: Type inference failed for: r1v10, types: [int[][], int[][][]] */
        public NodeData(String str, List<String> list) {
            this.labelAttribute = str;
            this.remainingFeatures = list;
            this.counter = new int[list.size()];
            int i = 0;
            int i2 = -1;
            int i3 = 0;
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                int size = HoeffdingTree.this.featureValuePairs.get(it.next()).size();
                this.counter[i] = new int[size][HoeffdingTree.this.labels.size()];
                if (i2 == -1 || size < i2) {
                    i3 = i;
                    i2 = size;
                }
                i++;
            }
            this.fewestValuesFeature = i3;
            initCounter();
        }

        private void initCounter() {
            for (int i = 0; i < this.counter.length; i++) {
                for (int i2 = 0; i2 < this.counter[i].length; i2++) {
                    for (int i3 = 0; i3 < this.counter[i][i2].length; i3++) {
                        this.counter[i][i2][i3] = 0;
                    }
                }
            }
        }

        public void addExample(Data data) {
            this.exampleCount++;
            Serializable label = LearnerUtils.getLabel(data);
            if (label == null) {
                HoeffdingTree.logger.warn("No label found for example: {}", data);
                return;
            }
            if (LearnerUtils.getLabel(data).equals(this.lastMajorityClass)) {
                this.examplesTillNextLabelUpdate++;
            } else {
                this.examplesTillNextLabelUpdate--;
            }
            int indexOf = HoeffdingTree.this.labels.indexOf(label);
            for (String str : this.remainingFeatures) {
                incrementCounter(str, data.get(str), indexOf);
            }
            if (this.isUniLabel) {
                if (this.uniLabelIndex == -1) {
                    this.uniLabelIndex = indexOf;
                } else if (this.uniLabelIndex != indexOf) {
                    this.isUniLabel = false;
                    this.uniLabelIndex = 0;
                }
            }
        }

        public List<String> getRemainingFeatures() {
            return this.remainingFeatures;
        }

        private void incrementCounter(String str, Serializable serializable, int i) {
            int indexOf = this.remainingFeatures.indexOf(str);
            if (indexOf != -1) {
                if (HoeffdingTree.this.tree.getType(str) == Double.class) {
                    int[] iArr = this.counter[indexOf][HoeffdingTree.this.featureValuePairs.get(str).indexOf(HoeffdingTree.numericToNominal(HoeffdingTree.this.featureValuePairs.get(str), (Comparable) serializable))];
                    iArr[i] = iArr[i] + 1;
                } else {
                    int[] iArr2 = this.counter[indexOf][HoeffdingTree.this.featureValuePairs.get(str).indexOf(serializable)];
                    iArr2[i] = iArr2[i] + 1;
                }
            }
        }

        public Serializable getMajorityClass() {
            if (this.examplesTillNextLabelUpdate <= 0) {
                if (this.isUniLabel) {
                    this.lastMajorityClass = HoeffdingTree.this.labels.get(this.uniLabelIndex);
                    this.examplesTillNextLabelUpdate = this.exampleCount + 1;
                    return this.lastMajorityClass;
                }
                int i = 0;
                int i2 = 0;
                int i3 = 0;
                for (int i4 = 0; i4 < HoeffdingTree.this.labels.size(); i4++) {
                    int i5 = 0;
                    for (int i6 = 0; i6 < this.counter[this.fewestValuesFeature].length; i6++) {
                        i5 += this.counter[this.fewestValuesFeature][i6][i4];
                    }
                    if (i5 > i2) {
                        i3 = i2;
                        i2 = i5;
                        i = i4;
                    } else if (i5 > i3) {
                        i3 = i5;
                    }
                }
                this.lastMajorityClass = HoeffdingTree.this.labels.get(i);
                this.examplesTillNextLabelUpdate = (i2 - i3) + 1;
            }
            return this.lastMajorityClass;
        }

        public Serializable getMajorityClass(String str, Serializable serializable) {
            int i = 0;
            int i2 = 0;
            int indexOf = this.remainingFeatures.indexOf(str);
            int indexOf2 = HoeffdingTree.this.featureValuePairs.get(str).indexOf(serializable);
            for (int i3 = 0; i3 < HoeffdingTree.this.labels.size(); i3++) {
                if (this.counter[indexOf][indexOf2][i3] > i2) {
                    i2 = this.counter[indexOf][indexOf2][i3];
                    i = i3;
                }
            }
            return HoeffdingTree.this.labels.get(i);
        }

        public int getExampleCount() {
            return this.exampleCount;
        }

        public boolean isUniformLabel() {
            return this.isUniLabel;
        }

        public Map<String, Double> getQualityAllAttributes() {
            HashMap hashMap = new HashMap();
            double qualityNoSplit = getQualityNoSplit();
            for (String str : getRemainingFeatures()) {
                hashMap.put(str, Double.valueOf(qualityNoSplit - getQualitySplitByFeature(str)));
            }
            return hashMap;
        }

        public double getQualityNoSplit() {
            int exampleCount = getExampleCount();
            double[] initArray = initArray(new double[HoeffdingTree.this.labels.size()]);
            for (int i = 0; i < this.counter[this.fewestValuesFeature].length; i++) {
                for (int i2 = 0; i2 < initArray.length; i2++) {
                    int i3 = i2;
                    initArray[i3] = initArray[i3] + this.counter[this.fewestValuesFeature][i][i2];
                }
            }
            for (int i4 = 0; i4 < initArray.length; i4++) {
                initArray[i4] = initArray[i4] / exampleCount;
            }
            return HoeffdingTree.this.qualityCriterion.getQuality(initArray);
        }

        public double getQualitySplitByFeature(String str) {
            int exampleCount = getExampleCount();
            int indexOf = this.remainingFeatures.indexOf(str);
            double d = 0.0d;
            for (int i = 0; i < this.counter[indexOf].length; i++) {
                int i2 = 0;
                for (int i3 = 0; i3 < this.counter[indexOf][i].length; i3++) {
                    i2 += this.counter[indexOf][i][i3];
                }
                if (i2 > 0) {
                    double[] initArray = initArray(new double[this.counter[indexOf][i].length]);
                    for (int i4 = 0; i4 < initArray.length; i4++) {
                        initArray[i4] = this.counter[indexOf][i][i4] / i2;
                    }
                    d += (i2 / exampleCount) * HoeffdingTree.this.qualityCriterion.getQuality(initArray);
                }
            }
            return d;
        }

        public double[] initArray(double[] dArr) {
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = 0.0d;
            }
            return dArr;
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("NodeData [counter=");
            for (int i = 0; i < this.counter.length; i++) {
                for (int i2 = 0; i2 < this.counter[i].length; i2++) {
                    for (int i3 = 0; i3 < this.counter[i][i2].length; i3++) {
                        sb.append("\n[" + i + "][" + i2 + "][" + i3 + "] = " + this.counter[i][i2][i3]);
                    }
                }
            }
            sb.append("\n remainingFeatures=");
            sb.append(this.remainingFeatures);
            sb.append("]");
            return sb.toString();
        }
    }

    public HoeffdingTree() {
        this.labelAttribute = null;
        this.initialized = false;
    }

    public HoeffdingTree(String str) {
        this.labelAttribute = str;
        this.initialized = false;
    }

    public String getLabelAttribute() {
        return this.labelAttribute;
    }

    public void setLabelAttribute(String str) {
        this.labelAttribute = str;
    }

    public HoeffdingTree(Map<String, Class<?>> map, Map<String, List<Serializable>> map2, List<Serializable> list, double d, QualityCriterion qualityCriterion, double d2) {
        initialize(map, map2, list, d, qualityCriterion, d2);
    }

    public HoeffdingTree(Collection<Data> collection) {
        initialize(collection);
    }

    @Override // org.streaminer.stream.learner.Learner
    public void init() {
        initialize(new HashSet());
    }

    public void initialize(Map<String, Class<?>> map, Map<String, List<Serializable>> map2, List<Serializable> list, double d, QualityCriterion qualityCriterion, double d2) {
        this.qualityCriterion = qualityCriterion;
        this.minQualityForSplit = d2;
        this.labels = list;
        this.featureValuePairs = map2;
        this.tree = new HoeffdingTreeModel(map, list.get(0));
        this.NodeDataPairs = new HashMap();
        this.NodeDataPairs.put(this.tree.getLeaf(null), new NodeData(this.labelAttribute, new ArrayList(map2.keySet())));
        double highestGain = qualityCriterion.getHighestGain(list.size());
        this.numerator = highestGain * highestGain * Math.log(1.0d / d);
        this.initialized = true;
        logger.debug("HoeffdingTree learner initialized with\n quality criterion: " + this.qualityCriterion + " gain,\n prepruning min quality value: " + this.minQualityForSplit + ",\n delta: " + d + "\n valid class labels: " + this.labels + ",\n feature types: " + map + ",\n valid (nominal) / threshold (numeric) feature values: " + this.featureValuePairs + ",\n highest value of quality criterion (for hoeffding bound computation): " + highestGain);
    }

    public void initialize(Collection<Data> collection) {
        initialize(LearnerUtils.getTypes(collection), constructFeatureValuePairs(collection, LearnerUtils.getTypes(collection), new Double[0]), getAllValues(collection, this.labelAttribute), 1.0E-7d, DEFAULT_QUALITY_CRITERION, DEFAULT_QUALITY_CRITERION.getLowestGain());
    }

    public List<Serializable> getAllValues(Collection<Data> collection, String str) {
        Serializable serializable;
        ArrayList arrayList = new ArrayList();
        for (Data data : collection) {
            if (!LearnerUtils.isHidden(str) && (serializable = data.get(str)) != null) {
                arrayList.add(serializable);
            }
        }
        return arrayList;
    }

    public Map<String, List<Serializable>> constructFeatureValuePairs(Collection<Data> collection, Map<String, Class<?>> map, Double... dArr) {
        HashMap hashMap = new HashMap();
        for (String str : getAllFeatures(collection)) {
            List<Serializable> allValues = getAllValues(collection, str);
            if (map.get(str) != Double.class) {
                hashMap.put(str, new ArrayList(new HashSet(allValues)));
            } else {
                hashMap.put(str, getNumericThresholds(allValues, dArr));
            }
        }
        return hashMap;
    }

    public Set<String> getAllFeatures(Collection<Data> collection) {
        HashSet hashSet = new HashSet();
        Iterator<Data> it = collection.iterator();
        while (it.hasNext()) {
            for (String str : LearnerUtils.getAttributes(it.next())) {
                if (!str.equals(this.labelAttribute) && !LearnerUtils.isHiddenOrSpecial(str)) {
                    hashSet.add(str);
                }
            }
        }
        return hashSet;
    }

    public static List<Serializable> getNumericThresholds(Collection<Serializable> collection, Double... dArr) {
        ArrayList arrayList = new ArrayList();
        Arrays.sort(dArr);
        ArrayList arrayList2 = new ArrayList(collection);
        Collections.sort(arrayList2);
        for (Double d : dArr) {
            int round = Math.round((float) (d.doubleValue() * collection.size())) - 1;
            if (round < 0) {
                round = 0;
            }
            if (round >= collection.size()) {
                round = collection.size() - 1;
            }
            Serializable serializable = (Serializable) arrayList2.get(round);
            if (arrayList.isEmpty() || !((Serializable) arrayList.get(arrayList.size() - 1)).equals(serializable)) {
                arrayList.add(serializable);
            }
        }
        return arrayList;
    }

    public static Comparable numericToNominal(Collection<Comparable> collection, Comparable comparable) {
        Comparable comparable2 = null;
        for (Comparable comparable3 : collection) {
            if (comparable.compareTo(comparable3) <= 0 && (comparable2 == null || comparable3.compareTo(comparable2) < 0)) {
                comparable2 = comparable3;
            }
        }
        if (comparable2 == null) {
            for (Comparable comparable4 : collection) {
                if (comparable2 == null || comparable2.compareTo(comparable4) < 0) {
                    comparable2 = comparable4;
                }
            }
        }
        return comparable2;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.streaminer.stream.learner.Learner
    public HoeffdingTreeModel getModel() {
        return this.tree;
    }

    @Override // org.streaminer.stream.learner.Learner
    public void learn(Data data) {
        if (!this.initialized) {
            logger.warn("Learner has not been initialized!");
            return;
        }
        HoeffdingTreeNode leaf = this.tree.getLeaf(data);
        NodeData nodeData = this.NodeDataPairs.get(leaf);
        if (nodeData.getRemainingFeatures().isEmpty()) {
            return;
        }
        nodeData.addExample(data);
        leaf.setLabel(nodeData.getMajorityClass());
        if (nodeData.isUniformLabel()) {
            return;
        }
        Map<String, Double> qualityAllAttributes = nodeData.getQualityAllAttributes();
        String maximumKey = LearnerUtils.getMaximumKey(qualityAllAttributes);
        double doubleValue = qualityAllAttributes.get(maximumKey).doubleValue();
        qualityAllAttributes.remove(maximumKey);
        String maximumKey2 = LearnerUtils.getMaximumKey(qualityAllAttributes);
        double doubleValue2 = maximumKey2 != null ? qualityAllAttributes.get(maximumKey2).doubleValue() : 0.0d;
        double sqrt = Math.sqrt(this.numerator / (2 * nodeData.getExampleCount()));
        if (doubleValue - doubleValue2 <= sqrt || doubleValue - this.minQualityForSplit < sqrt) {
            return;
        }
        leaf.setFeature(maximumKey);
        leaf.setLabel(null);
        this.NodeDataPairs.remove(leaf);
        ArrayList arrayList = new ArrayList(nodeData.getRemainingFeatures());
        arrayList.remove(maximumKey);
        for (Serializable serializable : this.featureValuePairs.get(maximumKey)) {
            HoeffdingTreeNode hoeffdingTreeNode = new HoeffdingTreeNode(nodeData.getMajorityClass(maximumKey, serializable));
            leaf.addChild(hoeffdingTreeNode, serializable);
            this.NodeDataPairs.put(hoeffdingTreeNode, new NodeData(this.labelAttribute, arrayList));
        }
        logger.debug("split node on feature " + maximumKey + "\n examples processed at this node: " + nodeData.getExampleCount() + "\n Epsilon=" + sqrt + ", G(" + maximumKey + ")=" + doubleValue + ", 2nd best: G(" + maximumKey2 + ")=" + doubleValue2 + "\n creating " + this.featureValuePairs.get(maximumKey).size() + " children for values: " + this.featureValuePairs.get(maximumKey) + "\n remaining features for each child: " + arrayList + "\n new tree:\n" + this.tree.toString());
    }
}
