/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.carbon.ml.lifecycle.test;

import java.io.IOException;
import javax.ws.rs.core.Response;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.json.JSONArray;
import org.json.JSONException;
import org.testng.AssertJUnit;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import org.wso2.carbon.ml.MLTestUtils;
import org.wso2.carbon.ml.integration.common.utils.MLBaseTest;
import org.wso2.carbon.ml.integration.common.utils.MLHttpClient;
import org.wso2.carbon.ml.integration.common.utils.exception.MLHttpClientException;
import org.wso2.carbon.ml.integration.common.utils.exception.MLIntegrationBaseTestException;

@Test(groups={"diabetesDataset"})
public class Dataset1DiabetesTestCase
extends MLBaseTest {
    private MLHttpClient mlHttpclient;
    private static String modelName;
    private static int modelId;
    private CloseableHttpResponse response;
    private int versionSetId;
    private int projectId;

    @BeforeClass(alwaysRun=true)
    public void initTest() throws MLIntegrationBaseTestException, MLHttpClientException, IOException, JSONException {
        super.init();
        this.mlHttpclient = this.getMLHttpClient();
        String version = "1.0";
        int datasetId = this.createDataset("Diabetes", version, "artifacts/ML/data/pIndiansDiabetes.csv");
        this.versionSetId = this.getVersionSetId(datasetId, version);
        this.isDatasetProcessed(this.versionSetId, 120000L, 1000);
        this.projectId = this.createProject("Diabetes_Project", "Diabetes");
    }

    private void testPredictDiabetes() throws MLHttpClientException, JSONException {
        this.testPredictDiabetes(false);
    }

    private void testPredictDiabetes(boolean skipDecoding) throws MLHttpClientException, JSONException {
        String payload = "[[1,89,66,23,94,28.1,0.167,21],[2,197,70,45,543,30.5,0.158,53]]";
        String url = skipDecoding ? "/api/models/" + modelId + "/predict?skipDecoding=true" : "/api/models/" + modelId + "/predict";
        this.response = this.mlHttpclient.doHttpPost(url, payload);
        AssertJUnit.assertEquals((String)"Unexpected response received", (int)Response.Status.OK.getStatusCode(), (int)this.response.getStatusLine().getStatusCode());
        String reply = this.mlHttpclient.getResponseAsString(this.response);
        JSONArray predictions = new JSONArray(reply);
        AssertJUnit.assertEquals((String)("Expected 2 predictions but received only " + predictions.length()), (int)2, (int)predictions.length());
        if (skipDecoding) {
            AssertJUnit.assertEquals((String)("Expected a double value but found " + predictions.get(0)), (boolean)true, (boolean)(predictions.get(0) instanceof Double));
            AssertJUnit.assertEquals((String)("Expected a double value but found " + predictions.get(1)), (boolean)true, (boolean)(predictions.get(1) instanceof Double));
        }
    }

    private void testPredictDiabetesInvalidNumberOfFeatures() throws MLHttpClientException, JSONException {
        String payload = "[[1,89,66,23,94,28.1,0.167],[2,197,70,45,543,30.5,0.158]]";
        this.response = this.mlHttpclient.doHttpPost("/api/models/" + modelId + "/predict", payload);
        AssertJUnit.assertEquals((String)"Unexpected response received", (int)Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), (int)this.response.getStatusLine().getStatusCode());
    }

    private void testPredictDiabetesInvalidNumericalFeatures() throws MLHttpClientException, JSONException {
        String payload = "[[1,89,66,23,94,28afdc.1,0.167,21],[2,197,70,45,543,30.5,0.158,53]]";
        this.response = this.mlHttpclient.doHttpPost("/api/models/" + modelId + "/predict", payload);
        AssertJUnit.assertEquals((String)"Unexpected response received", (int)Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), (int)this.response.getStatusLine().getStatusCode());
    }

    private void testPredictDiabetesFromFile() throws MLHttpClientException, JSONException {
        this.response = this.mlHttpclient.predictFromCSV((long)modelId, "artifacts/ML/data/pIndiansDiabetesTest.csv");
        AssertJUnit.assertEquals((String)"Unexpected response received", (int)Response.Status.OK.getStatusCode(), (int)this.response.getStatusLine().getStatusCode());
        String reply = this.mlHttpclient.getResponseAsString(this.response);
        JSONArray predictions = new JSONArray(reply);
        AssertJUnit.assertEquals((int)7, (int)predictions.length());
    }

    private void buildModelWithLearningAlgorithm(String algorithmName, String algorithmType) throws MLHttpClientException, IOException, JSONException, InterruptedException {
        modelName = MLTestUtils.createModelWithConfigurations(algorithmName, algorithmType, "Class", "0.7", this.projectId, this.versionSetId, this.mlHttpclient);
        modelId = this.mlHttpclient.getModelId(modelName);
        this.response = this.mlHttpclient.doHttpPost("/api/models/" + modelId);
        AssertJUnit.assertEquals((String)"Unexpected response received", (int)Response.Status.OK.getStatusCode(), (int)this.response.getStatusLine().getStatusCode());
        this.response.close();
        boolean status = MLTestUtils.checkModelStatusCompleted(modelName, this.mlHttpclient, 120000L, 1000);
        AssertJUnit.assertEquals((String)"Model building did not complete successfully", (boolean)true, (boolean)status);
    }

    @Test(description="Build a Naive Bayes model and predict for Diabetes dataset", groups={"createNaiveBayesModelDiabetes"})
    public void testBuildNaiveBayesModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("NAIVE_BAYES", "Classification");
        this.testPredictDiabetes();
        this.testPredictDiabetesInvalidNumberOfFeatures();
    }

    @Test(description="Build a SVM model and predict for Diabetes dataset", groups={"createSVMModelDiabetes"}, dependsOnGroups={"createNaiveBayesModelDiabetes"})
    public void testBuildSVMModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("SVM", "Classification");
        this.testPredictDiabetes();
        this.testPredictDiabetesInvalidNumericalFeatures();
        this.testExportAsPMML(modelId);
        this.testPublishAsPMML(modelId);
    }

    @Test(description="Build a Decision Tree model and predict for Diabetes dataset", groups={"createDecisionTreeModelDiabetes"}, dependsOnGroups={"createSVMModelDiabetes"})
    public void testBuildDecisionTreeModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("DECISION_TREE", "Classification");
        this.testPredictDiabetes();
    }

    @Test(description="Build a Random Forest model and predict for Diabetes dataset", groups={"createRandomForestModelDiabetes"}, dependsOnGroups={"createDecisionTreeModelDiabetes"})
    public void testBuildRandomForestModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("RANDOM_FOREST_CLASSIFICATION", "Classification");
        this.testPredictDiabetes();
    }

    @Test(description="Build a Stacked Autoencoders model and predict for Diabetes dataset", groups={"createStackedAutoencodersModelDiabetes"}, dependsOnGroups={"createRandomForestModelDiabetes"})
    public void testBuildStackedAutoencodersModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("STACKED_AUTOENCODERS", "Deeplearning");
        this.testPredictDiabetes();
    }

    @Test(description="Build a Logistic Regression model and predict for Diabetes dataset", groups={"createLogisticRegressionDiabetes"}, dependsOnGroups={"createStackedAutoencodersModelDiabetes"})
    public void testBuildLogisticRegressionModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("LOGISTIC_REGRESSION", "Classification");
        this.testPredictDiabetes();
        this.testPredictDiabetes(true);
        this.testPredictDiabetesFromFile();
        this.testExportAsPMML(modelId);
        this.testPublishAsPMML(modelId);
    }

    @Test(description="Build a K-means model", groups={"createKMeansDiabetes"}, dependsOnGroups={"createLogisticRegressionDiabetes"})
    public void testBuildKMeansModel() throws MLHttpClientException, IOException, JSONException, InterruptedException {
        this.buildModelWithLearningAlgorithm("K_MEANS", "Clustering");
        this.testExportAsPMML(modelId);
        this.testPublishAsPMML(modelId);
    }

    @AfterClass(alwaysRun=true)
    public void tearDown() throws InterruptedException, MLHttpClientException {
        super.destroy();
    }
}

