package org.deeplearning4j.base;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/deeplearning4j/base/IrisUtils.class */
public class IrisUtils {
    public static List<DataSet> loadIris(int i, int i2) throws IOException {
        List readLines = IOUtils.readLines(new ClassPathResource("/iris.dat").getInputStream());
        ArrayList arrayList = new ArrayList();
        DoubleMatrix ones = DoubleMatrix.ones(i2, 4);
        ArrayList arrayList2 = new ArrayList();
        double[][] dArr = new double[readLines.size()][3];
        for (int i3 = i; i3 < i2; i3++) {
            String[] split = ((String) readLines.get(i3)).split(",");
            addRow(ones, i3, split);
            String str = split[split.length - 1];
            if (!arrayList2.contains(str)) {
                arrayList2.add(str);
            }
            double[] dArr2 = new double[3];
            dArr2[arrayList2.indexOf(str)] = 1.0d;
            dArr[i3] = dArr2;
        }
        MatrixUtil.scaleByMax(ones);
        for (int i4 = 0; i4 < ones.rows; i4++) {
            arrayList.add(new DataSet(ones.getRow(i4), new DoubleMatrix(dArr[i4]).transpose()));
        }
        return arrayList;
    }

    public static DataSet loadIris() throws IOException {
        List readLines = IOUtils.readLines(new ClassPathResource("/iris.dat").getInputStream());
        Collections.shuffle(readLines);
        Collections.rotate(readLines, 3);
        DoubleMatrix ones = DoubleMatrix.ones(readLines.size(), 4);
        ArrayList arrayList = new ArrayList();
        double[][] dArr = new double[readLines.size()][3];
        for (int i = 0; i < readLines.size(); i++) {
            String[] split = ((String) readLines.get(i)).split(",");
            addRow(ones, i, split);
            String str = split[split.length - 1];
            if (!arrayList.contains(str)) {
                arrayList.add(str);
            }
            double[] dArr2 = new double[3];
            dArr2[arrayList.indexOf(str)] = 1.0d;
            dArr[i] = dArr2;
        }
        MatrixUtil.columnNormalizeBySum(ones);
        DoubleMatrix roundToTheNearest = MatrixUtil.roundToTheNearest(ones, 10000);
        MatrixUtil.discretizeColumns(roundToTheNearest, 4);
        return new DataSet(roundToTheNearest.mul(0.01d), new DoubleMatrix(dArr));
    }

    public static Pair<DoubleMatrix, DoubleMatrix> loadIris(int i) throws IOException {
        List readLines = IOUtils.readLines(new ClassPathResource("/iris.dat").getInputStream());
        Collections.shuffle(readLines);
        Collections.rotate(readLines, 3);
        Random random = new Random(1L);
        DoubleMatrix ones = DoubleMatrix.ones(i, 4);
        ArrayList arrayList = new ArrayList();
        double[][] dArr = new double[i][3];
        int i2 = 0;
        while (i2 < i) {
            String[] split = (i2 >= readLines.size() ? (String) readLines.get(random.nextInt(readLines.size())) : (String) readLines.get(i2)).split(",");
            addRow(ones, i2, split);
            String str = split[split.length - 1];
            if (!arrayList.contains(str)) {
                arrayList.add(str);
            }
            double[] dArr2 = new double[3];
            dArr2[arrayList.indexOf(str)] = 1.0d;
            dArr[i2] = dArr2;
            i2++;
        }
        return new Pair<>(ones, new DoubleMatrix(dArr));
    }

    private static void addRow(DoubleMatrix doubleMatrix, int i, String[] strArr) {
        double[] dArr = new double[4];
        for (int i2 = 0; i2 < 4; i2++) {
            dArr[i2] = Double.parseDouble(strArr[i2]);
        }
        doubleMatrix.putRow(i, new DoubleMatrix(dArr));
    }
}
