package ai.djl.training;

import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.TrainingMetric;

/* loaded from: input_file:ai/djl/training/Trainer.class */
public interface Trainer extends AutoCloseable {
    void initialize(Shape... shapeArr);

    default Iterable<Batch> iterateDataset(Dataset dataset) {
        return dataset.getData(getManager());
    }

    GradientCollector newGradientCollector();

    void trainBatch(Batch batch);

    NDList forward(NDList nDList);

    void validateBatch(Batch batch);

    void step();

    void setMetrics(Metrics metrics);

    void setTrainingListener(TrainingListener trainingListener);

    void resetTrainingMetrics();

    Loss getLoss();

    Loss getValidationLoss();

    Model getModel();

    <T extends TrainingMetric> T getTrainingMetric(Class<T> cls);

    <T extends TrainingMetric> T getValidationMetric(Class<T> cls);

    NDManager getManager();

    @Override // java.lang.AutoCloseable
    void close();
}
