package weka.classifiers.meta.multisearch;

import java.io.File;
import java.io.Serializable;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Future;
import weka.classifiers.meta.multisearch.AbstractSearch;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.converters.ConverterUtils;
import weka.core.setupgenerator.Point;
import weka.core.setupgenerator.Space;
import weka.filters.Filter;
import weka.filters.unsupervised.instance.Resample;

/* loaded from: input_file:weka/classifiers/meta/multisearch/RandomSearch.class */
public class RandomSearch extends AbstractMultiThreadedSearch {
    private static final long serialVersionUID = 2542453917013899104L;
    protected Instances m_SearchSpaceTestInst;
    protected double m_SampleSize = 100.0d;
    protected int m_SearchSpaceNumFolds = 2;
    protected File m_SearchSpaceTestSet = new File(".");
    protected int m_NumIterations = 100;
    protected int m_RandomSeed = 1;

    @Override // weka.classifiers.meta.multisearch.AbstractSearch
    public String globalInfo() {
        return "Performs a search of an arbitrary number of parameters of a classifier and chooses the best pair found for the actual filtering and training.\n";
    }

    @Override // weka.classifiers.meta.multisearch.AbstractMultiThreadedSearch, weka.classifiers.meta.multisearch.AbstractSearch
    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tThe size (in percent) of the sample to search the inital space with.\n\t(default: 100)", "sample-size", 1, "-sample-size <num>"));
        vector.addElement(new Option("\tThe number of cross-validation folds for the search space.\n\tNumbers smaller than 2 turn off cross-validation and\n\tjust perform evaluation on the training set.\n\t(default: 2)", "num-folds", 1, "-num-folds <num>"));
        vector.addElement(new Option("\tThe (optional) test set to use for the search space.\n\tGets ignored if pointing to a file. Overrides cross-validation.\n\t(default: .)", "test-set", 1, "-test-set <filename>"));
        vector.addElement(new Option("\tThe number parameter settings that are tried (i.e., number of points in the search space are checked).\n\t(default: 100)", "num-iterations", 1, "-num-iterations <num>"));
        vector.addElement(new Option("\tThe random seed", "seed", 1, "-S <num>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    @Override // weka.classifiers.meta.multisearch.AbstractMultiThreadedSearch, weka.classifiers.meta.multisearch.AbstractSearch
    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-sample-size");
        vector.add("" + getSampleSizePercent());
        vector.add("-num-folds");
        vector.add("" + getSearchSpaceNumFolds());
        vector.add("-test-set");
        vector.add("" + getSearchSpaceTestSet());
        vector.add("-num-iterations");
        vector.add("" + getNumIterations());
        vector.add("-S");
        vector.add("" + getRandomSeed());
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    @Override // weka.classifiers.meta.multisearch.AbstractMultiThreadedSearch, weka.classifiers.meta.multisearch.AbstractSearch
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption("sample-size", strArr);
        if (option.length() != 0) {
            setSampleSizePercent(Double.parseDouble(option));
        } else {
            setSampleSizePercent(100.0d);
        }
        String option2 = Utils.getOption("num-folds", strArr);
        if (option2.length() != 0) {
            setSearchSpaceNumFolds(Integer.parseInt(option2));
        } else {
            setSearchSpaceNumFolds(2);
        }
        String option3 = Utils.getOption("test-set", strArr);
        if (option3.length() != 0) {
            setSearchSpaceTestSet(new File(option3));
        } else {
            setSearchSpaceTestSet(new File(System.getProperty("user.dir")));
        }
        String option4 = Utils.getOption("num-iterations", strArr);
        if (option4.length() != 0) {
            setNumIterations(Integer.parseInt(option4));
        } else {
            setNumIterations(100);
        }
        String option5 = Utils.getOption("S", strArr);
        if (option5.length() != 0) {
            setRandomSeed(Integer.parseInt(option5));
        } else {
            setRandomSeed(1);
        }
        super.setOptions(strArr);
    }

    public String sampleSizePercentTipText() {
        return "The sample size (in percent) to use in the search.";
    }

    public double getSampleSizePercent() {
        return this.m_SampleSize;
    }

    public void setSampleSizePercent(double d) {
        this.m_SampleSize = d;
    }

    public String searchSpaceNumFoldsTipText() {
        return "The number of cross-validation folds when evaluating the search space; values smaller than 2 turn cross-validation off and simple evaluation on the training set is performed.";
    }

    public int getSearchSpaceNumFolds() {
        return this.m_SearchSpaceNumFolds;
    }

    public void setSearchSpaceNumFolds(int i) {
        this.m_SearchSpaceNumFolds = i;
    }

    public String searchSpaceTestSetTipText() {
        return "The (optional) test set to use for evaluating the search space; overrides cross-validation; gets ignored if pointing to a directory.";
    }

    public File getSearchSpaceTestSet() {
        return this.m_SearchSpaceTestSet;
    }

    public void setSearchSpaceTestSet(File file) {
        this.m_SearchSpaceTestSet = file;
    }

    public String numIterationsTipText() {
        return "The number parameter settings that are tried; ";
    }

    public int getNumIterations() {
        return this.m_NumIterations;
    }

    public void setNumIterations(int i) {
        this.m_NumIterations = i;
    }

    public String randomSeedTipText() {
        return "The seed used for randomization";
    }

    public int getRandomSeed() {
        return this.m_RandomSeed;
    }

    public void setRandomSeed(int i) {
        this.m_RandomSeed = i;
    }

    protected Performance determineBestInSpace(Space space, Instances instances, Instances instances2, int i, Random random) throws Exception {
        this.m_Performances.clear();
        if (i >= 2) {
            log("Determining best values with " + i + "-fold CV in space:\n" + space + "\n");
        } else {
            log("Determining best values with evaluation on training set in space:\n" + space + "\n");
        }
        ArrayList list = Collections.list(space.values());
        this.m_NumSetups = Math.min(space.size(), this.m_NumIterations);
        Collections.shuffle(list, random);
        int classLabelIndex = instances.classAttribute().isNominal() ? this.m_Owner.getClassLabelIndex(instances.classAttribute().numValues()) : -1;
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.m_NumSetups; i2++) {
            Point<Object> point = (Point) list.get(i2);
            if (this.m_Cache.isCached(i, point)) {
                Performance performance = this.m_Cache.get(i, point);
                this.m_Performances.add(performance);
                this.m_Trace.add(new AbstractMap.SimpleEntry(Integer.valueOf(i), performance));
                log(performance + ": cached=true");
            } else {
                arrayList.add(this.m_ExecutorPool.submit(this.m_Owner.getFactory().newTask(this.m_Owner, instances, instances2, this.m_Owner.getGenerator(), point, i, this.m_Owner.getEvaluation().getSelectedTag().getID(), classLabelIndex)));
            }
        }
        try {
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                if (!((Boolean) ((Future) it.next()).get()).booleanValue()) {
                    throw new IllegalStateException("Execution of evaluaton thread failed.");
                }
            }
            Collections.sort(this.m_Performances, new PerformanceComparator(this.m_Owner.getEvaluation().getSelectedTag().getID(), this.m_Owner.getMetrics()));
            Performance firstElement = this.m_Performances.firstElement();
            this.m_UniformPerformance = true;
            Performance performance2 = this.m_Performances.get(0);
            int i3 = 1;
            while (true) {
                if (i3 >= this.m_Performances.size()) {
                    break;
                }
                if (this.m_Performances.get(i3).getPerformance(this.m_Owner.getEvaluation().getSelectedTag().getID()) != performance2.getPerformance(this.m_Owner.getEvaluation().getSelectedTag().getID())) {
                    this.m_UniformPerformance = false;
                    break;
                }
                i3++;
            }
            if (this.m_UniformPerformance) {
                log("All performances are the same!");
            }
            logPerformances(space, this.m_Performances);
            log("\nBest performance:\n" + this.m_Performances.firstElement());
            this.m_Performances.clear();
            return firstElement;
        } catch (Exception e) {
            throw new IllegalStateException("Thread-based execution of evaluation tasks failed: " + e.getMessage());
        }
    }

    protected Performance findBest(Instances instances) throws Exception {
        Instances useFilter;
        Random random = new Random(this.m_RandomSeed);
        log("Step 1:\n");
        if (getSampleSizePercent() == 100.0d) {
            useFilter = instances;
        } else {
            log("Generating sample (" + getSampleSizePercent() + "%)");
            Resample resample = new Resample();
            resample.setRandomSeed(retrieveOwner().getSeed());
            resample.setSampleSizePercent(getSampleSizePercent());
            resample.setInputFormat(instances);
            useFilter = Filter.useFilter(instances, resample);
        }
        this.m_UniformPerformance = false;
        log("\n=== Search space - Start ===");
        Performance determineBestInSpace = determineBestInSpace(this.m_Space, useFilter, this.m_SearchSpaceTestInst, this.m_SearchSpaceNumFolds, random);
        log("\nResult: " + determineBestInSpace + "\n");
        log("=== Search space - End ===\n");
        log("Classifier: " + getCommandline(this.m_Owner.getGenerator().setup((Serializable) this.m_Owner.getClassifier(), this.m_Owner.getGenerator().evaluate(determineBestInSpace.getValues()))));
        return determineBestInSpace;
    }

    protected void loadTestData(Instances instances) throws Exception {
        this.m_SearchSpaceTestInst = null;
        if (!this.m_SearchSpaceTestSet.exists() || this.m_SearchSpaceTestSet.isDirectory()) {
            return;
        }
        this.m_SearchSpaceTestInst = ConverterUtils.DataSource.read(this.m_SearchSpaceTestSet.getAbsolutePath());
        this.m_SearchSpaceTestInst.setClassIndex(instances.classIndex());
        String equalHeadersMsg = instances.equalHeadersMsg(this.m_SearchSpaceTestInst);
        if (equalHeadersMsg != null) {
            throw new IllegalArgumentException("Test set for search space not compatible with training dta:\n" + equalHeadersMsg);
        }
        this.m_SearchSpaceTestInst.deleteWithMissingClass();
        log("Using test set for search space: " + this.m_SearchSpaceTestSet);
    }

    @Override // weka.classifiers.meta.multisearch.AbstractSearch
    public AbstractSearch.SearchResult doSearch(Instances instances) throws Exception {
        loadTestData(instances);
        Performance findBest = findBest(new Instances(instances));
        Point<Object> evaluate = this.m_Owner.getGenerator().evaluate(findBest.getValues());
        AbstractSearch.SearchResult searchResult = new AbstractSearch.SearchResult();
        searchResult.classifier = this.m_Owner.getGenerator().setup((Serializable) this.m_Owner.getClassifier(), evaluate);
        searchResult.performance = findBest;
        searchResult.values = evaluate;
        return searchResult;
    }
}
