package org.apache.spark.ml.feature;

import org.apache.spark.annotation.Experimental;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.feature.VectorIndexerParams;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.Param;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.param.shared.HasInputCol;
import org.apache.spark.ml.param.shared.HasOutputCol;
import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.ml.util.SchemaUtils$;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.util.collection.OpenHashSet;
import scala.Array$;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.immutable.Map;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: VectorIndexer.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005\rf\u0001B\u0001\u0003\u00015\u0011QBV3di>\u0014\u0018J\u001c3fq\u0016\u0014(BA\u0002\u0005\u0003\u001d1W-\u0019;ve\u0016T!!\u0002\u0004\u0002\u00055d'BA\u0004\t\u0003\u0015\u0019\b/\u0019:l\u0015\tI!\"\u0001\u0004ba\u0006\u001c\u0007.\u001a\u0006\u0002\u0017\u0005\u0019qN]4\u0004\u0001M\u0019\u0001A\u0004\f\u0011\u0007=\u0001\"#D\u0001\u0005\u0013\t\tBAA\u0005FgRLW.\u0019;peB\u00111\u0003F\u0007\u0002\u0005%\u0011QC\u0001\u0002\u0013-\u0016\u001cGo\u001c:J]\u0012,\u00070\u001a:N_\u0012,G\u000e\u0005\u0002\u0014/%\u0011\u0001D\u0001\u0002\u0014-\u0016\u001cGo\u001c:J]\u0012,\u00070\u001a:QCJ\fWn\u001d\u0005\t5\u0001\u0011)\u0019!C!7\u0005\u0019Q/\u001b3\u0016\u0003q\u0001\"!H\u0012\u000f\u0005y\tS\"A\u0010\u000b\u0003\u0001\nQa]2bY\u0006L!AI\u0010\u0002\rA\u0013X\rZ3g\u0013\t!SE\u0001\u0004TiJLgn\u001a\u0006\u0003E}A\u0001b\n\u0001\u0003\u0002\u0003\u0006I\u0001H\u0001\u0005k&$\u0007\u0005C\u0003*\u0001\u0011\u0005!&\u0001\u0004=S:LGO\u0010\u000b\u0003W1\u0002\"a\u0005\u0001\t\u000biA\u0003\u0019\u0001\u000f\t\u000b%\u0002A\u0011\u0001\u0018\u0015\u0003-BQ\u0001\r\u0001\u0005\u0002E\n\u0001c]3u\u001b\u0006D8)\u0019;fO>\u0014\u0018.Z:\u0015\u0005I\u001aT\"\u0001\u0001\t\u000bQz\u0003\u0019A\u001b\u0002\u000bY\fG.^3\u0011\u0005y1\u0014BA\u001c \u0005\rIe\u000e\u001e\u0005\u0006s\u0001!\tAO\u0001\fg\u0016$\u0018J\u001c9vi\u000e{G\u000e\u0006\u00023w!)A\u0007\u000fa\u00019!)Q\b\u0001C\u0001}\u0005a1/\u001a;PkR\u0004X\u000f^\"pYR\u0011!g\u0010\u0005\u0006iq\u0002\r\u0001\b\u0005\u0006\u0003\u0002!\tEQ\u0001\u0004M&$HC\u0001\nD\u0011\u0015!\u0005\t1\u0001F\u0003\u001d!\u0017\r^1tKR\u0004\"AR%\u000e\u0003\u001dS!\u0001\u0013\u0004\u0002\u0007M\fH.\u0003\u0002K\u000f\nIA)\u0019;b\rJ\fW.\u001a\u0005\u0006\u0019\u0002!\t%T\u0001\u0010iJ\fgn\u001d4pe6\u001c6\r[3nCR\u0011a\n\u0016\t\u0003\u001fJk\u0011\u0001\u0015\u0006\u0003#\u001e\u000bQ\u0001^=qKNL!a\u0015)\u0003\u0015M#(/^2u)f\u0004X\rC\u0003V\u0017\u0002\u0007a*\u0001\u0004tG\",W.\u0019\u0005\u0006/\u0002!\t\u0005W\u0001\u0005G>\u0004\u0018\u0010\u0006\u0002,3\")!L\u0016a\u00017\u0006)Q\r\u001f;sCB\u0011AlX\u0007\u0002;*\u0011a\fB\u0001\u0006a\u0006\u0014\u0018-\\\u0005\u0003Av\u0013\u0001\u0002U1sC6l\u0015\r\u001d\u0015\u0003\u0001\t\u0004\"a\u00194\u000e\u0003\u0011T!!\u001a\u0004\u0002\u0015\u0005tgn\u001c;bi&|g.\u0003\u0002hI\naQ\t\u001f9fe&lWM\u001c;bY\u001e)\u0011N\u0001E\u0005U\u0006ia+Z2u_JLe\u000eZ3yKJ\u0004\"aE6\u0007\u000b\u0005\u0011\u0001\u0012\u00027\u0014\u0007-l\u0007\u000f\u0005\u0002\u001f]&\u0011qn\b\u0002\u0007\u0003:L(+\u001a4\u0011\u0005y\t\u0018B\u0001: \u00051\u0019VM]5bY&T\u0018M\u00197f\u0011\u0015I3\u000e\"\u0001u)\u0005Qg\u0001\u0002<l\u0001]\u0014QbQ1uK\u001e|'/_*uCR\u001c8cA;na\"A\u00110\u001eBC\u0002\u0013%!0A\u0006ok64U-\u0019;ve\u0016\u001cX#A\u001b\t\u0011q,(\u0011!Q\u0001\nU\nAB\\;n\r\u0016\fG/\u001e:fg\u0002B\u0001B`;\u0003\u0006\u0004%IA_\u0001\u000e[\u0006D8)\u0019;fO>\u0014\u0018.Z:\t\u0013\u0005\u0005QO!A!\u0002\u0013)\u0014AD7bq\u000e\u000bG/Z4pe&,7\u000f\t\u0005\u0007SU$\t!!\u0002\u0015\r\u0005\u001d\u00111BA\u0007!\r\tI!^\u0007\u0002W\"1\u00110a\u0001A\u0002UBaA`A\u0002\u0001\u0004)\u0004\"CA\tk\n\u0007I\u0011BA\n\u0003A1W-\u0019;ve\u00164\u0016\r\\;f'\u0016$8/\u0006\u0002\u0002\u0016A)a$a\u0006\u0002\u001c%\u0019\u0011\u0011D\u0010\u0003\u000b\u0005\u0013(/Y=\u0011\r\u0005u\u0011qEA\u0016\u001b\t\tyB\u0003\u0003\u0002\"\u0005\r\u0012AC2pY2,7\r^5p]*\u0019\u0011Q\u0005\u0004\u0002\tU$\u0018\u000e\\\u0005\u0005\u0003S\tyBA\u0006Pa\u0016t\u0007*Y:i'\u0016$\bc\u0001\u0010\u0002.%\u0019\u0011qF\u0010\u0003\r\u0011{WO\u00197f\u0011!\t\u0019$\u001eQ\u0001\n\u0005U\u0011!\u00054fCR,(/\u001a,bYV,7+\u001a;tA!9\u0011qG;\u0005\u0002\u0005e\u0012!B7fe\u001e,G\u0003BA\u0004\u0003wA\u0001\"!\u0010\u00026\u0001\u0007\u0011qA\u0001\u0006_RDWM\u001d\u0005\b\u0003\u0003*H\u0011AA\"\u0003%\tG\r\u001a,fGR|'\u000f\u0006\u0003\u0002F\u0005-\u0003c\u0001\u0010\u0002H%\u0019\u0011\u0011J\u0010\u0003\tUs\u0017\u000e\u001e\u0005\t\u0003\u001b\ny\u00041\u0001\u0002P\u0005\ta\u000f\u0005\u0003\u0002R\u0005mSBAA*\u0015\u0011\t)&a\u0016\u0002\r1Lg.\u00197h\u0015\r\tIFB\u0001\u0006[2d\u0017NY\u0005\u0005\u0003;\n\u0019F\u0001\u0004WK\u000e$xN\u001d\u0005\b\u0003C*H\u0011AA2\u0003=9W\r^\"bi\u0016<wN]=NCB\u001cXCAA3!\u0019i\u0012qM\u001b\u0002l%\u0019\u0011\u0011N\u0013\u0003\u00075\u000b\u0007\u000f\u0005\u0004\u001e\u0003O\nY#\u000e\u0005\b\u0003_*H\u0011BA9\u00039\tG\r\u001a#f]N,g+Z2u_J$B!!\u0012\u0002t!A\u0011QOA7\u0001\u0004\t9(\u0001\u0002emB!\u0011\u0011KA=\u0013\u0011\tY(a\u0015\u0003\u0017\u0011+gn]3WK\u000e$xN\u001d\u0005\b\u0003\u007f*H\u0011BAA\u0003=\tG\rZ*qCJ\u001cXMV3di>\u0014H\u0003BA#\u0003\u0007C\u0001\"!\"\u0002~\u0001\u0007\u0011qQ\u0001\u0003gZ\u0004B!!\u0015\u0002\n&!\u00111RA*\u00051\u0019\u0006/\u0019:tKZ+7\r^8s\u0011%\tyi[A\u0001\n\u0013\t\t*A\u0006sK\u0006$'+Z:pYZ,GCAAJ!\u0011\t)*a(\u000e\u0005\u0005]%\u0002BAM\u00037\u000bA\u0001\\1oO*\u0011\u0011QT\u0001\u0005U\u00064\u0018-\u0003\u0003\u0002\"\u0006]%AB(cU\u0016\u001cG\u000f")
@Experimental
/* loaded from: input_file:org/apache/spark/ml/feature/VectorIndexer.class */
public class VectorIndexer extends Estimator<VectorIndexerModel> implements VectorIndexerParams {
    private final String uid;
    private final IntParam maxCategories;
    private final Param<String> outputCol;
    private final Param<String> inputCol;

