/*
 * Decompiled with CFR 0.152.
 */
package org.openmetadata.service.util.incidentSeverityClassifier;

import java.util.Arrays;
import java.util.List;
import org.openmetadata.common.utils.CommonUtil;
import org.openmetadata.schema.EntityInterface;
import org.openmetadata.schema.tests.type.Severity;
import org.openmetadata.schema.type.TagLabel;
import org.openmetadata.service.util.incidentSeverityClassifier.IncidentSeverityClassifierInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LogisticRegressionIncidentSeverityClassifier
extends IncidentSeverityClassifierInterface {
    private static final Logger LOG = LoggerFactory.getLogger(LogisticRegressionIncidentSeverityClassifier.class);
    static final double[][] coefMatrix = new double[][]{{-39.7199427, -3.16664212, 7.52955733, 16.7600252, 18.5970022}, {65.6563864, 9.33015912, -3.11353307, -13.7841793, -58.0888332}, {0.0102508192, 0.00490356651, -0.00162766138, -0.00622724217, -0.0072994822}, {0.0784018717, -0.01140259, -0.00911123152, -0.0237962385, -0.0340918118}};

    @Override
    public Severity classifyIncidentSeverity(EntityInterface entity) {
        double[] vectorX = this.getVectorX(entity);
        if (vectorX.length == 0) {
            return null;
        }
        try {
            double[] vectorZ = this.dotProduct(vectorX);
            double[] softmaxVector = this.softmax(vectorZ);
            int predictedClass = this.argmax(softmaxVector);
            switch (predictedClass) {
                case 0: {
                    return Severity.Severity1;
                }
                case 1: {
                    return Severity.Severity2;
                }
                case 2: {
                    return Severity.Severity3;
                }
                case 3: {
                    return Severity.Severity4;
                }
                case 4: {
                    return Severity.Severity5;
                }
            }
        }
        catch (Exception e) {
            LOG.error("Error occurred while classifying incident severity", (Throwable)e);
        }
        return null;
    }

    private double[] dotProduct(double[] vectorX) {
        double[] result = new double[coefMatrix[0].length];
        for (int i = 0; i < coefMatrix.length; ++i) {
            int sum = 0;
            for (int j = 0; j < vectorX.length; ++j) {
                sum = (int)((double)sum + vectorX[j] * coefMatrix[j][i]);
            }
            result[i] = sum;
        }
        return result;
    }

    private double[] softmax(double[] vectorZ) {
        double expSum = Arrays.stream(vectorZ).map(Math::exp).sum();
        double[] softmax = new double[vectorZ.length];
        for (int i = 0; i < vectorZ.length; ++i) {
            softmax[i] = Math.exp(vectorZ[i]) / expSum;
        }
        return softmax;
    }

    private int argmax(double[] softmaxVector) {
        int maxIndex = 0;
        double argmax = 0.0;
        for (int i = 0; i < softmaxVector.length; ++i) {
            if (!(softmaxVector[i] > argmax)) continue;
            argmax = softmaxVector[i];
            maxIndex = i;
        }
        return maxIndex;
    }

    private double[] getVectorX(EntityInterface entity) {
        double tier;
        double hasOwner = !CommonUtil.nullOrEmpty((List)entity.getOwners()) ? 1.0 : 0.0;
        double followers = entity.getFollowers() != null ? (double)entity.getFollowers().size() : 0.0;
        double votes = entity.getVotes() != null ? (double)entity.getVotes().getUpVotes().intValue() : 0.0;
        double d = tier = entity.getTags() != null ? this.getTier(entity.getTags()) : 0.0;
        if (tier == 0.0) {
            return new double[0];
        }
        return new double[]{tier, hasOwner, followers, votes};
    }

    private double getTier(List<TagLabel> tags) {
        for (TagLabel tag : tags) {
            if (!tag.getName().contains("Tier")) continue;
            switch (tag.getName()) {
                case "Tier1": {
                    return 1.0;
                }
                case "Tier2": {
                    return 2.0;
                }
                case "Tier3": {
                    return 3.0;
                }
                case "Tier4": {
                    return 4.0;
                }
                case "Tier5": {
                    return 5.0;
                }
            }
        }
        return 0.0;
    }
}

