package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.LinAlgExceptions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/ConjugateGradient.class */
public class ConjugateGradient extends BaseOptimizer {
    private static Logger logger = LoggerFactory.getLogger(ConjugateGradient.class);

    public ConjugateGradient(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
    }

    public ConjugateGradient(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<IterationListener> collection, Collection<TerminationCondition> collection2, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, collection2, model);
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer
    public void preProcessLine(INDArray iNDArray) {
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer
    public void postStep() {
        INDArray iNDArray = (INDArray) this.searchState.get(BaseOptimizer.GRADIENT_KEY);
        INDArray iNDArray2 = (INDArray) this.searchState.get("xi");
        INDArray iNDArray3 = (INDArray) this.searchState.get("h");
        this.searchState.put("gg", Double.valueOf(Transforms.pow(iNDArray, 2).sum(Integer.MAX_VALUE).getDouble(0)));
        this.searchState.put("dgg", Double.valueOf(iNDArray2.mul(iNDArray2.sub(iNDArray)).sum(Integer.MAX_VALUE).getDouble(0)));
        double doubleValue = ((Double) this.searchState.get("dgg")).doubleValue() / ((Double) this.searchState.get("gg")).doubleValue();
        this.searchState.put("gam", Double.valueOf(doubleValue));
        if (iNDArray3 == null) {
            iNDArray3 = iNDArray;
        }
        iNDArray.assign(iNDArray2);
        iNDArray3.assign(iNDArray3.mul(Double.valueOf(doubleValue)).addi(iNDArray2));
        BooleanIndexing.applyWhere(iNDArray3, Conditions.isNan(), new Value(Double.valueOf(Nd4j.EPS_THRESHOLD)));
        LinAlgExceptions.assertValidNum(iNDArray3);
        if (Nd4j.getBlasWrapper().dot(iNDArray2, iNDArray3) > 0.0d) {
            iNDArray2.assign(iNDArray3);
        } else {
            logger.warn("Reverting back to GA");
            iNDArray3.assign(iNDArray2);
        }
        this.searchState.put(BaseOptimizer.GRADIENT_KEY, iNDArray);
        this.searchState.put("xi", iNDArray2);
        this.searchState.put("h", iNDArray2.add(iNDArray3.mul(Double.valueOf(doubleValue))));
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer
    public void setupSearchState(Pair<Gradient, Double> pair) {
        super.setupSearchState(pair);
        INDArray iNDArray = (INDArray) this.searchState.get(BaseOptimizer.GRADIENT_KEY);
        this.searchState.put("h", iNDArray.dup());
        this.searchState.put("xi", iNDArray.dup());
        this.searchState.put("gg", Double.valueOf(0.0d));
        this.searchState.put("gam", Double.valueOf(0.0d));
        this.searchState.put("dgg", Double.valueOf(0.0d));
    }
}
