package ai.djl.modality.nlp;

import ai.djl.util.Utils;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;

/* loaded from: input_file:ai/djl/modality/nlp/DefaultVocabulary.class */
public class DefaultVocabulary implements Vocabulary {
    private Map<String, TokenInfo> tokens;
    private List<String> indexToToken;
    private Set<String> reservedTokens;
    private String unknownToken;

    /* loaded from: input_file:ai/djl/modality/nlp/DefaultVocabulary$Builder.class */
    public static final class Builder {
        List<List<String>> sentences;
        Set<String> reservedTokens;
        int minFrequency;
        int maxTokens;
        String unknownToken;

        private Builder() {
            this.sentences = new ArrayList();
            this.reservedTokens = new HashSet();
            this.minFrequency = -1;
            this.maxTokens = -1;
        }

        public Builder optMinFrequency(int i) {
            this.minFrequency = i;
            return this;
        }

        public Builder optMaxTokens(int i) {
            this.maxTokens = i;
            return this;
        }

        public Builder optUnknownToken() {
            return optUnknownToken("<unk>");
        }

        public Builder optUnknownToken(String str) {
            this.unknownToken = str;
            return this;
        }

        public Builder optReservedTokens(Collection<String> collection) {
            this.reservedTokens.addAll(collection);
            return this;
        }

        public Builder add(List<String> list) {
            this.sentences.add(list);
            return this;
        }

        public Builder addAll(List<List<String>> list) {
            this.sentences.addAll(list);
            return this;
        }

        public Builder addFromTextFile(Path path) throws IOException {
            add(Utils.readLines(path, true));
            return this;
        }

        public Builder addFromTextFile(URL url) throws IOException {
            InputStream openStream = url.openStream();
            Throwable th = null;
            try {
                try {
                    add(Utils.readLines(openStream, true));
                    if (openStream != null) {
                        if (0 != 0) {
                            try {
                                openStream.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            openStream.close();
                        }
                    }
                    return this;
                } finally {
                }
            } catch (Throwable th3) {
                if (openStream != null) {
                    if (th != null) {
                        try {
                            openStream.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        openStream.close();
                    }
                }
                throw th3;
            }
        }

        public Builder addFromCustomizedFile(URL url, Function<URL, List<String>> function) {
            return add(function.apply(url));
        }

        public DefaultVocabulary build() {
            if (this.maxTokens <= 0 || this.maxTokens >= this.reservedTokens.size()) {
                return new DefaultVocabulary(this);
            }
            throw new IllegalArgumentException("The vocabulary maxTokens can not be smaller than the number of reserved tokens");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/modality/nlp/DefaultVocabulary$TokenInfo.class */
    public static final class TokenInfo {
        int frequency;
        long index;

        private TokenInfo() {
            this.index = -1L;
        }
    }

    public DefaultVocabulary(List<String> list) {
        this(builder().add(list));
    }

    public DefaultVocabulary(Builder builder) {
        this.tokens = new ConcurrentHashMap();
        this.reservedTokens = builder.reservedTokens;
        this.unknownToken = builder.unknownToken;
        if (this.unknownToken != null) {
            this.reservedTokens.add(this.unknownToken);
        }
        addReservedTokens(this.reservedTokens);
        Iterator<List<String>> it = builder.sentences.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().iterator();
            while (it2.hasNext()) {
                addToken(it2.next());
            }
        }
        if (pruneTokens(builder.minFrequency, builder.maxTokens)) {
            initializeIndexToTokenReplacingIndices();
        } else {
            initializeIndexToTokenKeepingIndices();
        }
    }

    private void addToken(String str) {
        if (this.reservedTokens.contains(str)) {
            return;
        }
        int size = this.tokens.size();
        this.tokens.computeIfAbsent(str, str2 -> {
            TokenInfo tokenInfo = new TokenInfo();
            tokenInfo.index = size;
            return tokenInfo;
        });
        TokenInfo tokenInfo = this.tokens.get(str);
        if (tokenInfo.frequency < Integer.MAX_VALUE) {
            tokenInfo.frequency++;
        }
    }

    private boolean pruneTokens(int i, int i2) {
        boolean z = false;
        if (i > 1) {
            for (Map.Entry<String, TokenInfo> entry : this.tokens.entrySet()) {
                if (entry.getValue().frequency < i) {
                    this.tokens.remove(entry.getKey());
                }
            }
            z = true;
        }
        if (i2 > 0 && this.tokens.size() > i2) {
            this.tokens.entrySet().stream().sorted(Map.Entry.comparingByValue(Comparator.comparingInt(tokenInfo -> {
                return -tokenInfo.frequency;
            }))).skip(i2).forEach(entry2 -> {
                this.tokens.remove(entry2.getKey());
            });
            z = true;
        }
        return z;
    }

    private void addReservedTokens(Collection<String> collection) {
        for (String str : collection) {
            int size = this.tokens.size();
            TokenInfo tokenInfo = new TokenInfo();
            tokenInfo.frequency = Integer.MAX_VALUE;
            tokenInfo.index = size;
            this.tokens.put(str, tokenInfo);
        }
    }

    private void initializeIndexToTokenKeepingIndices() {
        this.indexToToken = Arrays.asList(new String[this.tokens.size()]);
        for (Map.Entry<String, TokenInfo> entry : this.tokens.entrySet()) {
            this.indexToToken.set(Math.toIntExact(entry.getValue().index), entry.getKey());
        }
    }

    private void initializeIndexToTokenReplacingIndices() {
        this.indexToToken = (List) this.tokens.entrySet().stream().sorted(Comparator.comparingLong(entry -> {
            return ((TokenInfo) entry.getValue()).index;
        })).map((v0) -> {
            return v0.getKey();
        }).collect(Collectors.toList());
        for (int i = 0; i < this.indexToToken.size(); i++) {
            this.tokens.get(this.indexToToken.get(i)).index = i;
        }
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public boolean contains(String str) {
        return this.tokens.containsKey(str);
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public String getToken(long j) {
        return (j < 0 || j >= ((long) this.indexToToken.size())) ? this.unknownToken : this.indexToToken.get((int) j);
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public long getIndex(String str) {
        if (this.tokens.containsKey(str)) {
            return this.tokens.get(str).index;
        }
        if (this.unknownToken != null) {
            return this.tokens.get(this.unknownToken).index;
        }
        throw new IllegalStateException("Unexpected token in getIndex. Define an unknownToken for the vocabulary to enable support for unknown tokens.");
    }

    @Override // ai.djl.modality.nlp.Vocabulary
    public long size() {
        return this.tokens.size();
    }

    public static Builder builder() {
        return new Builder();
    }
}
