package org.deeplearning4j.nn.transferlearning;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.graph.FrozenVertex;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.LayerVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.graph.vertex.impl.InputVertex;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/transferlearning/TransferLearning.class */
public class TransferLearning {
    private static final Logger log = LoggerFactory.getLogger(TransferLearning.class);

    /* loaded from: input_file:org/deeplearning4j/nn/transferlearning/TransferLearning$Builder.class */
    public static class Builder {
        private MultiLayerConfiguration origConf;
        private MultiLayerNetwork origModel;
        private MultiLayerNetwork editedModel;
        private FineTuneConfiguration finetuneConfiguration;
        private int frozenTill = -1;
        private int popN = 0;
        private boolean prepDone = false;
        private Set<Integer> editedLayers = new HashSet();
        private Map<Integer, Triple<Integer, Pair<WeightInit, Distribution>, Pair<WeightInit, Distribution>>> editedLayersMap = new HashMap();
        private List<INDArray> editedParams = new ArrayList();
        private List<NeuralNetConfiguration> editedConfs = new ArrayList();
        private List<INDArray> appendParams = new ArrayList();
        private List<NeuralNetConfiguration> appendConfs = new ArrayList();
        private Map<Integer, InputPreProcessor> inputPreProcessors;
        private InputType inputType;

        public Builder(MultiLayerNetwork multiLayerNetwork) {
            this.inputPreProcessors = new HashMap();
            this.origModel = multiLayerNetwork;
            this.origConf = multiLayerNetwork.getLayerWiseConfigurations().m30clone();
            this.inputPreProcessors = this.origConf.getInputPreProcessors();
        }

        public Builder fineTuneConfiguration(FineTuneConfiguration fineTuneConfiguration) {
            this.finetuneConfiguration = fineTuneConfiguration;
            return this;
        }

        public Builder setFeatureExtractor(int i) {
            this.frozenTill = i;
            return this;
        }

        public Builder nOutReplace(int i, int i2, WeightInit weightInit) {
            return nOutReplace(i, i2, weightInit, weightInit, null, null);
        }

        public Builder nOutReplace(int i, int i2, Distribution distribution) {
            return nOutReplace(i, i2, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, distribution, distribution);
        }

        public Builder nOutReplace(int i, int i2, WeightInit weightInit, WeightInit weightInit2) {
            return nOutReplace(i, i2, weightInit, weightInit2, null, null);
        }

        public Builder nOutReplace(int i, int i2, Distribution distribution, Distribution distribution2) {
            return nOutReplace(i, i2, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, distribution, distribution2);
        }

        public Builder nOutReplace(int i, int i2, WeightInit weightInit, Distribution distribution) {
            return nOutReplace(i, i2, weightInit, WeightInit.DISTRIBUTION, null, distribution);
        }

        public Builder nOutReplace(int i, int i2, Distribution distribution, WeightInit weightInit) {
            return nOutReplace(i, i2, WeightInit.DISTRIBUTION, weightInit, distribution, null);
        }

        private Builder nOutReplace(int i, int i2, WeightInit weightInit, WeightInit weightInit2, Distribution distribution, Distribution distribution2) {
            this.editedLayers.add(Integer.valueOf(i));
            this.editedLayersMap.put(Integer.valueOf(i), new Triple<>(Integer.valueOf(i2), new Pair(weightInit, distribution), new Pair(weightInit2, distribution2)));
            return this;
        }

        public Builder removeOutputLayer() {
            this.popN = 1;
            return this;
        }

        public Builder removeLayersFromOutput(int i) {
            if (this.popN != 0) {
                throw new IllegalArgumentException("Remove layers from can only be called once");
            }
            this.popN = i;
            return this;
        }

        public Builder addLayer(Layer layer) {
            if (!this.prepDone) {
                doPrep();
            }
            NeuralNetConfiguration build = this.finetuneConfiguration.appliedNeuralNetConfigurationBuilder().layer(layer).build();
            long numParams = layer.initializer().numParams(build);
            if (numParams > 0) {
                org.deeplearning4j.nn.api.Layer instantiate = layer.instantiate(build, null, 0, Nd4j.create(new long[]{1, numParams}), true);
                this.appendParams.add(instantiate.params());
                this.appendConfs.add(instantiate.conf());
            } else {
                this.appendConfs.add(build);
            }
            return this;
        }

        public Builder setInputPreProcessor(int i, InputPreProcessor inputPreProcessor) {
            this.inputPreProcessors.put(Integer.valueOf(i), inputPreProcessor);
            return this;
        }

