package org.datavec.api.transform.sequence.trim;

import java.util.ArrayList;
import java.util.List;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;
import org.nd4j.common.base.Preconditions;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"schema"})
/* loaded from: input_file:org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.class */
public class SequenceTrimToLengthTransform implements Transform {
    private int maxLength;
    private Mode mode;
    private List<Writable> pad;
    private Schema schema;

    /* loaded from: input_file:org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform$Mode.class */
    public enum Mode {
        TRIM,
        TRIM_OR_PAD
    }

    public SequenceTrimToLengthTransform(@JsonProperty("maxLength") int i, @JsonProperty("mode") Mode mode, @JsonProperty("pad") List<Writable> list) {
        Preconditions.checkState(i > 0, "Maximum length must be > 0, got %s", i);
        Preconditions.checkState(mode == Mode.TRIM || list != null, "If mode == Mode.TRIM_OR_PAD ");
        this.maxLength = i;
        this.mode = mode;
        this.pad = list;
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        throw new UnsupportedOperationException("SequenceTrimToLengthTransform cannot be applied to non-sequence values");
    }

    @Override // org.datavec.api.transform.Transform
    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        if (this.mode == Mode.TRIM) {
            return list.size() <= this.maxLength ? list : new ArrayList(list.subList(0, this.maxLength));
        }
        if (list.size() == this.maxLength) {
            return list;
        }
        if (list.size() > this.maxLength) {
            return new ArrayList(list.subList(0, this.maxLength));
        }
        Preconditions.checkState(list.size() == 0 || list.get(0).size() == this.pad.size(), "Invalid padding values: %s padding values were provided, but data has %s values per time step (columns)", this.pad.size(), list.get(0).size());
        ArrayList arrayList = new ArrayList(this.maxLength);
        arrayList.addAll(list);
        while (arrayList.size() < this.maxLength) {
            arrayList.add(this.pad);
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        throw new UnsupportedOperationException();
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException();
    }

    @Override // org.datavec.api.transform.Operation
    public Schema transform(Schema schema) {
        return schema;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        this.schema = schema;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema getInputSchema() {
        return this.schema;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        return null;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return (String[]) this.schema.getColumnNames().toArray(new String[this.schema.numColumns()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return outputColumnNames();
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return null;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof SequenceTrimToLengthTransform)) {
            return false;
        }
        SequenceTrimToLengthTransform sequenceTrimToLengthTransform = (SequenceTrimToLengthTransform) obj;
        if (!sequenceTrimToLengthTransform.canEqual(this) || getMaxLength() != sequenceTrimToLengthTransform.getMaxLength()) {
            return false;
        }
        Mode mode = getMode();
        Mode mode2 = sequenceTrimToLengthTransform.getMode();
        if (mode == null) {
            if (mode2 != null) {
                return false;
            }
        } else if (!mode.equals(mode2)) {
            return false;
        }
        List<Writable> pad = getPad();
        List<Writable> pad2 = sequenceTrimToLengthTransform.getPad();
        return pad == null ? pad2 == null : pad.equals(pad2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof SequenceTrimToLengthTransform;
    }

    public int hashCode() {
        int maxLength = (1 * 59) + getMaxLength();
        Mode mode = getMode();
        int hashCode = (maxLength * 59) + (mode == null ? 43 : mode.hashCode());
        List<Writable> pad = getPad();
        return (hashCode * 59) + (pad == null ? 43 : pad.hashCode());
    }

    public int getMaxLength() {
        return this.maxLength;
    }

    public Mode getMode() {
        return this.mode;
    }

    public List<Writable> getPad() {
        return this.pad;
    }

    public Schema getSchema() {
        return this.schema;
    }

    public void setMaxLength(int i) {
        this.maxLength = i;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public void setPad(List<Writable> list) {
        this.pad = list;
    }

    public void setSchema(Schema schema) {
        this.schema = schema;
    }

    public String toString() {
        return "SequenceTrimToLengthTransform(maxLength=" + getMaxLength() + ", mode=" + getMode() + ", pad=" + getPad() + ", schema=" + getSchema() + ")";
    }
}
