package org.nd4j.autodiff.samediff;

import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.RSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.SquaredDifferenceOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/autodiff/samediff/SDVariable.class */
public class SDVariable extends DifferentialFunction implements Serializable {
    private String varName;
    protected WeightInitScheme weightInitScheme;

    /* loaded from: input_file:org/nd4j/autodiff/samediff/SDVariable$SDVariableBuilder.class */
    public static class SDVariableBuilder {
        private String varName;
        private SameDiff sameDiff;
        private int[] shape;
        private WeightInitScheme weightInitScheme;

        SDVariableBuilder() {
        }

        public SDVariableBuilder varName(String str) {
            this.varName = str;
            return this;
        }

        public SDVariableBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public SDVariableBuilder shape(int[] iArr) {
            this.shape = iArr;
            return this;
        }

        public SDVariableBuilder weightInitScheme(WeightInitScheme weightInitScheme) {
            this.weightInitScheme = weightInitScheme;
            return this;
        }

        public SDVariable build() {
            return new SDVariable(this.varName, this.sameDiff, this.shape, this.weightInitScheme);
        }

        public String toString() {
            return "SDVariable.SDVariableBuilder(varName=" + this.varName + ", sameDiff=" + this.sameDiff + ", shape=" + Arrays.toString(this.shape) + ", weightInitScheme=" + this.weightInitScheme + ")";
        }
    }

    private SDVariable(String str, SameDiff sameDiff, int[] iArr, WeightInitScheme weightInitScheme) {
        super(sameDiff, new Object[0]);
        this.varName = str;
        this.weightInitScheme = weightInitScheme;
        if (weightInitScheme == null) {
            this.weightInitScheme = new ZeroInitScheme('f');
        }
        if (iArr == null) {
            sameDiff.addAsPlaceHolder(str);
        } else {
            boolean z = false;
            int i = 0;
            while (true) {
                if (i >= iArr.length) {
                    break;
                }
                if (iArr[i] < 0) {
                    sameDiff.addAsPlaceHolder(str);
                    sameDiff.setOriginalPlaceHolderShape(str, iArr);
                    z = true;
                    break;
                }
                i++;
            }
            if (!z && iArr != null) {
                sameDiff.putShapeForVarName(str, iArr);
            }
        }
        this.sameDiff = sameDiff;
    }

    public boolean isPlaceHolder() {
        return this.sameDiff.isPlaceHolder(this.varName);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "variable";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables() {
        return new SDVariable[]{this};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable arg() {
        return this;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] args() {
        return new SDVariable[]{this};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable[] outputVariables(String str) {
        return new SDVariable[]{this};
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
    }

    public INDArray storeAndAllocateNewArray() {
        int[] shapeForVarName = this.sameDiff.getShapeForVarName(getVarName());
        if (getArr() != null && Arrays.equals(getArr().shape(), shapeForVarName)) {
            return getArr();
        }
        if (this.varName == null) {
            throw new ND4JIllegalStateException("Unable to store array for null variable name!");
        }
        if (shapeForVarName == null) {
            throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + this.varName);
        }
        INDArray create = getWeightInitScheme().create(shapeForVarName);
        this.sameDiff.putArrayForVarName(getVarName(), create);
        return create;
    }

    public INDArray getArr() {
        if (this.sameDiff.arrayAlreadyExistsForVarName(getVarName())) {
            return this.sameDiff.getArrForVarName(getVarName());
        }
        if (getScalarValue() != null && ArrayUtil.prod(getShape()) == 1) {
            this.sameDiff.associateArrayWithVariable(Nd4j.valueArrayOf(getShape(), getScalarValue().doubleValue()), this);
        } else {
            if (this.sameDiff.getShapeForVarName(getVarName()) == null) {
                return null;
            }
            this.sameDiff.associateArrayWithVariable(getWeightInitScheme().create(this.sameDiff.getShapeForVarName(getVarName())), this);
        }
        return this.sameDiff.getArrForVarName(getVarName());
    }

    public SDVariable gradient() {
        return getGradient();
    }

    public SDVariable getGradient() {
        return this.sameDiff.getGradForVariable(getVarName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        throw new ND4JIllegalStateException("Unable to differentiate a variable! Must be a function.");
    }

    public int[] getShape() {
        INDArray arr;
        int[] shapeForVarName = this.sameDiff.getShapeForVarName(getVarName());
        return (shapeForVarName != null || (arr = getArr()) == null) ? shapeForVarName : arr.shape();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SDVariable dup() {
        return this.sameDiff.var(this);
    }

    public SDVariable rsub(double d) {
        return rsub(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), d);
    }

    public SDVariable rdiv(double d) {
        return rdiv(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), d);
    }

    public SDVariable add(double d) {
        return add(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), d);
    }

    public SDVariable sub(double d) {
        return sub(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), d);
    }

    public SDVariable squaredDifference(SDVariable sDVariable) {
        return squaredDifference(this.sameDiff.generateNewVarName(new SquaredDifferenceOp().opName(), 0), sDVariable);
    }

