package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig;
import org.nd4j.linalg.util.ArrayUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D.class */
public class DeConv2D extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger(DeConv2D.class);
    protected DeConv2DConfig config;

    /* loaded from: input_file:org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2D$DeConv2DBuilder.class */
    public static class DeConv2DBuilder {
        private SameDiff sameDiff;
        private SDVariable[] inputs;
        private INDArray[] inputArrays;
        private INDArray[] outputs;
        private DeConv2DConfig config;

        DeConv2DBuilder() {
        }

        public DeConv2DBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public DeConv2DBuilder inputs(SDVariable[] sDVariableArr) {
            this.inputs = sDVariableArr;
            return this;
        }

        public DeConv2DBuilder inputArrays(INDArray[] iNDArrayArr) {
            this.inputArrays = iNDArrayArr;
            return this;
        }

        public DeConv2DBuilder outputs(INDArray[] iNDArrayArr) {
            this.outputs = iNDArrayArr;
            return this;
        }

        public DeConv2DBuilder config(DeConv2DConfig deConv2DConfig) {
            this.config = deConv2DConfig;
            return this;
        }

        public DeConv2D build() {
            return new DeConv2D(this.sameDiff, this.inputs, this.inputArrays, this.outputs, this.config);
        }

        public String toString() {
            return "DeConv2D.DeConv2DBuilder(sameDiff=" + this.sameDiff + ", inputs=" + Arrays.deepToString(this.inputs) + ", inputArrays=" + Arrays.deepToString(this.inputArrays) + ", outputs=" + Arrays.deepToString(this.outputs) + ", config=" + this.config + ")";
        }
    }

    public DeConv2D(SameDiff sameDiff, SDVariable[] sDVariableArr, INDArray[] iNDArrayArr, INDArray[] iNDArrayArr2, DeConv2DConfig deConv2DConfig) {
        super((String) null, iNDArrayArr, iNDArrayArr2);
        this.sameDiff = sameDiff;
        this.config = deConv2DConfig;
        if (iNDArrayArr != null) {
            addInputArgument(iNDArrayArr);
        }
        if (iNDArrayArr2 != null) {
            addOutputArgument(iNDArrayArr2);
        }
        addArgs();
        sameDiff.putFunctionForId(getOwnName(), this);
        sameDiff.addArgsFor(sDVariableArr, this);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.linalg.api.ops.CustomOp
    public long[] iArgs() {
        if (this.iArguments.size() == 0) {
            addArgs();
        }
        return super.iArgs();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Object> propertiesForFunction() {
        if (this.config == null && !this.iArguments.isEmpty()) {
            this.config = DeConv2DConfig.builder().kH(this.iArguments.get(0).longValue()).kW(this.iArguments.get(1).longValue()).sH(this.iArguments.get(2).longValue()).sW(this.iArguments.get(3).longValue()).pH(this.iArguments.get(4).longValue()).pW(this.iArguments.get(5).longValue()).dH(this.iArguments.get(6).longValue()).dW(this.iArguments.get(7).longValue()).isSameMode(this.iArguments.get(8).longValue() == 1).dataFormat(this.iArguments.get(9).longValue() == 1 ? "NHWC" : "NCHW").build();
        }
        return this.config.toProperties();
    }

    private void addArgs() {
        addIArgument(this.config.getKH());
        addIArgument(this.config.getKW());
        addIArgument(this.config.getSH());
        addIArgument(this.config.getSW());
        addIArgument(this.config.getPH());
        addIArgument(this.config.getPW());
        addIArgument(this.config.getDH());
        addIArgument(this.config.getDW());
        addIArgument(ArrayUtil.fromBoolean(this.config.isSameMode()));
        int[] iArr = new int[1];
        iArr[0] = this.config.getDataFormat().equalsIgnoreCase("NCHW") ? 0 : 1;
        addIArgument(iArr);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public boolean isConfigProperties() {
        return true;
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public String configFieldName() {
        return "config";
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Object getValue(Field field) {
        if (this.config == null) {
            this.config = DeConv2DConfig.builder().build();
        }
        return this.config.getValue(field);
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        PropertyMapping build = PropertyMapping.builder().tfAttrName("strides").onnxAttrName("strides").build();
        PropertyMapping build2 = PropertyMapping.builder().propertyNames(new String[]{"kH", "kW"}).tfInputPosition(1).onnxAttrName("kernel_shape").build();
        PropertyMapping build3 = PropertyMapping.builder().onnxAttrName("dilations").propertyNames(new String[]{"dW", "dH"}).tfAttrName("rates").build();
        PropertyMapping build4 = PropertyMapping.builder().onnxAttrName("auto_pad").propertyNames(new String[]{"isSameMode"}).tfAttrName("padding").build();
        PropertyMapping build5 = PropertyMapping.builder().onnxAttrName("padding").propertyNames(new String[]{"pH", "pW"}).build();
        hashMap2.put("sW", build);
        hashMap2.put("sH", build);
        hashMap2.put("kH", build2);
        hashMap2.put("kW", build2);
        hashMap2.put("dW", build3);
        hashMap2.put("dH", build3);
        hashMap2.put("isSameMode", build4);
        hashMap2.put("pH", build5);
        hashMap2.put("pW", build5);
        hashMap.put(onnxName(), hashMap2);
        hashMap.put(tensorflowName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        int intValue;
        int intValue2;
        int size;
        int size2;
        List<Long> iList = nodeDef.getAttrOrThrow("strides").getList().getIList();
        String stringUtf8 = nodeDef.getAttrOrDefault("padding", null).getS().toStringUtf8();
        SDVariable[] args = args();
        INDArray arr = this.sameDiff.getVariable(args[1].getVarName()).getArr();
        if (arr == null) {
            arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graphDef);
            SDVariable variable = sameDiff.getVariable(args[1].getVarName());
            if (arr != null) {
                sameDiff.associateArrayWithVariable(arr, variable);
            }
        }
        String lowerCase = nodeDef.containsAttr("data_format") ? nodeDef.getAttrOrThrow("data_format").getS().toStringUtf8().toLowerCase() : "nhwc";
        if (lowerCase.equalsIgnoreCase("NCHW")) {
            intValue = iList.get(2).intValue();
            intValue2 = iList.get(3).intValue();
            size = (int) arr.size(2);
            size2 = (int) arr.size(3);
        } else {
            intValue = iList.get(1).intValue();
            intValue2 = iList.get(2).intValue();
            size = (int) arr.size(0);
            size2 = (int) arr.size(1);
        }
        this.config = DeConv2DConfig.builder().kH(size).kW(size2).sH(intValue2).sW(intValue).isSameMode(stringUtf8.equalsIgnoreCase("SAME")).dataFormat(lowerCase.equalsIgnoreCase("NHWC") ? "NHWC" : "NCHW").build();
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromOnnx(OnnxProto3.NodeProto nodeProto, SameDiff sameDiff, Map<String, OnnxProto3.AttributeProto> map, OnnxProto3.GraphProto graphProto) {
        String stringUtf8 = !map.containsKey("auto_pad") ? "VALID" : map.get("auto_pad").getS().toStringUtf8();
        OnnxProto3.AttributeProto attributeProto = map.get("dilations");
        int intValue = attributeProto == null ? 1 : attributeProto.getIntsList().get(0).intValue();
        int intValue2 = attributeProto == null ? 1 : attributeProto.getIntsList().get(1).intValue();
        map.get("group");
        OnnxProto3.AttributeProto attributeProto2 = map.get("kernel_shape");
        int intValue3 = attributeProto2.getIntsList().get(0).intValue();
        int intValue4 = attributeProto2.getIntsList().size() < 2 ? intValue3 : attributeProto2.getIntsList().get(1).intValue();
        SDVariable sDVariable = args()[0];
        INDArray dup = sDVariable.getArr().permute(3, 2, 0, 1).dup('c');
        sameDiff.associateArrayWithVariable(dup, sDVariable);
        OnnxProto3.AttributeProto attributeProto3 = map.get("strides");
        Long l = attributeProto3.getIntsList().get(0);
        this.config = DeConv2DConfig.builder().kH(intValue3).kW(intValue4).sH(l.intValue()).sW((attributeProto3.getIntsList().size() < 2 ? l : attributeProto3.getIntsList().get(1)).intValue()).isSameMode(stringUtf8.equalsIgnoreCase("SAME")).dataFormat("nhwc".equalsIgnoreCase("nhwc") ? "NHWC" : "NCHW").build();
        addArgs();
        addOutputArgument(dup);
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String opName() {
        return "deconv2d";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String onnxName() {
        return "ConvTranspose";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "Conv2DTranspose";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.addAll(Arrays.asList(args()));
        arrayList2.addAll(list);
        arrayList.addAll(Arrays.asList(DeConv2DDerivative.derivativeBuilder().sameDiff(this.sameDiff).config(this.config).inputs((SDVariable[]) arrayList2.toArray(new SDVariable[arrayList2.size()])).build().outputVariables()));
        return arrayList;
    }

    public static DeConv2DBuilder builder() {
        return new DeConv2DBuilder();
    }

    public DeConv2DConfig getConfig() {
        return this.config;
    }

    public DeConv2D() {
    }
}
