package ai.djl.engine.rust;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.nn.AbstractSymbolBlock;
import ai.djl.nn.ParameterList;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicReference;

/* loaded from: input_file:ai/djl/engine/rust/RsSymbolBlock.class */
public class RsSymbolBlock extends AbstractSymbolBlock implements AutoCloseable {
    private AtomicReference<Long> handle;
    private String uid;
    private RsNDManager manager;

    public RsSymbolBlock(RsNDManager rsNDManager, long j) {
        this.handle = new AtomicReference<>(Long.valueOf(j));
        this.manager = rsNDManager;
        this.inputNames = Arrays.asList(RustLibrary.getInputNames(j));
        this.uid = String.valueOf(j);
        rsNDManager.attachInternal(this.uid, new AutoCloseable[]{this});
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        if (this.inputNames.size() != nDList.size()) {
            throw new IllegalArgumentException("Input size mismatch, requires: " + this.inputNames);
        }
        RsNDManager newSubManager = this.manager.newSubManager();
        try {
            long[] jArr = new long[nDList.size()];
            for (int i = 0; i < nDList.size(); i++) {
                jArr[i] = ((Long) newSubManager.mo176from((NDArray) nDList.get(i)).getHandle()).longValue();
            }
            RsNDArray rsNDArray = new RsNDArray(this.manager, RustLibrary.runInference(this.handle.get().longValue(), jArr));
            rsNDArray.attach(nDList.head().getManager());
            NDList nDList2 = new NDList(new NDArray[]{rsNDArray});
            if (newSubManager != null) {
                newSubManager.close();
            }
            return nDList2;
        } catch (Throwable th) {
            if (newSubManager != null) {
                try {
                    newSubManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        if (this.handle.getAndSet(null) != null) {
            this.manager.detachInternal(this.uid);
            this.manager = null;
        }
    }

    public Long getHandle() {
        Long l = this.handle.get();
        if (l == null) {
            throw new IllegalStateException("Rust model handle has been released!");
        }
        return l;
    }

    public ParameterList getDirectParameters() {
        throw new UnsupportedOperationException("Not yet supported");
    }
}
