package org.nd4j.autodiff.samediff;

import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;
import lombok.NonNull;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.ShapeOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.OpProfiler;

/* loaded from: input_file:org/nd4j/autodiff/samediff/SameDiffOpExecutioner.class */
public class SameDiffOpExecutioner implements OpExecutioner, OpProfiler.OpProfilerListener {
    private AtomicReference<Op> opAtomicReference;
    private OpExecutioner backendExecutioner = Nd4j.getExecutioner();
    private Map<INDArray, SDVariable> variables = new IdentityHashMap();
    private SameDiff sameDiff = SameDiff.create();

    public SameDiffOpExecutioner() {
        OpProfiler.getInstance().addListener(this);
    }

    private Op processOp(Op op) {
        if (this.opAtomicReference == null) {
            this.opAtomicReference = new AtomicReference<>(op);
        }
        for (INDArray iNDArray : new INDArray[]{op.x(), op.y(), op.z()}) {
            if (iNDArray != null && !this.variables.containsKey(iNDArray)) {
                this.variables.put(iNDArray, this.sameDiff.var(UUID.randomUUID().toString(), iNDArray));
            }
        }
        if (op.x() == null || op.y() == null) {
            this.variables.put(op.z(), this.sameDiff.invoke(op, this.variables.get(op.x())));
        } else {
            this.variables.put(op.z(), this.sameDiff.invoke(op, this.variables.get(op.x()), this.variables.get(op.y())));
        }
        return op;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public String getLastOp() {
        return this.opAtomicReference.get().opName();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op) {
        return processOp(op);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllRows(Op op) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllColumns(Op op) {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(TransformOp transformOp) {
        return processOp(transformOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Accumulation execAndReturn(Accumulation accumulation) {
        return (Accumulation) processOp(accumulation).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Accumulation execAndReturn(Variance variance, boolean z) {
        return (Accumulation) processOp(variance);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public IndexAccumulation execAndReturn(IndexAccumulation indexAccumulation) {
        return (IndexAccumulation) processOp(indexAccumulation);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ScalarOp scalarOp) {
        return processOp(scalarOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(BroadcastOp broadcastOp) {
        return processOp(broadcastOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ShapeOp shapeOp) {
        return this.backendExecutioner.execAndReturn(shapeOp);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op, int... iArr) {
        return processOp(op);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(Accumulation accumulation, int... iArr) {
        return processOp(accumulation).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        return processOp(broadcastOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(Variance variance, boolean z, int... iArr) {
        return processOp(variance).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        return processOp(indexAccumulation).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(Op op) {
        return processOp(op).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ExecutionMode executionMode() {
        return this.backendExecutioner.executionMode();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.backendExecutioner.setExecutionMode(executionMode);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(MetaOp metaOp) {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(GridOp gridOp) {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(Aggregate aggregate) {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(ShapeOp shapeOp) {
        this.backendExecutioner.exec(shapeOp);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public <T extends Aggregate> void exec(Batch<T> batch) {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(List<Aggregate> list) {
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(RandomOp randomOp) {
        return processOp(randomOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(RandomOp randomOp, Random random) {
        return processOp(randomOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Properties getEnvironmentInformation() {
        return this.backendExecutioner.getEnvironmentInformation();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setProfilingMode(OpExecutioner.ProfilingMode profilingMode) {
        this.backendExecutioner.setProfilingMode(profilingMode);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ProfilingMode getProfilingMode() {
        return this.backendExecutioner.getProfilingMode();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public TADManager getTADManager() {
        return this.backendExecutioner.getTADManager();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void printEnvironmentInformation() {
        this.backendExecutioner.printEnvironmentInformation();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void push() {
        this.backendExecutioner.push();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void commit() {
        this.backendExecutioner.commit();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray thresholdEncode(INDArray iNDArray, double d) {
        return this.backendExecutioner.thresholdEncode(iNDArray, d);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray thresholdEncode(INDArray iNDArray, double d, Integer num) {
        return this.backendExecutioner.thresholdEncode(iNDArray, d, num);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray thresholdDecode(INDArray iNDArray, INDArray iNDArray2) {
        return this.backendExecutioner.thresholdDecode(iNDArray, iNDArray2);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public long bitmapEncode(INDArray iNDArray, INDArray iNDArray2, double d) {
        return this.backendExecutioner.bitmapEncode(iNDArray, iNDArray2, d);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray bitmapEncode(INDArray iNDArray, double d) {
        return this.backendExecutioner.bitmapEncode(iNDArray, d);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray bitmapDecode(INDArray iNDArray, INDArray iNDArray2) {
        return this.backendExecutioner.bitmapDecode(iNDArray, iNDArray2);
    }

    @Override // org.nd4j.linalg.profiler.OpProfiler.OpProfilerListener
    public void invoke(Op op) {
        processOp(op);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Map<String, CustomOpDescriptor> getCustomOperations() {
        return this.backendExecutioner.getCustomOperations();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void exec(CustomOp customOp) {
        this.backendExecutioner.exec(customOp);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public List<long[]> calculateOutputShape(CustomOp customOp) {
        return this.backendExecutioner.calculateOutputShape(customOp);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray[] allocateOutputArrays(CustomOp customOp) {
        return this.backendExecutioner.allocateOutputArrays(customOp);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void registerGraph(long j, Pointer pointer) {
        this.backendExecutioner.registerGraph(j, pointer);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Map<String, INDArray> executeGraph(long j, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> map2) {
        if (map == null) {
            throw new NullPointerException("map is marked @NonNull but is null");
        }
        if (map2 == null) {
            throw new NullPointerException("reverseMap is marked @NonNull but is null");
        }
        return this.backendExecutioner.executeGraph(j, map, map2);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void forgetGraph(long j) {
        this.backendExecutioner.forgetGraph(j);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void enableDebugMode(boolean z) {
        this.backendExecutioner.enableDebugMode(z);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void enableVerboseMode(boolean z) {
        this.backendExecutioner.enableVerboseMode(z);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setElementsThreshold(int i) {
        this.backendExecutioner.setElementsThreshold(i);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setTadThreshold(int i) {
        this.backendExecutioner.setTadThreshold(i);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ExecutionerType type() {
        return this.backendExecutioner.type();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public boolean isVerbose() {
        return this.backendExecutioner.isVerbose();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public boolean isDebug() {
        return this.backendExecutioner.isDebug();
    }

    public Map<INDArray, SDVariable> getVariables() {
        return this.variables;
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public AtomicReference<Op> getOpAtomicReference() {
        return this.opAtomicReference;
    }

    public OpExecutioner getBackendExecutioner() {
        return this.backendExecutioner;
    }
}
