/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.execution;

import java.io.File;
import java.nio.ByteBuffer;
import java.util.HashMap;
import java.util.Map;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.autodiff.execution.GraphExecutioner;
import org.nd4j.autodiff.execution.Node;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatResult;
import org.nd4j.graph.FlatVariable;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.ResultWrapperAbstraction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NativeGraphExecutioner
implements GraphExecutioner {
    private static final Logger log = LoggerFactory.getLogger(NativeGraphExecutioner.class);

    public GraphExecutioner.Type getExecutionerType() {
        return GraphExecutioner.Type.LOCAL;
    }

    public INDArray[] executeGraph(SameDiff sd) {
        return this.executeGraph(sd, ExecutorConfiguration.builder().outputMode(OutputMode.IMPLICIT).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).build());
    }

    public INDArray[] reuseGraph(SameDiff graph, Map<Integer, INDArray> inputs) {
        throw new UnsupportedOperationException();
    }

    public ByteBuffer convertToFlatBuffers(SameDiff sd, ExecutorConfiguration configuration, Map<Integer, Node> intermediate) {
        log.info("Configuration: {}", (Object)configuration);
        return sd.asFlatBuffers(configuration);
    }

    public ByteBuffer convertToFlatBuffers(SameDiff sd, ExecutorConfiguration configuration) {
        return this.convertToFlatBuffers(sd, configuration, new HashMap<Integer, Node>());
    }

    public INDArray[] executeGraph(SameDiff sd, ExecutorConfiguration configuration) {
        HashMap<Integer, Node> intermediate = new HashMap<Integer, Node>();
        ByteBuffer buffer = this.convertToFlatBuffers(sd, configuration, intermediate);
        BytePointer bPtr = new BytePointer(buffer);
        log.info("Buffer length: {}", (Object)buffer.limit());
        ResultWrapperAbstraction res = NativeOpsHolder.getInstance().getDeviceNativeOps().executeFlatGraph(null, (Pointer)bPtr);
        if (res == null) {
            throw new ND4JIllegalStateException("Graph execution failed");
        }
        PagedPointer pagedPointer = new PagedPointer(res.pointer(), res.size());
        FlatResult fr = FlatResult.getRootAsFlatResult((ByteBuffer)pagedPointer.asBytePointer().asByteBuffer());
        log.info("VarMap: {}", (Object)sd.variableMap());
        INDArray[] results = new INDArray[fr.variablesLength()];
        for (int e = 0; e < fr.variablesLength(); ++e) {
            INDArray val;
            FlatVariable var = fr.variables(e);
            FlatArray ndarray = var.ndarray();
            results[e] = val = Nd4j.createFromFlatArray((FlatArray)ndarray);
            if (var.name() != null && sd.variableMap().containsKey(var.name())) {
                sd.associateArrayWithVariable(val, (SDVariable)sd.variableMap().get(var.name()));
                continue;
            }
            if (sd.variableMap().get(var.name()) != null) {
                sd.associateArrayWithVariable(val, sd.getVariable(var.name()));
                continue;
            }
            log.warn("Unknown variable received: [{}]", (Object)var.name());
        }
        NativeOpsHolder.getInstance().getDeviceNativeOps().deleteResultWrapper(res);
        return results;
    }

    public static long getOpNum(String name, Op.Type type) {
        if (type == Op.Type.CUSTOM) {
            return ((CustomOpDescriptor)Nd4j.getExecutioner().getCustomOperations().get(name.toLowerCase())).getHash();
        }
        try {
            DifferentialFunction op = DifferentialFunctionClassHolder.getInstance().getInstance(name);
            return op.opNum();
        }
        catch (Exception e) {
            throw new RuntimeException("Could not find op number for operation: [" + name + "]", e);
        }
    }

    public static byte getFlatOpType(Op.Type type) {
        switch (type) {
            case SCALAR: {
                return 10;
            }
            case BROADCAST: {
                return 12;
            }
            case TRANSFORM_FLOAT: {
                return 0;
            }
            case TRANSFORM_SAME: {
                return 1;
            }
            case TRANSFORM_STRICT: {
                return 3;
            }
            case TRANSFORM_BOOL: {
                return 2;
            }
            case REDUCE_FLOAT: {
                return 5;
            }
            case REDUCE_BOOL: {
                return 8;
            }
            case REDUCE_SAME: {
                return 6;
            }
            case INDEXREDUCE: {
                return 9;
            }
            case CUSTOM: {
                return 21;
            }
        }
        throw new UnsupportedOperationException("Unknown op type passed in: " + type);
    }

    public INDArray[] executeGraph(int id, SDVariable ... variables) {
        return new INDArray[0];
    }

    public int registerGraph(SameDiff graph) {
        return 0;
    }

    public INDArray[] importProto(File file) {
        throw new UnsupportedOperationException("Not implemented yet");
    }
}

