package ai.djl.training.listener;

import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.Iterator;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/training/listener/TimeMeasureTrainingListener.class */
public class TimeMeasureTrainingListener extends TrainingListenerAdapter {
    private static final Logger logger = LoggerFactory.getLogger(TimeMeasureTrainingListener.class);
    private String outputDir;
    private long trainBatchBeginTime = -1;
    private long validateBatchBeginTime = -1;

    public TimeMeasureTrainingListener(String str) {
        this.outputDir = str;
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onEpoch(Trainer trainer) {
        this.trainBatchBeginTime = -1L;
        this.validateBatchBeginTime = -1L;
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.trainBatchBeginTime != -1) {
            trainer.addMetric("train", this.trainBatchBeginTime);
        }
        this.trainBatchBeginTime = System.nanoTime();
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.validateBatchBeginTime != -1) {
            trainer.addMetric("validate", this.validateBatchBeginTime);
        }
        this.validateBatchBeginTime = System.nanoTime();
    }

    @Override // ai.djl.training.listener.TrainingListenerAdapter, ai.djl.training.listener.TrainingListener
    public void onTrainingEnd(Trainer trainer) {
        dumpTrainingTimeInfo(trainer.getMetrics(), this.outputDir);
    }

    private static void dumpTrainingTimeInfo(Metrics metrics, String str) {
        if (str == null) {
            return;
        }
        try {
            Path path = Paths.get(str, new String[0]);
            Files.createDirectories(path, new FileAttribute[0]);
            dumpMetricToFile(path.resolve("training.log"), metrics.getMetric("train"));
            dumpMetricToFile(path.resolve("validate.log"), metrics.getMetric("validate"));
        } catch (IOException e) {
            logger.error("Failed dump training log", e);
        }
    }

    private static void dumpMetricToFile(Path path, List<Metric> list) throws IOException {
        if (list == null || list.isEmpty()) {
            return;
        }
        BufferedWriter newBufferedWriter = Files.newBufferedWriter(path, StandardOpenOption.CREATE, StandardOpenOption.APPEND);
        Throwable th = null;
        try {
            try {
                Iterator<Metric> it = list.iterator();
                while (it.hasNext()) {
                    newBufferedWriter.append((CharSequence) it.next().toString());
                    newBufferedWriter.newLine();
                }
                if (newBufferedWriter != null) {
                    if (0 == 0) {
                        newBufferedWriter.close();
                        return;
                    }
                    try {
                        newBufferedWriter.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
            } catch (Throwable th3) {
                th = th3;
                throw th3;
            }
        } catch (Throwable th4) {
            if (newBufferedWriter != null) {
                if (th != null) {
                    try {
                        newBufferedWriter.close();
                    } catch (Throwable th5) {
                        th.addSuppressed(th5);
                    }
                } else {
                    newBufferedWriter.close();
                }
            }
            throw th4;
        }
    }
}
