package moa.classifiers.multilabel;

import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.MultiLabelInstance;
import com.yahoo.labs.samoa.instances.Prediction;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import moa.classifiers.Classifier;
import moa.classifiers.MultiLabelLearner;
import moa.classifiers.MultiTargetRegressor;
import moa.classifiers.core.attributeclassobservers.AttributeClassObserver;
import moa.classifiers.trees.HoeffdingTree;
import moa.classifiers.trees.HoeffdingTreeClassifLeaves;
import moa.core.Example;
import moa.core.StringUtils;

/* loaded from: input_file:moa/classifiers/multilabel/MultilabelHoeffdingTree.class */
public class MultilabelHoeffdingTree extends HoeffdingTreeClassifLeaves implements MultiLabelLearner, MultiTargetRegressor {
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:moa/classifiers/multilabel/MultilabelHoeffdingTree$MultilabelInactiveLearningNode.class */
    public static class MultilabelInactiveLearningNode extends HoeffdingTree.InactiveLearningNode {
        private static final long serialVersionUID = 1;

        public MultilabelInactiveLearningNode(double[] dArr) {
            super(dArr);
        }

        @Override // moa.classifiers.trees.HoeffdingTree.InactiveLearningNode, moa.classifiers.trees.HoeffdingTree.LearningNode
        public void learnFromInstance(Instance instance, HoeffdingTree hoeffdingTree) {
            Iterator<Integer> it = MultilabelHoeffdingTree.getRelevantLabels(instance).iterator();
            while (it.hasNext()) {
                this.observedClassDistribution.addToValue(it.next().intValue(), instance.weight());
            }
        }
    }

    /* loaded from: input_file:moa/classifiers/multilabel/MultilabelHoeffdingTree$MultilabelLearningNodeClassifier.class */
    public class MultilabelLearningNodeClassifier extends HoeffdingTreeClassifLeaves.LearningNodeClassifier {
        private static final long serialVersionUID = 1;

        public MultilabelLearningNodeClassifier(double[] dArr, Classifier classifier, MultilabelHoeffdingTree multilabelHoeffdingTree) {
            super(dArr);
            if (classifier != null) {
                this.classifier = classifier.copy();
                return;
            }
            this.classifier = ((Classifier) MultilabelHoeffdingTree.this.getPreparedClassOption(multilabelHoeffdingTree.learnerOption)).copy();
            this.classifier.resetLearning();
            this.classifier.setModelContext(multilabelHoeffdingTree.getModelContext());
        }

        @Override // moa.classifiers.trees.HoeffdingTreeClassifLeaves.LearningNodeClassifier, moa.classifiers.trees.HoeffdingTree.Node
        public double[] getClassVotes(Instance instance, HoeffdingTree hoeffdingTree) {
            return this.classifier.getVotesForInstance(instance);
        }

        public Prediction getPredictionForInstance(Instance instance, HoeffdingTree hoeffdingTree) {
            return this.classifier.getPredictionForInstance(instance);
        }

        @Override // moa.classifiers.trees.HoeffdingTreeClassifLeaves.LearningNodeClassifier, moa.classifiers.trees.HoeffdingTree.ActiveLearningNode
        public void disableAttribute(int i) {
        }

        @Override // moa.classifiers.trees.HoeffdingTreeClassifLeaves.LearningNodeClassifier
        public Classifier getClassifier() {
            return this.classifier;
        }

        @Override // moa.classifiers.trees.HoeffdingTreeClassifLeaves.LearningNodeClassifier, moa.classifiers.trees.HoeffdingTree.ActiveLearningNode, moa.classifiers.trees.HoeffdingTree.LearningNode
        public void learnFromInstance(Instance instance, HoeffdingTree hoeffdingTree) {
            this.classifier.trainOnInstance(instance);
            MultilabelHoeffdingTree multilabelHoeffdingTree = (MultilabelHoeffdingTree) hoeffdingTree;
            List<Integer> relevantLabels = MultilabelHoeffdingTree.getRelevantLabels(instance);
            Iterator<Integer> it = relevantLabels.iterator();
            while (it.hasNext()) {
                this.observedClassDistribution.addToValue(it.next().intValue(), instance.weight());
            }
            for (int i = 0; i < instance.numInputAttributes(); i++) {
                AttributeClassObserver attributeClassObserver = this.attributeObservers.get(i);
                if (attributeClassObserver == null) {
                    attributeClassObserver = instance.inputAttribute(i).isNominal() ? multilabelHoeffdingTree.newNominalClassObserver() : multilabelHoeffdingTree.newNumericClassObserver();
                    this.attributeObservers.set(i, attributeClassObserver);
                }
                Iterator<Integer> it2 = relevantLabels.iterator();
                while (it2.hasNext()) {
                    attributeClassObserver.observeAttributeClass(instance.valueInputAttribute(i), it2.next().intValue(), instance.weight());
                }
            }
        }

