package org.streaminer.stream.classifier.tree;

import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.learner.AbstractRegressor;
import org.streaminer.stream.learner.LearnerUtils;
import org.streaminer.stream.learner.Regressor;

/* loaded from: input_file:org/streaminer/stream/classifier/tree/RegressionTree.class */
public class RegressionTree extends AbstractRegressor<Data> {
    private static final long serialVersionUID = -2683782830606679008L;
    static final transient Logger log = LoggerFactory.getLogger(RegressionTree.class);
    final RegressionTreeModel tree;
    BTreeNode possibleSplitpoint = null;
    boolean splitPossible;
    double epsilon;
    final double delta;

    public RegressionTree(Double d, Regressor<Data> regressor) throws Exception {
        this.tree = new RegressionTreeModel(regressor);
        this.delta = d.doubleValue();
    }

    @Override // org.streaminer.stream.learner.AbstractRegressor, org.streaminer.stream.learner.Learner
    public RegressionTreeModel getModel() {
        return this.tree;
    }

    @Override // org.streaminer.stream.learner.AbstractRegressor, org.streaminer.stream.learner.Learner
    public void learn(Data data) {
        LeafNode leaf = this.tree.getLeaf(data);
        leaf.getRegressionModel().learn(data);
        leaf.updateBTrees(data);
        computeChernoffBound(leaf, data);
        if (checkSplitpoints(leaf, data) != null) {
            try {
                doSplit(leaf);
            } catch (Exception e) {
            }
        }
    }

    protected BTreeNode checkSplitpoints(LeafNode leafNode, Data data) {
        BTreeNode bTreeNode = null;
        Iterator<String> it = LearnerUtils.getAttributes(data).iterator();
        while (it.hasNext()) {
            BTreeNode bTreeNode2 = leafNode.getBTrees().get(it.next());
            double doubleValue = bTreeNode2.getNodeInfo().getStandardDeviation().doubleValue();
            if (doubleValue > 0.0d && 0 != 0) {
                Double standardDeviationReduction = bTreeNode2.getStandardDeviationReduction();
                if (0 == 0 || standardDeviationReduction.doubleValue() <= bTreeNode.getStandardDeviationReduction().doubleValue() || (standardDeviationReduction.doubleValue() - bTreeNode.getStandardDeviationReduction().doubleValue()) - this.epsilon >= 0.0d) {
                }
                computeSDRs(bTreeNode2, doubleValue);
            }
        }
        return 0 != 0 ? null : null;
    }

    protected void computeSDRs(BTreeNode bTreeNode, double d) {
        if (d > 0.0d) {
            double doubleValue = bTreeNode.getStandardDeviationReduction().doubleValue();
            if (this.possibleSplitpoint == null) {
                this.possibleSplitpoint = bTreeNode;
                return;
            }
            if (doubleValue > this.possibleSplitpoint.getStandardDeviationReduction().doubleValue()) {
                if ((doubleValue - this.possibleSplitpoint.getStandardDeviationReduction().doubleValue()) - this.epsilon >= 0.0d) {
                    this.splitPossible = true;
                }
                this.possibleSplitpoint = bTreeNode;
            } else if ((this.possibleSplitpoint.getStandardDeviationReduction().doubleValue() - doubleValue) - this.epsilon < 0.0d) {
                this.splitPossible = false;
            }
        }
    }

    protected void doSplit(LeafNode leafNode) throws InstantiationException, IllegalAccessException, ClassNotFoundException {
        System.out.println("perform split");
        InnerNode innerNode = new InnerNode(this.possibleSplitpoint.getName(), this.possibleSplitpoint.getValue(), leafNode.getRegressionModel(), leafNode.getN());
        InnerNode innerNode2 = (InnerNode) leafNode.getParent();
        if (innerNode2 == null) {
            this.tree.setRoot(innerNode);
        } else if (leafNode.isRightChild()) {
            innerNode2.setRightChild(innerNode);
        } else {
            innerNode2.setLeftChild(innerNode);
        }
    }

    protected void computeChernoffBound(LeafNode leafNode, Data data) {
        log.debug("call predict: {}", data);
        this.epsilon = Math.sqrt(Math.abs(((3.0d * leafNode.getRegressionModel().predict(data).doubleValue()) / leafNode.getN()) * Math.log(2.0d / this.delta)));
    }

    @Override // org.streaminer.stream.learner.AbstractRegressor, org.streaminer.stream.model.PredictionModel
    public Double predict(Data data) {
        return this.tree.predict(data);
    }
}