    public SDVariable div(double d) {
        return div(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), d);
    }

    public SDVariable mul(double d) {
        return mul(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), d);
    }

    public SDVariable rsubi(double d) {
        return rsubi(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), d);
    }

    public SDVariable rdivi(double d) {
        return rdivi(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), d);
    }

    public SDVariable addi(double d) {
        return addi(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), d);
    }

    public SDVariable subi(double d) {
        return subi(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), d);
    }

    public SDVariable divi(double d) {
        return divi(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), d);
    }

    public SDVariable muli(double d) {
        return muli(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), d);
    }

    public SDVariable rsub(SDVariable sDVariable) {
        return rsub(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), sDVariable);
    }

    public SDVariable rdiv(SDVariable sDVariable) {
        return rdiv(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), sDVariable);
    }

    public SDVariable truncatedDiv(SDVariable sDVariable) {
        return truncatedDiv(this.sameDiff.generateNewVarName(new TruncateDivOp().opName(), 0), sDVariable);
    }

    public SDVariable add(SDVariable sDVariable) {
        return add(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), sDVariable);
    }

    public SDVariable sub(SDVariable sDVariable) {
        return sub(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), sDVariable);
    }

    public SDVariable div(SDVariable sDVariable) {
        return div(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), sDVariable);
    }

    public SDVariable mul(SDVariable sDVariable) {
        return mul(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), sDVariable);
    }

    public SDVariable rsubi(SDVariable sDVariable) {
        return rsubi(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), sDVariable);
    }

    public SDVariable rdivi(SDVariable sDVariable) {
        return rdivi(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), sDVariable);
    }

    public SDVariable addi(SDVariable sDVariable) {
        return addi(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), sDVariable);
    }

    public SDVariable subi(SDVariable sDVariable) {
        return subi(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), sDVariable);
    }

    public SDVariable divi(SDVariable sDVariable) {
        return divi(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), sDVariable);
    }

    public SDVariable muli(SDVariable sDVariable) {
        return muli(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), sDVariable);
    }

    public SDVariable rsub(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rsub(this, d), str);
    }

    public SDVariable rdiv(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rdiv(this, d), str);
    }

    public SDVariable truncatedDiv(String str, SDVariable sDVariable) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().truncatedDiv(this, sDVariable), str);
    }

    public SDVariable add(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().add(this, d), str);
    }

    public SDVariable sub(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().sub(this, d), str);
    }

    public SDVariable div(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().div(this, d), str);
    }

    public SDVariable mul(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().mul(this, d), str);
    }

    public SDVariable rsubi(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rsubi(this, d), str);
    }

    public SDVariable rdivi(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rdivi(this, d), str);
    }

    public SDVariable addi(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().addi(this, d), str);
    }

    public SDVariable subi(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().subi(this, d), str);
    }

    public SDVariable divi(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().divi(this, d), str);
    }

    public SDVariable muli(String str, double d) {
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().muli(this, d), str);
    }

    public SDVariable rsub(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rsub(this, sDVariable), str);
    }

    public SDVariable rdiv(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rdiv(this, sDVariable), str);
    }

    public SDVariable add(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().add(this, sDVariable), str);
    }

    public SDVariable sub(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().sub(this, sDVariable), str);
    }

    public SDVariable squaredDifference(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().squaredDifference(this, sDVariable), str);
    }

    public SDVariable div(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().div(this, sDVariable), str);
    }

    public SDVariable mul(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        Preconditions.checkState(this != null, "Left input is null!");
        Preconditions.checkState(sDVariable != null, "Right input is null!");
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().mul(this, sDVariable), str);
    }

    public SDVariable rsubi(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rsubi(this, sDVariable), str);
    }

    public SDVariable rdivi(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().rdivi(this, sDVariable), str);
    }

    public SDVariable addi(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().addi(this, sDVariable), str);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.RETURN;
    }

    public SDVariable subi(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        return this.sameDiff.updateVariableNameAndReference(this.sameDiff.f().subi(this, sDVariable), str);
    }

    public SDVariable divi(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        SDVariable divi = this.sameDiff.f().divi(this, sDVariable);
        divi.setVarName(str);
        return divi;
    }

    public SDVariable muli(String str, SDVariable sDVariable) {
        assertShapeEquals(sDVariable);
        SDVariable muli = this.sameDiff.f().muli(this, sDVariable);
        muli.setVarName(str);
        return muli;
    }

    public INDArray eval() {
        SameDiff dup = this.sameDiff.dup();
        dup.defineFunction("output", new SameDiff.SameDiffFunctionDefinition() { // from class: org.nd4j.autodiff.samediff.SDVariable.1
            @Override // org.nd4j.autodiff.samediff.SameDiff.SameDiffFunctionDefinition
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> map, SDVariable[] sDVariableArr) {
                return new SDVariable[]{SDVariable.this};
            }
        });
        return dup.invokeFunctionOn("output", dup).getSameDiff().execAndEndResult();
    }

    private void assertShapeEquals(SDVariable sDVariable) {
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return this.varName;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass() || !super.equals(obj)) {
            return false;
        }
        SDVariable sDVariable = (SDVariable) obj;
        if (this.varName != null) {
            if (!this.varName.equals(sDVariable.varName)) {
                return false;
            }
        } else if (sDVariable.varName != null) {
            return false;
        }
        return this.weightInitScheme != null ? this.weightInitScheme.equals(sDVariable.weightInitScheme) : sDVariable.weightInitScheme == null;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public int hashCode() {
        return (31 * ((31 * super.hashCode()) + (this.varName != null ? this.varName.hashCode() : 0))) + (this.weightInitScheme != null ? this.weightInitScheme.hashCode() : 0);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + opName());
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
    }

    public static SDVariableBuilder builder() {
        return new SDVariableBuilder();
    }

    public SDVariable() {
    }

    public String getVarName() {
        return this.varName;
    }

    public void setVarName(String str) {
        this.varName = str;
    }

    public WeightInitScheme getWeightInitScheme() {
        return this.weightInitScheme;
    }

    public void setWeightInitScheme(WeightInitScheme weightInitScheme) {
        this.weightInitScheme = weightInitScheme;
    }
}
