From d465e756a20df9c9c6f11d626ed30de5eb3fc838 Mon Sep 17 00:00:00 2001 From: "manthapavankumar11@gmail.com" Date: Mon, 13 Apr 2026 06:26:27 +0530 Subject: [PATCH] Implemented Full Reranking --- README.md | 120 +++++++++++++++++++++++++++- pyproject.toml | 2 +- src/qql/ast_nodes.py | 2 + src/qql/cli.py | 1 + src/qql/embedder.py | 31 ++++++++ src/qql/executor.py | 46 ++++++++++- src/qql/lexer.py | 2 + src/qql/parser.py | 10 +++ tests/test_executor.py | 177 +++++++++++++++++++++++++++++++++++++++++ tests/test_lexer.py | 24 ++++++ tests/test_parser.py | 59 ++++++++++++++ 11 files changed, 468 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2b65dff..e3a4942 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,12 @@ qql> SEARCH notes SIMILAR TO 'vector databases' LIMIT 5 USING HYBRID Score │ ID │ Payload ────────┼──────────────────────────────────────┼────────────────────────────────────── 0.9102 │ 3f2e1a4b-8c91-4d0e-b123-abc123def456 │ {'text': 'Qdrant is a ...', 'author': 'alice', 'year': 2024} + +qql> SEARCH notes SIMILAR TO 'vector databases' LIMIT 5 USING HYBRID RERANK +✓ Found 1 result(s) (hybrid, reranked) + Score │ ID │ Payload +────────┼──────────────────────────────────────┼────────────────────────────────────── + 5.3714 │ 3f2e1a4b-8c91-4d0e-b123-abc123def456 │ {'text': 'Qdrant is a ...', 'author': 'alice', 'year': 2024} ``` --- @@ -32,6 +38,7 @@ qql> SEARCH notes SIMILAR TO 'vector databases' LIMIT 5 USING HYBRID - [SEARCH — find similar points](#search--find-similar-points) - [WHERE Clause Filters](#where-clause-filters) - [Hybrid Search (USING HYBRID)](#hybrid-search-using-hybrid) + - [Cross-Encoder Reranking (RERANK)](#cross-encoder-reranking-rerank) - [SHOW COLLECTIONS — list collections](#show-collections--list-collections) - [CREATE COLLECTION — create a collection](#create-collection--create-a-collection) - [DROP COLLECTION — delete a collection](#drop-collection--delete-a-collection) @@ -244,6 +251,7 @@ SEARCH SIMILAR TO '' LIMIT USING MODEL ' SIMILAR TO '' LIMIT [USING MODEL ''] WHERE SEARCH SIMILAR TO '' LIMIT USING HYBRID SEARCH SIMILAR TO '' LIMIT USING HYBRID [DENSE MODEL ''] [SPARSE MODEL ''] [WHERE ] +SEARCH SIMILAR TO '' LIMIT [USING ...] [WHERE ] RERANK [MODEL ''] ``` **Examples:** @@ -508,6 +516,95 @@ Both can be overridden independently with `DENSE MODEL` and `SPARSE MODEL`. --- +### Cross-Encoder Reranking (RERANK) + +Appending `RERANK` to any SEARCH statement activates a **second-pass relevance scoring** step using a [cross-encoder](https://www.sbert.net/examples/applications/cross-encoder/README.html) model. Unlike bi-encoders (which encode query and document independently), a cross-encoder processes the **(query, document)** pair jointly, producing a more accurate relevance score at the cost of extra compute. + +#### How it works internally + +1. Qdrant executes the normal dense or hybrid search, but fetches `LIMIT × 4` candidates instead of just `LIMIT` — giving the reranker enough material to work with. +2. Each candidate's `payload["text"]` is paired with the original query text. +3. The cross-encoder scores all (query, document) pairs in one batch. +4. Results are sorted **descending by cross-encoder score** and sliced to `LIMIT`. +5. The `score` column in the output reflects the cross-encoder relevance score (raw logits — higher is more relevant). + +#### Syntax + +``` +SEARCH SIMILAR TO '' LIMIT RERANK +SEARCH SIMILAR TO '' LIMIT RERANK MODEL '' +``` + +`RERANK` must come **after** any `USING` and `WHERE` clauses: + +``` +SEARCH ... LIMIT n [USING ...] [WHERE ...] RERANK [MODEL '...'] +``` + +#### Examples + +Dense search + rerank (default cross-encoder): +```sql +SEARCH articles SIMILAR TO 'machine learning for healthcare' LIMIT 5 RERANK +``` + +Hybrid search + rerank (best of all three worlds): +```sql +SEARCH articles SIMILAR TO 'attention mechanism in transformers' LIMIT 10 USING HYBRID RERANK +``` + +Dense search + WHERE filter + rerank: +```sql +SEARCH articles SIMILAR TO 'deep learning' LIMIT 10 WHERE year > 2020 RERANK +``` + +Custom cross-encoder model: +```sql +SEARCH articles SIMILAR TO 'semantic search' LIMIT 5 + RERANK MODEL 'cross-encoder/ms-marco-MiniLM-L-6-v2' +``` + +All clauses combined: +```sql +SEARCH articles SIMILAR TO 'neural IR' LIMIT 10 + USING HYBRID DENSE MODEL 'BAAI/bge-base-en-v1.5' + WHERE year >= 2020 + RERANK MODEL 'cross-encoder/ms-marco-MiniLM-L-6-v2' +``` + +#### Default cross-encoder model + +``` +cross-encoder/ms-marco-MiniLM-L-6-v2 +``` + +- A lightweight but effective passage reranker fine-tuned on MS MARCO. +- Downloaded on first use and cached locally by Fastembed. +- No additional packages needed — `TextCrossEncoder` is included in the `fastembed` package. + +#### Commonly available cross-encoder models (Fastembed) + +| Model | Notes | +|---|---| +| `cross-encoder/ms-marco-MiniLM-L-6-v2` | Default. Fast and accurate for passage reranking | +| `cross-encoder/ms-marco-MiniLM-L-12-v2` | Larger, higher quality, slower | +| `BAAI/bge-reranker-base` | BGE reranker, strong general-purpose performance | +| `BAAI/bge-reranker-large` | Highest quality BGE reranker, slower | + +#### When to use RERANK + +| Situation | Recommendation | +|---|---| +| High-precision retrieval (legal, medical, research) | Add `RERANK` | +| Small LIMIT (top-3 or top-5 results) | Very effective — reranker focuses precision | +| Low latency required | Skip `RERANK` (adds ~100–500 ms per batch) | +| Large collections with keyword-heavy queries | `USING HYBRID RERANK` for best coverage + precision | +| General-purpose semantic search | Optional; `RERANK` improves quality at mild cost | + +> **Note on scores:** After reranking, the `score` column shows the cross-encoder's raw logit (can be any real number, unbounded). Do not compare reranked scores to non-reranked cosine similarity scores — they are on different scales. + +--- + ### SHOW COLLECTIONS — list collections Lists all collections in the connected Qdrant instance. @@ -670,6 +767,25 @@ SEARCH docs SIMILAR TO 'hello' LIMIT 5 | `prithivida/Splade_PP_en_v1` | SPLADE++ — strong keyword + semantic overlap | | `Qdrant/Unicoil` | UniCOIL sparse encoder | +### Cross-encoder reranking (RERANK default) + +``` +cross-encoder/ms-marco-MiniLM-L-6-v2 +``` + +- A passage reranker fine-tuned on MS MARCO. +- No new dependencies — `TextCrossEncoder` is included in the `fastembed` package. +- Override with `RERANK MODEL ''`. + +### Commonly available cross-encoder models (Fastembed) + +| Model | Notes | +|---|---| +| `cross-encoder/ms-marco-MiniLM-L-6-v2` | Default. Fast passage reranker | +| `cross-encoder/ms-marco-MiniLM-L-12-v2` | Larger, higher quality | +| `BAAI/bge-reranker-base` | Strong general-purpose reranker | +| `BAAI/bge-reranker-large` | Highest quality, slower | + > Models are downloaded automatically on first use and cached by Fastembed. Loading a new model for the first time takes a few seconds. ### Model consistency rule @@ -847,7 +963,7 @@ qql/ │ ├── lexer.py # Tokenizer: string → List[Token] │ ├── ast_nodes.py # Frozen dataclasses for each statement and filter type │ ├── parser.py # Recursive descent parser: tokens → AST node -│ ├── embedder.py # Embedder (dense) + SparseEmbedder (BM25) with per-model cache +│ ├── embedder.py # Embedder (dense) + SparseEmbedder (BM25) + CrossEncoderEmbedder (rerank) │ └── executor.py # AST node → Qdrant client call + filter + hybrid search └── tests/ ├── test_lexer.py # Tokenizer unit tests (keywords, operators, dot-paths, hybrid tokens) @@ -865,7 +981,7 @@ Tests do not require a running Qdrant instance — the Qdrant client is mocked. pytest tests/ -v ``` -Expected output: **169 tests passing**. +Expected output: **193 tests passing**. --- diff --git a/pyproject.toml b/pyproject.toml index 8e537ad..b8ecae1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "qql-cli" -version = "1.0.0" +version = "1.1.0" description = "A SQL-like query language CLI wrapper for Qdrant vector database" readme = "README.md" license = { file = "LICENSE" } diff --git a/src/qql/ast_nodes.py b/src/qql/ast_nodes.py index 6a9e81f..fa8e4fc 100644 --- a/src/qql/ast_nodes.py +++ b/src/qql/ast_nodes.py @@ -146,6 +146,8 @@ class SearchStmt: hybrid: bool = False # if True, use prefetch+RRF hybrid search sparse_model: str | None = None # sparse model for hybrid; None → SparseEmbedder.DEFAULT_MODEL query_filter: FilterExpr | None = None # optional WHERE clause; default keeps existing tests valid + rerank: bool = False # if True, apply cross-encoder reranking post-Qdrant + rerank_model: str | None = None # cross-encoder model; None → CrossEncoderEmbedder.DEFAULT_MODEL @dataclass(frozen=True) diff --git a/src/qql/cli.py b/src/qql/cli.py index b7ab731..1367d41 100644 --- a/src/qql/cli.py +++ b/src/qql/cli.py @@ -42,6 +42,7 @@ Optional: [yellow]USING MODEL[/yellow] '' Optional: [yellow]USING HYBRID[/yellow] [DENSE MODEL ''] [SPARSE MODEL ''] Optional: [yellow]WHERE[/yellow] (e.g. WHERE year > 2020 AND status = 'ok') + Optional: [yellow]RERANK[/yellow] [MODEL ''] rerank results with a cross-encoder [yellow]DELETE FROM[/yellow] [yellow]WHERE id =[/yellow] '' Delete a point by its ID. diff --git a/src/qql/embedder.py b/src/qql/embedder.py index 41243b5..cd08fe8 100644 --- a/src/qql/embedder.py +++ b/src/qql/embedder.py @@ -65,3 +65,34 @@ def query_embed(self, text: str) -> dict[str, list]: """Embed a query string (BM25 applies different IDF weighting at query time).""" result = next(iter(self._model.query_embed(text))) # type: ignore[attr-defined] return {"indices": result.indices.tolist(), "values": result.values.tolist()} + + +class CrossEncoderEmbedder: + """Cross-encoder reranker using fastembed.TextCrossEncoder. + + Jointly encodes (query, document) pairs to produce relevance scores. + Higher score = more relevant. No new package dependencies — + TextCrossEncoder is included in the fastembed package bundled with + qdrant-client[fastembed]. + """ + + DEFAULT_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" + + # Class-level cache mirrors Embedder's pattern + _cache: dict[str, object] = {} + + def __init__(self, model_name: str = DEFAULT_MODEL) -> None: + self._model_name = model_name + if model_name not in CrossEncoderEmbedder._cache: + from fastembed import TextCrossEncoder + + CrossEncoderEmbedder._cache[model_name] = TextCrossEncoder(model_name) + self._model = CrossEncoderEmbedder._cache[model_name] + + def rerank(self, query: str, documents: list[str]) -> list[float]: + """Return a relevance score for each (query, document) pair. + + Scores are raw logits — higher means more relevant. + The returned list is the same length as ``documents`` and in the same order. + """ + return list(self._model.rerank(query, documents)) # type: ignore[attr-defined] diff --git a/src/qql/executor.py b/src/qql/executor.py index 9408e7d..26d9a29 100644 --- a/src/qql/executor.py +++ b/src/qql/executor.py @@ -55,7 +55,9 @@ ShowCollectionsStmt, ) from .config import QQLConfig -from .embedder import Embedder, SparseEmbedder +from .embedder import CrossEncoderEmbedder, Embedder, SparseEmbedder + +_RERANK_FETCH_MULTIPLIER = 4 from .exceptions import QQLRuntimeError @@ -234,6 +236,10 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: self._build_qdrant_filter(node.query_filter) ) + # When reranking is requested, fetch more candidates so the reranker has + # enough material to reorder; only `node.limit` results are returned. + fetch_limit = node.limit * _RERANK_FETCH_MULTIPLIER if node.rerank else node.limit + # ── Hybrid SEARCH: prefetch dense+sparse, fuse with RRF ─────────── if node.hybrid: dense_model = node.model or self._config.default_model @@ -264,7 +270,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: ), ], query=FusionQuery(fusion=Fusion.RRF), - limit=node.limit, + limit=fetch_limit, query_filter=qdrant_filter, ) except UnexpectedResponse as e: @@ -274,6 +280,15 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} for h in response.points ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (hybrid, reranked)", + data=results, + ) + return ExecutionResult( success=True, message=f"Found {len(results)} result(s) (hybrid)", @@ -289,7 +304,7 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: response = self._client.query_points( collection_name=node.collection, query=vector, - limit=node.limit, + limit=fetch_limit, query_filter=qdrant_filter, ) except UnexpectedResponse as e: @@ -299,12 +314,37 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} for h in response.points ] + + if node.rerank: + results = self._apply_reranking(node.query_text, results, node.limit, node.rerank_model) + return ExecutionResult( + success=True, + message=f"Found {len(results)} result(s) (reranked)", + data=results, + ) + return ExecutionResult( success=True, message=f"Found {len(results)} result(s)", data=results, ) + def _apply_reranking( + self, + query: str, + results: list[dict], + limit: int, + rerank_model: str | None, + ) -> list[dict]: + """Re-score candidates with a cross-encoder and return top-``limit`` results.""" + model_name = rerank_model or CrossEncoderEmbedder.DEFAULT_MODEL + reranker = CrossEncoderEmbedder(model_name) + texts = [r["payload"].get("text", "") for r in results] + scores = reranker.rerank(query, texts) + for r, s in zip(results, scores): + r["score"] = round(float(s), 4) + return sorted(results, key=lambda r: r["score"], reverse=True)[:limit] + def _execute_delete(self, node: DeleteStmt) -> ExecutionResult: if not self._client.collection_exists(node.collection): raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") diff --git a/src/qql/lexer.py b/src/qql/lexer.py index 029a8a3..17ec703 100644 --- a/src/qql/lexer.py +++ b/src/qql/lexer.py @@ -15,6 +15,7 @@ class TokenKind(Enum): HYBRID = auto() DENSE = auto() SPARSE = auto() + RERANK = auto() CREATE = auto() DROP = auto() SHOW = auto() @@ -75,6 +76,7 @@ class TokenKind(Enum): "HYBRID": TokenKind.HYBRID, "DENSE": TokenKind.DENSE, "SPARSE": TokenKind.SPARSE, + "RERANK": TokenKind.RERANK, "CREATE": TokenKind.CREATE, "DROP": TokenKind.DROP, "SHOW": TokenKind.SHOW, diff --git a/src/qql/parser.py b/src/qql/parser.py index f141745..0595a3d 100644 --- a/src/qql/parser.py +++ b/src/qql/parser.py @@ -154,6 +154,14 @@ def _parse_search(self) -> SearchStmt: if self._peek().kind == TokenKind.WHERE: self._advance() # consume WHERE query_filter = self._parse_filter_expr() + rerank: bool = False + rerank_model: str | None = None + if self._peek().kind == TokenKind.RERANK: + self._advance() # consume RERANK + rerank = True + if self._peek().kind == TokenKind.MODEL: + self._advance() # consume MODEL + rerank_model = self._expect(TokenKind.STRING).value return SearchStmt( collection=collection, query_text=query_text, @@ -162,6 +170,8 @@ def _parse_search(self) -> SearchStmt: hybrid=hybrid, sparse_model=sparse_model, query_filter=query_filter, + rerank=rerank, + rerank_model=rerank_model, ) def _parse_delete(self) -> DeleteStmt: diff --git a/tests/test_executor.py b/tests/test_executor.py index fd70b8a..6c06b52 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -678,3 +678,180 @@ def test_unnamed_vector_mismatch_still_raises(self, executor, mock_client): ) with pytest.raises(QQLRuntimeError, match="dimension mismatch"): executor._ensure_collection("col", 384) + + +FAKE_SPARSE = {"indices": [1, 42, 100], "values": [0.22, 0.8, 0.3]} + + +class TestRerankSearch: + @pytest.fixture + def mock_cross_encoder(self, mocker): + mock = mocker.MagicMock() + mock.rerank.return_value = [0.9, 0.3, 0.7] + mocker.patch("qql.executor.CrossEncoderEmbedder", return_value=mock) + return mock + + def _make_point(self, mocker, id_, score, text): + p = mocker.MagicMock() + p.id = id_ + p.score = score + p.payload = {"text": text} + return p + + def test_rerank_calls_cross_encoder_with_query_and_texts( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + pts = [ + self._make_point(mocker, "a", 0.9, "doc A"), + self._make_point(mocker, "b", 0.5, "doc B"), + self._make_point(mocker, "c", 0.7, "doc C"), + ] + mock_resp = mocker.MagicMock() + mock_resp.points = pts + mock_client.query_points.return_value = mock_resp + + node = SearchStmt( + collection="col", query_text="my query", limit=3, model=None, rerank=True + ) + executor.execute(node) + mock_cross_encoder.rerank.assert_called_once_with( + "my query", ["doc A", "doc B", "doc C"] + ) + + def test_rerank_qdrant_fetches_multiplied_limit( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + mock_resp = mocker.MagicMock() + mock_resp.points = [] + mock_client.query_points.return_value = mock_resp + + node = SearchStmt( + collection="col", query_text="q", limit=5, model=None, rerank=True + ) + executor.execute(node) + kw = mock_client.query_points.call_args.kwargs + assert kw["limit"] == 5 * 4 # _RERANK_FETCH_MULTIPLIER + + def test_rerank_results_sorted_by_cross_encoder_score( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + pts = [ + self._make_point(mocker, "a", 0.9, "doc A"), + self._make_point(mocker, "b", 0.5, "doc B"), + self._make_point(mocker, "c", 0.7, "doc C"), + ] + mock_resp = mocker.MagicMock() + mock_resp.points = pts + mock_client.query_points.return_value = mock_resp + # scores: A→0.9, B→0.3, C→0.7 → sorted order: A, C, B + mock_cross_encoder.rerank.return_value = [0.9, 0.3, 0.7] + + node = SearchStmt( + collection="col", query_text="q", limit=3, model=None, rerank=True + ) + result = executor.execute(node) + ids = [r["id"] for r in result.data] + assert ids == ["a", "c", "b"] + + def test_rerank_slices_to_limit( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + pts = [self._make_point(mocker, str(i), 0.5, f"doc {i}") for i in range(8)] + mock_resp = mocker.MagicMock() + mock_resp.points = pts + mock_client.query_points.return_value = mock_resp + mock_cross_encoder.rerank.return_value = [float(i) for i in range(8)] + + node = SearchStmt( + collection="col", query_text="q", limit=3, model=None, rerank=True + ) + result = executor.execute(node) + assert len(result.data) == 3 + + def test_rerank_message_contains_reranked( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + mock_resp = mocker.MagicMock() + mock_resp.points = [] + mock_client.query_points.return_value = mock_resp + mock_cross_encoder.rerank.return_value = [] + + node = SearchStmt( + collection="col", query_text="q", limit=5, model=None, rerank=True + ) + result = executor.execute(node) + assert "reranked" in result.message + + def test_no_rerank_does_not_call_cross_encoder( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + mock_resp = mocker.MagicMock() + mock_resp.points = [] + mock_client.query_points.return_value = mock_resp + + node = SearchStmt( + collection="col", query_text="q", limit=5, model=None, rerank=False + ) + executor.execute(node) + mock_cross_encoder.rerank.assert_not_called() + + def test_no_rerank_uses_original_limit( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + mock_resp = mocker.MagicMock() + mock_resp.points = [] + mock_client.query_points.return_value = mock_resp + + node = SearchStmt( + collection="col", query_text="q", limit=5, model=None, rerank=False + ) + executor.execute(node) + kw = mock_client.query_points.call_args.kwargs + assert kw["limit"] == 5 + + def test_rerank_custom_model_forwarded( + self, executor, mock_client, mocker + ): + mock_client.collection_exists.return_value = True + mock_resp = mocker.MagicMock() + mock_resp.points = [] + mock_client.query_points.return_value = mock_resp + + mock_ce = mocker.MagicMock() + mock_ce.rerank.return_value = [] + ce_cls = mocker.patch("qql.executor.CrossEncoderEmbedder", return_value=mock_ce) + + node = SearchStmt( + collection="col", query_text="q", limit=5, model=None, + rerank=True, rerank_model="my-custom/reranker", + ) + executor.execute(node) + ce_cls.assert_called_once_with("my-custom/reranker") + + def test_rerank_hybrid_search_message( + self, executor, mock_client, mock_cross_encoder, mocker + ): + mock_client.collection_exists.return_value = True + mock_resp = mocker.MagicMock() + mock_resp.points = [] + mock_client.query_points.return_value = mock_resp + mock_cross_encoder.rerank.return_value = [] + + mock_sparse = mocker.MagicMock() + mock_sparse.query_embed.return_value = FAKE_SPARSE + mocker.patch("qql.executor.SparseEmbedder", return_value=mock_sparse) + + node = SearchStmt( + collection="col", query_text="q", limit=5, model=None, + hybrid=True, rerank=True, + ) + result = executor.execute(node) + assert "hybrid" in result.message + assert "reranked" in result.message diff --git a/tests/test_lexer.py b/tests/test_lexer.py index df41b59..d3af629 100644 --- a/tests/test_lexer.py +++ b/tests/test_lexer.py @@ -229,3 +229,27 @@ def test_sparse_as_identifier_in_dotted_path(self): tokens = tokenize("sparse.value") assert tokens[0].kind == TokenKind.IDENTIFIER assert tokens[0].value == "sparse.value" + + +class TestRerankKeyword: + def test_rerank_keyword_uppercase(self): + ks = kinds("RERANK") + assert ks[0] == TokenKind.RERANK + + def test_rerank_keyword_lowercase(self): + ks = kinds("rerank") + assert ks[0] == TokenKind.RERANK + + def test_rerank_keyword_mixed_case(self): + ks = kinds("Rerank") + assert ks[0] == TokenKind.RERANK + + def test_rerank_in_search_statement(self): + ks = kinds("SEARCH col SIMILAR TO 'q' LIMIT 5 RERANK") + assert TokenKind.RERANK in ks + + def test_rerank_with_model_in_search(self): + ks = kinds("SEARCH col SIMILAR TO 'q' LIMIT 5 RERANK MODEL 'x'") + rerank_idx = ks.index(TokenKind.RERANK) + assert ks[rerank_idx + 1] == TokenKind.MODEL + assert ks[rerank_idx + 2] == TokenKind.STRING diff --git a/tests/test_parser.py b/tests/test_parser.py index 19b2650..ba64939 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -477,3 +477,62 @@ def test_search_hybrid_dense_model_and_where(self): def test_search_hybrid_limit_preserved(self): node = parse("SEARCH col SIMILAR TO 'q' LIMIT 7 USING HYBRID") assert node.limit == 7 + + +class TestRerankSearch: + def test_rerank_flag_set(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 RERANK") + assert node.rerank is True + assert node.rerank_model is None + + def test_rerank_with_model(self): + node = parse( + "SEARCH col SIMILAR TO 'q' LIMIT 5 RERANK MODEL 'cross-encoder/ms-marco-MiniLM-L-6-v2'" + ) + assert node.rerank is True + assert node.rerank_model == "cross-encoder/ms-marco-MiniLM-L-6-v2" + + def test_rerank_default_false(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5") + assert node.rerank is False + assert node.rerank_model is None + + def test_rerank_with_using_model(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 USING MODEL 'BAAI/bge-small-en-v1.5' RERANK") + assert node.model == "BAAI/bge-small-en-v1.5" + assert node.rerank is True + + def test_rerank_with_hybrid(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 USING HYBRID RERANK") + assert node.hybrid is True + assert node.rerank is True + assert node.rerank_model is None + + def test_rerank_with_where(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 WHERE year > 2020 RERANK") + assert node.query_filter is not None + assert node.rerank is True + + def test_rerank_with_hybrid_where_and_model(self): + node = parse( + "SEARCH col SIMILAR TO 'q' LIMIT 5 USING HYBRID WHERE year > 2020 " + "RERANK MODEL 'cross-encoder/ms-marco-MiniLM-L-6-v2'" + ) + assert node.hybrid is True + assert node.query_filter is not None + assert node.rerank is True + assert node.rerank_model == "cross-encoder/ms-marco-MiniLM-L-6-v2" + + def test_rerank_lowercase(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 rerank") + assert node.rerank is True + + def test_rerank_model_custom(self): + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 5 RERANK MODEL 'my-custom/reranker'") + assert node.rerank_model == "my-custom/reranker" + + def test_existing_search_unaffected_by_rerank_addition(self): + """Existing parse calls without RERANK still produce rerank=False.""" + node = parse("SEARCH col SIMILAR TO 'q' LIMIT 10 USING MODEL 'BAAI/bge-small-en-v1.5'") + assert node.rerank is False + assert node.rerank_model is None