package org.deeplearning4j.models.sequencevectors.transformers.impl;

import java.util.Iterator;
import java.util.concurrent.atomic.AtomicInteger;
import lombok.NonNull;
import org.deeplearning4j.models.sequencevectors.graph.primitives.IGraph;
import org.deeplearning4j.models.sequencevectors.graph.walkers.GraphWalker;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.labels.LabelsProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformer.class */
public class GraphTransformer<T extends SequenceElement> implements Iterable<Sequence<T>> {
    protected IGraph<T, ?> sourceGraph;
    protected GraphWalker<T> walker;
    protected LabelsProvider<T> labelsProvider;
    protected AtomicInteger counter = new AtomicInteger(0);
    protected boolean shuffle = true;
    protected VocabCache<T> vocabCache;
    protected static final Logger log = LoggerFactory.getLogger(GraphTransformer.class);

    /* loaded from: input_file:org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformer$Builder.class */
    public static class Builder<T extends SequenceElement> {
        protected IGraph<T, ?> sourceGraph;
        protected LabelsProvider<T> labelsProvider;
        protected GraphWalker<T> walker;
        protected boolean shuffle = true;
        protected VocabCache<T> vocabCache;

        public Builder() {
        }

        public Builder(@NonNull GraphWalker<T> graphWalker) {
            if (graphWalker == null) {
                throw new NullPointerException("walker is marked @NonNull but is null");
            }
            this.walker = graphWalker;
        }

        public Builder(@NonNull IGraph<T, ?> iGraph) {
            if (iGraph == null) {
                throw new NullPointerException("sourceGraph is marked @NonNull but is null");
            }
            this.sourceGraph = iGraph;
        }

        public Builder<T> setLabelsProvider(@NonNull LabelsProvider<T> labelsProvider) {
            if (labelsProvider == null) {
                throw new NullPointerException("provider is marked @NonNull but is null");
            }
            this.labelsProvider = labelsProvider;
            return this;
        }

        public Builder<T> setGraphWalker(@NonNull GraphWalker<T> graphWalker) {
            if (graphWalker == null) {
                throw new NullPointerException("walker is marked @NonNull but is null");
            }
            this.walker = graphWalker;
            return this;
        }

        public Builder<T> setVocabCache(@NonNull VocabCache<T> vocabCache) {
            if (vocabCache == null) {
                throw new NullPointerException("vocabCache is marked @NonNull but is null");
            }
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder<T> shuffleOnReset(boolean z) {
            this.shuffle = z;
            return this;
        }

        public GraphTransformer<T> build() {
            if (this.walker == null) {
                throw new IllegalStateException("Please provide GraphWalker instance.");
            }
            GraphTransformer<T> graphTransformer = new GraphTransformer<>();
            if (this.sourceGraph == null) {
                this.sourceGraph = this.walker.getSourceGraph();
            }
            graphTransformer.sourceGraph = this.sourceGraph;
            graphTransformer.labelsProvider = this.labelsProvider;
            graphTransformer.shuffle = this.shuffle;
            graphTransformer.vocabCache = this.vocabCache;
            graphTransformer.walker = this.walker;
            graphTransformer.initialize();
            return graphTransformer;
        }
    }

    protected GraphTransformer() {
    }

    protected void initialize() {
        log.info("Building Huffman tree for source graph...");
        int numVertices = this.sourceGraph.numVertices();
        log.info("Transferring Huffman tree info to nodes...");
        for (int i = 0; i < numVertices; i++) {
            T value = this.sourceGraph.getVertex(i).getValue();
            value.setElementFrequency(this.sourceGraph.getConnectedVertices(i).size());
            if (this.vocabCache != null) {
                this.vocabCache.addToken(value);
            }
        }
        if (this.vocabCache != null) {
            Huffman huffman = new Huffman(this.vocabCache.vocabWords());
            huffman.build();
            huffman.applyIndexes(this.vocabCache);
        }
    }

    @Override // java.lang.Iterable
    public Iterator<Sequence<T>> iterator() {
        this.counter.set(0);
        this.walker.reset(this.shuffle);
        return (Iterator<Sequence<T>>) new Iterator<Sequence<T>>() { // from class: org.deeplearning4j.models.sequencevectors.transformers.impl.GraphTransformer.1
            private GraphWalker<T> walker;

            {
                this.walker = GraphTransformer.this.walker;
            }

            @Override // java.util.Iterator
            public void remove() {
                throw new UnsupportedOperationException("This is not supported on read-only iterator");
            }

            @Override // java.util.Iterator
            public boolean hasNext() {
                return this.walker.hasNext();
            }

            @Override // java.util.Iterator
            public Sequence<T> next() {
                Sequence<T> next = this.walker.next();
                next.setSequenceId(GraphTransformer.this.counter.getAndIncrement());
                if (this.walker.isLabelEnabled() && next.getSequenceLabels() == null && GraphTransformer.this.labelsProvider != null) {
                    next.setSequenceLabel(GraphTransformer.this.labelsProvider.getLabel(next.getSequenceId()));
                }
                return next;
            }
        };
    }
}
