package org.elasticsearch.search.retriever;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.common.ParsingException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.features.NodeFeature;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.search.retriever.rankdoc.RankDocsQueryBuilder;
import org.elasticsearch.search.vectors.ExactKnnQueryBuilder;
import org.elasticsearch.search.vectors.KnnSearchBuilder;
import org.elasticsearch.search.vectors.QueryVectorBuilder;
import org.elasticsearch.search.vectors.VectorData;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;

/* loaded from: input_file:org/elasticsearch/search/retriever/KnnRetrieverBuilder.class */
public final class KnnRetrieverBuilder extends RetrieverBuilder {
    public static final String NAME = "knn";
    public static final NodeFeature KNN_RETRIEVER_SUPPORTED;
    public static final ParseField FIELD_FIELD;
    public static final ParseField K_FIELD;
    public static final ParseField NUM_CANDS_FIELD;
    public static final ParseField QUERY_VECTOR_FIELD;
    public static final ParseField QUERY_VECTOR_BUILDER_FIELD;
    public static final ParseField VECTOR_SIMILARITY;
    public static final ConstructingObjectParser<KnnRetrieverBuilder, RetrieverParserContext> PARSER;
    private final String field;
    private final Supplier<float[]> queryVector;
    private final QueryVectorBuilder queryVectorBuilder;
    private final int k;
    private final int numCands;
    private final Float similarity;
    static final /* synthetic */ boolean $assertionsDisabled;

    public static KnnRetrieverBuilder fromXContent(XContentParser xContentParser, RetrieverParserContext retrieverParserContext) throws IOException {
        if (retrieverParserContext.clusterSupportsFeature(KNN_RETRIEVER_SUPPORTED)) {
            return (KnnRetrieverBuilder) PARSER.apply(xContentParser, retrieverParserContext);
        }
        throw new ParsingException(xContentParser.getTokenLocation(), "unknown retriever [knn]", new Object[0]);
    }

    public KnnRetrieverBuilder(String str, float[] fArr, QueryVectorBuilder queryVectorBuilder, int i, int i2, Float f) {
        if (fArr == null && queryVectorBuilder == null) {
            throw new IllegalArgumentException(Strings.format("either [%s] or [%s] must be provided", QUERY_VECTOR_FIELD.getPreferredName(), QUERY_VECTOR_BUILDER_FIELD.getPreferredName()));
        }
        if (fArr != null && queryVectorBuilder != null) {
            throw new IllegalArgumentException(Strings.format("only one of [%s] and [%s] must be provided", QUERY_VECTOR_FIELD.getPreferredName(), QUERY_VECTOR_BUILDER_FIELD.getPreferredName()));
        }
        this.field = str;
        this.queryVector = fArr != null ? () -> {
            return fArr;
        } : null;
        this.queryVectorBuilder = queryVectorBuilder;
        this.k = i;
        this.numCands = i2;
        this.similarity = f;
    }

