package smile.classification;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import smile.data.Attribute;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.math.Random;
import smile.util.MulticoreExecutor;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

/* loaded from: input_file:smile/classification/RandomForest.class */
public class RandomForest implements Classifier<double[]> {
    private List<DecisionTree> trees;
    private int k;
    private double error;
    private double[] importance;

    /* loaded from: input_file:smile/classification/RandomForest$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private int T;
        private int M;

        public Trainer(int i) {
            this.T = 500;
            this.M = -1;
            if (i < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + i);
            }
            this.T = i;
        }

        public Trainer(Attribute[] attributeArr, int i) {
            super(attributeArr);
            this.T = 500;
            this.M = -1;
            if (i < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + i);
            }
            this.T = i;
        }

        public void setNumTrees(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invlaid number of trees: " + i);
            }
            this.T = i;
        }

        public void setNumRandomFeatures(int i) {
            if (i < 1) {
                throw new IllegalArgumentException("Invalid number of random selected features for splitting: " + i);
            }
            this.M = i;
        }

        @Override // smile.classification.ClassifierTrainer
        public RandomForest train(double[][] dArr, int[] iArr) {
            return this.M < 0 ? new RandomForest(this.attributes, dArr, iArr, this.T) : new RandomForest(this.attributes, dArr, iArr, this.T, this.M);
        }
    }

    /* loaded from: input_file:smile/classification/RandomForest$TrainingTask.class */
    static class TrainingTask implements Callable<DecisionTree> {
        Attribute[] attributes;
        double[][] x;
        int[] y;
        int[][] order;
        int M;
        int[][] prediction;

        TrainingTask(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int[][] iArr2, int[][] iArr3) {
            this.attributes = attributeArr;
            this.x = dArr;
            this.y = iArr;
            this.order = iArr2;
            this.M = i;
            this.prediction = iArr3;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public DecisionTree call() {
            int length = this.x.length;
            Random random = new Random(Thread.currentThread().getId() * System.currentTimeMillis());
            int[] iArr = new int[length];
            for (int i = 0; i < length; i++) {
                int nextInt = random.nextInt(length);
                iArr[nextInt] = iArr[nextInt] + 1;
            }
            DecisionTree decisionTree = new DecisionTree(this.attributes, this.x, this.y, this.M, iArr, this.order);
            for (int i2 = 0; i2 < length; i2++) {
                if (iArr[i2] == 0) {
                    int predict = decisionTree.predict(this.x[i2]);
                    synchronized (this.prediction[i2]) {
                        int[] iArr2 = this.prediction[i2];
                        iArr2[predict] = iArr2[predict] + 1;
                    }
                }
            }
            return decisionTree;
        }
    }

    public RandomForest(double[][] dArr, int[] iArr, int i) {
        this((Attribute[]) null, dArr, iArr, i);
    }

    public RandomForest(double[][] dArr, int[] iArr, int i, int i2) {
        this(null, dArr, iArr, i, i2);
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i) {
        this(attributeArr, dArr, iArr, i, (int) Math.floor(Math.sqrt(dArr[0].length)));
    }

    public RandomForest(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2) {
        this.k = 2;
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException("Invlaid number of trees: " + i);
        }
        if (i2 < 1) {
            throw new IllegalArgumentException("Invalid number of variables for splitting: " + i2);
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i3 = 0; i3 < unique.length; i3++) {
            if (unique[i3] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i3]);
            }
            if (i3 > 0 && unique[i3] - unique[i3 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + unique[i3] + 1);
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributeArr == null) {
            int length = dArr[0].length;
            attributeArr = new Attribute[length];
            for (int i4 = 0; i4 < length; i4++) {
                attributeArr[i4] = new NumericAttribute("V" + (i4 + 1));
            }
        }
        int length2 = dArr.length;
        int[][] iArr2 = new int[length2][this.k];
        int[][] sort = SmileUtils.sort(attributeArr, dArr);
        ArrayList arrayList = new ArrayList();
        for (int i5 = 0; i5 < i; i5++) {
            arrayList.add(new TrainingTask(attributeArr, dArr, iArr, i2, sort, iArr2));
        }
        try {
            this.trees = MulticoreExecutor.run(arrayList);
        } catch (Exception e) {
            System.err.println(e);
            this.trees = new ArrayList(i);
            for (int i6 = 0; i6 < i; i6++) {
                this.trees.add(((TrainingTask) arrayList.get(i6)).call());
            }
        }
        int i7 = 0;
        for (int i8 = 0; i8 < length2; i8++) {
            int whichMax = Math.whichMax(iArr2[i8]);
            if (iArr2[i8][whichMax] > 0) {
                i7++;
                if (whichMax != iArr[i8]) {
                    this.error += 1.0d;
                }
            }
        }
        if (i7 > 0) {
            this.error /= i7;
        }
        this.importance = new double[attributeArr.length];
        Iterator<DecisionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            double[] importance = it.next().importance();
            for (int i9 = 0; i9 < importance.length; i9++) {
                double[] dArr2 = this.importance;
                int i10 = i9;
                dArr2[i10] = dArr2[i10] + importance[i9];
            }
        }
    }

    public double error() {
        return this.error;
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.size();
    }

    public void trim(int i) {
        if (i > this.trees.size()) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(this.trees.get(i2));
        }
        this.trees = arrayList;
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        int[] iArr = new int[this.k];
        Iterator<DecisionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            int predict = it.next().predict(dArr);
            iArr[predict] = iArr[predict] + 1;
        }
        return Math.whichMax(iArr);
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr, double[] dArr2) {
        if (dArr2.length != this.k) {
            throw new IllegalArgumentException(String.format("Invalid posteriori vector size: %d, expected: %d", Integer.valueOf(dArr2.length), Integer.valueOf(this.k)));
        }
        int[] iArr = new int[this.k];
        Iterator<DecisionTree> it = this.trees.iterator();
        while (it.hasNext()) {
            int predict = it.next().predict(dArr);
            iArr[predict] = iArr[predict] + 1;
        }
        double size = this.trees.size();
        for (int i = 0; i < this.k; i++) {
            dArr2[i] = iArr[i] / size;
        }
        return Math.whichMax(iArr);
    }

    public double[] test(double[][] dArr, int[] iArr) {
        int size = this.trees.size();
        double[] dArr2 = new double[size];
        int length = dArr.length;
        int[] iArr2 = new int[length];
        int[][] iArr3 = new int[length][this.k];
        Accuracy accuracy = new Accuracy();
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < length; i2++) {
                int[] iArr4 = iArr3[i2];
                int predict = this.trees.get(i).predict(dArr[i2]);
                iArr4[predict] = iArr4[predict] + 1;
                iArr2[i2] = Math.whichMax(iArr3[i2]);
            }
            dArr2[i] = accuracy.measure(iArr, iArr2);
        }
        return dArr2;
    }

    public double[][] test(double[][] dArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        int size = this.trees.size();
        int length = classificationMeasureArr.length;
        double[][] dArr2 = new double[size][length];
        int length2 = dArr.length;
        int[] iArr2 = new int[length2];
        double[][] dArr3 = new double[length2][this.k];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                double[] dArr4 = dArr3[i2];
                int predict = this.trees.get(i).predict(dArr[i2]);
                dArr4[predict] = dArr4[predict] + 1.0d;
                iArr2[i2] = Math.whichMax(dArr3[i2]);
            }
            for (int i3 = 0; i3 < length; i3++) {
                dArr2[i][i3] = classificationMeasureArr[i3].measure(iArr, iArr2);
            }
        }
        return dArr2;
    }
}
