diff --git a/README.md b/README.md index a54f25a..18e157a 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,7 @@ qql> SEARCH notes SIMILAR TO 'vector databases' LIMIT 5 USING HYBRID RERANK - [INSERT — add a point](#insert--add-a-point) - [INSERT BULK — batch insert](#insert-bulk--batch-insert-multiple-points) - [SEARCH — find similar points](#search--find-similar-points) + - [RECOMMEND — retrieve by example IDs](#recommend--retrieve-by-example-ids) - [Query-Time Search Params (`EXACT`, `WITH`)](#query-time-search-params-exact-with) - [WHERE Clause Filters](#where-clause-filters) - [Hybrid Search (USING HYBRID)](#hybrid-search-using-hybrid) @@ -186,6 +187,8 @@ Inserts a new document into a collection. The `text` field is **mandatory** — If the collection does not exist yet, it is **created automatically** with the correct vector dimensions. +If you include an `id` field in `VALUES`, QQL uses it as the Qdrant point ID. Supported explicit IDs are unsigned integers or UUID strings. If you omit `id`, QQL generates a UUID automatically. + **Syntax:** ``` INSERT INTO COLLECTION VALUES {} @@ -204,6 +207,7 @@ INSERT INTO COLLECTION articles VALUES {'text': 'Qdrant supports cosine similari Insert with metadata: ```sql INSERT INTO COLLECTION articles VALUES { + 'id': 1001, 'text': 'Neural networks learn representations from data', 'author': 'alice', 'category': 'ml', @@ -231,13 +235,13 @@ INSERT INTO COLLECTION articles VALUES {'text': 'hello world'} **What happens internally:** 1. The `text` value is embedded into a dense vector using the configured model. 2. In hybrid mode, a sparse BM25 vector is also generated. -3. A UUID is auto-generated as the point ID. -4. All fields (including `text`) are stored in the payload. +3. If `id` is provided, it is used as the point ID; otherwise a UUID is auto-generated. +4. All fields except `id` are stored in the payload. 5. The point is upserted into Qdrant. **Rules:** - `text` is always required. Omitting it raises an error. -- A point ID (UUID) is generated automatically — you do not provide one. +- `id`, when provided, must be an unsigned integer or UUID string. - If the collection already exists with a different vector size (from a different model), an error is raised with a clear message. - Hybrid inserts require a hybrid collection (created with `CREATE COLLECTION ... HYBRID` or auto-created on first `USING HYBRID` insert). @@ -249,6 +253,8 @@ Inserts multiple documents in a single statement. Each item in the array must co If the collection does not exist yet, it is **created automatically** on the first bulk insert. +Each record may optionally include an `id` field. This is the preferred way to keep seed data deterministic and to make follow-up operations like `RECOMMEND` or `DELETE` reproducible. + **Syntax:** ``` INSERT BULK INTO COLLECTION VALUES [, , ...] @@ -271,9 +277,9 @@ INSERT BULK INTO COLLECTION articles VALUES [ Bulk insert with metadata: ```sql INSERT BULK INTO COLLECTION articles VALUES [ - {'text': 'Attention is all you need', 'author': 'vaswani', 'year': 2017}, - {'text': 'BERT: Pre-training of deep bidirectional transformers', 'author': 'devlin', 'year': 2018}, - {'text': 'Language models are few-shot learners', 'author': 'brown', 'year': 2020} + {'id': 1001, 'text': 'Attention is all you need', 'author': 'vaswani', 'year': 2017}, + {'id': 1002, 'text': 'BERT: Pre-training of deep bidirectional transformers', 'author': 'devlin', 'year': 2018}, + {'id': 1003, 'text': 'Language models are few-shot learners', 'author': 'brown', 'year': 2020} ] ``` @@ -288,7 +294,7 @@ INSERT BULK INTO COLLECTION articles VALUES [ **Rules:** - Every dict in the array must contain a `"text"` key. Missing `text` on any item raises an error with the offending index. - An empty array `[]` raises an error. -- A UUID is auto-generated for each point — you do not provide IDs. +- `id`, when provided, must be an unsigned integer or UUID string. - Supports all the same `USING` clauses as single `INSERT`. --- @@ -371,13 +377,111 @@ Results are displayed as a table with three columns: ``` - **Score** — similarity score. Higher is more relevant. -- **ID** — the UUID of the matching point. +- **ID** — the point ID returned by Qdrant. This may be an integer or a UUID string. - **Payload** — all fields stored alongside the vector. **Important:** Use the same model for SEARCH as you used for INSERT. Mixing models produces meaningless scores because the vectors live in different spaces. --- +### RECOMMEND — retrieve by example IDs + +Performs a Qdrant recommendation query using existing point IDs as positive and optional negative examples. + +This is useful when you already know which stored points represent the kind of result you want. Qdrant uses those examples to retrieve nearby points, and QQL automatically excludes the seed IDs from the results. + +**Syntax:** +```sql +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT +RECOMMEND FROM POSITIVE IDS (, ...) NEGATIVE IDS (, ...) LIMIT +RECOMMEND FROM POSITIVE IDS (, ...) STRATEGY '' LIMIT +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT WHERE +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT OFFSET +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT SCORE THRESHOLD +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT WITH { exact: true, hnsw_ef: } +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT LOOKUP FROM +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT LOOKUP FROM VECTOR '' +RECOMMEND FROM POSITIVE IDS (, ...) LIMIT USING '' +``` + +**Examples:** + +Recommend more results like two known articles: +```sql +RECOMMEND FROM articles POSITIVE IDS (1001, 1002) LIMIT 5 +``` + +Recommend similar results while steering away from one bad example: +```sql +RECOMMEND FROM articles POSITIVE IDS (1001, 1002) NEGATIVE IDS (1009) LIMIT 5 +``` + +Use Qdrant's `best_score` recommendation strategy: +```sql +RECOMMEND FROM articles POSITIVE IDS (1001) STRATEGY 'best_score' LIMIT 10 +``` + +Recommend only within a filtered subset: +```sql +RECOMMEND FROM articles POSITIVE IDS (1001) LIMIT 5 WHERE year >= 2020 AND status = 'published' +``` + +Paginate recommendations (skip first 5, return next 10): +```sql +RECOMMEND FROM articles POSITIVE IDS (1001) LIMIT 10 OFFSET 5 +``` + +Filter out low-confidence recommendations: +```sql +RECOMMEND FROM articles POSITIVE IDS (1001) LIMIT 10 SCORE THRESHOLD 0.5 +``` + +Exact KNN baseline for recommendations: +```sql +RECOMMEND FROM articles POSITIVE IDS (1001) LIMIT 5 WITH { exact: true } +``` + +Cross-collection recommend (look up example IDs from another collection): +```sql +RECOMMEND FROM target_collection + POSITIVE IDS ('a') + LOOKUP FROM source_collection VECTOR 'dense' + LIMIT 5 +``` + +Recommend using a specific named vector in the target collection: +```sql +RECOMMEND FROM articles + POSITIVE IDS (1001) + USING 'sparse' + LIMIT 5 +``` + +Full-featured recommend: +```sql +RECOMMEND FROM articles + POSITIVE IDS (1001, 1002) + NEGATIVE IDS (1009) + STRATEGY 'best_score' + LOOKUP FROM other_collection VECTOR 'dense' + USING 'dense' + LIMIT 10 + OFFSET 5 + SCORE THRESHOLD 0.5 + WHERE year >= 2020 + WITH { exact: true } +``` + +**Supported strategies:** + +- `average_vector` +- `best_score` +- `sum_scores` + +**Clause order:** `POSITIVE IDS` → `NEGATIVE IDS` → `STRATEGY` → `LOOKUP FROM` → `USING` → `LIMIT` → `OFFSET` → `SCORE THRESHOLD` → `WHERE` → `WITH` + +--- + ### Query-Time Search Params (`EXACT`, `WITH`) QQL supports a small set of Qdrant query-time search parameters on `SEARCH` statements. @@ -818,7 +922,7 @@ Raises an error if the collection does not exist. ### DELETE — remove a point -Deletes a single point from a collection by its ID. The point ID is the UUID returned by INSERT. +Deletes a single point from a collection by its ID. The ID may be an integer or a UUID string, either generated by QQL or supplied explicitly on INSERT. **Syntax:** ``` @@ -890,9 +994,14 @@ SHOW COLLECTIONS **Rules:** - `--` to end-of-line is a comment and is ignored (inline or full-line) - Statements can span multiple lines (e.g. `INSERT BULK ... VALUES [...]`) +- `RECOMMEND` statements work in `.qql` files the same way they do in the REPL - Blank lines between statements are ignored - By default all statements run even if one fails; use `--stop-on-error` to halt early +**Included examples:** +- [`resources/sample.qql`](resources/sample.qql) seeds the demo medical dataset +- [`resources/sample_v2.qql`](resources/sample_v2.qql) is a compact end-to-end example with explicit IDs and runnable `RECOMMEND` statements + **Example output:** ``` Executing: /path/to/script.qql @@ -1165,15 +1274,15 @@ result = run_query( "INSERT INTO COLLECTION notes VALUES {'text': 'hello world', 'author': 'alice', 'year': 2024}", url="http://localhost:6333", ) -print(result.message) # "Inserted 1 point []" -print(result.data) # {"id": "...", "collection": "notes"} +print(result.message) # "Inserted 1 point []" +print(result.data) # {"id": 1001 or "", "collection": "notes"} # Insert with hybrid vectors result = run_query( "INSERT INTO COLLECTION notes VALUES {'text': 'hello world'} USING HYBRID", url="http://localhost:6333", ) -print(result.message) # "Inserted 1 point [] (hybrid)" +print(result.message) # "Inserted 1 point [] (hybrid)" # Dense search with WHERE filter result = run_query( @@ -1228,9 +1337,10 @@ class ExecutionResult: | Operation | `result.data` type | |---|---| -| INSERT (dense) | `{"id": "", "collection": ""}` | -| INSERT (hybrid) | `{"id": "", "collection": ""}` | +| INSERT (dense) | `{"id": int | "", "collection": ""}` | +| INSERT (hybrid) | `{"id": int | "", "collection": ""}` | | SEARCH | `[{"id": str, "score": float, "payload": dict}, ...]` | +| RECOMMEND | `[{"id": str, "score": float, "payload": dict}, ...]` | | SHOW COLLECTIONS | `["name1", "name2", ...]` | | CREATE COLLECTION | `None` | | DROP COLLECTION | `None` | diff --git a/resources/sample_v2.qql b/resources/sample_v2.qql new file mode 100644 index 0000000..a84348d --- /dev/null +++ b/resources/sample_v2.qql @@ -0,0 +1,124 @@ +-- QQL sample v2 +-- Compact end-to-end showcase for deterministic inserts, search, query-time +-- params, recommendation, and hybrid retrieval. + +SHOW COLLECTIONS + +-- Dense collection +CREATE COLLECTION qql_sample_v2 + +INSERT BULK INTO COLLECTION qql_sample_v2 VALUES [ + { + 'id': 1001, + 'text': 'STEMI requires emergent revascularization with primary PCI and dual antiplatelet therapy.', + 'department': 'cardiology', + 'topic': 'acute_coronary_syndrome', + 'year': 2024 + }, + { + 'id': 1002, + 'text': 'Heart failure with reduced ejection fraction is treated with ARNI, beta-blocker, MRA, and SGLT2 inhibitor therapy.', + 'department': 'cardiology', + 'topic': 'heart_failure', + 'year': 2024 + }, + { + 'id': 2001, + 'text': 'Acute ischemic stroke management includes rapid imaging, alteplase within the treatment window, and thrombectomy in selected patients.', + 'department': 'neurology', + 'topic': 'stroke', + 'year': 2024 + }, + { + 'id': 2002, + 'text': 'Transient ischemic attack requires urgent secondary prevention and vascular risk stratification.', + 'department': 'neurology', + 'topic': 'stroke', + 'year': 2023 + }, + { + 'id': 2003, + 'text': 'Secondary stroke prevention includes antiplatelet therapy, statins, blood pressure control, and carotid evaluation when indicated.', + 'department': 'neurology', + 'topic': 'stroke_prevention', + 'year': 2024 + }, + { + 'id': 3001, + 'text': 'COPD exacerbations are managed with bronchodilators, corticosteroids, and antibiotics when indicated.', + 'department': 'pulmonology', + 'topic': 'copd', + 'year': 2024 + } +] + +-- Basic dense search +SEARCH qql_sample_v2 SIMILAR TO 'stroke thrombolysis and thrombectomy' LIMIT 3 + +-- Dense search with filter +SEARCH qql_sample_v2 SIMILAR TO 'secondary stroke prevention' LIMIT 3 WHERE department = 'neurology' + +-- Query-time search params +SEARCH qql_sample_v2 SIMILAR TO 'acute coronary syndrome' LIMIT 3 EXACT +SEARCH qql_sample_v2 SIMILAR TO 'stroke prevention' LIMIT 3 WITH { hnsw_ef: 128 } + +-- Recommendation from known example IDs +RECOMMEND FROM qql_sample_v2 + POSITIVE IDS (2001) + LIMIT 3 + +RECOMMEND FROM qql_sample_v2 + POSITIVE IDS (2001, 2002) + NEGATIVE IDS (1001) + STRATEGY 'best_score' + LIMIT 3 + WHERE department = 'neurology' + +-- Recommend with pagination and score threshold +RECOMMEND FROM qql_sample_v2 + POSITIVE IDS (2001) + LIMIT 5 + OFFSET 2 + SCORE THRESHOLD 0.3 + +-- Recommend with exact KNN baseline +RECOMMEND FROM qql_sample_v2 + POSITIVE IDS (2001) + LIMIT 3 + WITH { exact: true } + +-- Recommend using sparse vector instead of dense +RECOMMEND FROM qql_sample_v2_hybrid + POSITIVE IDS (4001) + LIMIT 3 + USING 'sparse' + +-- Hybrid collection +CREATE COLLECTION qql_sample_v2_hybrid HYBRID + +INSERT BULK INTO COLLECTION qql_sample_v2_hybrid VALUES [ + { + 'id': 4001, + 'text': 'Transformer attention mechanisms improve long-context sequence modeling and retrieval quality.', + 'domain': 'ml', + 'year': 2024 + }, + { + 'id': 4002, + 'text': 'Sparse retrieval with BM25 remains strong for exact terminology and keyword-heavy document search.', + 'domain': 'ir', + 'year': 2023 + }, + { + 'id': 4003, + 'text': 'Hybrid retrieval combines dense semantic matching with sparse keyword search using reciprocal rank fusion.', + 'domain': 'ir', + 'year': 2024 + } +] USING HYBRID + +-- Hybrid and sparse-only search +SEARCH qql_sample_v2_hybrid SIMILAR TO 'keyword retrieval and bm25' LIMIT 3 USING HYBRID +SEARCH qql_sample_v2_hybrid SIMILAR TO 'keyword retrieval and bm25' LIMIT 3 USING SPARSE + +SHOW COLLECTIONS diff --git a/src/qql/ast_nodes.py b/src/qql/ast_nodes.py index fc95611..7bd64bf 100644 --- a/src/qql/ast_nodes.py +++ b/src/qql/ast_nodes.py @@ -169,6 +169,22 @@ class SearchStmt: rerank_model: str | None = None # cross-encoder model; None → CrossEncoderEmbedder.DEFAULT_MODEL with_clause: SearchWith | None = None + +@dataclass(frozen=True) +class RecommendStmt: + collection: str + positive_ids: tuple[str | int, ...] + negative_ids: tuple[str | int, ...] = () + limit: int = 10 + strategy: str | None = None + query_filter: FilterExpr | None = None + offset: int = 0 + score_threshold: float | None = None + with_clause: SearchWith | None = None + lookup_from: tuple[str, str | None] | None = None + using: str | None = None + + @dataclass(frozen=True) class DeleteStmt: collection: str @@ -183,5 +199,6 @@ class DeleteStmt: | DropCollectionStmt | ShowCollectionsStmt | SearchStmt + | RecommendStmt | DeleteStmt ) diff --git a/src/qql/cli.py b/src/qql/cli.py index 0f7f996..bfc169a 100644 --- a/src/qql/cli.py +++ b/src/qql/cli.py @@ -25,11 +25,13 @@ [yellow]INSERT INTO COLLECTION[/yellow] [yellow]VALUES[/yellow] {[yellow]'text'[/yellow]: '...', ...} Insert a point. 'text' is required and auto-vectorized. + Optional: include [yellow]'id'[/yellow] in VALUES as an integer or UUID Optional: [yellow]USING MODEL[/yellow] '' Optional: [yellow]USING HYBRID[/yellow] [DENSE MODEL ''] [SPARSE MODEL ''] [yellow]INSERT BULK INTO COLLECTION[/yellow] [yellow]VALUES[/yellow] [{[yellow]'text'[/yellow]: '...', ...}, ...] Batch insert multiple points in a single call. Each dict must contain 'text'. + Optional: each dict may include [yellow]'id'[/yellow] as an integer or UUID. Supports the same [yellow]USING[/yellow] clauses as INSERT. [yellow]CREATE COLLECTION[/yellow] [[yellow]HYBRID[/yellow]] @@ -53,6 +55,13 @@ Optional: [yellow]EXACT[/yellow] bypass HNSW and perform exact search Optional: [yellow]WITH[/yellow] { hnsw_ef: , exact: , acorn: } search parameters + [yellow]RECOMMEND FROM[/yellow] [yellow]POSITIVE IDS[/yellow] (, ...) + Find points similar to known examples. + Optional: [yellow]NEGATIVE IDS[/yellow] (, ...) + Optional: [yellow]STRATEGY[/yellow] 'average_vector|best_score|sum_scores' + Optional: [yellow]WHERE[/yellow] + Requires: [yellow]LIMIT[/yellow] + [yellow]DELETE FROM[/yellow] [yellow]WHERE id =[/yellow] '' Delete a point by its ID. @@ -66,8 +75,8 @@ The file can be re-imported with EXECUTE. Keyboard shortcuts: - ← → arrows move cursor within the current line - ↑ ↓ arrows scroll through command history + Left/Right arrows move cursor within the current line + Up/Down arrows scroll through command history Ctrl-A / Ctrl-E jump to beginning / end of line Ctrl-C cancel current input Ctrl-D exit shell @@ -215,7 +224,7 @@ def dump(collection: str, output: str) -> None: from .dumper import dump_collection console.print( - f"[bold cyan]Dumping:[/bold cyan] '{collection}' → {output}\n" + f"[bold cyan]Dumping:[/bold cyan] '{collection}' -> {output}\n" ) written, skipped = dump_collection(collection, output, client, console, err_console) @@ -315,7 +324,7 @@ def _launch_repl(cfg: QQLConfig) -> None: coll_name, out_path = parts[1], parts[2] from .dumper import dump_collection console.print( - f"[bold cyan]Dumping:[/bold cyan] '{coll_name}' → {out_path}\n" + f"[bold cyan]Dumping:[/bold cyan] '{coll_name}' -> {out_path}\n" ) written, skipped = dump_collection( coll_name, out_path, client, console, err_console @@ -348,7 +357,7 @@ def _run_and_print(executor: Executor, query: str) -> None: err_console.print(f"[bold red]Failed:[/bold red] {result.message}") return - console.print(f"[bold green]✓[/bold green] {result.message}") + console.print(f"[bold green]OK[/bold green] {result.message}") if result.data is None: return diff --git a/src/qql/dumper.py b/src/qql/dumper.py index e2b590b..3b63156 100644 --- a/src/qql/dumper.py +++ b/src/qql/dumper.py @@ -160,7 +160,9 @@ def dump_collection( if "text" not in payload: skipped += 1 continue - valid.append(payload) + dump_payload = dict(payload) + dump_payload["id"] = rec.id + valid.append(dump_payload) if valid: f.write( diff --git a/src/qql/executor.py b/src/qql/executor.py index 88166de..ee78a69 100644 --- a/src/qql/executor.py +++ b/src/qql/executor.py @@ -1,5 +1,6 @@ from __future__ import annotations +import time import uuid from dataclasses import dataclass from typing import Any @@ -13,8 +14,10 @@ Filter, Fusion, FusionQuery, + HasIdCondition, IsEmptyCondition, IsNullCondition, + LookupLocation, MatchAny, MatchExcept, MatchPhrase, @@ -26,6 +29,9 @@ PointStruct, Prefetch, Range, + RecommendInput, + RecommendQuery, + RecommendStrategy, SearchParams, SparseVector, SparseVectorParams, @@ -54,6 +60,7 @@ NotExpr, NotInExpr, OrExpr, + RecommendStmt, SearchStmt, SearchWith, ShowCollectionsStmt, @@ -62,6 +69,8 @@ from .embedder import CrossEncoderEmbedder, Embedder, SparseEmbedder _RERANK_FETCH_MULTIPLIER = 4 +_COLLECTION_VISIBILITY_TIMEOUT_SECONDS = 5.0 +_COLLECTION_VISIBILITY_POLL_SECONDS = 0.05 from .exceptions import QQLRuntimeError @@ -90,6 +99,8 @@ def execute(self, node: ASTNode) -> ExecutionResult: return self._execute_show(node) if isinstance(node, SearchStmt): return self._execute_search(node) + if isinstance(node, RecommendStmt): + return self._execute_recommend(node) if isinstance(node, DeleteStmt): return self._execute_delete(node) raise QQLRuntimeError(f"Unknown AST node type: {type(node)}") @@ -120,7 +131,7 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: # Auto-create hybrid collection if it doesn't exist yet if not self._client.collection_exists(node.collection): - self._client.create_collection( + self._create_collection_and_wait( collection_name=node.collection, vectors_config={ "dense": VectorParams( @@ -132,15 +143,16 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: }, ) - point_id = str(uuid.uuid4()) + point_id, payload = self._extract_point_id_and_payload(node.values) try: self._client.upsert( collection_name=node.collection, + wait=True, points=[ PointStruct( id=point_id, vector={"dense": dense_vector, "sparse": sparse_vector}, - payload=dict(node.values), + payload=payload, ) ], ) @@ -160,12 +172,12 @@ def _execute_insert(self, node: InsertStmt) -> ExecutionResult: self._ensure_collection(node.collection, len(vector)) - point_id = str(uuid.uuid4()) - payload = dict(node.values) + point_id, payload = self._extract_point_id_and_payload(node.values) try: self._client.upsert( collection_name=node.collection, + wait=True, points=[PointStruct(id=point_id, vector=vector, payload=payload)], ) except UnexpectedResponse as e: @@ -199,23 +211,23 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: points: list[PointStruct] = [] for vals in node.values_list: + point_id, payload = self._extract_point_id_and_payload(vals) dense_vector = dense_embedder.embed(vals["text"]) sparse_obj = sparse_embedder.embed(vals["text"]) sparse_vector = SparseVector( indices=sparse_obj["indices"], values=sparse_obj["values"] ) - point_id = str(uuid.uuid4()) points.append( PointStruct( id=point_id, vector={"dense": dense_vector, "sparse": sparse_vector}, - payload=dict(vals), + payload=payload, ) ) if not self._client.collection_exists(node.collection): first_dense = dense_embedder.embed(node.values_list[0]["text"]) - self._client.create_collection( + self._create_collection_and_wait( collection_name=node.collection, vectors_config={ "dense": VectorParams(size=len(first_dense), distance=Distance.COSINE) @@ -226,7 +238,11 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: ) try: - self._client.upsert(collection_name=node.collection, points=points) + self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) except UnexpectedResponse as e: raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e @@ -242,15 +258,19 @@ def _execute_insert_bulk(self, node: InsertBulkStmt) -> ExecutionResult: points = [] for vals in node.values_list: vector = embedder.embed(vals["text"]) - point_id = str(uuid.uuid4()) + point_id, payload = self._extract_point_id_and_payload(vals) points.append( - PointStruct(id=point_id, vector=vector, payload=dict(vals)) + PointStruct(id=point_id, vector=vector, payload=payload) ) self._ensure_collection(node.collection, len(points[0].vector)) try: - self._client.upsert(collection_name=node.collection, points=points) + self._client.upsert( + collection_name=node.collection, + wait=True, + points=points, + ) except UnexpectedResponse as e: raise QQLRuntimeError(f"Qdrant error during INSERT BULK: {e}") from e @@ -272,7 +292,7 @@ def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: if node.hybrid: embedder = Embedder(dense_model_name) dims = embedder.dimensions - self._client.create_collection( + self._create_collection_and_wait( collection_name=node.collection, vectors_config={ "dense": VectorParams(size=dims, distance=Distance.COSINE) @@ -292,7 +312,7 @@ def _execute_create(self, node: CreateCollectionStmt) -> ExecutionResult: # ── Standard dense-only collection ───────────────────────────────── embedder = Embedder(dense_model_name) dims = embedder.dimensions - self._client.create_collection( + self._create_collection_and_wait( collection_name=node.collection, vectors_config=VectorParams(size=dims, distance=Distance.COSINE), ) @@ -470,6 +490,61 @@ def _execute_search(self, node: SearchStmt) -> ExecutionResult: data=results, ) + def _execute_recommend(self, node: RecommendStmt) -> ExecutionResult: + if not self._client.collection_exists(node.collection): + raise QQLRuntimeError(f"Collection '{node.collection}' does not exist") + + qdrant_filter: Filter | None = None + if node.query_filter is not None: + qdrant_filter = self._wrap_as_filter( + self._build_qdrant_filter(node.query_filter) + ) + qdrant_filter = self._exclude_ids_from_filter( + qdrant_filter, + [*node.positive_ids, *node.negative_ids], + ) + + recommend_input = RecommendInput( + positive=list(node.positive_ids), + negative=list(node.negative_ids) or None, + strategy=self._parse_recommend_strategy(node.strategy), + ) + + search_params = self._build_search_params(node.with_clause) + + lookup_from: LookupLocation | None = None + if node.lookup_from is not None: + lookup_from = LookupLocation( + collection=node.lookup_from[0], + vector=node.lookup_from[1], + ) + + try: + response = self._client.query_points( + collection_name=node.collection, + query=RecommendQuery(recommend=recommend_input), + limit=node.limit, + offset=node.offset or None, + query_filter=qdrant_filter, + search_params=search_params, + score_threshold=node.score_threshold, + using=node.using, + lookup_from=lookup_from, + ) + except UnexpectedResponse as e: + raise QQLRuntimeError(f"Qdrant error during RECOMMEND: {e}") from e + + results = [ + {"id": str(h.id), "score": round(h.score, 4), "payload": h.payload} + for h in response.points + ] + + return ExecutionResult( + success=True, + message=f"Found {len(results)} recommendation(s)", + data=results, + ) + def _build_search_params(self, with_clause: SearchWith | None) -> SearchParams | None: if with_clause is None: return None @@ -479,6 +554,68 @@ def _build_search_params(self, with_clause: SearchWith | None) -> SearchParams | acorn=AcornSearchParams(enable=True) if with_clause.acorn else None, ) + def _parse_recommend_strategy( + self, strategy: str | None + ) -> RecommendStrategy | None: + if strategy is None: + return None + try: + return RecommendStrategy(strategy) + except ValueError as e: + raise QQLRuntimeError( + "Unknown recommend strategy " + f"'{strategy}'. Expected one of: average_vector, best_score, sum_scores" + ) from e + + def _exclude_ids_from_filter( + self, + query_filter: Filter | None, + point_ids: list[str | int], + ) -> Filter | None: + if not point_ids: + return query_filter + + exclude_condition = HasIdCondition(has_id=point_ids) + if query_filter is None: + return Filter(must_not=[exclude_condition]) + + return Filter( + must=list(query_filter.must or []), + should=list(query_filter.should or []), + must_not=[*(query_filter.must_not or []), exclude_condition], + min_should=query_filter.min_should, + ) + + def _extract_point_id_and_payload( + self, values: dict[str, Any] + ) -> tuple[str | int, dict[str, Any]]: + payload = dict(values) + if "id" not in payload: + return str(uuid.uuid4()), payload + + point_id = payload.pop("id") + if isinstance(point_id, bool): + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + if isinstance(point_id, int): + if point_id < 0: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + return point_id, payload + if isinstance(point_id, str): + try: + uuid.UUID(point_id) + except ValueError as e: + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) from e + return point_id, payload + raise QQLRuntimeError( + "INSERT id must be an unsigned integer or UUID string when provided" + ) + def _get_dense_vector_name(self, collection_name: str) -> str | None: """Return the dense vector name for named-vector collections. @@ -516,6 +653,7 @@ def _execute_delete(self, node: DeleteStmt) -> ExecutionResult: try: self._client.delete( collection_name=node.collection, + wait=True, points_selector=PointIdsList(points=[node.point_id]), ) except UnexpectedResponse as e: @@ -651,7 +789,21 @@ def _ensure_collection(self, name: str, vector_size: int) -> None: f"Specify a compatible model with USING MODEL ''." ) else: - self._client.create_collection( + self._create_collection_and_wait( collection_name=name, vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), ) + + def _create_collection_and_wait(self, **kwargs: Any) -> None: + collection_name = kwargs["collection_name"] + self._client.create_collection(**kwargs) + + deadline = time.monotonic() + _COLLECTION_VISIBILITY_TIMEOUT_SECONDS + while time.monotonic() < deadline: + if self._client.collection_exists(collection_name): + return + time.sleep(_COLLECTION_VISIBILITY_POLL_SECONDS) + + raise QQLRuntimeError( + f"Collection '{collection_name}' was created but did not become visible in time" + ) diff --git a/src/qql/lexer.py b/src/qql/lexer.py index 90b705d..ae49247 100644 --- a/src/qql/lexer.py +++ b/src/qql/lexer.py @@ -25,9 +25,19 @@ class TokenKind(Enum): SHOW = auto() COLLECTIONS = auto() SEARCH = auto() + RECOMMEND = auto() + POSITIVE = auto() + NEGATIVE = auto() + IDS = auto() + STRATEGY = auto() SIMILAR = auto() TO = auto() LIMIT = auto() + OFFSET = auto() + SCORE = auto() + THRESHOLD = auto() + LOOKUP = auto() + VECTOR = auto() DELETE = auto() FROM = auto() WHERE = auto() @@ -90,9 +100,19 @@ class TokenKind(Enum): "SHOW": TokenKind.SHOW, "COLLECTIONS": TokenKind.COLLECTIONS, "SEARCH": TokenKind.SEARCH, + "RECOMMEND": TokenKind.RECOMMEND, + "POSITIVE": TokenKind.POSITIVE, + "NEGATIVE": TokenKind.NEGATIVE, + "IDS": TokenKind.IDS, + "STRATEGY": TokenKind.STRATEGY, "SIMILAR": TokenKind.SIMILAR, "TO": TokenKind.TO, "LIMIT": TokenKind.LIMIT, + "OFFSET": TokenKind.OFFSET, + "SCORE": TokenKind.SCORE, + "THRESHOLD": TokenKind.THRESHOLD, + "LOOKUP": TokenKind.LOOKUP, + "VECTOR": TokenKind.VECTOR, "DELETE": TokenKind.DELETE, "FROM": TokenKind.FROM, "WHERE": TokenKind.WHERE, diff --git a/src/qql/parser.py b/src/qql/parser.py index f466a76..54ac119 100644 --- a/src/qql/parser.py +++ b/src/qql/parser.py @@ -22,6 +22,7 @@ NotExpr, NotInExpr, OrExpr, + RecommendStmt, SearchStmt, SearchWith, ShowCollectionsStmt, @@ -59,6 +60,8 @@ def parse(self) -> ASTNode: node = self._parse_show() elif tok.kind == TokenKind.SEARCH: node = self._parse_search() + elif tok.kind == TokenKind.RECOMMEND: + node = self._parse_recommend() elif tok.kind == TokenKind.DELETE: node = self._parse_delete() else: @@ -275,6 +278,79 @@ def _parse_search(self) -> SearchStmt: with_clause=with_clause, ) + def _parse_recommend(self) -> RecommendStmt: + self._expect(TokenKind.RECOMMEND) + self._expect(TokenKind.FROM) + collection = self._parse_identifier() + self._expect(TokenKind.POSITIVE) + self._expect(TokenKind.IDS) + positive_ids = self._parse_point_id_list() + + negative_ids: tuple[str | int, ...] = () + if self._peek().kind == TokenKind.NEGATIVE: + self._advance() + self._expect(TokenKind.IDS) + negative_ids = self._parse_point_id_list() + + strategy: str | None = None + if self._peek().kind == TokenKind.STRATEGY: + self._advance() + strategy = self._expect(TokenKind.STRING).value + + lookup_from: tuple[str, str | None] | None = None + if self._peek().kind == TokenKind.LOOKUP: + self._advance() + self._expect(TokenKind.FROM) + lookup_collection = self._parse_identifier() + lookup_vector: str | None = None + if self._peek().kind == TokenKind.VECTOR: + self._advance() + lookup_vector = self._expect(TokenKind.STRING).value + lookup_from = (lookup_collection, lookup_vector) + + using: str | None = None + if self._peek().kind == TokenKind.USING: + self._advance() + using = self._expect(TokenKind.STRING).value + + self._expect(TokenKind.LIMIT) + limit = int(self._expect(TokenKind.INTEGER).value) + + offset: int = 0 + if self._peek().kind == TokenKind.OFFSET: + self._advance() + offset = int(self._expect(TokenKind.INTEGER).value) + + score_threshold: float | None = None + if self._peek().kind == TokenKind.SCORE: + self._advance() + self._expect(TokenKind.THRESHOLD) + score_threshold = float(self._expect(TokenKind.FLOAT).value) + + query_filter: FilterExpr | None = None + if self._peek().kind == TokenKind.WHERE: + self._advance() + query_filter = self._parse_filter_expr() + + with_clause: SearchWith | None = None + if self._peek().kind == TokenKind.WITH: + self._advance() + with_clause = self._parse_with_clause() + + return RecommendStmt( + collection=collection, + positive_ids=positive_ids, + negative_ids=negative_ids, + limit=limit, + strategy=strategy, + query_filter=query_filter, + offset=offset, + score_threshold=score_threshold, + with_clause=with_clause, + lookup_from=lookup_from, + using=using, + ) + def _parse_delete(self) -> DeleteStmt: self._expect(TokenKind.DELETE) self._expect(TokenKind.FROM) @@ -417,12 +493,29 @@ def _parse_predicate(self) -> FilterExpr: def _parse_field_path(self) -> str: """Dot-notation paths are already single IDENTIFIER tokens from the lexer.""" tok = self._peek() - if tok.kind != TokenKind.IDENTIFIER: - raise QQLSyntaxError( - f"Expected a field name, got '{tok.value}'", tok.pos - ) - self._advance() - return tok.value + if tok.kind == TokenKind.IDENTIFIER: + self._advance() + return tok.value + # Allow bare keywords to serve as field names (e.g. score, limit), + # but not filter operator keywords or literal tokens. + if tok.kind not in { + TokenKind.AND, TokenKind.OR, TokenKind.NOT, + TokenKind.IN, TokenKind.BETWEEN, TokenKind.IS, + TokenKind.NULL, TokenKind.EMPTY, TokenKind.MATCH, + TokenKind.ANY, TokenKind.PHRASE, + TokenKind.STRING, TokenKind.INTEGER, TokenKind.FLOAT, + TokenKind.LPAREN, TokenKind.RPAREN, + TokenKind.LBRACE, TokenKind.RBRACE, + TokenKind.LBRACKET, TokenKind.RBRACKET, + TokenKind.COMMA, TokenKind.COLON, TokenKind.EQUALS, + TokenKind.NOT_EQUALS, TokenKind.GT, TokenKind.GTE, + TokenKind.LT, TokenKind.LTE, TokenKind.EOF, + }: + self._advance() + return tok.value + raise QQLSyntaxError( + f"Expected a field name, got '{tok.value}'", tok.pos + ) def _parse_literal(self) -> str | int | float: """STRING | INTEGER | FLOAT""" @@ -472,6 +565,33 @@ def _parse_literal_list(self) -> list[str | int | float]: self._expect(TokenKind.RPAREN) return items + def _parse_point_id_list(self) -> tuple[str | int, ...]: + self._expect(TokenKind.LPAREN) + items: list[str | int] = [] + if self._peek().kind == TokenKind.RPAREN: + raise QQLSyntaxError("Expected at least one point id", self._peek().pos) + while True: + tok = self._peek() + if tok.kind == TokenKind.STRING: + self._advance() + items.append(tok.value) + elif tok.kind == TokenKind.INTEGER: + self._advance() + items.append(int(tok.value)) + else: + raise QQLSyntaxError( + f"Expected string or integer point id, got '{tok.value}'", + tok.pos, + ) + if self._peek().kind == TokenKind.COMMA: + self._advance() + if self._peek().kind == TokenKind.RPAREN: + break + else: + break + self._expect(TokenKind.RPAREN) + return tuple(items) + # ── Dict / value parsers (for INSERT VALUES) ────────────────────────── def _parse_identifier(self) -> str: diff --git a/src/qql/script.py b/src/qql/script.py index 9b138a1..534e749 100644 --- a/src/qql/script.py +++ b/src/qql/script.py @@ -25,6 +25,7 @@ TokenKind.DROP, TokenKind.SHOW, TokenKind.SEARCH, + TokenKind.RECOMMEND, TokenKind.DELETE, } @@ -53,7 +54,8 @@ def split_statements(tokens: list[Token]) -> list[list[Token]]: """Split a flat token list into per-statement chunks. A new chunk begins whenever a statement-starter keyword (INSERT, CREATE, - DROP, SHOW, SEARCH, DELETE) is encountered at brace/bracket/paren depth 0. + DROP, SHOW, SEARCH, RECOMMEND, DELETE) is encountered at + brace/bracket/paren depth 0. The EOF sentinel is consumed and never included in any chunk. """ chunks: list[list[Token]] = [] @@ -138,19 +140,19 @@ def run_script( node = Parser(chunk + [eof_tok]).parse() result = executor.execute(node) except QQLError as e: - err_console.print(f" [bold red]✗[/bold red] {e}") + err_console.print(f" [bold red]x[/bold red] {e}") failed += 1 if stop_on_error: break continue except Exception as e: - err_console.print(f" [bold red]✗ Unexpected error:[/bold red] {e}") + err_console.print(f" [bold red]x Unexpected error:[/bold red] {e}") failed += 1 if stop_on_error: break continue - console.print(f" [bold green]✓[/bold green] {result.message}") + console.print(f" [bold green]OK[/bold green] {result.message}") succeeded += 1 return succeeded, failed diff --git a/tests/test_dumper.py b/tests/test_dumper.py index feb556a..26bffd4 100644 --- a/tests/test_dumper.py +++ b/tests/test_dumper.py @@ -20,10 +20,11 @@ def null_console() -> Console: return Console(quiet=True) -def _make_record(mocker, payload: dict): +def _make_record(mocker, payload: dict, point_id="rec-1"): """Create a mock Qdrant ScoredPoint / Record with the given payload.""" rec = mocker.MagicMock() rec.payload = payload + rec.id = point_id return rec @@ -52,7 +53,7 @@ def _make_client(mocker, *, exists=True, hybrid=False, points=None, total=None): client.count.return_value = cnt # scroll — single-batch by default - records = [mocker.MagicMock(payload=p) for p in points] + records = [_make_record(mocker, p, f"id-{i}") for i, p in enumerate(points, 1)] client.scroll.return_value = (records, None) return client @@ -187,6 +188,13 @@ def test_payload_values_serialized_correctly(self, tmp_path, mocker): assert "'active': true" in content assert "'score':" in content + def test_dump_preserves_point_id_in_insert_values(self, tmp_path, mocker): + out = str(tmp_path / "dump.qql") + client = _make_client(mocker, points=[{"text": "hello"}]) + dump_collection("col", out, client, null_console(), null_console()) + content = (tmp_path / "dump.qql").read_text() + assert "'id': 'id-1'" in content + def test_batches_multiple_scroll_pages(self, tmp_path, mocker): """When scroll returns two pages, two INSERT BULK blocks should be written.""" out = str(tmp_path / "dump.qql") @@ -197,8 +205,8 @@ def test_batches_multiple_scroll_pages(self, tmp_path, mocker): cnt.count = _DUMP_BATCH_SIZE + 1 client.count.return_value = cnt - batch1 = [mocker.MagicMock(payload={"text": f"doc {i}"}) for i in range(_DUMP_BATCH_SIZE)] - batch2 = [mocker.MagicMock(payload={"text": "last doc"})] + batch1 = [_make_record(mocker, {"text": f"doc {i}"}, f"id-{i}") for i in range(_DUMP_BATCH_SIZE)] + batch2 = [_make_record(mocker, {"text": "last doc"}, "id-last")] # First scroll call returns batch1 with a non-None offset; second returns batch2 + None client.scroll.side_effect = [ (batch1, "some_offset"), diff --git a/tests/test_executor.py b/tests/test_executor.py index d930815..5ec3ebd 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -6,6 +6,7 @@ DropCollectionStmt, InsertBulkStmt, InsertStmt, + RecommendStmt, SearchStmt, SearchWith, ShowCollectionsStmt, @@ -28,6 +29,16 @@ def cfg(): def mock_client(mocker): client = mocker.MagicMock() client.collection_exists.return_value = False + state = {"exists": False} + + def collection_exists(_name): + return state["exists"] or bool(client.collection_exists.return_value) + + def create_collection(**_kwargs): + state["exists"] = True + + client.collection_exists.side_effect = collection_exists + client.create_collection.side_effect = create_collection return client @@ -72,6 +83,46 @@ def test_insert_result_contains_point_id(self, executor, mock_client): assert result.data["id"] is not None assert len(result.data["id"]) == 36 # UUID format + def test_insert_uses_explicit_uuid_id_when_provided(self, executor, mock_client): + node = InsertStmt( + collection="notes", + values={"id": "550e8400-e29b-41d4-a716-446655440000", "text": "hello"}, + model=None, + ) + result = executor.execute(node) + point = mock_client.upsert.call_args.kwargs["points"][0] + assert point.id == "550e8400-e29b-41d4-a716-446655440000" + assert "id" not in point.payload + assert result.data["id"] == "550e8400-e29b-41d4-a716-446655440000" + + def test_insert_uses_explicit_integer_id_when_provided(self, executor, mock_client): + node = InsertStmt( + collection="notes", + values={"id": 42, "text": "hello"}, + model=None, + ) + executor.execute(node) + point = mock_client.upsert.call_args.kwargs["points"][0] + assert point.id == 42 + + def test_insert_rejects_non_scalar_id(self, executor): + node = InsertStmt( + collection="notes", + values={"id": {"bad": "id"}, "text": "hello"}, + model=None, + ) + with pytest.raises(QQLRuntimeError, match="unsigned integer or UUID string"): + executor.execute(node) + + def test_insert_rejects_non_uuid_string_id(self, executor): + node = InsertStmt( + collection="notes", + values={"id": "note-1", "text": "hello"}, + model=None, + ) + with pytest.raises(QQLRuntimeError, match="unsigned integer or UUID string"): + executor.execute(node) + def test_insert_stores_text_in_payload(self, executor, mock_client): node = InsertStmt(collection="notes", values={"text": "hello"}, model=None) executor.execute(node) @@ -157,6 +208,20 @@ def test_bulk_insert_result_message_contains_count(self, executor, mock_client): assert "2" in result.message assert "points" in result.message + def test_bulk_insert_preserves_explicit_ids(self, executor, mock_client): + node = InsertBulkStmt( + collection="col", + values_list=( + {"id": "550e8400-e29b-41d4-a716-446655440001", "text": "a"}, + {"id": 2, "text": "b"}, + ), + model=None, + ) + executor.execute(node) + points = mock_client.upsert.call_args.kwargs["points"] + assert [point.id for point in points] == ["550e8400-e29b-41d4-a716-446655440001", 2] + assert all("id" not in point.payload for point in points) + def test_single_insert_unaffected_by_bulk_dispatch(self, executor, mock_client): """Ensure single INSERT still routes correctly after bulk dispatch added.""" node = InsertStmt(collection="notes", values={"text": "hello"}, model=None) @@ -351,6 +416,212 @@ def test_dense_search_against_hybrid_collection_uses_dense_vector_name( assert mock_client.query_points.call_args.kwargs["using"] == "dense" +class TestRecommend: + def test_recommend_calls_qdrant_query_points(self, executor, mock_client, mocker): + from qdrant_client.models import RecommendQuery + + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt(collection="notes", positive_ids=("a",), limit=5) + result = executor.execute(node) + + mock_client.query_points.assert_called_once() + assert isinstance(mock_client.query_points.call_args.kwargs["query"], RecommendQuery) + assert result.success is True + assert "recommendation" in result.message + + def test_recommend_excludes_seed_ids_from_results_filter( + self, executor, mock_client, mocker + ): + from qdrant_client.models import Filter, HasIdCondition + + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", + positive_ids=("a", 2), + negative_ids=("x",), + limit=5, + ) + executor.execute(node) + + query_filter = mock_client.query_points.call_args.kwargs["query_filter"] + assert isinstance(query_filter, Filter) + assert isinstance(query_filter.must_not[0], HasIdCondition) + assert query_filter.must_not[0].has_id == ["a", 2, "x"] + + def test_recommend_merges_where_filter_with_seed_exclusion( + self, executor, mock_client, mocker + ): + from qdrant_client.models import Filter + from qql.ast_nodes import CompareExpr + + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", + positive_ids=("a",), + limit=5, + query_filter=CompareExpr(field="year", op=">", value=2020), + ) + executor.execute(node) + + query_filter = mock_client.query_points.call_args.kwargs["query_filter"] + assert isinstance(query_filter, Filter) + assert query_filter.must is not None + assert query_filter.must_not is not None + + def test_recommend_forwards_strategy(self, executor, mock_client, mocker): + from qdrant_client.models import RecommendStrategy + + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", + positive_ids=("a",), + limit=5, + strategy="best_score", + ) + executor.execute(node) + + recommend = mock_client.query_points.call_args.kwargs["query"].recommend + assert recommend.strategy == RecommendStrategy.BEST_SCORE + + def test_recommend_invalid_strategy_raises(self, executor, mock_client): + mock_client.collection_exists.return_value = True + node = RecommendStmt( + collection="notes", + positive_ids=("a",), + limit=5, + strategy="not-a-strategy", + ) + with pytest.raises(QQLRuntimeError, match="Unknown recommend strategy"): + executor.execute(node) + + def test_recommend_nonexistent_collection_raises(self, executor, mock_client): + mock_client.collection_exists.return_value = False + node = RecommendStmt(collection="ghost", positive_ids=("a",), limit=5) + with pytest.raises(QQLRuntimeError, match="does not exist"): + executor.execute(node) + + def test_recommend_forwards_offset(self, executor, mock_client, mocker): + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", positive_ids=("a",), limit=5, offset=10 + ) + executor.execute(node) + assert mock_client.query_points.call_args.kwargs["offset"] == 10 + + def test_recommend_forwards_score_threshold(self, executor, mock_client, mocker): + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", positive_ids=("a",), limit=5, score_threshold=0.5 + ) + executor.execute(node) + assert mock_client.query_points.call_args.kwargs["score_threshold"] == pytest.approx(0.5) + + def test_recommend_forwards_using(self, executor, mock_client, mocker): + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", positive_ids=("a",), limit=5, using="sparse" + ) + executor.execute(node) + assert mock_client.query_points.call_args.kwargs["using"] == "sparse" + + def test_recommend_forwards_lookup_from(self, executor, mock_client, mocker): + from qdrant_client.models import LookupLocation + + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", + positive_ids=("a",), + limit=5, + lookup_from=("source", "dense"), + ) + executor.execute(node) + lookup = mock_client.query_points.call_args.kwargs["lookup_from"] + assert isinstance(lookup, LookupLocation) + assert lookup.collection == "source" + assert lookup.vector == "dense" + + def test_recommend_forwards_lookup_from_without_vector(self, executor, mock_client, mocker): + from qdrant_client.models import LookupLocation + + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", + positive_ids=("a",), + limit=5, + lookup_from=("source", None), + ) + executor.execute(node) + lookup = mock_client.query_points.call_args.kwargs["lookup_from"] + assert isinstance(lookup, LookupLocation) + assert lookup.collection == "source" + assert lookup.vector is None + + def test_recommend_forwards_search_params(self, executor, mock_client, mocker): + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", + positive_ids=("a",), + limit=5, + with_clause=SearchWith(exact=True, hnsw_ef=128), + ) + executor.execute(node) + search_params = mock_client.query_points.call_args.kwargs["search_params"] + assert search_params.exact is True + assert search_params.hnsw_ef == 128 + + def test_recommend_offset_zero_passes_none(self, executor, mock_client, mocker): + mock_client.collection_exists.return_value = True + mock_response = mocker.MagicMock() + mock_response.points = [] + mock_client.query_points.return_value = mock_response + + node = RecommendStmt( + collection="notes", positive_ids=("a",), limit=5, offset=0 + ) + executor.execute(node) + assert mock_client.query_points.call_args.kwargs["offset"] is None + + class TestDelete: def test_delete_calls_qdrant_delete(self, executor, mock_client): mock_client.collection_exists.return_value = True diff --git a/tests/test_parser.py b/tests/test_parser.py index a0dab81..1d5c22e 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -20,6 +20,7 @@ NotExpr, NotInExpr, OrExpr, + RecommendStmt, SearchStmt, SearchWith, ShowCollectionsStmt, @@ -205,6 +206,112 @@ def test_delete_by_integer_id(self): assert node.point_id == 99 +class TestRecommend: + def test_recommend_with_positive_ids(self): + node = parse("RECOMMEND FROM notes POSITIVE IDS ('a', 'b') LIMIT 5") + assert isinstance(node, RecommendStmt) + assert node.collection == "notes" + assert node.positive_ids == ("a", "b") + assert node.negative_ids == () + assert node.limit == 5 + assert node.strategy is None + + def test_recommend_with_negative_ids_and_strategy(self): + node = parse( + "RECOMMEND FROM notes POSITIVE IDS ('a', 2) " + "NEGATIVE IDS ('x') STRATEGY 'best_score' LIMIT 7" + ) + assert node.positive_ids == ("a", 2) + assert node.negative_ids == ("x",) + assert node.strategy == "best_score" + assert node.limit == 7 + + def test_recommend_with_where_filter(self): + node = parse( + "RECOMMEND FROM notes POSITIVE IDS ('a') LIMIT 5 WHERE year > 2020" + ) + assert isinstance(node.query_filter, CompareExpr) + assert node.query_filter.field == "year" + + def test_recommend_requires_non_empty_positive_ids(self): + with pytest.raises(QQLSyntaxError): + parse("RECOMMEND FROM notes POSITIVE IDS () LIMIT 5") + + def test_recommend_with_offset(self): + node = parse("RECOMMEND FROM notes POSITIVE IDS ('a') LIMIT 10 OFFSET 5") + assert node.offset == 5 + + def test_recommend_with_score_threshold(self): + node = parse( + "RECOMMEND FROM notes POSITIVE IDS ('a') LIMIT 10 SCORE THRESHOLD 0.5" + ) + assert node.score_threshold == pytest.approx(0.5) + + def test_recommend_with_clause(self): + node = parse( + "RECOMMEND FROM notes POSITIVE IDS ('a') LIMIT 10 WITH { exact: true }" + ) + assert node.with_clause is not None + assert node.with_clause.exact is True + + def test_recommend_with_clause_hnsw_ef(self): + node = parse( + "RECOMMEND FROM notes POSITIVE IDS ('a') LIMIT 10 WITH { hnsw_ef: 128 }" + ) + assert node.with_clause is not None + assert node.with_clause.hnsw_ef == 128 + + def test_recommend_lookup_from(self): + node = parse( + "RECOMMEND FROM target_collection POSITIVE IDS ('a') " + "LOOKUP FROM source_collection LIMIT 5" + ) + assert node.lookup_from == ("source_collection", None) + + def test_recommend_lookup_from_with_vector(self): + node = parse( + "RECOMMEND FROM target_collection POSITIVE IDS ('a') " + "LOOKUP FROM source_collection VECTOR 'dense' LIMIT 5" + ) + assert node.lookup_from == ("source_collection", "dense") + + def test_recommend_using(self): + node = parse( + "RECOMMEND FROM docs POSITIVE IDS ('a') USING 'sparse' LIMIT 5" + ) + assert node.using == "sparse" + + def test_recommend_lookup_from_and_using(self): + node = parse( + "RECOMMEND FROM target_collection POSITIVE IDS ('a') " + "LOOKUP FROM source_collection VECTOR 'dense' USING 'sparse' LIMIT 5" + ) + assert node.lookup_from == ("source_collection", "dense") + assert node.using == "sparse" + + def test_recommend_full_clause_order(self): + node = parse( + "RECOMMEND FROM docs POSITIVE IDS ('a', 'b') " + "NEGATIVE IDS ('x') STRATEGY 'best_score' " + "LOOKUP FROM src VECTOR 'dense' USING 'sparse' " + "LIMIT 10 OFFSET 5 SCORE THRESHOLD 0.5 " + "WHERE year > 2020 WITH { exact: true, hnsw_ef: 128 }" + ) + assert node.collection == "docs" + assert node.positive_ids == ("a", "b") + assert node.negative_ids == ("x",) + assert node.strategy == "best_score" + assert node.lookup_from == ("src", "dense") + assert node.using == "sparse" + assert node.limit == 10 + assert node.offset == 5 + assert node.score_threshold == pytest.approx(0.5) + assert isinstance(node.query_filter, CompareExpr) + assert node.with_clause is not None + assert node.with_clause.exact is True + assert node.with_clause.hnsw_ef == 128 + + class TestErrors: def test_unknown_keyword(self): with pytest.raises(QQLSyntaxError): diff --git a/tests/test_script.py b/tests/test_script.py index 9b90330..2aefe4e 100644 --- a/tests/test_script.py +++ b/tests/test_script.py @@ -99,6 +99,18 @@ def test_first_chunk_starts_with_create(self): assert chunks[0][0].kind == TokenKind.CREATE assert chunks[1][0].kind == TokenKind.DROP + def test_recommend_starts_new_top_level_statement(self): + from qql.lexer import TokenKind + + tokens = tokenize( + "SEARCH x SIMILAR TO 'stroke' LIMIT 5\n" + "RECOMMEND FROM x POSITIVE IDS ('id-1') LIMIT 3\n" + "SHOW COLLECTIONS" + ) + chunks = split_statements(tokens) + assert len(chunks) == 3 + assert chunks[1][0].kind == TokenKind.RECOMMEND + # ── run_script ──────────────────────────────────────────────────────────────── @@ -175,6 +187,16 @@ def test_comments_are_stripped(self, script_file, mock_executor): assert ok == 1 assert fail == 0 + def test_recommend_statement_executes_from_script(self, script_file, mock_executor): + path = script_file( + "SEARCH x SIMILAR TO 'stroke' LIMIT 5\n" + "RECOMMEND FROM x POSITIVE IDS ('id-1') LIMIT 3\n" + ) + ok, fail = run_script(path, mock_executor, null_console(), null_console()) + assert ok == 2 + assert fail == 0 + assert mock_executor.execute.call_count == 2 + def test_nonexistent_file_returns_failure(self, mock_executor): ok, fail = run_script( "/no/such/file.qql", mock_executor, null_console(), null_console()