package moa.classifiers.lazy;

import com.github.javacliparser.FlagOption;
import com.github.javacliparser.FloatOption;
import com.github.javacliparser.IntOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.InstanceImpl;
import com.yahoo.labs.samoa.instances.Instances;
import com.yahoo.labs.samoa.instances.InstancesHeader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.TreeSet;
import moa.classifiers.AbstractClassifier;
import moa.classifiers.MultiClassClassifier;
import moa.clusterers.kmeanspm.CoresetKMeans;
import moa.core.Measurement;
import moa.gui.visualization.RunVisualizer;

/* loaded from: input_file:moa/classifiers/lazy/SAMkNN.class */
public class SAMkNN extends AbstractClassifier implements MultiClassClassifier {
    private static final long serialVersionUID = 1;
    public IntOption kOption = new IntOption("k", 'k', "The number of neighbors", 5, 1, Integer.MAX_VALUE);
    public IntOption limitOption = new IntOption("limit", 'w', "The maximum number of instances to store", RunVisualizer.initialPauseInterval, 1, Integer.MAX_VALUE);
    public IntOption minSTMSizeOption = new IntOption("minSTMSize", 'm', "The minimum number of instances in the STM", 50, 1, Integer.MAX_VALUE);
    public FloatOption relativeLTMSizeOption = new FloatOption("relativeLTMSize", 'p', "The allowed LTM size relative to the total limit.", 0.4d, 0.0d, 1.0d);
    public FlagOption recalculateSTMErrorOption = new FlagOption("recalculateError", 'r', "Recalculates the error rate of the STM for size adaption (Costly operation). Otherwise, an approximation is used.");
    private int maxClassValue = 0;
    private Instances stm;
    private Instances ltm;
    private int maxLTMSize;
    private int maxSTMSize;
    private List<Integer> stmHistory;
    private List<Integer> ltmHistory;
    private List<Integer> cmHistory;
    private double[][] distanceMatrixSTM;
    private Map<Integer, List<Integer>> predictionHistories;
    private Random random;

    @Override // moa.classifiers.AbstractClassifier, moa.options.AbstractOptionHandler, moa.options.OptionHandler
    public String getPurposeString() {
        return "SAMkNN: special.";
    }

    protected void init() {
        this.maxLTMSize = (int) (this.relativeLTMSizeOption.getValue() * this.limitOption.getValue());
        this.maxSTMSize = this.limitOption.getValue() - this.maxLTMSize;
        this.stmHistory = new ArrayList();
        this.ltmHistory = new ArrayList();
        this.cmHistory = new ArrayList();
        this.distanceMatrixSTM = new double[this.limitOption.getValue() + 1][this.limitOption.getValue() + 1];
        this.predictionHistories = new HashMap();
        this.random = new Random();
    }

