package org.nd4j.autodiff.samediff;

import com.google.flatbuffers.FlatBufferBuilder;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.NonNull;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutionMode;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.ListenerVariables;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.api.OutAndGrad;
import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder;
import org.nd4j.autodiff.samediff.array.ThreadSafeArrayHolder;
import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.TrainingSession;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDBitwise;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDImage;
import org.nd4j.autodiff.samediff.ops.SDLinalg;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.autodiff.util.SameDiffUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.common.util.ND4JFileUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.graph.UpdaterState;
import org.nd4j.imports.VariableUtils;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4UnresolvedOutputVariables;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Sets;
import org.nd4j.shade.guava.collect.Table;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.GraphDef;

/* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiff.class */
public class SameDiff extends SDBaseOps {
    private static final Logger log = LoggerFactory.getLogger(SameDiff.class);
    protected static final String GRAD_FN_KEY = "grad";
    private final Map<String, Variable> variables;
    private final Map<String, SameDiffOp> ops;
    private final Map<Long, InferenceSession> sessions;
    private ArrayHolder constantArrays;
    private ArrayHolder variablesArrays;
    private final Map<Long, Map<String, INDArray>> placeholdersPerThread;
    private final List<String> lossVariables;
    private final List<Listener> listeners;
    private final List<NameScope> nameScopes;
    private List<String> outputs;
    private TrainingConfig trainingConfig;
    private boolean initializedTraining;
    private Map<String, GradientUpdater> updaterMap;
    private int variableId;
    public final SDMath math;
    public final SDRandom random;
    public final SDNN nn;
    public final SDCNN cnn;
    public final SDRNN rnn;
    public final SDLoss loss;
    public final SDImage image;
    public final SDBitwise bitwise;
    public final SDLinalg linalg;
    private Map<String, SameDiff> sameDiffFunctionInstances;
    private Table<String, String, String> fieldVariableResolutionMapping;
    private transient AtomicBoolean wasRegistered;
    private boolean debugMode;
    private Stack<ArgumentInterceptor> argumentInterceptors;
    private Set<ArgumentInterceptor> pausedArgumentInterceptors;
    private Set<String> blockNames;
    boolean logExecution;
    private SameDiff parent;
    private SameDiff child;

    public SDMath math() {
        return this.math;
    }

    public SDRandom random() {
        return this.random;
    }

    public SDNN nn() {
        return this.nn;
    }

    public SDCNN cnn() {
        return this.cnn;
    }

    public SDRNN rnn() {
        return this.rnn;
    }

    public SDLoss loss() {
        return this.loss;
    }

    public SDImage image() {
        return this.image;
    }

    public SDBitwise bitwise() {
        return this.bitwise;
    }

    public SDLinalg linalg() {
        return this.linalg;
    }

    public SameDiff disableDebugging() {
        this.debugMode = false;
        return this;
    }

    public SameDiff enableDebugMode() {
        this.debugMode = true;
        return this;
    }

    public void setListeners(Listener... listenerArr) {
        this.listeners.clear();
        addListeners(listenerArr);
    }

    public void setListeners(Collection<? extends Listener> collection) {
        this.listeners.clear();
        addListeners(collection);
    }

    public void addListeners(Listener... listenerArr) {
        addListeners(Arrays.asList(listenerArr));
    }

    public void addListeners(Collection<? extends Listener> collection) {
        this.listeners.addAll(collection);
    }

    public List<Listener> getListeners() {
        return this.listeners;
    }

    public void setArrayHolders(@NonNull ArrayHolder arrayHolder, @NonNull ArrayHolder arrayHolder2, boolean z) {
        if (arrayHolder == null) {
            throw new NullPointerException("variableArrayHolder is marked non-null but is null");
        }
        if (arrayHolder2 == null) {
            throw new NullPointerException("constantArrayHolder is marked non-null but is null");
        }
        if (z) {
            arrayHolder.initFrom(this.variablesArrays);
            arrayHolder2.initFrom(this.constantArrays);
        }
        this.variablesArrays = arrayHolder;
        this.constantArrays = arrayHolder2;
    }

    public String currentNameScope() {
        if (this.nameScopes.isEmpty()) {
            return null;
        }
        StringBuilder sb = new StringBuilder();
        boolean z = true;
        for (NameScope nameScope : this.nameScopes) {
            if (!z) {
                sb.append("/");
            }
            sb.append(nameScope.getName());
            z = false;
        }
        return sb.toString();
    }

    protected String nameWithScope(String str) {
        String currentNameScope = currentNameScope();
        if (currentNameScope != null && !str.startsWith(currentNameScope + "/")) {
            return currentNameScope + "/" + str;
        }
        return str;
    }

    void addNameScope(NameScope nameScope) {
        this.nameScopes.add(nameScope);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void closeNameScope(NameScope nameScope) {
        Preconditions.checkState(!this.nameScopes.isEmpty(), "Cannot close name scope: no name scopes are currently defined");
        Preconditions.checkState(this.nameScopes.get(this.nameScopes.size() - 1).equals(nameScope), "Cannot close name scope %s: Name scopes must be closed in order. Current name scopes: \"%s\"", nameScope, currentNameScope());
        this.nameScopes.remove(this.nameScopes.size() - 1);
    }

    public NameScope withNameScope(String str) {
        NameScope nameScope = new NameScope(this, str);
        addNameScope(nameScope);
        return nameScope;
    }

    public List<SameDiffOp> getOpsInScope(NameScope nameScope) {
        ArrayList arrayList = new ArrayList();
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            if (sameDiffOp.getName().startsWith(nameScope.getName())) {
                arrayList.add(sameDiffOp);
            }
        }
        return arrayList;
    }

    public List<SameDiffOp> getOpsInScope(String str) {
        return getOpsInScope(new NameScope(this, str));
    }

    public List<SDVariable> getVariablesInScope(NameScope nameScope) {
        ArrayList arrayList = new ArrayList();
        for (SDVariable sDVariable : variables()) {
            if (sDVariable.name().startsWith(nameScope.getName())) {
                arrayList.add(sDVariable);
            }
        }
        return arrayList;
    }

    public List<SDVariable> getVariablesInScope(String str) {
        return getVariablesInScope(new NameScope(this, str));
    }

