package org.tensorflow;

import java.util.ArrayList;
import java.util.List;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
import org.tensorflow.internal.c_api.TFJ_GraphId;
import org.tensorflow.internal.c_api.TFJ_Scope;
import org.tensorflow.internal.c_api.TF_Operation;
import org.tensorflow.internal.c_api.TF_Output;

/* loaded from: input_file:org/tensorflow/AbstractGradientAdapter.class */
public abstract class AbstractGradientAdapter extends TFJ_GradFuncAdapter {
    protected abstract List<Operand<?>> apply(Graph graph, TFJ_Scope tFJ_Scope, GraphOperation graphOperation, List<Output<?>> list);

    public int call(TFJ_GraphId tFJ_GraphId, TFJ_Scope tFJ_Scope, TF_Operation tF_Operation, TF_Output tF_Output, int i, PointerPointer pointerPointer) {
        PointerScope pointerScope = new PointerScope();
        try {
            Graph findGraph = Graph.findGraph(tFJ_GraphId);
            GraphOperation graphOperation = new GraphOperation(findGraph, tF_Operation);
            List<Output<?>> fromNativeOutputs = fromNativeOutputs(findGraph, tF_Output, i);
            findGraph.setDangerousGradientBuilder(true);
            List<Operand<?>> apply = apply(findGraph, tFJ_Scope, graphOperation, fromNativeOutputs);
            findGraph.setDangerousGradientBuilder(false);
            pointerPointer.put(toNativeOutputs(apply));
            int size = apply.size();
            pointerScope.close();
            return size;
        } catch (Throwable th) {
            try {
                pointerScope.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private static List<Output<?>> fromNativeOutputs(Graph graph, TF_Output tF_Output, int i) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            TF_Output position = tF_Output.position(i2);
            arrayList.add(i2, new Output(new GraphOperation(graph, position.oper()), position.index()));
        }
        return arrayList;
    }

    private static TF_Output toNativeOutputs(List<Operand<?>> list) {
        TF_Output tF_Output = new TF_Output(Pointer.malloc(list.size() * Pointer.sizeof(TF_Output.class)));
        for (int i = 0; i < list.size(); i++) {
            Output<?> asOutput = list.get(i).asOutput();
            TF_Output pointer = tF_Output.getPointer(i);
            pointer.oper(((GraphOperation) asOutput.op()).getUnsafeNativeHandle());
            pointer.index(asOutput.index());
        }
        return tF_Output;
    }
}
