package org.datavec.api.transform.transform.column;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.datavec.api.transform.ColumnOp;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.util.StringUtils;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"inputSchema", "columnsToRemoveIdx", "indicesToRemove"})
/* loaded from: input_file:org/datavec/api/transform/transform/column/RemoveColumnsTransform.class */
public class RemoveColumnsTransform extends BaseTransform implements ColumnOp {
    private int[] columnsToRemoveIdx;
    private String[] columnsToRemove;
    private Set<Integer> indicesToRemove;
    private String[] leftOverColumns;

    public RemoveColumnsTransform(@JsonProperty("columnsToRemove") String... strArr) {
        this.columnsToRemove = strArr;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform, org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        super.setInputSchema(schema);
        for (String str : this.columnsToRemove) {
            if (!this.inputSchema.hasColumn(str)) {
                throw new IllegalStateException("Cannot remove column \"" + str + "\": column does not exist. All columns for input schema: " + this.inputSchema.getColumnNames());
            }
        }
        this.leftOverColumns = new String[schema.numColumns() - this.columnsToRemove.length];
        this.indicesToRemove = new HashSet();
        int i = 0;
        this.columnsToRemoveIdx = new int[this.columnsToRemove.length];
        for (String str2 : this.columnsToRemove) {
            int indexOfColumn = schema.getIndexOfColumn(str2);
            if (indexOfColumn < 0) {
                throw new RuntimeException("Column \"" + str2 + "\" not found");
            }
            int i2 = i;
            i++;
            this.columnsToRemoveIdx[i2] = indexOfColumn;
            this.indicesToRemove.add(Integer.valueOf(indexOfColumn));
        }
        int i3 = 0;
        List asList = Arrays.asList(this.columnsToRemove);
        List<String> columnNames = schema.getColumnNames();
        for (int i4 = 0; i4 < schema.numColumns(); i4++) {
            if (!asList.contains(columnNames.get(i4))) {
                int i5 = i3;
                i3++;
                this.leftOverColumns[i5] = columnNames.get(i4);
            }
        }
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        int numColumns = schema.numColumns() - this.columnsToRemove.length;
        if (numColumns <= 0) {
            throw new IllegalStateException("Number of columns after executing operation is " + numColumns + " (is <= 0). origColumns = " + schema.getColumnNames() + ", toRemove = " + Arrays.toString(this.columnsToRemove));
        }
        List<String> columnNames = schema.getColumnNames();
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        HashSet hashSet = new HashSet();
        Collections.addAll(hashSet, this.columnsToRemove);
        ArrayList arrayList = new ArrayList(numColumns);
        Iterator<ColumnMetaData> it = columnMetaData.iterator();
        for (String str : columnNames) {
            ColumnMetaData next = it.next();
            if (!hashSet.contains(str)) {
                arrayList.add(next);
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (list.size() != this.inputSchema.numColumns()) {
            ArrayList arrayList = new ArrayList();
            Iterator<Writable> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(it.next().toString());
            }
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + list.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + toString() + " and record " + StringUtils.join(",", arrayList));
        }
        ArrayList arrayList2 = new ArrayList(list.size() - this.columnsToRemove.length);
        int i = 0;
        for (Writable writable : list) {
            int i2 = i;
            i++;
            if (!this.indicesToRemove.contains(Integer.valueOf(i2))) {
                arrayList2.add(writable);
            }
        }
        return arrayList2;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        throw new UnsupportedOperationException("Unable to map. Please treat this as a special operation. This should be handled by your implementation.");
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        throw new UnsupportedOperationException("Unable to map. Please treat this as a special operation. This should be handled by your implementation.");
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "RemoveColumnsTransform(" + Arrays.toString(this.columnsToRemove) + ")";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return Arrays.equals(this.columnsToRemove, ((RemoveColumnsTransform) obj).columnsToRemove);
    }

    public int hashCode() {
        return Arrays.hashCode(this.columnsToRemove);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        return outputColumnNames()[0];
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return this.leftOverColumns;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return (String[]) this.inputSchema.getColumnNames().toArray(new String[this.inputSchema.numColumns()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return columnNames()[0];
    }
}
