/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableParallelIteratedSingleClassifierEnhancer;
import weka.classifiers.trees.REPTree;
import weka.core.AdditionalMeasureProducer;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.PartitionGenerator;
import weka.core.Randomizable;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class Bagging
extends RandomizableParallelIteratedSingleClassifierEnhancer
implements WeightedInstancesHandler,
AdditionalMeasureProducer,
TechnicalInformationHandler,
PartitionGenerator {
    static final long serialVersionUID = -115879962237199703L;
    protected int m_BagSizePercent = 100;
    protected boolean m_CalcOutOfBag = false;
    protected double m_OutOfBagError;
    protected Random m_random;
    protected boolean[][] m_inBag;
    protected Instances m_data;

    public Bagging() {
        this.m_Classifier = new REPTree();
    }

    public String globalInfo() {
        return "Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. \n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Leo Breiman");
        result.setValue(TechnicalInformation.Field.YEAR, "1996");
        result.setValue(TechnicalInformation.Field.TITLE, "Bagging predictors");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        result.setValue(TechnicalInformation.Field.VOLUME, "24");
        result.setValue(TechnicalInformation.Field.NUMBER, "2");
        result.setValue(TechnicalInformation.Field.PAGES, "123-140");
        return result;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.REPTree";
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>(2);
        newVector.addElement(new Option("\tSize of each bag, as a percentage of the\n\ttraining set size. (default 100)", "P", 1, "-P"));
        newVector.addElement(new Option("\tCalculate the out of bag error.", "O", 0, "-O"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String bagSize = Utils.getOption('P', options);
        if (bagSize.length() != 0) {
            this.setBagSizePercent(Integer.parseInt(bagSize));
        } else {
            this.setBagSizePercent(100);
        }
        this.setCalcOutOfBag(Utils.getFlag('O', options));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 3];
        int current = 0;
        options[current++] = "-P";
        options[current++] = "" + this.getBagSizePercent();
        if (this.getCalcOutOfBag()) {
            options[current++] = "-O";
        }
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        current += superOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String bagSizePercentTipText() {
        return "Size of each bag, as a percentage of the training set size.";
    }

    public int getBagSizePercent() {
        return this.m_BagSizePercent;
    }

    public void setBagSizePercent(int newBagSizePercent) {
        this.m_BagSizePercent = newBagSizePercent;
    }

    public String calcOutOfBagTipText() {
        return "Whether the out-of-bag error is calculated.";
    }

    public void setCalcOutOfBag(boolean calcOutOfBag) {
        this.m_CalcOutOfBag = calcOutOfBag;
    }

    public boolean getCalcOutOfBag() {
        return this.m_CalcOutOfBag;
    }

    public double measureOutOfBagError() {
        return this.m_OutOfBagError;
    }

    @Override
    public Enumeration enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureOutOfBagError");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.equalsIgnoreCase("measureOutOfBagError")) {
            return this.measureOutOfBagError();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (Bagging)");
    }

    @Override
    protected synchronized Instances getTrainingSet(int iteration) throws Exception {
        int bagSize = this.m_data.numInstances() * this.m_BagSizePercent / 100;
        Instances bagData = null;
        Random r = new Random(this.m_Seed + iteration);
        if (this.m_CalcOutOfBag) {
            this.m_inBag[iteration] = new boolean[this.m_data.numInstances()];
            bagData = this.m_data.resampleWithWeights(r, this.m_inBag[iteration]);
        } else {
            bagData = this.m_data.resampleWithWeights(r);
            if (bagSize < this.m_data.numInstances()) {
                Instances newBagData;
                bagData.randomize(r);
                bagData = newBagData = new Instances(bagData, 0, bagSize);
            }
        }
        return bagData;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_data = new Instances(data);
        this.m_data.deleteWithMissingClass();
        super.buildClassifier(this.m_data);
        if (this.m_CalcOutOfBag && this.m_BagSizePercent != 100) {
            throw new IllegalArgumentException("Bag size needs to be 100% if out-of-bag error is to be calculated!");
        }
        int bagSize = this.m_data.numInstances() * this.m_BagSizePercent / 100;
        this.m_random = new Random(this.m_Seed);
        this.m_inBag = null;
        if (this.m_CalcOutOfBag) {
            this.m_inBag = new boolean[this.m_Classifiers.length][];
        }
        for (int j = 0; j < this.m_Classifiers.length; ++j) {
            if (!(this.m_Classifier instanceof Randomizable)) continue;
            ((Randomizable)((Object)this.m_Classifiers[j])).setSeed(this.m_random.nextInt());
        }
        this.buildClassifiers();
        if (this.getCalcOutOfBag()) {
            double outOfBagCount = 0.0;
            double errorSum = 0.0;
            boolean numeric = this.m_data.classAttribute().isNumeric();
            for (int i = 0; i < this.m_data.numInstances(); ++i) {
                double vote;
                double[] votes = numeric ? new double[1] : new double[this.m_data.numClasses()];
                int voteCount = 0;
                for (int j = 0; j < this.m_Classifiers.length; ++j) {
                    if (this.m_inBag[j][i]) continue;
                    ++voteCount;
                    if (numeric) {
                        votes[0] = votes[0] + this.m_Classifiers[j].classifyInstance(this.m_data.instance(i));
                        continue;
                    }
                    double[] newProbs = this.m_Classifiers[j].distributionForInstance(this.m_data.instance(i));
                    for (int k = 0; k < newProbs.length; ++k) {
                        int n = k;
                        votes[n] = votes[n] + newProbs[k];
                    }
                }
                if (numeric) {
                    vote = votes[0];
                    if (voteCount > 0) {
                        vote /= (double)voteCount;
                    }
                } else {
                    if (!Utils.eq(Utils.sum(votes), 0.0)) {
                        Utils.normalize(votes);
                    }
                    vote = Utils.maxIndex(votes);
                }
                outOfBagCount += this.m_data.instance(i).weight();
                if (numeric) {
                    errorSum += StrictMath.abs(vote - this.m_data.instance(i).classValue()) * this.m_data.instance(i).weight();
                    continue;
                }
                if (vote == this.m_data.instance(i).classValue()) continue;
                errorSum += this.m_data.instance(i).weight();
            }
            this.m_OutOfBagError = errorSum / outOfBagCount;
        } else {
            this.m_OutOfBagError = 0.0;
        }
        this.m_data = null;
    }

    @Override
    public double[] distributionForInstance(Instance instance) throws Exception {
        double[] sums = new double[instance.numClasses()];
        for (int i = 0; i < this.m_NumIterations; ++i) {
            if (instance.classAttribute().isNumeric()) {
                sums[0] = sums[0] + this.m_Classifiers[i].classifyInstance(instance);
                continue;
            }
            double[] newProbs = this.m_Classifiers[i].distributionForInstance(instance);
            for (int j = 0; j < newProbs.length; ++j) {
                int n = j;
                sums[n] = sums[n] + newProbs[j];
            }
        }
        if (instance.classAttribute().isNumeric()) {
            sums[0] = sums[0] / (double)this.m_NumIterations;
            return sums;
        }
        if (Utils.eq(Utils.sum(sums), 0.0)) {
            return sums;
        }
        Utils.normalize(sums);
        return sums;
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "Bagging: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("All the base classifiers: \n\n");
        for (int i = 0; i < this.m_Classifiers.length; ++i) {
            text.append(this.m_Classifiers[i].toString() + "\n\n");
        }
        if (this.m_CalcOutOfBag) {
            text.append("Out of bag error: " + Utils.doubleToString(this.m_OutOfBagError, 4) + "\n\n");
        }
        return text.toString();
    }

    @Override
    public void generatePartition(Instances data) throws Exception {
        if (!(this.m_Classifier instanceof PartitionGenerator)) {
            throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
        }
        this.buildClassifier(data);
    }

    @Override
    public double[] getMembershipValues(Instance inst) throws Exception {
        if (this.m_Classifier instanceof PartitionGenerator) {
            ArrayList<double[]> al = new ArrayList<double[]>();
            int size = 0;
            for (int i = 0; i < this.m_Classifiers.length; ++i) {
                double[] r = ((PartitionGenerator)((Object)this.m_Classifiers[i])).getMembershipValues(inst);
                size += r.length;
                al.add(r);
            }
            double[] values = new double[size];
            int pos = 0;
            for (double[] v : al) {
                System.arraycopy(v, 0, values, pos, v.length);
                pos += v.length;
            }
            return values;
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
    }

    @Override
    public int numElements() throws Exception {
        if (this.m_Classifier instanceof PartitionGenerator) {
            int size = 0;
            for (int i = 0; i < this.m_Classifiers.length; ++i) {
                size += ((PartitionGenerator)((Object)this.m_Classifiers[i])).numElements();
            }
            return size;
        }
        throw new Exception("Classifier: " + this.getClassifierSpec() + " cannot generate a partition");
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9186 $");
    }

    public static void main(String[] argv) {
        Bagging.runClassifier(new Bagging(), argv);
    }
}

