/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocAndScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.SeededKnnVectorQuery;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.search.TimeLimitingKnnCollectorManager;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;
import org.apache.lucene.search.knn.TopKnnCollectorManager;
import org.apache.lucene.util.Bits;

abstract class AbstractKnnVectorQuery
extends Query {
    private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS;
    private static final int LAMBDA = 16;
    protected final String field;
    protected final int k;
    protected final Query filter;
    protected final KnnSearchStrategy searchStrategy;

    AbstractKnnVectorQuery(String field, int k, Query filter, KnnSearchStrategy searchStrategy) {
        this.field = Objects.requireNonNull(field, "field");
        this.k = k;
        if (k < 1) {
            throw new IllegalArgumentException("k must be at least 1, got: " + k);
        }
        this.filter = filter;
        this.searchStrategy = searchStrategy;
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        Weight filterWeight;
        IndexReader reader = indexSearcher.getIndexReader();
        if (this.filter != null) {
            Query rewrittenFilter = this.filter.rewrite(indexSearcher);
            if (rewrittenFilter.getClass() == MatchNoDocsQuery.class) {
                return rewrittenFilter;
            }
            if (rewrittenFilter.getClass() != MatchAllDocsQuery.class) {
                BooleanQuery booleanQuery = new BooleanQuery.Builder().add(this.filter, BooleanClause.Occur.FILTER).add(new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER).build();
                Query rewritten = indexSearcher.rewrite(booleanQuery);
                if (rewritten.getClass() == MatchNoDocsQuery.class) {
                    return rewritten;
                }
                filterWeight = rewritten.createWeight(indexSearcher, ScoreMode.COMPLETE_NO_SCORES, 1.0f);
            } else {
                filterWeight = null;
            }
        } else {
            filterWeight = null;
        }
        KnnCollectorManager knnCollectorManager = this.getKnnCollectorManager(this.k, indexSearcher);
        OptimisticKnnCollectorManager optimisticCollectorManager = new OptimisticKnnCollectorManager(this.k, knnCollectorManager);
        TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager = new TimeLimitingKnnCollectorManager(optimisticCollectorManager, indexSearcher.getTimeout());
        TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
        ArrayList<LeafReaderContext> leafReaderContexts = new ArrayList<LeafReaderContext>(reader.leaves());
        ArrayList<Callable<TopDocs>> tasks = new ArrayList<Callable<TopDocs>>(leafReaderContexts.size());
        for (LeafReaderContext context2 : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(context2, filterWeight, timeLimitingKnnCollectorManager));
        }
        HashMap<Integer, TopDocs> perLeafResults = new HashMap<Integer, TopDocs>();
        TopDocs topK = this.runSearchTasks(tasks, taskExecutor, perLeafResults, leafReaderContexts);
        if (topK.scoreDocs.length > 0 && perLeafResults.size() > 1 && knnCollectorManager.isOptimistic() && topK.totalHits.relation() == TotalHits.Relation.EQUAL_TO) {
            float minTopKScore = topK.scoreDocs[topK.scoreDocs.length - 1].score;
            TimeLimitingKnnCollectorManager knnCollectorManagerPhase2 = new TimeLimitingKnnCollectorManager(new ReentrantKnnCollectorManager(this.getKnnCollectorManager(this.k, indexSearcher), perLeafResults), indexSearcher.getTimeout());
            Iterator ctxIter = leafReaderContexts.iterator();
            while (ctxIter.hasNext()) {
                LeafReaderContext ctx = (LeafReaderContext)ctxIter.next();
                TopDocs perLeaf = (TopDocs)perLeafResults.get(ctx.ord);
                if (perLeaf.scoreDocs.length > 0 && perLeaf.scoreDocs[perLeaf.scoreDocs.length - 1].score >= minTopKScore) {
                    tasks.add(() -> this.searchLeaf(ctx, filterWeight, knnCollectorManagerPhase2));
                    continue;
                }
                ctxIter.remove();
            }
            assert (leafReaderContexts.size() == tasks.size());
            assert (perLeafResults.size() == reader.leaves().size());
            topK = this.runSearchTasks(tasks, taskExecutor, perLeafResults, leafReaderContexts);
        }
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery();
        }
        return DocAndScoreQuery.createDocAndScoreQuery(reader, topK);
    }

    private TopDocs runSearchTasks(List<Callable<TopDocs>> tasks, TaskExecutor taskExecutor, Map<Integer, TopDocs> perLeafResults, List<LeafReaderContext> leafReaderContexts) throws IOException {
        List taskResults = taskExecutor.invokeAll(tasks);
        for (int i = 0; i < taskResults.size(); ++i) {
            perLeafResults.put(leafReaderContexts.get((int)i).ord, (TopDocs)taskResults.get(i));
        }
        tasks.clear();
        return this.mergeLeafResults((TopDocs[])perLeafResults.values().toArray(TopDocs[]::new));
    }

    protected TopDocs searchLeaf(LeafReaderContext ctx, Weight filterWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException {
        TopDocs results = this.getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager);
        if (ctx.docBase > 0) {
            for (ScoreDoc scoreDoc : results.scoreDocs) {
                scoreDoc.doc += ctx.docBase;
            }
        }
        return results;
    }

    private TopDocs getLeafResults(LeafReaderContext ctx, Weight filterWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException {
        LeafReader reader = ctx.reader();
        Bits liveDocs = reader.getLiveDocs();
        if (filterWeight == null) {
            AcceptDocs acceptDocs = AcceptDocs.fromLiveDocs(liveDocs, reader.maxDoc());
            return this.approximateSearch(ctx, acceptDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager);
        }
        AcceptDocs acceptDocs = AcceptDocs.fromIteratorSupplier(() -> {
            Scorer scorer = filterWeight.scorer(ctx);
            if (scorer == null) {
                return DocIdSetIterator.empty();
            }
            return scorer.iterator();
        }, liveDocs, reader.maxDoc());
        int cost = acceptDocs.cost();
        QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();
        float leafProportion = (float)ctx.reader().maxDoc() / (float)ctx.parent.reader().maxDoc();
        int perLeafTopK = AbstractKnnVectorQuery.perLeafTopKCalculation(this.k, leafProportion);
        if (cost <= perLeafTopK) {
            return this.exactSearch(ctx, acceptDocs.iterator(), queryTimeout);
        }
        TopDocs results = this.approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager);
        if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO && results.scoreDocs.length >= perLeafTopK || queryTimeout != null && queryTimeout.shouldExit()) {
            return results;
        }
        return this.exactSearch(ctx, acceptDocs.iterator(), queryTimeout);
    }

    protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
        return new TopKnnCollectorManager(k, searcher);
    }

    private static int perLeafTopKCalculation(int k, float leafProportion) {
        return (int)Math.max(1.0, (double)((float)k * leafProportion) + 16.0 * Math.sqrt((float)k * leafProportion * (1.0f - leafProportion)));
    }

    protected abstract TopDocs approximateSearch(LeafReaderContext var1, AcceptDocs var2, int var3, KnnCollectorManager var4) throws IOException;

    abstract VectorScorer createVectorScorer(LeafReaderContext var1, FieldInfo var2) throws IOException;

    protected TopDocs exactSearch(LeafReaderContext context2, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException {
        int doc;
        FieldInfo fi = context2.reader().getFieldInfos().fieldInfo(this.field);
        if (fi == null || fi.getVectorDimension() == 0) {
            return NO_RESULTS;
        }
        VectorScorer vectorScorer = this.createVectorScorer(context2, fi);
        if (vectorScorer == null) {
            return NO_RESULTS;
        }
        int queueSize = Math.min(this.k, Math.toIntExact(acceptIterator.cost()));
        HitQueue queue = new HitQueue(queueSize, true);
        TotalHits.Relation relation = TotalHits.Relation.EQUAL_TO;
        ScoreDoc topDoc = (ScoreDoc)queue.top();
        DocIdSetIterator vectorIterator = vectorScorer.iterator();
        DocIdSetIterator conjunction = ConjunctionDISI.createConjunction(List.of(vectorIterator, acceptIterator), List.of());
        while ((doc = conjunction.nextDoc()) != Integer.MAX_VALUE) {
            if (queryTimeout != null && queryTimeout.shouldExit()) {
                relation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
                break;
            }
            assert (vectorIterator.docID() == doc);
            float score = vectorScorer.score();
            if (!(score > topDoc.score)) continue;
            topDoc.score = score;
            topDoc.doc = doc;
            topDoc = (ScoreDoc)queue.updateTop();
        }
        while (queue.size() > 0 && ((ScoreDoc)queue.top()).score < 0.0f) {
            queue.pop();
        }
        ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
        for (int i = topScoreDocs.length - 1; i >= 0; --i) {
            topScoreDocs[i] = (ScoreDoc)queue.pop();
        }
        TotalHits totalHits = new TotalHits(acceptIterator.cost(), relation);
        return new TopDocs(totalHits, topScoreDocs);
    }

    protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
        return TopDocs.merge(this.k, perLeafResults);
    }

    @Override
    public void visit(QueryVisitor visitor) {
        if (visitor.acceptField(this.field)) {
            visitor.visitLeaf(this);
        }
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        AbstractKnnVectorQuery that = (AbstractKnnVectorQuery)o;
        return this.k == that.k && Objects.equals(this.field, that.field) && Objects.equals(this.filter, that.filter) && Objects.equals(this.searchStrategy, that.searchStrategy);
    }

    @Override
    public int hashCode() {
        return Objects.hash(this.field, this.k, this.filter);
    }

    public String getField() {
        return this.field;
    }

    public int getK() {
        return this.k;
    }

    public Query getFilter() {
        return this.filter;
    }

    public KnnSearchStrategy getSearchStrategy() {
        return this.searchStrategy;
    }

    static class OptimisticKnnCollectorManager
    implements KnnCollectorManager {
        private final int k;
        private final KnnCollectorManager delegate;

        OptimisticKnnCollectorManager(int k, KnnCollectorManager delegate) {
            this.k = k;
            this.delegate = delegate;
        }

        @Override
        public KnnCollector newCollector(int visitedLimit, KnnSearchStrategy searchStrategy, LeafReaderContext context2) throws IOException {
            if (this.delegate.isOptimistic()) {
                float leafProportion = (float)context2.reader().maxDoc() / (float)context2.parent.reader().maxDoc();
                int perLeafTopK = AbstractKnnVectorQuery.perLeafTopKCalculation(this.k, leafProportion);
                assert (perLeafTopK > 0);
                return this.delegate.newOptimisticCollector(visitedLimit, searchStrategy, context2, perLeafTopK);
            }
            return this.delegate.newCollector(visitedLimit, searchStrategy, context2);
        }
    }

    private class ReentrantKnnCollectorManager
    implements KnnCollectorManager {
        final KnnCollectorManager knnCollectorManager;
        final Map<Integer, TopDocs> perLeafResults;

        ReentrantKnnCollectorManager(KnnCollectorManager knnCollectorManager, Map<Integer, TopDocs> perLeafResults) {
            this.knnCollectorManager = knnCollectorManager;
            this.perLeafResults = perLeafResults;
        }

        @Override
        public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException {
            KnnCollector delegateCollector = this.knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx);
            TopDocs seedTopDocs = this.perLeafResults.get(ctx.ord);
            VectorScorer scorer = AbstractKnnVectorQuery.this.createVectorScorer(ctx, ctx.reader().getFieldInfos().fieldInfo(AbstractKnnVectorQuery.this.field));
            if (seedTopDocs.totalHits.value() == 0L || scorer == null) {
                assert (false);
                return delegateCollector;
            }
            DocIdSetIterator vectorIterator = scorer.iterator();
            if (vectorIterator instanceof IndexedDISI) {
                IndexedDISI indexedDISI = (IndexedDISI)vectorIterator;
                vectorIterator = IndexedDISI.asDocIndexIterator(indexedDISI);
            }
            if (vectorIterator instanceof KnnVectorValues.DocIndexIterator) {
                KnnVectorValues.DocIndexIterator indexIterator = (KnnVectorValues.DocIndexIterator)vectorIterator;
                SeededKnnVectorQuery.MappedDISI seedDocs = new SeededKnnVectorQuery.MappedDISI(indexIterator, new SeededKnnVectorQuery.TopDocsDISI(seedTopDocs, ctx));
                return this.knnCollectorManager.newCollector(visitLimit, new KnnSearchStrategy.Seeded(seedDocs, seedTopDocs.scoreDocs.length, searchStrategy), ctx);
            }
            assert (false);
            return delegateCollector;
        }
    }
}