        public MultiLayerNetwork build() {
            if (!this.prepDone) {
                doPrep();
            }
            this.editedModel = new MultiLayerNetwork(constructConf(), constructParams());
            if (this.frozenTill != -1) {
                org.deeplearning4j.nn.api.Layer[] layers = this.editedModel.getLayers();
                for (int i = this.frozenTill; i >= 0; i--) {
                    NeuralNetConfiguration conf = this.editedModel.getLayerWiseConfigurations().getConf(i);
                    NeuralNetConfiguration m33clone = conf.m33clone();
                    layers[i].setConf(m33clone);
                    layers[i] = new FrozenLayer(layers[i]);
                    if (conf.getVariables() != null) {
                        List<String> variables = conf.variables(true);
                        conf.clearVariables();
                        m33clone.clearVariables();
                        for (String str : variables) {
                            conf.variables(false).add(str);
                            m33clone.variables(false).add(str);
                        }
                    }
                    Layer layer = this.editedModel.getLayerWiseConfigurations().getConf(i).getLayer();
                    org.deeplearning4j.nn.conf.layers.misc.FrozenLayer frozenLayer = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(layer);
                    frozenLayer.setLayerName(layer.getLayerName());
                    this.editedModel.getLayerWiseConfigurations().getConf(i).setLayer(frozenLayer);
                }
                this.editedModel.setLayers(layers);
            }
            return this.editedModel;
        }

        private void doPrep() {
            fineTuneConfigurationBuild();
            for (int i = 0; i < this.origModel.getnLayers(); i++) {
                if (this.origModel.getLayer(i).numParams() > 0) {
                    this.editedParams.add(this.origModel.getLayer(i).params().dup());
                } else {
                    this.editedParams.add(this.origModel.getLayer(i).params());
                }
            }
            if (!this.editedLayers.isEmpty()) {
                Integer[] numArr = (Integer[]) this.editedLayers.toArray(new Integer[this.editedLayers.size()]);
                Arrays.sort(numArr);
                for (Integer num : numArr) {
                    int intValue = num.intValue();
                    nOutReplaceBuild(intValue, ((Integer) this.editedLayersMap.get(Integer.valueOf(intValue)).getLeft()).intValue(), (Pair) this.editedLayersMap.get(Integer.valueOf(intValue)).getMiddle(), (Pair) this.editedLayersMap.get(Integer.valueOf(intValue)).getRight());
                }
            }
            for (int i2 = 0; i2 < this.popN; i2++) {
                Integer valueOf = Integer.valueOf(this.origModel.getnLayers() - i2);
                if (this.inputPreProcessors.containsKey(valueOf)) {
                    this.inputPreProcessors.remove(valueOf);
                }
                this.editedConfs.remove(this.editedConfs.size() - 1);
                this.editedParams.remove(this.editedParams.size() - 1);
            }
            this.prepDone = true;
        }

        private void fineTuneConfigurationBuild() {
            NeuralNetConfiguration m33clone;
            for (int i = 0; i < this.origConf.getConfs().size(); i++) {
                if (this.finetuneConfiguration != null) {
                    NeuralNetConfiguration m33clone2 = this.origConf.getConf(i).m33clone();
                    this.finetuneConfiguration.applyToNeuralNetConfiguration(m33clone2);
                    m33clone = m33clone2;
                } else {
                    m33clone = this.origConf.getConf(i).m33clone();
                }
                this.editedConfs.add(m33clone);
            }
        }

        private void nOutReplaceBuild(int i, int i2, Pair<WeightInit, Distribution> pair, Pair<WeightInit, Distribution> pair2) {
            NeuralNetConfiguration neuralNetConfiguration = this.editedConfs.get(i);
            Layer layer = neuralNetConfiguration.getLayer();
            FeedForwardLayer feedForwardLayer = (FeedForwardLayer) layer;
            feedForwardLayer.setWeightInit((WeightInit) pair.getLeft());
            feedForwardLayer.setDist((Distribution) pair.getRight());
            feedForwardLayer.setNOut(i2);
            this.editedParams.set(i, layer.instantiate(neuralNetConfiguration, null, 0, Nd4j.create(new long[]{1, layer.initializer().numParams(neuralNetConfiguration)}), true).params());
            if (i + 1 < this.editedConfs.size()) {
                NeuralNetConfiguration neuralNetConfiguration2 = this.editedConfs.get(i + 1);
                Layer layer2 = neuralNetConfiguration2.getLayer();
                FeedForwardLayer feedForwardLayer2 = (FeedForwardLayer) layer2;
                feedForwardLayer2.setWeightInit((WeightInit) pair2.getLeft());
                feedForwardLayer2.setDist((Distribution) pair2.getRight());
                feedForwardLayer2.setNIn(i2);
                long numParams = layer2.initializer().numParams(neuralNetConfiguration2);
                if (numParams > 0) {
                    this.editedParams.set(i + 1, layer2.instantiate(neuralNetConfiguration2, null, 0, Nd4j.create(new long[]{1, numParams}), true).params());
                }
            }
        }

