package org.nd4j.autodiff.samediff.config;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

/* loaded from: input_file:org/nd4j/autodiff/samediff/config/FitConfig.class */
public class FitConfig {
    private SameDiff sd;
    private MultiDataSetIterator trainingData;
    private MultiDataSetIterator validationData = null;
    private int epochs = -1;
    private int validationFrequency = 1;

    @NonNull
    private List<Listener> listeners = new ArrayList();

    public FitConfig(@NonNull SameDiff sameDiff) {
        if (sameDiff == null) {
            throw new NullPointerException("sd is marked non-null but is null");
        }
        this.sd = sameDiff;
    }

    public FitConfig epochs(int i) {
        this.epochs = i;
        return this;
    }

    public FitConfig train(@NonNull MultiDataSetIterator multiDataSetIterator) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("trainingData is marked non-null but is null");
        }
        this.trainingData = multiDataSetIterator;
        return this;
    }

    public FitConfig train(@NonNull DataSetIterator dataSetIterator) {
        if (dataSetIterator == null) {
            throw new NullPointerException("trainingData is marked non-null but is null");
        }
        return train(new MultiDataSetIteratorAdapter(dataSetIterator));
    }

    public FitConfig train(@NonNull MultiDataSetIterator multiDataSetIterator, int i) {
        if (multiDataSetIterator == null) {
            throw new NullPointerException("trainingData is marked non-null but is null");
        }
        return train(multiDataSetIterator).epochs(i);
    }

    public FitConfig train(@NonNull DataSetIterator dataSetIterator, int i) {
        if (dataSetIterator == null) {
            throw new NullPointerException("trainingData is marked non-null but is null");
        }
        return train(dataSetIterator).epochs(i);
    }

    public FitConfig validate(MultiDataSetIterator multiDataSetIterator) {
        this.validationData = multiDataSetIterator;
        return this;
    }

    public FitConfig validate(DataSetIterator dataSetIterator) {
        return dataSetIterator == null ? validate((MultiDataSetIterator) null) : validate(new MultiDataSetIteratorAdapter(dataSetIterator));
    }

    public FitConfig validationFrequency(int i) {
        this.validationFrequency = i;
        return this;
    }

    public FitConfig validate(MultiDataSetIterator multiDataSetIterator, int i) {
        return validate(multiDataSetIterator).validationFrequency(i);
    }

    public FitConfig validate(DataSetIterator dataSetIterator, int i) {
        return validate(dataSetIterator).validationFrequency(i);
    }

    public FitConfig listeners(@NonNull Listener... listenerArr) {
        if (listenerArr == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        this.listeners.addAll(Arrays.asList(listenerArr));
        return this;
    }

    private void validateConfig() {
        Preconditions.checkNotNull(this.trainingData, "Training data must not be null");
        Preconditions.checkState(this.epochs > 0, "Epochs must be > 0, got %s", this.epochs);
        if (this.validationData != null) {
            Preconditions.checkState(this.validationFrequency > 0, "Validation Frequency must be > 0 if validation data is given, got %s", this.validationFrequency);
        }
    }

    public History exec() {
        validateConfig();
        return this.sd.fit(this.trainingData, this.epochs, this.validationData, this.validationFrequency, (Listener[]) this.listeners.toArray(new Listener[0]));
    }

    public SameDiff getSd() {
        return this.sd;
    }

    public MultiDataSetIterator getTrainingData() {
        return this.trainingData;
    }

    public MultiDataSetIterator getValidationData() {
        return this.validationData;
    }

    public int getEpochs() {
        return this.epochs;
    }

    public int getValidationFrequency() {
        return this.validationFrequency;
    }

    @NonNull
    public List<Listener> getListeners() {
        return this.listeners;
    }

    public void setTrainingData(MultiDataSetIterator multiDataSetIterator) {
        this.trainingData = multiDataSetIterator;
    }

    public void setValidationData(MultiDataSetIterator multiDataSetIterator) {
        this.validationData = multiDataSetIterator;
    }

    public void setEpochs(int i) {
        this.epochs = i;
    }

    public void setValidationFrequency(int i) {
        this.validationFrequency = i;
    }

    public void setListeners(@NonNull List<Listener> list) {
        if (list == null) {
            throw new NullPointerException("listeners is marked non-null but is null");
        }
        this.listeners = list;
    }
}
