package gov.sandia.cognition.learning.algorithm.svm;

import gov.sandia.cognition.annotation.PublicationReference;
import gov.sandia.cognition.annotation.PublicationType;
import gov.sandia.cognition.learning.algorithm.AbstractAnytimeSupervisedBatchLearner;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.learning.function.categorization.KernelBinaryCategorizer;
import gov.sandia.cognition.learning.function.kernel.Kernel;
import gov.sandia.cognition.learning.function.kernel.KernelContainer;
import gov.sandia.cognition.util.DefaultWeightedValue;
import gov.sandia.cognition.util.Randomized;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Random;

@PublicationReference(title = "Fast training of support vector machines using sequential minimal optimization", author = {"John C. Platt"}, year = 1999, type = PublicationType.BookChapter, pages = {185, 208}, publication = "Advances in Kernel Methods", url = "http://research.microsoft.com/pubs/68391/smo-book.pdf")
/* loaded from: input_file:gov/sandia/cognition/learning/algorithm/svm/SequentialMinimalOptimization.class */
public class SequentialMinimalOptimization<InputType> extends AbstractAnytimeSupervisedBatchLearner<InputType, Boolean, KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>>> implements KernelContainer<InputType>, Randomized {
    public static final int DEFAULT_MAX_ITERATIONS = 1000;
    public static final double DEFAULT_MAX_PENALTY = Double.POSITIVE_INFINITY;
    public static final double DEFAULT_ERROR_TOLERANCE = 0.001d;
    public static final double DEFAULT_EFFECTIVE_ZERO = 1.0E-10d;
    public static final int DEFAULT_KERNEL_CACHE_SIZE = 1000;
    private double maxPenalty;
    private double errorTolerance;
    private double effectiveZero;
    private int kernelCacheSize;
    private Random random;
    private Kernel<? super InputType> kernel;
    private transient KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> result;
    private transient ArrayList<InputOutputPair<? extends InputType, Boolean>> dataList;
    private transient int dataSize;
    private transient boolean examineAll;
    private transient int changeCount;
    private transient LinkedHashMap<Integer, DefaultWeightedValue<InputType>> supportsMap;
    private transient LinkedHashSet<Integer> nonBoundAlphaIndices;
    private transient LinkedHashMap<Integer, Double> errorCache;
    private transient LinkedHashMap<Long, Double> kernelCache;

    public SequentialMinimalOptimization() {
        this(null);
    }

    public SequentialMinimalOptimization(Kernel<? super InputType> kernel) {
        this(kernel, new Random());
    }

    public SequentialMinimalOptimization(Kernel<? super InputType> kernel, Random random) {
        this(kernel, Double.POSITIVE_INFINITY, 0.001d, 1.0E-10d, 1000, 1000, random);
    }

