package org.nd4j.linalg.api.ops.impl.summarystats;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/summarystats/Variance.class */
public class Variance extends BaseReduceOp {
    protected double mean;
    protected double bias;
    protected boolean biasCorrected;

    public Variance(SameDiff sameDiff, SDVariable sDVariable, boolean z, boolean z2, int[] iArr) {
        super(sameDiff, sDVariable, iArr, z2);
        this.biasCorrected = true;
        this.biasCorrected = z;
        defineDimensions(iArr);
    }

    public Variance() {
        this.biasCorrected = true;
    }

    public Variance(boolean z) {
        this.biasCorrected = true;
        this.biasCorrected = z;
    }

    public Variance(INDArray iNDArray, int... iArr) {
        this(iNDArray, true, iArr);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, boolean z, int... iArr) {
        this(iNDArray, iNDArray2, true, false, iArr);
        this.biasCorrected = z;
    }

    public Variance(INDArray iNDArray, boolean z, boolean z2, int... iArr) {
        this(iNDArray, (INDArray) null, z, z2, iArr);
    }

    public Variance(INDArray iNDArray, boolean z, int... iArr) {
        super(iNDArray, new int[0]);
        this.biasCorrected = true;
        this.biasCorrected = z;
        defineDimensions(iArr);
    }

    public Variance(INDArray iNDArray, INDArray iNDArray2, boolean z, boolean z2, int... iArr) {
        super(iNDArray, (INDArray) null, iNDArray2, z2, iArr);
        this.biasCorrected = true;
        this.biasCorrected = z;
        defineDimensions(iArr);
    }

    @Override // org.nd4j.linalg.api.ops.BaseReduceOp, org.nd4j.linalg.api.ops.ReduceOp
    public INDArray noOp() {
        return Nd4j.zerosLike(x());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int opNum() {
        return 0;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "var";
    }

    public boolean isBiasCorrected() {
        return this.biasCorrected;
    }

    public void setBiasCorrected(boolean z) {
        this.biasCorrected = z;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return new VarianceBp(this.sameDiff, arg(), list.get(0), this.biasCorrected, this.keepDims, this.dimensions).outputs();
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx opName found for " + opName());
    }

    @Override // org.nd4j.linalg.api.ops.BaseOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
    }

    public Op.Type getOpType() {
        return Op.Type.VARIANCE;
    }

    @Override // org.nd4j.linalg.api.ops.ReduceOp
    public DataType resultType() {
        return resultType(null);
    }

    @Override // org.nd4j.linalg.api.ops.ReduceOp
    public DataType resultType(OpContext opContext) {
        INDArray inputArray = opContext != null ? opContext.getInputArray(0) : x();
        return (inputArray == null || !inputArray.isR()) ? arg() != null ? arg().dataType() : Nd4j.defaultFloatingPointType() : inputArray.dataType();
    }

    @Override // org.nd4j.linalg.api.ops.ReduceOp
    public boolean validateDataTypes(OpContext opContext) {
        INDArray inputArray = opContext != null ? opContext.getInputArray(0) : x();
        if (inputArray != null && !inputArray.isR()) {
            return false;
        }
        INDArray inputArray2 = opContext != null ? opContext.getInputArray(1) : y();
        if (inputArray2 != null && !inputArray2.isR()) {
            return false;
        }
        INDArray outputArray = opContext != null ? opContext.getOutputArray(0) : z();
        return outputArray == null || outputArray.isR();
    }

    @Override // org.nd4j.linalg.api.ops.BaseReduceOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape() {
        return calculateOutputShape(null);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<LongShapeDescriptor> calculateOutputShape(OpContext opContext) {
        INDArray inputArray = opContext != null ? opContext.getInputArray(0) : x();
        if (opContext == null && args().length < 1) {
            throw new ND4JIllegalStateException("Unable to compute input shape. No arguments found.");
        }
        long[] shape = arg().getShape();
        if (shape == null && inputArray == null) {
            return Collections.emptyList();
        }
        long[] shape2 = (shape == null || Shape.isPlaceholderShape(shape)) ? inputArray.shape() : shape;
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(LongShapeDescriptor.fromShape(Shape.getReducedShape(shape2, this.dimensions, isKeepDims()), resultType()));
        return arrayList;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.VARIANCE;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), list);
        return list.get(0).isFPType() ? list : Collections.singletonList(Nd4j.defaultFloatingPointType());
    }
}
