package ai.djl;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SymbolBlock;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/BaseModel.class */
public abstract class BaseModel implements Model {
    private static final Logger logger = LoggerFactory.getLogger(BaseModel.class);
    private static final int MODEL_VERSION = 1;
    protected Path modelDir;
    protected Block block;
    protected String modelName;
    protected NDManager manager;
    protected DataType dataType;
    protected PairList<String, Shape> inputData;
    protected Map<String, Object> artifacts = new ConcurrentHashMap();
    protected Map<String, String> properties = new ConcurrentHashMap();

    protected BaseModel(String str) {
        this.modelName = str;
    }

    @Override // ai.djl.Model
    public Block getBlock() {
        return this.block;
    }

    @Override // ai.djl.Model
    public void setBlock(Block block) {
        this.block = block;
    }

    @Override // ai.djl.Model
    public String getName() {
        return this.modelName;
    }

    @Override // ai.djl.Model
    public NDManager getNDManager() {
        return this.manager;
    }

    @Override // ai.djl.Model
    public Trainer newTrainer(TrainingConfig trainingConfig) {
        throw new UnsupportedOperationException("Not supported!");
    }

    @Override // ai.djl.Model
    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator, Device device) {
        return new Predictor<>(this, translator, device, false);
    }

    @Override // ai.djl.Model
    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }

    @Override // ai.djl.Model
    public DataType getDataType() {
        return this.dataType;
    }

    @Override // ai.djl.Model
    public void load(InputStream inputStream, Map<String, ?> map) throws IOException, MalformedModelException {
        throw new UnsupportedOperationException("Not supported!");
    }

    @Override // ai.djl.Model, java.lang.AutoCloseable
    public void close() {
        this.manager.close();
    }

    @Override // ai.djl.Model
    public PairList<String, Shape> describeInput() {
        if (this.inputData == null) {
            this.inputData = this.block.describeInput();
        }
        return this.inputData;
    }

    @Override // ai.djl.Model
    public PairList<String, Shape> describeOutput() {
        if (this.block instanceof SymbolBlock) {
            return ((SymbolBlock) this.block).describeOutput();
        }
        NDList nDList = new NDList();
        Iterator<Pair<String, Shape>> it = describeInput().iterator();
        while (it.hasNext()) {
            nDList.add(this.manager.ones(it.next().getValue()));
        }
        ArrayList arrayList = new ArrayList();
        Shape[] shapeArr = (Shape[]) this.block.forward(new ParameterStore(this.manager, true), nDList, false).stream().map((v0) -> {
            return v0.getShape();
        }).toArray(i -> {
            return new Shape[i];
        });
        for (int i2 = 0; i2 < shapeArr.length; i2 += MODEL_VERSION) {
            arrayList.add("output" + i2);
        }
        return new PairList<>(arrayList, Arrays.asList(shapeArr));
    }

    @Override // ai.djl.Model
    public String[] getArtifactNames() {
        throw new UnsupportedOperationException("Not supported!");
    }

    @Override // ai.djl.Model
    public <T> T getArtifact(String str, Function<InputStream, T> function) throws IOException {
        try {
            return (T) this.artifacts.computeIfAbsent(str, str2 -> {
                try {
                    InputStream artifactAsStream = getArtifactAsStream(str);
                    try {
                        Object apply = function.apply(artifactAsStream);
                        if (artifactAsStream != null) {
                            artifactAsStream.close();
                        }
                        return apply;
                    } finally {
                    }
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            });
        } catch (RuntimeException e) {
            if (e.getCause() instanceof IOException) {
                throw ((IOException) e.getCause());
            }
            throw e;
        }
    }

    @Override // ai.djl.Model
    public URL getArtifact(String str) throws IOException {
        if (str == null) {
            throw new IllegalArgumentException("artifactName cannot be null");
        }
        Path resolve = this.modelDir.resolve(str);
        if (Files.exists(resolve, new LinkOption[0]) && Files.isReadable(resolve)) {
            return resolve.toUri().toURL();
        }
        throw new FileNotFoundException("File not found: " + resolve);
    }

    @Override // ai.djl.Model
    public InputStream getArtifactAsStream(String str) throws IOException {
        return new BufferedInputStream(getArtifact(str).openStream());
    }

    @Override // ai.djl.Model
    public void setProperty(String str, String str2) {
        this.properties.put(str, str2);
    }

    @Override // ai.djl.Model
    public String getProperty(String str) {
        return this.properties.get(str);
    }

    protected void setModelDir(Path path) {
        this.modelDir = Utils.getNestedModelDir(path);
    }

    @Override // ai.djl.Model
    public void save(Path path, String str) throws IOException {
        if (str == null || str.isEmpty()) {
            str = this.modelName;
        }
        if (Files.notExists(path, new LinkOption[0])) {
            Files.createDirectories(path, new FileAttribute[0]);
        }
        if (this.block == null || !this.block.isInitialized()) {
            throw new IllegalStateException("Model has not be trained or loaded yet.");
        }
        String property = getProperty("Epoch");
        DataOutputStream dataOutputStream = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(path.resolve(String.format(Locale.ROOT, "%s-%04d.params", str, Integer.valueOf(property == null ? Utils.getCurrentEpoch(path, str) + MODEL_VERSION : Integer.parseInt(property)))), new OpenOption[0])));
        try {
            dataOutputStream.writeBytes("DJL@");
            dataOutputStream.writeInt(MODEL_VERSION);
            dataOutputStream.writeUTF(str);
            dataOutputStream.writeUTF(this.dataType.name());
            this.inputData = this.block.describeInput();
            dataOutputStream.writeInt(this.inputData.size());
            Iterator<Pair<String, Shape>> it = this.inputData.iterator();
            while (it.hasNext()) {
                Pair<String, Shape> next = it.next();
                String key = next.getKey();
                if (key == null) {
                    dataOutputStream.writeUTF("");
                } else {
                    dataOutputStream.writeUTF(key);
                }
                dataOutputStream.write(next.getValue().getEncoded());
            }
            dataOutputStream.writeInt(this.properties.size());
            for (Map.Entry<String, String> entry : this.properties.entrySet()) {
                dataOutputStream.writeUTF(entry.getKey());
                dataOutputStream.writeUTF(entry.getValue());
            }
            this.block.saveParameters(dataOutputStream);
            dataOutputStream.close();
            this.modelDir = path.toAbsolutePath();
        } catch (Throwable th) {
            try {
                dataOutputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    @Override // ai.djl.Model
    public Path getModelPath() {
        return this.modelDir;
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Model (\n\tName: ").append(this.modelName);
        if (this.modelDir != null) {
            sb.append("\n\tModel location: ").append(this.modelDir.toAbsolutePath());
        }
        sb.append("\n\tData Type: ").append(this.dataType);
        for (Map.Entry<String, String> entry : this.properties.entrySet()) {
            sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
        }
        sb.append("\n)");
        return sb.toString();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            logger.warn("Model: {} was not closed explicitly.", this.modelName);
            this.manager.close();
        }
        super.finalize();
    }

    protected Path paramPathResolver(String str, Map<String, ?> map) throws IOException {
        int parseInt;
        Object obj = null;
        if (map != null) {
            obj = map.get("epoch");
        }
        if (obj == null) {
            parseInt = Utils.getCurrentEpoch(this.modelDir, str);
            if (parseInt == -1) {
                return null;
            }
        } else {
            parseInt = Integer.parseInt(obj.toString());
        }
        return this.modelDir.resolve(String.format(Locale.ROOT, "%s-%04d.params", str, Integer.valueOf(parseInt)));
    }

    protected boolean readParameters(Path path, Map<String, ?> map) throws IOException, MalformedModelException {
        logger.debug("Try to load model from {}", path);
        DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(Files.newInputStream(path, new OpenOption[0])));
        try {
            byte[] bArr = new byte[4];
            dataInputStream.readFully(bArr);
            if (!"DJL@".equals(new String(bArr, StandardCharsets.US_ASCII))) {
                dataInputStream.close();
                return false;
            }
            int readInt = dataInputStream.readInt();
            if (readInt != MODEL_VERSION) {
                throw new IOException("Unsupported model version: " + readInt);
            }
            logger.debug("Loading saved model: {} parameter", dataInputStream.readUTF());
            this.dataType = DataType.valueOf(dataInputStream.readUTF());
            int readInt2 = dataInputStream.readInt();
            this.inputData = new PairList<>();
            for (int i = 0; i < readInt2; i += MODEL_VERSION) {
                this.inputData.add(dataInputStream.readUTF(), Shape.decode(dataInputStream));
            }
            int readInt3 = dataInputStream.readInt();
            for (int i2 = 0; i2 < readInt3; i2 += MODEL_VERSION) {
                this.properties.put(dataInputStream.readUTF(), dataInputStream.readUTF());
            }
            this.block.loadParameters(this.manager, dataInputStream);
            logger.debug("DJL model loaded successfully");
            dataInputStream.close();
            return true;
        } catch (Throwable th) {
            try {
                dataInputStream.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }
}