    /* compiled from: VectorIndexer.scala */
    /* loaded from: input_file:org/apache/spark/ml/feature/VectorIndexer$CategoryStats.class */
    public static class CategoryStats implements Serializable {
        private final int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures;
        private final int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories;
        private final OpenHashSet<Object>[] featureValueSets;

        public int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures() {
            return this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures;
        }

        public int org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories() {
            return this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories;
        }

        private OpenHashSet<Object>[] featureValueSets() {
            return this.featureValueSets;
        }

        public CategoryStats merge(CategoryStats categoryStats) {
            Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(featureValueSets()).zip(Predef$.MODULE$.wrapRefArray(categoryStats.featureValueSets()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach(new VectorIndexer$CategoryStats$$anonfun$merge$1(this));
            return this;
        }

        public void addVector(Vector vector) {
            Predef$.MODULE$.require(vector.size() == org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures(), new VectorIndexer$CategoryStats$$anonfun$addVector$1(this, vector));
            if (vector instanceof DenseVector) {
                addDenseVector((DenseVector) vector);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                if (!(vector instanceof SparseVector)) {
                    throw new MatchError(vector);
                }
                addSparseVector((SparseVector) vector);
                BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
            }
        }

        public Map<Object, Map<Object, Object>> getCategoryMaps() {
            return Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(featureValueSets()).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).filter(new VectorIndexer$CategoryStats$$anonfun$getCategoryMaps$1(this))).map(new VectorIndexer$CategoryStats$$anonfun$getCategoryMaps$2(this), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).toMap(Predef$.MODULE$.conforms());
        }