    private KnnRetrieverBuilder(KnnRetrieverBuilder knnRetrieverBuilder, Supplier<float[]> supplier, QueryVectorBuilder queryVectorBuilder) {
        this.queryVector = supplier;
        this.queryVectorBuilder = queryVectorBuilder;
        this.field = knnRetrieverBuilder.field;
        this.k = knnRetrieverBuilder.k;
        this.numCands = knnRetrieverBuilder.numCands;
        this.similarity = knnRetrieverBuilder.similarity;
        this.retrieverName = knnRetrieverBuilder.retrieverName;
        this.preFilterQueryBuilders = knnRetrieverBuilder.preFilterQueryBuilders;
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public String getName() {
        return "knn";
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.elasticsearch.search.retriever.RetrieverBuilder, org.elasticsearch.index.query.Rewriteable
    public RetrieverBuilder rewrite(QueryRewriteContext queryRewriteContext) throws IOException {
        List<QueryBuilder> rewritePreFilters = rewritePreFilters(queryRewriteContext);
        if (rewritePreFilters != this.preFilterQueryBuilders) {
            KnnRetrieverBuilder knnRetrieverBuilder = new KnnRetrieverBuilder(this, this.queryVector, this.queryVectorBuilder);
            knnRetrieverBuilder.preFilterQueryBuilders = rewritePreFilters;
            return knnRetrieverBuilder;
        }
        if (this.queryVectorBuilder == null) {
            return super.rewrite(queryRewriteContext);
        }
        SetOnce setOnce = new SetOnce();
        queryRewriteContext.registerAsyncAction((client, actionListener) -> {
            this.queryVectorBuilder.buildVector(client, actionListener.delegateFailureAndWrap((actionListener, fArr) -> {
                setOnce.set(fArr);
                if (fArr == null) {
                    actionListener.onFailure(new IllegalArgumentException(Strings.format("[%s] with name [%s] returned null query_vector", QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder.getWriteableName())));
                } else {
                    actionListener.onResponse(null);
                }
            }));
        });
        return new KnnRetrieverBuilder(this, () -> {
            return (float[]) setOnce.get();
        }, null);
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public QueryBuilder topDocsQuery() {
        if (!$assertionsDisabled && this.queryVector == null) {
            throw new AssertionError("query vector must be materialized at this point");
        }
        if (!$assertionsDisabled && this.rankDocs == null) {
            throw new AssertionError("rankDocs should have been materialized by now");
        }
        RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(this.rankDocs, null, true);
        if (this.preFilterQueryBuilders.isEmpty()) {
            return rankDocsQueryBuilder.queryName(this.retrieverName);
        }
        BoolQueryBuilder must = new BoolQueryBuilder().must(rankDocsQueryBuilder);
        List<QueryBuilder> list = this.preFilterQueryBuilders;
        Objects.requireNonNull(must);
        list.forEach(must::filter);
        return must.queryName(this.retrieverName);
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public QueryBuilder explainQuery() {
        if (!$assertionsDisabled && this.queryVector == null) {
            throw new AssertionError("query vector must be materialized at this point");
        }
        if (!$assertionsDisabled && this.rankDocs == null) {
            throw new AssertionError("rankDocs should have been materialized by now");
        }
        RankDocsQueryBuilder rankDocsQueryBuilder = new RankDocsQueryBuilder(this.rankDocs, new QueryBuilder[]{new ExactKnnQueryBuilder(VectorData.fromFloats(this.queryVector.get()), this.field, this.similarity)}, true);
        if (this.preFilterQueryBuilders.isEmpty()) {
            return rankDocsQueryBuilder.queryName(this.retrieverName);
        }
        BoolQueryBuilder must = new BoolQueryBuilder().must(rankDocsQueryBuilder);
        List<QueryBuilder> list = this.preFilterQueryBuilders;
        Objects.requireNonNull(must);
        list.forEach(must::filter);
        return must.queryName(this.retrieverName);
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean z) {
        if (!$assertionsDisabled && this.queryVector == null) {
            throw new AssertionError("query vector must be materialized at this point.");
        }
        KnnSearchBuilder knnSearchBuilder = new KnnSearchBuilder(this.field, VectorData.fromFloats(this.queryVector.get()), (QueryVectorBuilder) null, this.k, this.numCands, this.similarity);
        if (this.preFilterQueryBuilders != null) {
            knnSearchBuilder.addFilterQueries(this.preFilterQueryBuilders);
        }
        if (this.retrieverName != null) {
            knnSearchBuilder.queryName(this.retrieverName);
        }
        ArrayList arrayList = new ArrayList(searchSourceBuilder.knnSearch());
        arrayList.add(knnSearchBuilder);
        searchSourceBuilder.knnSearch(arrayList);
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public void doToXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.field(FIELD_FIELD.getPreferredName(), this.field);
        xContentBuilder.field(K_FIELD.getPreferredName(), this.k);
        xContentBuilder.field(NUM_CANDS_FIELD.getPreferredName(), this.numCands);
        if (this.queryVector != null) {
            xContentBuilder.field(QUERY_VECTOR_FIELD.getPreferredName(), this.queryVector.get());
        }
        if (this.queryVectorBuilder != null) {
            xContentBuilder.field(QUERY_VECTOR_BUILDER_FIELD.getPreferredName(), this.queryVectorBuilder);
        }
        if (this.similarity != null) {
            xContentBuilder.field(VECTOR_SIMILARITY.getPreferredName(), this.similarity);
        }
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public boolean doEquals(Object obj) {
        KnnRetrieverBuilder knnRetrieverBuilder = (KnnRetrieverBuilder) obj;
        return this.k == knnRetrieverBuilder.k && this.numCands == knnRetrieverBuilder.numCands && Objects.equals(this.field, knnRetrieverBuilder.field) && ((this.queryVector == null && knnRetrieverBuilder.queryVector == null) || !(this.queryVector == null || knnRetrieverBuilder.queryVector == null || !Arrays.equals(this.queryVector.get(), knnRetrieverBuilder.queryVector.get()))) && Objects.equals(this.queryVectorBuilder, knnRetrieverBuilder.queryVectorBuilder) && Objects.equals(this.similarity, knnRetrieverBuilder.similarity);
    }

    @Override // org.elasticsearch.search.retriever.RetrieverBuilder
    public int doHashCode() {
        return (31 * Objects.hash(this.field, this.queryVectorBuilder, Integer.valueOf(this.k), Integer.valueOf(this.numCands), this.similarity)) + Arrays.hashCode(this.queryVector != null ? this.queryVector.get() : null);
    }

    static {
        $assertionsDisabled = !KnnRetrieverBuilder.class.desiredAssertionStatus();
        KNN_RETRIEVER_SUPPORTED = new NodeFeature("knn_retriever_supported");
        FIELD_FIELD = new ParseField("field", new String[0]);
        K_FIELD = new ParseField("k", new String[0]);
        NUM_CANDS_FIELD = new ParseField("num_candidates", new String[0]);
        QUERY_VECTOR_FIELD = new ParseField("query_vector", new String[0]);
        QUERY_VECTOR_BUILDER_FIELD = new ParseField("query_vector_builder", new String[0]);
        VECTOR_SIMILARITY = new ParseField("similarity", new String[0]);
        PARSER = new ConstructingObjectParser<>("knn", objArr -> {
            float[] fArr;
            List list = (List) objArr[1];
            if (list != null) {
                fArr = new float[list.size()];
                for (int i = 0; i < list.size(); i++) {
                    fArr[i] = ((Float) list.get(i)).floatValue();
                }
            } else {
                fArr = null;
            }
            return new KnnRetrieverBuilder((String) objArr[0], fArr, (QueryVectorBuilder) objArr[2], ((Integer) objArr[3]).intValue(), ((Integer) objArr[4]).intValue(), (Float) objArr[5]);
        });
        PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD_FIELD);
        PARSER.declareFloatArray(ConstructingObjectParser.optionalConstructorArg(), QUERY_VECTOR_FIELD);
        PARSER.declareNamedObject(ConstructingObjectParser.optionalConstructorArg(), (xContentParser, retrieverParserContext, str) -> {
            return (QueryVectorBuilder) xContentParser.namedObject(QueryVectorBuilder.class, str, retrieverParserContext);
        }, QUERY_VECTOR_BUILDER_FIELD);
        PARSER.declareInt(ConstructingObjectParser.constructorArg(), K_FIELD);
        PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CANDS_FIELD);
        PARSER.declareFloat(ConstructingObjectParser.optionalConstructorArg(), VECTOR_SIMILARITY);
        RetrieverBuilder.declareBaseParserFields("knn", PARSER);
    }
}