    @Override // moa.classifiers.AbstractClassifier, moa.learners.Learner
    public void setModelContext(InstancesHeader instancesHeader) {
        try {
            this.stm = new Instances(instancesHeader, 0);
            this.stm.setClassIndex(instancesHeader.classIndex());
            this.ltm = new Instances(instancesHeader, 0);
            this.ltm.setClassIndex(instancesHeader.classIndex());
            init();
        } catch (Exception e) {
            System.err.println("Error: no Model Context available.");
            e.printStackTrace();
            System.exit(1);
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    public void resetLearningImpl() {
        this.stm = null;
        this.ltm = null;
        this.stmHistory = null;
        this.ltmHistory = null;
        this.cmHistory = null;
        this.distanceMatrixSTM = (double[][]) null;
        this.predictionHistories = null;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void trainOnInstanceImpl(Instance instance) {
        if (instance.classValue() > this.maxClassValue) {
            this.maxClassValue = (int) instance.classValue();
        }
        this.stm.add(instance);
        memorySizeCheck();
        clean(this.stm, this.ltm, true);
        double[] dArr = get1ToNDistances(instance, this.stm);
        for (int i = 0; i < this.stm.numInstances(); i++) {
            this.distanceMatrixSTM[this.stm.numInstances() - 1][i] = dArr[i];
        }
        int numInstances = this.stm.numInstances();
        int newSTMSize = getNewSTMSize(this.recalculateSTMErrorOption.isSet());
        if (newSTMSize < numInstances) {
            int i2 = numInstances - newSTMSize;
            Instances instances = new Instances(this.stm, 0);
            for (int i3 = i2; i3 > 0; i3--) {
                instances.add(this.stm.get(0).copy());
                this.stm.delete(0);
            }
            for (int i4 = 0; i4 < this.stm.numInstances(); i4++) {
                for (int i5 = 0; i5 < this.stm.numInstances(); i5++) {
                    this.distanceMatrixSTM[i4][i5] = this.distanceMatrixSTM[i2 + i4][i2 + i5];
                }
            }
            for (int i6 = 0; i6 < i2; i6++) {
                this.stmHistory.remove(0);
                this.ltmHistory.remove(0);
                this.cmHistory.remove(0);
            }
            clean(this.stm, instances, false);
            for (int i7 = 0; i7 < instances.numInstances(); i7++) {
                this.ltm.add(instances.get(i7).copy());
            }
            memorySizeCheck();
        }
    }

    @Override // moa.classifiers.AbstractClassifier, moa.classifiers.Classifier
    public double[] getVotesForInstance(Instance instance) {
        double[] dArr;
        double[] dArr2;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        try {
            if (this.stm.numInstances() > 0) {
                double[] dArr3 = get1ToNDistances(instance, this.stm);
                double[] distanceWeightedVotes = getDistanceWeightedVotes(dArr3, nArgMin(Math.min(dArr3.length, this.kOption.getValue()), dArr3), this.stm);
                i = getClassFromVotes(distanceWeightedVotes);
                double[] dArr4 = get1ToNDistances(instance, this.ltm);
                double[] cMVotes = getCMVotes(dArr3, this.stm, dArr4, this.ltm);
                i3 = getClassFromVotes(cMVotes);
                if (this.ltm.numInstances() >= 0) {
                    dArr2 = getDistanceWeightedVotes(dArr4, nArgMin(Math.min(dArr4.length, this.kOption.getValue()), dArr4), this.ltm);
                    i2 = getClassFromVotes(dArr2);
                } else {
                    dArr2 = new double[instance.numClasses()];
                }
                int historySum = historySum(this.stmHistory);
                int historySum2 = historySum(this.ltmHistory);
                int historySum3 = historySum(this.cmHistory);
                dArr = (historySum < historySum2 || historySum < historySum3) ? (historySum2 <= historySum || historySum2 < historySum3) ? cMVotes : dArr2 : distanceWeightedVotes;
            } else {
                dArr = new double[instance.numClasses()];
            }
            this.stmHistory.add(Integer.valueOf(((double) i) == instance.classValue() ? 1 : 0));
            this.ltmHistory.add(Integer.valueOf(((double) i2) == instance.classValue() ? 1 : 0));
            this.cmHistory.add(Integer.valueOf(((double) i3) == instance.classValue() ? 1 : 0));
            return dArr;
        } catch (Exception e) {
            return new double[instance.numClasses()];
        }
    }

    @Override // moa.classifiers.AbstractClassifier
    protected Measurement[] getModelMeasurementsImpl() {
        return null;
    }

    @Override // moa.classifiers.AbstractClassifier
    public void getModelDescription(StringBuilder sb, int i) {
    }

    @Override // moa.learners.Learner
    public boolean isRandomizable() {
        return false;
    }

    private int historySum(List<Integer> list) {
        int i = 0;
        Iterator<Integer> it = list.iterator();
        while (it.hasNext()) {
            i += it.next().intValue();
        }
        return i;
    }

    private List<double[]> kMeans(List<double[]> list, int i) {
        List<double[]> generatekMeansPlusPlusCentroids = CoresetKMeans.generatekMeansPlusPlusCentroids(i, list, this.random);
        CoresetKMeans.kMeans(generatekMeansPlusPlusCentroids, list);
        return generatekMeansPlusPlusCentroids;
    }

    private void clusterDown() {
        int classIndex = this.ltm.classIndex();
        for (int i = 0; i <= this.maxClassValue; i++) {
            ArrayList arrayList = new ArrayList();
            for (int numInstances = this.ltm.numInstances() - 1; numInstances > -1; numInstances--) {
                if (this.ltm.get(numInstances).classValue() == i) {
                    arrayList.add(this.ltm.get(numInstances).toDoubleArray());
                    this.ltm.delete(numInstances);
                }
            }
            if (arrayList.size() > 0) {
                for (double[] dArr : arrayList) {
                    if (classIndex != 0) {
                        dArr[classIndex] = dArr[0];
                    }
                    dArr[0] = 1.0d;
                }
                for (double[] dArr2 : kMeans(arrayList, Math.max(arrayList.size() / 2, 1))) {
                    double[] dArr3 = new double[this.ltm.numAttributes()];
                    System.arraycopy(dArr2, 0, dArr3, 1, this.ltm.numAttributes() - 1);
                    if (classIndex != 0) {
                        dArr3[0] = dArr3[classIndex];
                    }
                    dArr3[classIndex] = i;
                    InstanceImpl instanceImpl = new InstanceImpl(1.0d, dArr3);
                    instanceImpl.setDataset(this.ltm);
                    this.ltm.add(instanceImpl);
                }
            }
        }
    }

    private void memorySizeCheck() {
        if (this.stm.numInstances() + this.ltm.numInstances() > this.maxSTMSize + this.maxLTMSize) {
            if (this.ltm.numInstances() > this.maxLTMSize) {
                clusterDown();
                return;
            }
            int numInstances = (this.maxLTMSize - this.ltm.numInstances()) + 1;
            for (int i = 0; i < numInstances; i++) {
                this.ltm.add(this.stm.get(0).copy());
                this.stm.delete(0);
                this.stmHistory.remove(0);
                this.ltmHistory.remove(0);
                this.cmHistory.remove(0);
            }
            clusterDown();
            this.predictionHistories.clear();
            for (int i2 = 0; i2 < this.stm.numInstances(); i2++) {
                for (int i3 = 0; i3 < this.stm.numInstances(); i3++) {
                    this.distanceMatrixSTM[i2][i3] = this.distanceMatrixSTM[numInstances + i2][numInstances + i3];
                }
            }
        }
    }

    private void cleanSingle(Instances instances, int i, Instances instances2) {
        Instances instances3 = new Instances(instances);
        instances3.delete(i);
        double[] dArr = get1ToNDistances(instances.get(i), instances3);
        int[] nArgMin = nArgMin(Math.min(this.kOption.getValue(), dArr.length), dArr);
        double[] dArr2 = get1ToNDistances(instances.get(i), instances2);
        int[] nArgMin2 = nArgMin(Math.min(this.kOption.getValue(), dArr2.length), dArr2);
        double d = 0.0d;
        for (int i2 : nArgMin) {
            if (instances3.get(i2).classValue() == instances.get(i).classValue() && dArr[i2] > d) {
                d = dArr[i2];
            }
        }
        ArrayList arrayList = new ArrayList();
        for (int i3 : nArgMin2) {
            if (instances2.get(i3).classValue() != instances.get(i).classValue() && dArr2[i3] <= d) {
                arrayList.add(Integer.valueOf(i3));
            }
        }
        Collections.sort(arrayList, Collections.reverseOrder());
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            instances2.delete(((Integer) it.next()).intValue());
        }
    }

    private void clean(Instances instances, Instances instances2, boolean z) {
        if (instances.numInstances() <= this.kOption.getValue() || instances2.numInstances() <= 0) {
            return;
        }
        if (z) {
            cleanSingle(instances, instances.numInstances() - 1, instances2);
            return;
        }
        for (int i = 0; i < instances.numInstances(); i++) {
            cleanSingle(instances, i, instances2);
        }
    }

    private double[] getDistanceWeightedVotes(double[] dArr, int[] iArr, Instances instances) {
        double[] dArr2 = new double[this.maxClassValue + 1];
        for (int i : iArr) {
            int classValue = (int) instances.instance(i).classValue();
            dArr2[classValue] = dArr2[classValue] + (1.0d / Math.max(dArr[i], 1.0E-9d));
        }
        return dArr2;
    }

    private double[] getDistanceWeightedVotesCM(double[] dArr, int[] iArr, Instances instances, Instances instances2) {
        double[] dArr2 = new double[this.maxClassValue + 1];
        for (int i : iArr) {
            if (i < instances.numInstances()) {
                int classValue = (int) instances.instance(i).classValue();
                dArr2[classValue] = dArr2[classValue] + (1.0d / Math.max(dArr[i], 1.0E-9d));
            } else {
                int classValue2 = (int) instances2.instance(i - instances.numInstances()).classValue();
                dArr2[classValue2] = dArr2[classValue2] + (1.0d / Math.max(dArr[i], 1.0E-9d));
            }
        }
        return dArr2;
    }

    private double[] getCMVotes(double[] dArr, Instances instances, double[] dArr2, Instances instances2) {
        double[] dArr3 = new double[dArr.length + dArr2.length];
        System.arraycopy(dArr, 0, dArr3, 0, dArr.length);
        System.arraycopy(dArr2, 0, dArr3, dArr.length, dArr2.length);
        return getDistanceWeightedVotesCM(dArr3, nArgMin(Math.min(dArr3.length, this.kOption.getValue()), dArr3), instances, instances2);
    }

    private int getClassFromVotes(double[] dArr) {
        double d = -1.0d;
        int i = -1;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    private int getLabelFct(double[] dArr, Instances instances, int i, int i2) {
        return getClassFromVotes(getDistanceWeightedVotes(dArr, nArgMin(Math.min(this.kOption.getValue(), dArr.length), dArr, i, i2), instances));
    }

    private double getDistance(Instance instance, Instance instance2) {
        double d = 0.0d;
        for (int i = 0; i < instance.numInputAttributes(); i++) {
            double valueInputAttribute = instance.valueInputAttribute(i) - instance2.valueInputAttribute(i);
            d += valueInputAttribute * valueInputAttribute;
        }
        return Math.sqrt(d);
    }

    private double[] get1ToNDistances(Instance instance, Instances instances) {
        double[] dArr = new double[instances.numInstances()];
        for (int i = 0; i < instances.numInstances(); i++) {
            dArr[i] = getDistance(instance, instances.get(i));
        }
        return dArr;
    }

    private int[] nArgMin(int i, double[] dArr, int i2, int i3) {
        int[] iArr = new int[i];
        for (int i4 = 0; i4 < i; i4++) {
            double d = Double.MAX_VALUE;
            for (int i5 = i2; i5 < i3 + 1; i5++) {
                if (dArr[i5] < d) {
                    boolean z = false;
                    for (int i6 = 0; i6 < i4; i6++) {
                        if (iArr[i6] == i5) {
                            z = true;
                        }
                    }
                    if (!z) {
                        iArr[i4] = i5;
                        d = dArr[i5];
                    }
                }
            }
        }
        return iArr;
    }

    private int[] nArgMin(int i, double[] dArr) {
        return nArgMin(i, dArr, 0, dArr.length - 1);
    }

    private void adaptHistories(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            this.predictionHistories.remove(new TreeSet(this.predictionHistories.keySet()).first());
            TreeSet<Integer> treeSet = new TreeSet(this.predictionHistories.keySet());
            for (Integer num : treeSet) {
                this.predictionHistories.put(Integer.valueOf(num.intValue() - ((Integer) treeSet.first()).intValue()), this.predictionHistories.remove(num));
            }
        }
    }

    private List<Integer> getIncrementalTestTrainPredHistory(Instances instances, int i, List<Integer> list) {
        for (int value = i + this.kOption.getValue() + list.size(); value < instances.numInstances(); value++) {
            list.add(Integer.valueOf(((double) getLabelFct(this.distanceMatrixSTM[value], instances, i, value - 1)) == instances.get(value).classValue() ? 1 : 0));
        }
        return list;
    }

    private List<Integer> getTestTrainPredHistory(Instances instances, int i) {
        ArrayList arrayList = new ArrayList();
        for (int value = i + this.kOption.getValue(); value < instances.numInstances(); value++) {
            arrayList.add(Integer.valueOf(((double) getLabelFct(this.distanceMatrixSTM[value], instances, i, value - 1)) == instances.get(value).classValue() ? 1 : 0));
        }
        return arrayList;
    }

    private int getMinErrorRateWindowSize() {
        int numInstances = this.stm.numInstances();
        if (numInstances < 2 * this.minSTMSizeOption.getValue()) {
            return numInstances;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(Integer.valueOf(numInstances));
        while (((Integer) arrayList.get(arrayList.size() - 1)).intValue() >= 2 * this.minSTMSizeOption.getValue()) {
            arrayList.add(Integer.valueOf(((Integer) arrayList.get(arrayList.size() - 1)).intValue() / 2));
        }
        Iterator<Integer> it = this.predictionHistories.keySet().iterator();
        while (it.hasNext()) {
            if (!arrayList.contains(Integer.valueOf(numInstances - it.next().intValue()))) {
                it.remove();
            }
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it2 = arrayList.iterator();
        while (it2.hasNext()) {
            int intValue = numInstances - ((Integer) it2.next()).intValue();
            List<Integer> incrementalTestTrainPredHistory = this.predictionHistories.containsKey(Integer.valueOf(intValue)) ? getIncrementalTestTrainPredHistory(this.stm, intValue, this.predictionHistories.get(Integer.valueOf(intValue))) : getTestTrainPredHistory(this.stm, intValue);
            this.predictionHistories.put(Integer.valueOf(intValue), incrementalTestTrainPredHistory);
            arrayList2.add(Double.valueOf(getHistoryErrorRate(incrementalTestTrainPredHistory)));
        }
        int indexOf = arrayList2.indexOf(Collections.min(arrayList2));
        int intValue2 = ((Integer) arrayList.get(indexOf)).intValue();
        if (intValue2 < numInstances) {
            adaptHistories(indexOf);
        }
        return intValue2;
    }

    private double getHistoryErrorRate(List<Integer> list) {
        double d = 0.0d;
        while (list.iterator().hasNext()) {
            d += r0.next().intValue();
        }
        return 1.0d - (d / list.size());
    }

    private int getMinErrorRateWindowSizeIncremental() {
        List<Integer> testTrainPredHistory;
        int numInstances = this.stm.numInstances();
        if (numInstances < 2 * this.minSTMSizeOption.getValue()) {
            return numInstances;
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(Integer.valueOf(numInstances));
        while (((Integer) arrayList.get(arrayList.size() - 1)).intValue() >= 2 * this.minSTMSizeOption.getValue()) {
            arrayList.add(Integer.valueOf(((Integer) arrayList.get(arrayList.size() - 1)).intValue() / 2));
        }
        ArrayList arrayList2 = new ArrayList();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int intValue = numInstances - ((Integer) it.next()).intValue();
            if (this.predictionHistories.containsKey(Integer.valueOf(intValue))) {
                testTrainPredHistory = getIncrementalTestTrainPredHistory(this.stm, intValue, this.predictionHistories.get(Integer.valueOf(intValue)));
            } else if (this.predictionHistories.containsKey(Integer.valueOf(intValue - 1))) {
                List<Integer> remove = this.predictionHistories.remove(Integer.valueOf(intValue - 1));
                remove.remove(0);
                testTrainPredHistory = getIncrementalTestTrainPredHistory(this.stm, intValue, remove);
                this.predictionHistories.put(Integer.valueOf(intValue), testTrainPredHistory);
            } else {
                testTrainPredHistory = getTestTrainPredHistory(this.stm, intValue);
                this.predictionHistories.put(Integer.valueOf(intValue), testTrainPredHistory);
            }
            arrayList2.add(Double.valueOf(getHistoryErrorRate(testTrainPredHistory)));
        }
        int indexOf = arrayList2.indexOf(Collections.min(arrayList2));
        if (indexOf > 0) {
            for (int i = 1; i < arrayList2.size(); i++) {
                if (((Double) arrayList2.get(i)).doubleValue() < ((Double) arrayList2.get(0)).doubleValue()) {
                    int intValue2 = numInstances - ((Integer) arrayList.get(i)).intValue();
                    List<Integer> testTrainPredHistory2 = getTestTrainPredHistory(this.stm, intValue2);
                    arrayList2.set(i, Double.valueOf(getHistoryErrorRate(testTrainPredHistory2)));
                    this.predictionHistories.remove(Integer.valueOf(intValue2));
                    this.predictionHistories.put(Integer.valueOf(intValue2), testTrainPredHistory2);
                }
            }
            indexOf = arrayList2.indexOf(Collections.min(arrayList2));
        }
        int intValue3 = ((Integer) arrayList.get(indexOf)).intValue();
        if (intValue3 < numInstances) {
            adaptHistories(indexOf);
        }
        return intValue3;
    }

    private int getNewSTMSize(boolean z) {
        return z ? getMinErrorRateWindowSize() : getMinErrorRateWindowSizeIncremental();
    }
}
