package ai.djl.timeseries.block;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.timeseries.block.Scaler;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/timeseries/block/MeanScaler.class */
public class MeanScaler extends Scaler {
    private float minimumScale;

    /* loaded from: input_file:ai/djl/timeseries/block/MeanScaler$Builder.class */
    public static final class Builder extends Scaler.ScalerBuilder<Builder> {
        private float minimumScale = 1.0E-10f;

        Builder() {
        }

        public Builder optMinimumScale(float f) {
            this.minimumScale = f;
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.timeseries.block.Scaler.ScalerBuilder
        public Builder self() {
            return this;
        }

        public MeanScaler build() {
            validate();
            return new MeanScaler(this);
        }
    }

    MeanScaler(Builder builder) {
        super(builder);
        this.minimumScale = builder.minimumScale;
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray nDArray = (NDArray) nDList.get(0);
        NDArray nDArray2 = (NDArray) nDList.get(1);
        NDArray sum = nDArray2.sum(new int[]{this.dim});
        NDArray sum2 = nDArray.abs().mul(nDArray2).sum(new int[]{this.dim});
        NDArray expandDims = NDArrays.maximum(Float.valueOf(this.minimumScale), NDArrays.where(sum2.gt(sum2.zerosLike()), sum2.div(NDArrays.maximum(sum, Float.valueOf(1.0f))), sum2.sum(new int[]{0}).div(NDArrays.maximum(sum.sum(new int[]{0}), Float.valueOf(1.0f))).mul(sum.onesLike()))).expandDims(this.dim);
        NDArray[] nDArrayArr = new NDArray[2];
        nDArrayArr[0] = nDArray.div(expandDims);
        nDArrayArr[1] = this.keepDim ? expandDims : expandDims.squeeze(this.dim);
        return new NDList(nDArrayArr);
    }

    public static Builder builder() {
        return new Builder();
    }
}
