package ai.djl.training.optimizer;

import ai.djl.Device;
import ai.djl.ndarray.NDArray;
import ai.djl.training.optimizer.Adadelta;
import ai.djl.training.optimizer.Adagrad;
import ai.djl.training.optimizer.Adam;
import ai.djl.training.optimizer.Nag;
import ai.djl.training.optimizer.RmsProp;
import ai.djl.training.optimizer.Sgd;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/training/optimizer/Optimizer.class */
public abstract class Optimizer {
    protected float rescaleGrad;
    protected float clipGrad;
    private float weightDecays;
    private int beginNumUpdate;
    private int numUpdate;
    private Map<String, Integer> updateCounts = new ConcurrentHashMap();

    /* loaded from: input_file:ai/djl/training/optimizer/Optimizer$OptimizerBuilder.class */
    public static abstract class OptimizerBuilder<T extends OptimizerBuilder> {
        private float weightDecays;
        private int beginNumUpdate;
        private float rescaleGrad = 1.0f;
        private float clipGrad = -1.0f;

        public T setRescaleGrad(float f) {
            this.rescaleGrad = f;
            return self();
        }

        public T optWeightDecays(float f) {
            this.weightDecays = f;
            return self();
        }

        public T optClipGrad(float f) {
            this.clipGrad = f;
            return self();
        }

        public T optBeginNumUpdate(int i) {
            this.beginNumUpdate = i;
            return self();
        }

        protected abstract T self();
    }

    public Optimizer(OptimizerBuilder<?> optimizerBuilder) {
        this.rescaleGrad = ((OptimizerBuilder) optimizerBuilder).rescaleGrad;
        this.weightDecays = ((OptimizerBuilder) optimizerBuilder).weightDecays;
        this.clipGrad = ((OptimizerBuilder) optimizerBuilder).clipGrad;
        this.beginNumUpdate = ((OptimizerBuilder) optimizerBuilder).beginNumUpdate;
    }

    public static Sgd.Builder sgd() {
        return new Sgd.Builder();
    }

    public static Nag.Builder nag() {
        return new Nag.Builder();
    }

    public static Adam.Builder adam() {
        return new Adam.Builder();
    }

    public static RmsProp.Builder rmsprop() {
        return new RmsProp.Builder();
    }

    public static Adagrad.Builder adagrad() {
        return new Adagrad.Builder();
    }

    public static Adadelta.Builder adadelta() {
        return new Adadelta.Builder();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float getWeightDecay() {
        return this.weightDecays;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int updateCount(String str) {
        this.numUpdate = Math.max(this.numUpdate, this.updateCounts.compute(str, (str2, num) -> {
            return Integer.valueOf(num == null ? this.beginNumUpdate + 1 : num.intValue() + 1);
        }).intValue());
        return this.numUpdate;
    }

    public abstract void update(String str, NDArray nDArray, NDArray nDArray2);

    /* JADX INFO: Access modifiers changed from: protected */
    public NDArray withDefaultState(Map<String, Map<Device, NDArray>> map, String str, Device device, Function<String, NDArray> function) {
        Map<Device, NDArray> computeIfAbsent = map.computeIfAbsent(str, str2 -> {
            ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
            NDArray nDArray = (NDArray) function.apply(str2);
            nDArray.detach();
            concurrentHashMap.put(device, nDArray);
            return concurrentHashMap;
        });
        return computeIfAbsent.computeIfAbsent(device, device2 -> {
            return ((NDArray) computeIfAbsent.values().iterator().next()).toDevice(device, true);
        });
    }
}