        private INDArray constructParams() {
            INDArray iNDArray = null;
            for (INDArray iNDArray2 : this.editedParams) {
                if (iNDArray2 != null) {
                    iNDArray = iNDArray == null ? iNDArray2 : Nd4j.hstack(new INDArray[]{iNDArray, iNDArray2});
                }
            }
            return !this.appendParams.isEmpty() ? Nd4j.hstack(new INDArray[]{iNDArray, Nd4j.hstack(this.appendParams)}) : iNDArray;
        }

        private MultiLayerConfiguration constructConf() {
            ArrayList arrayList = new ArrayList();
            arrayList.addAll(this.editedConfs);
            arrayList.addAll(this.appendConfs);
            for (int i = 0; i < arrayList.size(); i++) {
                if (((NeuralNetConfiguration) arrayList.get(i)).getLayer().getLayerName() == null) {
                    ((NeuralNetConfiguration) arrayList.get(i)).getLayer().setLayerName("layer" + i);
                }
            }
            MultiLayerConfiguration build = new MultiLayerConfiguration.Builder().inputPreProcessors(this.inputPreProcessors).setInputType(this.inputType).confs(arrayList).build();
            if (this.finetuneConfiguration != null) {
                this.finetuneConfiguration.applyToMultiLayerConfiguration(build);
            }
            return build;
        }
    }

    /* loaded from: input_file:org/deeplearning4j/nn/transferlearning/TransferLearning$GraphBuilder.class */
    public static class GraphBuilder {
        private ComputationGraph origGraph;
        private ComputationGraphConfiguration origConfig;
        private FineTuneConfiguration fineTuneConfiguration;
        private ComputationGraphConfiguration.GraphBuilder editedConfigBuilder;
        private String[] frozenOutputAt;
        private boolean hasFrozen = false;
        private Set<String> editedVertices = new HashSet();
        private WorkspaceMode workspaceMode;

        public GraphBuilder(ComputationGraph computationGraph) {
            this.origGraph = computationGraph;
            this.origConfig = computationGraph.getConfiguration().m26clone();
        }

        public GraphBuilder fineTuneConfiguration(FineTuneConfiguration fineTuneConfiguration) {
            this.fineTuneConfiguration = fineTuneConfiguration;
            this.editedConfigBuilder = new ComputationGraphConfiguration.GraphBuilder(this.origConfig, fineTuneConfiguration.appliedNeuralNetConfigurationBuilder());
            Map<String, GraphVertex> vertices = this.editedConfigBuilder.getVertices();
            for (Map.Entry<String, GraphVertex> entry : vertices.entrySet()) {
                if (entry.getValue() instanceof LayerVertex) {
                    LayerVertex layerVertex = (LayerVertex) entry.getValue();
                    NeuralNetConfiguration m33clone = layerVertex.getLayerConf().m33clone();
                    fineTuneConfiguration.applyToNeuralNetConfiguration(m33clone);
                    vertices.put(entry.getKey(), new LayerVertex(m33clone, layerVertex.getPreProcessor()));
                    m33clone.getLayer().setLayerName(entry.getKey());
                }
            }
            return this;
        }

        public GraphBuilder setFeatureExtractor(String... strArr) {
            this.hasFrozen = true;
            this.frozenOutputAt = strArr;
            return this;
        }

        public GraphBuilder nOutReplace(String str, int i, WeightInit weightInit) {
            return nOutReplace(str, i, weightInit, weightInit, null, null);
        }

        public GraphBuilder nOutReplace(String str, int i, Distribution distribution) {
            return nOutReplace(str, i, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, distribution, distribution);
        }

        public GraphBuilder nOutReplace(String str, int i, Distribution distribution, Distribution distribution2) {
            return nOutReplace(str, i, WeightInit.DISTRIBUTION, WeightInit.DISTRIBUTION, distribution, distribution2);
        }

