/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.imports.graphmapper.tf;

import com.github.os72.protobuf351.Message;
import com.google.common.primitives.Floats;
import com.google.common.primitives.Ints;
import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.controlflow.IfImportState;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.DataType;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorProto;
import org.tensorflow.framework.TensorShapeProto;

public class TFGraphMapper
extends BaseGraphMapper<GraphDef, NodeDef, AttrValue, NodeDef> {
    private static final Logger log = LoggerFactory.getLogger(TFGraphMapper.class);
    private Set<String> seenNodes = new LinkedHashSet<String>();
    public static final String VALUE_ATTR_KEY = "value";
    public static final String SHAPE_KEY = "shape";
    private static TFGraphMapper MAPPER_INSTANCE = new TFGraphMapper();
    private Set<String> graphMapper = new HashSet<String>(){
        {
            this.add("LoopCond");
            this.add("Merge");
            this.add("Exit");
            this.add("NextIteration");
            this.add("NoOp");
            this.add("Switch");
        }
    };

    private TFGraphMapper() {
    }

    public static TFGraphMapper getInstance() {
        return MAPPER_INSTANCE;
    }

    @Override
    public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
        try {
            GraphDef graphDef = GraphDef.parseFrom(inputFile);
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile, true));
            for (NodeDef node : graphDef.getNodeList()) {
                bufferedWriter.write(node.toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public boolean isOpIgnoreException(NodeDef node) {
        return true;
    }

    @Override
    public String getTargetMappingForOp(DifferentialFunction function, NodeDef node) {
        return function.opName();
    }

    @Override
    public NodeDef getNodeWithNameFromGraph(GraphDef graph, String name) {
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            NodeDef node = graph.getNode(i);
            if (!node.getName().equals(name)) continue;
            return node;
        }
        return null;
    }

    @Override
    public void mapProperty(String name, DifferentialFunction on, NodeDef node, GraphDef graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
        if (node == null) {
            throw new ND4JIllegalStateException("No node found for name " + name);
        }
        PropertyMapping mapping = propertyMappingsForFunction.get(this.getOpType(node)).get(name);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        if (mapping.getTfInputPosition() != null && mapping.getTfInputPosition() < node.getInputCount()) {
            int tfMappingIdx = mapping.getTfInputPosition();
            if (tfMappingIdx < 0) {
                tfMappingIdx += node.getInputCount();
            }
            String input = node.getInput(tfMappingIdx);
            NodeDef inputNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, input);
            INDArray arr = this.getArrayFrom(inputNode, graph);
            if (arr == null) {
                arr = sameDiff.getArrForVarName(input);
            }
            if (arr == null && inputNode != null) {
                sameDiff.addPropertyToResolve(on, name);
                sameDiff.addVariableMappingForField(on, name, inputNode.getName());
                return;
            }
            if (inputNode == null) {
                sameDiff.addAsPlaceHolder(input);
                return;
            }
            Field field = fields.get(name);
            Class<?> type = field.getType();
            if (type.equals(int[].class)) {
                on.setValueFor(field, arr.data().asInt());
            } else if (type.equals(Integer.TYPE) || type.equals(Long.TYPE) || type.equals(Long.class) || type.equals(Integer.class)) {
                if (mapping.getShapePosition() != null) {
                    on.setValueFor(field, arr.size(mapping.getShapePosition()));
                } else {
                    on.setValueFor(field, arr.getInt(0));
                }
            } else if (type.equals(Float.TYPE) || type.equals(Double.TYPE) || type.equals(Float.class) || type.equals(Double.class)) {
                on.setValueFor(field, arr.getDouble(0));
            }
        } else {
            String tfMappingAttrName = mapping.getTfAttrName();
            if (tfMappingAttrName == null) {
                return;
            }
            if (!node.containsAttr(tfMappingAttrName)) {
                return;
            }
            AttrValue attr = node.getAttrOrThrow(tfMappingAttrName);
            DataType type = attr.getType();
            if (fields == null) {
                throw new ND4JIllegalStateException("No fields found for op " + mapping);
            }
            if (mapping.getPropertyNames() == null) {
                throw new ND4JIllegalStateException("no property found for " + name + " and op " + on.opName());
            }
            Field field = fields.get(mapping.getPropertyNames()[0]);
            Object valueToSet = null;
            switch (type) {
                case DT_BOOL: {
                    valueToSet = attr.getB();
                    break;
                }
                case DT_INT8: {
                    valueToSet = attr.getI();
                    break;
                }
                case DT_INT16: {
                    valueToSet = attr.getI();
                    break;
                }
                case DT_INT32: {
                    valueToSet = attr.getI();
                    break;
                }
                case DT_FLOAT: {
                    valueToSet = Float.valueOf(attr.getF());
                    break;
                }
                case DT_DOUBLE: {
                    valueToSet = Float.valueOf(attr.getF());
                    break;
                }
                case DT_STRING: {
                    valueToSet = attr.getS();
                    break;
                }
                case DT_INT64: {
                    valueToSet = attr.getI();
                }
            }
            if (field != null && valueToSet != null) {
                on.setValueFor(field, valueToSet);
            }
        }
    }

    @Override
    public boolean isPlaceHolderNode(NodeDef node) {
        return node.getOp().startsWith("Placeholder");
    }

    @Override
    public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
        try {
            GraphDef graphDef = GraphDef.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
            BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile, true));
            for (NodeDef node : graphDef.getNodeList()) {
                bufferedWriter.write(node.toString());
            }
            bufferedWriter.flush();
            bufferedWriter.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override
    public int[] getShapeFromAttr(AttrValue attr) {
        return this.shapeFromShapeProto(attr.getShape());
    }

    @Override
    public Map<String, AttrValue> getAttrMap(NodeDef nodeDef) {
        return nodeDef.getAttrMap();
    }

    @Override
    public String getName(NodeDef nodeDef) {
        return nodeDef.getName();
    }

    @Override
    public boolean alreadySeen(NodeDef nodeDef) {
        return this.seenNodes.contains(nodeDef.getName());
    }

    @Override
    public boolean isVariableNode(NodeDef nodeDef) {
        boolean isVar = nodeDef.getOp().startsWith("VariableV") || nodeDef.getOp().equalsIgnoreCase("const");
        return isVar;
    }

    @Override
    public boolean shouldSkip(NodeDef opType) {
        if (opType == null) {
            return true;
        }
        boolean endsWithRead = opType.getName().endsWith("/read");
        boolean isReductionIndices = opType.getOp().endsWith("/reduction_indices");
        return endsWithRead || isReductionIndices;
    }

    @Override
    public boolean hasShape(NodeDef nodeDef) {
        return nodeDef.containsAttr(SHAPE_KEY);
    }

    @Override
    public int[] getShape(NodeDef nodeDef) {
        return this.getShapeFromAttr(nodeDef.getAttrOrThrow(SHAPE_KEY));
    }

    @Override
    public INDArray getArrayFrom(NodeDef nodeDef, GraphDef graph) {
        if (nodeDef == null) {
            return null;
        }
        return this.getNDArrayFromTensor(nodeDef.getName(), nodeDef, graph);
    }

    @Override
    public String getOpType(NodeDef nodeDef) {
        return nodeDef.getOp();
    }

    @Override
    public List<NodeDef> getNodeList(GraphDef graphDef) {
        return graphDef.getNodeList();
    }

    @Override
    public DifferentialFunction getMappedOp(String name) {
        return DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(name);
    }

    public String getNodeName(String name) {
        String ret = name;
        if (ret.startsWith("^")) {
            ret = ret.substring(1);
        }
        if (ret.endsWith("/read")) {
            ret = ret.replace("/read", "");
        }
        return ret;
    }

    @Override
    public Map<String, NodeDef> variablesForGraph(GraphDef graphDef) {
        LinkedHashMap<String, NodeDef> ret = new LinkedHashMap<String, NodeDef>();
        for (NodeDef nodeDef : graphDef.getNodeList()) {
            if (nodeDef.getName().endsWith("/read")) continue;
            String name = this.translateToSameDiffName(nodeDef.getName(), nodeDef);
            ret.put(name, nodeDef);
        }
        return ret;
    }

    @Override
    public String translateToSameDiffName(String name, NodeDef node) {
        if (this.isVariableNode(node) || this.isPlaceHolder(node)) {
            return name;
        }
        StringBuilder stringBuilder = new StringBuilder();
        if (name.contains(":")) {
            name = name.substring(0, name.lastIndexOf(58));
            stringBuilder.append(name);
        } else {
            stringBuilder.append(name);
        }
        return stringBuilder.toString();
    }

    @Override
    public Message.Builder getNewGraphBuilder() {
        return GraphDef.newBuilder();
    }

    @Override
    public GraphDef parseGraphFrom(byte[] inputStream) throws IOException {
        return GraphDef.parseFrom(inputStream);
    }

    @Override
    public GraphDef parseGraphFrom(InputStream inputStream) throws IOException {
        return GraphDef.parseFrom(inputStream);
    }

    protected void importCondition(String conditionName, NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
    }

    @Override
    public void mapNodeType(NodeDef tfNode, ImportState<GraphDef, NodeDef> importState) {
        block12: {
            SameDiff diff;
            block10: {
                AttrValue shape;
                int[] shapeArr;
                int dims;
                Map<String, AttrValue> attributes;
                ArrayList<Integer> dimensions;
                block11: {
                    if (this.shouldSkip(tfNode) || this.alreadySeen(tfNode) || this.isVariableNode(tfNode)) {
                        return;
                    }
                    diff = importState.getSameDiff();
                    if (!this.isVariableNode(tfNode)) break block10;
                    dimensions = new ArrayList<Integer>();
                    attributes = this.getAttrMap(tfNode);
                    if (!attributes.containsKey(VALUE_ATTR_KEY)) break block11;
                    diff.var(this.getName(tfNode), this.getArrayFrom(tfNode, importState.getGraph()));
                    break block12;
                }
                if (!attributes.containsKey(SHAPE_KEY) || (dims = (shapeArr = this.getShapeFromAttr(shape = attributes.get(SHAPE_KEY))).length) <= 0) break block12;
                if (dims == 1) {
                    dimensions.add(1);
                }
                for (int e = 0; e < dims; ++e) {
                    dimensions.add(this.getShapeFromAttr(shape)[e]);
                }
                break block12;
            }
            if (this.isPlaceHolder(tfNode)) {
                SDVariable vertexId = diff.getVariable(this.getName(tfNode));
                diff.addAsPlaceHolder(vertexId.getVarName());
            } else {
                String opName = tfNode.getOp();
                String nodeName = tfNode.getName();
                DifferentialFunction differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithTensorflowName(opName);
                if (differentialFunction == null) {
                    throw new ND4JIllegalStateException("No tensorflow op found for " + opName + " possibly missing operation class?");
                }
                try {
                    DifferentialFunction newInstance = (DifferentialFunction)differentialFunction.getClass().newInstance();
                    SDVariable[] args = new SDVariable[tfNode.getInputCount()];
                    newInstance.setOwnName(tfNode.getName());
                    for (int i = 0; i < tfNode.getInputCount(); ++i) {
                        String name = this.getNodeName(tfNode.getInput(i));
                        args[i] = diff.getVariable(name);
                        if (args[i] == null) {
                            args[i] = diff.var(name, null, new ZeroInitScheme('f'));
                            diff.addAsPlaceHolder(args[i].getVarName());
                        }
                        if (!diff.isPlaceHolder(args[i].getVarName())) continue;
                        diff.putPlaceHolderForVariable(args[i].getVarName(), name);
                    }
                    diff.addArgsFor(args, newInstance);
                    newInstance.setSameDiff(importState.getSameDiff());
                    newInstance.initFromTensorFlow(tfNode, diff, this.getAttrMap(tfNode), importState.getGraph());
                    this.mapProperties(newInstance, tfNode, importState.getGraph(), importState.getSameDiff(), newInstance.mappingsForFunction());
                    importState.getSameDiff().putFunctionForId(newInstance.getOwnName(), newInstance);
                    diff.setBaseNameForFunctionInstanceId(tfNode.getName(), newInstance);
                    diff.addVarNameForImport(tfNode.getName());
                }
                catch (Exception e) {
                    log.error("Failed with [{}]", (Object)opName);
                    throw new RuntimeException(e);
                }
            }
        }
    }

    public void initFunctionFromProperties(DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
        this.initFunctionFromProperties(on.tensorflowName(), on, attributesForNode, node, graph);
    }

    public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, AttrValue> attributesForNode, NodeDef node, GraphDef graph) {
        Map<String, Map<String, PropertyMapping>> properties = on.mappingsForFunction();
        Map<String, PropertyMapping> tfProperties = properties.get(mappedTfName);
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
        Map<String, Map<String, AttributeAdapter>> attributeAdapters = on.attributeAdaptersForFunction();
        if (tfProperties == null) {
            return;
        }
        for (Map.Entry<String, PropertyMapping> entry : tfProperties.entrySet()) {
            NodeDef inputFromNode;
            INDArray tensor;
            String tfAttrName = entry.getValue().getTfAttrName();
            Field currentField = fields.get(entry.getKey());
            AttributeAdapter adapter = null;
            if (attributeAdapters != null && !attributeAdapters.isEmpty()) {
                AttributeAdapter adapterFor;
                Map<String, AttributeAdapter> mappers = attributeAdapters.get(mappedTfName);
                adapter = adapterFor = mappers.get(entry.getKey());
            }
            if (tfAttrName != null) {
                if (currentField == null || !attributesForNode.containsKey(tfAttrName)) continue;
                AttrValue attr = attributesForNode.get(tfAttrName);
                switch (attr.getValueCase()) {
                    case B: {
                        if (adapter == null) break;
                        adapter.mapAttributeFor(attr.getB(), currentField, on);
                        break;
                    }
                    case F: {
                        break;
                    }
                    case FUNC: {
                        break;
                    }
                    case S: {
                        String setString = attr.getS().toStringUtf8();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setString, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, setString);
                        break;
                    }
                    case I: {
                        int setInt = (int)attr.getI();
                        if (adapter != null) {
                            adapter.mapAttributeFor(setInt, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, setInt);
                        break;
                    }
                    case SHAPE: {
                        List<TensorShapeProto.Dim> shape = attr.getShape().getDimList();
                        int[] dimsToSet = new int[shape.size()];
                        for (int i = 0; i < dimsToSet.length; ++i) {
                            dimsToSet[i] = (int)shape.get(i).getSize();
                        }
                        if (adapter != null) {
                            adapter.mapAttributeFor(dimsToSet, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, dimsToSet);
                        break;
                    }
                    case VALUE_NOT_SET: {
                        break;
                    }
                    case PLACEHOLDER: {
                        break;
                    }
                    case LIST: {
                        AttrValue.ListValue setList = attr.getList();
                        if (!setList.getIList().isEmpty()) {
                            int[] intList = Ints.toArray(setList.getIList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(intList, currentField, on);
                                break;
                            }
                            on.setValueFor(currentField, intList);
                            break;
                        }
                        if (!setList.getBList().isEmpty()) break;
                        if (!setList.getFList().isEmpty()) {
                            float[] floats = Floats.toArray(setList.getFList());
                            if (adapter != null) {
                                adapter.mapAttributeFor(floats, currentField, on);
                                break;
                            }
                            on.setValueFor(currentField, floats);
                            break;
                        }
                        if (!setList.getFuncList().isEmpty() || setList.getTensorList().isEmpty()) break;
                        break;
                    }
                    case TENSOR: {
                        INDArray tensorToGet = TFGraphMapper.getInstance().mapTensorProto(attr.getTensor());
                        if (adapter != null) {
                            adapter.mapAttributeFor(tensorToGet, currentField, on);
                            break;
                        }
                        on.setValueFor(currentField, tensorToGet);
                        break;
                    }
                    case TYPE: {
                        if (adapter == null) break;
                        adapter.mapAttributeFor((Object)attr.getType(), currentField, on);
                    }
                }
                continue;
            }
            if (entry.getValue().getTfInputPosition() == null) continue;
            int position = entry.getValue().getTfInputPosition();
            if (position < 0) {
                position += node.getInputCount();
            }
            INDArray iNDArray = tensor = (inputFromNode = TFGraphMapper.getInstance().getNodeWithNameFromGraph(graph, node.getInput(position))) != null ? TFGraphMapper.getInstance().getNDArrayFromTensor(VALUE_ATTR_KEY, inputFromNode, graph) : null;
            if (tensor == null) {
                tensor = on.getSameDiff().getArrForVarName(this.getNodeName(node.getInput(position)));
            }
            if (tensor != null) {
                if (adapter != null) {
                    adapter.mapAttributeFor(tensor, currentField, on);
                    continue;
                }
                if (currentField.getType().equals(int[].class)) {
                    on.setValueFor(currentField, tensor.data().asInt());
                    continue;
                }
                if (currentField.getType().equals(double[].class)) {
                    on.setValueFor(currentField, tensor.data().asDouble());
                    continue;
                }
                if (currentField.getType().equals(float[].class)) {
                    on.setValueFor(currentField, tensor.data().asFloat());
                    continue;
                }
                if (currentField.getType().equals(INDArray.class)) {
                    on.setValueFor(currentField, tensor);
                    continue;
                }
                if (currentField.getType().equals(Integer.TYPE)) {
                    on.setValueFor(currentField, tensor.getInt(0));
                    continue;
                }
                if (currentField.getType().equals(Double.TYPE)) {
                    on.setValueFor(currentField, tensor.getDouble(0));
                    continue;
                }
                if (!currentField.getType().equals(Float.TYPE)) continue;
                on.setValueFor(currentField, Float.valueOf(tensor.getFloat(0)));
                continue;
            }
            on.getSameDiff().addPropertyToResolve(on, entry.getKey());
        }
    }

    @Override
    public DataBuffer.Type dataTypeForTensor(NodeDef tensorProto) {
        if (!(tensorProto.containsAttr("dtype") || tensorProto.containsAttr("Tidx") || tensorProto.containsAttr("T"))) {
            return DataBuffer.Type.UNKNOWN;
        }
        DataType type = tensorProto.containsAttr("dtype") ? tensorProto.getAttrOrThrow("dtype").getType() : (tensorProto.containsAttr("T") ? tensorProto.getAttrOrThrow("T").getType() : tensorProto.getAttrOrThrow("Tidx").getType());
        switch (type) {
            case DT_DOUBLE: {
                return DataBuffer.Type.DOUBLE;
            }
            case DT_INT32: 
            case DT_INT64: {
                return DataBuffer.Type.INT;
            }
            case DT_FLOAT: {
                return DataBuffer.Type.FLOAT;
            }
            case DT_BFLOAT16: {
                return DataBuffer.Type.HALF;
            }
        }
        return DataBuffer.Type.UNKNOWN;
    }

    @Override
    public String getAttrValueFromNode(NodeDef nodeDef, String key) {
        return nodeDef.getAttrOrThrow(key).getS().toStringUtf8();
    }

    @Override
    public int[] getShapeFromAttribute(AttrValue attrValue) {
        TensorShapeProto shape = attrValue.getShape();
        int[] ret = new int[shape.getDimCount()];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = (int)shape.getDim(i).getSize();
        }
        return ret;
    }

    @Override
    public boolean isPlaceHolder(NodeDef nodeDef) {
        return nodeDef.getOp().startsWith("Placeholder");
    }

    @Override
    public INDArray getNDArrayFromTensor(String tensorName, NodeDef node, GraphDef graph) {
        if (!node.getAttrMap().containsKey(VALUE_ATTR_KEY)) {
            return null;
        }
        TensorProto tfTensor = node.getAttrOrThrow(VALUE_ATTR_KEY).getTensor();
        return this.mapTensorProto(tfTensor);
    }

    public INDArray mapTensorProto(TensorProto tfTensor) {
        int e;
        int dims = tfTensor.getTensorShape().getDimCount();
        int[] arrayShape = null;
        ArrayList<Integer> dimensions = new ArrayList<Integer>();
        for (int e2 = 0; e2 < dims; ++e2) {
            int dim = (int)tfTensor.getTensorShape().getDim(e2).getSize();
            dimensions.add(dim);
        }
        arrayShape = Ints.toArray(dimensions);
        if (tfTensor.getDtype() == DataType.DT_INT32 || tfTensor.getDtype() == DataType.DT_INT16 || tfTensor.getDtype() == DataType.DT_INT8) {
            if (tfTensor.getIntValCount() == 1 || ArrayUtil.prod((int[])arrayShape) == 1) {
                if (tfTensor.getIntValCount() < 1) {
                    return Nd4j.trueScalar(0.0);
                }
                int val = tfTensor.getIntVal(0);
                if (arrayShape == null || arrayShape.length == 0) {
                    arrayShape = new int[]{};
                }
                INDArray array = Nd4j.valueArrayOf(arrayShape, (double)val);
                return array;
            }
            if (tfTensor.getInt64ValCount() > 0) {
                double[] jArray = new double[tfTensor.getIntValCount()];
                for (e = 0; e < tfTensor.getIntValCount(); ++e) {
                    jArray[e] = tfTensor.getIntVal(e);
                }
                INDArray array = Nd4j.create(jArray, arrayShape, 0L, 'c');
                return array;
            }
            long length = ArrayUtil.prodLong((int[])arrayShape);
            ByteBuffer bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
            IntBuffer fb = bb.order(ByteOrder.nativeOrder()).asIntBuffer();
            float[] fa = new float[fb.capacity()];
            for (int e3 = 0; e3 < fb.capacity(); ++e3) {
                fa[e3] = fb.get(e3);
            }
            if (fa.length == 0) {
                throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
            }
            if (fa.length == 1) {
                return Nd4j.trueScalar(Float.valueOf(fa[0]));
            }
            if (arrayShape.length == 1) {
                return Nd4j.trueVector(fa);
            }
            INDArray array = Nd4j.create(fa, arrayShape, 'c', 0L);
            return array;
        }
        if (tfTensor.getDtype() == DataType.DT_FLOAT) {
            if (tfTensor.getFloatValCount() == 1 || ArrayUtil.prod((int[])arrayShape) == 1) {
                if (tfTensor.getFloatValCount() < 1) {
                    return Nd4j.scalar(0.0);
                }
                float val = tfTensor.getFloatVal(0);
                if (arrayShape == null || arrayShape.length == 0) {
                    arrayShape = new int[]{};
                }
                INDArray array = Nd4j.valueArrayOf(arrayShape, (double)val);
                return array;
            }
            if (tfTensor.getFloatValCount() > 0) {
                float[] jArray = new float[tfTensor.getFloatValCount()];
                for (e = 0; e < tfTensor.getFloatValCount(); ++e) {
                    jArray[e] = tfTensor.getFloatVal(e);
                }
                INDArray array = Nd4j.create(Nd4j.createBuffer(jArray), arrayShape, 99L);
                return array;
            }
            if (tfTensor.getTensorContent().size() > 0) {
                ByteBuffer bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
                FloatBuffer fb = bb.order(ByteOrder.nativeOrder()).asFloatBuffer();
                float[] fa = new float[fb.capacity()];
                for (int e4 = 0; e4 < fb.capacity(); ++e4) {
                    fa[e4] = fb.get(e4);
                }
                if (fa.length == 0) {
                    throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
                }
                if (fa.length == 1) {
                    return Nd4j.trueScalar(Float.valueOf(fa[0]));
                }
                if (arrayShape.length == 1) {
                    return Nd4j.trueVector(fa);
                }
                INDArray array = Nd4j.create(fa, arrayShape, 'c', 0L);
                return array;
            }
        } else if (tfTensor.getDtype() == DataType.DT_DOUBLE) {
            if (tfTensor.getDoubleValCount() == 1 || ArrayUtil.prod((int[])arrayShape) == 1) {
                if (tfTensor.getDoubleValCount() < 1) {
                    return Nd4j.trueScalar(0.0);
                }
                double val = tfTensor.getDoubleVal(0);
                INDArray array = Nd4j.trueScalar(val);
                return array;
            }
            if (tfTensor.getDoubleValCount() > 0) {
                double[] jArray = new double[tfTensor.getDoubleValCount()];
                for (e = 0; e < tfTensor.getDoubleValCount(); ++e) {
                    jArray[e] = tfTensor.getDoubleVal(e);
                }
                INDArray array = Nd4j.create(jArray, arrayShape, 0L, 'c');
                return array;
            }
            if (tfTensor.getTensorContent().size() > 0) {
                ByteBuffer bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
                DoubleBuffer fb = bb.order(ByteOrder.nativeOrder()).asDoubleBuffer();
                double[] da = new double[fb.capacity()];
                for (int e5 = 0; e5 < fb.capacity(); ++e5) {
                    da[e5] = fb.get(e5);
                }
                if (da.length == 0) {
                    throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
                }
                if (da.length == 1) {
                    return Nd4j.trueScalar(da[0]);
                }
                if (arrayShape.length == 1) {
                    return Nd4j.trueVector(da);
                }
                INDArray array = Nd4j.create(da, arrayShape, 0L, 'c');
                return array;
            }
        } else if (tfTensor.getDtype() == DataType.DT_INT64) {
            if (tfTensor.getInt64ValCount() == 1 || ArrayUtil.prod((int[])arrayShape) == 1) {
                if (tfTensor.getDoubleValCount() < 1) {
                    return Nd4j.trueScalar(0.0);
                }
                double val = tfTensor.getInt64Val(0);
                INDArray array = Nd4j.trueScalar(val);
                return array;
            }
            if (tfTensor.getInt64ValCount() > 0) {
                double[] jArray = new double[tfTensor.getInt64ValCount()];
                for (e = 0; e < tfTensor.getInt64ValCount(); ++e) {
                    jArray[e] = tfTensor.getInt64Val(e);
                }
                INDArray array = Nd4j.create(jArray, arrayShape, 0L, 'c');
                return array;
            }
            if (tfTensor.getTensorContent().size() > 0) {
                ByteBuffer bb = tfTensor.getTensorContent().asReadOnlyByteBuffer();
                LongBuffer lb = bb.order(ByteOrder.nativeOrder()).asLongBuffer();
                float[] fa = new float[lb.capacity()];
                for (int e6 = 0; e6 < lb.capacity(); ++e6) {
                    fa[e6] = lb.get(e6);
                }
                if (fa.length == 0) {
                    throw new ND4JIllegalStateException("Can't find Tensor values! Probably you've forgot to freeze graph before saving?");
                }
                if (fa.length == 1) {
                    return Nd4j.trueScalar(Float.valueOf(fa[0]));
                }
                if (arrayShape.length == 1) {
                    return Nd4j.trueVector(fa);
                }
                INDArray array = Nd4j.create(fa, arrayShape, 'c', 0L);
                return array;
            }
        } else {
            throw new UnsupportedOperationException("Unknown dataType found: [" + (Object)((Object)tfTensor.getDtype()) + "]");
        }
        throw new ND4JIllegalStateException("Invalid method state");
    }

    @Override
    public int[] getShapeFromTensor(NodeDef tensorProto) {
        if (tensorProto.containsAttr(SHAPE_KEY)) {
            return this.shapeFromShapeProto(tensorProto.getAttrOrThrow(SHAPE_KEY).getShape());
        }
        if (!tensorProto.containsAttr(VALUE_ATTR_KEY)) {
            return null;
        }
        return this.shapeFromShapeProto(tensorProto.getAttrOrThrow(VALUE_ATTR_KEY).getTensor().getTensorShape());
    }

    @Override
    public Set<String> opsToIgnore() {
        return this.graphMapper;
    }

    @Override
    public String getInputFromNode(NodeDef node, int index) {
        return node.getInput(index);
    }

    @Override
    public int numInputsFor(NodeDef nodeDef) {
        return nodeDef.getInputCount();
    }

    private int[] shapeFromShapeProto(TensorShapeProto tensorShapeProto) {
        int[] shape = new int[tensorShapeProto.getDimList().size()];
        for (int i = 0; i < shape.length; ++i) {
            shape[i] = (int)tensorShapeProto.getDim(i).getSize();
        }
        if (shape.length < 2) {
            shape = shape.length == 1 ? new int[]{1, shape[0]} : new int[]{1, 1};
        }
        return shape;
    }

    public IfImportState nodesForIf(NodeDef from, GraphDef graph) {
        int currNodeIndex = graph.getNodeList().indexOf(from);
        String trueDefName = from.getInput(1);
        String falseDefName = from.getInput(0);
        String scopeId = UUID.randomUUID().toString();
        String scopeName = scopeId + "-" + trueDefName.substring(0, trueDefName.indexOf("/"));
        String trueDefScopeName = scopeName + "-true-scope";
        String falseDefScopeName = scopeName + "-false-scope";
        boolean onFalseDefinition = true;
        boolean onTrueDefinition = false;
        ArrayList<NodeDef> falseBodyNodes = new ArrayList<NodeDef>();
        ArrayList<NodeDef> trueBodyNodes = new ArrayList<NodeDef>();
        ArrayList<NodeDef> conditionNodes = new ArrayList<NodeDef>();
        LinkedHashSet<String> seenNames = new LinkedHashSet<String>();
        for (int i = currNodeIndex; i >= 0; --i) {
            if (graph.getNode(i).getName().equals(trueDefName)) {
                onFalseDefinition = false;
                onTrueDefinition = true;
            }
            if (graph.getNode(i).getName().contains("pred_id")) {
                onTrueDefinition = false;
            }
            if (onTrueDefinition && !graph.getNode(i).equals(from)) {
                trueBodyNodes.add(graph.getNode(i));
                continue;
            }
            if (onFalseDefinition && !graph.getNode(i).equals(from)) {
                falseBodyNodes.add(graph.getNode(i));
                continue;
            }
            NodeDef currNode = graph.getNode(i);
            if (currNode.equals(from)) continue;
            if (!seenNames.contains(graph.getNode(i).getName()) && !graph.getNode(i).getName().contains("pred_id")) break;
            for (int inputIdx = 0; inputIdx < currNode.getInputCount(); ++inputIdx) {
                seenNames.add(currNode.getInput(inputIdx));
            }
            seenNames.add(graph.getNode(i).getName());
            conditionNodes.add(graph.getNode(i));
        }
        Collections.reverse(falseBodyNodes);
        Collections.reverse(trueBodyNodes);
        Collections.reverse(conditionNodes);
        return IfImportState.builder().condNodes(conditionNodes).falseNodes(falseBodyNodes).trueNodes(trueBodyNodes).conditionBodyScopeName(falseDefScopeName).falseBodyScopeName(falseDefScopeName).trueBodyScopeName(trueDefScopeName).conditionBodyScopeName(scopeName).build();
    }
}