        @Override // moa.classifiers.trees.HoeffdingTree.Node
        public void describeSubtree(HoeffdingTree hoeffdingTree, StringBuilder sb, int i) {
            StringUtils.appendIndented(sb, i, "Leaf ");
            sb.append(" = ");
            sb.append(" weights: ");
            this.observedClassDistribution.getSingleLineDescription(sb, this.observedClassDistribution.numValues());
            StringUtils.appendNewline(sb);
        }
    }

    @Override // moa.classifiers.trees.HoeffdingTreeClassifLeaves, moa.classifiers.trees.HoeffdingTree
    protected HoeffdingTree.LearningNode newLearningNode(double[] dArr) {
        return new MultilabelLearningNodeClassifier(dArr, null, this);
    }

    @Override // moa.classifiers.trees.HoeffdingTreeClassifLeaves
    protected HoeffdingTree.LearningNode newLearningNode(double[] dArr, Classifier classifier) {
        return new MultilabelLearningNodeClassifier(dArr, classifier, this);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // moa.classifiers.trees.HoeffdingTree
    public void deactivateLearningNode(HoeffdingTree.ActiveLearningNode activeLearningNode, HoeffdingTree.SplitNode splitNode, int i) {
        MultilabelInactiveLearningNode multilabelInactiveLearningNode = new MultilabelInactiveLearningNode(activeLearningNode.getObservedClassDistribution());
        if (splitNode == null) {
            this.treeRoot = multilabelInactiveLearningNode;
        } else {
            splitNode.setChild(i, multilabelInactiveLearningNode);
        }
        this.activeLeafNodeCount--;
        this.inactiveLeafNodeCount++;
    }

    @Override // moa.classifiers.AbstractClassifier, moa.learners.Learner
    public Prediction getPredictionForInstance(Example<Instance> example) {
        return getPredictionForInstance((MultiLabelInstance) example.getData());
    }

    @Override // moa.classifiers.MultiLabelLearner
    public Prediction getPredictionForInstance(MultiLabelInstance multiLabelInstance) {
        if (this.treeRoot == null) {
            System.err.println("[WARNING] Root Node == Null !!!!!!");
            return null;
        }
        HoeffdingTree.FoundNode filterInstanceToLeaf = this.treeRoot.filterInstanceToLeaf(multiLabelInstance, null, -1);
        MultilabelLearningNodeClassifier multilabelLearningNodeClassifier = (MultilabelLearningNodeClassifier) filterInstanceToLeaf.node;
        MultilabelLearningNodeClassifier multilabelLearningNodeClassifier2 = multilabelLearningNodeClassifier;
        if (multilabelLearningNodeClassifier == null) {
            HoeffdingTree.SplitNode splitNode = filterInstanceToLeaf.parent;
        }
        return multilabelLearningNodeClassifier2.getPredictionForInstance(multiLabelInstance, this);
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public void trainOnInstance(Instance instance) {
        if (instance.weight() > 0.0d) {
            this.trainingWeightSeenByModel += instance.weight();
            trainOnInstanceImpl((MultiLabelInstance) instance);
        }
    }

    @Override // moa.classifiers.MultiLabelLearner
    public void trainOnInstanceImpl(MultiLabelInstance multiLabelInstance) {
        trainOnInstanceImpl((Instance) multiLabelInstance);
    }

    public static List<Integer> getRelevantLabels(Instance instance) {
        LinkedList linkedList = new LinkedList();
        for (int i = 0; i < instance.numberOutputTargets(); i++) {
            if (instance.classValue(i) > 0.0d) {
                linkedList.add(Integer.valueOf(i));
            }
        }
        return linkedList;
    }
}