        public GraphBuilder nOutReplace(String str, int i, WeightInit weightInit, Distribution distribution) {
            return nOutReplace(str, i, weightInit, WeightInit.DISTRIBUTION, null, distribution);
        }

        public GraphBuilder nOutReplace(String str, int i, Distribution distribution, WeightInit weightInit) {
            return nOutReplace(str, i, WeightInit.DISTRIBUTION, weightInit, distribution, null);
        }

        public GraphBuilder nOutReplace(String str, int i, WeightInit weightInit, WeightInit weightInit2) {
            return nOutReplace(str, i, weightInit, weightInit2, null, null);
        }

        private GraphBuilder nOutReplace(String str, int i, WeightInit weightInit, WeightInit weightInit2, Distribution distribution, Distribution distribution2) {
            initBuilderIfReq();
            if (!this.origGraph.getVertex(str).hasLayer()) {
                throw new IllegalArgumentException("noutReplace can only be applied to layer vertices. " + str + " is not a layer vertex");
            }
            Layer mo56clone = this.origGraph.getLayer(str).conf().getLayer().mo56clone();
            mo56clone.resetLayerDefaultConfig();
            FeedForwardLayer feedForwardLayer = (FeedForwardLayer) mo56clone;
            feedForwardLayer.setWeightInit(weightInit);
            feedForwardLayer.setDist(distribution);
            feedForwardLayer.setNOut(i);
            this.editedConfigBuilder.removeVertex(str, false);
            this.editedConfigBuilder.addLayer(str, mo56clone, ((LayerVertex) this.origConfig.getVertices().get(str)).getPreProcessor(), (String[]) this.origConfig.getVertexInputs().get(str).toArray(new String[0]));
            this.editedVertices.add(str);
            ArrayList<String> arrayList = new ArrayList();
            for (Map.Entry<String, List<String>> entry : this.origConfig.getVertexInputs().entrySet()) {
                String key = entry.getKey();
                if (!key.equals(str) && entry.getValue().contains(str)) {
                    arrayList.add(key);
                }
            }
            for (String str2 : arrayList) {
                if (!this.origGraph.getVertex(str2).hasLayer()) {
                    throw new UnsupportedOperationException("Cannot modify nOut of a layer vertex that feeds non-layer vertices. Use removeVertexKeepConnections followed by addVertex instead");
                }
                Layer mo56clone2 = this.origGraph.getLayer(str2).conf().getLayer().mo56clone();
                FeedForwardLayer feedForwardLayer2 = (FeedForwardLayer) mo56clone2;
                feedForwardLayer2.setWeightInit(weightInit2);
                feedForwardLayer2.setDist(distribution2);
                feedForwardLayer2.setNIn(i);
                this.editedConfigBuilder.removeVertex(str2, false);
                this.editedConfigBuilder.addLayer(str2, mo56clone2, ((LayerVertex) this.origConfig.getVertices().get(str2)).getPreProcessor(), (String[]) this.origConfig.getVertexInputs().get(str2).toArray(new String[0]));
                this.editedVertices.add(str2);
            }
            return this;
        }

        public GraphBuilder removeVertexKeepConnections(String str) {
            initBuilderIfReq();
            this.editedConfigBuilder.removeVertex(str, false);
            return this;
        }

        public GraphBuilder removeVertexAndConnections(String str) {
            initBuilderIfReq();
            this.editedConfigBuilder.removeVertex(str, true);
            return this;
        }

        public GraphBuilder addLayer(String str, Layer layer, String... strArr) {
            initBuilderIfReq();
            this.editedConfigBuilder.addLayer(str, layer, null, strArr);
            this.editedVertices.add(str);
            return this;
        }

        public GraphBuilder addLayer(String str, Layer layer, InputPreProcessor inputPreProcessor, String... strArr) {
            initBuilderIfReq();
            this.editedConfigBuilder.addLayer(str, layer, inputPreProcessor, strArr);
            this.editedVertices.add(str);
            return this;
        }

        public GraphBuilder addVertex(String str, GraphVertex graphVertex, String... strArr) {
            initBuilderIfReq();
            this.editedConfigBuilder.addVertex(str, graphVertex, strArr);
            this.editedVertices.add(str);
            return this;
        }

        public GraphBuilder setOutputs(String... strArr) {
            initBuilderIfReq();
            this.editedConfigBuilder.setOutputs(strArr);
            return this;
        }

