package org.deeplearning4j.models.embeddings.learning.impl.sequence;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm;
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
import org.deeplearning4j.models.sequencevectors.interfaces.SequenceIterator;
import org.deeplearning4j.models.sequencevectors.sequence.Sequence;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;

/* loaded from: input_file:org/deeplearning4j/models/embeddings/learning/impl/sequence/DBOW.class */
public class DBOW<T extends SequenceElement> implements SequenceLearningAlgorithm<T> {
    protected VocabCache<T> vocabCache;
    protected WeightLookupTable<T> lookupTable;
    protected VectorsConfiguration configuration;
    protected int window;
    protected boolean useAdaGrad;
    protected double negative;
    protected SkipGram<T> skipGram = new SkipGram<>();

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public String getCodeName() {
        return "DBOW";
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void configure(@NonNull VocabCache<T> vocabCache, @NonNull WeightLookupTable<T> weightLookupTable, @NonNull VectorsConfiguration vectorsConfiguration) {
        if (vocabCache == null) {
            throw new NullPointerException("vocabCache");
        }
        if (weightLookupTable == null) {
            throw new NullPointerException("lookupTable");
        }
        if (vectorsConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        this.vocabCache = vocabCache;
        this.lookupTable = weightLookupTable;
        this.window = vectorsConfiguration.getWindow();
        this.useAdaGrad = vectorsConfiguration.isUseAdaGrad();
        this.negative = vectorsConfiguration.getNegative();
        this.skipGram.configure(vocabCache, weightLookupTable, vectorsConfiguration);
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void pretrain(SequenceIterator<T> sequenceIterator) {
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public void learnSequence(@NonNull Sequence<T> sequence, @NonNull AtomicLong atomicLong, double d) {
        if (sequence == null) {
            throw new NullPointerException("sequence");
        }
        if (atomicLong == null) {
            throw new NullPointerException("nextRandom");
        }
        for (int i = 0; i < sequence.getElements().size(); i++) {
            dbow(i, sequence, ((int) atomicLong.get()) % this.window, atomicLong, d);
        }
    }

    @Override // org.deeplearning4j.models.embeddings.learning.SequenceLearningAlgorithm
    public boolean isEarlyTerminationHit() {
        return false;
    }

    /* JADX WARN: Multi-variable type inference failed */
    protected void dbow(int i, Sequence<T> sequence, int i2, AtomicLong atomicLong, double d) {
        int i3;
        T t = sequence.getElements().get(i);
        List<T> elements = sequence.getElements();
        ArrayList arrayList = new ArrayList();
        arrayList.add(sequence.getSequenceLabel());
        if (sequence.getSequenceLabel() == null) {
            throw new IllegalStateException("Label is NULL");
        }
        if (t == null || elements.isEmpty()) {
            return;
        }
        int i4 = ((this.window * 2) + 1) - i2;
        for (int i5 = i2; i5 < i4; i5++) {
            if (i5 != this.window && (i3 = (i - this.window) + i5) >= 0 && i3 < arrayList.size()) {
                this.skipGram.iterateSample(t, (SequenceElement) arrayList.get(i3), atomicLong, d);
            }
        }
    }
}
