/*
 * Decompiled with CFR 0.152.
 */
package org.streaminer.stream.classifier;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.streaminer.stream.classifier.AbstractClassifier;
import org.streaminer.stream.data.Data;
import org.streaminer.stream.learner.LearnerUtils;
import org.streaminer.stream.model.HyperplaneModel;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Perceptron
extends AbstractClassifier<Data, String> {
    private static final long serialVersionUID = -3263838547557335984L;
    private static Logger LOG = LoggerFactory.getLogger(Perceptron.class);
    private Double learnRate;
    private String labelAttribute;
    private List<String> labels = new ArrayList<String>();
    private List<String> attributes = new ArrayList<String>();
    private HyperplaneModel model;

    public Perceptron() {
        this(1, 0.05);
    }

    public Perceptron(int learnRate) {
        this(1, learnRate);
    }

    public Perceptron(int kernelType, double learnRate) {
        this.model = new HyperplaneModel(kernelType);
        this.model.initModel(new LinkedHashMap<String, Double>(), 0.0);
        this.learnRate = learnRate;
    }

    public String getLabelAttribute() {
        return this.labelAttribute;
    }

    public void setLabelAttribute(String labelAttribute) {
        this.labelAttribute = labelAttribute;
    }

    public Double getLearnRate() {
        return this.learnRate;
    }

    public void setLearnRate(Double learnRate) {
        this.learnRate = learnRate;
    }

    @Override
    public void learn(Data item) {
        if (this.labelAttribute == null) {
            this.labelAttribute = LearnerUtils.detectLabelAttribute(item);
        }
        if (this.labelAttribute == null) {
            LOG.info("No label defined!");
            return;
        }
        String label = null;
        if (item.get(this.labelAttribute) == null) {
            LOG.error("No label found for example!");
            return;
        }
        label = ((Serializable)item.get(this.labelAttribute)).toString();
        int labelIndex = this.labels.indexOf(label);
        if (labelIndex < 0 && this.labels.size() < 2) {
            LOG.info("Adding label '{}'", (Object)label);
            this.labels.add(label);
            labelIndex = this.labels.indexOf(label);
        }
        if (labelIndex < 0) {
            LOG.error("My labels are {}, unknown label: {}", this.labels, (Object)label);
            if (this.labels.size() == 2) {
                LOG.error("The perceptron algorithm only works for binary classification tasks!");
            }
            return;
        }
        Map<String, Double> example = LearnerUtils.getNumericVector(item);
        if (example.isEmpty()) {
            LOG.info("No numerical attributes found for learning! Ignoring example!");
            return;
        }
        Double prediction = this.model.predict(item);
        if (prediction != null && prediction.intValue() != labelIndex) {
            double direction = labelIndex == 0 ? -1.0 : 1.0;
            this.model.setBias(this.model.getBias() + this.learnRate * direction);
            Map<String, Double> weights = this.model.getWeights();
            for (String attribute : this.attributes) {
                Double attributeValue = example.get(attribute);
                Double weight = weights.get(attribute);
                if (weight == null) {
                    weight = 0.0;
                }
                weight = weight + this.learnRate * direction * attributeValue;
                weights.put(attribute, weight);
            }
            this.model.setWeights(weights);
        }
    }

    @Override
    public String predict(Data item) {
        if (this.labels.isEmpty()) {
            LOG.warn("No labels available, predicting '?'!");
            return "?";
        }
        if (this.labels.size() == 1) {
            LOG.warn("Only 1 label available, predicting '{}'!", (Object)this.labels.get(0));
            return this.labels.get(0);
        }
        Double pred = this.model.predict(item);
        if (pred < 0.5) {
            return this.labels.get(0);
        }
        return this.labels.get(1);
    }
}

