package org.nd4j.linalg.api.ops.impl.accum;

import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BaseAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/accum/Bias.class */
public class Bias extends BaseAccumulation {
    private double mean;

    public Bias() {
    }

    public Bias(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        super(iNDArray, iNDArray2, iNDArray3, i);
        this.passThrough = true;
    }

    public Bias(INDArray iNDArray, INDArray iNDArray2, int i) {
        this(iNDArray, iNDArray2, iNDArray, i);
        this.passThrough = true;
    }

    public Bias(INDArray iNDArray) {
        super(iNDArray);
        this.passThrough = true;
    }

    public Bias(INDArray iNDArray, INDArray iNDArray2) {
        super(iNDArray, iNDArray2);
        this.passThrough = true;
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public String name() {
        return "bias";
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int i2) {
        INDArray vectorAlongDimension = this.x.vectorAlongDimension(i, i2);
        return y() != null ? new Bias(vectorAlongDimension, this.y.vectorAlongDimension(i, i2), vectorAlongDimension.length()) : new Bias(this.x.vectorAlongDimension(i, i2));
    }

    @Override // org.nd4j.linalg.api.ops.Op
    public Op opForDimension(int i, int... iArr) {
        INDArray tensorAlongDimension = this.x.tensorAlongDimension(i, iArr);
        return y() != null ? new Bias(tensorAlongDimension, this.y.tensorAlongDimension(i, iArr), tensorAlongDimension.length()) : new Bias(this.x.tensorAlongDimension(i, iArr));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public double update(double d, double d2) {
        return d + (d2 - this.mean);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public double update(double d, double d2, double d3) {
        return d + (d2 - this.mean);
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public float update(float f, float f2) {
        return f + ((float) (f2 - this.mean));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public float update(float f, float f2, float f3) {
        return f + ((float) (f2 - this.mean));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, double d) {
        return iComplexNumber.add(Double.valueOf(d - this.mean));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, double d, double d2) {
        return iComplexNumber.add(Double.valueOf(d - this.mean));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        return iComplexNumber.add(iComplexNumber2.sub(Double.valueOf(this.mean)));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2, IComplexNumber iComplexNumber3) {
        return iComplexNumber.add(iComplexNumber2.sub(Double.valueOf(this.mean)));
    }

    @Override // org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber update(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2, double d) {
        return iComplexNumber.add(iComplexNumber2.sub(Double.valueOf(this.mean)));
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber zeroComplex() {
        return Nd4j.createComplexNumber(Double.valueOf(0.0d), Double.valueOf(0.0d));
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void init(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        super.init(iNDArray, iNDArray2, iNDArray3, i);
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public double combineSubResults(double d, double d2) {
        return d + d2;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public float combineSubResults(float f, float f2) {
        return f + f2;
    }

    @Override // org.nd4j.linalg.api.ops.BaseAccumulation, org.nd4j.linalg.api.ops.Accumulation
    public IComplexNumber combineSubResults(IComplexNumber iComplexNumber, IComplexNumber iComplexNumber2) {
        return iComplexNumber.add(iComplexNumber2);
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec() {
        this.mean = Nd4j.getExecutioner().execAndReturn((Accumulation) new Mean(this.x)).getFinalResult().doubleValue();
        this.finalResult = Double.valueOf(Nd4j.getExecutioner().execAndReturn((Accumulation) new Sum(this.x.sub(Double.valueOf(this.mean)))).getFinalResult().doubleValue());
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.linalg.api.ops.Op
    public void exec(int... iArr) {
        int[] removeIndex = ArrayUtil.removeIndex(this.x.shape(), iArr);
        int tensorssAlongDimension = this.x.tensorssAlongDimension(iArr);
        this.z = Nd4j.create(removeIndex);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            this.z.putScalar(i, Nd4j.getExecutioner().execAndReturn((Accumulation) opForDimension(i, iArr)).getFinalResult().doubleValue());
        }
    }
}
