package org.deeplearning4j.optimize.listeners;

import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.Map;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.solvers.BaseOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/listeners/ParamAndGradientIterationListener.class */
public class ParamAndGradientIterationListener implements IterationListener {
    private static final int MAX_WRITE_FAILURE_MESSAGES = 10;
    private static final Logger logger = LoggerFactory.getLogger(ParamAndGradientIterationListener.class);
    private boolean invoked;
    private int iterations;
    private long totalIterationCount;
    private boolean printMean;
    private boolean printHeader;
    private boolean printMinMax;
    private boolean printMeanAbsValue;
    private File file;
    private Path filePath;
    private boolean outputToConsole;
    private boolean outputToFile;
    private boolean outputToLogger;
    private String delimiter;
    private int writeFailureCount;

    /* loaded from: input_file:org/deeplearning4j/optimize/listeners/ParamAndGradientIterationListener$ParamAndGradientIterationListenerBuilder.class */
    public static class ParamAndGradientIterationListenerBuilder {
        private int iterations;
        private boolean printHeader;
        private boolean printMean;
        private boolean printMinMax;
        private boolean printMeanAbsValue;
        private boolean outputToConsole;
        private boolean outputToFile;
        private boolean outputToLogger;
        private File file;
        private String delimiter;

        ParamAndGradientIterationListenerBuilder() {
        }

