package org.deeplearning4j.optimize;

import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/OutputLayerTrainingEvaluator.class */
public class OutputLayerTrainingEvaluator implements TrainingEvaluator {
    private MultiLayerNetwork network;
    private double patience;
    private double patienceIncrease;
    private double bestLoss;
    private int validationEpochs;
    private int miniBatchSize;
    private double improvementThreshold;
    private static final Logger log = LoggerFactory.getLogger(OutputLayerTrainingEvaluator.class);

    /* loaded from: input_file:org/deeplearning4j/optimize/OutputLayerTrainingEvaluator$Builder.class */
    public static class Builder {
        private MultiLayerNetwork network;
        private double patience;
        private double patienceIncrease;
        private double bestLoss;
        private int validationEpochs;
        private int miniBatchSize;
        private DataSet testSet;
        private double improvementThreshold;

        public Builder withNetwork(MultiLayerNetwork multiLayerNetwork) {
            this.network = multiLayerNetwork;
            return this;
        }

        public Builder patience(double d) {
            this.patience = d;
            return this;
        }

        public Builder patienceIncrease(double d) {
            this.patienceIncrease = d;
            return this;
        }

        public Builder bestLoss(double d) {
            this.bestLoss = d;
            return this;
        }

        public Builder validationEpochs(int i) {
            this.validationEpochs = i;
            return this;
        }

        public Builder testSet(DataSet dataSet) {
            this.testSet = dataSet;
            return this;
        }

        public Builder miniBatchSize(int i) {
            this.miniBatchSize = i;
            return this;
        }

        public Builder improvementThreshold(double d) {
            this.improvementThreshold = d;
            return this;
        }

        public OutputLayerTrainingEvaluator build() {
            return new OutputLayerTrainingEvaluator(this.network, this.patience, this.patienceIncrease, this.bestLoss, this.validationEpochs, this.miniBatchSize, this.testSet, this.improvementThreshold);
        }
    }

    public OutputLayerTrainingEvaluator(MultiLayerNetwork multiLayerNetwork, double d, double d2, double d3, int i, int i2, DataSet dataSet, double d4) {
        this.network = multiLayerNetwork;
        this.patience = 4 * i2;
        this.patienceIncrease = d2;
        this.bestLoss = d3;
        this.validationEpochs = i;
        this.miniBatchSize = i2;
        this.improvementThreshold = d4;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public boolean shouldStop(int i) {
        if (i % this.validationEpochs != 0 || i < 2) {
            return false;
        }
        double score = this.network.score();
        if (score < this.bestLoss && score < this.bestLoss * this.improvementThreshold) {
            this.bestLoss = score;
            this.patience = Math.max(this.patience, i * this.patienceIncrease);
        }
        boolean z = this.patience < ((double) i);
        if (z) {
            log.info("Returning early on finetune");
        }
        return z;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public double patienceIncrease() {
        return this.patienceIncrease;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public double improvementThreshold() {
        return this.improvementThreshold;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public double patience() {
        return this.patience;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public double bestLoss() {
        return this.bestLoss;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public int validationEpochs() {
        return this.validationEpochs;
    }

    @Override // org.deeplearning4j.optimize.api.TrainingEvaluator
    public int miniBatchSize() {
        return this.miniBatchSize;
    }
}
