package ai.djl.timeseries.translator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.timeseries.Forecast;
import ai.djl.timeseries.SampleForecast;
import ai.djl.timeseries.TimeSeriesData;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.timefeature.Lag;
import ai.djl.timeseries.timefeature.TimeFeature;
import ai.djl.timeseries.transform.InstanceSampler;
import ai.djl.timeseries.transform.PredictionSplitSampler;
import ai.djl.timeseries.transform.convert.Convert;
import ai.djl.timeseries.transform.feature.Feature;
import ai.djl.timeseries.transform.field.Field;
import ai.djl.timeseries.transform.split.Split;
import ai.djl.timeseries.translator.BaseTimeSeriesTranslator;
import ai.djl.translate.ArgumentsUtil;
import ai.djl.translate.TranslatorContext;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;

/* loaded from: input_file:ai/djl/timeseries/translator/DeepARTranslator.class */
public class DeepARTranslator extends BaseTimeSeriesTranslator {
    private boolean useFeatDynamicReal;
    private boolean useFeatStaticReal;
    private boolean useFeatStaticCat;
    private int historyLength;
    private static final String[] PRED_INPUT_FIELDS = {FieldName.FEAT_STATIC_CAT.name(), FieldName.FEAT_STATIC_REAL.name(), "PAST_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.TARGET.name(), "PAST_" + FieldName.OBSERVED_VALUES.name(), "FUTURE_" + FieldName.FEAT_TIME.name()};
    private static final FieldName[] TIME_SERIES_FIELDS = {FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES};
    private List<BiFunction<NDManager, List<LocalDateTime>, NDArray>> timeFeatures;
    private InstanceSampler instanceSampler;
    private String[] predictInputFields;

    /* loaded from: input_file:ai/djl/timeseries/translator/DeepARTranslator$Builder.class */
    public static class Builder extends BaseTimeSeriesTranslator.BaseBuilder<Builder> {
        boolean useFeatDynamicReal;
        boolean useFeatStaticReal;
        boolean useFeatStaticCat;
        boolean useIsPad;

        Builder() {
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.timeseries.translator.BaseTimeSeriesTranslator.BaseBuilder
        public Builder self() {
            return this;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // ai.djl.timeseries.translator.BaseTimeSeriesTranslator.BaseBuilder
        public void configPreProcess(Map<String, ?> map) {
            super.configPreProcess(map);
            this.useFeatDynamicReal = ArgumentsUtil.booleanValue(map, "use_" + FieldName.FEAT_DYNAMIC_REAL.name().toLowerCase(), false);
            this.useFeatStaticCat = ArgumentsUtil.booleanValue(map, "use_" + FieldName.FEAT_STATIC_CAT.name().toLowerCase(), false);
            this.useFeatStaticReal = ArgumentsUtil.booleanValue(map, "use_" + FieldName.FEAT_STATIC_REAL.name().toLowerCase(), false);
            this.useIsPad = ArgumentsUtil.booleanValue(map, "use_is_pad", true);
        }

        public DeepARTranslator build() {
            validate();
            return new DeepARTranslator(this);
        }
    }

    public DeepARTranslator(Builder builder) {
        super(builder);
        this.useFeatDynamicReal = builder.useFeatDynamicReal;
        this.useFeatStaticReal = builder.useFeatStaticReal;
        this.useFeatStaticCat = builder.useFeatStaticCat;
        List<Integer> lagsForFreq = Lag.getLagsForFreq(this.freq);
        this.timeFeatures = TimeFeature.timeFeaturesFromFreqStr(this.freq);
        this.historyLength = this.contextLength + lagsForFreq.get(lagsForFreq.size() - 1).intValue();
        this.instanceSampler = PredictionSplitSampler.newTestSplitSampler();
        if (!builder.useIsPad) {
            this.predictInputFields = PRED_INPUT_FIELDS;
            return;
        }
        int length = PRED_INPUT_FIELDS.length;
        this.predictInputFields = new String[length + 1];
        System.arraycopy(PRED_INPUT_FIELDS, 0, this.predictInputFields, 0, length);
        this.predictInputFields[length] = "PAST_" + FieldName.IS_PAD.name();
    }

    /* renamed from: processOutput, reason: merged with bridge method [inline-methods] */
    public Forecast m8processOutput(TranslatorContext translatorContext, NDList nDList) {
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        TimeSeriesData timeSeriesData = (TimeSeriesData) translatorContext.getAttachment("input");
        singletonOrThrow.attach((NDManager) translatorContext.getAttachment("manager"));
        return new SampleForecast(singletonOrThrow, timeSeriesData.getStartTime(), this.freq);
    }

    public NDList processInput(TranslatorContext translatorContext, TimeSeriesData timeSeriesData) {
        FieldName[] fieldNameArr;
        NDManager nDManager = translatorContext.getNDManager();
        translatorContext.setAttachment("input", timeSeriesData);
        translatorContext.setAttachment("manager", timeSeriesData.get(FieldName.TARGET).getManager());
        ArrayList arrayList = new ArrayList();
        arrayList.add(FieldName.FEAT_DYNAMIC_CAT);
        if (!this.useFeatStaticReal) {
            arrayList.add(FieldName.FEAT_STATIC_REAL);
        }
        if (!this.useFeatDynamicReal) {
            arrayList.add(FieldName.FEAT_DYNAMIC_REAL);
        }
        Field.removeFields(arrayList, timeSeriesData);
        if (!this.useFeatStaticCat) {
            Field.setField(FieldName.FEAT_STATIC_CAT, nDManager.zeros(new Shape(new long[]{1})), timeSeriesData);
        }
        if (!this.useFeatStaticReal) {
            Field.setField(FieldName.FEAT_STATIC_REAL, nDManager.zeros(new Shape(new long[]{1})), timeSeriesData);
        }
        Convert.asArray(FieldName.FEAT_STATIC_CAT, 1, DataType.INT32, timeSeriesData);
        Convert.asArray(FieldName.FEAT_STATIC_REAL, 1, timeSeriesData);
        Feature.addObservedValuesIndicator(nDManager, FieldName.TARGET, FieldName.OBSERVED_VALUES, timeSeriesData);
        Feature.addTimeFeature(nDManager, FieldName.START, FieldName.TARGET, FieldName.FEAT_TIME, this.timeFeatures, this.predictionLength, this.freq, timeSeriesData);
        Feature.addAgeFeature(nDManager, FieldName.TARGET, FieldName.FEAT_AGE, this.predictionLength, timeSeriesData);
        if (this.useFeatDynamicReal) {
            fieldNameArr = new FieldName[3];
            fieldNameArr[2] = FieldName.FEAT_DYNAMIC_REAL;
        } else {
            fieldNameArr = new FieldName[2];
        }
        fieldNameArr[0] = FieldName.FEAT_TIME;
        fieldNameArr[1] = FieldName.FEAT_AGE;
        Convert.vstackFeatures(FieldName.FEAT_TIME, fieldNameArr, timeSeriesData);
        Split.instanceSplit(nDManager, FieldName.TARGET, FieldName.IS_PAD, FieldName.START, FieldName.FORECAST_START, this.instanceSampler, this.historyLength, this.predictionLength, TIME_SERIES_FIELDS, 0.0f, timeSeriesData);
        return Field.selectField(this.predictInputFields, timeSeriesData).toNDList();
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> map) {
        Builder builder = new Builder();
        builder.configPreProcess(map);
        builder.configPostProcess(map);
        return builder;
    }
}
