package org.nd4j.linalg.api.ops.impl.reduce;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp;
import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.guava.primitives.Ints;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/reduce/Moments.class */
public class Moments extends DynamicCustomOp {
    private boolean keepDims;

    public Moments() {
    }

    public Moments(@NonNull INDArray iNDArray, boolean z, int... iArr) {
        super(new INDArray[]{iNDArray}, (INDArray[]) null);
        if (iNDArray == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        this.dimensions = iArr;
        this.keepDims = z;
        addArgs();
    }

    public Moments(@NonNull INDArray iNDArray, int... iArr) {
        super(new INDArray[]{iNDArray}, (INDArray[]) null);
        if (iNDArray == null) {
            throw new NullPointerException("input is marked non-null but is null");
        }
        this.dimensions = iArr;
        addArgs();
    }

    public Moments(SameDiff sameDiff, SDVariable sDVariable) {
        this(sameDiff, sDVariable, (int[]) null);
        addArgs();
    }

    public Moments(SameDiff sameDiff, SDVariable sDVariable, int[] iArr) {
        super(null, sameDiff, new SDVariable[]{sDVariable}, false);
        this.dimensions = iArr;
        addArgs();
    }

    public Moments(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int... iArr) {
        super((String) null, new INDArray[]{iNDArray}, new INDArray[]{iNDArray2, iNDArray3}, (List<Double>) null, iArr);
        this.dimensions = iArr;
        addArgs();
    }

    public Moments(INDArray iNDArray, int[] iArr, boolean z) {
        super((String) null, new INDArray[]{iNDArray}, (INDArray[]) null);
        this.keepDims = z;
        this.dimensions = iArr;
        addArgs();
    }

    public Moments(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        super((String) null, new INDArray[]{iNDArray, iNDArray2}, (INDArray[]) null);
        this.keepDims = z;
        addArgs();
    }

    public Moments(SameDiff sameDiff, SDVariable sDVariable, int[] iArr, boolean z) {
        super(null, sameDiff, new SDVariable[]{sDVariable}, false);
        this.keepDims = z;
        this.dimensions = iArr;
        addArgs();
    }

    public Moments(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, boolean z) {
        super(null, sameDiff, new SDVariable[]{sDVariable, sDVariable2}, false);
        this.keepDims = z;
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "moments";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        SDVariable sDVariable = list.get(0);
        SDVariable sDVariable2 = list.get(1);
        if (this.dimensions != null) {
            return Collections.singletonList(new MeanBp(this.sameDiff, arg(), sDVariable, this.keepDims, this.dimensions).outputVariable().add(new VarianceBp(this.sameDiff, arg(), sDVariable2, false, this.keepDims, this.dimensions).outputVariable()));
        }
        if (numIArguments() <= 0) {
            return numInputArguments() > 1 ? Collections.singletonList(new MeanBp(this.sameDiff, arg(), sDVariable, this.keepDims, arg(1)).outputVariable().add(new VarianceBp(this.sameDiff, arg(), sDVariable2, false, this.keepDims, arg(1)).outputVariable())) : Collections.singletonList(new MeanBp(this.sameDiff, arg(), sDVariable, this.keepDims, this.dimensions).outputVariable().add(new VarianceBp(this.sameDiff, arg(), sDVariable2, false, this.keepDims, this.dimensions).outputVariable()));
        }
        int[] array = Ints.toArray(this.iArguments);
        this.dimensions = array;
        return Collections.singletonList(new MeanBp(this.sameDiff, arg(), sDVariable, this.keepDims, array).outputVariable().add(new VarianceBp(this.sameDiff, arg(), sDVariable2, false, this.keepDims, array).outputVariable()));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected 1 datatype for %s, got %s", getClass(), list);
        return list.get(0).isFPType() ? Arrays.asList(list.get(0), list.get(0)) : Arrays.asList(Nd4j.defaultFloatingPointType(), Nd4j.defaultFloatingPointType());
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        HashMap hashMap = new HashMap();
        hashMap.put("keepDims", Boolean.valueOf(this.keepDims));
        hashMap.put("dimensions", this.dimensions);
        return hashMap;
    }

    protected void addArgs() {
        addBArgument(this.keepDims);
        if (this.dimensions == null || this.dimensions.length <= 0) {
            return;
        }
        if (this.dimensions.length == 1 && this.dimensions[0] == Integer.MAX_VALUE) {
            return;
        }
        addIArgument(this.dimensions);
    }
}
