package mulan.classifier.neural;

import java.util.List;
import java.util.Map;
import mulan.classifier.MultiLabelOutput;
import mulan.classifier.neural.model.Neuron;
import mulan.core.ArgumentNullException;
import mulan.evaluation.loss.RankingLossFunction;

/* loaded from: input_file:mulan/classifier/neural/MMPUpdateRuleBase.class */
public abstract class MMPUpdateRuleBase implements ModelUpdateRule {
    private final List<Neuron> perceptrons;
    private final RankingLossFunction lossFunction;

    public MMPUpdateRuleBase(List<Neuron> list, RankingLossFunction rankingLossFunction) {
        if (list == null) {
            throw new ArgumentNullException("perceptrons");
        }
        if (rankingLossFunction == null) {
            throw new ArgumentNullException("lossMeasure");
        }
        this.perceptrons = list;
        this.lossFunction = rankingLossFunction;
    }

    @Override // mulan.classifier.neural.ModelUpdateRule
    public final double process(DataPair dataPair, Map<String, Object> map) {
        int length = dataPair.getOutput().length;
        int length2 = dataPair.getInput().length;
        double[] input = dataPair.getInput();
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = this.perceptrons.get(i).processInput(input);
        }
        double computeLoss = this.lossFunction.computeLoss(new MultiLabelOutput(dArr).getRanking(), dataPair.getOutputBoolean());
        if (computeLoss != 0.0d) {
            double[] computeUpdateParameters = computeUpdateParameters(dataPair, dArr, computeLoss);
            for (int i2 = 0; i2 < length; i2++) {
                double[] weights = this.perceptrons.get(i2).getWeights();
                for (int i3 = 0; i3 < length2; i3++) {
                    int i4 = i3;
                    weights[i4] = weights[i4] + (computeUpdateParameters[i2] * input[i3]);
                }
            }
        }
        return computeLoss;
    }

    protected abstract double[] computeUpdateParameters(DataPair dataPair, double[] dArr, double d);
}
