package org.datavec.api.transform.transform.column;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.transform.Transform;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.Writable;

/* loaded from: input_file:org/datavec/api/transform/transform/column/ReorderColumnsTransform.class */
public class ReorderColumnsTransform implements Transform {
    private final List<String> newOrder;
    private Schema inputSchema;
    private int[] outputOrder;

    public ReorderColumnsTransform(String... strArr) {
        this((List<String>) Arrays.asList(strArr));
    }

    public ReorderColumnsTransform(List<String> list) {
        this.newOrder = list;
    }

    @Override // org.datavec.api.transform.Transform
    public Schema transform(Schema schema) {
        for (String str : this.newOrder) {
            if (!schema.hasColumn(str)) {
                throw new IllegalStateException("Input schema does not contain column with name \"" + str + "\"");
            }
        }
        if (schema.numColumns() < this.newOrder.size()) {
            throw new IllegalArgumentException("Schema has " + schema.numColumns() + " column but newOrder has " + this.newOrder.size() + " columns");
        }
        List<String> columnNames = schema.getColumnNames();
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        boolean[] zArr = new boolean[columnNames.size()];
        Iterator<String> it = this.newOrder.iterator();
        while (it.hasNext()) {
            int indexOfColumn = schema.getIndexOfColumn(it.next());
            arrayList.add(columnNames.get(indexOfColumn));
            arrayList2.add(columnMetaData.get(indexOfColumn));
            zArr[indexOfColumn] = true;
        }
        for (int i = 0; i < zArr.length; i++) {
            if (!zArr[i]) {
                arrayList.add(columnNames.get(i));
                arrayList2.add(columnMetaData.get(i));
            }
        }
        return schema.newSchema(arrayList, arrayList2);
    }

    @Override // org.datavec.api.transform.Transform
    public void setInputSchema(Schema schema) {
        for (String str : this.newOrder) {
            if (!schema.hasColumn(str)) {
                throw new IllegalStateException("Input schema does not contain column with name \"" + str + "\"");
            }
        }
        if (schema.numColumns() < this.newOrder.size()) {
            throw new IllegalArgumentException("Schema has " + schema.numColumns() + " columns but newOrder has " + this.newOrder.size() + " columns");
        }
        List<String> columnNames = schema.getColumnNames();
        this.outputOrder = new int[columnNames.size()];
        boolean[] zArr = new boolean[columnNames.size()];
        int i = 0;
        Iterator<String> it = this.newOrder.iterator();
        while (it.hasNext()) {
            int indexOfColumn = schema.getIndexOfColumn(it.next());
            zArr[indexOfColumn] = true;
            int i2 = i;
            i++;
            this.outputOrder[i2] = indexOfColumn;
        }
        for (int i3 = 0; i3 < zArr.length; i3++) {
            if (!zArr[i3]) {
                int i4 = i;
                i++;
                this.outputOrder[i4] = i3;
            }
        }
    }

    @Override // org.datavec.api.transform.Transform
    public Schema getInputSchema() {
        return this.inputSchema;
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        ArrayList arrayList = new ArrayList();
        for (int i : this.outputOrder) {
            arrayList.add(list.get(i));
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public List<List<Writable>> mapSequence(List<List<Writable>> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<List<Writable>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(map(it.next()));
        }
        return arrayList;
    }
}
