package org.datavec.api.transform.transform.sequence;

import java.util.ArrayList;
import java.util.List;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.DoubleMetaData;
import org.datavec.api.transform.metadata.FloatMetaData;
import org.datavec.api.transform.metadata.IntegerMetaData;
import org.datavec.api.transform.metadata.LongMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.schema.SequenceSchema;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonInclude;

@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonIgnoreProperties({"inputSchema", "columnType"})
/* loaded from: input_file:org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.class */
public class SequenceDifferenceTransform implements Transform {
    private final String columnName;
    private final String newColumnName;
    private final int lookback;
    private final FirstStepMode firstStepMode;
    private final Writable specifiedValueWritable;
    private Schema inputSchema;
    private ColumnType columnType;

    /* loaded from: input_file:org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform$FirstStepMode.class */
    public enum FirstStepMode {
        Default,
        SpecifiedValue
    }

    public SequenceDifferenceTransform(String str) {
        this(str, str, 1, FirstStepMode.Default, null);
    }

    public SequenceDifferenceTransform(String str, String str2, int i) {
        this(str, str2, i, FirstStepMode.Default, null);
    }

    public SequenceDifferenceTransform(String str, String str2, int i, FirstStepMode firstStepMode, Writable writable) {
        if (firstStepMode != FirstStepMode.SpecifiedValue && writable != null) {
            throw new IllegalArgumentException("Specified value writable provided (" + writable + ") but firstStepMode != FirstStepMode.SpecifiedValue");
        }
        if (firstStepMode == FirstStepMode.SpecifiedValue && writable == null) {
            throw new IllegalArgumentException("Specified value writable is null but firstStepMode != FirstStepMode.SpecifiedValue");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Lookback period must be > 0. Got: lookback period = " + i);
        }
        this.columnName = str;
        this.newColumnName = str2;
        this.lookback = i;
        this.firstStepMode = firstStepMode;
        this.specifiedValueWritable = writable;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        return this.columnName;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return new String[]{columnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{columnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return this.columnName;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        if (!schema.hasColumn(this.columnName)) {
            throw new IllegalStateException("Invalid input schema: does not have column with name \"" + this.columnName + "\"\n. All schema names: " + schema.getColumnNames());
        }
        if (!(schema instanceof SequenceSchema)) {
            throw new IllegalStateException("Invalid input schema: expected a SequenceSchema, got " + schema.getClass());
        }
        ArrayList arrayList = new ArrayList(schema.numColumns());
        for (ColumnMetaData columnMetaData : schema.getColumnMetaData()) {
            if (this.columnName.equals(columnMetaData.getName())) {
                switch (columnMetaData.getColumnType()) {
                    case Integer:
                        arrayList.add(new IntegerMetaData(this.newColumnName));
                        break;
                    case Long:
                        arrayList.add(new LongMetaData(this.newColumnName));
                        break;
                    case Double:
                        arrayList.add(new DoubleMetaData(this.newColumnName));
                        break;
                    case Float:
                        arrayList.add(new FloatMetaData(this.newColumnName));
                        break;
                    case Time:
                        arrayList.add(new LongMetaData(this.newColumnName));
                        break;
                    case Categorical:
                    case Bytes:
                    case String:
                    case Boolean:
                    default:
                        throw new IllegalStateException("Cannot perform sequence difference on column of type " + columnMetaData.getColumnType());
                }
            } else {
                arrayList.add(columnMetaData);
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        if (!schema.hasColumn(this.columnName)) {
            throw new IllegalStateException("Invalid input schema: does not have column with name \"" + this.columnName + "\"\n. All schema names: " + schema.getColumnNames());
        }
        this.columnType = schema.getMetaData(this.columnName).getColumnType();
        this.inputSchema = schema;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema getInputSchema() {
        return this.inputSchema;
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        throw new UnsupportedOperationException("Only sequence operations are supported for SequenceDifferenceTransform. Attempting to apply SequenceDifferenceTransform on non-sequence data?");
    }

    @Override // org.datavec.api.transform.Transform
    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        int indexOfColumn = this.inputSchema.getIndexOfColumn(this.columnName);
        int size = list.size();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            List<Writable> list2 = list.get(i);
            ArrayList arrayList2 = new ArrayList(list2.size());
            for (int i2 = 0; i2 < list2.size(); i2++) {
                if (i2 != indexOfColumn) {
                    arrayList2.add(list2.get(i2));
                } else if (i2 >= this.lookback || this.firstStepMode != FirstStepMode.SpecifiedValue) {
                    Writable writable = list2.get(i2);
                    Writable writable2 = list.get(Math.max(0, i - this.lookback)).get(i2);
                    switch (this.columnType) {
                        case Integer:
                            arrayList2.add(new IntWritable(writable.toInt() - writable2.toInt()));
                            break;
                        case Long:
                        case Time:
                            arrayList2.add(new LongWritable(writable.toLong() - writable2.toLong()));
                            break;
                        case Double:
                            arrayList2.add(new DoubleWritable(writable.toDouble() - writable2.toDouble()));
                            break;
                        case Float:
                            arrayList2.add(new FloatWritable(writable.toFloat() - writable2.toFloat()));
                            break;
                        default:
                            throw new IllegalStateException("Cannot perform sequence difference on column of type " + this.columnType);
                    }
                } else {
                    arrayList2.add(this.specifiedValueWritable);
                }
            }
            arrayList.add(arrayList2);
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        throw new UnsupportedOperationException("Only sequence operations are supported for SequenceDifferenceTransform. Attempting to apply SequenceDifferenceTransform on non-sequence data?");
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException("Only sequence operations are supported for SequenceDifferenceTransform. Attempting to apply SequenceDifferenceTransform on non-sequence data?");
    }
}
