package org.nd4j.linalg.api.ops.impl.shape;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/shape/Gather.class */
public class Gather extends DynamicCustomOp {
    protected int[] indices;
    protected int axis;

    public Gather(SameDiff sameDiff, SDVariable sDVariable, int[] iArr, int i, boolean z) {
        super(null, sameDiff, new SDVariable[]{sDVariable}, z);
        this.axis = 0;
        addIArgument(i);
        addIArgument(iArr);
        this.axis = i;
        this.indices = iArr;
    }

    public Gather(SameDiff sameDiff, SDVariable sDVariable, SDVariable sDVariable2, int i, boolean z) {
        super(null, sameDiff, new SDVariable[]{sDVariable, sDVariable2}, z);
        this.axis = 0;
        addIArgument(i);
        this.axis = i;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "Gather";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String[] tensorflowNames() {
        return new String[]{"Gather", "GatherV2"};
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
        OnnxGraphMapper.getInstance().initFunctionFromProperties(nodeProto.getOpType(), this, map, nodeProto, graphProto);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public void resolvePropertiesFromSameDiffBeforeExecution() {
        super.resolvePropertiesFromSameDiffBeforeExecution();
        if (this.indices != null && numInputArguments() < 2) {
            if (numInputArguments() == 0) {
                INDArray create = Nd4j.create(ArrayUtil.toFloats(this.indices));
                addInputArgument(args()[0].getArr(), this.indices.length > 1 ? create.reshape(this.indices.length) : create.reshape(new int[0]));
            } else if (numInputArguments() == 1) {
                addInputArgument(Nd4j.create(ArrayUtil.toFloats(this.indices)));
            }
        }
        if (numIArguments() < 1) {
            addIArgument(this.axis);
        }
        if (numOutputArguments() < getDescriptor().getNumOutputs()) {
            SDVariable[] outputVariables = outputVariables();
            for (SDVariable sDVariable : outputVariables) {
                if (sDVariable.getArr() == null) {
                    return;
                }
            }
            for (SDVariable sDVariable2 : outputVariables) {
                addOutputArgument(sDVariable2.getArr());
            }
        }
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        hashMap2.put("indices", PropertyMapping.builder().onnxAttrName("indices").tfInputPosition(1).propertyNames(new String[]{"indices"}).build());
        hashMap.put(tensorflowNames()[0], hashMap2);
        hashMap.put(onnxName(), hashMap2);
        HashMap hashMap3 = new HashMap();
        hashMap3.put("indices", PropertyMapping.builder().tfInputPosition(1).propertyNames(new String[]{"indices"}).build());
        hashMap3.put("axis", PropertyMapping.builder().tfInputPosition(2).propertyNames(new String[]{"axis"}).build());
        hashMap.put("GatherV2", hashMap3);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "gather";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        SDVariable permute;
        SDVariable zerosLike = this.sameDiff.zerosLike(arg(1));
        SDVariable zerosLike2 = this.sameDiff.zerosLike(arg(0));
        int length = arg(0).getShape().length;
        int i = this.axis;
        if (i < 0) {
            i += length;
        }
        if (i == 0) {
            permute = this.sameDiff.scatterAdd(zerosLike2, arg(1), list.get(0));
        } else {
            int[] iArr = new int[length];
            iArr[0] = i;
            for (int i2 = 0; i2 < i; i2++) {
                iArr[i2 + 1] = i2;
            }
            for (int i3 = i + 1; i3 < length; i3++) {
                iArr[i3] = i3;
            }
            SDVariable scatterAdd = this.sameDiff.scatterAdd(this.sameDiff.permute(zerosLike2, iArr), arg(1), this.sameDiff.permute(list.get(0), iArr));
            int[] iArr2 = new int[length];
            for (int i4 = 0; i4 < length; i4++) {
                iArr2[iArr[i4]] = i4;
            }
            permute = this.sameDiff.permute(scatterAdd, iArr2);
        }
        return Arrays.asList(permute, zerosLike);
    }

    public Gather() {
        this.axis = 0;
    }
}
