package hex.genmodel.tools;

import au.com.bytecode.opencsv.CSVReader;
import hex.ModelCategory;
import hex.genmodel.GenModel;
import hex.genmodel.MojoModel;
import hex.genmodel.algos.deepwater.caffe.nano.Deepwater;
import hex.genmodel.algos.glrm.GlrmMojoModel;
import hex.genmodel.algos.tree.SharedTreeMojoModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.prediction.AutoEncoderModelPrediction;
import hex.genmodel.easy.prediction.BinomialModelPrediction;
import hex.genmodel.easy.prediction.DimReductionModelPrediction;
import hex.genmodel.easy.prediction.MultinomialModelPrediction;
import hex.genmodel.easy.prediction.OrdinalModelPrediction;
import hex.genmodel.easy.prediction.RegressionModelPrediction;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;

/* loaded from: input_file:hex/genmodel/tools/PredictCsv.class */
public class PredictCsv {
    private String inputCSVFileName;
    private String outputCSVFileName;
    private boolean useDecimalOutput = false;
    public char separator = ',';
    public boolean setInvNumNA = false;
    public boolean getTreePath = false;
    boolean returnGLRMReconstruct = false;
    public int glrmIterNumber = -1;
    private EasyPredictModelWrapper model;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: hex.genmodel.tools.PredictCsv$1, reason: invalid class name */
    /* loaded from: input_file:hex/genmodel/tools/PredictCsv$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$hex$ModelCategory = new int[ModelCategory.values().length];

        static {
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.AutoEncoder.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Binomial.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Multinomial.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Ordinal.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Clustering.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.Regression.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$hex$ModelCategory[ModelCategory.DimReduction.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
        }
    }

    public static void main(String[] strArr) {
        PredictCsv predictCsv = new PredictCsv();
        predictCsv.parseArgs(strArr);
        try {
            predictCsv.run();
        } catch (Exception e) {
            e.printStackTrace();
            System.exit(2);
        }
        System.exit(0);
    }

    /* JADX WARN: Failed to find 'out' block for switch in B:5:0x0031. Please report as an issue. */
    private static RowData formatDataRow(String[] strArr, String[] strArr2) {
        RowData rowData = new RowData();
        int min = Math.min(strArr2.length, strArr.length);
        for (int i = 0; i < min; i++) {
            String str = strArr2[i];
            String str2 = strArr[i];
            boolean z = -1;
            switch (str2.hashCode()) {
                case Deepwater.Create /* 0 */:
                    if (str2.equals("")) {
                        z = false;
                        break;
                    }
                    break;
                case 45:
                    if (str2.equals("-")) {
                        z = 3;
                        break;
                    }
                    break;
                case 2483:
                    if (str2.equals("NA")) {
                        z = true;
                        break;
                    }
                    break;
                case 76480:
                    if (str2.equals("N/A")) {
                        z = 2;
                        break;
                    }
                    break;
            }
            switch (z) {
                case Deepwater.Create /* 0 */:
                case Deepwater.Train /* 1 */:
                case Deepwater.Predict /* 2 */:
                case Deepwater.SaveGraph /* 3 */:
                    break;
                default:
                    rowData.put(str, str2);
                    break;
            }
        }
        return rowData;
    }

    private String myDoubleToString(double d) {
        return Double.isNaN(d) ? "NA" : this.useDecimalOutput ? Double.toString(d) : Double.toHexString(d);
    }

    private void writeTreePathNames(BufferedWriter bufferedWriter) throws Exception {
        String[] decisionPathNames = ((SharedTreeMojoModel) this.model.m).getDecisionPathNames();
        int length = decisionPathNames.length - 1;
        for (int i = 0; i < length; i++) {
            bufferedWriter.write(decisionPathNames[i]);
            bufferedWriter.write(",");
        }
        bufferedWriter.write(decisionPathNames[length]);
    }

