package org.deeplearning4j.nn.updater;

import java.util.Arrays;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/updater/MultiLayerUpdater.class */
public class MultiLayerUpdater implements Updater {
    private final Updater[] layerUpdaters;

    /* loaded from: input_file:org/deeplearning4j/nn/updater/MultiLayerUpdater$MultiLayerUpdaterAggregator.class */
    protected static class MultiLayerUpdaterAggregator implements UpdaterAggregator {
        private UpdaterAggregator[] aggregators;

        protected MultiLayerUpdaterAggregator() {
        }

        @Override // org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator
        public void aggregate(Updater updater) {
            MultiLayerUpdater multiLayerUpdater = (MultiLayerUpdater) updater;
            if (this.aggregators == null) {
                this.aggregators = new UpdaterAggregator[multiLayerUpdater.layerUpdaters.length];
                for (int i = 0; i < this.aggregators.length; i++) {
                    this.aggregators[i] = multiLayerUpdater.layerUpdaters[i].getAggregator(true);
                }
                return;
            }
            if (multiLayerUpdater.layerUpdaters == null) {
                return;
            }
            for (int i2 = 0; i2 < this.aggregators.length; i2++) {
                this.aggregators[i2].aggregate(multiLayerUpdater.layerUpdaters[i2]);
            }
        }

        @Override // org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator
        public void merge(UpdaterAggregator updaterAggregator) {
            MultiLayerUpdaterAggregator multiLayerUpdaterAggregator = (MultiLayerUpdaterAggregator) updaterAggregator;
            if (this.aggregators == null) {
                this.aggregators = multiLayerUpdaterAggregator.aggregators;
            } else if (multiLayerUpdaterAggregator.aggregators != null) {
                for (int i = 0; i < this.aggregators.length; i++) {
                    this.aggregators[i].merge(multiLayerUpdaterAggregator.aggregators[i]);
                }
            }
        }

        @Override // org.deeplearning4j.nn.updater.aggregate.UpdaterAggregator
        public Updater getUpdater() {
            MultiLayerUpdater multiLayerUpdater = new MultiLayerUpdater(this.aggregators.length);
            for (int i = 0; i < this.aggregators.length; i++) {
                multiLayerUpdater.layerUpdaters[i] = this.aggregators[i].getUpdater();
            }
            return multiLayerUpdater;
        }
    }

    public MultiLayerUpdater(MultiLayerNetwork multiLayerNetwork) {
        Layer[] layers = multiLayerNetwork.getLayers();
        this.layerUpdaters = new Updater[layers.length];
        for (int i = 0; i < layers.length; i++) {
            this.layerUpdaters[i] = UpdaterCreator.getUpdater(layers[i]);
        }
    }

    public MultiLayerUpdater(MultiLayerUpdater multiLayerUpdater) {
        this.layerUpdaters = new Updater[multiLayerUpdater.layerUpdaters.length];
        for (int i = 0; i < multiLayerUpdater.layerUpdaters.length; i++) {
            this.layerUpdaters[i] = multiLayerUpdater.layerUpdaters[i].m97clone();
        }
    }

    private MultiLayerUpdater(int i) {
        this.layerUpdaters = new Updater[i];
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void update(Layer layer, Gradient gradient, int i, int i2) {
        MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) layer;
        Gradient[] gradientArr = new Gradient[this.layerUpdaters.length];
        for (int i3 = 0; i3 < gradientArr.length; i3++) {
            gradientArr[i3] = new DefaultGradient();
        }
        for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
            String key = entry.getKey();
            int indexOf = key.indexOf(95);
            if (indexOf == -1) {
                throw new IllegalStateException("Invalid key: MuliLayerNetwork Gradient key does not have layer separator: \"" + key + "\"");
            }
            gradientArr[Integer.parseInt(key.substring(0, indexOf))].gradientForVariable().put(key.substring(indexOf + 1), entry.getValue());
        }
        for (int i4 = 0; i4 < this.layerUpdaters.length; i4++) {
            this.layerUpdaters[i4].update(multiLayerNetwork.getLayer(i4), gradientArr[i4], i, i2);
        }
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public UpdaterAggregator getAggregator(boolean z) {
        MultiLayerUpdaterAggregator multiLayerUpdaterAggregator = new MultiLayerUpdaterAggregator();
        if (z) {
            multiLayerUpdaterAggregator.aggregate(this);
        }
        return multiLayerUpdaterAggregator;
    }

    @Override // org.deeplearning4j.nn.api.Updater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Updater m97clone() {
        return new MultiLayerUpdater(this);
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MultiLayerUpdater)) {
            return false;
        }
        MultiLayerUpdater multiLayerUpdater = (MultiLayerUpdater) obj;
        return multiLayerUpdater.canEqual(this) && Arrays.deepEquals(this.layerUpdaters, multiLayerUpdater.layerUpdaters);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof MultiLayerUpdater;
    }

    public int hashCode() {
        return (1 * 59) + Arrays.deepHashCode(this.layerUpdaters);
    }
}
