/*
 * Decompiled with CFR 0.152.
 */
package io.milvus.orm.iterator;

import io.milvus.common.utils.ExceptionUtils;
import io.milvus.grpc.DescribeCollectionRequest;
import io.milvus.grpc.DescribeCollectionResponse;
import io.milvus.grpc.MilvusServiceGrpc;
import io.milvus.grpc.SearchIteratorV2Results;
import io.milvus.grpc.SearchRequest;
import io.milvus.grpc.SearchResultData;
import io.milvus.grpc.SearchResults;
import io.milvus.orm.iterator.SearchIterator;
import io.milvus.v2.service.collection.response.DescribeCollectionResp;
import io.milvus.v2.service.vector.request.SearchIteratorReqV2;
import io.milvus.v2.service.vector.request.SearchReq;
import io.milvus.v2.service.vector.response.SearchResp;
import io.milvus.v2.utils.ConvertUtils;
import io.milvus.v2.utils.RpcUtils;
import io.milvus.v2.utils.VectorUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SearchIteratorV2 {
    private static final Logger logger = LoggerFactory.getLogger(SearchIterator.class);
    private final MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub;
    private final SearchIteratorReqV2 searchIteratorReq;
    private final int batchSize;
    private Map<String, Object> searchParams;
    private final RpcUtils rpcUtils;
    private Integer leftResCnt = null;
    private Long collectionID = null;
    private Function<List<SearchResp.SearchResult>, List<SearchResp.SearchResult>> externalFilterFunc = null;
    private List<SearchResp.SearchResult> cache = new ArrayList<SearchResp.SearchResult>();

    public SearchIteratorV2(SearchIteratorReqV2 searchIteratorReq, MilvusServiceGrpc.MilvusServiceBlockingStub blockingStub) {
        this.blockingStub = blockingStub;
        this.searchIteratorReq = searchIteratorReq;
        this.batchSize = (int)searchIteratorReq.getBatchSize();
        this.externalFilterFunc = searchIteratorReq.getExternalFilterFunc();
        this.rpcUtils = new RpcUtils();
        this.checkParams();
        this.setupCollectionID();
        this.probeForCompability();
    }

    private void checkParams() {
        int rows;
        if (this.batchSize < 0) {
            ExceptionUtils.throwUnExpectedException("Batch size cannot be less than zero");
        } else if (this.batchSize > 16384) {
            ExceptionUtils.throwUnExpectedException(String.format("Batch size cannot be larger than %d", 16384));
        }
        this.searchParams = this.searchIteratorReq.getSearchParams();
        if (this.searchParams.containsKey("offset") && (Integer)this.searchParams.get("offset") > 0) {
            ExceptionUtils.throwUnExpectedException("Offset is not supported for SearchIterator");
        }
        if ((rows = this.searchIteratorReq.getVectors().size()) > 1) {
            ExceptionUtils.throwUnExpectedException("SearchIterator does not support processing multiple vectors simultaneously");
        } else if (rows <= 0) {
            ExceptionUtils.throwUnExpectedException("The vector data for search cannot be empty");
        }
        if (this.searchIteratorReq.getTopK() != -1) {
            this.leftResCnt = this.searchIteratorReq.getTopK();
        }
    }

    private void setupCollectionID() {
        DescribeCollectionRequest.Builder builder = DescribeCollectionRequest.newBuilder().setCollectionName(this.searchIteratorReq.getCollectionName());
        if (StringUtils.isNotEmpty((CharSequence)this.searchIteratorReq.getDatabaseName())) {
            builder.setDbName(this.searchIteratorReq.getDatabaseName());
        }
        DescribeCollectionResponse response = this.rpcUtils.retry(() -> this.blockingStub.describeCollection(builder.build()));
        String title = String.format("DescribeCollectionRequest collectionName:%s", this.searchIteratorReq.getCollectionName());
        this.rpcUtils.handleResponse(title, response.getStatus());
        DescribeCollectionResp respR = new ConvertUtils().convertDescCollectionResp(response);
        this.collectionID = respR.getCollectionID();
    }

    private SearchResults executeSearch(int limit) {
        this.searchParams.put("search_iter_batch_size", limit);
        Object request = ((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)((SearchReq.SearchReqBuilder)SearchReq.builder().collectionName(this.searchIteratorReq.getCollectionName())).partitionNames(this.searchIteratorReq.getPartitionNames())).databaseName(this.searchIteratorReq.getDatabaseName())).annsField(this.searchIteratorReq.getVectorFieldName())).data(this.searchIteratorReq.getVectors())).topK(limit)).filter(this.searchIteratorReq.getFilter())).consistencyLevel(this.searchIteratorReq.getConsistencyLevel())).outputFields(this.searchIteratorReq.getOutputFields())).roundDecimal(this.searchIteratorReq.getRoundDecimal())).searchParams(this.searchParams)).metricType(this.searchIteratorReq.getMetricType())).ignoreGrowing(this.searchIteratorReq.isIgnoreGrowing())).groupByFieldName(this.searchIteratorReq.getGroupByFieldName())).build();
        SearchRequest searchRequest = new VectorUtils().ConvertToGrpcSearchRequest((SearchReq)request);
        SearchResults response = this.rpcUtils.retry(() -> this.blockingStub.search(searchRequest));
        String title = String.format("SearchRequest collectionName:%s", this.searchIteratorReq.getCollectionName());
        this.rpcUtils.handleResponse(title, response.getStatus());
        return response;
    }

    private void probeForCompability() {
        this.searchParams.put("collection_id", this.collectionID);
        this.searchParams.put("iterator", true);
        this.searchParams.put("search_iter_v2", true);
        this.searchParams.put("guarantee_timestamp", 0L);
        SearchResultData resultData = this.executeSearch(1).getResults();
        this.checkTokenExists(resultData);
    }

    private void checkTokenExists(SearchResultData resultData) {
        String token = resultData.getSearchIteratorV2Results().getToken();
        if (StringUtils.isEmpty((CharSequence)token)) {
            ExceptionUtils.throwUnExpectedException("The server does not support Search Iterator V2. The search_iterator (v1) is used instead.\n    Please upgrade your Milvus server version to 2.5.2 and later,\n    or use a pymilvus version before 2.5.3 (excluded) to avoid this issue.");
        }
    }

    public List<SearchResp.SearchResult> next() {
        List<SearchResp.SearchResult> hits;
        if (this.leftResCnt != null && this.leftResCnt <= 0) {
            return new ArrayList<SearchResp.SearchResult>();
        }
        if (this.externalFilterFunc == null) {
            return this.wrapReturnRes(this._next());
        }
        int targetLen = this.batchSize;
        if (this.leftResCnt != null && this.leftResCnt < targetLen) {
            targetLen = this.leftResCnt;
        }
        while ((hits = this._next()) != null && !hits.isEmpty()) {
            if (this.externalFilterFunc != null) {
                hits = this.externalFilterFunc.apply(hits);
            }
            this.cache.addAll(hits);
            if (this.cache.size() < targetLen) continue;
            break;
        }
        List<SearchResp.SearchResult> subList = this.cache.subList(0, targetLen);
        ArrayList<SearchResp.SearchResult> ret = new ArrayList<SearchResp.SearchResult>(subList);
        subList.clear();
        return this.wrapReturnRes(ret);
    }

    private List<SearchResp.SearchResult> _next() {
        long ts;
        SearchResults response = this.executeSearch(this.batchSize);
        this.checkTokenExists(response.getResults());
        SearchIteratorV2Results iterInfo = response.getResults().getSearchIteratorV2Results();
        this.searchParams.put("search_iter_last_bound", Float.valueOf(iterInfo.getLastBound()));
        if (!this.searchParams.containsKey("search_iter_id")) {
            this.searchParams.put("search_iter_id", iterInfo.getToken());
        }
        if ((ts = ((Long)this.searchParams.get("guarantee_timestamp")).longValue()) <= 0L) {
            if (response.getSessionTs() > 0L) {
                this.searchParams.put("guarantee_timestamp", response.getSessionTs());
            } else {
                logger.warn("Failed to set up mvccTs from milvus server, use client-side ts instead");
                long clientTs = System.currentTimeMillis() + 1000L;
                this.searchParams.put("guarantee_timestamp", clientTs <<= 18);
            }
        }
        List<List<SearchResp.SearchResult>> res = new ConvertUtils().getEntities(response);
        return res.get(0);
    }

    private List<SearchResp.SearchResult> wrapReturnRes(List<SearchResp.SearchResult> res) {
        if (this.leftResCnt == null) {
            return res;
        }
        int currentLen = res.size();
        if (currentLen > this.leftResCnt) {
            res = res.subList(0, this.leftResCnt);
        }
        this.leftResCnt = this.leftResCnt - currentLen;
        return res;
    }

    public void close() {
    }
}