    public SDVariable invokeGraphOn(SameDiff sameDiff) {
        HashMap hashMap = new HashMap();
        int i = 1;
        for (SDVariable sDVariable : variables()) {
            SDVariable clone = sDVariable.clone(this);
            SDVariable var = sameDiff.var(clone);
            if (sDVariable.getVariableType() != VariableType.ARRAY && sDVariable.getArr() != null) {
                sameDiff.associateArrayWithVariable(sDVariable.getArr(), var);
            }
            hashMap.put(Integer.valueOf(i), Integer.valueOf(i));
            clone.setSameDiff(sameDiff);
            i++;
        }
        HashMap hashMap2 = new HashMap();
        int i2 = 0;
        Iterator<Variable> it = this.variables.values().iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            hashMap2.put(it.next().getName(), Integer.valueOf(i3));
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            DifferentialFunction op = sameDiffOp.getOp();
            DifferentialFunction cloneViaSerialize = FlatBuffersMapper.cloneViaSerialize(this, op, hashMap2);
            cloneViaSerialize.setSameDiff(sameDiff);
            cloneViaSerialize.setOwnName(op.getOwnName());
            if (sameDiff.opExists(op.getOwnName())) {
                sameDiff.putOpForId(op.getOwnName(), op);
            }
            linkedHashMap.put(op.getOwnName(), cloneViaSerialize);
            SDVariable[] args = op.args();
            SDVariable[] outputVariables = op.outputVariables();
            sameDiff.addArgsFor(args, cloneViaSerialize);
            sameDiff.addOutgoingFor(outputVariables, op);
            for (SDVariable sDVariable2 : cloneViaSerialize.args()) {
                sDVariable2.setSameDiff(sameDiff);
            }
            for (SDVariable sDVariable3 : cloneViaSerialize.outputVariables()) {
                sDVariable3.setSameDiff(sameDiff);
            }
            sameDiff.ops.put(op.getOwnName(), sameDiffOp);
        }
        return sameDiff.variables().get(sameDiff.variables().size() - 1);
    }

    public boolean opExists(String str) {
        return this.ops.containsKey(str);
    }

    public DifferentialFunction getVariableOutputOp(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" found in graph", str);
        if (this.variables.get(str).getOutputOfOp() == null || this.ops.get(VariableUtils.stripVarSuffix(this.variables.get(str).getOutputOfOp())) == null) {
            return null;
        }
        return this.ops.get(VariableUtils.stripVarSuffix(this.variables.get(str).getOutputOfOp())).getOp();
    }

    public DifferentialFunction getOpById(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("id is marked non-null but is null");
        }
        if (this.ops.containsKey(str)) {
            return this.ops.get(str).getOp();
        }
        throw new ND4JIllegalStateException("No function with id " + str + " found!");
    }

    public void putOpForId(String str, DifferentialFunction differentialFunction) {
        if (this.ops.containsKey(str) && this.ops.get(str).getOp() == null) {
            throw new ND4JIllegalStateException("Function by id already exists!");
        }
        if (this.ops.containsKey(str)) {
            return;
        }
        this.ops.put(str, SameDiffOp.builder().name(str).op(differentialFunction).build());
    }

    public String[] getInputsForOp(@NonNull DifferentialFunction differentialFunction) {
        if (differentialFunction == null) {
            throw new NullPointerException("function is marked non-null but is null");
        }
        if (!this.ops.containsKey(differentialFunction.getOwnName())) {
            throw new ND4JIllegalStateException("Unknown function instance id found: \"" + differentialFunction.getOwnName() + "\"");
        }
        List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
        if (inputsToOp == null) {
            return null;
        }
        return (String[]) inputsToOp.toArray(new String[inputsToOp.size()]);
    }

    public String[] getOutputsForOp(DifferentialFunction differentialFunction) {
        if (!this.ops.containsKey(differentialFunction.getOwnName())) {
            throw new ND4JIllegalStateException("Illegal function instance id found " + differentialFunction.getOwnName());
        }
        List<String> outputsOfOp = this.ops.get(differentialFunction.getOwnName()).getOutputsOfOp();
        if (outputsOfOp == null) {
            return null;
        }
        return (String[]) outputsOfOp.toArray(new String[outputsOfOp.size()]);
    }

    public SDVariable[] getOutputVariablesForOp(DifferentialFunction differentialFunction) {
        String[] outputsForOp = getOutputsForOp(differentialFunction);
        if (outputsForOp == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + differentialFunction);
        }
        SDVariable[] sDVariableArr = new SDVariable[outputsForOp.length];
        for (int i = 0; i < outputsForOp.length; i++) {
            sDVariableArr[i] = getVariable(outputsForOp[i]);
        }
        return sDVariableArr;
    }

    public SDVariable[] getInputVariablesForOp(DifferentialFunction differentialFunction) {
        String[] inputsForOp = getInputsForOp(differentialFunction);
        if (inputsForOp == null) {
            throw new ND4JIllegalStateException("No inputs found for function " + differentialFunction);
        }
        SDVariable[] sDVariableArr = new SDVariable[inputsForOp.length];
        for (int i = 0; i < inputsForOp.length; i++) {
            sDVariableArr[i] = getVariable(inputsForOp[i]);
            if (sDVariableArr[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
        }
        return sDVariableArr;
    }

    public void setArrayForVariable(@NonNull String str, @NonNull INDArray iNDArray) {
        if (str == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        if (iNDArray == null) {
            throw new NullPointerException("arr is marked non-null but is null");
        }
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        if (variable.isConstant()) {
            this.constantArrays.setArray(str, iNDArray);
            return;
        }
        if (variable.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.setArray(str, iNDArray);
        } else {
            if (!variable.isPlaceHolder()) {
                throw new UnsupportedOperationException("Cannot set variable of type " + variable.getVariableType() + " using this method");
            }
            long id = Thread.currentThread().getId();
            if (!this.placeholdersPerThread.containsKey(Long.valueOf(id))) {
                this.placeholdersPerThread.put(Long.valueOf(id), new HashMap());
            }
            this.placeholdersPerThread.get(Long.valueOf(id)).put(str, iNDArray);
        }
    }

    public boolean arrayAlreadyExistsForVarName(String str) {
        SDVariable variable = getVariable(str);
        switch (variable.getVariableType()) {
            case VARIABLE:
                return this.variablesArrays.hasArray(str);
            case ARRAY:
                long id = Thread.currentThread().getId();
                return this.sessions.containsKey(Long.valueOf(id)) && this.sessions.get(Long.valueOf(id)).contains(str, AbstractSession.OUTER_FRAME, 0, null);
            case CONSTANT:
                return this.constantArrays.hasArray(str);
            case PLACEHOLDER:
                return this.placeholdersPerThread.containsKey(Long.valueOf(Thread.currentThread().getId())) && this.placeholdersPerThread.get(Long.valueOf(Thread.currentThread().getId())).containsKey(str);
            default:
                throw new RuntimeException("Unknown variable type: " + variable.getVariableType());
        }
    }

    public INDArray getArrForVarName(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        Preconditions.checkState(this.variables.containsKey(str), "No variable found with name \"%s\"", str);
        SDVariable variable = this.variables.get(str).getVariable();
        switch (variable.getVariableType()) {
            case VARIABLE:
                return this.variablesArrays.getArray(str);
            case ARRAY:
                InferenceSession inferenceSession = this.sessions.get(Long.valueOf(Thread.currentThread().getId()));
                if (inferenceSession == null) {
                    throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
                }
                return inferenceSession.get(str, AbstractSession.OUTER_FRAME, 0, null, false);
            case CONSTANT:
                if (this.constantArrays.hasArray(str)) {
                    return this.constantArrays.getArray(str);
                }
                return null;
            case PLACEHOLDER:
                long id = Thread.currentThread().getId();
                if (this.placeholdersPerThread.get(Long.valueOf(id)) == null || !this.placeholdersPerThread.get(Long.valueOf(id)).containsKey(str)) {
                    return null;
                }
                return this.placeholdersPerThread.get(Long.valueOf(id)).get(str);
            default:
                throw new RuntimeException("Unknown variable type: " + variable.getVariableType());
        }
    }

    public void associateArrayWithVariable(INDArray iNDArray, @NonNull String str) {
        if (str == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        Preconditions.checkState(this.variables.containsKey(str), "Cannot associate array with variable \"%s\": variable \"%s\" does not exist in this SameDiff instance", str, str);
        associateArrayWithVariable(iNDArray, getVariable(str));
    }

    public void associateArrayWithVariable(INDArray iNDArray, SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new ND4JIllegalArgumentException("Variable must not be null!");
        }
        if (iNDArray == null) {
            throw new ND4JIllegalArgumentException("Array must not be null");
        }
        if (sDVariable.dataType() != iNDArray.dataType()) {
            iNDArray = iNDArray.castTo(sDVariable.dataType());
        }
        Preconditions.checkState(sDVariable.dataType() == iNDArray.dataType(), "Variable \"%s\" has datatype %s: cannot associate array with type %s with this variable", sDVariable.name(), sDVariable.dataType(), iNDArray.dataType());
        if (this.sessions.get(Long.valueOf(Thread.currentThread().getId())) == null) {
            this.sessions.put(Long.valueOf(Thread.currentThread().getId()), new InferenceSession(this));
        }
        if (iNDArray.isAttached()) {
            iNDArray = iNDArray.detach();
        }
        switch (sDVariable.getVariableType()) {
            case VARIABLE:
                this.variablesArrays.setArray(sDVariable.name(), iNDArray);
                break;
            case ARRAY:
                throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for this type of variable is calculated ");
            case CONSTANT:
                this.constantArrays.setArray(sDVariable.name(), iNDArray);
                break;
            case PLACEHOLDER:
                long[] placeholderShape = sDVariable.placeholderShape();
                Preconditions.checkState(placeholderShape == null || Shape.shapeMatchesPlaceholder(placeholderShape, iNDArray.shape()), "Invalid array shape: cannot associate an array with shape %ndShape with a placeholder of shape %s:shape is wrong rank or does not match on one or more dimensions", iNDArray, placeholderShape);
                long id = Thread.currentThread().getId();
                if (!this.placeholdersPerThread.containsKey(Long.valueOf(id))) {
                    this.placeholdersPerThread.put(Long.valueOf(id), new HashMap());
                }
                this.placeholdersPerThread.get(Long.valueOf(id)).put(sDVariable.name(), iNDArray);
                break;
            default:
                throw new IllegalStateException("Unknown variable type: " + sDVariable.getVariableType());
        }
        if (this.sameDiffFunctionInstances == null || this.sameDiffFunctionInstances.size() <= 0) {
            return;
        }
        Iterator<Map.Entry<String, SameDiff>> it = this.sameDiffFunctionInstances.entrySet().iterator();
        while (it.hasNext()) {
            SameDiff value = it.next().getValue();
            SDVariable variable = value.getVariable(sDVariable.name());
            if (variable != null) {
                value.associateArrayWithVariable(iNDArray, variable);
            }
        }
    }

    public void assignArray(@NonNull INDArray iNDArray, @NonNull SDVariable sDVariable) {
        if (iNDArray == null) {
            throw new NullPointerException("arr is marked non-null but is null");
        }
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        Preconditions.checkState(sDVariable.getVariableType() == VariableType.VARIABLE || sDVariable.getVariableType() == VariableType.CONSTANT, "assignArray method can only be used with VARIBLE or CONSTANT type SDVariables, variable \"%s\" has type %s", sDVariable.name(), sDVariable.getVariableType());
        if (iNDArray.isView()) {
            iNDArray = iNDArray.dup();
        }
        if (sDVariable.getVariableType() == VariableType.VARIABLE) {
            this.variablesArrays.setArray(sDVariable.name(), iNDArray);
        } else {
            this.constantArrays.setArray(sDVariable.name(), iNDArray);
        }
    }

    public void putSubFunction(String str, SameDiff sameDiff) {
        if (this.sameDiffFunctionInstances.containsKey(str) && this.sameDiffFunctionInstances.get(str) != sameDiff) {
            throw new ND4JIllegalStateException("Unable to replace samediff namespace. Please choose another opName");
        }
        this.sameDiffFunctionInstances.put(str, sameDiff);
    }

    public Map<String, SDVariable> variableMap() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Variable variable : this.variables.values()) {
            linkedHashMap.put(variable.getName(), variable.getVariable());
        }
        return linkedHashMap;
    }

    public Collection<String> definedFunctionNames() {
        return this.sameDiffFunctionInstances.keySet();
    }

    private SameDiff() {
        super(null);
        this.variables = new LinkedHashMap();
        this.ops = new LinkedHashMap();
        this.sessions = new ConcurrentHashMap();
        this.constantArrays = new ThreadSafeArrayHolder(true);
        this.variablesArrays = new ThreadSafeArrayHolder(true);
        this.placeholdersPerThread = new ConcurrentHashMap();
        this.lossVariables = new ArrayList();
        this.listeners = new ArrayList();
        this.nameScopes = new ArrayList();
        this.variableId = 0;
        this.math = new SDMath(this);
        this.random = new SDRandom(this);
        this.nn = new SDNN(this);
        this.cnn = new SDCNN(this);
        this.rnn = new SDRNN(this);
        this.loss = new SDLoss(this);
        this.image = new SDImage(this);
        this.bitwise = new SDBitwise(this);
        this.linalg = new SDLinalg(this);
        this.wasRegistered = new AtomicBoolean(false);
        this.argumentInterceptors = new Stack<>();
        this.pausedArgumentInterceptors = new HashSet();
        this.blockNames = new HashSet();
        this.logExecution = true;
        this.sd = this;
        this.sameDiffFunctionInstances = new LinkedHashMap();
        this.fieldVariableResolutionMapping = HashBasedTable.create();
    }

    public <X extends SDVariable> X setupFunction(X x) {
        Preconditions.checkNotNull(x, "Passed in function must not be null!");
        if (!(x instanceof SDVariable)) {
            return x;
        }
        if (x.getSameDiff() != this) {
            x.setSameDiff(this);
        }
        return x;
    }

    public void addOutgoingFor(SDVariable[] sDVariableArr, DifferentialFunction differentialFunction) {
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < strArr.length; i++) {
            strArr[i] = sDVariableArr[i].name();
        }
        addOutgoingFor(strArr, differentialFunction);
    }

    public void addOutgoingFor(String[] strArr, DifferentialFunction differentialFunction) {
        if (differentialFunction.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (this.ops.get(differentialFunction.getOwnName()).getOutputsOfOp() != null && !this.ops.get(differentialFunction.getOwnName()).getOutputsOfOp().isEmpty()) {
            throw new ND4JIllegalStateException("Outgoing arguments already declared for " + differentialFunction);
        }
        if (strArr == null) {
            throw new ND4JIllegalStateException("Var names can not be null!");
        }
        for (String str : strArr) {
            if (str == null) {
                throw new ND4JIllegalStateException("Variable name elements can not be null!");
            }
        }
        this.ops.get(differentialFunction.getOwnName()).setOutputsOfOp(Arrays.asList(strArr));
        for (String str2 : strArr) {
            this.variables.get(str2).setOutputOfOp(differentialFunction.getOwnName());
        }
    }

    public void addArgumentInterceptor(@NonNull ArgumentInterceptor argumentInterceptor) {
        if (argumentInterceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        this.argumentInterceptors.push(argumentInterceptor);
    }

    private boolean isArgumentInterceptorPaused(@NonNull ArgumentInterceptor argumentInterceptor) {
        if (argumentInterceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        return this.pausedArgumentInterceptors.contains(argumentInterceptor);
    }

    private ArgumentInterceptor getArgumentInterceptorToUse() {
        if (this.argumentInterceptors.isEmpty()) {
            return null;
        }
        ArgumentInterceptor peek = this.argumentInterceptors.peek();
        int i = 1;
        while (isArgumentInterceptorPaused(peek)) {
            if (this.argumentInterceptors.size() - i < 0) {
                return null;
            }
            peek = this.argumentInterceptors.elementAt(this.argumentInterceptors.size() - i);
            i++;
        }
        return peek;
    }

    public void removeArgumentInterceptor() {
        if (this.argumentInterceptors.isEmpty()) {
            return;
        }
        this.argumentInterceptors.pop();
    }

    public void pauseArgumentInterceptor() {
        this.pausedArgumentInterceptors.add(this.argumentInterceptors.peek());
    }

    public void pauseArgumentInterceptor(@NonNull ArgumentInterceptor argumentInterceptor) {
        if (argumentInterceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        this.pausedArgumentInterceptors.add(argumentInterceptor);
    }

    public void unpauseArgumentInterceptor() {
        this.pausedArgumentInterceptors.remove(this.argumentInterceptors.peek());
    }

    public void unpauseArgumentInterceptor(@NonNull ArgumentInterceptor argumentInterceptor) {
        if (argumentInterceptor == null) {
            throw new NullPointerException("interceptor is marked non-null but is null");
        }
        this.pausedArgumentInterceptors.remove(argumentInterceptor);
    }

    public void addArgsFor(String[] strArr, DifferentialFunction differentialFunction) {
        ArgumentInterceptor argumentInterceptorToUse = getArgumentInterceptorToUse();
        if (argumentInterceptorToUse != null) {
            pauseArgumentInterceptor(argumentInterceptorToUse);
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = argumentInterceptorToUse.intercept(getVariable(strArr[i])).name();
            }
            unpauseArgumentInterceptor(argumentInterceptorToUse);
        }
        if (differentialFunction.getOwnName() == null) {
            throw new ND4JIllegalStateException("Instance id can not be null. Function not initialized properly");
        }
        if (!this.ops.containsKey(differentialFunction.getOwnName())) {
            this.ops.put(differentialFunction.getOwnName(), SameDiffOp.builder().name(differentialFunction.getOwnName()).op(differentialFunction).build());
        }
        this.ops.get(differentialFunction.getOwnName()).setInputsToOp(Arrays.asList(strArr));
        for (String str : strArr) {
            if (this.variables.containsKey(str)) {
                List<String> inputsForOp = this.variables.get(str).getInputsForOp();
                if (inputsForOp == null) {
                    inputsForOp = new ArrayList();
                    this.variables.get(str).setInputsForOp(inputsForOp);
                }
                if (!inputsForOp.contains(differentialFunction.getOwnName())) {
                    inputsForOp.add(differentialFunction.getOwnName());
                }
            }
        }
    }

    public void addArgsFor(SDVariable[] sDVariableArr, DifferentialFunction differentialFunction) {
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < strArr.length; i++) {
            if (sDVariableArr[i] == null) {
                throw new ND4JIllegalStateException("Found null variable at index " + i);
            }
            strArr[i] = sDVariableArr[i].name();
        }
        addArgsFor(strArr, differentialFunction);
    }

    public void replaceArgFor(int i, @NonNull SDVariable sDVariable, @NonNull DifferentialFunction differentialFunction) {
        if (sDVariable == null) {
            throw new NullPointerException("newArg is marked non-null but is null");
        }
        if (differentialFunction == null) {
            throw new NullPointerException("function is marked non-null but is null");
        }
        Preconditions.checkArgument(i < differentialFunction.args().length, "Index out of range: function " + differentialFunction.getOwnName() + " only has " + differentialFunction.args().length + " args but you are tryingto replace the argument at " + i);
        String name = differentialFunction.arg(i).name();
        String name2 = sDVariable.name();
        ArrayList arrayList = new ArrayList(this.ops.get(differentialFunction.getOwnName()).getInputsToOp());
        arrayList.set(i, name2);
        this.ops.get(differentialFunction.getOwnName()).setInputsToOp(arrayList);
        List<String> inputsForOp = this.variables.get(name2).getInputsForOp();
        if (inputsForOp == null) {
            inputsForOp = new ArrayList();
            this.variables.get(name2).setInputsForOp(inputsForOp);
        }
        if (!inputsForOp.contains(differentialFunction.getOwnName())) {
            inputsForOp.add(differentialFunction.getOwnName());
        }
        List<String> inputsForOp2 = this.variables.get(name).getInputsForOp();
        if (inputsForOp2 == null || ArrayUtils.contains(differentialFunction.argNames(), name)) {
            return;
        }
        inputsForOp2.remove(differentialFunction.getOwnName());
    }

    public boolean hasArgs(DifferentialFunction differentialFunction) {
        List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
        return inputsToOp != null && inputsToOp.size() > 0;
    }

    public void clearPlaceholders(boolean z) {
        if (z) {
            this.placeholdersPerThread.clear();
        } else {
            this.placeholdersPerThread.remove(Long.valueOf(Thread.currentThread().getId()));
        }
        Iterator<SameDiff> it = this.sameDiffFunctionInstances.values().iterator();
        while (it.hasNext()) {
            it.next().clearPlaceholders(z);
        }
    }

    public void clearOpInputs() {
        for (SameDiffOp sameDiffOp : this.ops.values()) {
            if (sameDiffOp.getOp() instanceof Op) {
                Op op = (Op) sameDiffOp.getOp();
                op.setX(null);
                if (op.y() != null) {
                    op.setY(null);
                }
            } else if (sameDiffOp.getOp() instanceof DynamicCustomOp) {
                ((DynamicCustomOp) sameDiffOp.getOp()).setInputArguments((INDArray[]) null);
            }
        }
        Iterator<SameDiff> it = this.sameDiffFunctionInstances.values().iterator();
        while (it.hasNext()) {
            it.next().clearOpInputs();
        }
    }

    public DifferentialFunction[] ops() {
        ArrayList arrayList = new ArrayList(this.ops.size());
        Iterator<SameDiffOp> it = this.ops.values().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getOp());
        }
        return (DifferentialFunction[]) arrayList.toArray(new DifferentialFunction[arrayList.size()]);
    }

    public int hashCode() {
        return (31 * super.hashCode()) + (this.variables != null ? this.variables.hashCode() : 0);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        SameDiff sameDiff = (SameDiff) obj;
        return this.variables.equals(sameDiff.variables) && this.ops.equals(sameDiff.ops);
    }

    public static SameDiff create() {
        return new SameDiff();
    }

    public SameDiff dup() {
        try {
            return fromFlatBuffers(asFlatBuffers(true));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public long numElements() {
        long j = 0;
        Iterator<SDVariable> it = variables().iterator();
        while (it.hasNext()) {
            if (it.next().getShape() != null) {
                j += ArrayUtil.prod(r0);
            }
        }
        return j;
    }

    public List<String> inputs() {
        ArrayList arrayList = new ArrayList();
        for (String str : this.variables.keySet()) {
            if (isPlaceHolder(str)) {
                arrayList.add(str);
            }
        }
        return arrayList;
    }

    public List<String> outputs() {
        return this.outputs;
    }

    public void setOutputs(String... strArr) {
        setOutputs(strArr == null ? null : Arrays.asList(strArr));
    }

    public void setOutputs(List<String> list) {
        if (list != null) {
            Iterator<String> it = list.iterator();
            while (it.hasNext()) {
                Preconditions.checkArgument(this.variables.containsKey(it.next()), "Cannot set variable \"%s\" as an output: SameDiff instance does not contain a variable with this name");
            }
        }
        this.outputs = list;
    }

    public List<SDVariable> variables() {
        return new ArrayList(variableMap().values());
    }

    public List<String> getLossVariables() {
        return Collections.unmodifiableList(this.lossVariables);
    }

    public void setLossVariables(@NonNull String... strArr) {
        if (strArr == null) {
            throw new NullPointerException("lossVariableNames is marked non-null but is null");
        }
        this.lossVariables.clear();
        for (String str : strArr) {
            addLossVariable(str);
        }
        this.sameDiffFunctionInstances.remove("grad");
    }

    public void setLossVariables(@NonNull SDVariable... sDVariableArr) {
        if (sDVariableArr == null) {
            throw new NullPointerException("lossVariables is marked non-null but is null");
        }
        String[] strArr = new String[sDVariableArr.length];
        for (int i = 0; i < sDVariableArr.length; i++) {
            strArr[i] = sDVariableArr[i].name();
        }
        setLossVariables(strArr);
    }

    public void addLossVariable(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("variableName is marked non-null but is null");
        }
        Preconditions.checkState(hasVariable(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        Preconditions.checkState(variable.dataType().isFPType(), "Only floating point type variables can be marked as losses to be minimized. SDVariable \"%s\" has datatype %s", str, variable.dataType());
        Preconditions.checkState(variable.getVariableType() == VariableType.ARRAY, "Only ARRAY type SDVariables can be marked as losses to be minimized. SDVariable \"%s\" has variable type %s", str, variable.getVariableType());
        if (this.lossVariables.contains(str)) {
            return;
        }
        this.lossVariables.add(str);
    }

    public void addLossVariable(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        addLossVariable(sDVariable.name());
    }

    public void setTrainingConfig(TrainingConfig trainingConfig) {
        this.trainingConfig = trainingConfig;
    }

    public History fit(@NonNull DataSet dataSet, @NonNull Listener... listenerArr) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return fit(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), 1, false, null, 1, listenerArr);
    }

    public History fit(@NonNull MultiDataSet multiDataSet, @NonNull Listener... listenerArr) {
        if (multiDataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return fit(new SingletonMultiDataSetIterator(multiDataSet), 1, false, null, 1, listenerArr);
    }

    public History fit(@NonNull DataSetIterator dataSetIterator, int i, DataSetIterator dataSetIterator2, int i2, @NonNull Listener... listenerArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return fit().train(dataSetIterator, i).validate(dataSetIterator2, i2).listeners(listenerArr).exec();
    }

    public History fit(@NonNull DataSetIterator dataSetIterator, int i, @NonNull Listener... listenerArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return fit().train(dataSetIterator, i).listeners(listenerArr).exec();
    }

    public History fit(@NonNull MultiDataSetIterator multiDataSetIterator, int i, MultiDataSetIterator multiDataSetIterator2, int i2, @NonNull Listener... listenerArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return fit(multiDataSetIterator, i, true, multiDataSetIterator2, i2, listenerArr);
    }

    public History fit(@NonNull MultiDataSetIterator multiDataSetIterator, int i, @NonNull Listener... listenerArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        return fit().train(multiDataSetIterator, i).listeners(listenerArr).exec();
    }

    public FitConfig fit() {
        return new FitConfig(this);
    }

    protected synchronized History fit(@NonNull MultiDataSetIterator multiDataSetIterator, int i, boolean z, MultiDataSetIterator multiDataSetIterator2, int i2, @NonNull Listener... listenerArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iter is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        boolean asyncSupported = multiDataSetIterator.asyncSupported();
        boolean z2 = false;
        if (multiDataSetIterator2 != null) {
            z2 = multiDataSetIterator2.asyncSupported();
        }
        if (asyncSupported) {
            multiDataSetIterator = new AsyncMultiDataSetIterator(multiDataSetIterator, 3, true);
        }
        if (z2) {
            multiDataSetIterator2 = new AsyncMultiDataSetIterator(multiDataSetIterator2, 3, true);
        }
        try {
            History fitHelper = fitHelper(multiDataSetIterator, i, z, multiDataSetIterator2, i2, Arrays.asList(listenerArr));
            if (asyncSupported) {
                ((AsyncMultiDataSetIterator) multiDataSetIterator).shutdown();
            }
            if (z2) {
                ((AsyncMultiDataSetIterator) multiDataSetIterator2).shutdown();
            }
            return fitHelper;
        } catch (Throwable th) {
            if (asyncSupported) {
                ((AsyncMultiDataSetIterator) multiDataSetIterator).shutdown();
            }
            if (z2) {
                ((AsyncMultiDataSetIterator) multiDataSetIterator2).shutdown();
            }
            throw th;
        }
    }

    protected synchronized History fitHelper(@NonNull MultiDataSetIterator multiDataSetIterator, int i, boolean z, MultiDataSetIterator multiDataSetIterator2, int i2, @NonNull List<Listener> list) {
        Set<String> trainingVariables;
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        Preconditions.checkNotNull(multiDataSetIterator, "Iterator must not be null");
        Preconditions.checkState(i > 0, "Number of training epochs must be a positive number. Got: %s", i);
        Preconditions.checkState(this.trainingConfig != null, "No training configuration has been set. A training configuration must be set before training. Use setTrainingConfig(TrainingConfig)");
        Preconditions.checkState(i == 1 || multiDataSetIterator.resetSupported(), "Cannot train for multiple epochs on an iterator that does not support resetting");
        HistoryListener historyListener = new HistoryListener(this.trainingConfig);
        List<Listener> arrayList = new ArrayList<>();
        arrayList.add(historyListener);
        for (Listener listener : this.listeners) {
            if (listener.isActive(Operation.TRAINING)) {
                arrayList.add(listener);
            }
        }
        for (Listener listener2 : list) {
            if (listener2.isActive(Operation.TRAINING)) {
                arrayList.add(listener2);
            }
        }
        validateListenerActivations(arrayList, Operation.TRAINING);
        validateListenerActivations(arrayList, Operation.TRAINING_VALIDATION);
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        boolean z2 = false;
        long id = Thread.currentThread().getId();
        boolean z3 = !arrayList.isEmpty();
        At build = At.builder().epoch(this.trainingConfig.getEpochCount()).iteration(this.trainingConfig.getIterationCount()).trainingThreadNum(0).javaThreadNum(id).operation(Operation.TRAINING).build();
        LossCurve lossCurve = null;
        HashSet hashSet = new HashSet();
        Iterator<Listener> it = arrayList.iterator();
        while (it.hasNext()) {
            ListenerVariables requiredVariables = it.next().requiredVariables(this);
            if (requiredVariables != null && (trainingVariables = requiredVariables.trainingVariables()) != null) {
                hashSet.addAll(trainingVariables);
            }
        }
        List<Listener> arrayList2 = new ArrayList<>(list);
        for (Listener listener3 : this.listeners) {
            if (!arrayList2.contains(listener3)) {
                arrayList2.add(listener3);
            }
        }
        arrayList2.add(historyListener);
        SameDiff function = getFunction("grad");
        if (function == null) {
            createGradFunction();
            function = getFunction("grad");
        }
        TrainingSession trainingSession = new TrainingSession(function);
        function.setTrainingConfig(this.trainingConfig);
        Iterator<Listener> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            it2.next().operationStart(function, Operation.TRAINING);
        }
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (Variable variable : this.variables.values()) {
            if (variable.getVariable().getVariableType() == VariableType.VARIABLE) {
                linkedHashSet.add(variable.getName());
            }
        }
        Loss loss = null;
        for (int i3 = 0; i3 < i; i3++) {
            if (z && z3) {
                build.setEpoch(this.trainingConfig.getEpochCount());
                Iterator<Listener> it3 = arrayList.iterator();
                while (it3.hasNext()) {
                    it3.next().epochStart(this, build);
                }
            }
            long currentTimeMillis = System.currentTimeMillis();
            double[] dArr = null;
            int i4 = 0;
            while (multiDataSetIterator.hasNext()) {
                long currentTimeMillis2 = z3 ? System.currentTimeMillis() : 0L;
                MultiDataSet next = multiDataSetIterator.next();
                long currentTimeMillis3 = z3 ? System.currentTimeMillis() : 0L;
                if (!z2) {
                    Preconditions.checkState(this.trainingConfig.getDataSetFeatureMapping().size() == next.numFeatureArrays(), "The number of dataset feature mapping variables set in the training configuration (%s) must match the number of dataset feature arrays (%s)", this.trainingConfig.getDataSetFeatureMapping().size(), next.numFeatureArrays());
                    List<String> dataSetLabelMapping = this.trainingConfig.getDataSetLabelMapping();
                    int size = dataSetLabelMapping == null ? 0 : dataSetLabelMapping.size();
                    Preconditions.checkState(size == next.numLabelsArrays(), "The number of dataset label mapping variables set in the training configuration (%s) must match the number of dataset label arrays (%s)", size, next.numLabelsArrays());
                    z2 = true;
                }
                if (z3) {
                    build.setIteration(this.trainingConfig.getIterationCount());
                    Iterator<Listener> it4 = arrayList.iterator();
                    while (it4.hasNext()) {
                        it4.next().iterationStart(this, build, next, currentTimeMillis3 - currentTimeMillis2);
                    }
                }
                Map<String, INDArray> placeholderMap = toPlaceholderMap(next);
                Preconditions.checkState(placeholderMap.size() > 0, "No placeholder variables were set for training");
                if (!this.initializedTraining) {
                    initializeTraining();
                }
                loss = trainingSession.trainingIteration(this.trainingConfig, placeholderMap, linkedHashSet, this.updaterMap, next, getLossVariables(), arrayList2, build);
                if (dArr == null) {
                    dArr = (double[]) loss.getLosses().clone();
                } else {
                    for (int i5 = 0; i5 < dArr.length; i5++) {
                        double[] dArr2 = dArr;
                        int i6 = i5;
                        dArr2[i6] = dArr2[i6] + loss.getLosses()[i5];
                    }
                }
                i4++;
                this.trainingConfig.incrementIterationCount();
            }
            long currentTimeMillis4 = System.currentTimeMillis() - currentTimeMillis;
            if (z) {
                List<String> lossNames = loss.getLossNames();
                for (int i7 = 0; i7 < dArr.length; i7++) {
                    double[] dArr3 = dArr;
                    int i8 = i7;
                    dArr3[i8] = dArr3[i8] / i4;
                }
                lossCurve = lossCurve != null ? lossCurve.addLossAndCopy(dArr, lossNames) : new LossCurve(dArr, lossNames);
            }
            if (z) {
                if (z3) {
                    boolean z4 = false;
                    Listener listener4 = null;
                    for (Listener listener5 : arrayList) {
                        if (listener5.epochEnd(this, build, lossCurve, currentTimeMillis4) == ListenerResponse.STOP && i3 < i - 1) {
                            z4 = true;
                            listener4 = listener5;
                        }
                    }
                    if (z4) {
                        log.info("Stopping training early.  Listener " + listener4 + " gave a STOP signal at epoch " + build.epoch() + " and iteration " + build.iteration());
                        Iterator<Listener> it5 = arrayList.iterator();
                        while (it5.hasNext()) {
                            it5.next().operationEnd(this, Operation.TRAINING);
                        }
                        if (i3 < i - 1) {
                            multiDataSetIterator.reset();
                        }
                        if (z) {
                            this.trainingConfig.incrementEpochCount();
                        }
                        return historyListener.getReport();
                    }
                    if (multiDataSetIterator2 != null && (i2 <= 0 || i3 % i2 == 0)) {
                        long currentTimeMillis5 = System.currentTimeMillis();
                        outputHelper(multiDataSetIterator2, new At(build.epoch(), 0, 0, 0L, null, Operation.TRAINING_VALIDATION), arrayList2, new String[0]);
                        long currentTimeMillis6 = System.currentTimeMillis() - currentTimeMillis5;
                        boolean z5 = false;
                        Listener listener6 = null;
                        for (Listener listener7 : arrayList) {
                            if (listener7.validationDone(this, build, currentTimeMillis6) == ListenerResponse.STOP && i3 < i - 1) {
                                z5 = true;
                                listener6 = listener7;
                            }
                        }
                        if (z5) {
                            log.info("Stopping training early from validation.  Listener " + listener6 + " gave a STOP signal at epoch " + build.epoch() + " and iteration " + build.iteration());
                            Iterator<Listener> it6 = arrayList.iterator();
                            while (it6.hasNext()) {
                                it6.next().operationEnd(this, Operation.TRAINING);
                            }
                            if (i3 < i - 1) {
                                multiDataSetIterator.reset();
                            }
                            if (z) {
                                this.trainingConfig.incrementEpochCount();
                            }
                            return historyListener.getReport();
                        }
                    }
                }
                this.trainingConfig.incrementEpochCount();
            }
            if (i3 < i - 1) {
                multiDataSetIterator.reset();
            }
        }
        Iterator<Listener> it7 = arrayList.iterator();
        while (it7.hasNext()) {
            it7.next().operationEnd(this, Operation.TRAINING);
        }
        return historyListener.getReport();
    }

    private void validateListenerActivations(List<Listener> list, Operation operation) {
        for (Listener listener : list) {
            ListenerVariables requiredVariables = listener.requiredVariables(this);
            if (requiredVariables != null) {
                for (String str : requiredVariables.requiredVariables(operation)) {
                    if (!this.variables.containsKey(str)) {
                        Preconditions.checkState(false, "Listener %s requested variable %s that is not defined in this SameDiff graph", listener, str);
                    }
                }
            }
        }
    }

    public double calcRegularizationScore() {
        Preconditions.checkState(this.trainingConfig != null, "No training configuration has been set. A training configuration must be set before calculating the L2 loss. Use setTrainingConfig(TrainingConfig)");
        if (this.trainingConfig.getRegularization() == null || this.trainingConfig.getRegularization().isEmpty()) {
            return 0.0d;
        }
        List<Regularization> regularization = this.trainingConfig.getRegularization();
        double d = 0.0d;
        Iterator<Variable> it = this.variables.values().iterator();
        while (it.hasNext()) {
            SDVariable variable = it.next().getVariable();
            if (variable.getVariableType() == VariableType.VARIABLE && variable.dataType().isFPType()) {
                Iterator<Regularization> it2 = regularization.iterator();
                while (it2.hasNext()) {
                    d += it2.next().score(variable.getArr(), this.trainingConfig.getIterationCount(), this.trainingConfig.getEpochCount());
                }
            }
        }
        return d;
    }

    protected void initializeTraining() {
        if (this.initializedTraining) {
            return;
        }
        if (this.trainingConfig == null) {
            throw new ND4JIllegalStateException("Please specify a training config with setTrainingConfig");
        }
        this.updaterMap = new HashMap();
        for (Variable variable : this.variables.values()) {
            if (variable.getVariable().getVariableType() == VariableType.VARIABLE && variable.getVariable().dataType().isFPType()) {
                INDArray arr = variable.getVariable().getArr();
                long stateSize = this.trainingConfig.getUpdater().stateSize(arr.length());
                INDArray createUninitialized = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize);
                GradientUpdater instantiate = this.trainingConfig.getUpdater().instantiate(createUninitialized, false);
                instantiate.setStateViewArray(createUninitialized, arr.shape(), arr.ordering(), true);
                this.updaterMap.put(variable.getName(), instantiate);
            }
        }
        this.initializedTraining = true;
    }

    private Map<String, INDArray> toPlaceholderMap(MultiDataSet multiDataSet) {
        HashMap hashMap = new HashMap();
        int i = 0;
        Iterator<String> it = this.trainingConfig.getDataSetFeatureMapping().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            hashMap.put(it.next(), multiDataSet.getFeatures(i2));
        }
        int i3 = 0;
        if (this.trainingConfig.getDataSetLabelMapping() != null) {
            Iterator<String> it2 = this.trainingConfig.getDataSetLabelMapping().iterator();
            while (it2.hasNext()) {
                int i4 = i3;
                i3++;
                hashMap.put(it2.next(), multiDataSet.getLabels(i4));
            }
        }
        if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().size() > 0) {
            int i5 = 0;
            for (String str : this.trainingConfig.getDataSetFeatureMaskMapping()) {
                if (str == null) {
                    i5++;
                } else {
                    int i6 = i5;
                    i5++;
                    hashMap.put(str, multiDataSet.getFeaturesMaskArray(i6));
                }
            }
        }
        if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().size() > 0) {
            int i7 = 0;
            for (String str2 : this.trainingConfig.getDataSetLabelMaskMapping()) {
                if (str2 == null) {
                    i7++;
                } else {
                    int i8 = i7;
                    i7++;
                    hashMap.put(str2, multiDataSet.getLabelsMaskArray(i8));
                }
            }
        }
        return hashMap;
    }

    public void evaluate(@NonNull DataSetIterator dataSetIterator, @NonNull String str, @NonNull List<Listener> list, @NonNull IEvaluation... iEvaluationArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        Preconditions.checkArgument(iEvaluationArr != null && iEvaluationArr.length > 0, "No evaluations were passed to the evaluate method");
        evaluate().data(dataSetIterator).evaluate(str, iEvaluationArr).listeners((Listener[]) list.toArray(new Listener[0])).exec();
    }

    public void evaluate(@NonNull DataSetIterator dataSetIterator, @NonNull String str, @NonNull IEvaluation... iEvaluationArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        evaluate().data(dataSetIterator).evaluate(str, iEvaluationArr).exec();
    }

    public void evaluate(@NonNull DataSetIterator dataSetIterator, @NonNull Map<String, IEvaluation> map, @NonNull Listener... listenerArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (map == null) {
            throw new NullPointerException("variableEvals is marked non-null but is null");
        }
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (String str : map.keySet()) {
            hashMap.put(str, 0);
            hashMap2.put(str, Collections.singletonList(map.get(str)));
        }
        evaluate(new MultiDataSetIteratorAdapter(dataSetIterator), hashMap2, hashMap, listenerArr);
    }

    public void evaluateMultiple(DataSetIterator dataSetIterator, Map<String, List<IEvaluation>> map, @NonNull Listener... listenerArr) {
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        HashMap hashMap = new HashMap();
        Iterator<String> it = map.keySet().iterator();
        while (it.hasNext()) {
            hashMap.put(it.next(), 0);
        }
        evaluate(new MultiDataSetIteratorAdapter(dataSetIterator), map, hashMap, listenerArr);
    }

    public void evaluate(@NonNull MultiDataSetIterator multiDataSetIterator, @NonNull String str, int i, @NonNull List<Listener> list, @NonNull IEvaluation... iEvaluationArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        Preconditions.checkArgument(iEvaluationArr != null && iEvaluationArr.length > 0, "No evaluations were passed to the evaluate method");
        evaluate().data(multiDataSetIterator).evaluate(str, i, iEvaluationArr).listeners((Listener[]) list.toArray(new Listener[0])).exec();
    }

    public void evaluate(@NonNull MultiDataSetIterator multiDataSetIterator, @NonNull String str, int i, @NonNull IEvaluation... iEvaluationArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (str == null) {
            throw new NullPointerException("outputVariable is marked non-null but is null");
        }
        if (iEvaluationArr == null) {
            throw new NullPointerException("evaluations is marked non-null but is null");
        }
        evaluate().data(multiDataSetIterator).evaluate(str, i, iEvaluationArr).exec();
    }

    public void evaluate(MultiDataSetIterator multiDataSetIterator, Map<String, List<IEvaluation>> map, Map<String, Integer> map2, Listener... listenerArr) {
        evaluateHelper(multiDataSetIterator, map, map2, At.defaultAt(Operation.EVALUATION), listenerArr);
    }

    public EvaluationConfig evaluate() {
        return new EvaluationConfig(this);
    }

    private void evaluateHelper(MultiDataSetIterator multiDataSetIterator, Map<String, List<IEvaluation>> map, Map<String, Integer> map2, At at, @NonNull Listener... listenerArr) {
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        Preconditions.checkState(this.trainingConfig != null, "Training config has not been set");
        Preconditions.checkState(map.keySet().equals(map2.keySet()), "Keysets for variable evaluations and for the prediction label mapping must be equal. Keys for variables to evaluate: %s vs. keys for label mapping: %s", map.keySet(), map2.keySet());
        ArrayList arrayList = new ArrayList();
        for (Listener listener : listenerArr) {
            if (listener.isActive(at.operation())) {
                arrayList.add(listener);
            }
        }
        for (Listener listener2 : this.listeners) {
            if (listener2.isActive(at.operation())) {
                arrayList.add(listener2);
            }
        }
        validateListenerActivations(arrayList, at.operation());
        Iterator<Listener> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().operationStart(this, at.operation());
        }
        boolean z = !arrayList.isEmpty();
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        HashSet hashSet = new HashSet(map.keySet());
        if (z) {
            Iterator<Listener> it2 = arrayList.iterator();
            while (it2.hasNext()) {
                ListenerVariables requiredVariables = it2.next().requiredVariables(this);
                if (requiredVariables != null) {
                    hashSet.addAll(requiredVariables.evaluationVariables());
                }
            }
        }
        String[] strArr = (String[]) hashSet.toArray(new String[0]);
        while (multiDataSetIterator.hasNext()) {
            MultiDataSet next = multiDataSetIterator.next();
            Map<String, INDArray> directExecHelper = directExecHelper(toPlaceholderMap(next), at, next, Collections.emptyList(), arrayList, strArr);
            for (Map.Entry<String, List<IEvaluation>> entry : map.entrySet()) {
                INDArray iNDArray = directExecHelper.get(entry.getKey());
                Iterator<IEvaluation> it3 = entry.getValue().iterator();
                while (it3.hasNext()) {
                    it3.next().eval(next.getLabels(map2.get(entry.getKey()).intValue()), iNDArray, next.getLabelsMaskArray(map2.get(entry.getKey()).intValue()));
                }
            }
            at.setIteration(at.iteration() + 1);
        }
        Iterator<Listener> it4 = arrayList.iterator();
        while (it4.hasNext()) {
            it4.next().operationEnd(this, at.operation());
        }
    }

    public Map<String, INDArray> output(@NonNull DataSet dataSet, @NonNull String... strArr) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return outputBatches(new SingletonMultiDataSetIterator(dataSet.toMultiDataSet()), strArr).get(0);
    }

    public Map<String, INDArray> output(@NonNull MultiDataSet multiDataSet, @NonNull String... strArr) {
        if (multiDataSet == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return outputBatches(new SingletonMultiDataSetIterator(multiDataSet), strArr).get(0);
    }

    public Map<String, INDArray> output(@NonNull DataSetIterator dataSetIterator, @NonNull List<Listener> list, @NonNull String... strArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return output().data(dataSetIterator).output(strArr).listeners((Listener[]) list.toArray(new Listener[0])).exec();
    }

    public Map<String, INDArray> output(@NonNull DataSetIterator dataSetIterator, @NonNull String... strArr) {
        if (dataSetIterator == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return output().data(dataSetIterator).output(strArr).exec();
    }

    public List<Map<String, INDArray>> outputBatches(DataSetIterator dataSetIterator, List<Listener> list, String... strArr) {
        return output().data(dataSetIterator).output(strArr).listeners((Listener[]) list.toArray(new Listener[0])).execBatches();
    }

    public List<Map<String, INDArray>> outputBatches(DataSetIterator dataSetIterator, String... strArr) {
        return output().data(dataSetIterator).output(strArr).execBatches();
    }

    public Map<String, INDArray> output(@NonNull MultiDataSetIterator multiDataSetIterator, @NonNull List<Listener> list, @NonNull String... strArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("iterator is marked non-null but is null");
        }
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return SameDiffUtils.stackOutputs(outputHelper(multiDataSetIterator, At.defaultAt(Operation.INFERENCE), list, strArr));
    }

    public Map<String, INDArray> output(@NonNull MultiDataSetIterator multiDataSetIterator, @NonNull String... strArr) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("dataSet is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return output().data(multiDataSetIterator).output(strArr).exec();
    }

    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator multiDataSetIterator, List<Listener> list, String... strArr) {
        return outputHelper(multiDataSetIterator, At.defaultAt(Operation.INFERENCE), list, strArr);
    }

    public List<Map<String, INDArray>> outputBatches(MultiDataSetIterator multiDataSetIterator, String... strArr) {
        return output().data(multiDataSetIterator).output(strArr).execBatches();
    }

    public OutputConfig output() {
        return new OutputConfig(this);
    }

    private List<Map<String, INDArray>> outputHelper(MultiDataSetIterator multiDataSetIterator, At at, @NonNull List<Listener> list, @NonNull String... strArr) {
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        if (strArr == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        Preconditions.checkState(this.trainingConfig != null, "Training config has not been set");
        List<Listener> arrayList = new ArrayList<>();
        for (Listener listener : list) {
            if (listener.isActive(at.operation())) {
                arrayList.add(listener);
            }
        }
        for (Listener listener2 : this.listeners) {
            if (listener2.isActive(at.operation())) {
                arrayList.add(listener2);
            }
        }
        validateListenerActivations(arrayList, at.operation());
        Iterator<Listener> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().operationStart(this, at.operation());
        }
        boolean z = !arrayList.isEmpty();
        String[] strArr2 = (String[]) ((strArr == null || strArr.length == 0) ? getLossVariables() : Arrays.asList(strArr)).toArray(new String[0]);
        ArrayList arrayList2 = new ArrayList();
        if (!multiDataSetIterator.hasNext() && multiDataSetIterator.resetSupported()) {
            multiDataSetIterator.reset();
        }
        HashSet hashSet = new HashSet();
        for (Listener listener3 : arrayList) {
            if (at.operation() == Operation.TRAINING_VALIDATION) {
                hashSet.addAll(listener3.requiredVariables(this).validationVariables());
            } else {
                hashSet.addAll(listener3.requiredVariables(this).inferenceVariables());
            }
        }
        while (multiDataSetIterator.hasNext()) {
            long currentTimeMillis = z ? System.currentTimeMillis() : 0L;
            MultiDataSet next = multiDataSetIterator.next();
            long currentTimeMillis2 = z ? System.currentTimeMillis() : 0L;
            Map<String, INDArray> placeholderMap = toPlaceholderMap(next);
            if (z) {
                Iterator<Listener> it2 = arrayList.iterator();
                while (it2.hasNext()) {
                    it2.next().iterationStart(this, at, next, currentTimeMillis2 - currentTimeMillis);
                }
                Map<String, INDArray> directExecHelper = directExecHelper(placeholderMap, at, next, hashSet, arrayList, strArr2);
                Iterator<Listener> it3 = arrayList.iterator();
                while (it3.hasNext()) {
                    it3.next().iterationDone(this, at, next, null);
                }
                arrayList2.add(directExecHelper);
            } else {
                arrayList2.add(directExecHelper(placeholderMap, at, next, hashSet, arrayList, strArr2));
            }
            at.setIteration(at.iteration() + 1);
        }
        Iterator<Listener> it4 = arrayList.iterator();
        while (it4.hasNext()) {
            it4.next().operationEnd(this, at.operation());
        }
        return arrayList2;
    }

    public BatchOutputConfig batchOutput() {
        return new BatchOutputConfig(this);
    }

    public Map<String, INDArray> outputAll(Map<String, INDArray> map) {
        return batchOutput().outputAll().inputs(map).output();
    }

    public INDArray outputSingle(Map<String, INDArray> map, String str) {
        return batchOutput().output(str).inputs(map).outputSingle();
    }

    public Map<String, INDArray> output(Map<String, INDArray> map, @NonNull List<String> list) {
        if (list == null) {
            throw new NullPointerException("outputs is marked non-null but is null");
        }
        return batchOutput().output((String[]) list.toArray(new String[0])).inputs(map).output();
    }

    public Map<String, INDArray> output(Map<String, INDArray> map, String... strArr) {
        return batchOutput().output(strArr).inputs(map).output();
    }

    public Map<String, INDArray> output(Map<String, INDArray> map, List<Listener> list, String... strArr) {
        return batchOutputHelper(map, list, Operation.INFERENCE, strArr);
    }

    protected Map<String, INDArray> batchOutputHelper(Map<String, INDArray> map, List<Listener> list, Operation operation, String... strArr) {
        ArrayList arrayList = new ArrayList();
        if (operation == null) {
            operation = Operation.INFERENCE;
        }
        for (Listener listener : this.listeners) {
            if (listener.isActive(operation)) {
                arrayList.add(listener);
            }
        }
        if (list != null) {
            for (Listener listener2 : list) {
                if (listener2.isActive(operation)) {
                    arrayList.add(listener2);
                }
            }
        }
        Iterator<Listener> it = arrayList.iterator();
        while (it.hasNext()) {
            it.next().operationStart(this, operation);
        }
        validateListenerActivations(arrayList, operation);
        Map<String, INDArray> directExecHelper = directExecHelper(map, At.defaultAt(operation), null, Collections.emptyList(), arrayList, strArr);
        Iterator<Listener> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            it2.next().operationEnd(this, operation);
        }
        return directExecHelper;
    }

    protected Map<String, INDArray> directExecHelper(Map<String, INDArray> map, At at, MultiDataSet multiDataSet, Collection<String> collection, List<Listener> list, String... strArr) {
        if (at == null) {
            at = At.defaultAt();
        }
        Preconditions.checkState(strArr != null && strArr.length > 0, "No outputs were specified");
        long id = Thread.currentThread().getId();
        if (!this.sessions.containsKey(Long.valueOf(id))) {
            log.info("Creating new InferenceSession for thread {}", Long.valueOf(id));
            this.sessions.put(Long.valueOf(id), new InferenceSession(this));
        }
        List<String> inputs = inputs();
        if (map == null && inputs != null) {
            map = this.placeholdersPerThread.get(Long.valueOf(Thread.currentThread().getId()));
        }
        return this.sessions.get(Long.valueOf(id)).output(strArr == null ? Collections.emptyList() : Arrays.asList(strArr), map, multiDataSet, collection, list, at);
    }

    public SDVariable one(String str, int... iArr) {
        return one(str, Nd4j.defaultFloatingPointType(), iArr);
    }

    public SDVariable one(String str, long... jArr) {
        return one(str, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable one(String str, DataType dataType, int... iArr) {
        return one(str, dataType, ArrayUtil.toLongArray(iArr));
    }

    public SDVariable one(String str, DataType dataType, long... jArr) {
        return constant(str, Nd4j.ones(dataType, jArr));
    }

    public SDVariable zero(String str, long... jArr) {
        return zero(str, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable zero(String str, int... iArr) {
        return zero(str, Nd4j.defaultFloatingPointType(), iArr);
    }

    public SDVariable zero(String str, DataType dataType, long... jArr) {
        return constant(str, Nd4j.zeros(dataType, jArr));
    }

    public SDVariable zero(String str, DataType dataType, int... iArr) {
        return zero(str, dataType, ArrayUtil.toLongArray(iArr));
    }

    public SDVariable constant(@NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("constant is marked non-null but is null");
        }
        return constant(getNewVarName(), iNDArray);
    }

    public SDVariable constant(String str, @NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("constant is marked non-null but is null");
        }
        Preconditions.checkState(!this.variables.containsKey(str), "Variable with name \"%s\" already exists", str);
        if (str == null || str.length() < 1) {
            str = getNewVarName();
        }
        if (iNDArray.isView()) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    iNDArray = iNDArray.dup();
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        SDVariable sDVariable = new SDVariable(str, VariableType.CONSTANT, this, iNDArray.shape(), iNDArray.dataType());
        String name = sDVariable.name();
        this.variables.put(name, Variable.builder().name(name).variable(sDVariable).build());
        this.constantArrays.setArray(name, iNDArray);
        return sDVariable;
    }

    public SDVariable placeHolder(@NonNull String str, DataType dataType, long... jArr) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        Preconditions.checkState(!this.variables.containsKey(str), "Variable already exists with name %s", str);
        SDVariable sDVariable = new SDVariable(str, VariableType.PLACEHOLDER, this, jArr, dataType);
        this.variables.put(str, Variable.builder().name(str).variable(sDVariable).build());
        return sDVariable;
    }

    public SDVariable var(@NonNull String str, @NonNull WeightInitScheme weightInitScheme, @NonNull DataType dataType, @NonNull long... jArr) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked non-null but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        if (jArr == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return var(str, VariableType.VARIABLE, weightInitScheme, dataType, jArr);
    }

    public SDVariable var(@NonNull String str, @NonNull VariableType variableType, WeightInitScheme weightInitScheme, DataType dataType, long... jArr) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (variableType == null) {
            throw new NullPointerException("variableType is marked non-null but is null");
        }
        if (jArr != null) {
            for (long j : jArr) {
                Preconditions.checkArgument(j != 0, "Cannot create variable with a shape that contains zeros (empty array shape) - got shape %s", jArr);
            }
        }
        String newVarName = (str == null || str.length() < 1) ? getNewVarName() : generateNewVarName(str, 0);
        if (this.variables.containsKey(newVarName)) {
            if (this.nameScopes.isEmpty()) {
                throw new IllegalArgumentException("Another variable with the name " + newVarName + " already exists (current name scope: \"" + currentNameScope() + "\"");
            }
            throw new IllegalArgumentException("Another variable with the name " + newVarName + " already exists.");
        }
        Preconditions.checkState((variableType == VariableType.VARIABLE && weightInitScheme == null) ? false : true, "A weight initalization scheme must be provided when creating a VARIABLE type SDVariables - variable name: \"%s\"", newVarName);
        SDVariable sDVariable = new SDVariable(newVarName, variableType, this, jArr, dataType);
        addVariable(sDVariable);
        if (variableType == VariableType.VARIABLE) {
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    this.variablesArrays.setArray(newVarName, weightInitScheme.create(dataType, jArr));
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        return sDVariable;
    }

    public SDVariable var(@NonNull String str, @NonNull LongShapeDescriptor longShapeDescriptor, WeightInitScheme weightInitScheme) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (longShapeDescriptor == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return var(str, weightInitScheme, longShapeDescriptor.dataType(), longShapeDescriptor.getShape());
    }

    public SDVariable var(String str, DataType dataType, long... jArr) {
        Preconditions.checkNotNull(Boolean.valueOf(jArr != null), "Invalid shape: shape may not be null");
        return Shape.isPlaceholderShape(jArr) ? placeHolder(str, dataType, jArr) : var(str, new ZeroInitScheme(), dataType, jArr);
    }

    public SDVariable var(String str, LongShapeDescriptor longShapeDescriptor) {
        Preconditions.checkNotNull(Boolean.valueOf(longShapeDescriptor != null), "Invalid shape: shape may not be null");
        return var(str, longShapeDescriptor, new ZeroInitScheme());
    }

    public SDVariable var(String str, int... iArr) {
        return var(str, Nd4j.defaultFloatingPointType(), iArr);
    }

    public SDVariable var(String str, long... jArr) {
        return var(str, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable var(@NonNull String str, @NonNull WeightInitScheme weightInitScheme, @NonNull long... jArr) {
        if (str == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (weightInitScheme == null) {
            throw new NullPointerException("weightInitScheme is marked non-null but is null");
        }
        if (jArr == null) {
            throw new NullPointerException("shape is marked non-null but is null");
        }
        return var(str, weightInitScheme, Nd4j.defaultFloatingPointType(), jArr);
    }

    public SDVariable var(String str, DataType dataType, int... iArr) {
        Preconditions.checkNotNull(iArr, "Invalid shape: shape may not be null");
        return Shape.isPlaceholderShape(iArr) ? placeHolder(str, dataType, ArrayUtil.toLongArray(iArr)) : var(str, new ZeroInitScheme(), dataType, ArrayUtil.toLongArray(iArr));
    }

    public SDVariable var(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("v is marked non-null but is null");
        }
        if (this.variables.containsKey(sDVariable.name()) && this.variables.get(sDVariable.name()).getVariable().getArr() != null) {
            return this.variables.get(sDVariable.name()).getVariable();
        }
        if (sDVariable.name() == null || sDVariable.name().length() < 1) {
            throw new IllegalArgumentException("Name for variable must be defined");
        }
        VariableType variableType = sDVariable.getVariableType();
        switch (variableType) {
            case VARIABLE:
                SDVariable sDVariable2 = new SDVariable(sDVariable.name(), sDVariable.getVariableType(), this, sDVariable.getShape(), sDVariable.dataType());
                addVariable(sDVariable2);
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    try {
                        this.variablesArrays.setArray(sDVariable.name(), sDVariable.getArr().dup());
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        return sDVariable2;
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (scopeOutOfWorkspaces != null) {
                        if (th != null) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    throw th3;
                }
            case ARRAY:
                return addVariable(new SDVariable(sDVariable.name(), sDVariable.getVariableType(), this, sDVariable.getShape(), sDVariable.dataType()));
            case CONSTANT:
                return constant(sDVariable.name(), sDVariable.getArr());
            case PLACEHOLDER:
                return placeHolder(sDVariable.name(), sDVariable.dataType(), sDVariable.placeholderShape());
            default:
                throw new RuntimeException("Unknown/not supported variable type: " + variableType);
        }
    }

    private String getNewVarName() {
        return generateNewVarName("sd_var", 0, false);
    }

    public SDVariable var(DataType dataType, int... iArr) {
        return var(getNewVarName(), dataType, iArr);
    }

    public SDVariable var(DataType dataType, long... jArr) {
        return var(getNewVarName(), dataType, jArr);
    }

    public SDVariable var(WeightInitScheme weightInitScheme, DataType dataType, long... jArr) {
        return var(getNewVarName(), weightInitScheme, dataType, jArr);
    }

    public SDVariable var(INDArray iNDArray) {
        return var(getNewVarName(), iNDArray);
    }

    public SDVariable var(String str, @NonNull INDArray iNDArray) {
        if (iNDArray == null) {
            throw new NullPointerException("arr is marked non-null but is null");
        }
        if (this.variables.containsKey(str) && this.variables.get(str).getVariable().getArr() != null) {
            throw new IllegalArgumentException("Another variable with the name " + str + " already exists.");
        }
        Preconditions.checkState(iNDArray.dataType().isFPType(), "Cannot create variable with non-floating point type: provided array has datatype %s. Variables must be floating point type to be trainable by backpropagation.\nFor non floating point types, these should be created as placeholders or constants instead.", iNDArray.dataType());
        Preconditions.checkArgument(!iNDArray.isEmpty(), "Empty arrays cannot be used when creating variables. Array shape: %ndShape", iNDArray);
        if (str == null || str.length() < 1) {
            str = getNewVarName();
        }
        boolean z = false;
        if (iNDArray.isAttached()) {
            iNDArray = iNDArray.detach();
            z = true;
        }
        if (!z) {
            Iterator<String> it = this.variablesArrays.arrayNames().iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                if (this.variablesArrays.getArray(it.next()) == iNDArray) {
                    iNDArray = iNDArray.dup();
                    break;
                }
            }
        }
        SDVariable sDVariable = new SDVariable(str, VariableType.VARIABLE, this, iNDArray.shape(), iNDArray.dataType());
        associateArrayWithVariable(iNDArray, sDVariable);
        addVariable(sDVariable);
        return sDVariable;
    }

    public SDVariable convertToConstant(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("variable is marked non-null but is null");
        }
        convertToConstants(Collections.singletonList(sDVariable));
        return sDVariable;
    }

    public void convertToConstants(List<SDVariable> list) {
        if (list.size() == 0) {
            return;
        }
        boolean z = true;
        for (SDVariable sDVariable : list) {
            if (sDVariable.getVariableType() != VariableType.CONSTANT) {
                z = false;
                Preconditions.checkState(sDVariable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a constant: %s", sDVariable);
            }
        }
        if (z) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove("grad");
        for (SDVariable sDVariable2 : list) {
            String name = sDVariable2.name();
            INDArray arr = sDVariable2.getArr();
            Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", sDVariable2);
            this.constantArrays.setArray(name, arr);
            this.variablesArrays.removeArray(name);
            if (!this.placeholdersPerThread.isEmpty()) {
                Iterator<Map<String, INDArray>> it = this.placeholdersPerThread.values().iterator();
                while (it.hasNext()) {
                    it.next().remove(name);
                }
            }
            sDVariable2.setVariableType(VariableType.CONSTANT);
        }
        if (this.trainingConfig == null || !this.initializedTraining) {
            return;
        }
        for (SDVariable sDVariable3 : list) {
            GradientUpdater remove = this.updaterMap.remove(sDVariable3.name());
            Map<String, INDArray> state = remove == null ? null : remove.getState();
            if (state != null) {
                for (INDArray iNDArray : state.values()) {
                    if (iNDArray.closeable()) {
                        iNDArray.close();
                    }
                }
            }
            if (this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(sDVariable3.name())) {
                ArrayList arrayList = new ArrayList(this.trainingConfig.getDataSetFeatureMapping());
                arrayList.remove(sDVariable3.name());
                this.trainingConfig.setDataSetFeatureMapping(arrayList);
            }
            if (this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(sDVariable3.name())) {
                ArrayList arrayList2 = new ArrayList(this.trainingConfig.getDataSetLabelMapping());
                arrayList2.remove(sDVariable3.name());
                this.trainingConfig.setDataSetLabelMapping(arrayList2);
            }
            if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(sDVariable3.name())) {
                ArrayList arrayList3 = new ArrayList(this.trainingConfig.getDataSetFeatureMaskMapping());
                arrayList3.remove(sDVariable3.name());
                this.trainingConfig.setDataSetFeatureMaskMapping(arrayList3);
            }
            if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().contains(sDVariable3.name())) {
                ArrayList arrayList4 = new ArrayList(this.trainingConfig.getDataSetLabelMaskMapping());
                arrayList4.remove(sDVariable3.name());
                this.trainingConfig.setDataSetLabelMaskMapping(arrayList4);
            }
        }
    }

    public SDVariable convertToVariable(@NonNull SDVariable sDVariable) {
        if (sDVariable == null) {
            throw new NullPointerException("constant is marked non-null but is null");
        }
        Preconditions.checkState(sDVariable.dataType().isFPType(), "Only floating point SDVariables can be converted to variables, datatype of %s is %s", sDVariable.name(), sDVariable.dataType());
        convertToVariables(Collections.singletonList(sDVariable));
        return sDVariable;
    }

    public void convertToVariables(@NonNull List<SDVariable> list) {
        if (list == null) {
            throw new NullPointerException("constants is marked non-null but is null");
        }
        if (list.size() == 0) {
            return;
        }
        boolean z = true;
        for (SDVariable sDVariable : list) {
            if (sDVariable.getVariableType() != VariableType.VARIABLE) {
                z = false;
            }
            Preconditions.checkState(sDVariable.getVariableType() != VariableType.ARRAY, "Cannot convert variable of type ARRAY to a variable: %s", sDVariable);
        }
        if (z) {
            return;
        }
        this.sessions.clear();
        this.sameDiffFunctionInstances.remove("grad");
        for (SDVariable sDVariable2 : list) {
            String name = sDVariable2.name();
            INDArray arr = sDVariable2.getArr();
            Preconditions.checkNotNull(arr, "Could not get array for variable %s: if this is a placeholder, use SDVariable.setArray before converting", sDVariable2);
            this.variablesArrays.setArray(name, arr);
            this.constantArrays.removeArray(name);
            if (!this.placeholdersPerThread.isEmpty()) {
                Iterator<Map<String, INDArray>> it = this.placeholdersPerThread.values().iterator();
                while (it.hasNext()) {
                    it.next().remove(name);
                }
            }
            sDVariable2.setVariableType(VariableType.VARIABLE);
        }
        if (this.trainingConfig == null || !this.initializedTraining) {
            return;
        }
        for (SDVariable sDVariable3 : list) {
            if (!this.updaterMap.containsKey(sDVariable3.name())) {
                INDArray arr2 = sDVariable3.getArr();
                long stateSize = this.trainingConfig.getUpdater().stateSize(arr2.length());
                if (stateSize > 0) {
                    INDArray create = Nd4j.create(arr2.dataType(), 1, stateSize);
                    GradientUpdater instantiate = this.trainingConfig.getUpdater().instantiate(create, false);
                    instantiate.setStateViewArray(create, arr2.shape(), arr2.ordering(), true);
                    this.updaterMap.put(sDVariable3.name(), instantiate);
                } else {
                    this.updaterMap.put(sDVariable3.name(), this.trainingConfig.getUpdater().instantiate((INDArray) null, true));
                }
            }
        }
    }

    public void convertDataTypes(@NonNull Map<String, DataType> map) {
        if (map == null) {
            throw new NullPointerException("dataTypeMap is marked non-null but is null");
        }
        if (map.isEmpty()) {
            return;
        }
        for (Map.Entry<String, DataType> entry : map.entrySet()) {
            String key = entry.getKey();
            Preconditions.checkState(this.variables.containsKey(key), "Cannot change datatype of variable \"%s\": No variable with this name exists", key);
            SDVariable variable = this.variables.get(key).getVariable();
            Preconditions.checkState(variable.getVariableType() != VariableType.ARRAY, "Cannot change datatype of ARRAY type variable \"%s\": datatype of ARRAY type variables is determined by the datatypes of their inputs plus corresponding ");
            if (variable.getVariableType() != VariableType.PLACEHOLDER) {
                Preconditions.checkState(variable.dataType().isNumerical() == entry.getValue().isNumerical(), "Cannot convert variables between numerical and non-numerical types: attempting to convert variable \"%s\" from %s to %s", entry.getKey(), variable.dataType(), entry.getValue());
            }
        }
        boolean z = false;
        for (Map.Entry<String, DataType> entry2 : map.entrySet()) {
            String key2 = entry2.getKey();
            DataType value = entry2.getValue();
            SDVariable variable2 = this.variables.get(key2).getVariable();
            if (variable2.dataType() != value) {
                variable2.setDataType(value);
                switch (variable2.getVariableType()) {
                    case VARIABLE:
                        this.variablesArrays.setArray(entry2.getKey(), this.variablesArrays.removeArray(entry2.getKey()).castTo(value));
                        break;
                    case ARRAY:
                    default:
                        throw new IllegalStateException("Cannot convert array type variable");
                    case CONSTANT:
                        this.constantArrays.setArray(entry2.getKey(), this.constantArrays.removeArray(entry2.getKey()).castTo(value));
                        break;
                    case PLACEHOLDER:
                        Map<String, INDArray> map2 = this.placeholdersPerThread.get(Long.valueOf(Thread.currentThread().getId()));
                        if (map2 != null && map2.containsKey(entry2.getKey())) {
                            map2.put(entry2.getKey(), map2.get(entry2.getKey()).castTo(value));
                            break;
                        }
                        break;
                }
                z = true;
            }
        }
        if (z) {
            this.sessions.clear();
            HashSet hashSet = new HashSet();
            LinkedList linkedList = new LinkedList();
            for (String str : map.keySet()) {
                Variable variable3 = this.variables.get(str);
                variable3.getVariable().setDataType(map.get(str));
                List<String> inputsForOp = variable3.getInputsForOp();
                if (inputsForOp != null) {
                    for (String str2 : inputsForOp) {
                        if (!hashSet.contains(str2)) {
                            hashSet.add(str2);
                            linkedList.add(str2);
                        }
                    }
                }
            }
            while (!linkedList.isEmpty()) {
                SameDiffOp sameDiffOp = this.ops.get((String) linkedList.remove());
                List<String> inputsToOp = sameDiffOp.getInputsToOp();
                ArrayList arrayList = new ArrayList();
                if (inputsToOp != null) {
                    Iterator<String> it = inputsToOp.iterator();
                    while (it.hasNext()) {
                        arrayList.add(this.variables.get(it.next()).getVariable().dataType());
                    }
                }
                List<DataType> calculateOutputDataTypes = sameDiffOp.getOp().calculateOutputDataTypes(arrayList);
                List<String> outputsOfOp = sameDiffOp.getOutputsOfOp();
                for (int i = 0; i < outputsOfOp.size(); i++) {
                    Variable variable4 = this.variables.get(outputsOfOp.get(i));
                    variable4.getVariable().setDataType(calculateOutputDataTypes.get(i));
                    if (variable4.getInputsForOp() != null) {
                        for (String str3 : variable4.getInputsForOp()) {
                            if (!hashSet.contains(str3)) {
                                hashSet.add(str3);
                                linkedList.add(str3);
                            }
                        }
                    }
                }
            }
        }
    }

    public void renameVariable(SameDiffOp sameDiffOp, String str, String str2) {
        Preconditions.checkState(this.variables.containsKey(str), "Cannot rename variable \"%s\": no variable with this name exists", str);
        Preconditions.checkState(!this.variables.containsKey(str2), "Cannot rename variable \"%s\" to name \"%s\": a variable with name \"%s\" already exists", str, str2, str2);
        Variable variable = this.variables.get(str);
        variable.setName(str2);
        variable.getVariable().setVarName(str2);
        if (variable.getInputsForOp() != null) {
            Iterator<String> it = variable.getInputsForOp().iterator();
            while (it.hasNext()) {
                SameDiffOp sameDiffOp2 = this.ops.get(it.next());
                ArrayList arrayList = new ArrayList(sameDiffOp2.getInputsToOp());
                while (arrayList.contains(str)) {
                    arrayList.set(arrayList.indexOf(str), str2);
                }
                sameDiffOp2.setInputsToOp(arrayList);
            }
        }
        if (variable.getControlDepsForOp() != null) {
            Iterator<String> it2 = variable.getControlDepsForOp().iterator();
            while (it2.hasNext()) {
                SameDiffOp sameDiffOp3 = this.ops.get(it2.next());
                ArrayList arrayList2 = new ArrayList(sameDiffOp3.getControlDeps());
                while (arrayList2.contains(str)) {
                    arrayList2.set(arrayList2.indexOf(str), str2);
                }
                sameDiffOp3.setControlDeps(arrayList2);
            }
        }
        if (variable.getControlDepsForVar() != null) {
            Iterator<String> it3 = variable.getControlDepsForVar().iterator();
            while (it3.hasNext()) {
                Variable variable2 = this.variables.get(it3.next());
                ArrayList arrayList3 = new ArrayList(variable2.getControlDeps());
                while (arrayList3.contains(str)) {
                    arrayList3.set(arrayList3.indexOf(str), str2);
                }
                variable2.setControlDeps(arrayList3);
            }
        }
        if (variable.getControlDeps() != null) {
            Iterator<String> it4 = variable.getControlDeps().iterator();
            while (it4.hasNext()) {
                Variable variable3 = this.variables.get(it4.next());
                ArrayList arrayList4 = new ArrayList(variable3.getControlDepsForVar());
                while (arrayList4.contains(str)) {
                    arrayList4.set(arrayList4.indexOf(str), str2);
                }
                variable3.setControlDepsForVar(arrayList4);
            }
        }
        if (variable.getOutputOfOp() != null) {
            SameDiffOp sameDiffOp4 = this.ops.get(variable.getOutputOfOp());
            ArrayList arrayList5 = new ArrayList(sameDiffOp4.getOutputsOfOp());
            while (arrayList5.contains(str)) {
                arrayList5.set(arrayList5.indexOf(str), str2);
            }
            sameDiffOp4.setOutputsOfOp(arrayList5);
        }
        this.variables.remove(str);
        this.variables.put(str2, variable);
        if (variable.getVariable().getVariableType() == VariableType.CONSTANT && this.constantArrays.hasArray(str)) {
            this.constantArrays.rename(str, str2);
        }
        if (variable.getVariable().getVariableType() == VariableType.VARIABLE && this.variablesArrays.hasArray(str)) {
            this.variablesArrays.rename(str, str2);
        }
        if (variable.getVariable().getVariableType() == VariableType.PLACEHOLDER) {
            for (Map<String, INDArray> map : this.placeholdersPerThread.values()) {
                if (map != null && map.containsKey(str)) {
                    map.put(str2, map.remove(str));
                }
            }
        }
        if (this.trainingConfig != null) {
            if (this.trainingConfig.getDataSetFeatureMapping() != null && this.trainingConfig.getDataSetFeatureMapping().contains(str)) {
                ArrayList arrayList6 = new ArrayList(this.trainingConfig.getDataSetFeatureMapping());
                while (arrayList6.contains(str)) {
                    arrayList6.set(arrayList6.indexOf(str), str2);
                }
                this.trainingConfig.setDataSetFeatureMapping(arrayList6);
            }
            if (this.trainingConfig.getDataSetLabelMapping() != null && this.trainingConfig.getDataSetLabelMapping().contains(str)) {
                ArrayList arrayList7 = new ArrayList(this.trainingConfig.getDataSetLabelMapping());
                while (arrayList7.contains(str)) {
                    arrayList7.set(arrayList7.indexOf(str), str2);
                }
                this.trainingConfig.setDataSetLabelMapping(arrayList7);
            }
            if (this.trainingConfig.getDataSetFeatureMaskMapping() != null && this.trainingConfig.getDataSetFeatureMaskMapping().contains(str)) {
                ArrayList arrayList8 = new ArrayList(this.trainingConfig.getDataSetFeatureMaskMapping());
                while (arrayList8.contains(str)) {
                    arrayList8.set(arrayList8.indexOf(str), str2);
                }
                this.trainingConfig.setDataSetFeatureMaskMapping(arrayList8);
            }
            if (this.trainingConfig.getDataSetLabelMaskMapping() != null && this.trainingConfig.getDataSetLabelMaskMapping().contains(str)) {
                ArrayList arrayList9 = new ArrayList(this.trainingConfig.getDataSetLabelMaskMapping());
                while (arrayList9.contains(str)) {
                    arrayList9.set(arrayList9.indexOf(str), str2);
                }
                this.trainingConfig.setDataSetLabelMaskMapping(arrayList9);
            }
            if (this.trainingConfig.getLossVariables() != null && this.trainingConfig.getLossVariables().contains(str)) {
                ArrayList arrayList10 = new ArrayList(this.trainingConfig.getLossVariables());
                while (arrayList10.contains(str)) {
                    arrayList10.set(arrayList10.indexOf(str), str2);
                }
                this.trainingConfig.setLossVariables(arrayList10);
            }
        }
        for (SameDiff sameDiff : this.sameDiffFunctionInstances.values()) {
            if (sameDiff.hasVariable(str)) {
                sameDiff.renameVariable(str, str2);
            }
        }
        if (this.lossVariables.contains(str)) {
            this.lossVariables.set(this.lossVariables.indexOf(str), str2);
        }
    }

    public void renameVariable(String str, String str2) {
        renameVariable(this.ops.get(VariableUtils.stripVarSuffix(str)), str, str2);
    }

    public void removeArgFromOp(String str, DifferentialFunction differentialFunction) {
        SDVariable[] args = differentialFunction.args();
        int i = 0;
        while (true) {
            if (i >= args.length) {
                break;
            }
            if (args[i].name().equals(str)) {
                List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
                ArrayList arrayList = new ArrayList(args.length - 1);
                for (int i2 = 0; i2 < args.length; i2++) {
                    if (!inputsToOp.get(i2).equals(str)) {
                        arrayList.add(inputsToOp.get(i2));
                    }
                }
                this.ops.get(differentialFunction.getOwnName()).setInputsToOp(arrayList);
            } else {
                i++;
            }
        }
        this.variables.get(str).getInputsForOp().remove(differentialFunction.getOwnName());
    }

    public SDVariable getVariable(String str) {
        Variable variable = this.variables.get(str);
        if (variable == null) {
            return null;
        }
        return variable.getVariable();
    }

    public boolean hasVariable(String str) {
        return this.variables.containsKey(str);
    }

    public SDVariable getGradForVariable(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        Preconditions.checkState(variable.dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating point variables have gradients", str, variable.dataType());
        if (this.variables.containsKey(str) && this.variables.get(str).getGradient() != null) {
            return this.variables.get(str).getGradient();
        }
        if (this.sameDiffFunctionInstances.containsKey("grad") && this.sameDiffFunctionInstances.get("grad").variables.containsKey(str)) {
            return this.sameDiffFunctionInstances.get("grad").variables.get(str).getGradient();
        }
        return null;
    }

    public boolean variableHasGradient(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists", str);
        SDVariable variable = getVariable(str);
        return (!variable.dataType().isFPType() || variable.isConstant() || getGradForVariable(str) == null) ? false : true;
    }

    public void setGradientForVariableName(String str, SDVariable sDVariable) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable exists with name \"%s\"", str);
        if (sDVariable == null) {
            throw new ND4JIllegalStateException("Unable to set null gradient for variable name " + str);
        }
        this.variables.get(str).setGradient(sDVariable);
    }

    public SDVariable grad(String str) {
        if (!this.sameDiffFunctionInstances.containsKey("grad")) {
            createGradFunction();
        }
        return getFunction("grad").getGradForVariable(getFunction("grad").getVariable(str).name());
    }

    public SDVariable scalar(String str, double d) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(d));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, float f) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(f));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, int i) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(i));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, long j) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(j));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable scalar(String str, DataType dataType, Number number) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable var = var(str, Nd4j.scalar(dataType, number));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return var;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(double d) {
        return constant((String) null, d);
    }

    public SDVariable constant(String str, double d) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(d));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(float f) {
        return constant((String) null, f);
    }

    public SDVariable constant(String str, float f) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(f));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(int i) {
        return constant((String) null, i);
    }

    public SDVariable constant(String str, int i) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(i));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(long j) {
        return constant((String) null, j);
    }

    public SDVariable constant(String str, long j) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(j));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable constant(String str, DataType dataType, Number number) {
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                SDVariable constant = constant(str, Nd4j.scalar(dataType, number));
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                return constant;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    public SDVariable addVariable(SDVariable sDVariable) {
        Preconditions.checkState(sDVariable.getSameDiff() == this, "Samediff instance must be the same.");
        if (this.variables.containsKey(sDVariable.name()) && !this.variables.get(sDVariable.name()).getVariable().equals(sDVariable)) {
            throw new IllegalArgumentException("Variable with name \"" + sDVariable.name() + "\" already exists");
        }
        Preconditions.checkState(sDVariable.getSameDiff() == this, "Same diff instance for variable must be the same!");
        this.variables.put(sDVariable.name(), Variable.builder().name(sDVariable.name()).variable(sDVariable).build());
        return sDVariable;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public SDVariable[] generateOutputVariableForOp(DifferentialFunction differentialFunction, String str, boolean z) {
        if (str == null) {
            str = differentialFunction.getOwnName();
        }
        if (str == null) {
            str = differentialFunction.opName();
        }
        List<DataType> list = null;
        if (!z) {
            ArrayList arrayList = new ArrayList();
            List<String> inputsToOp = this.ops.get(differentialFunction.getOwnName()).getInputsToOp();
            if (inputsToOp != null) {
                Iterator<String> it = inputsToOp.iterator();
                while (it.hasNext()) {
                    arrayList.add(this.variables.get(it.next()).getVariable().dataType());
                }
            }
            list = differentialFunction.calculateOutputDataTypes(arrayList);
        }
        if (!(differentialFunction instanceof CustomOp)) {
            if (!(differentialFunction instanceof BaseOp)) {
                throw new RuntimeException("Unknown op type: " + differentialFunction.getClass());
            }
            SDVariable[] sDVariableArr = new SDVariable[1];
            SDVariable variable = getVariable(str);
            differentialFunction.args();
            if (variable == null) {
                variable = var(str, VariableType.ARRAY, null, list.get(0), (long[]) null);
            }
            if (variable == null) {
                variable = var(str, VariableType.ARRAY, null, list.get(0), (long[]) null);
            }
            variable.setCreator(differentialFunction);
            sDVariableArr[0] = variable;
            if (getOutputsForOp(differentialFunction) == null) {
                addOutgoingFor(sDVariableArr, differentialFunction);
            }
            return sDVariableArr;
        }
        CustomOp customOp = (CustomOp) differentialFunction;
        int numOutputs = differentialFunction.getNumOutputs();
        if (numOutputs <= 0) {
            CustomOpDescriptor descriptor = customOp.getDescriptor();
            if (descriptor != null) {
                numOutputs = descriptor.getNumOutputs();
            }
            if (numOutputs <= 0) {
                throw new ND4UnresolvedOutputVariables("Could not determine number of output variables for op " + differentialFunction.getOwnName() + " - " + differentialFunction.getClass().getSimpleName() + ". Ops can override getNumOutputs() to specify number of outputs if required");
            }
        }
        SDVariable[] sDVariableArr2 = new SDVariable[numOutputs];
        if (z || (list != null && list.size() == numOutputs)) {
            Logger logger = log;
            Object[] objArr = new Object[4];
            objArr[0] = list == null ? null : Integer.valueOf(list.size());
            objArr[1] = Integer.valueOf(numOutputs);
            objArr[2] = list;
            objArr[3] = differentialFunction.getClass().getSimpleName();
            logger.trace("Incorrect number of output datatypes: got %s but expected datatypes for %s outputs - %s (op: %s), could be due to variable input types.", objArr);
        }
        int i = 0;
        while (i < sDVariableArr2.length) {
            SDVariable variable2 = i == 0 ? getVariable(str) : getVariable(str + ":" + i);
            if (variable2 == null) {
                variable2 = var(generateNewVarName(str, i), VariableType.ARRAY, null, z ? null : list.get(i), (long[]) null);
            }
            variable2.setCreator(differentialFunction);
            sDVariableArr2[i] = variable2;
            i++;
        }
        if (getOutputsForOp(differentialFunction) == null) {
            addOutgoingFor(sDVariableArr2, differentialFunction);
        }
        return sDVariableArr2;
    }

    public SDVariable[] generateOutputVariableForOp(DifferentialFunction differentialFunction) {
        return generateOutputVariableForOp(differentialFunction, differentialFunction.getOwnName() != null ? differentialFunction.getOwnName() : differentialFunction.opName(), false);
    }

    public SameDiff getFunction(String str) {
        return this.sameDiffFunctionInstances.get(str);
    }

    public TensorArray tensorArray(DataType dataType) {
        TensorArray tensorArray = new TensorArray(this, dataType);
        tensorArray.outputVariables();
        return tensorArray;
    }

    public SDVariable invokeFunctionOn(String str, SameDiff sameDiff) {
        return this.sameDiffFunctionInstances.get(str).invokeGraphOn(sameDiff);
    }

    public SameDiff defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition, SDVariable[] sDVariableArr) {
        if (!this.sameDiffFunctionInstances.containsKey(str)) {
            SameDiff create = create();
            this.child = create;
            create.parent = this;
            SDVariable[] sDVariableArr2 = new SDVariable[sDVariableArr.length];
            for (int i = 0; i < sDVariableArr2.length; i++) {
                sDVariableArr2[i] = create.var(sDVariableArr[i]);
            }
            sameDiffFunctionDefinition.define(create, null, sDVariableArr2);
            this.sameDiffFunctionInstances.put(str, create);
        }
        this.child = null;
        return this.sameDiffFunctionInstances.get(str);
    }

    public void defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition) {
        defineFunction(str, sameDiffFunctionDefinition, new LinkedHashMap());
    }

    public void defineFunction(String str, SameDiffFunctionDefinition sameDiffFunctionDefinition, Map<String, INDArray> map) {
        if (this.sameDiffFunctionInstances.containsKey(str)) {
            return;
        }
        SameDiff create = create();
        sameDiffFunctionDefinition.define(create, map, null);
        this.sameDiffFunctionInstances.put(str, create);
    }

    public Map<String, INDArray> calculateGradients(Map<String, INDArray> map, @NonNull String... strArr) {
        if (strArr == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
        Preconditions.checkArgument(strArr.length > 0, "No variables were specified");
        return calculateGradients(map, Arrays.asList(strArr));
    }

    public Map<String, INDArray> calculateGradients(Map<String, INDArray> map, @NonNull Collection<String> collection) {
        if (collection == null) {
            throw new NullPointerException("variables is marked non-null but is null");
        }
        Preconditions.checkArgument(!collection.isEmpty(), "No variables were specified");
        return calculateGradientsAndOutputs(map, null, collection).getGradients();
    }

    public OutAndGrad calculateGradientsAndOutputs(Map<String, INDArray> map, Collection<String> collection, Collection<String> collection2) {
        Preconditions.checkArgument(((collection == null || collection.isEmpty()) && (collection2 == null || collection2.isEmpty())) ? false : true, "No variables were specified for either output or gradients");
        if (getFunction("grad") == null) {
            createGradFunction();
        }
        ArrayList arrayList = new ArrayList();
        if (collection != null) {
            arrayList.addAll(collection);
        }
        if (collection2 != null) {
            for (String str : collection2) {
                Preconditions.checkState(this.variables.containsKey(str), "No variable with name \"%s\" exists in the SameDiff instance", str);
                SDVariable gradient = getVariable(str).getGradient();
                if (gradient != null) {
                    arrayList.add(gradient.name());
                }
            }
        }
        SameDiff function = getFunction("grad");
        function.setListeners(this.listeners);
        Map<String, INDArray> batchOutputHelper = function.batchOutputHelper(map, null, Operation.TRAINING, (String[]) arrayList.toArray(new String[0]));
        HashMap hashMap = collection == null ? null : new HashMap();
        HashMap hashMap2 = collection2 == null ? null : new HashMap();
        if (collection != null) {
            for (String str2 : collection) {
                hashMap.put(str2, batchOutputHelper.get(str2));
            }
        }
        if (collection2 != null) {
            for (String str3 : collection2) {
                if (getVariable(str3).getGradient() != null) {
                    hashMap2.put(str3, batchOutputHelper.get(getVariable(str3).getGradient().name()));
                }
            }
        }
        return new OutAndGrad(hashMap, hashMap2);
    }

    public boolean hasGradientFunction() {
        return this.sameDiffFunctionInstances.containsKey("grad");
    }

    public void createGradFunction() {
        createGradFunction((String[]) null);
    }

    public void createGradFunction(final String... strArr) {
        if (this.lossVariables.isEmpty()) {
            if (this.trainingConfig == null || this.trainingConfig.getLossVariables() == null || this.trainingConfig.getLossVariables().isEmpty()) {
                List<String> bestGuessLossVariables = bestGuessLossVariables();
                if (bestGuessLossVariables.size() == 1) {
                    String outputOfOp = this.variables.get(bestGuessLossVariables.get(0)).getOutputOfOp();
                    if (outputOfOp == null || !(this.ops.get(outputOfOp).getOp() instanceof ExternalErrorsFunction)) {
                        log.info("Inferring output \"{}\" as loss variable as none were previously set.Use SameDiff.setLossVariables() or SDVariable.markAsLoss() to override", bestGuessLossVariables.get(0));
                    }
                    this.lossVariables.add(bestGuessLossVariables.get(0));
                } else if (bestGuessLossVariables.isEmpty()) {
                    for (SameDiffOp sameDiffOp : this.ops.values()) {
                        if (sameDiffOp.getOp() instanceof ExternalErrorsFunction) {
                            this.lossVariables.add(sameDiffOp.getOutputsOfOp().get(0));
                        }
                    }
                }
            } else {
                this.lossVariables.addAll(this.trainingConfig.getLossVariables());
            }
        }
        Preconditions.checkState(!this.lossVariables.isEmpty(), "Cannot create gradient function: No loss variables (variables to minimize) have been specified. Loss variables are the variables that represent the loss/cost/score to be minimized during training, and that all gradients are calculated with respect to.\n Losses can be specified either in TrainingConfiguration (Builder.minimize(...)) or via SameDiff.setLossVariables()/addLossVariable()");
        if (log.isTraceEnabled()) {
            log.trace("Defining function \"grad\"");
        }
        if (strArr != null && strArr.length > 0) {
            for (String str : strArr) {
                Preconditions.checkArgument(this.variables.containsKey(str), "Cannot ensure gradient exists for variable: no variable with name \"%s\" exists", str);
                DataType dataType = this.variables.get(str).getVariable().dataType();
                Preconditions.checkState(dataType.isFPType(), "Cannot ensure gradient exists for variable \"%s\": variable is not a floating point SDVariable. Only floating point SDVariables have gradients defined - variable has type %s", str, dataType);
            }
        }
        defineFunction("grad", new SameDiffFunctionDefinition() { // from class: org.nd4j.autodiff.samediff.SameDiff.1
            @Override // org.nd4j.autodiff.samediff.SameDiffFunctionDefinition
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> map, SDVariable[] sDVariableArr) {
                List<String> inputsToOp;
                List<String> outputsOfOp;
                List<String> outputsOfOp2;
                List<String> inputsToOp2;
                sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false);
                if (SameDiff.this.debugMode) {
                    sameDiff.enableDebugMode();
                }
                this.invokeGraphOn(sameDiff);
                if (SameDiff.this.debugMode) {
                    Preconditions.checkState(sameDiff.ops.keySet().equals(SameDiff.this.ops.keySet()), "ops keysets not equal");
                }
                ArrayList arrayList = new ArrayList(sameDiff.ops.values());
                if (arrayList.isEmpty()) {
                    throw new ND4JIllegalStateException("No ops found!");
                }
                Iterator it = arrayList.iterator();
                while (it.hasNext()) {
                    DifferentialFunction op = ((SameDiffOp) it.next()).getOp();
                    for (SDVariable sDVariable : op.args()) {
                        sDVariable.setSameDiff(sameDiff);
                    }
                    for (SDVariable sDVariable2 : op.outputVariables()) {
                        sDVariable2.setSameDiff(sameDiff);
                    }
                    op.setSameDiff(sameDiff);
                }
                ArrayList arrayList2 = new ArrayList(SameDiff.this.lossVariables.size());
                SDVariable var = sameDiff.var("one-var", Nd4j.scalar(1.0f));
                for (String str2 : SameDiff.this.lossVariables) {
                    Preconditions.checkNotNull(str2, "Encountered null value in loss variables. Null loss variables are not allowed. Use SameDiff.setLossVariables with non-null array names to fix");
                    Preconditions.checkState(SameDiff.this.variables.containsKey(str2), "Specified loss function variable \"%s\" does not exist", str2);
                    SDVariable variable = ((Variable) SameDiff.this.variables.get(str2)).getVariable();
                    Preconditions.checkState(variable.dataType().isFPType(), "Specified loss function variable \"%s\" is not a floatingpoint variable (datatype: %s). Only floating point variables may be used as loss function variable", str2, variable.dataType());
                    SDVariable sum = variable.sum(new int[0]);
                    if (sum.dataType() == var.dataType()) {
                        sameDiff.setGradientForVariableName(sum.name(), var);
                    } else {
                        sameDiff.setGradientForVariableName(sum.name(), var.castTo(sum.dataType()));
                    }
                    if (arrayList2.contains(sum)) {
                        SameDiff.log.warn("Loss function variable \"{}\" appears multiple times in list of loss variables - using only first instance", str2);
                    } else {
                        arrayList2.add(sum);
                    }
                }
                if (SameDiff.log.isTraceEnabled()) {
                    String[] outputVariablesNames = ((SameDiffOp) arrayList.get(arrayList.size() - 1)).getOp().outputVariablesNames();
                    SameDiff.log.trace("Defining backward function: initial outputs {}", outputVariablesNames == null ? "null" : Arrays.toString(outputVariablesNames));
                }
                HashSet<String> hashSet = new HashSet();
                LinkedList linkedList = new LinkedList();
                for (String str3 : SameDiff.this.lossVariables) {
                    if (!linkedList.contains(str3)) {
                        linkedList.add(str3);
                    }
                }
                while (!linkedList.isEmpty()) {
                    String str4 = (String) linkedList.remove();
                    if (!hashSet.contains(str4)) {
                        Variable variable2 = (Variable) SameDiff.this.variables.get(str4);
                        if (variable2.getVariable().dataType().isFPType()) {
                            hashSet.add(variable2.getName());
                            if (variable2.getOutputOfOp() != null && (inputsToOp2 = ((SameDiffOp) SameDiff.this.ops.get(variable2.getOutputOfOp())).getInputsToOp()) != null) {
                                for (String str5 : inputsToOp2) {
                                    if (((Variable) SameDiff.this.variables.get(str5)).getVariable().dataType().isFPType()) {
                                        linkedList.add(str5);
                                    }
                                }
                            }
                        }
                    }
                }
                HashSet<String> hashSet2 = new HashSet(hashSet);
                LinkedList linkedList2 = new LinkedList();
                for (String str6 : hashSet) {
                    Variable variable3 = (Variable) SameDiff.this.variables.get(str6);
                    if (variable3.getVariable().getVariableType() == VariableType.ARRAY) {
                        List<String> inputsToOp3 = ((SameDiffOp) SameDiff.this.ops.get(variable3.getOutputOfOp())).getInputsToOp();
                        boolean z = false;
                        if (inputsToOp3 != null) {
                            Iterator<String> it2 = inputsToOp3.iterator();
                            while (true) {
                                if (!it2.hasNext()) {
                                    break;
                                }
                                if (hashSet.contains(it2.next())) {
                                    z = true;
                                    break;
                                }
                            }
                        }
                        if (!z) {
                            linkedList2.add(str6);
                        }
                    }
                    VariableType variableType = variable3.getVariable().getVariableType();
                    boolean z2 = strArr != null && ArrayUtils.contains(strArr, str6);
                    if (variableType == VariableType.CONSTANT || variableType == VariableType.PLACEHOLDER) {
                        if (!z2) {
                            linkedList2.add(str6);
                        }
                    }
                }
                while (!linkedList2.isEmpty()) {
                    String str7 = (String) linkedList2.remove();
                    Variable variable4 = (Variable) SameDiff.this.variables.get(str7);
                    hashSet2.remove(str7);
                    List<String> inputsForOp = variable4.getInputsForOp();
                    if (inputsForOp != null && !inputsForOp.isEmpty()) {
                        Iterator<String> it3 = inputsForOp.iterator();
                        while (it3.hasNext()) {
                            SameDiffOp sameDiffOp2 = (SameDiffOp) SameDiff.this.ops.get(it3.next());
                            boolean z3 = false;
                            for (String str8 : sameDiffOp2.getInputsToOp()) {
                                if (hashSet2.contains(str8) || (strArr != null && ArrayUtils.contains(strArr, str8))) {
                                    z3 = true;
                                    break;
                                }
                            }
                            if (!z3 && (outputsOfOp2 = sameDiffOp2.getOutputsOfOp()) != null) {
                                for (String str9 : outputsOfOp2) {
                                    if (!linkedList2.contains(str9)) {
                                        linkedList2.add(str9);
                                    }
                                }
                            }
                        }
                    }
                }
                Preconditions.checkState(!hashSet2.isEmpty(), "Cannot differentiate graph relative to the specified loss function variables %s: graph does not contain any trainable SDVariables (floating point VARIABLE type SDVariables) that the loss function depend on.", SameDiff.this.lossVariables);
                LinkedList linkedList3 = new LinkedList();
                Iterator it4 = arrayList2.iterator();
                while (it4.hasNext()) {
                    Variable variable5 = (Variable) sameDiff.variables.get(((SDVariable) it4.next()).name());
                    if (variable5.getOutputOfOp() != null) {
                        linkedList3.add(variable5.getOutputOfOp());
                    }
                }
                HashMap hashMap = new HashMap();
                Iterator it5 = hashSet2.iterator();
                while (it5.hasNext()) {
                    Variable variable6 = (Variable) SameDiff.this.variables.get((String) it5.next());
                    List<String> inputsForOp2 = variable6.getInputsForOp();
                    if (inputsForOp2 != null) {
                        ArrayList arrayList3 = new ArrayList();
                        for (String str10 : inputsForOp2) {
                            List<String> outputsOfOp3 = ((SameDiffOp) SameDiff.this.ops.get(str10)).getOutputsOfOp();
                            boolean z4 = false;
                            if (outputsOfOp3 != null) {
                                Iterator<String> it6 = outputsOfOp3.iterator();
                                while (true) {
                                    if (!it6.hasNext()) {
                                        break;
                                    }
                                    if (hashSet2.contains(it6.next())) {
                                        z4 = true;
                                        break;
                                    }
                                }
                            }
                            if (z4) {
                                arrayList3.add(str10);
                            }
                        }
                        hashMap.put(variable6.getName(), arrayList3);
                    }
                }
                HashSet hashSet3 = new HashSet();
                while (!linkedList3.isEmpty()) {
                    DifferentialFunction op2 = ((SameDiffOp) sameDiff.ops.get((String) linkedList3.remove())).getOp();
                    if (op2 instanceof GradientBackwardsMarker) {
                        inputsToOp = ((SameDiffOp) sameDiff.ops.get(op2.getOwnName())).getInputsToOp();
                        outputsOfOp = Collections.emptyList();
                    } else {
                        inputsToOp = ((SameDiffOp) sameDiff.ops.get(op2.getOwnName())).getInputsToOp();
                        outputsOfOp = ((SameDiffOp) sameDiff.ops.get(op2.getOwnName())).getOutputsOfOp();
                    }
                    ArrayList arrayList4 = new ArrayList();
                    Iterator<String> it7 = outputsOfOp.iterator();
                    while (it7.hasNext()) {
                        SDVariable variable7 = sameDiff.getVariable(it7.next());
                        SDVariable gradient = variable7.hasGradient() ? variable7.gradient() : null;
                        if (gradient != null) {
                            arrayList4.add(gradient);
                        } else if (variable7.dataType().isFPType()) {
                            arrayList4.add(sameDiff.zerosLike(variable7));
                        } else {
                            arrayList4.add(null);
                        }
                    }
                    op2.diff(arrayList4);
                    hashSet3.add(op2.getOwnName());
                    Iterator<String> it8 = inputsToOp.iterator();
                    while (it8.hasNext()) {
                        String outputOfOp2 = ((Variable) sameDiff.variables.get(it8.next())).getOutputOfOp();
                        if (outputOfOp2 != null && !hashSet3.contains(outputOfOp2)) {
                            boolean z5 = false;
                            SameDiffOp sameDiffOp3 = (SameDiffOp) SameDiff.this.ops.get(outputOfOp2);
                            if (sameDiffOp3.getInputsToOp() != null) {
                                boolean z6 = false;
                                Iterator<String> it9 = sameDiffOp3.getInputsToOp().iterator();
                                while (true) {
                                    if (!it9.hasNext()) {
                                        break;
                                    }
                                    if (hashSet2.contains(it9.next())) {
                                        z6 = true;
                                        break;
                                    }
                                }
                                if (z6 && !hashSet3.contains(sameDiffOp3.getName())) {
                                    z5 = true;
                                }
                            }
                            if (z5) {
                                boolean z7 = true;
                                SameDiffOp sameDiffOp4 = (SameDiffOp) sameDiff.ops.get(outputOfOp2);
                                Iterator<String> it10 = sameDiffOp4.getOutputsOfOp().iterator();
                                while (true) {
                                    if (!it10.hasNext()) {
                                        break;
                                    }
                                    Variable variable8 = (Variable) SameDiff.this.variables.get(it10.next());
                                    if (variable8.getVariable().dataType().isFPType() && hashSet2.contains(variable8.getName())) {
                                        if (variable8.getVariable().gradient() == null) {
                                            z7 = false;
                                            break;
                                        }
                                        List list = (List) hashMap.get(variable8.getName());
                                        if (list != null) {
                                            z7 &= hashSet3.containsAll(list);
                                            if (!z7) {
                                                break;
                                            }
                                        } else {
                                            continue;
                                        }
                                    }
                                }
                                if (z7 && !linkedList3.contains(sameDiffOp4.getOp().getOwnName())) {
                                    linkedList3.add(sameDiffOp4.getOp().getOwnName());
                                }
                            }
                        }
                    }
                }
                for (String str11 : hashSet2) {
                    if (!SameDiff.this.lossVariables.contains(str11) && ((Variable) SameDiff.this.variables.get(str11)).getVariable().gradient() == null) {
                        throw new IllegalStateException("Error encountered during differentiation: no gradient for required variable \"" + str11 + "\" was calculated");
                    }
                }
                return new SDVariable[]{sameDiff.var("grad", DataType.FLOAT, 1)};
            }
        });
        associateSameDiffWithOpsAndVariables();
    }

    protected List<String> bestGuessLossVariables() {
        ArrayList arrayList = new ArrayList();
        for (Variable variable : this.variables.values()) {
            if (!variable.getVariable().isConstant() && !variable.getVariable().isPlaceHolder() && (variable.getInputsForOp() == null || variable.getInputsForOp().isEmpty())) {
                if (variable.getControlDepsForOp() == null || variable.getControlDepsForOp().isEmpty()) {
                    if (variable.getControlDepsForVar() == null || variable.getControlDepsForVar().isEmpty()) {
                        if (variable.getOutputOfOp() != null && variable.getVariable().dataType().isFPType()) {
                            SameDiffOp sameDiffOp = this.ops.get(variable.getOutputOfOp());
                            if (!(sameDiffOp.getOp() instanceof Assert) && !(sameDiffOp.getOp() instanceof Switch)) {
                            }
                        }
                        arrayList.add(variable.getName());
                    }
                }
            }
        }
        return arrayList;
    }

    public boolean isPlaceHolder(String str) {
        Preconditions.checkState(this.variables.containsKey(str), "No variable present in SameDiff instance with name \"%s\"", str);
        return this.variables.get(str).getVariable().isPlaceHolder();
    }

    public SDVariable updateVariableNameAndReference(SameDiffOp sameDiffOp, SDVariable sDVariable, String str) {
        String currentNameScope;
        if (sDVariable == null) {
            throw new NullPointerException("Null input: No variable found for updating!");
        }
        if (str != null && (currentNameScope = currentNameScope()) != null && !str.startsWith(currentNameScope + "/")) {
            str = currentNameScope + "/" + str;
        }
        if (str != null && this.variables.containsKey(str) && sDVariable != this.variables.get(str).getVariable()) {
            throw new IllegalStateException("Variable name \"" + str + "\" already exists for a different SDVariable");
        }
        if (str == null && this.variables.containsKey(sDVariable.name()) && this.variables.get(sDVariable.name()).getVariable() != sDVariable) {
            str = generateNewVarName(sDVariable.name(), 0);
        }
        if (str == null || sDVariable.name().equals(str)) {
            return sDVariable;
        }
        String name = sDVariable.name();
        sDVariable.setVarName(str);
        renameVariable(sameDiffOp, name, str);
        return sDVariable;
    }

    public SDVariable updateVariableNameAndReference(SDVariable sDVariable, String str) {
        return updateVariableNameAndReference(this.ops.get(sDVariable.name()), sDVariable, str);
    }

    public SDVariable[] updateVariableNamesAndReferences(SDVariable[] sDVariableArr, String[] strArr) {
        int length = sDVariableArr.length;
        SDVariable[] sDVariableArr2 = new SDVariable[length];
        for (int i = 0; i < length; i++) {
            sDVariableArr2[i] = updateVariableNameAndReference(sDVariableArr[i], strArr == null ? null : strArr[i]);
        }
        return sDVariableArr2;
    }

    protected void associateSameDiffWithOpsAndVariables() {
        Iterator<SDVariable> it = variableMap().values().iterator();
        while (it.hasNext()) {
            it.next().setSameDiff(this);
        }
        Iterator<SameDiffOp> it2 = this.ops.values().iterator();
        while (it2.hasNext()) {
            DifferentialFunction op = it2.next().getOp();
            op.setSameDiff(this);
            SDVariable[] args = op.args();
            if (args != null) {
                for (SDVariable sDVariable : args) {
                    sDVariable.setSameDiff(this);
                }
            }
            SDVariable[] outputVariables = op.outputVariables();
            if (outputVariables != null) {
                for (SDVariable sDVariable2 : outputVariables) {
                    sDVariable2.setSameDiff(this);
                }
            }
        }
    }

    protected int asFlatNode(String str, @NonNull SameDiff sameDiff, @NonNull FlatBufferBuilder flatBufferBuilder) {
        if (sameDiff == null) {
            throw new NullPointerException("scope is marked non-null but is null");
        }
        if (flatBufferBuilder == null) {
            throw new NullPointerException("bufferBuilder is marked non-null but is null");
        }
        int createString = flatBufferBuilder.createString(str);
        return FlatNode.createFlatNode(flatBufferBuilder, createString, createString, (byte) 119, 10L, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
    }

    public static Pair<String, Integer> parseVariable(@NonNull String str) {
        if (str == null) {
            throw new NullPointerException("varName is marked non-null but is null");
        }
        if (!str.contains(":")) {
            return Pair.pairOf(str, 0);
        }
        String[] split = str.split(":");
        Integer valueOf = Integer.valueOf(split[split.length - 1]);
        if (split.length == 2) {
            return Pair.pairOf(split[0], valueOf);
        }
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < split.length - 1; i++) {
            sb.append(split[i]);
            if (i < split.length - 2) {
                sb.append(":");
            }
        }
        return Pair.pairOf(sb.toString(), valueOf);
    }

    public ByteBuffer asFlatBuffers(@NonNull ExecutorConfiguration executorConfiguration, boolean z) {
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        return asFlatBuffers(0L, executorConfiguration, z);
    }

    public ByteBuffer asFlatBuffers(long j, @NonNull ExecutorConfiguration executorConfiguration, boolean z) {
        int incrementAndGet;
        int i;
        long[] shape;
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        Nd4j.getExecutioner().commit();
        FlatBufferBuilder flatBufferBuilder = new FlatBufferBuilder(1024);
        AtomicInteger atomicInteger = new AtomicInteger(0);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        ArrayList arrayList4 = new ArrayList(variables());
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        LinkedHashMap linkedHashMap3 = new LinkedHashMap();
        IdentityHashMap identityHashMap = new IdentityHashMap();
        for (SDVariable sDVariable : variables()) {
            INDArray arr = sDVariable.getVariableType() == VariableType.ARRAY ? null : sDVariable.getArr();
            log.trace("Exporting variable: [{}]", sDVariable.name());
            String name = sDVariable.name();
            if (this.variables.get(name).getOutputOfOp() != null) {
                DifferentialFunction op = this.ops.get(this.variables.get(name).getOutputOfOp()).getOp();
                if (identityHashMap.containsKey(op)) {
                    incrementAndGet = ((Integer) identityHashMap.get(op)).intValue();
                } else {
                    incrementAndGet = atomicInteger.incrementAndGet();
                    identityHashMap.put(op, Integer.valueOf(incrementAndGet));
                }
                String[] outputVariablesNames = op.outputVariablesNames();
                i = ArrayUtils.indexOf(outputVariablesNames, name);
                Preconditions.checkState(i >= 0, "Variable name \"%s\" not found in list of outputs: %s", name, outputVariablesNames);
            } else {
                incrementAndGet = atomicInteger.incrementAndGet();
                i = 0;
            }
            linkedHashMap.put(sDVariable.name(), Integer.valueOf(incrementAndGet));
            log.trace("Adding [{}] as [{}]", sDVariable.name(), Integer.valueOf(incrementAndGet));
            int i2 = 0;
            int createString = flatBufferBuilder.createString(sDVariable.name());
            int createIntPair = IntPair.createIntPair(flatBufferBuilder, incrementAndGet, i);
            byte ordinal = (byte) sDVariable.getVariableType().ordinal();
            int flatArray = (sDVariable.isConstant() || sDVariable.isPlaceHolder() || sDVariable.getVariableType() == VariableType.VARIABLE) ? arr == null ? 0 : arr.toFlatArray(flatBufferBuilder) : 0;
            if (sDVariable.getVariableType() == VariableType.PLACEHOLDER && (shape = sDVariable.getShape()) != null) {
                i2 = FlatVariable.createShapeVector(flatBufferBuilder, shape);
            }
            int i3 = 0;
            Variable variable = this.variables.get(name);
            int[] mapOrNull = FlatBuffersMapper.mapOrNull(variable.getControlDeps(), flatBufferBuilder);
            int createControlDepsVector = mapOrNull != null ? FlatVariable.createControlDepsVector(flatBufferBuilder, mapOrNull) : 0;
            int[] mapOrNull2 = FlatBuffersMapper.mapOrNull(variable.getControlDepsForOp(), flatBufferBuilder);
            int createControlDepForOpVector = mapOrNull2 != null ? FlatVariable.createControlDepForOpVector(flatBufferBuilder, mapOrNull2) : 0;
            int[] mapOrNull3 = FlatBuffersMapper.mapOrNull(variable.getControlDepsForVar(), flatBufferBuilder);
            if (mapOrNull3 != null) {
                i3 = FlatVariable.createControlDepsForVarVector(flatBufferBuilder, mapOrNull3);
            }
            arrayList.add(Integer.valueOf(FlatVariable.createFlatVariable(flatBufferBuilder, createIntPair, createString, FlatBuffersMapper.getDataTypeAsByte(sDVariable.dataType()), i2, flatArray, -1, ordinal, createControlDepsVector, createControlDepForOpVector, i3)));
        }
        Iterator<SameDiffOp> it = this.ops.values().iterator();
        while (it.hasNext()) {
            DifferentialFunction op2 = it.next().getOp();
            arrayList3.add(Integer.valueOf(FlatBuffersMapper.asFlatNode(this, op2, flatBufferBuilder, arrayList4, linkedHashMap, linkedHashMap2, linkedHashMap3, atomicInteger, (Integer) identityHashMap.get(op2))));
        }
        int createVariablesVector = FlatGraph.createVariablesVector(flatBufferBuilder, Ints.toArray(arrayList2));
        int createVariablesVector2 = FlatGraph.createVariablesVector(flatBufferBuilder, Ints.toArray(arrayList));
        int createNodesVector = FlatGraph.createNodesVector(flatBufferBuilder, Ints.toArray(arrayList3));
        int i4 = 0;
        Iterator<SDVariable> it2 = variables().iterator();
        while (it2.hasNext()) {
            if (it2.next().isPlaceHolder()) {
                i4++;
            }
        }
        int[] iArr = new int[i4];
        if (i4 > 0) {
            int i5 = 0;
            for (SDVariable sDVariable2 : variables()) {
                if (sDVariable2.isPlaceHolder()) {
                    int i6 = i5;
                    i5++;
                    iArr[i6] = flatBufferBuilder.createString(sDVariable2.name());
                }
            }
        }
        int createPlaceholdersVector = FlatGraph.createPlaceholdersVector(flatBufferBuilder, iArr);
        List<String> lossVariables = getLossVariables();
        int[] iArr2 = new int[lossVariables == null ? 0 : lossVariables.size()];
        for (int i7 = 0; i7 < iArr2.length; i7++) {
            iArr2[i7] = flatBufferBuilder.createString(lossVariables.get(i7));
        }
        int createLossVariablesVector = FlatGraph.createLossVariablesVector(flatBufferBuilder, iArr2);
        int i8 = 0;
        int createString2 = this.trainingConfig != null ? flatBufferBuilder.createString(this.trainingConfig.toJson()) : 0;
        if (z && this.updaterMap != null && !this.updaterMap.isEmpty()) {
            int[] iArr3 = new int[this.updaterMap.size()];
            int i9 = 0;
            for (Map.Entry<String, GradientUpdater> entry : this.updaterMap.entrySet()) {
                int createString3 = flatBufferBuilder.createString(entry.getKey());
                int i10 = 0;
                int i11 = 0;
                Map<String, INDArray> state = entry.getValue().getState();
                if (state != null && !state.isEmpty()) {
                    int[] iArr4 = new int[state.size()];
                    int[] iArr5 = new int[state.size()];
                    int i12 = 0;
                    for (Map.Entry<String, INDArray> entry2 : state.entrySet()) {
                        iArr4[i12] = flatBufferBuilder.createString(entry2.getKey());
                        iArr5[i12] = entry2.getValue().toFlatArray(flatBufferBuilder);
                        i12++;
                    }
                    i10 = UpdaterState.createUpdaterStateKeysVector(flatBufferBuilder, iArr4);
                    i11 = UpdaterState.createUpdaterStateValuesVector(flatBufferBuilder, iArr5);
                }
                int i13 = i9;
                i9++;
                iArr3[i13] = UpdaterState.createUpdaterState(flatBufferBuilder, createString3, i10, i11);
            }
            i8 = FlatGraph.createUpdaterStateVector(flatBufferBuilder, iArr3);
        }
        flatBufferBuilder.finish(FlatGraph.createFlatGraph(flatBufferBuilder, j, createVariablesVector2, createNodesVector, createVariablesVector, executorConfiguration.getFlatConfiguration(flatBufferBuilder), createPlaceholdersVector, createLossVariablesVector, createString2, i8));
        synchronized (this) {
            for (Map.Entry entry3 : linkedHashMap.entrySet()) {
                this.variables.get(entry3.getKey()).setVariableIndex(((Integer) entry3.getValue()).intValue());
            }
        }
        return flatBufferBuilder.dataBuffer();
    }

    public FlatGraph asFlatGraph(boolean z) {
        return FlatGraph.getRootAsFlatGraph(asFlatBuffers(z));
    }

    public FlatGraph asFlatGraph(long j, ExecutorConfiguration executorConfiguration, boolean z) {
        return FlatGraph.getRootAsFlatGraph(asFlatBuffers(j, executorConfiguration, z));
    }

    public ByteBuffer asFlatBuffers(boolean z) {
        return asFlatBuffers(ExecutorConfiguration.builder().outputMode(OutputMode.VARIABLE_SPACE).executionMode(ExecutionMode.SEQUENTIAL).profilingMode(OpExecutioner.ProfilingMode.DISABLED).gatherTimings(true).build(), z);
    }

    public void save(@NonNull File file, boolean z) {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        try {
            asFlatFile(file, z);
        } catch (IOException e) {
            throw new RuntimeException("Error saving SameDiff instance to file", e);
        }
    }

    /* JADX WARN: Finally extract failed */
    public void save(@NonNull OutputStream outputStream, boolean z) {
        if (outputStream == null) {
            throw new NullPointerException("outputStream is marked non-null but is null");
        }
        File createTempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp");
        try {
            save(createTempFile, z);
            if (!(outputStream instanceof BufferedOutputStream)) {
                outputStream = new BufferedOutputStream(outputStream);
            }
            OutputStream outputStream2 = outputStream;
            Throwable th = null;
            try {
                try {
                    BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(createTempFile));
                    Throwable th2 = null;
                    try {
                        IOUtils.copy(bufferedInputStream, outputStream2);
                        if (bufferedInputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedInputStream.close();
                                } catch (Throwable th3) {
                                    th2.addSuppressed(th3);
                                }
                            } else {
                                bufferedInputStream.close();
                            }
                        }
                        if (outputStream2 != null) {
                            if (0 != 0) {
                                try {
                                    outputStream2.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                outputStream2.close();
                            }
                        }
                    } catch (Throwable th5) {
                        if (bufferedInputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedInputStream.close();
                                } catch (Throwable th6) {
                                    th2.addSuppressed(th6);
                                }
                            } else {
                                bufferedInputStream.close();
                            }
                        }
                        throw th5;
                    }
                } catch (IOException e) {
                    throw new RuntimeException("Error writing to output stream (or reading from temp file)", e);
                }
            } catch (Throwable th7) {
                if (outputStream2 != null) {
                    if (0 != 0) {
                        try {
                            outputStream2.close();
                        } catch (Throwable th8) {
                            th.addSuppressed(th8);
                        }
                    } else {
                        outputStream2.close();
                    }
                }
                throw th7;
            }
        } finally {
            createTempFile.delete();
        }
    }

    public static SameDiff load(@NonNull File file, boolean z) {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        try {
            return fromFlatFile(file, z);
        } catch (IOException e) {
            throw new RuntimeException("Error loading SameDiff instance from file", e);
        }
    }

    public static SameDiff load(@NonNull InputStream inputStream, boolean z) {
        if (inputStream == null) {
            throw new NullPointerException("is is marked non-null but is null");
        }
        File createTempFile = ND4JFileUtils.createTempFile("SameDiffFile", "temp");
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(createTempFile));
                Throwable th = null;
                try {
                    try {
                        IOUtils.copy(inputStream, bufferedOutputStream);
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        SameDiff fromFlatFile = fromFlatFile(createTempFile, z);
                        createTempFile.delete();
                        return fromFlatFile;
                    } finally {
                    }
                } catch (Throwable th3) {
                    if (bufferedOutputStream != null) {
                        if (th != null) {
                            try {
                                bufferedOutputStream.close();
                            } catch (Throwable th4) {
                                th.addSuppressed(th4);
                            }
                        } else {
                            bufferedOutputStream.close();
                        }
                    }
                    throw th3;
                }
            } catch (IOException e) {
                throw new RuntimeException("Error loading SameDiff instance from file", e);
            }
        } catch (Throwable th5) {
            createTempFile.delete();
            throw th5;
        }
    }

    public void asFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        asFlatFile(file, true);
    }

    /* JADX WARN: Failed to calculate best type for var: r14v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r15v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 14, insn: 0x00e1: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r14 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:72:0x00e1 */
    /* JADX WARN: Not initialized variable reg: 15, insn: 0x00e6: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r15 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:74:0x00e6 */
    /* JADX WARN: Type inference failed for: r14v0, types: [java.io.BufferedOutputStream] */
    /* JADX WARN: Type inference failed for: r15v0, types: [java.lang.Throwable] */
    public void asFlatFile(@NonNull File file, boolean z) throws IOException {
        ?? r14;
        ?? r15;
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        ByteBuffer asFlatBuffers = asFlatBuffers(z);
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
                Throwable th3 = null;
                try {
                    try {
                        dataOutputStream.write(array, position, array.length - position);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 == 0) {
                                fileOutputStream.close();
                                return;
                            }
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        }
                    } catch (Throwable th7) {
                        th3 = th7;
                        throw th7;
                    }
                } catch (Throwable th8) {
                    if (dataOutputStream != null) {
                        if (th3 != null) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th9) {
                                th3.addSuppressed(th9);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    throw th8;
                }
            } catch (Throwable th10) {
                if (r14 != 0) {
                    if (r15 != 0) {
                        try {
                            r14.close();
                        } catch (Throwable th11) {
                            r15.addSuppressed(th11);
                        }
                    } else {
                        r14.close();
                    }
                }
                throw th10;
            }
        } catch (Throwable th12) {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th13) {
                        th.addSuppressed(th13);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th12;
        }
    }

    /* JADX WARN: Finally extract failed */
    public void asFlatFile(@NonNull File file, @NonNull ExecutorConfiguration executorConfiguration, boolean z) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        if (executorConfiguration == null) {
            throw new NullPointerException("configuration is marked non-null but is null");
        }
        ByteBuffer asFlatBuffers = asFlatBuffers(executorConfiguration, z);
        int position = asFlatBuffers.position();
        byte[] array = asFlatBuffers.array();
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
            Throwable th2 = null;
            try {
                DataOutputStream dataOutputStream = new DataOutputStream(bufferedOutputStream);
                Throwable th3 = null;
                try {
                    try {
                        dataOutputStream.write(array, position, array.length - position);
                        if (dataOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    dataOutputStream.close();
                                } catch (Throwable th4) {
                                    th3.addSuppressed(th4);
                                }
                            } else {
                                dataOutputStream.close();
                            }
                        }
                        if (bufferedOutputStream != null) {
                            if (0 != 0) {
                                try {
                                    bufferedOutputStream.close();
                                } catch (Throwable th5) {
                                    th2.addSuppressed(th5);
                                }
                            } else {
                                bufferedOutputStream.close();
                            }
                        }
                        if (fileOutputStream != null) {
                            if (0 == 0) {
                                fileOutputStream.close();
                                return;
                            }
                            try {
                                fileOutputStream.close();
                            } catch (Throwable th6) {
                                th.addSuppressed(th6);
                            }
                        }
                    } catch (Throwable th7) {
                        th3 = th7;
                        throw th7;
                    }
                } catch (Throwable th8) {
                    if (dataOutputStream != null) {
                        if (th3 != null) {
                            try {
                                dataOutputStream.close();
                            } catch (Throwable th9) {
                                th3.addSuppressed(th9);
                            }
                        } else {
                            dataOutputStream.close();
                        }
                    }
                    throw th8;
                }
            } catch (Throwable th10) {
                if (bufferedOutputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedOutputStream.close();
                        } catch (Throwable th11) {
                            th2.addSuppressed(th11);
                        }
                    } else {
                        bufferedOutputStream.close();
                    }
                }
                throw th10;
            }
        } catch (Throwable th12) {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th13) {
                        th.addSuppressed(th13);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th12;
        }
    }

    public static SameDiff fromFlatFile(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        return fromFlatFile(file, true);
    }

    public static SameDiff fromFlatFile(@NonNull File file, boolean z) throws IOException {
        if (file == null) {
            throw new NullPointerException("file is marked non-null but is null");
        }
        BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file));
        Throwable th = null;
        try {
            try {
                byte[] byteArray = IOUtils.toByteArray(bufferedInputStream);
                if (bufferedInputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedInputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        bufferedInputStream.close();
                    }
                }
                return fromFlatBuffers(ByteBuffer.wrap(byteArray), z);
            } finally {
            }
        } catch (Throwable th3) {
            if (bufferedInputStream != null) {
                if (th != null) {
                    try {
                        bufferedInputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    bufferedInputStream.close();
                }
            }
            throw th3;
        }
    }

    public static SameDiff fromFlatBuffers(ByteBuffer byteBuffer) throws IOException {
        return fromFlatBuffers(byteBuffer, true);
    }

    public static SameDiff fromFlatBuffers(ByteBuffer byteBuffer, boolean z) throws IOException {
        String[] strArr;
        FlatGraph rootAsFlatGraph = FlatGraph.getRootAsFlatGraph(byteBuffer);
        int nodesLength = rootAsFlatGraph.nodesLength();
        int variablesLength = rootAsFlatGraph.variablesLength();
        ArrayList<FlatNode> arrayList = new ArrayList(nodesLength);
        for (int i = 0; i < nodesLength; i++) {
            arrayList.add(rootAsFlatGraph.nodes(i));
        }
        ArrayList<FlatVariable> arrayList2 = new ArrayList(variablesLength);
        for (int i2 = 0; i2 < variablesLength; i2++) {
            arrayList2.add(rootAsFlatGraph.variables(i2));
        }
        SameDiff create = create();
        int placeholdersLength = rootAsFlatGraph.placeholdersLength();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        for (int i3 = 0; i3 < placeholdersLength; i3++) {
            linkedHashSet.add(rootAsFlatGraph.placeholders(i3));
        }
        new HashMap();
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (FlatVariable flatVariable : arrayList2) {
            int shapeLength = flatVariable.shapeLength();
            long[] jArr = new long[shapeLength];
            for (int i4 = 0; i4 < shapeLength; i4++) {
                jArr[i4] = flatVariable.shape(i4);
            }
            String name = flatVariable.name();
            DataType dataTypeFromByte = FlatBuffersMapper.getDataTypeFromByte(flatVariable.dtype());
            VariableType variableType = VariableType.values()[flatVariable.variabletype()];
            SDVariable sDVariable = new SDVariable(name, variableType, create, jArr, dataTypeFromByte);
            create.variables.put(name, Variable.builder().name(name).variable(sDVariable).build());
            Variable variable = create.variables.get(name);
            if (flatVariable.controlDepsLength() > 0) {
                int controlDepsLength = flatVariable.controlDepsLength();
                ArrayList arrayList3 = new ArrayList(controlDepsLength);
                for (int i5 = 0; i5 < controlDepsLength; i5++) {
                    arrayList3.add(flatVariable.controlDeps(i5));
                }
                variable.setControlDeps(arrayList3);
            }
            if (flatVariable.controlDepForOpLength() > 0) {
                int controlDepForOpLength = flatVariable.controlDepForOpLength();
                ArrayList arrayList4 = new ArrayList(controlDepForOpLength);
                for (int i6 = 0; i6 < controlDepForOpLength; i6++) {
                    arrayList4.add(flatVariable.controlDepForOp(i6));
                }
                variable.setControlDepsForOp(arrayList4);
            }
            if (flatVariable.controlDepsForVarLength() > 0) {
                int controlDepsForVarLength = flatVariable.controlDepsForVarLength();
                ArrayList arrayList5 = new ArrayList(controlDepsForVarLength);
                for (int i7 = 0; i7 < controlDepsForVarLength; i7++) {
                    arrayList5.add(flatVariable.controlDepsForVar(i7));
                }
                variable.setControlDepsForVar(arrayList5);
            }
            FlatArray ndarray = flatVariable.ndarray();
            if (ndarray != null && variableType != VariableType.ARRAY) {
                MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
                Throwable th = null;
                try {
                    try {
                        INDArray createFromFlatArray = Nd4j.createFromFlatArray(ndarray);
                        if (scopeOutOfWorkspaces != null) {
                            if (0 != 0) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th2) {
                                    th.addSuppressed(th2);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        create.setArrayForVariable(name, createFromFlatArray);
                    } catch (Throwable th3) {
                        if (scopeOutOfWorkspaces != null) {
                            if (th != null) {
                                try {
                                    scopeOutOfWorkspaces.close();
                                } catch (Throwable th4) {
                                    th.addSuppressed(th4);
                                }
                            } else {
                                scopeOutOfWorkspaces.close();
                            }
                        }
                        throw th3;
                    }
                } finally {
                }
            }
            IntPair id = flatVariable.id();
            hashMap.put(new Pair(Integer.valueOf(id.first()), Integer.valueOf(id.second())), sDVariable);
            if (!hashMap2.containsKey(name)) {
                hashMap2.put(name, new ArrayList());
            }
            ((List) hashMap2.get(name)).add(sDVariable);
        }
        for (FlatNode flatNode : arrayList) {
            DifferentialFunction fromFlatNode = FlatBuffersMapper.fromFlatNode(flatNode);
            String name2 = flatNode.name();
            fromFlatNode.setSameDiff(create);
            fromFlatNode.setOwnName(name2);
            if (create.ops.containsKey(name2)) {
                create.ops.get(name2).setOp(fromFlatNode);
            } else {
                create.ops.put(name2, SameDiffOp.builder().name(name2).op(fromFlatNode).build());
            }
            int outputLength = flatNode.outputLength();
            int[] iArr = new int[outputLength];
            for (int i8 = 0; i8 < outputLength; i8++) {
                iArr[i8] = flatNode.output(i8);
            }
            int id2 = flatNode.id();
            int[] iArr2 = new int[flatNode.outputLength()];
            for (int i9 = 0; i9 < iArr2.length; i9++) {
                iArr2[i9] = flatNode.output(i9);
            }
            int[] iArr3 = new int[flatNode.inputLength()];
            for (int i10 = 0; i10 < iArr3.length; i10++) {
                iArr3[i10] = flatNode.input(i10);
            }
            IntPair[] intPairArr = new IntPair[flatNode.inputPairedLength()];
            ArrayList arrayList6 = new ArrayList();
            for (int i11 = 0; i11 < intPairArr.length; i11++) {
                intPairArr[i11] = flatNode.inputPaired(i11);
                arrayList6.add(new Pair(Integer.valueOf(intPairArr[i11].first()), Integer.valueOf(intPairArr[i11].second())));
            }
            String[] strArr2 = new String[intPairArr.length];
            for (int i12 = 0; i12 < intPairArr.length; i12++) {
                SDVariable sDVariable2 = (SDVariable) hashMap.get(new Pair(Integer.valueOf(intPairArr[i12].first()), Integer.valueOf(intPairArr[i12].second())));
                if (sDVariable2 == null) {
                }
                strArr2[i12] = sDVariable2.name();
            }
            SameDiffOp sameDiffOp = create.ops.get(fromFlatNode.getOwnName());
            sameDiffOp.setInputsToOp(Arrays.asList(strArr2));
            if (flatNode.controlDepsLength() > 0) {
                int controlDepsLength2 = flatNode.controlDepsLength();
                ArrayList arrayList7 = new ArrayList(controlDepsLength2);
                for (int i13 = 0; i13 < controlDepsLength2; i13++) {
                    arrayList7.add(flatNode.controlDeps(i13));
                }
                sameDiffOp.setControlDeps(arrayList7);
            }
            if (flatNode.varControlDepsLength() > 0) {
                int varControlDepsLength = flatNode.varControlDepsLength();
                ArrayList arrayList8 = new ArrayList(varControlDepsLength);
                for (int i14 = 0; i14 < varControlDepsLength; i14++) {
                    arrayList8.add(flatNode.varControlDeps(i14));
                }
                sameDiffOp.setVarControlDeps(arrayList8);
            }
            if (flatNode.controlDepForLength() > 0) {
                int controlDepForLength = flatNode.controlDepForLength();
                ArrayList arrayList9 = new ArrayList(controlDepForLength);
                for (int i15 = 0; i15 < controlDepForLength; i15++) {
                    arrayList9.add(flatNode.controlDepFor(i15));
                }
                sameDiffOp.setControlDepFor(arrayList9);
            }
            for (String str : strArr2) {
                Variable variable2 = create.getVariables().get(str);
                if (variable2.getInputsForOp() == null) {
                    variable2.setInputsForOp(new ArrayList());
                }
                if (!variable2.getInputsForOp().contains(fromFlatNode.getOwnName())) {
                    variable2.getInputsForOp().add(fromFlatNode.getOwnName());
                }
            }
            List list = (List) hashMap2.get(name2);
            int numOutputs = fromFlatNode.getNumOutputs();
            if (numOutputs <= 0) {
                numOutputs = flatNode.outputLength();
            }
            if (list == null || list.size() != numOutputs) {
                int outputNamesLength = flatNode.outputNamesLength();
                strArr = new String[outputNamesLength];
                for (int i16 = 0; i16 < outputNamesLength; i16++) {
                    String outputNames = flatNode.outputNames(i16);
                    strArr[i16] = outputNames;
                    if (!create.variables.containsKey(outputNames)) {
                        SDVariable sDVariable3 = new SDVariable(outputNames, VariableType.VARIABLE, create, null, null);
                        create.variables.put(outputNames, Variable.builder().name(outputNames).variable(sDVariable3).build());
                        hashMap.put(new Pair(Integer.valueOf(id2), Integer.valueOf(i16)), sDVariable3);
                    }
                    create.getVariables().get(strArr[i16]).setOutputOfOp(fromFlatNode.getOwnName());
                }
                create.ops.get(fromFlatNode.getOwnName()).setOutputsOfOp(Arrays.asList(strArr));
            } else {
                strArr = new String[list.size()];
                for (int i17 = 0; i17 < strArr.length; i17++) {
                    strArr[i17] = ((SDVariable) list.get(i17)).name();
                    create.getVariables().get(strArr[i17]).setOutputOfOp(fromFlatNode.getOwnName());
                }
                create.ops.get(fromFlatNode.getOwnName()).setOutputsOfOp(Arrays.asList(strArr));
            }
            for (int i18 = 0; i18 < strArr.length; i18++) {
                Pair pair = new Pair(Integer.valueOf(id2), Integer.valueOf(i18));
                if (!hashMap.containsKey(pair)) {
                    hashMap.put(pair, create.getVariable(strArr[i18]));
                }
            }
        }
        if (rootAsFlatGraph.lossVariablesLength() > 0) {
            for (int i19 = 0; i19 < rootAsFlatGraph.lossVariablesLength(); i19++) {
                create.addLossVariable(rootAsFlatGraph.lossVariables(i19));
            }
        }
        String trainingConfig = rootAsFlatGraph.trainingConfig();
        if (trainingConfig != null) {
            create.trainingConfig = TrainingConfig.fromJson(trainingConfig);
        }
        if (z && rootAsFlatGraph.updaterStateLength() > 0) {
            create.updaterMap = new HashMap();
            int updaterStateLength = rootAsFlatGraph.updaterStateLength();
            for (int i20 = 0; i20 < updaterStateLength; i20++) {
                UpdaterState updaterState = rootAsFlatGraph.updaterState(i20);
                String paramName = updaterState.paramName();
                int updaterStateKeysLength = updaterState.updaterStateKeysLength();
                HashMap hashMap3 = new HashMap();
                for (int i21 = 0; i21 < updaterStateKeysLength; i21++) {
                    hashMap3.put(updaterState.updaterStateKeys(i21), Nd4j.createFromFlatArray(updaterState.updaterStateValues(i21)));
                }
                create.updaterMap.put(paramName, create.trainingConfig.getUpdater().instantiate((Map<String, INDArray>) hashMap3, false));
            }
            create.initializedTraining = true;
        }
        return create;
    }

    public String asFlatPrint() {
        StringBuilder sb = new StringBuilder();
        FlatGraph rootAsFlatGraph = FlatGraph.getRootAsFlatGraph(asFlatBuffers(false));
        sb.append("\nExternal variables:\n\n");
        for (int i = 0; i < rootAsFlatGraph.variablesLength(); i++) {
            FlatVariable variables = rootAsFlatGraph.variables(i);
            MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();
            Throwable th = null;
            try {
                try {
                    FlatArray ndarray = variables.ndarray();
                    INDArray createFromFlatArray = ndarray != null ? Nd4j.createFromFlatArray(ndarray) : null;
                    if (scopeOutOfWorkspaces != null) {
                        if (0 != 0) {
                            try {
                                scopeOutOfWorkspaces.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scopeOutOfWorkspaces.close();
                        }
                    }
                    sb.append(variables.id().first()).append(":<").append(variables.name()).append("> ");
                    if (createFromFlatArray == null) {
                        sb.append("<no array>").append("; Values: ").append("<no array>").append(";\n");
                    } else {
                        sb.append(Arrays.toString(createFromFlatArray.shapeInfoDataBuffer().asInt())).append("; Values: ");
                        if (createFromFlatArray.data() == null) {
                            sb.append("<empty array>");
                        } else if (createFromFlatArray.dataType() == DataType.UTF8) {
                            sb.append("<string array>");
                        } else if (createFromFlatArray.length() < 50) {
                            sb.append(Arrays.toString(createFromFlatArray.data().asFloat()).replaceAll(" ", ""));
                        } else {
                            sb.append("[");
                            for (int i2 = 0; i2 < 50; i2++) {
                                if (i2 > 0) {
                                    sb.append(",");
                                }
                                sb.append(createFromFlatArray.data().getFloat(i2));
                            }
                            sb.append("]");
                        }
                        sb.append(";\n");
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scopeOutOfWorkspaces != null) {
                    if (th != null) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                throw th3;
            }
        }
        Map<String, CustomOpDescriptor> customOperations = Nd4j.getExecutioner().getCustomOperations();
        sb.append("\nOps sequence:\n\n");
        for (int i3 = 0; i3 < rootAsFlatGraph.nodesLength(); i3++) {
            FlatNode nodes = rootAsFlatGraph.nodes(i3);
            log.info("{}:<{}>", Integer.valueOf(nodes.id()), nodes.name());
            sb.append(nodes.id()).append(":<").append(nodes.name()).append("> ").append(FlatBuffersMapper.getTypeFromByte(nodes.opType()));
            if (FlatBuffersMapper.getTypeFromByte(nodes.opType()) != Op.Type.CUSTOM) {
                sb.append(": ").append(nodes.opNum());
            } else {
                String str = null;
                for (String str2 : customOperations.keySet()) {
                    if (customOperations.get(str2).getHash() == nodes.opNum()) {
                        str = str2;
                    }
                }
                if (str == null) {
                    str = "unknown";
                }
                sb.append(": ").append(str);
            }
            sb.append("; Inputs: {");
            for (int i4 = 0; i4 < nodes.inputPairedLength(); i4++) {
                IntPair inputPaired = nodes.inputPaired(i4);
                sb.append("[").append(inputPaired.first()).append(":").append(inputPaired.second()).append("]");
                if (i4 < nodes.inputPairedLength() - 1) {
                    sb.append(", ");
                }
            }
            sb.append("};");
            sb.append(" OpNum: {").append(nodes.opNum()).append("};");
            sb.append("\n");
        }
        return sb.toString();
    }

    public String summary() {
        long[] placeholderShape;
        String str;
        Map<String, SDVariable> variableMap = variableMap();
        DifferentialFunction[] ops = ops();
        int i = 0;
        Iterator<String> it = variableMap.keySet().iterator();
        while (it.hasNext()) {
            if (getArrForVarName(it.next()) != null) {
                i++;
            }
        }
        StringBuilder sb = new StringBuilder();
        sb.append("--- Summary ---\n");
        sb.append(String.format("%-25s%-20s", "Variables:", Integer.valueOf(variableMap.size()))).append(" (").append(i).append(" with arrays)").append("\n").append(String.format("%-25s%-20s", "Functions:", Integer.valueOf(ops.length))).append("\n").append(String.format("%-25s%-20s", "SameDiff Function Defs:", Integer.valueOf(this.sameDiffFunctionInstances.size()))).append("\n").append("Loss function variables: ").append(getLossVariables()).append("\n\n");
        sb.append("--- Variables ---\n");
        HashMap hashMap = new HashMap();
        int i2 = 22;
        int i3 = 8;
        for (String str2 : variableMap.keySet()) {
            String str3 = null;
            Iterator<SameDiffOp> it2 = this.ops.values().iterator();
            while (true) {
                if (!it2.hasNext()) {
                    break;
                }
                SameDiffOp next = it2.next();
                List<String> outputsOfOp = next.getOutputsOfOp();
                if (outputsOfOp != null && outputsOfOp.contains(str2)) {
                    str3 = next.getName();
                    break;
                }
            }
            if (str3 == null) {
                str = "<none>";
            } else {
                DifferentialFunction opById = getOpById(str3);
                str = opById.getOwnName() + "(" + opById.opName() + ")";
            }
            String str4 = str;
            hashMap.put(str2, str4);
            i2 = Math.max(i2, str4.length());
            i3 = Math.max(i3, str2.length());
        }
        String str5 = "%-" + (i3 + 2) + "s%-20s%-20s%-20s%-" + (i2 + 2) + "s%-20s";
        sb.append(String.format(str5, "- Name -", "- Array Shape -", "- Variable Type -", "- Data Type-", "- Output Of Function -", "- Inputs To Functions -")).append("\n");
        for (String str6 : variableMap.keySet()) {
            INDArray arrForVarName = getArrForVarName(str6);
            String str7 = "-";
            if (arrForVarName != null) {
                str7 = Arrays.toString(arrForVarName.shape());
            } else if (variableMap.get(str6).isPlaceHolder() && (placeholderShape = variableMap.get(str6).placeholderShape()) != null) {
                str7 = Arrays.toString(placeholderShape);
            }
            String variableType = getVariable(str6).getVariableType().toString();
            String dataType = getVariable(str6).dataType().toString();
            List<String> inputsForOp = this.variables.get(str6).getInputsForOp();
            String str8 = "";
            if (inputsForOp != null) {
                str8 = inputsForOp.toString();
            }
            sb.append(String.format(str5, str6, str7, variableType, dataType, (String) hashMap.get(str6), str8)).append("\n");
        }
        sb.append("\n\n--- Functions ---\n");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int i4 = 10;
        int i5 = 11;
        int i6 = 17;
        int i7 = 10;
        for (DifferentialFunction differentialFunction : ops) {
            String[] argNames = differentialFunction.argNames();
            String[] outputVariablesNames = differentialFunction.outputVariablesNames();
            String arrays = Arrays.toString(argNames);
            String arrays2 = Arrays.toString(outputVariablesNames);
            i4 = Math.max(i4, arrays.length());
            i5 = Math.max(i5, arrays2.length());
            arrayList.add(arrays);
            arrayList2.add(arrays2);
            i6 = Math.max(i6, (differentialFunction.getOwnName() == null ? differentialFunction.opName() : differentialFunction.getOwnName()).length());
            i7 = Math.max(i7, differentialFunction.getClass().getSimpleName().length());
        }
        String str9 = "%-5s%-" + (i6 + 2) + "s%-" + (i7 + 2) + "s%-" + (i4 + 2) + "s%-" + (i5 + 2) + "s";
        sb.append(String.format(str9, "", "- Function Name -", "- Op -", "- Inputs -", "- Outputs -")).append("\n");
        for (int i8 = 0; i8 < ops.length; i8++) {
            DifferentialFunction differentialFunction2 = ops[i8];
            sb.append(String.format(str9, Integer.valueOf(i8), differentialFunction2.getOwnName() == null ? differentialFunction2.opName() : differentialFunction2.getOwnName(), differentialFunction2.getClass().getSimpleName(), arrayList.get(i8), arrayList2.get(i8))).append("\n");
        }
        if (this.sameDiffFunctionInstances.size() > 0) {
            sb.append("\n\n--- SameDiff Defined Functions ---\n");
            sb.append(String.format("%-20s%-15s%-15s%-15s", "- Name -", "- Variables -", "- Functions -", "- Fn Defs -")).append("\n");
            for (Map.Entry<String, SameDiff> entry : this.sameDiffFunctionInstances.entrySet()) {
                SameDiff value = entry.getValue();
                sb.append(String.format("%-20s%-15s%-15s%-15s", entry.getKey(), Integer.valueOf(value.variableMap().size()), Integer.valueOf(value.ops() == null ? 0 : value.ops().length), Integer.valueOf(value.definedFunctionNames().size()))).append("\n");
            }
        }
        return sb.toString();
    }

    public String newBlockName(String str) {
        if (str == null) {
            return null;
        }
        if (!this.blockNames.contains(str)) {
            this.blockNames.add(str);
            return str;
        }
        int i = 1;
        while (this.blockNames.contains(str + "_" + i)) {
            i++;
        }
        this.blockNames.add(str + "_" + i);
        return str + "_" + i;
    }

    public static SameDiff importFrozenTF(File file) {
        return TFGraphMapper.importGraph(file);
    }

    public static SameDiff importFrozenTF(GraphDef graphDef) {
        return TFGraphMapper.importGraph(graphDef);
    }

    public static SameDiff importFrozenTF(InputStream inputStream) {
        return TFGraphMapper.importGraph(inputStream);
    }

    public String getOpName(String str, boolean z) {
        String nameWithScope = nameWithScope(str);
        if (z && this.ops.containsKey(nameWithScope)) {
            throw new IllegalArgumentException("Op with name \"" + nameWithScope + "\" already exists");
        }
        if (z) {
            return nameWithScope;
        }
        int i = 1;
        if (nameWithScope.contains("_") && nameWithScope.matches(".*_\\d+")) {
            Matcher matcher = Pattern.compile("(.*)_(\\d+)").matcher(nameWithScope);
            if (matcher.find()) {
                i = Integer.parseInt(matcher.group(2));
                nameWithScope = matcher.group(1);
            }
        }
        String str2 = nameWithScope;
        int i2 = i;
        while (true) {
            boolean z2 = false;
            for (String str3 : this.variables.keySet()) {
                if (str3.startsWith(str2 + ":") || str3.equals(str2)) {
                    z2 = true;
                }
            }
            if (!this.ops.containsKey(str2) && !z2) {
                return str2;
            }
            str2 = nameWithScope + "_" + i2;
            i2++;
        }
    }

    public String getOpName(String str) {
        return getOpName(str, false);
    }

    public String generateNewVarName(String str, int i, boolean z) {
        String nameWithScope = nameWithScope(str);
        if (i > 0 && nameWithScope.contains(":")) {
            Matcher matcher = Pattern.compile("(.*):(\\d+)").matcher(nameWithScope);
            if (matcher.find()) {
                i = Integer.parseInt(matcher.group(2)) + 1;
                nameWithScope = matcher.group(1);
            }
        }
        if (!z) {
            nameWithScope = getOpName(nameWithScope);
        }
        if (i > 0) {
            nameWithScope = nameWithScope + ":" + i;
        }
        if (this.variables.containsKey(nameWithScope)) {
            throw new IllegalArgumentException("Variable with name \"" + nameWithScope + "\" already exists");
        }
        return nameWithScope;
    }

    public String generateNewVarName(String str, int i) {
        return generateNewVarName(str, i, true);
    }

    public String generateDistinctCustomVariableName(String str) {
        if (!this.variables.containsKey(str)) {
            return str;
        }
        int i = 1;
        while (this.variables.containsKey(str + "_" + i)) {
            i++;
        }
        return str + "_" + i;
    }

    public String toString() {
        return "SameDiff(nVars=" + this.variables.size() + ",nOps=" + this.ops.size() + ")";
    }

    public SDVariable ifCond(@NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda2, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda3) {
        if (sameDiffNoArgSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda2 == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda3 == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        return ifCond(null, null, sameDiffNoArgSingleLambda, sameDiffNoArgSingleLambda2, sameDiffNoArgSingleLambda3);
    }

    public SDVariable ifCond(String str, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda2, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda3) {
        if (sameDiffNoArgSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda2 == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda3 == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        return ifCond(null, str, sameDiffNoArgSingleLambda, sameDiffNoArgSingleLambda2, sameDiffNoArgSingleLambda3);
    }

    public SDVariable ifCond(String str, String str2, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda2, @NonNull SameDiffNoArgSingleLambda sameDiffNoArgSingleLambda3) {
        if (sameDiffNoArgSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda2 == null) {
            throw new NullPointerException("trueBody is marked non-null but is null");
        }
        if (sameDiffNoArgSingleLambda3 == null) {
            throw new NullPointerException("falseBody is marked non-null but is null");
        }
        NameScope withNameScope = this.sd.withNameScope(newBlockName(str2 == null ? "if" : str2));
        NameScope withNameScope2 = withNameScope("cond");
        final SDVariable define = sameDiffNoArgSingleLambda.define(this);
        withNameScope2.close();
        if (define.dataType() != DataType.BOOL) {
            Iterator<SDVariable> it = getVariablesInScope(withNameScope).iterator();
            while (it.hasNext()) {
                getVariables().remove(it.next().name());
            }
            for (SameDiffOp sameDiffOp : getOpsInScope(withNameScope)) {
                Iterator<String> it2 = sameDiffOp.getInputsToOp().iterator();
                while (it2.hasNext()) {
                    removeArgFromOp(it2.next(), sameDiffOp.getOp());
                }
                getOps().remove(sameDiffOp.getName());
            }
            throw new IllegalStateException("Can not use " + define.name() + " as the condition of an If statement, the condition must be a boolean.");
        }
        final HashMap hashMap = new HashMap();
        final HashSet newHashSet = Sets.newHashSet(variableMap().keySet());
        addArgumentInterceptor(new ArgumentInterceptor() { // from class: org.nd4j.autodiff.samediff.SameDiff.2
            @Override // org.nd4j.autodiff.samediff.ArgumentInterceptor
            public SDVariable intercept(SDVariable sDVariable) {
                if (!newHashSet.contains(sDVariable.name())) {
                    return sDVariable;
                }
                if (hashMap.containsKey(sDVariable.name())) {
                    return ((SDVariable[]) hashMap.get(sDVariable.name()))[1];
                }
                SDVariable[] switchOp = SameDiff.this.switchOp(sDVariable, define);
                hashMap.put(sDVariable.name(), switchOp);
                return switchOp[1];
            }
        });
        NameScope withNameScope3 = withNameScope("trueBody");
        SDVariable define2 = sameDiffNoArgSingleLambda2.define(this);
        removeArgumentInterceptor();
        if (newHashSet.contains(define2.name())) {
            SDVariable[] switchOp = switchOp(define2, define);
            hashMap.put(define2.name(), switchOp);
            define2 = switchOp[1];
        }
        withNameScope3.close();
        final HashSet newHashSet2 = Sets.newHashSet(variableMap().keySet());
        this.sd.addArgumentInterceptor(new ArgumentInterceptor() { // from class: org.nd4j.autodiff.samediff.SameDiff.3
            @Override // org.nd4j.autodiff.samediff.ArgumentInterceptor
            public SDVariable intercept(SDVariable sDVariable) {
                if (!newHashSet2.contains(sDVariable.name())) {
                    return sDVariable;
                }
                if (hashMap.containsKey(sDVariable.name())) {
                    return ((SDVariable[]) hashMap.get(sDVariable.name()))[0];
                }
                SDVariable[] switchOp2 = SameDiff.this.switchOp(sDVariable, define);
                hashMap.put(sDVariable.name(), switchOp2);
                return switchOp2[0];
            }
        });
        NameScope withNameScope4 = withNameScope("falseBody");
        SDVariable define3 = sameDiffNoArgSingleLambda3.define(this);
        removeArgumentInterceptor();
        if (newHashSet2.contains(define3.name())) {
            SDVariable[] switchOp2 = switchOp(define3, define);
            hashMap.put(define3.name(), switchOp2);
            define3 = switchOp2[0];
        }
        withNameScope4.close();
        SDVariable merge = merge(define2, define3);
        withNameScope.close();
        return updateVariableNameAndReference(merge, str);
    }

    public SDVariable[] whileLoop(@NonNull SDVariable[] sDVariableArr, @NonNull SameDiffSingleLambda sameDiffSingleLambda, @NonNull SameDiffLambda sameDiffLambda) {
        if (sDVariableArr == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (sameDiffSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffLambda == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        return whileLoop(null, null, sDVariableArr, sameDiffSingleLambda, sameDiffLambda);
    }

    public SDVariable[] whileLoop(String str, @NonNull SDVariable[] sDVariableArr, @NonNull SameDiffSingleLambda sameDiffSingleLambda, @NonNull SameDiffLambda sameDiffLambda) {
        if (sDVariableArr == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (sameDiffSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffLambda == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        return whileLoop(null, str, sDVariableArr, sameDiffSingleLambda, sameDiffLambda);
    }

    public SDVariable[] whileLoop(String[] strArr, String str, @NonNull SDVariable[] sDVariableArr, @NonNull SameDiffSingleLambda sameDiffSingleLambda, @NonNull SameDiffLambda sameDiffLambda) {
        if (sDVariableArr == null) {
            throw new NullPointerException("loopVars is marked non-null but is null");
        }
        if (sameDiffSingleLambda == null) {
            throw new NullPointerException("cond is marked non-null but is null");
        }
        if (sameDiffLambda == null) {
            throw new NullPointerException("body is marked non-null but is null");
        }
        final String newBlockName = newBlockName(str == null ? "while" : str);
        NameScope withNameScope = withNameScope(newBlockName);
        SDVariable[] sDVariableArr2 = new SDVariable[sDVariableArr.length];
        for (int i = 0; i < sDVariableArr.length; i++) {
            sDVariableArr2[i] = new Enter(this, newBlockName, sDVariableArr[i]).outputVariable();
        }
        SDVariable[] sDVariableArr3 = new SDVariable[sDVariableArr.length];
        Merge[] mergeArr = new Merge[sDVariableArr.length];
        for (int i2 = 0; i2 < sDVariableArr.length; i2++) {
            mergeArr[i2] = new Merge(this, sDVariableArr2[i2], sDVariableArr2[i2]);
            sDVariableArr3[i2] = mergeArr[i2].outputVariable();
        }
        NameScope withNameScope2 = withNameScope("cond");
        SDVariable define = sameDiffSingleLambda.define(this, sDVariableArr3);
        withNameScope2.close();
        if (define.dataType() != DataType.BOOL) {
            throw new IllegalStateException("Can not use " + define.name() + " as the condition of an While loop, the condition must be a boolean.");
        }
        final HashSet newHashSet = Sets.newHashSet();
        SDVariable[] sDVariableArr4 = new SDVariable[sDVariableArr.length];
        SDVariable[] sDVariableArr5 = new SDVariable[sDVariableArr.length];
        for (int i3 = 0; i3 < sDVariableArr.length; i3++) {
            SDVariable[] switchOp = switchOp(sDVariableArr3[i3], define);
            sDVariableArr4[i3] = switchOp[1];
            newHashSet.add(switchOp[1].name());
            sDVariableArr5[i3] = new Exit(this, switchOp[0]).outputVariable();
        }
        final HashSet newHashSet2 = Sets.newHashSet(variableMap().keySet());
        final HashMap hashMap = new HashMap();
        addArgumentInterceptor(new ArgumentInterceptor() { // from class: org.nd4j.autodiff.samediff.SameDiff.4
            @Override // org.nd4j.autodiff.samediff.ArgumentInterceptor
            public SDVariable intercept(SDVariable sDVariable) {
                if (newHashSet2.contains(sDVariable.name()) && !newHashSet.contains(sDVariable.name())) {
                    if (hashMap.containsKey(sDVariable.name())) {
                        return (SDVariable) hashMap.get(sDVariable.name());
                    }
                    SDVariable outputVariable = new Enter(this, newBlockName, sDVariable, true).outputVariable();
                    hashMap.put(sDVariable.name(), outputVariable);
                    return outputVariable;
                }
                return sDVariable;
            }
        });
        NameScope withNameScope3 = withNameScope("body");
        SDVariable[] define2 = sameDiffLambda.define(this, sDVariableArr4);
        withNameScope3.close();
        removeArgumentInterceptor();
        for (int i4 = 0; i4 < sDVariableArr.length; i4++) {
            mergeArr[i4].replaceArg(1, new NextIteration(this, define2[i4]).outputVariable());
        }
        withNameScope.close();
        return updateVariableNamesAndReferences(sDVariableArr5, strArr);
    }

    public Map<String, Variable> getVariables() {
        return this.variables;
    }

    public Map<String, SameDiffOp> getOps() {
        return this.ops;
    }

    public Map<Long, InferenceSession> getSessions() {
        return this.sessions;
    }

    public TrainingConfig getTrainingConfig() {
        return this.trainingConfig;
    }

    public boolean isInitializedTraining() {
        return this.initializedTraining;
    }

    public Map<String, GradientUpdater> getUpdaterMap() {
        return this.updaterMap;
    }

    public boolean isDebugMode() {
        return this.debugMode;
    }

    public Stack<ArgumentInterceptor> getArgumentInterceptors() {
        return this.argumentInterceptors;
    }

    public Set<ArgumentInterceptor> getPausedArgumentInterceptors() {
        return this.pausedArgumentInterceptors;
    }

    public boolean isLogExecution() {
        return this.logExecution;
    }

    public void setLogExecution(boolean z) {
        this.logExecution = z;
    }

    public SameDiff getParent() {
        return this.parent;
    }

    public SameDiff getChild() {
        return this.child;
    }
}