    public SequentialMinimalOptimization(Kernel<? super InputType> kernel, double d, double d2, double d3, int i, int i2, Random random) {
        super(i2);
        setKernel(kernel);
        setMaxPenalty(d);
        setErrorTolerance(d2);
        setEffectiveZero(d3);
        setKernelCacheSize(i);
        setRandom(random);
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean initializeAlgorithm() {
        this.result = null;
        if (getData() == null) {
            return false;
        }
        this.dataList = new ArrayList<>(getData().size());
        int i = 0;
        for (InputOutputPair<? extends InputType, Boolean> inputOutputPair : getData()) {
            if (inputOutputPair != null && inputOutputPair.getInput() != null && inputOutputPair.getOutput() != null) {
                this.dataList.add(inputOutputPair);
                if (inputOutputPair.getOutput().booleanValue()) {
                    i++;
                }
            }
        }
        this.dataSize = this.dataList.size();
        if (this.dataSize <= 0) {
            this.dataList = null;
            return false;
        }
        if (i <= 0 || i >= this.dataSize) {
            throw new IllegalArgumentException("Data is all one category");
        }
        this.changeCount = getData().size();
        this.supportsMap = new LinkedHashMap<>();
        this.nonBoundAlphaIndices = new LinkedHashSet<>();
        this.errorCache = new LinkedHashMap<>();
        if (this.kernelCacheSize > 1 && this.dataSize > 1) {
            final int min = Math.min(this.dataSize * this.dataSize, this.kernelCacheSize);
            this.kernelCache = new LinkedHashMap<Long, Double>(min, 0.75f, true) { // from class: gov.sandia.cognition.learning.algorithm.svm.SequentialMinimalOptimization.1
                @Override // java.util.LinkedHashMap
                protected boolean removeEldestEntry(Map.Entry<Long, Double> entry) {
                    return size() > min;
                }
            };
        }
        this.result = new KernelBinaryCategorizer<>(this.kernel, this.supportsMap.values(), 0.0d);
        return true;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected boolean step() {
        this.changeCount = 0;
        if (this.examineAll) {
            for (int i = 0; i < this.dataSize; i++) {
                this.changeCount += examineExample(i);
            }
        } else {
            Iterator it = new ArrayList(this.nonBoundAlphaIndices).iterator();
            while (it.hasNext()) {
                Integer num = (Integer) it.next();
                double alpha = getAlpha(num.intValue());
                if (alpha > 0.0d && alpha < this.maxPenalty) {
                    this.changeCount += examineExample(num.intValue());
                }
            }
        }
        if (this.examineAll) {
            this.examineAll = false;
        } else if (this.changeCount <= 0) {
            this.examineAll = true;
        }
        return this.changeCount > 0 || this.examineAll;
    }

    @Override // gov.sandia.cognition.learning.algorithm.AbstractAnytimeBatchLearner
    protected void cleanupAlgorithm() {
        this.dataList = null;
        this.supportsMap = null;
        this.nonBoundAlphaIndices = null;
        this.errorCache = null;
        this.kernelCache = null;
    }

    protected int examineExample(int i) {
        double d = this.maxPenalty;
        double d2 = this.effectiveZero;
        double target = getTarget(i);
        double alpha = getAlpha(i);
        double error = getError(i);
        double d3 = error * target;
        if ((d3 >= (-d2) || alpha >= d) && (d3 <= d2 || alpha <= 0.0d)) {
            return 0;
        }
        int size = this.nonBoundAlphaIndices.size();
        if (size > 1) {
            if (takeStep(error > 0.0d ? getMinErrorIndex() : getMaxErrorIndex(), i)) {
                return 1;
            }
        }
        if (size > 0) {
            int nextInt = this.random.nextInt(size);
            ArrayList arrayList = new ArrayList(this.nonBoundAlphaIndices);
            for (int i2 = 0; i2 < size; i2++) {
                if (takeStep(((Integer) arrayList.get((nextInt + i2) % size)).intValue(), i)) {
                    return 1;
                }
            }
        }
        int nextInt2 = this.random.nextInt(this.dataSize);
        for (int i3 = 0; i3 < this.dataSize; i3++) {
            if (takeStep((nextInt2 + i3) % this.dataSize, i)) {
                return 1;
            }
        }
        return 0;
    }

    private boolean takeStep(int i, int i2) {
        double max;
        double min;
        if (i == i2) {
            return false;
        }
        double d = this.maxPenalty;
        double d2 = this.effectiveZero;
        double d3 = d - d2;
        double target = getTarget(i);
        double error = getError(i);
        double alpha = getAlpha(i);
        double target2 = getTarget(i2);
        double error2 = getError(i2);
        double alpha2 = getAlpha(i2);
        if (target != target2) {
            double d4 = alpha2 - alpha;
            max = Math.max(0.0d, d4);
            min = Math.min(d, d4 + d);
        } else {
            double d5 = alpha + alpha2;
            max = Math.max(0.0d, d5 - d);
            min = Math.min(d, d5);
        }
        if (max >= min) {
            return false;
        }
        double evaluateKernel = evaluateKernel(i, i);
        double evaluateKernel2 = evaluateKernel(i, i2);
        double evaluateKernel3 = evaluateKernel(i2, i2);
        double d6 = ((evaluateKernel2 + evaluateKernel2) - evaluateKernel) - evaluateKernel3;
        if (d6 >= 0.0d) {
            return false;
        }
        double d7 = alpha2 - ((target2 * (error - error2)) / d6);
        if (d7 <= max) {
            d7 = max;
        } else if (d7 >= min) {
            d7 = min;
        }
        if (d7 < d2) {
            d7 = 0.0d;
        } else if (d7 > d3) {
            d7 = d;
        }
        if (Math.abs(d7 - alpha2) < d2) {
            return false;
        }
        double d8 = alpha + (target * target2 * (alpha2 - d7));
        if (d8 < d2) {
            d8 = 0.0d;
        } else if (d8 > d3) {
            d8 = d;
        }
        double bias = getBias();
        double d9 = ((bias - error) - ((target * (d8 - alpha)) * evaluateKernel)) - ((target2 * (d7 - alpha2)) * evaluateKernel2);
        double d10 = ((bias - error2) - ((target * (d8 - alpha)) * evaluateKernel2)) - ((target2 * (d7 - alpha2)) * evaluateKernel3);
        double d11 = (d8 <= d2 || d8 >= d3) ? (d7 <= d2 || d7 >= d3) ? (d9 + d10) / 2.0d : d10 : d9;
        setAlpha(i, d8);
        setAlpha(i2, d7);
        setBias(d11);
        updateErrorCache(i, target, alpha, d8, i2, target2, alpha2, d7, bias, d11);
        return true;
    }

    private void updateErrorCache(int i, double d, double d2, double d3, int i2, double d4, double d5, double d6, double d7, double d8) {
        if (d3 <= 0.0d || d3 >= this.maxPenalty) {
            this.errorCache.remove(Integer.valueOf(i));
        }
        if (d6 <= 0.0d || d6 >= this.maxPenalty) {
            this.errorCache.remove(Integer.valueOf(i2));
        }
        double d9 = d * (d3 - d2);
        double d10 = d4 * (d6 - d5);
        double d11 = d8 - d7;
        Iterator<Integer> it = this.nonBoundAlphaIndices.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            Double d12 = this.errorCache.get(next);
            this.errorCache.put(next, Double.valueOf((next.intValue() == i || next.intValue() == i2) ? 0.0d : d12 == null ? getSVMOutput(next.intValue()) - getTarget(next.intValue()) : d12.doubleValue() + (d9 * evaluateKernel(i, next.intValue())) + (d10 * evaluateKernel(i2, next.intValue())) + d11));
        }
    }

    private double evaluateKernel(int i, int i2) {
        if (this.kernelCache == null) {
            return this.kernel.evaluate(getPoint(i), getPoint(i2));
        }
        long j = i <= i2 ? (i << 32) | i2 : (i2 << 32) | i;
        Double d = this.kernelCache.get(Long.valueOf(j));
        if (d != null) {
            return d.doubleValue();
        }
        double evaluate = this.kernel.evaluate(getPoint(i), getPoint(i2));
        this.kernelCache.put(Long.valueOf(j), Double.valueOf(evaluate));
        return evaluate;
    }

    private double getSVMOutput(int i) {
        double bias = this.result.getBias();
        for (Map.Entry<Integer, DefaultWeightedValue<InputType>> entry : this.supportsMap.entrySet()) {
            bias += entry.getValue().getWeight() * evaluateKernel(i, entry.getKey().intValue());
        }
        return bias;
    }

    private double getError(int i) {
        Double d = this.errorCache.get(Integer.valueOf(i));
        return d != null ? d.doubleValue() : getSVMOutput(i) - getTarget(i);
    }

    private int getMinErrorIndex() {
        double d = Double.POSITIVE_INFINITY;
        int i = -1;
        Iterator<Integer> it = this.nonBoundAlphaIndices.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            double error = getError(next.intValue());
            if (error < d) {
                d = error;
                i = next.intValue();
            }
        }
        return i;
    }

