From d28381dabc150b3efc3f6a0c6b05bbe088714a45 Mon Sep 17 00:00:00 2001 From: Nitin Kanukolanu Date: Wed, 13 May 2026 15:48:45 -0400 Subject: [PATCH] feat(query): add shared query base classes --- src/indexes/search-index.ts | 13 +- src/query/base.ts | 270 +++++++++++++++++++++++- src/query/count.ts | 21 +- src/query/filter-query.ts | 25 +-- src/query/range.ts | 39 +--- src/query/text.ts | 25 +-- src/query/vector.ts | 41 +--- tests/unit/indexes/search-index.test.ts | 23 ++ tests/unit/query/base.test.ts | 171 +++++++++++++++ tests/unit/query/vector.test.ts | 37 ++++ 10 files changed, 547 insertions(+), 118 deletions(-) create mode 100644 tests/unit/query/base.test.ts diff --git a/src/indexes/search-index.ts b/src/indexes/search-index.ts index b6ab935..60ec87a 100644 --- a/src/indexes/search-index.ts +++ b/src/indexes/search-index.ts @@ -602,7 +602,18 @@ export class SearchIndex { searchOptions.LIMIT = { from: offset, size: limit }; } - // Add sorting if specified + // Add sorting if specified on the query. FT.SEARCH accepts one + // SORTBY clause, so use the first collected sort field. + if (query.sortFields.length > 0) { + const [sortField] = query.sortFields; + searchOptions.SORTBY = { + BY: sortField.field, + DIRECTION: sortField.direction, + }; + } + + // Add sorting if specified in execution options. These options + // preserve the historical API and override query-level sorting. if (options?.sortBy) { searchOptions.SORTBY = options.sortBy; if (options.sortOrder) { diff --git a/src/query/base.ts b/src/query/base.ts index 8130b22..d7557bd 100644 --- a/src/query/base.ts +++ b/src/query/base.ts @@ -4,6 +4,9 @@ */ import type { FilterExpression } from './filter.js'; +import { QueryValidationError, SchemaValidationError } from '../errors.js'; +import { normalizeVectorDataType } from '../redis/utils.js'; +import { VectorDataType } from '../schema/types.js'; /** * A filter clause supplied to a query — either a pre-built {@link FilterExpression} @@ -21,17 +24,61 @@ export function renderFilter(filter: FilterInput | undefined): string { } /** - * Base interface for all query types + * Sort direction for Redis search results. */ -export interface BaseQuery { +export type SortDirection = 'ASC' | 'DESC'; + +/** + * A normalized sort field specification. + */ +export interface SortField { + field: string; + direction: SortDirection; +} + +/** + * Common query configuration shared by all FT.SEARCH query types. + */ +export interface BaseQueryConfig { + /** Filter expression used as the base Redis query string. */ + filter?: FilterInput; + /** Fields to return in results */ returnFields?: string[]; + /** Offset for pagination */ + offset?: number; + /** Number of results to return */ limit?: number; +} - /** Offset for pagination */ - offset?: number; +/** + * Options for configuring returned fields. + */ +export interface ReturnFieldsOptions { + /** Fields that should not be decoded by higher-level result processors. */ + skipDecode?: string | string[]; +} + +/** + * Options for sorting query results. + */ +export interface SortByOptions { + /** Sort direction. Defaults to ASC. */ + direction?: SortDirection; +} + +/** + * Base abstract class for all FT.SEARCH query types. + */ +export abstract class BaseQuery { + private _filter?: FilterInput; + private _returnFields?: string[]; + private _skipDecodeFields?: string[]; + private _offset?: number; + private _limit?: number; + private readonly _sortFields: SortField[] = []; /** When true, ask Redis to return only document ids/counts (FT.SEARCH NOCONTENT). */ noContent?: boolean; @@ -39,11 +86,220 @@ export interface BaseQuery { /** Optional RediSearch scorer to apply when ranking text results. */ textScorer?: string; - /** Build the Redis query string */ - buildQuery(): string; + protected constructor(config: BaseQueryConfig = {}) { + if (config.filter !== undefined) { + this.setFilter(config.filter); + } + if (config.returnFields !== undefined) { + this.setReturnFields(config.returnFields); + } + if (config.offset !== undefined || config.limit !== undefined) { + this.setPagingFromConfig(config.offset, config.limit); + } + } + + /** Filter expression used by the query. */ + get filter(): FilterInput | undefined { + return this._filter; + } + + /** Fields to return in results. */ + get returnFields(): string[] | undefined { + return this._returnFields ? [...this._returnFields] : undefined; + } + + /** Fields that should not be decoded by higher-level result processors. */ + get skipDecodeFields(): string[] | undefined { + return this._skipDecodeFields ? [...this._skipDecodeFields] : undefined; + } + + /** Offset for pagination. */ + get offset(): number | undefined { + return this._offset; + } + + /** Number of results to return. */ + get limit(): number | undefined { + return this._limit; + } + + /** Sort fields collected for query execution. */ + get sortFields(): SortField[] { + return this._sortFields.map((field) => ({ ...field })); + } + + /** Set or clear the query filter. */ + setFilter(filter?: FilterInput | null): this { + if (filter === undefined || filter === null) { + this._filter = undefined; + return this; + } + + const rendered = renderFilter(filter); + if (rendered.trim() === '') { + throw new QueryValidationError('filter cannot be empty'); + } + this._filter = filter; + return this; + } + + /** Set or clear return fields. */ + setReturnFields(fields?: string[], options: ReturnFieldsOptions = {}): this { + if (fields === undefined) { + this._returnFields = undefined; + this._skipDecodeFields = undefined; + return this; + } + + this._returnFields = validateStringList(fields, 'returnFields'); + + if (options.skipDecode !== undefined) { + const skipDecode = Array.isArray(options.skipDecode) + ? options.skipDecode + : [options.skipDecode]; + this._skipDecodeFields = validateStringList(skipDecode, 'skipDecode'); + } else { + this._skipDecodeFields = undefined; + } + + return this; + } + + /** Set pagination values. */ + paging(offset: number, limit: number): this { + validateOffset(offset); + validateLimit(limit); + this._offset = offset; + this._limit = limit; + return this; + } + + /** Add a sort field. */ + sortBy(field: string, options: SortByOptions = {}): this { + validateNonEmptyString(field, 'sort field'); + const direction = options.direction ?? 'ASC'; + if (direction !== 'ASC' && direction !== 'DESC') { + throw new QueryValidationError('sort direction must be either ASC or DESC'); + } + this._sortFields.push({ field, direction }); + return this; + } + + /** Build the Redis query string. */ + abstract buildQuery(): string; /** Build the query parameters for Redis */ - buildParams(): Record; + buildParams(): Record { + return {}; + } + + private setPagingFromConfig(offset?: number, limit?: number): void { + if (offset !== undefined) { + validateOffset(offset); + this._offset = offset; + } + if (limit !== undefined) { + validateLimit(limit); + this._limit = limit; + } + } +} + +/** + * Common vector query configuration. + */ +export interface BaseVectorQueryConfig extends BaseQueryConfig { + /** Vector to search with. */ + vector: number[]; + + /** Name of the vector field in the index. */ + vectorField: string; + + /** Vector datatype to use when serializing the query vector. */ + datatype?: VectorDataType | string; + + /** Whether to normalize distances during result processing. */ + normalizeDistance?: boolean; +} + +/** + * Base abstract class for vector-backed query types. + */ +export abstract class BaseVectorQuery extends BaseQuery { + private readonly _vector: number[]; + private readonly _vectorField: string; + private readonly _datatype: VectorDataType; + private readonly _normalizeDistance: boolean; + + protected constructor(config: BaseVectorQueryConfig) { + super(config); + + if (!config.vector || config.vector.length === 0) { + throw new QueryValidationError('Vector cannot be empty'); + } + + if (!config.vectorField || config.vectorField.trim() === '') { + throw new QueryValidationError('vectorField is required'); + } + + this._vector = [...config.vector]; + this._vectorField = config.vectorField; + this._normalizeDistance = config.normalizeDistance ?? false; + + try { + this._datatype = normalizeVectorDataType(config.datatype); + } catch (error) { + if (error instanceof SchemaValidationError) { + throw new QueryValidationError(error.message); + } + throw error; + } + } + + /** Vector to search with. */ + get vector(): number[] { + return [...this._vector]; + } + + /** Name of the vector field in the index. */ + get vectorField(): string { + return this._vectorField; + } + + /** Vector datatype used when serializing the query vector. */ + get datatype(): VectorDataType { + return this._datatype; + } + + /** Whether to normalize distances during result processing. */ + get normalizeDistance(): boolean { + return this._normalizeDistance; + } +} + +function validateStringList(values: string[], label: string): string[] { + return values.map((value) => { + validateNonEmptyString(value, label); + return value; + }); +} + +function validateNonEmptyString(value: string, label: string): void { + if (typeof value !== 'string' || value.trim() === '') { + throw new QueryValidationError(`${label} cannot be empty`); + } +} + +function validateOffset(offset: number): void { + if (!Number.isInteger(offset) || offset < 0) { + throw new QueryValidationError('offset must be a non-negative integer'); + } +} + +function validateLimit(limit: number): void { + if (!Number.isInteger(limit) || limit <= 0) { + throw new QueryValidationError('limit must be a positive integer'); + } } /** diff --git a/src/query/count.ts b/src/query/count.ts index c60b44c..7b90bed 100644 --- a/src/query/count.ts +++ b/src/query/count.ts @@ -1,4 +1,4 @@ -import { renderFilter, type BaseQuery, type FilterInput } from './base.js'; +import { BaseQuery, renderFilter, type FilterInput } from './base.js'; /** * Configuration for {@link CountQuery}. @@ -22,21 +22,22 @@ export interface CountQueryConfig { * const total = (await index.search(new CountQuery({ filter: Tag('brand').eq('nike') }))).total; * ``` */ -export class CountQuery implements BaseQuery { - public readonly filter?: FilterInput; - public readonly offset = 0; - public readonly limit = 0; +export class CountQuery extends BaseQuery { public readonly noContent = true; constructor(config: CountQueryConfig = {}) { - this.filter = config.filter; + super({ filter: config.filter }); } - buildQuery(): string { - return renderFilter(this.filter); + get offset(): number { + return 0; + } + + get limit(): number { + return 0; } - buildParams(): Record { - return {}; + buildQuery(): string { + return renderFilter(this.filter); } } diff --git a/src/query/filter-query.ts b/src/query/filter-query.ts index 847c8fc..7e2da00 100644 --- a/src/query/filter-query.ts +++ b/src/query/filter-query.ts @@ -1,4 +1,4 @@ -import { renderFilter, type BaseQuery, type FilterInput } from './base.js'; +import { BaseQuery, renderFilter, type FilterInput } from './base.js'; /** * Configuration for {@link FilterQuery}. @@ -35,26 +35,21 @@ export interface FilterQueryConfig { * const results = await index.search(q); * ``` */ -export class FilterQuery implements BaseQuery { - public readonly filter?: FilterInput; - public readonly returnFields?: string[]; +export class FilterQuery extends BaseQuery { public readonly numResults: number; - public readonly offset?: number; - public readonly limit?: number; constructor(config: FilterQueryConfig = {}) { - this.filter = config.filter; - this.returnFields = config.returnFields; - this.numResults = config.numResults ?? 10; - this.offset = config.offset; - this.limit = config.limit ?? this.numResults; + const numResults = config.numResults ?? 10; + super({ + filter: config.filter, + returnFields: config.returnFields, + offset: config.offset, + limit: config.limit ?? numResults, + }); + this.numResults = numResults; } buildQuery(): string { return renderFilter(this.filter); } - - buildParams(): Record { - return {}; - } } diff --git a/src/query/range.ts b/src/query/range.ts index 14ceddd..6d2ab3b 100644 --- a/src/query/range.ts +++ b/src/query/range.ts @@ -1,7 +1,7 @@ -import { renderFilter, type BaseQuery, type FilterInput } from './base.js'; +import { BaseVectorQuery, renderFilter, type FilterInput } from './base.js'; import { VectorDataType, VectorDistanceMetric } from '../schema/types.js'; -import { QueryValidationError, SchemaValidationError } from '../errors.js'; -import { encodeVectorBuffer, normalizeVectorDataType } from '../redis/utils.js'; +import { QueryValidationError } from '../errors.js'; +import { encodeVectorBuffer } from '../redis/utils.js'; import type { HybridPolicy } from './vector.js'; /** @@ -76,29 +76,15 @@ export interface VectorRangeQueryConfig { * const results = await index.search(q); * ``` */ -export class VectorRangeQuery implements BaseQuery { - public readonly vector: number[]; - public readonly vectorField: string; +export class VectorRangeQuery extends BaseVectorQuery { public readonly distanceThreshold: number; - public readonly filter?: FilterInput; - public readonly returnFields?: string[]; public readonly distanceMetric: VectorDistanceMetric; - public readonly datatype: VectorDataType; - public readonly offset?: number; - public readonly limit?: number; public readonly scoreAlias: string; public readonly hybridPolicy?: HybridPolicy; public readonly batchSize?: number; - public readonly normalizeDistance: boolean; constructor(config: VectorRangeQueryConfig) { - if (!config.vector || config.vector.length === 0) { - throw new QueryValidationError('Vector cannot be empty'); - } - - if (!config.vectorField) { - throw new QueryValidationError('vectorField is required'); - } + super(config); if (config.distanceThreshold !== undefined && config.distanceThreshold < 0) { throw new QueryValidationError('distanceThreshold must be non-negative'); @@ -119,26 +105,11 @@ export class VectorRangeQuery implements BaseQuery { } } - this.vector = config.vector; - this.vectorField = config.vectorField; this.distanceThreshold = config.distanceThreshold ?? 0.2; - this.filter = config.filter; - this.returnFields = config.returnFields; this.distanceMetric = config.distanceMetric ?? VectorDistanceMetric.COSINE; - try { - this.datatype = normalizeVectorDataType(config.datatype); - } catch (error) { - if (error instanceof SchemaValidationError) { - throw new QueryValidationError(error.message); - } - throw error; - } - this.offset = config.offset; - this.limit = config.limit; this.scoreAlias = config.scoreAlias ?? 'vector_distance'; this.hybridPolicy = config.hybridPolicy; this.batchSize = config.batchSize; - this.normalizeDistance = config.normalizeDistance ?? false; } buildQuery(): string { diff --git a/src/query/text.ts b/src/query/text.ts index bece785..aae58de 100644 --- a/src/query/text.ts +++ b/src/query/text.ts @@ -1,4 +1,4 @@ -import { renderFilter, type BaseQuery, type FilterInput } from './base.js'; +import { BaseQuery, renderFilter, type FilterInput } from './base.js'; import { TokenEscaper } from '../utils/token-escaper.js'; import { QueryValidationError } from '../errors.js'; @@ -66,15 +66,11 @@ export interface TextQueryConfig { * const results = await index.search(q); * ``` */ -export class TextQuery implements BaseQuery { +export class TextQuery extends BaseQuery { public readonly text: string; public readonly textFieldName: string; public readonly textScorer: TextScorer; - public readonly filter?: FilterInput; - public readonly returnFields?: string[]; public readonly numResults: number; - public readonly offset?: number; - public readonly limit?: number; constructor(config: TextQueryConfig) { if (!config.text || config.text.trim() === '') { @@ -85,14 +81,17 @@ export class TextQuery implements BaseQuery { throw new QueryValidationError('textFieldName is required'); } + const numResults = config.numResults ?? 10; + super({ + filter: config.filter, + returnFields: config.returnFields, + offset: config.offset, + limit: config.limit ?? numResults, + }); this.text = config.text; this.textFieldName = config.textFieldName; this.textScorer = config.textScorer ?? 'BM25STD'; - this.filter = config.filter; - this.returnFields = config.returnFields; - this.numResults = config.numResults ?? 10; - this.offset = config.offset; - this.limit = config.limit ?? this.numResults; + this.numResults = numResults; } buildQuery(): string { @@ -113,8 +112,4 @@ export class TextQuery implements BaseQuery { } return `(${filterStr} ${textClause})`; } - - buildParams(): Record { - return {}; - } } diff --git a/src/query/vector.ts b/src/query/vector.ts index f499b04..5c7ba04 100644 --- a/src/query/vector.ts +++ b/src/query/vector.ts @@ -1,7 +1,7 @@ -import { renderFilter, type BaseQuery, type FilterInput } from './base.js'; +import { BaseVectorQuery, renderFilter, type FilterInput } from './base.js'; import { VectorDataType, VectorDistanceMetric } from '../schema/types.js'; -import { QueryValidationError, SchemaValidationError } from '../errors.js'; -import { encodeVectorBuffer, normalizeVectorDataType } from '../redis/utils.js'; +import { QueryValidationError } from '../errors.js'; +import { encodeVectorBuffer } from '../redis/utils.js'; /** * Hybrid policy options for vector search with filters @@ -184,36 +184,20 @@ export interface VectorQueryConfig { * const results = await index.search(query); * ``` */ -export class VectorQuery implements BaseQuery { - public readonly vector: number[]; - public readonly vectorField: string; +export class VectorQuery extends BaseVectorQuery { public readonly numResults: number; - public readonly filter?: FilterInput; - public readonly returnFields?: string[]; public readonly distanceMetric: VectorDistanceMetric; - public readonly datatype: VectorDataType; - public readonly offset?: number; - public readonly limit?: number; public readonly scoreAlias: string; public readonly efRuntime?: number; public readonly epsilon?: number; public readonly hybridPolicy?: HybridPolicy; public readonly batchSize?: number; - public readonly normalizeDistance: boolean; public readonly searchWindowSize?: number; public readonly useSearchHistory?: UseSearchHistory; public readonly searchBufferCapacity?: number; constructor(config: VectorQueryConfig) { - // Validate vector - if (!config.vector || config.vector.length === 0) { - throw new QueryValidationError('Vector cannot be empty'); - } - - // Validate vectorField - if (!config.vectorField) { - throw new QueryValidationError('vectorField is required'); - } + super(config); // Validate HNSW parameters if (config.efRuntime !== undefined && config.efRuntime <= 0) { @@ -264,28 +248,13 @@ export class VectorQuery implements BaseQuery { throw new QueryValidationError('searchBufferCapacity must be positive'); } - this.vector = config.vector; - this.vectorField = config.vectorField; this.numResults = config.numResults ?? 10; - this.filter = config.filter; - this.returnFields = config.returnFields; this.distanceMetric = config.distanceMetric ?? VectorDistanceMetric.COSINE; - try { - this.datatype = normalizeVectorDataType(config.datatype); - } catch (error) { - if (error instanceof SchemaValidationError) { - throw new QueryValidationError(error.message); - } - throw error; - } - this.offset = config.offset; - this.limit = config.limit; this.scoreAlias = config.scoreAlias ?? 'vector_distance'; this.efRuntime = config.efRuntime; this.epsilon = config.epsilon; this.hybridPolicy = config.hybridPolicy; this.batchSize = config.batchSize; - this.normalizeDistance = config.normalizeDistance ?? false; this.searchWindowSize = config.searchWindowSize; this.useSearchHistory = config.useSearchHistory; this.searchBufferCapacity = config.searchBufferCapacity; diff --git a/tests/unit/indexes/search-index.test.ts b/tests/unit/indexes/search-index.test.ts index 4e3c845..c15e33a 100644 --- a/tests/unit/indexes/search-index.test.ts +++ b/tests/unit/indexes/search-index.test.ts @@ -7,6 +7,7 @@ import type { RedisClientType } from 'redis'; import { RedisVLError, SchemaValidationError } from '../../../src/errors.js'; import { VectorQuery } from '../../../src/query/vector.js'; import { TextQuery } from '../../../src/query/text.js'; +import { FilterQuery } from '../../../src/query/filter-query.js'; describe('SearchIndex', () => { let schema: IndexSchema; @@ -1362,5 +1363,27 @@ describe('SearchIndex', () => { }) ); }); + + it('should pass chainable query sort fields to FT.SEARCH', async () => { + const ftSearch = mockClient.ft.search as MockedFunction; + ftSearch.mockResolvedValue({ + total: 0, + documents: [], + } as Awaited>); + + const index = new SearchIndex(schema, mockClient); + await index.search(new FilterQuery().sortBy('title', { direction: 'DESC' })); + + expect(ftSearch).toHaveBeenCalledWith( + 'redisvl-test-index', + '*', + expect.objectContaining({ + SORTBY: { + BY: 'title', + DIRECTION: 'DESC', + }, + }) + ); + }); }); }); diff --git a/tests/unit/query/base.test.ts b/tests/unit/query/base.test.ts new file mode 100644 index 0000000..7180ffa --- /dev/null +++ b/tests/unit/query/base.test.ts @@ -0,0 +1,171 @@ +import { describe, expect, it } from 'vitest'; +import { + BaseQuery, + BaseVectorQuery, + renderFilter, + type BaseQueryConfig, + type BaseVectorQueryConfig, +} from '../../../src/query/base.js'; +import { Tag } from '../../../src/query/filter.js'; +import { QueryValidationError } from '../../../src/errors.js'; +import { VectorDataType } from '../../../src/schema/types.js'; + +class TestQuery extends BaseQuery { + constructor(config: BaseQueryConfig = {}) { + super(config); + } + + buildQuery(): string { + return renderFilter(this.filter); + } +} + +class TestVectorQuery extends BaseVectorQuery { + constructor(config: BaseVectorQueryConfig) { + super(config); + } + + buildQuery(): string { + return `${renderFilter(this.filter)} @${this.vectorField}`; + } +} + +describe('BaseQuery', () => { + it('initializes common query state', () => { + const query = new TestQuery({ + filter: '@category:{books}', + returnFields: ['title', 'price'], + offset: 10, + limit: 5, + }); + + expect(query.filter).toBe('@category:{books}'); + expect(query.returnFields).toEqual(['title', 'price']); + expect(query.offset).toBe(10); + expect(query.limit).toBe(5); + expect(query.buildParams()).toEqual({}); + }); + + it('sets and clears filters from strings and FilterExpression objects', () => { + const query = new TestQuery().setFilter(Tag('brand').eq('redis')); + + expect(query.buildQuery()).toBe('@brand:{redis}'); + + query.setFilter('@brand:{valkey}'); + expect(query.buildQuery()).toBe('@brand:{valkey}'); + + query.setFilter(null); + expect(query.filter).toBeUndefined(); + expect(query.buildQuery()).toBe('*'); + }); + + it('rejects empty filter strings', () => { + expect(() => new TestQuery({ filter: '' })).toThrow(QueryValidationError); + expect(() => new TestQuery().setFilter(' ')).toThrow(QueryValidationError); + }); + + it('sets return fields and skip-decode fields defensively', () => { + const fields = ['title', 'embedding']; + const skipDecode = ['embedding']; + const query = new TestQuery().setReturnFields(fields, { skipDecode }); + + fields.push('price'); + skipDecode.push('blob'); + + expect(query.returnFields).toEqual(['title', 'embedding']); + expect(query.skipDecodeFields).toEqual(['embedding']); + + query.setReturnFields(['title'], { skipDecode: 'raw' }); + expect(query.returnFields).toEqual(['title']); + expect(query.skipDecodeFields).toEqual(['raw']); + }); + + it('clears return fields and skip-decode fields together', () => { + const query = new TestQuery() + .setReturnFields(['title'], { skipDecode: 'embedding' }) + .setReturnFields(); + + expect(query.returnFields).toBeUndefined(); + expect(query.skipDecodeFields).toBeUndefined(); + }); + + it('rejects invalid return and skip-decode fields', () => { + expect(() => new TestQuery().setReturnFields(['title', ''])).toThrow(QueryValidationError); + expect(() => + new TestQuery().setReturnFields(['title'], { skipDecode: ['embedding', ''] }) + ).toThrow(QueryValidationError); + }); + + it('sets paging values and validates them', () => { + const query = new TestQuery().paging(20, 10); + + expect(query.offset).toBe(20); + expect(query.limit).toBe(10); + expect(() => new TestQuery().paging(-1, 10)).toThrow(QueryValidationError); + expect(() => new TestQuery().paging(0, 0)).toThrow(QueryValidationError); + expect(() => new TestQuery({ offset: -1 })).toThrow(QueryValidationError); + expect(() => new TestQuery({ limit: 0 })).toThrow(QueryValidationError); + }); + + it('collects sort fields and validates sort input', () => { + const query = new TestQuery().sortBy('price').sortBy('created_at', { + direction: 'DESC', + }); + + expect(query.sortFields).toEqual([ + { field: 'price', direction: 'ASC' }, + { field: 'created_at', direction: 'DESC' }, + ]); + expect(() => new TestQuery().sortBy('')).toThrow(QueryValidationError); + expect(() => new TestQuery().sortBy('price', { direction: 'DOWN' as any })).toThrow( + QueryValidationError + ); + }); +}); + +describe('BaseVectorQuery', () => { + it('initializes vector state and common query state', () => { + const query = new TestVectorQuery({ + vector: [0.1, 0.2, 0.3], + vectorField: 'embedding', + datatype: VectorDataType.FLOAT64, + normalizeDistance: true, + filter: '@category:{books}', + returnFields: ['title'], + }); + + expect(query).toBeInstanceOf(BaseQuery); + expect(query.vector).toEqual([0.1, 0.2, 0.3]); + expect(query.vectorField).toBe('embedding'); + expect(query.datatype).toBe(VectorDataType.FLOAT64); + expect(query.normalizeDistance).toBe(true); + expect(query.buildQuery()).toBe('@category:{books} @embedding'); + }); + + it('defensively copies vectors', () => { + const vector = [0.1, 0.2, 0.3]; + const query = new TestVectorQuery({ vector, vectorField: 'embedding' }); + + vector.push(0.4); + query.vector.push(0.5); + + expect(query.vector).toEqual([0.1, 0.2, 0.3]); + }); + + it('rejects invalid vector state', () => { + expect(() => new TestVectorQuery({ vector: [], vectorField: 'embedding' })).toThrow( + QueryValidationError + ); + expect(() => new TestVectorQuery({ vector: [0.1], vectorField: '' })).toThrow( + QueryValidationError + ); + expect( + () => + new TestVectorQuery({ + vector: [0.1], + vectorField: 'embedding', + datatype: 'float25', + }) + ).toThrow(QueryValidationError); + }); +}); diff --git a/tests/unit/query/vector.test.ts b/tests/unit/query/vector.test.ts index d38f7b0..8025777 100644 --- a/tests/unit/query/vector.test.ts +++ b/tests/unit/query/vector.test.ts @@ -1,4 +1,5 @@ import { describe, it, expect, vi } from 'vitest'; +import { BaseQuery, BaseVectorQuery } from '../../../src/query/base.js'; import { VectorQuery } from '../../../src/query/vector.js'; import { QueryValidationError } from '../../../src/errors.js'; import { VectorDataType, VectorDistanceMetric } from '../../../src/schema/types.js'; @@ -14,6 +15,8 @@ describe('VectorQuery', () => { }); expect(query).toBeInstanceOf(VectorQuery); + expect(query).toBeInstanceOf(BaseVectorQuery); + expect(query).toBeInstanceOf(BaseQuery); expect(query.numResults).toBe(10); expect(query.vectorField).toBe('embedding'); expect(query.returnFields).toEqual(['title', 'score']); @@ -176,6 +179,16 @@ describe('VectorQuery', () => { expect(query.offset).toBe(20); expect(query.limit).toBe(10); }); + + it('should support chainable paging updates', () => { + const query = new VectorQuery({ + vector: [0.1, 0.2, 0.3], + vectorField: 'embedding', + }).paging(30, 15); + + expect(query.offset).toBe(30); + expect(query.limit).toBe(15); + }); }); describe('returnFields', () => { @@ -197,6 +210,30 @@ describe('VectorQuery', () => { expect(query.returnFields).toBeUndefined(); }); + + it('should support chainable return field updates with skip-decode fields', () => { + const query = new VectorQuery({ + vector: [0.1, 0.2, 0.3], + vectorField: 'embedding', + }).setReturnFields(['title', 'embedding'], { skipDecode: 'embedding' }); + + expect(query.returnFields).toEqual(['title', 'embedding']); + expect(query.skipDecodeFields).toEqual(['embedding']); + }); + }); + + describe('filter updates', () => { + it('should support chainable filter updates', () => { + const query = new VectorQuery({ + vector: [0.1, 0.2, 0.3], + vectorField: 'embedding', + }).setFilter('@category:{books}'); + + expect(query.filter).toBe('@category:{books}'); + expect(query.buildQuery()).toBe( + '(@category:{books})=>[KNN 10 @embedding $vector AS vector_distance]' + ); + }); }); describe('Distance Normalization', () => {