    private void run() throws Exception {
        int i;
        String str;
        ModelCategory modelCategory = this.model.getModelCategory();
        CSVReader cSVReader = new CSVReader(new FileReader(this.inputCSVFileName), this.separator);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(this.outputCSVFileName));
        int i2 = -1;
        switch (AnonymousClass1.$SwitchMap$hex$ModelCategory[modelCategory.ordinal()]) {
            case Deepwater.Train /* 1 */:
                String[] names = this.model.m.getNames();
                int size = this.model.domainMap.size();
                int nfeatures = this.model.m.nfeatures() - size;
                String[][] domainValues = this.model.m.getDomainValues();
                int i3 = size - 1;
                for (int i4 = 0; i4 <= i3; i4++) {
                    String[] strArr = domainValues[i4];
                    int length = strArr.length - 1;
                    for (int i5 = 0; i5 <= length; i5++) {
                        i2++;
                        bufferedWriter.write("reconstr_" + strArr[i5]);
                        bufferedWriter.write(44);
                    }
                    i2++;
                    bufferedWriter.write("reconstr_" + names[i4] + ".missing(NA)");
                    if (nfeatures > 0 || i4 < i3) {
                        bufferedWriter.write(44);
                    }
                }
                int length2 = names.length - 1;
                for (int i6 = size; i6 < names.length; i6++) {
                    i2++;
                    bufferedWriter.write("reconstr_" + names[i6]);
                    if (i6 < length2) {
                        bufferedWriter.write(44);
                    }
                }
                break;
            case Deepwater.Predict /* 2 */:
            case Deepwater.SaveGraph /* 3 */:
                if (this.getTreePath) {
                    writeTreePathNames(bufferedWriter);
                    break;
                } else {
                    bufferedWriter.write("predict");
                    for (String str2 : this.model.getResponseDomainValues()) {
                        bufferedWriter.write(",");
                        bufferedWriter.write(str2);
                    }
                    break;
                }
            case Deepwater.Save /* 4 */:
                bufferedWriter.write("predict");
                for (String str3 : this.model.getResponseDomainValues()) {
                    bufferedWriter.write(",");
                    bufferedWriter.write(str3);
                }
                break;
            case Deepwater.Load /* 5 */:
                bufferedWriter.write("cluster");
                break;
            case 6:
                if (this.getTreePath) {
                    writeTreePathNames(bufferedWriter);
                    break;
                } else {
                    bufferedWriter.write("predict");
                    break;
                }
            case 7:
                String[] names2 = this.model.m.getNames();
                if (this.returnGLRMReconstruct) {
                    i = ((GlrmMojoModel) this.model.m)._permutation.length;
                    str = "reconstr_";
                } else {
                    i = ((GlrmMojoModel) this.model.m)._ncolX;
                    str = "Arch";
                }
                int i7 = i - 1;
                for (int i8 = 0; i8 < i; i8++) {
                    bufferedWriter.write(this.returnGLRMReconstruct ? str + names2[i8] : str + (i8 + 1));
                    if (i8 < i7) {
                        bufferedWriter.write(44);
                    }
                }
                break;
            default:
                throw new Exception("Unknown model category " + modelCategory);
        }
        bufferedWriter.write("\n");
        int i9 = 1;
        try {
            try {
                String[] readNext = cSVReader.readNext();
                if (readNext == null) {
                    throw new Exception("Input dataset file is empty!");
                }
                checkMissingColumns(readNext);
                while (true) {
                    String[] readNext2 = cSVReader.readNext();
                    if (readNext2 == null) {
                        bufferedWriter.close();
                        cSVReader.close();
                        return;
                    }
                    RowData formatDataRow = formatDataRow(readNext2, readNext);
                    switch (AnonymousClass1.$SwitchMap$hex$ModelCategory[modelCategory.ordinal()]) {
                        case Deepwater.Train /* 1 */:
                            AutoEncoderModelPrediction predictAutoEncoder = this.model.predictAutoEncoder(formatDataRow);
                            for (int i10 = 0; i10 < predictAutoEncoder.reconstructed.length; i10++) {
                                bufferedWriter.write(myDoubleToString(predictAutoEncoder.reconstructed[i10]));
                                if (i10 < i2) {
                                    bufferedWriter.write(44);
                                }
                            }
                            break;
                        case Deepwater.Predict /* 2 */:
                            BinomialModelPrediction predictBinomial = this.model.predictBinomial(formatDataRow);
                            if (this.getTreePath) {
                                writeTreePaths(predictBinomial.leafNodeAssignments, bufferedWriter);
                                break;
                            } else {
                                bufferedWriter.write(predictBinomial.label);
                                bufferedWriter.write(",");
                                for (int i11 = 0; i11 < predictBinomial.classProbabilities.length; i11++) {
                                    if (i11 > 0) {
                                        bufferedWriter.write(",");
                                    }
                                    bufferedWriter.write(myDoubleToString(predictBinomial.classProbabilities[i11]));
                                }
                                break;
                            }
                        case Deepwater.SaveGraph /* 3 */:
                            MultinomialModelPrediction predictMultinomial = this.model.predictMultinomial(formatDataRow);
                            if (this.getTreePath) {
                                writeTreePaths(predictMultinomial.leafNodeAssignments, bufferedWriter);
                                break;
                            } else {
                                bufferedWriter.write(predictMultinomial.label);
                                bufferedWriter.write(",");
                                for (int i12 = 0; i12 < predictMultinomial.classProbabilities.length; i12++) {
                                    if (i12 > 0) {
                                        bufferedWriter.write(",");
                                    }
                                    bufferedWriter.write(myDoubleToString(predictMultinomial.classProbabilities[i12]));
                                }
                                break;
                            }
                        case Deepwater.Save /* 4 */:
                            OrdinalModelPrediction predictOrdinal = this.model.predictOrdinal(formatDataRow);
                            bufferedWriter.write(predictOrdinal.label);
                            bufferedWriter.write(",");
                            for (int i13 = 0; i13 < predictOrdinal.classProbabilities.length; i13++) {
                                if (i13 > 0) {
                                    bufferedWriter.write(",");
                                }
                                bufferedWriter.write(myDoubleToString(predictOrdinal.classProbabilities[i13]));
                            }
                            break;
                        case Deepwater.Load /* 5 */:
                            bufferedWriter.write(myDoubleToString(this.model.predictClustering(formatDataRow).cluster));
                            break;
                        case 6:
                            RegressionModelPrediction predictRegression = this.model.predictRegression(formatDataRow);
                            if (this.getTreePath) {
                                writeTreePaths(predictRegression.leafNodeAssignments, bufferedWriter);
                                break;
                            } else {
                                bufferedWriter.write(myDoubleToString(predictRegression.value));
                                break;
                            }
                        case 7:
                            DimReductionModelPrediction predictDimReduction = this.model.predictDimReduction(formatDataRow);
                            double[] dArr = this.returnGLRMReconstruct ? predictDimReduction.reconstructed : predictDimReduction.dimensions;
                            int length3 = dArr.length - 1;
                            for (int i14 = 0; i14 < dArr.length; i14++) {
                                bufferedWriter.write(myDoubleToString(dArr[i14]));
                                if (i14 < length3) {
                                    bufferedWriter.write(44);
                                }
                            }
                            break;
                        default:
                            throw new Exception("Unknown model category " + modelCategory);
                    }
                    bufferedWriter.write("\n");
                    i9++;
                }
            } catch (Exception e) {
                System.out.println("Caught exception on line 1");
                System.out.println("");
                e.printStackTrace();
                System.exit(1);
                bufferedWriter.close();
                cSVReader.close();
            }
        } catch (Throwable th) {
            bufferedWriter.close();
            cSVReader.close();
            throw th;
        }
    }

    private void writeTreePaths(String[] strArr, BufferedWriter bufferedWriter) throws Exception {
        int length = strArr.length - 1;
        for (int i = 0; i < length; i++) {
            bufferedWriter.write(strArr[i]);
            bufferedWriter.write(",");
        }
        bufferedWriter.write(strArr[length]);
    }

    private void loadModel(String str) throws Exception {
        try {
            loadMojo(str);
        } catch (IOException e) {
            loadPojo(str);
        }
    }

    private void loadPojo(String str) throws Exception {
        EasyPredictModelWrapper.Config convertInvalidNumbersToNa = new EasyPredictModelWrapper.Config().setModel((GenModel) Class.forName(str).newInstance()).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            convertInvalidNumbersToNa.setEnableLeafAssignment(true);
        }
        if (this.returnGLRMReconstruct) {
            convertInvalidNumbersToNa.setEnableGLRMReconstrut(true);
        }
        this.model = new EasyPredictModelWrapper(convertInvalidNumbersToNa);
    }

    private void loadMojo(String str) throws IOException {
        EasyPredictModelWrapper.Config convertInvalidNumbersToNa = new EasyPredictModelWrapper.Config().setModel(MojoModel.load(str)).setConvertUnknownCategoricalLevelsToNa(true).setConvertInvalidNumbersToNa(this.setInvNumNA);
        if (this.getTreePath) {
            convertInvalidNumbersToNa.setEnableLeafAssignment(true);
        }
        if (this.returnGLRMReconstruct) {
            convertInvalidNumbersToNa.setEnableGLRMReconstrut(true);
        }
        if (this.glrmIterNumber > 0) {
            convertInvalidNumbersToNa.setGLRMIterNumber(this.glrmIterNumber);
        }
        this.model = new EasyPredictModelWrapper(convertInvalidNumbersToNa);
    }

    private static void usage() {
        System.out.println("");
        System.out.println("Usage:  java [...java args...] hex.genmodel.tools.PredictCsv --mojo mojoName");
        System.out.println("             --pojo pojoName --input inputFile --output outputFile --separator sepStr --decimal --setConvertInvalidNum");
        System.out.println("");
        System.out.println("     --mojo    Name of the zip file containing model's MOJO.");
        System.out.println("     --pojo    Name of the java class containing the model's POJO. Either this ");
        System.out.println("               parameter or --model must be specified.");
        System.out.println("     --input   text file containing the test data set to score.");
        System.out.println("     --output  Name of the output CSV file with computed predictions.");
        System.out.println("     --separator Separator to be used in input file containing test data set.");
        System.out.println("     --decimal Use decimal numbers in the output (default is to use hexademical).");
        System.out.println("     --setConvertInvalidNum Will call .setConvertInvalidNumbersToNa(true) when loading models.");
        System.out.println("     --leafNodeAssignment will show the leaf node assignment for GBM and DRF instead of the prediction results");
        System.out.println("     --glrmReconstruct will return the reconstructed dataset for GLRM mojo instead of X factor derived from the dataset.");
        System.out.println("     --glrmIterNumber integer indicating number of iterations to go through when constructing X factor derived from the dataset.");
        System.out.println("");
        System.exit(1);
    }

    private void checkMissingColumns(String[] strArr) {
        String[] strArr2 = this.model.m._names;
        HashSet hashSet = new HashSet(strArr.length);
        for (String str : strArr) {
            hashSet.add(str);
        }
        ArrayList arrayList = new ArrayList();
        for (String str2 : strArr2) {
            if (hashSet.contains(str2) || str2.equals(this.model.m._responseColumn)) {
                hashSet.remove(arrayList);
            } else {
                arrayList.add(str2);
            }
        }
        if (arrayList.size() > 0) {
            StringBuilder sb = new StringBuilder("There were ");
            sb.append(arrayList.size());
            sb.append(" missing columns found in the input data set: {");
            for (int i = 0; i < arrayList.size(); i++) {
                sb.append((String) arrayList.get(i));
                if (i != arrayList.size() - 1) {
                    sb.append(",");
                }
            }
            sb.append('}');
            System.out.println(sb);
        }
        if (hashSet.size() > 0) {
            StringBuilder sb2 = new StringBuilder("Detected ");
            sb2.append(hashSet.size());
            sb2.append(" unused columns in the input data set: {");
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                sb2.append((String) it.next());
                if (it.hasNext()) {
                    sb2.append(",");
                }
            }
            sb2.append('}');
            System.out.println(sb2);
        }
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:53:0x013c, code lost:
    
        switch(r13) {
            case 0: goto L48;
            case 1: goto L49;
            case 2: goto L50;
            case 3: goto L51;
            case 4: goto L52;
            case 5: goto L53;
            case 6: goto L54;
            default: goto L55;
        };
     */
    /* JADX WARN: Code restructure failed: missing block: B:54:0x0168, code lost:
    
        r7 = r0;
        r8 = 2;
     */
    /* JADX WARN: Code restructure failed: missing block: B:56:0x0170, code lost:
    
        r7 = r0;
        r8 = true;
     */
    /* JADX WARN: Code restructure failed: missing block: B:58:0x0178, code lost:
    
        r7 = r0;
        r8 = false;
     */
    /* JADX WARN: Code restructure failed: missing block: B:60:0x0180, code lost:
    
        r5.inputCSVFileName = r0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:62:0x0189, code lost:
    
        r5.outputCSVFileName = r0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:64:0x0192, code lost:
    
        r5.separator = r0.charAt(r0.length() - 1);
     */
    /* JADX WARN: Code restructure failed: missing block: B:66:0x01a5, code lost:
    
        r5.glrmIterNumber = java.lang.Integer.valueOf(r0).intValue();
     */
    /* JADX WARN: Code restructure failed: missing block: B:68:0x01b4, code lost:
    
        java.lang.System.out.println("ERROR: Unknown command line argument: " + r0);
        usage();
     */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private void parseArgs(java.lang.String[] r6) {
        /*
            Method dump skipped, instructions count: 533
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: hex.genmodel.tools.PredictCsv.parseArgs(java.lang.String[]):void");
    }
}
