package org.deeplearning4j.gradient.multilayer;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.gradient.MultiLayerGradient;

/* loaded from: input_file:org/deeplearning4j/gradient/multilayer/AverageChangeMultiLayerGradientListener.class */
public class AverageChangeMultiLayerGradientListener implements MultiLayerGradientListener {
    private static final long serialVersionUID = 9078190492614228289L;
    private List<MultiLayerGradient> gradients = new ArrayList();

    @Override // org.deeplearning4j.gradient.multilayer.MultiLayerGradientListener
    public void onMultiLayerGradient(MultiLayerGradient multiLayerGradient) {
        this.gradients.add(multiLayerGradient);
    }

    public MultiLayerGradient averaged() {
        MultiLayerGradient m16clone = this.gradients.get(0).m16clone();
        for (int i = 1; i < this.gradients.size(); i++) {
            m16clone.addGradient(this.gradients.get(i).m16clone());
        }
        m16clone.div(this.gradients.size());
        return m16clone;
    }
}
