package org.deeplearning4j.iterator.provider;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.datavec.api.util.RandomUtils;
import org.deeplearning4j.iterator.LabeledSentenceProvider;
import org.nd4j.linalg.collection.CompactHeapStringList;
import org.nd4j.linalg.primitives.Pair;

/* loaded from: input_file:org/deeplearning4j/iterator/provider/FileLabeledSentenceProvider.class */
public class FileLabeledSentenceProvider implements LabeledSentenceProvider {
    private final int totalCount;
    private final List<String> filePaths;
    private final int[] fileLabelIndexes;
    private final Random rng;
    private final int[] order;
    private final List<String> allLabels;
    private int cursor;

    public FileLabeledSentenceProvider(Map<String, List<File>> map) {
        this(map, new Random());
    }

    public FileLabeledSentenceProvider(@NonNull Map<String, List<File>> map, Random random) {
        this.cursor = 0;
        if (map == null) {
            throw new NullPointerException("filesByLabel");
        }
        int i = 0;
        Iterator<List<File>> it = map.values().iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        this.totalCount = i;
        this.rng = random;
        if (random == null) {
            this.order = null;
        } else {
            this.order = new int[i];
            for (int i2 = 0; i2 < i; i2++) {
                this.order[i2] = i2;
            }
            RandomUtils.shuffleInPlace(this.order, random);
        }
        this.allLabels = new ArrayList(map.keySet());
        Collections.sort(this.allLabels);
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this.allLabels.size(); i3++) {
            hashMap.put(this.allLabels.get(i3), Integer.valueOf(i3));
        }
        this.filePaths = new CompactHeapStringList();
        this.fileLabelIndexes = new int[i];
        int i4 = 0;
        for (Map.Entry<String, List<File>> entry : map.entrySet()) {
            int intValue = ((Integer) hashMap.get(entry.getKey())).intValue();
            Iterator<File> it2 = entry.getValue().iterator();
            while (it2.hasNext()) {
                this.filePaths.add(it2.next().getPath());
                this.fileLabelIndexes[i4] = intValue;
                i4++;
            }
        }
    }

    @Override // org.deeplearning4j.iterator.LabeledSentenceProvider
    public boolean hasNext() {
        return this.cursor < this.totalCount;
    }

    @Override // org.deeplearning4j.iterator.LabeledSentenceProvider
    public Pair<String, String> nextSentence() {
        int i;
        if (this.rng == null) {
            int i2 = this.cursor;
            this.cursor = i2 + 1;
            i = i2;
        } else {
            int[] iArr = this.order;
            int i3 = this.cursor;
            this.cursor = i3 + 1;
            i = iArr[i3];
        }
        File file = new File(this.filePaths.get(i));
        try {
            return new Pair<>(FileUtils.readFileToString(file), this.allLabels.get(this.fileLabelIndexes[i]));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.iterator.LabeledSentenceProvider
    public void reset() {
        this.cursor = 0;
        if (this.rng != null) {
            RandomUtils.shuffleInPlace(this.order, this.rng);
        }
    }

    @Override // org.deeplearning4j.iterator.LabeledSentenceProvider
    public int totalNumSentences() {
        return this.totalCount;
    }

    @Override // org.deeplearning4j.iterator.LabeledSentenceProvider
    public List<String> allLabels() {
        return this.allLabels;
    }

    @Override // org.deeplearning4j.iterator.LabeledSentenceProvider
    public int numLabelClasses() {
        return this.allLabels.size();
    }
}