        private void addDenseVector(DenseVector denseVector) {
            int size = denseVector.size();
            for (int i = 0; i < size; i++) {
                if (featureValueSets()[i].size() <= org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories()) {
                    featureValueSets()[i].add(BoxesRunTime.boxToDouble(denseVector.apply(i)));
                }
            }
        }

        private void addSparseVector(SparseVector sparseVector) {
            double d;
            int i = 0;
            int size = sparseVector.size();
            for (int i2 = 0; i2 < size; i2++) {
                if (i >= sparseVector.indices().length || i2 != sparseVector.indices()[i]) {
                    d = 0.0d;
                } else {
                    i++;
                    d = sparseVector.values()[i - 1];
                }
                double d2 = d;
                if (featureValueSets()[i2].size() <= org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories()) {
                    featureValueSets()[i2].add(BoxesRunTime.boxToDouble(d2));
                }
            }
        }

        public CategoryStats(int i, int i2) {
            this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$numFeatures = i;
            this.org$apache$spark$ml$feature$VectorIndexer$CategoryStats$$maxCategories = i2;
            this.featureValueSets = (OpenHashSet[]) Array$.MODULE$.fill(i, new VectorIndexer$CategoryStats$$anonfun$5(this), ClassTag$.MODULE$.apply(OpenHashSet.class));
        }
    }

