package org.deeplearning4j.nn.layers.normalization;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.class */
public class LocalResponseNormalization extends BaseLayer<org.deeplearning4j.nn.conf.layers.LocalResponseNormalization> {
    protected static final Logger log = LoggerFactory.getLogger(org.deeplearning4j.nn.conf.layers.LocalResponseNormalization.class);
    LocalResponseNormalizationHelper helper;
    private double k;
    private double n;
    private double alpha;
    private double beta;
    private int halfN;
    private INDArray activations;
    private INDArray unitScale;
    private INDArray scale;

    public LocalResponseNormalization(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        super(neuralNetConfiguration, iNDArray);
        this.helper = null;
        initializeHelper();
    }

    public LocalResponseNormalization(NeuralNetConfiguration neuralNetConfiguration) {
        super(neuralNetConfiguration);
        this.helper = null;
        initializeHelper();
    }

    void initializeHelper() {
        try {
            this.helper = (LocalResponseNormalizationHelper) Class.forName("org.deeplearning4j.nn.layers.normalization.CudnnLocalResponseNormalizationHelper").asSubclass(LocalResponseNormalizationHelper.class).newInstance();
        } catch (Throwable th) {
            if (th instanceof ClassNotFoundException) {
                return;
            }
            log.warn("Could not load CudnnLocalResponseNormalizationHelper", th);
        }
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL2() {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public double calcL1() {
        return 0.0d;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer.Type type() {
        return Layer.Type.NORMALIZATION;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void fit(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Pair<Gradient, INDArray> backpropGradient(INDArray iNDArray) {
        Pair<Gradient, INDArray> backpropGradient;
        if (this.helper != null && (backpropGradient = this.helper.backpropGradient(this.input, iNDArray, this.k, this.n, this.alpha, this.beta)) != null) {
            return backpropGradient;
        }
        int i = this.input.shape()[1];
        DefaultGradient defaultGradient = new DefaultGradient();
        INDArray mul = this.activations.mul(iNDArray);
        INDArray dup = mul.dup();
        for (int i2 = 1; i2 < this.halfN + 1; i2++) {
            dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, i - i2), NDArrayIndex.all(), NDArrayIndex.all()})));
            dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, i - i2), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, i - i2), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i), NDArrayIndex.all(), NDArrayIndex.all()})));
        }
        return new Pair<>(defaultGradient, iNDArray.mul(this.scale).sub(this.input.mul(Double.valueOf(2.0d * this.alpha * this.beta)).mul(dup.div(this.unitScale))));
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public INDArray activate(boolean z) {
        this.k = layerConf().getK();
        this.n = layerConf().getN();
        this.alpha = layerConf().getAlpha();
        this.beta = layerConf().getBeta();
        this.halfN = ((int) this.n) / 2;
        if (this.helper != null) {
            this.activations = this.helper.activate(this.input, z, this.k, this.n, this.alpha, this.beta);
            if (this.activations != null) {
                return this.activations;
            }
        }
        int i = this.input.shape()[1];
        INDArray mul = this.input.mul(this.input);
        INDArray dup = mul.dup();
        for (int i2 = 1; i2 < this.halfN + 1; i2++) {
            dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, i - i2), NDArrayIndex.all(), NDArrayIndex.all()})));
            dup.put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, i - i2), NDArrayIndex.all(), NDArrayIndex.all()}, dup.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(0, i - i2), NDArrayIndex.all(), NDArrayIndex.all()}).addi(mul.get(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.interval(i2, i), NDArrayIndex.all(), NDArrayIndex.all()})));
        }
        this.unitScale = dup.mul(Double.valueOf(this.alpha)).add(Double.valueOf(this.k));
        this.scale = Transforms.pow(this.unitScale, Double.valueOf(-this.beta));
        this.activations = this.input.mul(this.scale);
        return this.activations;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public Gradient calcGradient(Gradient gradient, INDArray iNDArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        throw new UnsupportedOperationException();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public INDArray params() {
        return null;
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public INDArray getParam(String str) {
        return params();
    }

    @Override // org.deeplearning4j.nn.layers.BaseLayer, org.deeplearning4j.nn.api.Model
    public void setParams(INDArray iNDArray) {
    }
}
