package ai.djl.training.tracker;

import ai.djl.util.Preconditions;
import ai.djl.util.cuda.CudaLibrary;

/* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker.class */
public class CyclicalTracker implements Tracker {
    private float baseValue;
    private float maxValue;
    private int stepSizeUp;
    private int stepSizeDown;
    private int totalSize;
    private float stepRatio;
    private ScaleFunction scaleFunction;
    private boolean scaleModeCycle;

    /* renamed from: ai.djl.training.tracker.CyclicalTracker$1, reason: invalid class name */
    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$djl$training$tracker$CyclicalTracker$CyclicalMode = new int[CyclicalMode.values().length];

        static {
            try {
                $SwitchMap$ai$djl$training$tracker$CyclicalTracker$CyclicalMode[CyclicalMode.TRIANGULAR.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$djl$training$tracker$CyclicalTracker$CyclicalMode[CyclicalMode.TRIANGULAR2.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$djl$training$tracker$CyclicalTracker$CyclicalMode[CyclicalMode.EXP_RANGE.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$Builder.class */
    public static final class Builder {
        private float baseValue;
        private float maxValue;
        private int stepSizeUp;
        private int stepSizeDown;
        private CyclicalMode mode;
        private ScaleFunction scaleFunction;
        private boolean scaleModeCycle;
        private float gamma;

        private Builder() {
            this.baseValue = 0.001f;
            this.maxValue = 0.006f;
            this.stepSizeUp = 2000;
            this.mode = CyclicalMode.TRIANGULAR;
            this.scaleModeCycle = true;
            this.gamma = 1.0f;
        }

        public Builder optBaseValue(float f) {
            this.baseValue = f;
            return this;
        }

        public Builder optMaxValue(float f) {
            this.maxValue = f;
            return this;
        }

        public Builder optStepSizeUp(int i) {
            this.stepSizeUp = i;
            return this;
        }

        public Builder optStepSizeDown(int i) {
            this.stepSizeDown = i;
            return this;
        }

        public Builder optMode(CyclicalMode cyclicalMode) {
            this.mode = cyclicalMode;
            return this;
        }

        public Builder optGamma(float f) {
            this.gamma = f;
            return this;
        }

        public Builder optScaleFunction(ScaleFunction scaleFunction) {
            this.scaleFunction = scaleFunction;
            return this;
        }

        public Builder optScaleModeCycle(boolean z) {
            this.scaleModeCycle = z;
            return this;
        }

        public CyclicalTracker build() {
            Preconditions.checkArgument(this.baseValue > 0.0f, "baseValue has to be positive!");
            Preconditions.checkArgument(this.maxValue > 0.0f, "maxValue has to be positive!");
            Preconditions.checkArgument(this.baseValue <= this.maxValue, "baseValue has to lower than maxValue!");
            Preconditions.checkArgument(this.stepSizeUp >= 1, "stepSizeUp has to be positive!");
            Preconditions.checkArgument(this.stepSizeDown >= 0, "stepSizeUp cannot be negative!");
            Preconditions.checkArgument(this.gamma >= 0.0f && this.gamma <= 1.0f, "gamma has to be between 0 and 1!");
            return new CyclicalTracker(this);
        }

        /* synthetic */ Builder(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$CyclicalMode.class */
    public enum CyclicalMode {
        TRIANGULAR,
        TRIANGULAR2,
        EXP_RANGE
    }

    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$ExpRangeScaleFunction.class */
    private static class ExpRangeScaleFunction implements ScaleFunction {
        float gamma;

        ExpRangeScaleFunction(float f) {
            this.gamma = f;
        }

        @Override // ai.djl.training.tracker.CyclicalTracker.ScaleFunction
        public float func(int i) {
            return (float) Math.pow(this.gamma, i);
        }
    }

    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$ScaleFunction.class */
    public interface ScaleFunction {
        float func(int i);
    }

    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$Triangular2ScaleFunction.class */
    private static class Triangular2ScaleFunction implements ScaleFunction {
        private Triangular2ScaleFunction() {
        }

        @Override // ai.djl.training.tracker.CyclicalTracker.ScaleFunction
        public float func(int i) {
            return (float) (1.0d / Math.pow(2.0d, i - 1));
        }

        /* synthetic */ Triangular2ScaleFunction(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    /* loaded from: input_file:ai/djl/training/tracker/CyclicalTracker$TriangularScaleFunction.class */
    private static class TriangularScaleFunction implements ScaleFunction {
        private TriangularScaleFunction() {
        }

        @Override // ai.djl.training.tracker.CyclicalTracker.ScaleFunction
        public float func(int i) {
            return 1.0f;
        }

        /* synthetic */ TriangularScaleFunction(AnonymousClass1 anonymousClass1) {
            this();
        }
    }

    public CyclicalTracker(Builder builder) {
        this.baseValue = builder.baseValue;
        this.maxValue = builder.maxValue;
        this.stepSizeUp = builder.stepSizeUp;
        this.stepSizeDown = builder.stepSizeDown > 0 ? builder.stepSizeDown : builder.stepSizeUp;
        this.totalSize = this.stepSizeUp + this.stepSizeDown;
        this.stepRatio = this.stepSizeUp / this.totalSize;
        if (builder.scaleFunction != null) {
            this.scaleFunction = builder.scaleFunction;
            this.scaleModeCycle = builder.scaleModeCycle;
            return;
        }
        switch (AnonymousClass1.$SwitchMap$ai$djl$training$tracker$CyclicalTracker$CyclicalMode[builder.mode.ordinal()]) {
            case 1:
                this.scaleFunction = new TriangularScaleFunction(null);
                this.scaleModeCycle = true;
                return;
            case 2:
                this.scaleFunction = new Triangular2ScaleFunction(null);
                this.scaleModeCycle = true;
                return;
            case CudaLibrary.INITIALIZATION_ERROR /* 3 */:
                this.scaleFunction = new ExpRangeScaleFunction(builder.gamma);
                this.scaleModeCycle = false;
                return;
            default:
                throw new UnsupportedOperationException("Unsupported Cyclical mode.");
        }
    }

    public static Builder builder() {
        return new Builder(null);
    }

    @Override // ai.djl.training.tracker.Tracker
    public float getNewValue(int i) {
        int floor = (int) Math.floor(1.0f + (i / this.totalSize));
        float f = (1.0f + (i / this.totalSize)) - floor;
        float f2 = (this.maxValue - this.baseValue) * (f < this.stepRatio ? f / this.stepRatio : (f - 1.0f) / (this.stepRatio - 1.0f));
        return this.scaleModeCycle ? this.baseValue + (f2 * this.scaleFunction.func(floor)) : this.baseValue + (f2 * this.scaleFunction.func(i));
    }
}
