package org.datavec.api.transform.transform.categorical;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.metadata.IntegerMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.Writable;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties({"inputSchema", "columnIdx", "stateNames", "statesMap"})
/* loaded from: input_file:org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.class */
public class CategoricalToIntegerTransform extends BaseTransform {
    private String columnName;
    private int columnIdx = -1;
    private List<String> stateNames;
    private Map<String, Integer> statesMap;

    public CategoricalToIntegerTransform(@JsonProperty("columnName") String str) {
        this.columnName = str;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform, org.datavec.api.transform.ColumnOp
    public void setInputSchema(Schema schema) {
        super.setInputSchema(schema);
        this.columnIdx = schema.getIndexOfColumn(this.columnName);
        ColumnMetaData metaData = schema.getMetaData(this.columnName);
        if (!(metaData instanceof CategoricalMetaData)) {
            throw new IllegalStateException("Cannot convert column \"" + this.columnName + "\" from categorical to one-hot: column is not categorical (is: " + metaData.getColumnType() + ")");
        }
        this.stateNames = ((CategoricalMetaData) metaData).getStateNames();
        this.statesMap = new HashMap(this.stateNames.size());
        for (int i = 0; i < this.stateNames.size(); i++) {
            this.statesMap.put(this.stateNames.get(i), Integer.valueOf(i));
        }
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        List<String> columnNames = schema.getColumnNames();
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        int i = 0;
        Iterator<ColumnMetaData> it = columnMetaData.iterator();
        ArrayList arrayList = new ArrayList(schema.numColumns());
        for (String str : columnNames) {
            ColumnMetaData next = it.next();
            int i2 = i;
            i++;
            if (i2 == this.columnIdx) {
                arrayList.add(new IntegerMetaData(next.getName(), 0, Integer.valueOf(this.stateNames.size() - 1)));
            } else {
                arrayList.add(next);
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (list.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + list.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + toString());
        }
        int columnIdx = getColumnIdx();
        ArrayList arrayList = new ArrayList(list.size() + this.stateNames.size());
        int i = 0;
        for (Writable writable : list) {
            int i2 = i;
            i++;
            if (i2 == columnIdx) {
                String obj = writable.toString();
                Integer num = this.statesMap.get(obj);
                if (num == null) {
                    throw new IllegalStateException("Cannot convert categorical value to integer value: input value (\"" + obj + "\") is not in the list of known categories (state names/categories: " + this.stateNames + ")");
                }
                arrayList.add(new IntWritable(num.intValue()));
            } else {
                arrayList.add(writable);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        String obj2 = obj.toString();
        Integer num = this.statesMap.get(obj2);
        if (num == null) {
            throw new IllegalStateException("Cannot convert categorical value to integer value: input value (\"" + obj2 + "\") is not in the list of known categories (state names/categories: " + this.stateNames + ")");
        }
        return num;
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        return null;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof CategoricalToIntegerTransform)) {
            return false;
        }
        CategoricalToIntegerTransform categoricalToIntegerTransform = (CategoricalToIntegerTransform) obj;
        return this.columnName == null ? categoricalToIntegerTransform.columnName == null : this.columnName.equals(categoricalToIntegerTransform.columnName);
    }

    public int hashCode() {
        return this.columnName.hashCode();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof CategoricalToIntegerTransform;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        return this.columnName;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        return new String[]{columnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{columnName()};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        return this.columnName;
    }

    public String getColumnName() {
        return this.columnName;
    }

    public int getColumnIdx() {
        return this.columnIdx;
    }

    public List<String> getStateNames() {
        return this.stateNames;
    }

    public Map<String, Integer> getStatesMap() {
        return this.statesMap;
    }

    public void setColumnName(String str) {
        this.columnName = str;
    }

    public void setColumnIdx(int i) {
        this.columnIdx = i;
    }

    public void setStateNames(List<String> list) {
        this.stateNames = list;
    }

    public void setStatesMap(Map<String, Integer> map) {
        this.statesMap = map;
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "CategoricalToIntegerTransform(columnName=" + getColumnName() + ", columnIdx=" + getColumnIdx() + ", stateNames=" + getStateNames() + ", statesMap=" + getStatesMap() + ")";
    }
}
