package moa.tasks;

import com.github.javacliparser.FileOption;
import com.github.javacliparser.FlagOption;
import com.github.javacliparser.IntOption;
import com.github.javacliparser.gui.FloatOptionEditComponent;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Instances;
import java.io.File;
import java.io.FileOutputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import moa.classifiers.MultiClassClassifier;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.StringUtils;
import moa.core.TimingUtils;
import moa.evaluation.LearningCurve;
import moa.evaluation.LearningEvaluation;
import moa.evaluation.LearningPerformanceEvaluator;
import moa.learners.Learner;
import moa.options.ClassOption;
import moa.streams.CachedInstancesStream;
import moa.streams.ExampleStream;

/* loaded from: input_file:moa/tasks/EvaluatePeriodicHeldOutTest.class */
public class EvaluatePeriodicHeldOutTest extends ClassificationMainTask {
    private static final long serialVersionUID = 1;
    public ClassOption learnerOption = new ClassOption("learner", 'l', "Classifier to train.", MultiClassClassifier.class, "moa.classifiers.trees.HoeffdingTree");
    public ClassOption streamOption = new ClassOption("stream", 's', "Stream to learn from.", ExampleStream.class, "generators.RandomTreeGenerator");
    public ClassOption evaluatorOption = new ClassOption("evaluator", 'e', "Learning performance evaluation method.", LearningPerformanceEvaluator.class, "BasicClassificationPerformanceEvaluator");
    public IntOption testSizeOption = new IntOption("testSize", 'n', "Number of testing examples.", 1000000, 0, Integer.MAX_VALUE);
    public IntOption trainSizeOption = new IntOption("trainSize", 'i', "Number of training examples, <1 = unlimited.", 0, 0, Integer.MAX_VALUE);
    public IntOption trainTimeOption = new IntOption("trainTime", 't', "Number of training seconds.", 36000, 0, Integer.MAX_VALUE);
    public IntOption sampleFrequencyOption = new IntOption("sampleFrequency", 'f', "Number of training examples between samples of learning performance.", FloatOptionEditComponent.SLIDER_RESOLUTION, 0, Integer.MAX_VALUE);
    public FileOption dumpFileOption = new FileOption("dumpFile", 'd', "File to append intermediate csv results to.", null, "csv", true);
    public FlagOption cacheTestOption = new FlagOption("cacheTest", 'c', "Cache test instances in memory.");