        private void initBuilderIfReq() {
            if (this.editedConfigBuilder == null) {
                fineTuneConfiguration(new FineTuneConfiguration.Builder().seed(this.origConfig.getDefaultConfiguration().getSeed()).build());
            }
        }

        public GraphBuilder setInputs(String... strArr) {
            this.editedConfigBuilder.setNetworkInputs(Arrays.asList(strArr));
            return this;
        }

        public GraphBuilder setInputTypes(InputType... inputTypeArr) {
            this.editedConfigBuilder.setInputTypes(inputTypeArr);
            return this;
        }

        public GraphBuilder addInputs(String... strArr) {
            this.editedConfigBuilder.addInputs(strArr);
            return this;
        }

        public GraphBuilder setWorkspaceMode(WorkspaceMode workspaceMode) {
            this.workspaceMode = workspaceMode;
            return this;
        }

        public ComputationGraph build() {
            initBuilderIfReq();
            ComputationGraphConfiguration build = this.editedConfigBuilder.build();
            if (this.workspaceMode != null) {
                build.setTrainingWorkspaceMode(this.workspaceMode);
            }
            ComputationGraph computationGraph = new ComputationGraph(build);
            computationGraph.init();
            int[] iArr = computationGraph.topologicalSortOrder();
            org.deeplearning4j.nn.graph.vertex.GraphVertex[] vertices = computationGraph.getVertices();
            if (this.editedVertices.isEmpty()) {
                computationGraph.setParams(this.origGraph.params());
            } else {
                for (int i = 0; i < iArr.length; i++) {
                    if (vertices[iArr[i]].hasLayer()) {
                        org.deeplearning4j.nn.api.Layer layer = vertices[iArr[i]].getLayer();
                        String vertexName = vertices[iArr[i]].getVertexName();
                        if (layer.numParams() > 0 && !this.editedVertices.contains(vertexName)) {
                            layer.setParams(this.origGraph.getLayer(vertexName).params().dup());
                        }
                    }
                }
            }
            if (this.hasFrozen) {
                HashSet hashSet = new HashSet();
                Collections.addAll(hashSet, this.frozenOutputAt);
                for (int length = iArr.length - 1; length >= 0; length--) {
                    org.deeplearning4j.nn.graph.vertex.GraphVertex graphVertex = vertices[iArr[length]];
                    if (hashSet.contains(graphVertex.getVertexName())) {
                        if (graphVertex.hasLayer()) {
                            org.deeplearning4j.nn.api.Layer layer2 = graphVertex.getLayer();
                            graphVertex.setLayerAsFrozen();
                            LayerVertex layerVertex = (LayerVertex) build.getVertices().get(graphVertex.getVertexName());
                            Layer layer3 = layerVertex.getLayerConf().getLayer();
                            org.deeplearning4j.nn.conf.layers.misc.FrozenLayer frozenLayer = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(layer3);
                            frozenLayer.setLayerName(layer3.getLayerName());
                            NeuralNetConfiguration m33clone = layerVertex.getLayerConf().m33clone();
                            layerVertex.setLayerConf(m33clone);
                            layerVertex.getLayerConf().setLayer(frozenLayer);
                            List<String> variables = layerVertex.getLayerConf().variables(true);
                            layerVertex.getLayerConf().clearVariables();
                            Iterator<String> it = variables.iterator();
                            while (it.hasNext()) {
                                m33clone.variables(false).add(it.next());
                            }
                            org.deeplearning4j.nn.api.Layer[] layers = computationGraph.getLayers();
                            int i2 = 0;
                            while (true) {
                                if (i2 >= layers.length) {
                                    break;
                                }
                                if (layers[i2] == layer2) {
                                    layers[i2] = graphVertex.getLayer();
                                    break;
                                }
                                i2++;
                            }
                        } else if (!(graphVertex instanceof InputVertex)) {
                            build.getVertices().put(graphVertex.getVertexName(), new FrozenVertex(build.getVertices().get(graphVertex.getVertexName())));
                            vertices[iArr[length]] = new org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex(graphVertex);
                        }
                        VertexIndices[] inputVertices = graphVertex.getInputVertices();
                        if (inputVertices != null && inputVertices.length > 0) {
                            for (VertexIndices vertexIndices : inputVertices) {
                                hashSet.add(vertices[vertexIndices.getVertexIndex()].getVertexName());
                            }
                        }
                    }
                }
                computationGraph.initGradientsView();
            }
            return computationGraph;
        }
    }
}
