/*
 * Decompiled with CFR 0.152.
 */
package org.apache.mahout.benchmark;

import java.io.IOException;
import org.apache.mahout.benchmark.VectorBenchmarks;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.common.TimingStatistics;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.SparseMatrix;
import org.apache.mahout.math.Vector;

public class ClosestCentroidBenchmark {
    public static final String SERIALIZE = "Serialize";
    public static final String DESERIALIZE = "Deserialize";
    private final VectorBenchmarks mark;

    public ClosestCentroidBenchmark(VectorBenchmarks mark) {
        this.mark = mark;
    }

    public void benchmark(DistanceMeasure measure) throws IOException {
        SparseMatrix clusterDistances = new SparseMatrix(this.mark.numClusters, this.mark.numClusters);
        for (int i = 0; i < this.mark.numClusters; ++i) {
            for (int j = 0; j < this.mark.numClusters; ++j) {
                double distance = Double.POSITIVE_INFINITY;
                if (i != j) {
                    distance = measure.distance(this.mark.clusters[i], this.mark.clusters[j]);
                }
                clusterDistances.setQuick(i, j, distance);
            }
        }
        long distanceCalculations = 0L;
        TimingStatistics stats = new TimingStatistics();
        for (int l = 0; l < this.mark.loop; ++l) {
            TimingStatistics.Call call = stats.newCall(this.mark.leadTimeUsec);
            for (int i = 0; i < this.mark.numVectors; ++i) {
                Vector vector = this.mark.vectors[1][this.mark.vIndex(i)];
                double minDistance = Double.MAX_VALUE;
                for (int k = 0; k < this.mark.numClusters; ++k) {
                    double distance = measure.distance(vector, this.mark.clusters[k]);
                    ++distanceCalculations;
                    if (!(distance < minDistance)) continue;
                    minDistance = distance;
                }
            }
            if (call.end(this.mark.maxTimeUsec)) break;
        }
        this.mark.printStats(stats, measure.getClass().getName(), "Closest C w/o Elkan's trick", "distanceCalculations = " + distanceCalculations);
        distanceCalculations = 0L;
        stats = new TimingStatistics();
        RandomWrapper rand = RandomUtils.getRandom();
        for (int l = 0; l < this.mark.loop; ++l) {
            TimingStatistics.Call call = stats.newCall(this.mark.leadTimeUsec);
            for (int i = 0; i < this.mark.numVectors; ++i) {
                Vector vector = this.mark.vectors[1][this.mark.vIndex(i)];
                int closestCentroid = rand.nextInt(this.mark.numClusters);
                double dist = measure.distance(vector, this.mark.clusters[closestCentroid]);
                ++distanceCalculations;
                for (int k = 0; k < this.mark.numClusters; ++k) {
                    double centroidDist;
                    if (closestCentroid == k || !((centroidDist = clusterDistances.getQuick(k, closestCentroid)) < 2.0 * dist)) continue;
                    dist = measure.distance(vector, this.mark.clusters[k]);
                    closestCentroid = k;
                    ++distanceCalculations;
                }
            }
            if (call.end(this.mark.maxTimeUsec)) break;
        }
        this.mark.printStats(stats, measure.getClass().getName(), "Closest C w/ Elkan's trick", "distanceCalculations = " + distanceCalculations);
    }
}