    @Override // moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "Evaluates a classifier on a stream by periodically testing on a heldout set.";
    }

    @Override // moa.tasks.MainTask
    protected Object doMainTask(TaskMonitor taskMonitor, ObjectRepository objectRepository) {
        ExampleStream exampleStream;
        Learner learner = (Learner) getPreparedClassOption(this.learnerOption);
        ExampleStream exampleStream2 = (ExampleStream) getPreparedClassOption(this.streamOption);
        LearningPerformanceEvaluator learningPerformanceEvaluator = (LearningPerformanceEvaluator) getPreparedClassOption(this.evaluatorOption);
        learner.setModelContext(exampleStream2.getHeader());
        LearningCurve learningCurve = new LearningCurve("evaluation instances");
        File file = this.dumpFileOption.getFile();
        PrintStream printStream = null;
        if (file != null) {
            try {
                printStream = file.exists() ? new PrintStream((OutputStream) new FileOutputStream(file, true), true) : new PrintStream((OutputStream) new FileOutputStream(file), true);
            } catch (Exception e) {
                throw new RuntimeException("Unable to open immediate result file: " + file, e);
            }
        }
        boolean z = true;
        int value = this.testSizeOption.getValue();
        if (this.cacheTestOption.isSet()) {
            taskMonitor.setCurrentActivity("Caching test examples...", -1.0d);
            Instances instances = new Instances(exampleStream2.getHeader(), this.testSizeOption.getValue());
            while (instances.numInstances() < value) {
                instances.add((Instance) exampleStream2.nextInstance2().getData());
                if (instances.numInstances() % 10 == 0) {
                    if (taskMonitor.taskShouldAbort()) {
                        return null;
                    }
                    taskMonitor.setCurrentActivityFractionComplete(instances.numInstances() / this.testSizeOption.getValue());
                }
            }
            exampleStream = new CachedInstancesStream(instances);
        } else {
            exampleStream = exampleStream2;
        }
        long j = 0;
        TimingUtils.enablePreciseTiming();
        double d = 0.0d;
        while (true) {
            if ((this.trainSizeOption.getValue() >= 1 && j >= this.trainSizeOption.getValue()) || !exampleStream2.hasMoreInstances()) {
                break;
            }
            taskMonitor.setCurrentActivityDescription("Training...");
            long value2 = j + this.sampleFrequencyOption.getValue();
            long nanoCPUTimeOfCurrentThread = TimingUtils.getNanoCPUTimeOfCurrentThread();
            while (j < value2 && exampleStream2.hasMoreInstances()) {
                learner.trainOnInstance(exampleStream2.nextInstance2());
                j++;
                if (j % 10 == 0) {
                    if (taskMonitor.taskShouldAbort()) {
                        return null;
                    }
                    taskMonitor.setCurrentActivityFractionComplete(j / this.trainSizeOption.getValue());
                }
            }
            double nanoTimeToSeconds = TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - nanoCPUTimeOfCurrentThread);
            d += nanoTimeToSeconds;
            if (d > this.trainTimeOption.getValue()) {
                break;
            }
            if (this.cacheTestOption.isSet()) {
                exampleStream.restart();
            }
            learningPerformanceEvaluator.reset();
            long j2 = 0;
            taskMonitor.setCurrentActivityDescription("Testing (after " + StringUtils.doubleToString((j / this.trainSizeOption.getValue()) * 100.0d, 2) + "% training)...");
            long nanoCPUTimeOfCurrentThread2 = TimingUtils.getNanoCPUTimeOfCurrentThread();
            int i = 0;
            while (i < value && exampleStream2.hasMoreInstances()) {
                Example nextInstance2 = exampleStream.nextInstance2();
                ((Instance) nextInstance2.getData()).classValue();
                learningPerformanceEvaluator.addResult((LearningPerformanceEvaluator) nextInstance2, learner.getVotesForInstance(nextInstance2));
                j2++;
                if (j2 % 10 == 0) {
                    if (taskMonitor.taskShouldAbort()) {
                        return null;
                    }
                    taskMonitor.setCurrentActivityFractionComplete(j2 / value);
                }
                i++;
            }
            if (i != value) {
                break;
            }
            double nanoTimeToSeconds2 = TimingUtils.nanoTimeToSeconds(TimingUtils.getNanoCPUTimeOfCurrentThread() - nanoCPUTimeOfCurrentThread2);
            ArrayList arrayList = new ArrayList();
            arrayList.add(new Measurement("evaluation instances", j));
            arrayList.add(new Measurement("total train time", d));
            arrayList.add(new Measurement("total train speed", j / d));
            arrayList.add(new Measurement("last train time", nanoTimeToSeconds));
            arrayList.add(new Measurement("last train speed", this.sampleFrequencyOption.getValue() / nanoTimeToSeconds));
            arrayList.add(new Measurement("test time", nanoTimeToSeconds2));
            arrayList.add(new Measurement("test speed", this.testSizeOption.getValue() / nanoTimeToSeconds2));
            for (Measurement measurement : learningPerformanceEvaluator.getPerformanceMeasurements()) {
                arrayList.add(measurement);
            }
            for (Measurement measurement2 : learner.getModelMeasurements()) {
                arrayList.add(measurement2);
            }
            learningCurve.insertEntry(new LearningEvaluation((Measurement[]) arrayList.toArray(new Measurement[arrayList.size()])));
            if (printStream != null) {
                if (z) {
                    printStream.println(learningCurve.headerToString());
                    z = false;
                }
                printStream.println(learningCurve.entryToString(learningCurve.numEntries() - 1));
                printStream.flush();
            }
            if (taskMonitor.resultPreviewRequested()) {
                taskMonitor.setLatestResultPreview(learningCurve.copy());
            }
        }
        if (printStream != null) {
            printStream.close();
        }
        return learningCurve;
    }

    @Override // moa.tasks.Task
    public Class<?> getTaskResultType() {
        return LearningCurve.class;
    }
}