    private int getMaxErrorIndex() {
        double d = Double.NEGATIVE_INFINITY;
        int i = -1;
        Iterator<Integer> it = this.nonBoundAlphaIndices.iterator();
        while (it.hasNext()) {
            Integer next = it.next();
            double error = getError(next.intValue());
            if (error > d) {
                d = error;
                i = next.intValue();
            }
        }
        return i;
    }

    private InputType getPoint(int i) {
        return this.dataList.get(i).getInput();
    }

    private double getTarget(int i) {
        return this.dataList.get(i).getOutput().booleanValue() ? 1.0d : -1.0d;
    }

    private double getAlpha(int i) {
        DefaultWeightedValue<InputType> defaultWeightedValue = this.supportsMap.get(Integer.valueOf(i));
        if (defaultWeightedValue == null) {
            return 0.0d;
        }
        return Math.abs(defaultWeightedValue.getWeight());
    }

    private void setAlpha(int i, double d) {
        if (d == 0.0d) {
            this.supportsMap.remove(Integer.valueOf(i));
            this.nonBoundAlphaIndices.remove(Integer.valueOf(i));
            return;
        }
        double target = getTarget(i) * d;
        DefaultWeightedValue<InputType> defaultWeightedValue = this.supportsMap.get(Integer.valueOf(i));
        if (defaultWeightedValue == null) {
            this.supportsMap.put(Integer.valueOf(i), new DefaultWeightedValue<>(getPoint(i), target));
        } else {
            defaultWeightedValue.setWeight(target);
        }
        if (d == this.maxPenalty) {
            this.nonBoundAlphaIndices.remove(Integer.valueOf(i));
        } else {
            this.nonBoundAlphaIndices.add(Integer.valueOf(i));
        }
    }

    private double getBias() {
        return this.result.getBias();
    }

    private void setBias(double d) {
        this.result.setBias(d);
    }

    /* renamed from: getResult, reason: merged with bridge method [inline-methods] */
    public KernelBinaryCategorizer<InputType, DefaultWeightedValue<InputType>> m105getResult() {
        return this.result;
    }

    @Override // gov.sandia.cognition.learning.function.kernel.KernelContainer
    public Kernel<? super InputType> getKernel() {
        return this.kernel;
    }

    public void setKernel(Kernel<? super InputType> kernel) {
        this.kernel = kernel;
    }

    public double getMaxPenalty() {
        return this.maxPenalty;
    }

    public void setMaxPenalty(double d) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("maxPenalty must be positive.");
        }
        this.maxPenalty = d;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setErrorTolerance(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("errorTolerance cannot be negative.");
        }
        this.errorTolerance = d;
    }

    public double getEffectiveZero() {
        return this.effectiveZero;
    }

    public void setEffectiveZero(double d) {
        if (d < 0.0d) {
            throw new IllegalArgumentException("effectiveZero cannot be negative.");
        }
        this.effectiveZero = d;
    }

    public int getKernelCacheSize() {
        return this.kernelCacheSize;
    }

    public void setKernelCacheSize(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("kernelCacheSize cannot be negative");
        }
        this.kernelCacheSize = i;
    }

    public Random getRandom() {
        return this.random;
    }

    public void setRandom(Random random) {
        this.random = random;
    }

    public int getChangeCount() {
        return this.changeCount;
    }
}