        public ParamAndGradientIterationListenerBuilder iterations(int i) {
            this.iterations = i;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder printHeader(boolean z) {
            this.printHeader = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder printMean(boolean z) {
            this.printMean = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder printMinMax(boolean z) {
            this.printMinMax = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder printMeanAbsValue(boolean z) {
            this.printMeanAbsValue = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder outputToConsole(boolean z) {
            this.outputToConsole = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder outputToFile(boolean z) {
            this.outputToFile = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder outputToLogger(boolean z) {
            this.outputToLogger = z;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder file(File file) {
            this.file = file;
            return this;
        }

        public ParamAndGradientIterationListenerBuilder delimiter(String str) {
            this.delimiter = str;
            return this;
        }

        public ParamAndGradientIterationListener build() {
            return new ParamAndGradientIterationListener(this.iterations, this.printHeader, this.printMean, this.printMinMax, this.printMeanAbsValue, this.outputToConsole, this.outputToFile, this.outputToLogger, this.file, this.delimiter);
        }

        public String toString() {
            return "ParamAndGradientIterationListener.ParamAndGradientIterationListenerBuilder(iterations=" + this.iterations + ", printHeader=" + this.printHeader + ", printMean=" + this.printMean + ", printMinMax=" + this.printMinMax + ", printMeanAbsValue=" + this.printMeanAbsValue + ", outputToConsole=" + this.outputToConsole + ", outputToFile=" + this.outputToFile + ", outputToLogger=" + this.outputToLogger + ", file=" + this.file + ", delimiter=" + this.delimiter + ")";
        }
    }

    public ParamAndGradientIterationListener() {
        this(1, true, true, true, true, true, false, false, null, "\t");
    }

    public ParamAndGradientIterationListener(int i, boolean z, boolean z2, boolean z3, boolean z4, boolean z5, boolean z6, boolean z7, File file, String str) {
        this.invoked = false;
        this.totalIterationCount = 0L;
        this.printMean = true;
        this.printHeader = true;
        this.printMinMax = true;
        this.printMeanAbsValue = true;
        this.delimiter = "\t";
        this.writeFailureCount = 0;
        this.printHeader = z;
        this.printMean = z2;
        this.printMinMax = z3;
        this.printMeanAbsValue = z4;
        this.iterations = i;
        this.file = file;
        if (this.file != null) {
            this.filePath = file.toPath();
        }
        this.outputToConsole = z5;
        this.outputToFile = z6;
        this.outputToLogger = z7;
        this.delimiter = str;
    }

    @Override // org.deeplearning4j.optimize.api.IterationListener
    public void iterationDone(Model model, int i, int i2) {
        this.totalIterationCount++;
        if (this.totalIterationCount == 1 && this.printHeader) {
            Map<String, INDArray> paramTable = model.paramTable();
            model.conf().getVariables();
            StringBuilder sb = new StringBuilder();
            sb.append("n");
            sb.append(this.delimiter);
            sb.append(BaseOptimizer.SCORE_KEY);
            for (String str : paramTable.keySet()) {
                if (this.printMean) {
                    sb.append(this.delimiter).append(str).append("_mean");
                }
                if (this.printMinMax) {
                    sb.append(this.delimiter).append(str).append("_min").append(this.delimiter).append(str).append("_max");
                }
                if (this.printMeanAbsValue) {
                    sb.append(this.delimiter).append(str).append("_meanAbsValue");
                }
                if (this.printMean) {
                    sb.append(this.delimiter).append(str).append("_meanG");
                }
                if (this.printMinMax) {
                    sb.append(this.delimiter).append(str).append("_minG").append(this.delimiter).append(str).append("_maxG");
                }
                if (this.printMeanAbsValue) {
                    sb.append(this.delimiter).append(str).append("_meanAbsValueG");
                }
            }
            sb.append("\n");
            if (this.outputToFile) {
                try {
                    Files.write(this.filePath, sb.toString().getBytes(), StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING);
                } catch (IOException e) {
                    int i3 = this.writeFailureCount;
                    this.writeFailureCount = i3 + 1;
                    if (i3 < 10) {
                        logger.warn("Error writing to file: {}", e);
                    }
                    if (this.writeFailureCount == 10) {
                        logger.warn("Max file write messages displayed. No more failure messages will be printed");
                    }
                }
            }
            if (this.outputToLogger) {
                logger.info(sb.toString());
            }
            if (this.outputToConsole) {
                System.out.println(sb.toString());
            }
        }
        if (this.totalIterationCount % this.iterations != 0) {
            return;
        }
        Map<String, INDArray> paramTable2 = model.paramTable();
        Map<String, INDArray> gradientForVariable = model.gradient().gradientForVariable();
        StringBuilder sb2 = new StringBuilder();
        sb2.append(this.totalIterationCount);
        sb2.append(this.delimiter);
        sb2.append(model.score());
        for (Map.Entry<String, INDArray> entry : paramTable2.entrySet()) {
            INDArray value = entry.getValue();
            INDArray iNDArray = gradientForVariable.get(entry.getKey());
            if (this.printMean) {
                sb2.append(this.delimiter);
                sb2.append(value.meanNumber().doubleValue());
            }
            if (this.printMinMax) {
                sb2.append(this.delimiter);
                sb2.append(value.minNumber().doubleValue());
                sb2.append(this.delimiter);
                sb2.append(value.maxNumber().doubleValue());
            }
            if (this.printMeanAbsValue) {
                sb2.append(this.delimiter);
                sb2.append(Transforms.abs(value.dup()).meanNumber().doubleValue());
            }
            if (this.printMean) {
                sb2.append(this.delimiter);
                sb2.append(iNDArray.meanNumber().doubleValue());
            }
            if (this.printMinMax) {
                sb2.append(this.delimiter);
                sb2.append(iNDArray.minNumber().doubleValue());
                sb2.append(this.delimiter);
                sb2.append(iNDArray.maxNumber().doubleValue());
            }
            if (this.printMeanAbsValue) {
                sb2.append(this.delimiter);
                sb2.append(Transforms.abs(iNDArray.dup()).meanNumber().doubleValue());
            }
        }
        sb2.append("\n");
        String sb3 = sb2.toString();
        if (this.outputToLogger) {
            logger.info(sb3);
        }
        if (this.outputToConsole) {
            System.out.print(sb3);
        }
        if (this.outputToFile) {
            try {
                Files.write(this.filePath, sb3.getBytes(), StandardOpenOption.CREATE, StandardOpenOption.APPEND);
            } catch (IOException e2) {
                int i4 = this.writeFailureCount;
                this.writeFailureCount = i4 + 1;
                if (i4 < 10) {
                    logger.warn("Error writing to file: {}", e2);
                }
                if (this.writeFailureCount == 10) {
                    logger.warn("Max file write messages displayed. No more failure messages will be printed");
                }
            }
        }
    }

    public static ParamAndGradientIterationListenerBuilder builder() {
        return new ParamAndGradientIterationListenerBuilder();
    }
}
