package org.streaminer.stream.classifier.tree;

import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.learner.LearnerUtils;
import org.streaminer.stream.learner.Regressor;
import org.streaminer.stream.model.Model;

/* loaded from: input_file:org/streaminer/stream/classifier/tree/RTree.class */
public class RTree extends ModelTree<RegressionTreeStatistics, Double> implements Regressor<Data> {
    private static final long serialVersionUID = 4926545397273482368L;
    static Logger log = LoggerFactory.getLogger(RTree.class);
    SplitCriterion<RegressionTreeStatistics> splitCriterion;
    Map<String, BTreeNode> btrees;
    BestSplitValueFinder splitValueFinder;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:org/streaminer/stream/classifier/tree/RTree$BestSplitValueFinder.class */
    public class BestSplitValueFinder implements Visitor<BinaryTreeNode<RegressionTreeStatistics, Double>> {
        Double maxValue = null;
        Double maxSdr = Double.valueOf(Double.NEGATIVE_INFINITY);

        BestSplitValueFinder() {
        }

        public void reset() {
            this.maxValue = null;
            this.maxSdr = Double.valueOf(Double.NEGATIVE_INFINITY);
        }

        public Double getValue() {
            return this.maxValue;
        }

        public Double getMaximum() {
            return this.maxSdr;
        }

        @Override // org.streaminer.stream.classifier.tree.Visitor
        public void visit(BinaryTreeNode<RegressionTreeStatistics, Double> binaryTreeNode) {
            if (this.maxValue == null) {
                this.maxValue = binaryTreeNode.getValue();
                RTree.log.info("Found initial split value: {}  (sdr: {})", this.maxValue, this.maxSdr);
                return;
            }
            Double standardDeviationReduction = getStandardDeviationReduction(binaryTreeNode);
            if (standardDeviationReduction.doubleValue() > this.maxSdr.doubleValue()) {
                this.maxValue = binaryTreeNode.getValue();
                this.maxSdr = standardDeviationReduction;
                RTree.log.info("Found new best split value: {}  (sdr: {})", this.maxValue, this.maxSdr);
            }
        }

        public Double getStandardDeviationReduction(BinaryTreeNode<RegressionTreeStatistics, Double> binaryTreeNode) {
            Double.valueOf(0.0d);
            Double standardDeviation = binaryTreeNode.getNodeInfo().getStandardDeviation();
            Double numberOfExamples = binaryTreeNode.getNodeInfo().getNumberOfExamples();
            Double valueOf = Double.valueOf(0.0d);
            Double valueOf2 = Double.valueOf(0.0d);
            if (binaryTreeNode.getLeft() != null) {
                valueOf = binaryTreeNode.getLeft().getNodeInfo().getNumberOfExamples();
                valueOf2 = binaryTreeNode.getLeft().getNodeInfo().getStandardDeviation();
            }
            Double valueOf3 = Double.valueOf(0.0d);
            Double valueOf4 = Double.valueOf(0.0d);
            if (binaryTreeNode.getRight() != null) {
                valueOf3 = binaryTreeNode.getRight().getNodeInfo().getNumberOfExamples();
                valueOf4 = binaryTreeNode.getRight().getNodeInfo().getStandardDeviation();
            }
            return Double.valueOf((standardDeviation.doubleValue() - ((valueOf.doubleValue() / numberOfExamples.doubleValue()) * valueOf2.doubleValue())) - ((valueOf3.doubleValue() / numberOfExamples.doubleValue()) * valueOf4.doubleValue()));
        }
    }

    public RTree(String str, TreeNode<RegressionTreeStatistics> treeNode) {
        super(str, treeNode, new ChernoffSplitCriterion());
        this.btrees = new HashMap();
        this.splitValueFinder = new BestSplitValueFinder();
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.streaminer.stream.learner.Learner
    public void learn(Data data) {
        ModelTree<RegressionTreeStatistics, Double> leaf = getLeaf(data);
        leaf.getModel().learn(data);
        if (!LearnerUtils.isNumerical(getName(), data)) {
            throw new RuntimeException("Nominal values are not supported!");
        }
        Double d = LearnerUtils.getDouble(getName(), data);
        ((RegressionTreeStatistics) leaf.getNodeInfo()).update(d);
        BTreeNode bTreeNode = this.btrees.get(leaf.getName());
        if (bTreeNode == null) {
            this.btrees.put(leaf.getName(), new BTreeNode(leaf.getName(), d));
        } else {
            bTreeNode.insert(d);
        }
        boolean requiresSplit = this.splitCriterion.requiresSplit(leaf.getNodeInfo());
        getBestSplitValue(this.btrees.get(leaf.getName()));
        if (requiresSplit) {
            new ModelTree(leaf.getName(), (ModelTree) leaf.getParent(), this.splitCriterion).add(new ModelTree(leaf.getName(), null, this.splitCriterion));
        }
    }

    public Double getBestSplitValue(BTreeNode bTreeNode) {
        this.splitValueFinder.reset();
        bTreeNode.inOrder(this.splitValueFinder);
        log.info("Best split value is: {}  (SDR: {})", this.splitValueFinder.getValue(), this.splitValueFinder.getMaximum());
        return this.splitValueFinder.getValue();
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.streaminer.stream.classifier.tree.ModelTree, org.streaminer.stream.model.PredictionModel
    public Double predict(Data data) {
        return getLeaf(data).predict(data);
    }

    @Override // org.streaminer.stream.learner.Learner
    public void init() {
    }

    @Override // org.streaminer.stream.learner.Learner
    public /* bridge */ /* synthetic */ Model getModel() {
        return super.getModel();
    }
}
