package org.deeplearning4j.util;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipOutputStream;
import lombok.NonNull;
import org.apache.commons.io.output.CloseShieldOutputStream;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/util/ModelSerializer.class */
public class ModelSerializer {
    public static final String OLD_UPDATER_BIN = "updater.bin";
    public static final String UPDATER_BIN = "updaterState.bin";

    private ModelSerializer() {
    }

    public static void writeModel(@NonNull Model model, @NonNull File file, boolean z) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (file == null) {
            throw new NullPointerException("file");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file));
        Throwable th = null;
        try {
            try {
                writeModel(model, bufferedOutputStream, z);
                bufferedOutputStream.flush();
                bufferedOutputStream.close();
                if (bufferedOutputStream != null) {
                    if (0 == 0) {
                        bufferedOutputStream.close();
                        return;
                    }
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (bufferedOutputStream != null) {
                if (th != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th4;
        }
    }

    public static void writeModel(@NonNull Model model, @NonNull String str, boolean z) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (str == null) {
            throw new NullPointerException("path");
        }
        BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(str));
        Throwable th = null;
        try {
            try {
                writeModel(model, bufferedOutputStream, z);
                bufferedOutputStream.flush();
                bufferedOutputStream.close();
                if (bufferedOutputStream != null) {
                    if (0 == 0) {
                        bufferedOutputStream.close();
                        return;
                    }
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (bufferedOutputStream != null) {
                if (th != null) {
                    try {
                        bufferedOutputStream.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    bufferedOutputStream.close();
                }
            }
            throw th4;
        }
    }

    public static void writeModel(@NonNull Model model, @NonNull OutputStream outputStream, boolean z) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (outputStream == null) {
            throw new NullPointerException("stream");
        }
        ZipOutputStream zipOutputStream = new ZipOutputStream(new CloseShieldOutputStream(outputStream));
        String str = "";
        if (model instanceof MultiLayerNetwork) {
            str = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson();
        } else if (model instanceof ComputationGraph) {
            str = ((ComputationGraph) model).getConfiguration().toJson();
        }
        zipOutputStream.putNextEntry(new ZipEntry("configuration.json"));
        writeEntry(new ByteArrayInputStream(str.getBytes()), zipOutputStream);
        zipOutputStream.putNextEntry(new ZipEntry("coefficients.bin"));
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream);
        Nd4j.write(model.params(), dataOutputStream);
        dataOutputStream.flush();
        dataOutputStream.close();
        writeEntry(new ByteArrayInputStream(byteArrayOutputStream.toByteArray()), zipOutputStream);
        if (z) {
            INDArray iNDArray = null;
            if (model instanceof MultiLayerNetwork) {
                iNDArray = ((MultiLayerNetwork) model).getUpdater().getStateViewArray();
            } else if (model instanceof ComputationGraph) {
                iNDArray = ((ComputationGraph) model).getUpdater().getStateViewArray();
            }
            if (iNDArray != null && iNDArray.length() > 0) {
                zipOutputStream.putNextEntry(new ZipEntry(UPDATER_BIN));
                ByteArrayOutputStream byteArrayOutputStream2 = new ByteArrayOutputStream();
                DataOutputStream dataOutputStream2 = new DataOutputStream(byteArrayOutputStream2);
                Nd4j.write(iNDArray, dataOutputStream2);
                dataOutputStream2.flush();
                dataOutputStream2.close();
                writeEntry(new ByteArrayInputStream(byteArrayOutputStream2.toByteArray()), zipOutputStream);
            }
        }
        zipOutputStream.flush();
        zipOutputStream.close();
    }

    private static void writeEntry(InputStream inputStream, ZipOutputStream zipOutputStream) throws IOException {
        byte[] bArr = new byte[1024];
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                return;
            } else {
                zipOutputStream.write(bArr, 0, read);
            }
        }
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file");
        }
        return restoreMultiLayerNetwork(file, true);
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file, boolean z) throws IOException {
        if (file == null) {
            throw new NullPointerException("file");
        }
        ZipFile zipFile = new ZipFile(file);
        boolean z2 = false;
        boolean z3 = false;
        boolean z4 = false;
        boolean z5 = false;
        String str = "";
        INDArray iNDArray = null;
        Updater updater = null;
        INDArray iNDArray2 = null;
        ZipEntry entry = zipFile.getEntry("configuration.json");
        if (entry != null) {
            InputStream inputStream = zipFile.getInputStream(entry);
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                sb.append(readLine).append("\n");
            }
            str = sb.toString();
            bufferedReader.close();
            inputStream.close();
            z2 = true;
        }
        ZipEntry entry2 = zipFile.getEntry("coefficients.bin");
        if (entry2 != null) {
            DataInputStream dataInputStream = new DataInputStream(new BufferedInputStream(zipFile.getInputStream(entry2)));
            iNDArray = Nd4j.read(dataInputStream);
            dataInputStream.close();
            z3 = true;
        }
        if (z) {
            ZipEntry entry3 = zipFile.getEntry(OLD_UPDATER_BIN);
            if (entry3 != null) {
                try {
                    updater = (Updater) new ObjectInputStream(zipFile.getInputStream(entry3)).readObject();
                    z4 = true;
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
            }
            ZipEntry entry4 = zipFile.getEntry(UPDATER_BIN);
            if (entry4 != null) {
                DataInputStream dataInputStream2 = new DataInputStream(zipFile.getInputStream(entry4));
                iNDArray2 = Nd4j.read(dataInputStream2);
                dataInputStream2.close();
                z5 = true;
            }
        }
        ZipEntry entry5 = zipFile.getEntry("preprocessor.bin");
        if (entry5 != null) {
            try {
            } catch (ClassNotFoundException e2) {
                throw new RuntimeException(e2);
            }
        }
        zipFile.close();
        if (!z2 || !z3) {
            throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + z2 + "], gotCoefficients: [" + z3 + "], gotUpdater: [" + z5 + "]");
        }
        MultiLayerNetwork multiLayerNetwork = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(str));
        multiLayerNetwork.init(iNDArray, false);
        if (z5 && iNDArray2 != null) {
            multiLayerNetwork.getUpdater().setStateViewArray(multiLayerNetwork, iNDArray2, false);
        } else if (z4 && updater != null) {
            multiLayerNetwork.setUpdater(updater);
        }
        return multiLayerNetwork;
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream inputStream, boolean z) throws IOException {
        if (inputStream == null) {
            throw new NullPointerException("is");
        }
        File createTempFile = File.createTempFile("restore", "multiLayer");
        Files.copy(inputStream, Paths.get(createTempFile.getAbsolutePath(), new String[0]), StandardCopyOption.REPLACE_EXISTING);
        return restoreMultiLayerNetwork(createTempFile, z);
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream inputStream) throws IOException {
        if (inputStream == null) {
            throw new NullPointerException("is");
        }
        return restoreMultiLayerNetwork(inputStream, true);
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String str) throws IOException {
        if (str == null) {
            throw new NullPointerException("path");
        }
        return restoreMultiLayerNetwork(new File(str), true);
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String str, boolean z) throws IOException {
        if (str == null) {
            throw new NullPointerException("path");
        }
        return restoreMultiLayerNetwork(new File(str), z);
    }

    public static ComputationGraph restoreComputationGraph(@NonNull String str) throws IOException {
        if (str == null) {
            throw new NullPointerException("path");
        }
        return restoreComputationGraph(new File(str), true);
    }

    public static ComputationGraph restoreComputationGraph(@NonNull String str, boolean z) throws IOException {
        if (str == null) {
            throw new NullPointerException("path");
        }
        return restoreComputationGraph(new File(str), z);
    }

    public static ComputationGraph restoreComputationGraph(@NonNull InputStream inputStream, boolean z) throws IOException {
        if (inputStream == null) {
            throw new NullPointerException("is");
        }
        File createTempFile = File.createTempFile("restore", "compGraph");
        Files.copy(inputStream, Paths.get(createTempFile.getAbsolutePath(), new String[0]), StandardCopyOption.REPLACE_EXISTING);
        return restoreComputationGraph(createTempFile, z);
    }

    public static ComputationGraph restoreComputationGraph(@NonNull InputStream inputStream) throws IOException {
        if (inputStream == null) {
            throw new NullPointerException("is");
        }
        return restoreComputationGraph(inputStream, true);
    }

    public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException {
        if (file == null) {
            throw new NullPointerException("file");
        }
        return restoreComputationGraph(file, true);
    }

    public static ComputationGraph restoreComputationGraph(@NonNull File file, boolean z) throws IOException {
        if (file == null) {
            throw new NullPointerException("file");
        }
        ZipFile zipFile = new ZipFile(file);
        boolean z2 = false;
        boolean z3 = false;
        boolean z4 = false;
        boolean z5 = false;
        String str = "";
        INDArray iNDArray = null;
        ComputationGraphUpdater computationGraphUpdater = null;
        INDArray iNDArray2 = null;
        ZipEntry entry = zipFile.getEntry("configuration.json");
        if (entry != null) {
            InputStream inputStream = zipFile.getInputStream(entry);
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            StringBuilder sb = new StringBuilder();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    break;
                }
                sb.append(readLine).append("\n");
            }
            str = sb.toString();
            bufferedReader.close();
            inputStream.close();
            z2 = true;
        }
        ZipEntry entry2 = zipFile.getEntry("coefficients.bin");
        if (entry2 != null) {
            DataInputStream dataInputStream = new DataInputStream(zipFile.getInputStream(entry2));
            iNDArray = Nd4j.read(dataInputStream);
            dataInputStream.close();
            z3 = true;
        }
        if (z) {
            ZipEntry entry3 = zipFile.getEntry(OLD_UPDATER_BIN);
            if (entry3 != null) {
                try {
                    computationGraphUpdater = (ComputationGraphUpdater) new ObjectInputStream(zipFile.getInputStream(entry3)).readObject();
                    z4 = true;
                } catch (ClassNotFoundException e) {
                    throw new RuntimeException(e);
                }
            }
            ZipEntry entry4 = zipFile.getEntry(UPDATER_BIN);
            if (entry4 != null) {
                DataInputStream dataInputStream2 = new DataInputStream(zipFile.getInputStream(entry4));
                iNDArray2 = Nd4j.read(dataInputStream2);
                dataInputStream2.close();
                z5 = true;
            }
        }
        ZipEntry entry5 = zipFile.getEntry("preprocessor.bin");
        if (entry5 != null) {
            try {
            } catch (ClassNotFoundException e2) {
                throw new RuntimeException(e2);
            }
        }
        zipFile.close();
        if (!z2 || !z3) {
            throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + z2 + "], gotCoefficients: [" + z3 + "], gotUpdater: [" + z5 + "]");
        }
        ComputationGraph computationGraph = new ComputationGraph(ComputationGraphConfiguration.fromJson(str));
        computationGraph.init(iNDArray, false);
        if (z5 && iNDArray2 != null) {
            computationGraph.getUpdater().setStateViewArray(iNDArray2);
        } else if (z4 && computationGraphUpdater != null) {
            computationGraph.setUpdater(computationGraphUpdater);
        }
        return computationGraph;
    }

    /* JADX WARN: Code restructure failed: missing block: B:37:0x005a, code lost:
    
        r0.setArchitectureType(org.nd4j.linalg.heartbeat.reports.Task.ArchitectureType.RBM);
     */
    /* JADX WARN: Code restructure failed: missing block: B:65:0x014e, code lost:
    
        r0.setArchitectureType(org.nd4j.linalg.heartbeat.reports.Task.ArchitectureType.RECURRENT);
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    public static org.nd4j.linalg.heartbeat.reports.Task taskByModel(org.deeplearning4j.nn.api.Model r3) {
        /*
            Method dump skipped, instructions count: 383
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: org.deeplearning4j.util.ModelSerializer.taskByModel(org.deeplearning4j.nn.api.Model):org.nd4j.linalg.heartbeat.reports.Task");
    }
}
