/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops;

import java.nio.Buffer;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.ShapeOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public abstract class BaseOp
extends DifferentialFunction
implements Op {
    protected INDArray x;
    protected INDArray y;
    protected INDArray z;
    protected long n;
    protected long numProcessed;
    protected Object[] extraArgs;
    protected boolean passThrough;
    protected String xVertexId;
    protected String yVertexId;
    protected String zVertexId;
    protected DataBuffer extraArgz;

    public BaseOp() {
    }

    public BaseOp(SameDiff sameDiff, boolean inPlace, Object[] extraArgs) {
        super(sameDiff, inPlace, extraArgs);
    }

    public BaseOp(SameDiff sameDiff, Object[] extraArgs) {
        super(sameDiff, extraArgs);
    }

    @Override
    public boolean isExecSpecial() {
        return false;
    }

    public static Op.Type getOpType(Op op) {
        Op.Type type = null;
        if (op instanceof CustomOp) {
            return Op.Type.CUSTOM;
        }
        if (op instanceof ShapeOp) {
            return Op.Type.SHAPE;
        }
        if (op instanceof TransformOp) {
            type = op.y() == null ? (!op.isExecSpecial() ? Op.Type.TRANSFORM : Op.Type.SPECIAL) : Op.Type.PAIRWISE;
        } else if (op instanceof Accumulation) {
            type = op.y() == null ? Op.Type.REDUCE : Op.Type.REDUCE3;
        } else if (op instanceof ScalarOp) {
            type = Op.Type.SCALAR;
        } else if (op instanceof BroadcastOp) {
            type = Op.Type.BROADCAST;
        } else if (op instanceof IndexAccumulation) {
            type = Op.Type.INDEXREDUCE;
        } else if (op instanceof MetaOp) {
            type = Op.Type.META;
        } else if (op instanceof GridOp) {
            type = Op.Type.GRID;
        }
        return type;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    @Override
    public DataBuffer extraArgsDataBuff() {
        if (this.extraArgz != null) {
            return this.extraArgz;
        }
        if (this.extraArgs != null) {
            DataBuffer.Type dtype;
            DataBuffer.Type type = dtype = this.x != null ? this.x.data().dataType() : Nd4j.dataType();
            if (dtype == DataBuffer.Type.FLOAT || dtype == DataBuffer.Type.HALF) {
                float[] extraz = new float[this.extraArgs.length];
                for (int i = 0; i < this.extraArgs.length; ++i) {
                    float val;
                    Number arg = (Number)this.extraArgs[i];
                    extraz[i] = val = arg.floatValue();
                }
                this.extraArgz = Nd4j.getConstantHandler().getConstantBuffer(extraz);
                return this.extraArgz;
            }
            if (dtype == DataBuffer.Type.DOUBLE) {
                double[] extraz = new double[this.extraArgs.length];
                for (int i = 0; i < this.extraArgs.length; ++i) {
                    double val;
                    if (!(this.extraArgs[i] instanceof Number)) continue;
                    Number arg = (Number)this.extraArgs[i];
                    if (arg == null) {
                        arg = 0.0;
                    }
                    extraz[i] = val = arg.doubleValue();
                }
                this.extraArgz = Nd4j.getConstantHandler().getConstantBuffer(extraz);
                return this.extraArgz;
            }
        }
        return null;
    }

    @Override
    public Buffer extraArgsBuff() {
        if (this.extraArgs != null) {
            if (this.x.data().dataType() == DataBuffer.Type.FLOAT) {
                DataBuffer retBuff = Nd4j.createBuffer(new float[this.extraArgs.length]);
                for (int i = 0; i < this.extraArgs.length; ++i) {
                    Number val = (Number)this.extraArgs[i];
                    retBuff.put((long)i, val.floatValue());
                }
                return retBuff.asNioFloat();
            }
            DataBuffer retBuff = Nd4j.createBuffer(new double[this.extraArgs.length]);
            for (int i = 0; i < this.extraArgs.length; ++i) {
                Number val = (Number)this.extraArgs[i];
                retBuff.put((long)i, val.doubleValue());
            }
            return retBuff.asNioDouble();
        }
        return null;
    }

    @Override
    public boolean isPassThrough() {
        return this.passThrough;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void setX(INDArray x) {
        if (x == null) {
            SDVariable sdVariable;
            if (this.args() == null || this.args().length < 1) throw new ND4JIllegalStateException("Unable to set null array for x. Also unable to infer from differential function arguments");
            SDVariable firstArg = this.args()[0];
            if (firstArg instanceof SDVariable && (sdVariable = firstArg).getArr() != null) {
                this.x = sdVariable.getArr();
            }
        } else {
            this.x = x;
        }
        this.numProcessed = 0L;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void setZ(INDArray z) {
        if (z == null) {
            SDVariable getResult = this.sameDiff.getVariable(this.zVertexId);
            if (getResult == null) throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
            if (getResult.getArr() != null) {
                this.z = getResult.getArr();
            } else {
                if (this.sameDiff.getShapeForVarName(getResult.getVarName()) == null) throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
                int[] shape = this.sameDiff.getShapeForVarName(getResult.getVarName());
                this.sameDiff.putArrayForVarName(getResult.getVarName(), getResult.getWeightInitScheme().create(shape));
            }
        } else {
            this.z = z;
        }
        this.numProcessed = 0L;
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public void setY(INDArray y) {
        if (y == null) {
            SDVariable sdVariable;
            if (this.args() == null || this.args().length <= 1) throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments");
            SDVariable firstArg = this.args()[1];
            if (firstArg instanceof SDVariable && (sdVariable = firstArg).getArr() != null) {
                this.y = sdVariable.getArr();
            }
        } else {
            this.y = y;
        }
        this.numProcessed = 0L;
    }

    public BaseOp(INDArray x, INDArray z) {
        this(x, z, x.lengthLong());
    }

    public BaseOp(INDArray x, INDArray z, long n) {
        this(x, null, z, n);
    }

    public BaseOp(INDArray x, INDArray y, INDArray z, long n) {
        this.init(x, y, z, n);
    }

    public BaseOp(INDArray x) {
        this(x, null, x, x == null ? 0L : x.lengthLong());
    }

    @Override
    public Object[] extraArgs() {
        return this.extraArgs;
    }

    @Override
    public INDArray x() {
        if (this.x == null && this.sameDiff != null) {
            this.x = this.sameDiff.getArrForVarName(this.args()[0].getVarName());
            if (this.x == null && this.args()[0].getShape() != null) {
                this.x = this.args()[0].storeAndAllocateNewArray();
            }
        }
        return this.x;
    }

    @Override
    public INDArray y() {
        if (this.y == null && this.sameDiff != null && this.args().length > 1) {
            this.y = this.sameDiff.getArrForVarName(this.args()[1].getVarName());
            if (this.y == null && this.args()[1].getShape() != null) {
                this.y = this.args()[1].storeAndAllocateNewArray();
            }
        }
        return this.y;
    }

    @Override
    public INDArray z() {
        if (this.z == null) {
            if (this.sameDiff != null) {
                SDVariable var;
                this.z = this.outputVariables()[0].getArr();
                if (this.z == null && (var = this.outputVariables()[0]).getShape() != null) {
                    this.z = var.storeAndAllocateNewArray();
                }
            }
        } else if (this.zVertexId != null && this.sameDiff != null && this.sameDiff.getArrForVarName(this.zVertexId) == null && this.z != null) {
            this.sameDiff.putArrayForVarName(this.zVertexId, this.z);
        }
        return this.z;
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        if (this.zVertexId == null) {
            List<int[]> shapes;
            String[] outputNames = this.sameDiff.getOutputsForFunction(this);
            if (outputNames != null) {
                this.zVertexId = this.sameDiff.getVariable(outputNames[0]).getVarName();
                return new SDVariable[]{this.sameDiff.getVariable(outputNames[0])};
            }
            if (this.isInPlace()) {
                SDVariable[] newVars = this.sameDiff.generateOutputVariableForOp(this, null);
                INDArray inputArr = this.x();
                if (inputArr == null) {
                    return newVars;
                }
                this.sameDiff.putArrayForVarName(newVars[0].getVarName(), inputArr);
                this.z = inputArr;
                if (this.sameDiff.getOutputsForFunction(this) == null) {
                    this.sameDiff.addOutgoingFor(newVars, (DifferentialFunction)this);
                }
                return newVars;
            }
            SDVariable[] newVars = this.sameDiff.generateOutputVariableForOp(this, null);
            INDArray arr = null;
            arr = newVars == null || newVars.length < 1 || newVars[0].getShape() == null ? null : (newVars[0].getArr() == null ? newVars[0].storeAndAllocateNewArray() : newVars[0].getArr());
            if (arr == null && (shapes = this.calculateOutputShape()) != null && !shapes.isEmpty() && shapes.get(0) != null) {
                this.sameDiff.putShapeForVarName(newVars[0].getVarName(), shapes.get(0));
                arr = newVars[0].storeAndAllocateNewArray();
            }
            this.z = arr;
            if (this.sameDiff.getOutputsForFunction(this) == null) {
                this.sameDiff.addOutgoingFor(newVars, (DifferentialFunction)this);
            }
            return newVars;
        }
        return new SDVariable[]{this.sameDiff.getVariable(this.zVertexId)};
    }

    @Override
    public long n() {
        if (this.n == 0L && this.arg() != null) {
            this.n = ArrayUtil.prod((int[])this.arg().getShape());
        }
        return this.n;
    }

    @Override
    public void init(INDArray x, INDArray y, INDArray z, long n) {
        this.x = x;
        this.y = y;
        this.z = z;
        this.n = n;
    }

    @Override
    public void setN(long n) {
        this.n = n;
    }

    @Override
    public long numProcessed() {
        return this.numProcessed;
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public CustomOp toCustomOp() {
        DynamicCustomOp.DynamicCustomOpsBuilder customOpBuilder = DynamicCustomOp.builder(this.opName());
        customOpBuilder.callInplace(this.x() == this.z());
        if (this.y() != null) {
            customOpBuilder.addInputs(this.x(), this.y());
        } else {
            customOpBuilder.addInputs(this.x());
        }
        customOpBuilder.addOutputs(this.z());
        if (this.extraArgs != null) {
            for (int i = 0; i < this.extraArgs.length; ++i) {
                if (this.extraArgs[i] instanceof Integer) {
                    customOpBuilder.addIntegerArguments((int)((Integer)this.extraArgs[i]));
                    continue;
                }
                if (!(this.extraArgs[i] instanceof Double) && !(this.extraArgs[i] instanceof Float)) continue;
                Double num = (Double)this.extraArgs[i];
                customOpBuilder.addFloatingPointArguments(num);
            }
        }
        return customOpBuilder.build();
    }

    @Override
    public void exec() {
    }

    @Override
    public void exec(int ... dimensions) {
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BaseOp baseOp = (BaseOp)o;
        if (this.n != baseOp.n) {
            return false;
        }
        if (this.numProcessed != baseOp.numProcessed) {
            return false;
        }
        if (this.passThrough != baseOp.passThrough) {
            return false;
        }
        if (this.x != null ? !this.x.equals(baseOp.x) : baseOp.x != null) {
            return false;
        }
        if (this.y != null ? !this.y.equals(baseOp.y) : baseOp.y != null) {
            return false;
        }
        if (this.z != null ? !this.z.equals(baseOp.z) : baseOp.z != null) {
            return false;
        }
        if (!Arrays.equals(this.extraArgs, baseOp.extraArgs)) {
            return false;
        }
        return this.extraArgz != null ? this.extraArgz.equals(baseOp.extraArgz) : baseOp.extraArgz == null;
    }

    @Override
    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.x != null ? this.x.hashCode() : 0);
        result = 31 * result + (this.y != null ? this.y.hashCode() : 0);
        result = 31 * result + (this.z != null ? this.z.hashCode() : 0);
        result = 31 * result + (int)(this.n ^ this.n >>> 32);
        result = 31 * result + (int)(this.numProcessed ^ this.numProcessed >>> 32);
        result = 31 * result + Arrays.hashCode(this.extraArgs);
        result = 31 * result + (this.passThrough ? 1 : 0);
        result = 31 * result + (this.extraArgz != null ? this.extraArgz.hashCode() : 0);
        return result;
    }

    public INDArray getX() {
        return this.x;
    }

    public INDArray getY() {
        return this.y;
    }

    public INDArray getZ() {
        return this.z;
    }

    public long getN() {
        return this.n;
    }

    public long getNumProcessed() {
        return this.numProcessed;
    }

    @Override
    public Object[] getExtraArgs() {
        return this.extraArgs;
    }

    public DataBuffer getExtraArgz() {
        return this.extraArgz;
    }

    public void setNumProcessed(long numProcessed) {
        this.numProcessed = numProcessed;
    }

    @Override
    public void setExtraArgs(Object[] extraArgs) {
        this.extraArgs = extraArgs;
    }

    public void setPassThrough(boolean passThrough) {
        this.passThrough = passThrough;
    }

    public void setExtraArgz(DataBuffer extraArgz) {
        this.extraArgz = extraArgz;
    }

    public String getXVertexId() {
        return this.xVertexId;
    }

    public String getYVertexId() {
        return this.yVertexId;
    }

    public String getZVertexId() {
        return this.zVertexId;
    }

    public void setXVertexId(String xVertexId) {
        this.xVertexId = xVertexId;
    }

    public void setYVertexId(String yVertexId) {
        this.yVertexId = yVertexId;
    }

    public void setZVertexId(String zVertexId) {
        this.zVertexId = zVertexId;
    }
}

