package org.deeplearning4j.nn.updater;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.layers.BaseLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.updater.UpdaterBlock;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.accum.Norm2;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.class */
public abstract class BaseMultiLayerUpdater<T extends Model> implements Updater {
    protected final T network;
    protected Map<String, Layer> layersByName;
    protected final List<UpdaterBlock> updaterBlocks;
    protected INDArray updaterStateViewArray;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.nn.updater.BaseMultiLayerUpdater$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/updater/BaseMultiLayerUpdater$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization = new int[GradientNormalization.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization[GradientNormalization.RenormalizeL2PerLayer.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization[GradientNormalization.RenormalizeL2PerParamType.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization[GradientNormalization.ClipElementWiseAbsoluteValue.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization[GradientNormalization.ClipL2PerLayer.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization[GradientNormalization.ClipL2PerParamType.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
        }
    }

    public BaseMultiLayerUpdater(T t) {
        this(t, null);
    }

    public BaseMultiLayerUpdater(T t, INDArray iNDArray) {
        this.network = t;
        Layer[] orderedLayers = getOrderedLayers();
        int i = 0;
        Layer layer = null;
        String str = null;
        UpdaterBlock updaterBlock = null;
        this.updaterBlocks = new ArrayList();
        INDArray params = t.params();
        INDArray flattenedGradientsView = getFlattenedGradientsView();
        int i2 = 0;
        int i3 = 0;
        for (int i4 = 0; i4 < orderedLayers.length; i4++) {
            Map<String, INDArray> paramTable = orderedLayers[i4].paramTable();
            if (paramTable != null) {
                ArrayList arrayList = new ArrayList(paramTable.keySet());
                for (int i5 = 0; i5 < arrayList.size(); i5++) {
                    String str2 = (String) arrayList.get(i5);
                    int length = paramTable.get(str2).length();
                    int stateSize = (int) orderedLayers[i4].conf().getLayer().getUpdaterByParam(str2).stateSize(length);
                    INDArray iNDArray2 = null;
                    INDArray iNDArray3 = null;
                    if (length > 0) {
                        iNDArray3 = params.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i2, i2 + length)});
                        iNDArray2 = flattenedGradientsView.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i2, i2 + length)});
                    }
                    if (updaterBlock == null || !UpdaterUtils.updaterConfigurationsEquals(layer, str, orderedLayers[i4], str2)) {
                        ArrayList arrayList2 = new ArrayList();
                        arrayList2.add(new UpdaterBlock.ParamState(orderedLayers[i4], str2, i2, i2 + length, iNDArray3, iNDArray2));
                        updaterBlock = new UpdaterBlock(i2, i2 + length, i3, i3 + stateSize, arrayList2);
                        this.updaterBlocks.add(updaterBlock);
                    } else {
                        updaterBlock.setParamOffsetEnd(updaterBlock.getParamOffsetEnd() + length);
                        updaterBlock.setUpdaterViewOffsetEnd(updaterBlock.getUpdaterViewOffsetEnd() + stateSize);
                        updaterBlock.getLayersAndVariablesInBlock().add(new UpdaterBlock.ParamState(orderedLayers[i4], str2, i2, i2 + length, iNDArray3, iNDArray2));
                    }
                    layer = orderedLayers[i4];
                    str = (String) arrayList.get(i5);
                    i += stateSize;
                    i2 += length;
                    i3 += stateSize;
                }
            }
        }
        boolean z = false;
        if (iNDArray != null) {
            this.updaterStateViewArray = iNDArray;
            z = false;
        } else if (i > 0) {
            this.updaterStateViewArray = Nd4j.createUninitialized(new int[]{1, i}, Nd4j.order().charValue());
            z = true;
        }
        int i6 = 0;
        int i7 = 0;
        for (int i8 = 0; i8 < this.updaterBlocks.size(); i8++) {
            UpdaterBlock updaterBlock2 = this.updaterBlocks.get(i8);
            int updaterViewOffsetEnd = updaterBlock2.getUpdaterViewOffsetEnd() - updaterBlock2.getUpdaterViewOffsetStart();
            int paramOffsetEnd = updaterBlock2.getParamOffsetEnd() - updaterBlock2.getParamOffsetStart();
            if (updaterViewOffsetEnd > 0) {
                updaterBlock2.setUpdaterView(this.updaterStateViewArray.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i6, i6 + updaterViewOffsetEnd)}));
                updaterBlock2.setUpdaterViewRequiresInitialization(z);
            }
            if (paramOffsetEnd > 0) {
                updaterBlock2.setGradientView(flattenedGradientsView.get(new INDArrayIndex[]{NDArrayIndex.point(0), NDArrayIndex.interval(i7, i7 + paramOffsetEnd)}));
            }
            updaterBlock2.init();
            i6 += updaterViewOffsetEnd;
            i7 += paramOffsetEnd;
        }
    }

    protected abstract Layer[] getOrderedLayers();

    protected abstract INDArray getFlattenedGradientsView();

    protected abstract INDArray getParams();

    protected abstract boolean isMiniBatch();

    public void setStateViewArray(INDArray iNDArray) {
        if (this.updaterStateViewArray.length() != iNDArray.length()) {
            throw new IllegalStateException("Invalid input: view arrays differ in length. Expected length " + this.updaterStateViewArray.length() + ", got length " + iNDArray.length());
        }
        this.updaterStateViewArray.assign(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void setStateViewArray(Layer layer, INDArray iNDArray, boolean z) {
        setStateViewArray(iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public INDArray getStateViewArray() {
        return this.updaterStateViewArray;
    }

    @Override // org.deeplearning4j.nn.api.Updater
    public void update(Layer layer, Gradient gradient, int i, int i2, int i3) {
        update(gradient, i, i2, i3);
    }

    public void update(Gradient gradient, int i, int i2, int i3) {
        boolean z = gradient.gradient() != getFlattenedGradientsView();
        HashMap hashMap = new HashMap();
        Layer[] orderedLayers = getOrderedLayers();
        if (orderedLayers.length == 1 && isSingleLayerUpdater()) {
            hashMap.put(orderedLayers[0].conf().getLayer().getLayerName(), gradient);
        } else {
            for (Map.Entry<String, INDArray> entry : gradient.gradientForVariable().entrySet()) {
                String key = entry.getKey();
                int lastIndexOf = key.lastIndexOf(95);
                if (lastIndexOf == -1) {
                    throw new IllegalStateException("Invalid key: Gradient key does not have layer separator: \"" + key + "\"");
                }
                String substring = key.substring(0, lastIndexOf);
                Gradient gradient2 = (Gradient) hashMap.get(substring);
                if (gradient2 == null) {
                    gradient2 = new DefaultGradient();
                    hashMap.put(substring, gradient2);
                }
                gradient2.setGradientFor(key.substring(lastIndexOf + 1), entry.getValue());
            }
        }
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            String str = (String) ((Map.Entry) it.next()).getKey();
            preApply(this.layersByName.get(str), (Gradient) hashMap.get(str), i);
        }
        for (UpdaterBlock updaterBlock : this.updaterBlocks) {
            if (!updaterBlock.skipDueToPretrainConfig()) {
                if (Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(ComputationGraph.WORKSPACE_FEED_FORWARD)) {
                    MemoryWorkspace andActivateWorkspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(ComputationGraph.WORKSPACE_FEED_FORWARD);
                    Throwable th = null;
                    if (z) {
                        try {
                            try {
                                updaterBlock.updateExternalGradient(i, i2, gradient.gradient(), getParams());
                            } catch (Throwable th2) {
                                th = th2;
                                throw th2;
                            }
                        } catch (Throwable th3) {
                            if (andActivateWorkspace != null) {
                                if (th != null) {
                                    try {
                                        andActivateWorkspace.close();
                                    } catch (Throwable th4) {
                                        th.addSuppressed(th4);
                                    }
                                } else {
                                    andActivateWorkspace.close();
                                }
                            }
                            throw th3;
                        }
                    } else {
                        updaterBlock.update(i, i2);
                    }
                    if (andActivateWorkspace != null) {
                        if (0 != 0) {
                            try {
                                andActivateWorkspace.close();
                            } catch (Throwable th5) {
                                th.addSuppressed(th5);
                            }
                        } else {
                            andActivateWorkspace.close();
                        }
                    }
                } else if (z) {
                    updaterBlock.updateExternalGradient(i, i2, gradient.gradient(), getParams());
                } else {
                    updaterBlock.update(i, i2);
                }
            }
        }
        if (isMiniBatch()) {
            if (z) {
                gradient.gradient().divi(Integer.valueOf(i3));
                return;
            }
            INDArray flattenedGradientsView = getFlattenedGradientsView();
            if (flattenedGradientsView != null) {
                flattenedGradientsView.divi(Integer.valueOf(i3));
            }
        }
    }

    protected boolean isSingleLayerUpdater() {
        return false;
    }

    public void preApply(Layer layer, Gradient gradient, int i) {
        BaseLayer baseLayer;
        GradientNormalization gradientNormalization;
        if (!(layer.conf().getLayer() instanceof BaseLayer) || (gradientNormalization = (baseLayer = (BaseLayer) layer.conf().getLayer()).getGradientNormalization()) == null || gradientNormalization == GradientNormalization.None || layer.conf().isPretrain()) {
            return;
        }
        double gradientNormalizationThreshold = baseLayer.getGradientNormalizationThreshold();
        INDArray gradientsViewArray = layer.getGradientsViewArray();
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$conf$GradientNormalization[gradientNormalization.ordinal()]) {
            case 1:
                if (gradientsViewArray != null) {
                    gradientsViewArray.divi(Double.valueOf(gradientsViewArray.norm2Number().doubleValue()));
                    return;
                }
                return;
            case 2:
                for (INDArray iNDArray : gradient.gradientForVariable().values()) {
                    iNDArray.divi(Double.valueOf(Nd4j.getExecutioner().execAndReturn(new Norm2(iNDArray)).getFinalResult().doubleValue()));
                }
                return;
            case 3:
                if (gradientsViewArray != null) {
                    Nd4j.getExecutioner().exec(DynamicCustomOp.builder("clipbyvalue").addInputs(new INDArray[]{gradientsViewArray}).callInplace(true).addFloatingPointArguments(new Double[]{Double.valueOf(-gradientNormalizationThreshold), Double.valueOf(gradientNormalizationThreshold)}).build());
                    return;
                }
                return;
            case 4:
                if (gradientsViewArray != null) {
                    double doubleValue = gradientsViewArray.norm2Number().doubleValue();
                    if (doubleValue > gradientNormalizationThreshold) {
                        gradientsViewArray.muli(Double.valueOf(gradientNormalizationThreshold / doubleValue));
                        return;
                    }
                    return;
                }
                return;
            case RegressionEvaluation.DEFAULT_PRECISION /* 5 */:
                for (INDArray iNDArray2 : gradient.gradientForVariable().values()) {
                    double doubleValue2 = iNDArray2.norm2Number().doubleValue();
                    if (doubleValue2 > gradientNormalizationThreshold) {
                        iNDArray2.divi(Double.valueOf(doubleValue2 / gradientNormalizationThreshold));
                    }
                }
                return;
            default:
                throw new RuntimeException("Unknown (or not implemented) gradient normalization strategy: " + gradientNormalization);
        }
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        BaseMultiLayerUpdater baseMultiLayerUpdater = (BaseMultiLayerUpdater) obj;
        return this.updaterStateViewArray != null ? this.updaterStateViewArray.equals(baseMultiLayerUpdater.updaterStateViewArray) : baseMultiLayerUpdater.updaterStateViewArray == null;
    }

    public int hashCode() {
        return (31 * ((31 * (this.layersByName != null ? this.layersByName.hashCode() : 0)) + (this.updaterBlocks != null ? this.updaterBlocks.hashCode() : 0))) + (this.updaterStateViewArray != null ? this.updaterStateViewArray.hashCode() : 0);
    }

    public T getNetwork() {
        return this.network;
    }

    public Map<String, Layer> getLayersByName() {
        return this.layersByName;
    }

    public List<UpdaterBlock> getUpdaterBlocks() {
        return this.updaterBlocks;
    }

    public INDArray getUpdaterStateViewArray() {
        return this.updaterStateViewArray;
    }
}