    @Override // org.apache.spark.ml.feature.VectorIndexerParams
    public IntParam maxCategories() {
        return this.maxCategories;
    }

    @Override // org.apache.spark.ml.feature.VectorIndexerParams
    public void org$apache$spark$ml$feature$VectorIndexerParams$_setter_$maxCategories_$eq(IntParam intParam) {
        this.maxCategories = intParam;
    }

    @Override // org.apache.spark.ml.feature.VectorIndexerParams
    public int getMaxCategories() {
        return VectorIndexerParams.Cclass.getMaxCategories(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final Param<String> outputCol() {
        return this.outputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final void org$apache$spark$ml$param$shared$HasOutputCol$_setter_$outputCol_$eq(Param param) {
        this.outputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasOutputCol
    public final String getOutputCol() {
        return HasOutputCol.Cclass.getOutputCol(this);
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final Param<String> inputCol() {
        return this.inputCol;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final void org$apache$spark$ml$param$shared$HasInputCol$_setter_$inputCol_$eq(Param param) {
        this.inputCol = param;
    }

    @Override // org.apache.spark.ml.param.shared.HasInputCol
    public final String getInputCol() {
        return HasInputCol.Cclass.getInputCol(this);
    }

    @Override // org.apache.spark.ml.util.Identifiable
    public String uid() {
        return this.uid;
    }

    public VectorIndexer setMaxCategories(int i) {
        return (VectorIndexer) set((Param<IntParam>) maxCategories(), (IntParam) BoxesRunTime.boxToInteger(i));
    }

    public VectorIndexer setInputCol(String str) {
        return (VectorIndexer) set((Param<Param<String>>) inputCol(), (Param<String>) str);
    }

    public VectorIndexer setOutputCol(String str) {
        return (VectorIndexer) set((Param<Param<String>>) outputCol(), (Param<String>) str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.apache.spark.ml.Estimator
    public VectorIndexerModel fit(DataFrame dataFrame) {
        transformSchema(dataFrame.schema(), true);
        Row[] take = dataFrame.select((String) $(inputCol()), Predef$.MODULE$.wrapRefArray(new String[0])).take(1);
        Predef$.MODULE$.require(take.length == 1, new VectorIndexer$$anonfun$fit$1(this));
        int size = ((Vector) take[0].getAs(0)).size();
        RDD map = dataFrame.select((String) $(inputCol()), Predef$.MODULE$.wrapRefArray(new String[0])).map(new VectorIndexer$$anonfun$2(this), ClassTag$.MODULE$.apply(Vector.class));
        return (VectorIndexerModel) copyValues(new VectorIndexerModel(uid(), size, ((CategoryStats) map.mapPartitions(new VectorIndexer$$anonfun$3(this, size, BoxesRunTime.unboxToInt($(maxCategories()))), map.mapPartitions$default$2(), ClassTag$.MODULE$.apply(CategoryStats.class)).reduce(new VectorIndexer$$anonfun$4(this))).getCategoryMaps()).setParent(this), copyValues$default$2());
    }

    @Override // org.apache.spark.ml.PipelineStage
    public StructType transformSchema(StructType structType) {
        DataType vectorUDT = new VectorUDT();
        Predef$.MODULE$.require(isDefined(inputCol()), new VectorIndexer$$anonfun$transformSchema$2(this));
        Predef$.MODULE$.require(isDefined(outputCol()), new VectorIndexer$$anonfun$transformSchema$3(this));
        SchemaUtils$.MODULE$.checkColumnType(structType, (String) $(inputCol()), vectorUDT);
        return SchemaUtils$.MODULE$.appendColumn(structType, (String) $(outputCol()), vectorUDT);
    }

    @Override // org.apache.spark.ml.Estimator, org.apache.spark.ml.PipelineStage, org.apache.spark.ml.param.Params
    public VectorIndexer copy(ParamMap paramMap) {
        return (VectorIndexer) defaultCopy(paramMap);
    }

    public VectorIndexer(String str) {
        this.uid = str;
        HasInputCol.Cclass.$init$(this);
        HasOutputCol.Cclass.$init$(this);
        VectorIndexerParams.Cclass.$init$(this);
    }

    public VectorIndexer() {
        this(Identifiable$.MODULE$.randomUID("vecIdx"));
    }
}
