/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.search;

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.QueryTimeout;
import org.apache.lucene.search.AbstractKnnVectorQuery;
import org.apache.lucene.search.AcceptDocs;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.TopScoreDocCollectorManager;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.knn.KnnCollectorManager;
import org.apache.lucene.search.knn.KnnSearchStrategy;

public class SeededKnnVectorQuery
extends AbstractKnnVectorQuery {
    final Query seed;
    final Weight seedWeight;
    final AbstractKnnVectorQuery delegate;

    public static SeededKnnVectorQuery fromFloatQuery(KnnFloatVectorQuery knnQuery, Query seed) {
        return new SeededKnnVectorQuery(knnQuery, seed, null);
    }

    public static SeededKnnVectorQuery fromByteQuery(KnnByteVectorQuery knnQuery, Query seed) {
        return new SeededKnnVectorQuery(knnQuery, seed, null);
    }

    SeededKnnVectorQuery(AbstractKnnVectorQuery knnQuery, Query seed, Weight seedWeight, String field, int k, Query filter, KnnSearchStrategy searchStrategy) {
        super(field, k, filter, searchStrategy);
        this.delegate = knnQuery;
        this.seed = Objects.requireNonNull(seed);
        this.seedWeight = seedWeight;
    }

    public SeededKnnVectorQuery(KnnFloatVectorQuery knnQuery, Query seed, Weight seedWeight) {
        this(knnQuery, seed, seedWeight, knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy);
    }

    public SeededKnnVectorQuery(KnnByteVectorQuery knnQuery, Query seed, Weight seedWeight) {
        this(knnQuery, seed, seedWeight, knnQuery.field, knnQuery.k, knnQuery.filter, knnQuery.searchStrategy);
    }

    @Override
    public String toString(String field) {
        return "SeededKnnVectorQuery{seed=" + String.valueOf(this.seed) + ", seedWeight=" + String.valueOf(this.seedWeight) + ", delegate=" + String.valueOf(this.delegate) + "}";
    }

    @Override
    public Query rewrite(IndexSearcher indexSearcher) throws IOException {
        if (this.seedWeight != null) {
            return super.rewrite(indexSearcher);
        }
        SeededKnnVectorQuery rewritten = new SeededKnnVectorQuery(this.delegate, this.seed, this.createSeedWeight(indexSearcher), this.delegate.field, this.delegate.k, this.delegate.filter, this.delegate.searchStrategy);
        return rewritten.rewrite(indexSearcher);
    }

    Weight createSeedWeight(IndexSearcher indexSearcher) throws IOException {
        BooleanQuery.Builder booleanSeedQueryBuilder = new BooleanQuery.Builder().add(this.seed, BooleanClause.Occur.MUST).add(new FieldExistsQuery(this.field), BooleanClause.Occur.FILTER);
        if (this.filter != null) {
            booleanSeedQueryBuilder.add(this.filter, BooleanClause.Occur.FILTER);
        }
        Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build());
        return indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1.0f);
    }

    @Override
    protected TopDocs approximateSearch(LeafReaderContext context, AcceptDocs acceptDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException {
        return this.delegate.approximateSearch(context, acceptDocs, visitedLimit, new SeededCollectorManager(knnCollectorManager));
    }

    @Override
    protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) {
        return this.delegate.getKnnCollectorManager(k, searcher);
    }

    @Override
    protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) throws IOException {
        return this.delegate.exactSearch(context, acceptIterator, queryTimeout);
    }

    @Override
    protected TopDocs mergeLeafResults(TopDocs[] perLeafResults) {
        return this.delegate.mergeLeafResults(perLeafResults);
    }

    @Override
    public void visit(QueryVisitor visitor) {
        this.delegate.visit(visitor);
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        SeededKnnVectorQuery that = (SeededKnnVectorQuery)o;
        return Objects.equals(this.seed, that.seed) && Objects.equals(this.seedWeight, that.seedWeight) && Objects.equals(this.delegate, that.delegate);
    }

    @Override
    public int hashCode() {
        return Objects.hash(super.hashCode(), this.seed, this.seedWeight, this.delegate);
    }

    @Override
    public String getField() {
        return this.delegate.getField();
    }

    @Override
    public int getK() {
        return this.delegate.getK();
    }

    @Override
    public Query getFilter() {
        return this.delegate.getFilter();
    }

    @Override
    VectorScorer createVectorScorer(LeafReaderContext context, FieldInfo fi) throws IOException {
        return this.delegate.createVectorScorer(context, fi);
    }

    class SeededCollectorManager
    implements KnnCollectorManager {
        final KnnCollectorManager knnCollectorManager;

        SeededCollectorManager(KnnCollectorManager knnCollectorManager) {
            this.knnCollectorManager = knnCollectorManager;
        }

        @Override
        public KnnCollector newCollector(int visitLimit, KnnSearchStrategy searchStrategy, LeafReaderContext ctx) throws IOException {
            TopScoreDocCollector seedCollector = new TopScoreDocCollectorManager(SeededKnnVectorQuery.this.k, null, Integer.MAX_VALUE).newCollector();
            LeafReader leafReader = ctx.reader();
            LeafCollector leafCollector = seedCollector.getLeafCollector(ctx);
            if (leafCollector != null) {
                try {
                    BulkScorer scorer = SeededKnnVectorQuery.this.seedWeight.bulkScorer(ctx);
                    if (scorer != null) {
                        scorer.score(leafCollector, leafReader.getLiveDocs(), 0, Integer.MAX_VALUE);
                    }
                }
                catch (CollectionTerminatedException scorer) {
                    // empty catch block
                }
                leafCollector.finish();
            }
            KnnCollector delegateCollector = this.knnCollectorManager.newCollector(visitLimit, searchStrategy, ctx);
            TopDocs seedTopDocs = seedCollector.topDocs();
            VectorScorer scorer = SeededKnnVectorQuery.this.delegate.createVectorScorer(ctx, leafReader.getFieldInfos().fieldInfo(SeededKnnVectorQuery.this.field));
            if (seedTopDocs.totalHits.value() == 0L || scorer == null) {
                return delegateCollector;
            }
            DocIdSetIterator vectorIterator = scorer.iterator();
            if (vectorIterator instanceof IndexedDISI) {
                IndexedDISI indexedDISI = (IndexedDISI)vectorIterator;
                vectorIterator = IndexedDISI.asDocIndexIterator(indexedDISI);
            }
            if (vectorIterator instanceof KnnVectorValues.DocIndexIterator) {
                KnnVectorValues.DocIndexIterator indexIterator = (KnnVectorValues.DocIndexIterator)vectorIterator;
                MappedDISI seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs, ctx));
                return this.knnCollectorManager.newCollector(visitLimit, new KnnSearchStrategy.Seeded(seedDocs, seedTopDocs.scoreDocs.length, searchStrategy), ctx);
            }
            return delegateCollector;
        }
    }

    static class TopDocsDISI
    extends DocIdSetIterator {
        private final int[] sortedDocIds;
        private int idx = -1;

        TopDocsDISI(TopDocs topDocs, LeafReaderContext ctx) {
            this.sortedDocIds = new int[topDocs.scoreDocs.length];
            for (int i = 0; i < topDocs.scoreDocs.length; ++i) {
                this.sortedDocIds[i] = topDocs.scoreDocs[i].doc - ctx.docBase;
            }
            Arrays.sort(this.sortedDocIds);
        }

        @Override
        public int advance(int target) throws IOException {
            return this.slowAdvance(target);
        }

        @Override
        public long cost() {
            return this.sortedDocIds.length;
        }

        @Override
        public int docID() {
            if (this.idx == -1) {
                return -1;
            }
            if (this.idx >= this.sortedDocIds.length) {
                return Integer.MAX_VALUE;
            }
            return this.sortedDocIds[this.idx];
        }

        @Override
        public int nextDoc() {
            ++this.idx;
            return this.docID();
        }
    }

    static class MappedDISI
    extends DocIdSetIterator {
        KnnVectorValues.DocIndexIterator indexedDISI;
        DocIdSetIterator sourceDISI;

        MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) {
            this.indexedDISI = indexedDISI;
            this.sourceDISI = sourceDISI;
        }

        @Override
        public int advance(int target) throws IOException {
            int newTarget = this.sourceDISI.advance(target);
            if (newTarget != Integer.MAX_VALUE) {
                this.indexedDISI.advance(newTarget);
            }
            return this.docID();
        }

        @Override
        public long cost() {
            return this.sourceDISI.cost();
        }

        @Override
        public int docID() {
            if (this.indexedDISI.docID() == Integer.MAX_VALUE || this.sourceDISI.docID() == Integer.MAX_VALUE) {
                return Integer.MAX_VALUE;
            }
            return this.indexedDISI.index();
        }

        @Override
        public int nextDoc() throws IOException {
            int newTarget = this.sourceDISI.nextDoc();
            if (newTarget != Integer.MAX_VALUE) {
                this.indexedDISI.advance(newTarget);
            }
            return this.docID();
        }
    }
}

