package mulan.evaluation.loss;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import mulan.data.LabelNode;
import mulan.data.LabelsMetaData;
import mulan.data.MultiLabelInstances;

/* loaded from: input_file:mulan/evaluation/loss/HierarchicalLoss.class */
public class HierarchicalLoss extends BipartitionLossFunctionBase {
    private LabelsMetaData metaData;
    private Map<String, Integer> labelPosition = new HashMap();
    private double loss;

    public HierarchicalLoss(MultiLabelInstances multiLabelInstances) {
        this.metaData = multiLabelInstances.getLabelsMetaData();
        int i = 0;
        for (int i2 : multiLabelInstances.getLabelIndices()) {
            this.labelPosition.put(multiLabelInstances.getDataSet().attribute(i2).name(), Integer.valueOf(i));
            i++;
        }
    }

    @Override // mulan.evaluation.loss.MultiLabelLossFunction
    public String getName() {
        return "Hierarchical Loss";
    }

    @Override // mulan.evaluation.loss.BipartitionLossFunctionBase, mulan.evaluation.loss.BipartitionLossFunction
    public double computeLoss(boolean[] zArr, boolean[] zArr2) {
        this.loss = 0.0d;
        calculateHLoss(zArr, zArr2, this.metaData.getRootLabels());
        return this.loss;
    }

    private void calculateHLoss(boolean[] zArr, boolean[] zArr2, Set<LabelNode> set) {
        for (LabelNode labelNode : set) {
            int intValue = this.labelPosition.get(labelNode.getName()).intValue();
            if (zArr[intValue] == zArr2[intValue]) {
                calculateHLoss(zArr, zArr2, labelNode.getChildren());
            } else {
                this.loss += 1.0d;
            }
        }
    }
}
