/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.feedforward.rbm.RBM;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.heartbeat.reports.Task;

public class ModelSerializer {
    public static final String UPDATER_BIN = "updater.bin";

    private ModelSerializer() {
    }

    public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (file == null) {
            throw new NullPointerException("file");
        }
        try (BufferedOutputStream stream = new BufferedOutputStream(new FileOutputStream(file));){
            ModelSerializer.writeModel(model, stream, saveUpdater);
        }
    }

    public static void writeModel(@NonNull Model model, @NonNull String path, boolean saveUpdater) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (path == null) {
            throw new NullPointerException("path");
        }
        try (BufferedOutputStream stream = new BufferedOutputStream(new FileOutputStream(path));){
            ModelSerializer.writeModel(model, stream, saveUpdater);
        }
    }

    public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater) throws IOException {
        if (model == null) {
            throw new NullPointerException("model");
        }
        if (stream == null) {
            throw new NullPointerException("stream");
        }
        ZipOutputStream zipfile = new ZipOutputStream(stream);
        String json = "";
        if (model instanceof MultiLayerNetwork) {
            json = ((MultiLayerNetwork)model).getLayerWiseConfigurations().toJson();
        } else if (model instanceof ComputationGraph) {
            json = ((ComputationGraph)model).getConfiguration().toJson();
        }
        ZipEntry config = new ZipEntry("configuration.json");
        zipfile.putNextEntry(config);
        ModelSerializer.writeEntry(new ByteArrayInputStream(json.getBytes()), zipfile);
        ZipEntry coefficients = new ZipEntry("coefficients.bin");
        zipfile.putNextEntry(coefficients);
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        DataOutputStream dos = new DataOutputStream(bos);
        Nd4j.write((INDArray)model.params(), (DataOutputStream)dos);
        dos.flush();
        dos.close();
        ByteArrayInputStream inputStream = new ByteArrayInputStream(bos.toByteArray());
        ModelSerializer.writeEntry(inputStream, zipfile);
        if (saveUpdater) {
            ZipEntry updater = new ZipEntry(UPDATER_BIN);
            zipfile.putNextEntry(updater);
            bos = new ByteArrayOutputStream();
            ObjectOutputStream oos = new ObjectOutputStream(bos);
            if (model instanceof MultiLayerNetwork) {
                oos.writeObject(((MultiLayerNetwork)model).getUpdater());
            } else if (model instanceof ComputationGraph) {
                oos.writeObject(((ComputationGraph)model).getUpdater());
            }
            oos.flush();
            oos.close();
            inputStream = new ByteArrayInputStream(bos.toByteArray());
            ModelSerializer.writeEntry(inputStream, zipfile);
        }
        zipfile.flush();
        zipfile.close();
    }

    private static void writeEntry(InputStream inputStream, ZipOutputStream zipStream) throws IOException {
        int bytesRead;
        byte[] bytes = new byte[1024];
        while ((bytesRead = inputStream.read(bytes)) != -1) {
            zipStream.write(bytes, 0, bytesRead);
        }
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull File file) throws IOException {
        ZipEntry prep;
        ZipEntry updaters;
        ZipEntry coefficients;
        if (file == null) {
            throw new NullPointerException("file");
        }
        ZipFile zipFile = new ZipFile(file);
        boolean gotConfig = false;
        boolean gotCoefficients = false;
        boolean gotUpdater = false;
        boolean gotPreProcessor = false;
        String json = "";
        INDArray params = null;
        Updater updater = null;
        DataSetPreProcessor preProcessor = null;
        ZipEntry config = zipFile.getEntry("configuration.json");
        if (config != null) {
            InputStream stream = zipFile.getInputStream(config);
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
            String line = "";
            StringBuilder js = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                js.append(line).append("\n");
            }
            json = js.toString();
            reader.close();
            stream.close();
            gotConfig = true;
        }
        if ((coefficients = zipFile.getEntry("coefficients.bin")) != null) {
            InputStream stream = zipFile.getInputStream(coefficients);
            DataInputStream dis = new DataInputStream(stream);
            params = Nd4j.read((DataInputStream)dis);
            dis.close();
            gotCoefficients = true;
        }
        if ((updaters = zipFile.getEntry(UPDATER_BIN)) != null) {
            InputStream stream = zipFile.getInputStream(updaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (Updater)ois.readObject();
            }
            catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotUpdater = true;
        }
        if ((prep = zipFile.getEntry("preprocessor.bin")) != null) {
            InputStream stream = zipFile.getInputStream(prep);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                preProcessor = (DataSetPreProcessor)ois.readObject();
            }
            catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotPreProcessor = true;
        }
        zipFile.close();
        if (gotConfig && gotCoefficients) {
            MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
            MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
            network.init(params, false);
            if (gotUpdater && updater != null) {
                network.setUpdater(updater);
            }
            return network;
        }
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is) throws IOException {
        ZipEntry entry;
        if (is == null) {
            throw new NullPointerException("is");
        }
        ZipInputStream zipFile = new ZipInputStream(is);
        boolean gotConfig = false;
        boolean gotCoefficients = false;
        boolean gotUpdater = false;
        boolean gotPreProcessor = false;
        String json = "";
        INDArray params = null;
        DataSetPreProcessor preProcessor = null;
        Updater updater = null;
        while ((entry = zipFile.getNextEntry()) != null) {
            switch (entry.getName()) {
                case "configuration.json": {
                    DataInputStream dis = new DataInputStream(zipFile);
                    params = Nd4j.read((DataInputStream)dis);
                    gotConfig = true;
                    break;
                }
                case "coefficients.bin": {
                    DataInputStream dis2 = new DataInputStream(zipFile);
                    params = Nd4j.read((DataInputStream)dis2);
                    gotCoefficients = true;
                    break;
                }
                case "updater.bin": {
                    ObjectInputStream ois = new ObjectInputStream(zipFile);
                    try {
                        updater = (Updater)ois.readObject();
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    gotUpdater = true;
                    break;
                }
                case "preprocessor.bin": {
                    ObjectInputStream ois = new ObjectInputStream(zipFile);
                    try {
                        preProcessor = (DataSetPreProcessor)ois.readObject();
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    gotPreProcessor = true;
                }
            }
            zipFile.closeEntry();
        }
        zipFile.close();
        if (gotConfig && gotCoefficients) {
            MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json);
            MultiLayerNetwork network = new MultiLayerNetwork(confFromJson);
            network.init(params, false);
            if (gotUpdater && updater != null) {
                network.setUpdater(updater);
            }
            return network;
        }
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
    }

    public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull String path) throws IOException {
        if (path == null) {
            throw new NullPointerException("path");
        }
        return ModelSerializer.restoreMultiLayerNetwork(new File(path));
    }

    public static ComputationGraph restoreComputationGraph(@NonNull String path) throws IOException {
        if (path == null) {
            throw new NullPointerException("path");
        }
        return ModelSerializer.restoreComputationGraph(new File(path));
    }

    public static ComputationGraph restoreComputationGraph(@NonNull InputStream is) throws IOException {
        ZipEntry entry;
        if (is == null) {
            throw new NullPointerException("is");
        }
        ZipInputStream zis = new ZipInputStream(is);
        boolean gotConfig = false;
        boolean gotCoefficients = false;
        boolean gotUpdater = false;
        boolean gotPreProcessor = false;
        String json = "";
        INDArray params = null;
        ComputationGraphUpdater updater = null;
        DataSetPreProcessor preProcessor = null;
        BufferedReader reader = new BufferedReader(new InputStreamReader(zis));
        while ((entry = zis.getNextEntry()) != null) {
            switch (entry.getName()) {
                case "configuration.json": {
                    String line;
                    StringBuilder js = new StringBuilder();
                    while ((line = reader.readLine()) != null) {
                        js.append(line).append("\n");
                    }
                    json = js.toString();
                    gotConfig = true;
                    break;
                }
                case "coefficients.bin": {
                    DataInputStream dis = new DataInputStream(zis);
                    params = Nd4j.read((DataInputStream)dis);
                    gotCoefficients = true;
                    break;
                }
                case "updater.bin": {
                    ObjectInputStream ois = new ObjectInputStream(zis);
                    try {
                        updater = (ComputationGraphUpdater)ois.readObject();
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    gotUpdater = true;
                    break;
                }
                case "preprocessor.bin": {
                    ObjectInputStream ois = new ObjectInputStream(zis);
                    try {
                        preProcessor = (DataSetPreProcessor)ois.readObject();
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    gotPreProcessor = true;
                }
            }
            zis.closeEntry();
        }
        if (gotConfig && gotCoefficients) {
            ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
            ComputationGraph cg = new ComputationGraph(confFromJson);
            cg.init(params, false);
            if (gotUpdater && updater != null) {
                cg.setUpdater(updater);
            }
            zis.close();
            return cg;
        }
        zis.close();
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
    }

    public static ComputationGraph restoreComputationGraph(@NonNull File file) throws IOException {
        ZipEntry prep;
        ZipEntry updaters;
        ZipEntry coefficients;
        if (file == null) {
            throw new NullPointerException("file");
        }
        ZipFile zipFile = new ZipFile(file);
        boolean gotConfig = false;
        boolean gotCoefficients = false;
        boolean gotUpdater = false;
        boolean gotPreProcessor = false;
        String json = "";
        INDArray params = null;
        ComputationGraphUpdater updater = null;
        DataSetPreProcessor preProcessor = null;
        ZipEntry config = zipFile.getEntry("configuration.json");
        if (config != null) {
            InputStream stream = zipFile.getInputStream(config);
            BufferedReader reader = new BufferedReader(new InputStreamReader(stream));
            String line = "";
            StringBuilder js = new StringBuilder();
            while ((line = reader.readLine()) != null) {
                js.append(line).append("\n");
            }
            json = js.toString();
            reader.close();
            stream.close();
            gotConfig = true;
        }
        if ((coefficients = zipFile.getEntry("coefficients.bin")) != null) {
            InputStream stream = zipFile.getInputStream(coefficients);
            DataInputStream dis = new DataInputStream(stream);
            params = Nd4j.read((DataInputStream)dis);
            dis.close();
            gotCoefficients = true;
        }
        if ((updaters = zipFile.getEntry(UPDATER_BIN)) != null) {
            InputStream stream = zipFile.getInputStream(updaters);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                updater = (ComputationGraphUpdater)ois.readObject();
            }
            catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotUpdater = true;
        }
        if ((prep = zipFile.getEntry("preprocessor.bin")) != null) {
            InputStream stream = zipFile.getInputStream(prep);
            ObjectInputStream ois = new ObjectInputStream(stream);
            try {
                preProcessor = (DataSetPreProcessor)ois.readObject();
            }
            catch (ClassNotFoundException e) {
                throw new RuntimeException(e);
            }
            gotPreProcessor = true;
        }
        zipFile.close();
        if (gotConfig && gotCoefficients) {
            ComputationGraphConfiguration confFromJson = ComputationGraphConfiguration.fromJson(json);
            ComputationGraph cg = new ComputationGraph(confFromJson);
            cg.init(params, false);
            if (gotUpdater && updater != null) {
                cg.setUpdater(updater);
            }
            return cg;
        }
        throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdater + "]");
    }

    public static Task taskByModel(Model model) {
        Task task = new Task();
        try {
            block21: {
                task.setArchitectureType(Task.ArchitectureType.RECURRENT);
                if (model instanceof ComputationGraph) {
                    task.setNetworkType(Task.NetworkType.ComputationalGraph);
                    ComputationGraph network = (ComputationGraph)model;
                    try {
                        if (network.getLayers() != null && network.getLayers().length > 0) {
                            for (Layer layer : network.getLayers()) {
                                if (layer instanceof org.deeplearning4j.nn.conf.layers.RBM || layer instanceof RBM) {
                                    task.setArchitectureType(Task.ArchitectureType.RBM);
                                } else if (layer.type().equals((Object)Layer.Type.CONVOLUTIONAL)) {
                                    task.setArchitectureType(Task.ArchitectureType.CONVOLUTION);
                                } else {
                                    if (!layer.type().equals((Object)Layer.Type.RECURRENT) && !layer.type().equals((Object)Layer.Type.RECURSIVE)) continue;
                                    task.setArchitectureType(Task.ArchitectureType.RECURRENT);
                                }
                                break block21;
                            }
                            break block21;
                        }
                        task.setArchitectureType(Task.ArchitectureType.UNKNOWN);
                    }
                    catch (Exception exception) {}
                } else if (model instanceof MultiLayerNetwork) {
                    task.setNetworkType(Task.NetworkType.MultilayerNetwork);
                    MultiLayerNetwork network = (MultiLayerNetwork)model;
                    try {
                        if (network.getLayers() != null && network.getLayers().length > 0) {
                            for (Layer layer : network.getLayers()) {
                                if (layer instanceof org.deeplearning4j.nn.conf.layers.RBM || layer instanceof RBM) {
                                    task.setArchitectureType(Task.ArchitectureType.RBM);
                                } else if (layer.type().equals((Object)Layer.Type.CONVOLUTIONAL)) {
                                    task.setArchitectureType(Task.ArchitectureType.CONVOLUTION);
                                } else {
                                    if (!layer.type().equals((Object)Layer.Type.RECURRENT) && !layer.type().equals((Object)Layer.Type.RECURSIVE)) continue;
                                    task.setArchitectureType(Task.ArchitectureType.RECURRENT);
                                }
                                break block21;
                            }
                            break block21;
                        }
                        task.setArchitectureType(Task.ArchitectureType.UNKNOWN);
                    }
                    catch (Exception exception) {
                        // empty catch block
                    }
                }
            }
            return task;
        }
        catch (Exception e) {
            task.setArchitectureType(Task.ArchitectureType.UNKNOWN);
            task.setNetworkType(Task.NetworkType.DenseNetwork);
            return task;
        }
    }
}

