diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b915033..adba401 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -191,7 +191,7 @@ def parse_azure_endpoint( if not azure_endpoint: raise RuntimeError(f"Environment variable {endpoint_envvar} not found") - m = re.search(r"[?,]api-version=([\d-]+(?:preview)?)", azure_endpoint) + m = re.search(r"[?&]api-version=([\d-]+(?:preview)?)", azure_endpoint) if not m: raise RuntimeError( f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field" diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 46c4dba..63e2e77 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -150,8 +150,9 @@ def fuzzy_lookup_embedding_in_subset( max_hits: int | None = None, min_score: float | None = None, ) -> list[ScoredInt]: + ordinals_set = set(ordinals_of_subset) return self.fuzzy_lookup_embedding( - embedding, max_hits, min_score, lambda i: i in ordinals_of_subset + embedding, max_hits, min_score, lambda i: i in ordinals_set ) async def fuzzy_lookup( diff --git a/src/typeagent/emails/email_import.py b/src/typeagent/emails/email_import.py index ce61d6b..88f13b9 100644 --- a/src/typeagent/emails/email_import.py +++ b/src/typeagent/emails/email_import.py @@ -263,7 +263,8 @@ def _merge_chunks( yield cur_chunk cur_chunk = new_chunk else: - cur_chunk += separator + if cur_chunk: + cur_chunk += separator cur_chunk += new_chunk if (len(cur_chunk)) > 0: diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index ae6fad9..d1b87a3 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -404,11 +404,12 @@ async def get_enclosing_date_range_for_text_range( start_timestamp = (await messages.get_item(range.start.message_ordinal)).timestamp if not start_timestamp: return None - end_timestamp = ( - (await messages.get_item(range.end.message_ordinal)).timestamp - if range.end - else None - ) + end_timestamp: str | None = None + if range.end: + end_ordinal = range.end.message_ordinal + if end_ordinal < await messages.size(): + end_timestamp = (await messages.get_item(end_ordinal)).timestamp + # else: range extends to the end of the conversation; leave as None. return DateRange( start=Datetime.fromisoformat(start_timestamp), end=Datetime.fromisoformat(end_timestamp) if end_timestamp else None, @@ -535,7 +536,7 @@ def facets_to_merged_facets(facets: list[Facet]) -> MergedFacets: merged_facets: MergedFacets = {} for facet in facets: name = facet.name.lower() - value = str(facet).lower() + value = str(facet.value).lower() merged_facets.setdefault(name, []).append(value) return merged_facets diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index 1e5205f..a271657 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -91,10 +91,14 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None: ) ) else: + # New related-only match: hit_count stays 0 because + # only exact matches count as direct hits. This matters + # for select_with_hit_count / _matches_with_min_hit_count + # which filter on hit_count to weed out noise. self.set_match( Match( value, - hit_count=1, + hit_count=0, score=0.0, related_hit_count=1, related_score=score, @@ -250,9 +254,11 @@ def smooth_match_score[T](match: Match[T]) -> None: class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]): - def __init__(self, search_term_matches: set[str] = set()): + def __init__(self, search_term_matches: set[str] | None = None): super().__init__() - self.search_term_matches = search_term_matches + self.search_term_matches = ( + search_term_matches if search_term_matches is not None else set() + ) def add_term_matches( self, @@ -330,8 +336,7 @@ async def group_matches_by_type( semantic_ref = await semantic_refs.get_item(match.value) group = groups.get(semantic_ref.knowledge.knowledge_type) if group is None: - group = SemanticRefAccumulator() - group.search_term_matches = self.search_term_matches + group = SemanticRefAccumulator(self.search_term_matches) groups[semantic_ref.knowledge.knowledge_type] = group group.set_match(match) return groups @@ -513,11 +518,10 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No for text_range in text_ranges._ranges: self.add_range(text_range) - def is_in_range(self, inner_range: TextRange) -> bool: - if len(self._ranges) == 0: - return False - i = bisect.bisect_left(self._ranges, inner_range) - for outer_range in self._ranges[i:]: + def contains_range(self, inner_range: TextRange) -> bool: + # Since ranges are sorted by start, once we pass inner_range's start + # no further range can contain it. + for outer_range in self._ranges: if outer_range.start > inner_range.start: break if inner_range in outer_range: @@ -544,7 +548,7 @@ def is_range_in_scope(self, inner_range: TextRange) -> bool: # We have a very simple impl: we don't intersect/union ranges yet. # Instead, we ensure that the inner range is not rejected by any outer ranges. for outer_ranges in self.text_ranges: - if not outer_ranges.is_in_range(inner_range): + if not outer_ranges.contains_range(inner_range): return False return True diff --git a/src/typeagent/knowpro/interfaces_search.py b/src/typeagent/knowpro/interfaces_search.py index 4ff20c7..c3727d2 100644 --- a/src/typeagent/knowpro/interfaces_search.py +++ b/src/typeagent/knowpro/interfaces_search.py @@ -18,13 +18,14 @@ ) __all__ = [ - "SearchTerm", "KnowledgePropertyName", "PropertySearchTerm", + "SearchSelectExpr", + "SearchTerm", "SearchTermGroup", "SearchTermGroupTypes", + "SemanticRefSearchResult", "WhenFilter", - "SearchSelectExpr", ] @@ -142,15 +143,3 @@ class SemanticRefSearchResult: term_matches: set[str] semantic_ref_matches: list[ScoredSemanticRefOrdinal] - - -__all__ = [ - "KnowledgePropertyName", - "PropertySearchTerm", - "SearchSelectExpr", - "SearchTerm", - "SearchTermGroup", - "SearchTermGroupTypes", - "SemanticRefSearchResult", - "WhenFilter", -] diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index 5151054..0bedf95 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -45,6 +45,7 @@ Thread, ) from .kplib import ConcreteEntity +from .utils import aenumerate # TODO: Move to compilelib.py type BooleanOp = Literal["and", "or", "or_max"] @@ -101,11 +102,14 @@ async def get_text_range_for_date_range( messages = conversation.messages range_start_ordinal: MessageOrdinal = -1 range_end_ordinal = range_start_ordinal - async for message in messages: - if Datetime.fromisoformat(message.timestamp) in date_range: + async for ordinal, message in aenumerate(messages): + if ( + message.timestamp + and Datetime.fromisoformat(message.timestamp) in date_range + ): if range_start_ordinal < 0: - range_start_ordinal = message.ordinal - range_end_ordinal = message.ordinal + range_start_ordinal = ordinal + range_end_ordinal = ordinal else: if range_start_ordinal >= 0: # We have a range, so break. @@ -696,7 +700,7 @@ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]): async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: accumulator = await self.source_expr.eval(context) - filtered = SemanticRefAccumulator(accumulator.search_term_matches) + filtered = SemanticRefAccumulator(set(accumulator.search_term_matches)) # Filter matches asynchronously filtered_matches = [] diff --git a/src/typeagent/knowpro/search.py b/src/typeagent/knowpro/search.py index b7c37ac..9641590 100644 --- a/src/typeagent/knowpro/search.py +++ b/src/typeagent/knowpro/search.py @@ -90,11 +90,9 @@ class SearchOptions: def __repr__(self): parts = [] - for key in dir(self): - if not key.startswith("_"): - value = getattr(self, key) - if value is not None: - parts.append(f"{key}={value!r}") + for key, value in vars(self).items(): + if not key.startswith("_") and value is not None: + parts.append(f"{key}={value!r}") return f"{self.__class__.__name__}({', '.join(parts)})" diff --git a/src/typeagent/knowpro/searchlang.py b/src/typeagent/knowpro/searchlang.py index dbb8092..e2e990d 100644 --- a/src/typeagent/knowpro/searchlang.py +++ b/src/typeagent/knowpro/searchlang.py @@ -83,11 +83,9 @@ class LanguageSearchOptions(SearchOptions): def __repr__(self): parts = [] - for key in dir(self): - if not key.startswith("_"): - value = getattr(self, key) - if value is not None: - parts.append(f"{key}={value!r}") + for key, value in vars(self).items(): + if not key.startswith("_") and value is not None: + parts.append(f"{key}={value!r}") return f"{self.__class__.__name__}({', '.join(parts)})" @@ -371,6 +369,9 @@ def compile_action_term_as_search_terms( self.compile_entity_terms_as_search_terms( action_term.additional_entities, action_group ) + # only append the nested or_max wrapper when created one (use_or_max) and it's non-empty. + if use_or_max and action_group.terms: + term_group.terms.append(action_group) return term_group def compile_search_terms( @@ -609,21 +610,6 @@ def add_entity_name_to_group( exact_match_value, ) - def add_search_term_to_groupadd_entity_name_to_group( - self, - entity_term: EntityTerm, - property_name: PropertyNames, - term_group: SearchTermGroup, - exact_match_value: bool = False, - ) -> None: - if not entity_term.is_name_pronoun: - self.add_property_term_to_group( - property_name.value, - entity_term.name, - term_group, - exact_match_value, - ) - def add_property_term_to_group( self, property_name: str, diff --git a/src/typeagent/knowpro/utils.py b/src/typeagent/knowpro/utils.py index 298c09d..92eedac 100644 --- a/src/typeagent/knowpro/utils.py +++ b/src/typeagent/knowpro/utils.py @@ -3,9 +3,18 @@ """Utility functions for the knowpro package.""" +from collections.abc import AsyncIterable + from .interfaces import MessageOrdinal, TextLocation, TextRange +async def aenumerate[T](aiterable: AsyncIterable[T], start: int = 0): + i = start + async for item in aiterable: + yield i, item + i += 1 + + def text_range_from_message_chunk( message_ordinal: MessageOrdinal, chunk_ordinal: int = 0, diff --git a/src/typeagent/mcp/server.py b/src/typeagent/mcp/server.py index dcd4a3c..8fb4d03 100644 --- a/src/typeagent/mcp/server.py +++ b/src/typeagent/mcp/server.py @@ -9,7 +9,10 @@ import time from typing import Any -import coverage +try: + import coverage +except ImportError: + coverage = None # type: ignore[assignment] from dotenv import load_dotenv from mcp.server.fastmcp import Context, FastMCP @@ -18,7 +21,8 @@ import typechat # Enable coverage.py before local imports (a no-op unless COVERAGE_PROCESS_START is set). -coverage.process_startup() +if coverage is not None: + coverage.process_startup() from typeagent.aitools import embeddings, utils from typeagent.knowpro import answers, query, searchlang @@ -246,6 +250,12 @@ async def query_conversation( return QuestionResponse( success=True, answer=combined_answer.answer or "", time_used=dt ) + case _: + return QuestionResponse( + success=False, + answer=f"Unexpected answer type: {combined_answer.type}", + time_used=dt, + ) # Run the MCP server diff --git a/tests/test_answers.py b/tests/test_answers.py new file mode 100644 index 0000000..a8fc0ce --- /dev/null +++ b/tests/test_answers.py @@ -0,0 +1,192 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio + +from typeagent.knowpro.answers import ( + facets_to_merged_facets, + get_enclosing_date_range_for_text_range, + get_enclosing_text_range, + merged_facets_to_facets, + text_range_from_message_range, +) +from typeagent.knowpro.interfaces import TextLocation, TextRange +from typeagent.knowpro.kplib import Facet + +from conftest import FakeMessage, FakeMessageCollection + +# --------------------------------------------------------------------------- +# Change 1: facets_to_merged_facets uses str(facet.value), not str(facet) +# --------------------------------------------------------------------------- + + +class TestFacetsToMergedFacets: + """Verify that facet *values* (not the whole Facet object) are stringified.""" + + def test_string_value(self) -> None: + facets = [Facet(name="colour", value="red")] + merged = facets_to_merged_facets(facets) + assert merged == {"colour": ["red"]} + + def test_numeric_value(self) -> None: + facets = [Facet(name="age", value=30.0)] + merged = facets_to_merged_facets(facets) + # Should be "30.0", NOT "Facet('age', 30.0)" + assert merged == {"age": ["30.0"]} + assert "Facet" not in merged["age"][0] + + def test_bool_value(self) -> None: + facets = [Facet(name="active", value=True)] + merged = facets_to_merged_facets(facets) + assert merged == {"active": ["true"]} + + def test_multiple_facets_same_name(self) -> None: + facets = [ + Facet(name="tag", value="a"), + Facet(name="tag", value="b"), + ] + merged = facets_to_merged_facets(facets) + assert merged == {"tag": ["a", "b"]} + + def test_lowercases_names_and_values(self) -> None: + facets = [Facet(name="Colour", value="RED")] + merged = facets_to_merged_facets(facets) + assert "colour" in merged + assert merged["colour"] == ["red"] + + def test_roundtrip_through_merged(self) -> None: + """facets_to_merged_facets -> merged_facets_to_facets preserves semantics.""" + original = [ + Facet(name="colour", value="red"), + Facet(name="colour", value="blue"), + Facet(name="size", value="large"), + ] + merged = facets_to_merged_facets(original) + restored = merged_facets_to_facets(merged) + restored_by_name = {f.name: f.value for f in restored} + assert restored_by_name["colour"] == "red; blue" + assert restored_by_name["size"] == "large" + + +# --------------------------------------------------------------------------- +# Change 2: get_enclosing_date_range_for_text_range uses ordinal-1 for end +# --------------------------------------------------------------------------- + + +class TestGetEnclosingDateRangeForTextRange: + """Verify the off-by-one fix: end is exclusive, so we subtract 1.""" + + @pytest_asyncio.fixture() + async def messages(self) -> AsyncGenerator[FakeMessageCollection, None]: + """Three messages with ordinals 0, 1, 2 and timestamps derived from them.""" + coll = FakeMessageCollection() + for i in range(3): + msg = FakeMessage("text", message_ordinal=i) + await coll.append(msg) + yield coll + + @pytest.mark.asyncio + async def test_single_message_range(self, messages: FakeMessageCollection) -> None: + """Point range (end=None) should use only the start message's timestamp.""" + tr = TextRange(start=TextLocation(1)) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 1 + assert dr.end is None + + @pytest.mark.asyncio + async def test_multi_message_range_uses_exclusive_end( + self, messages: FakeMessageCollection + ) -> None: + """Range [0, 2) should use message 2 (the exclusive end) for end timestamp.""" + tr = TextRange( + start=TextLocation(0), + end=TextLocation(2), # exclusive end + ) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 0 + # End timestamp comes from the message at the exclusive end ordinal: + assert dr.end is not None + assert dr.end.hour == 2 + + @pytest.mark.asyncio + async def test_adjacent_messages(self, messages: FakeMessageCollection) -> None: + """Range [1, 2) covers only message 1; end timestamp is message 2.""" + tr = TextRange( + start=TextLocation(1), + end=TextLocation(2), + ) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 1 + assert dr.end is not None + assert dr.end.hour == 2 # exclusive end: timestamp of the next message + + @pytest.mark.asyncio + async def test_end_past_last_message(self, messages: FakeMessageCollection) -> None: + """If the exclusive end ordinal is past the last message, end is None.""" + tr = TextRange( + start=TextLocation(0), + end=TextLocation(3), # messages only have ordinals 0, 1, 2 + ) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 0 + assert dr.end is None + + @pytest.mark.asyncio + async def test_no_timestamp_returns_none(self) -> None: + """If start message has no timestamp, return None.""" + coll = FakeMessageCollection() + msg = FakeMessage("text") # no message_ordinal → no timestamp + await coll.append(msg) + tr = TextRange(start=TextLocation(0)) + dr = await get_enclosing_date_range_for_text_range(coll, tr) + assert dr is None + + +# --------------------------------------------------------------------------- +# Helper functions (also exercised for completeness) +# --------------------------------------------------------------------------- + + +class TestGetEnclosingTextRange: + def test_single_ordinal(self) -> None: + tr = get_enclosing_text_range([5]) + assert tr is not None + assert tr.start.message_ordinal == 5 + assert tr.end is None # point range + + def test_multiple_ordinals(self) -> None: + tr = get_enclosing_text_range([3, 1, 7]) + assert tr is not None + assert tr.start.message_ordinal == 1 + assert tr.end is not None + assert tr.end.message_ordinal == 7 + + def test_empty_ordinals(self) -> None: + tr = get_enclosing_text_range([]) + assert tr is None + + +class TestTextRangeFromMessageRange: + def test_point(self) -> None: + tr = text_range_from_message_range(3, 3) + assert tr is not None + assert tr.start.message_ordinal == 3 + assert tr.end is None + + def test_range(self) -> None: + tr = text_range_from_message_range(2, 5) + assert tr is not None + assert tr.start.message_ordinal == 2 + assert tr.end is not None + assert tr.end.message_ordinal == 5 + + def test_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="Expect message ordinal range"): + text_range_from_message_range(5, 2) diff --git a/tests/test_collections.py b/tests/test_collections.py index f35c3fe..ca6e14e 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -110,11 +110,15 @@ def test_text_range_collection_add_and_check(): assert len(collection) == 2 - assert collection.is_in_range(range1) is True - assert collection.is_in_range(range2) is True - assert collection.is_in_range(range3) is False - assert collection.is_in_range(range4) is False - assert collection.is_in_range(range5) is False + assert collection.contains_range(range1) is True + assert collection.contains_range(range2) is True + assert ( + collection.contains_range(range3) is True + ) # range3 [5,10) is inside range1 [0,10) + assert ( + collection.contains_range(range4) is False + ) # range4 [5,25) spans across ranges + assert collection.contains_range(range5) is False def test_text_ranges_in_scope(): @@ -406,6 +410,90 @@ def test_match_accumulator_select_top_n_scoring(): assert matches[1].value == "medium" +def test_match_accumulator_add_non_exact_match(): + """Non-exact (related) matches must start with hit_count=0.""" + accumulator = MatchAccumulator[str]() + accumulator.add("related_term", score=0.7, is_exact_match=False) + + match = accumulator.get_match("related_term") + assert match is not None + assert match.hit_count == 0 + assert match.score == 0.0 + assert match.related_hit_count == 1 + assert match.related_score == 0.7 + + +def test_match_accumulator_non_exact_filtered_by_min_hit_count(): + """Related-only matches should be excluded by min_hit_count=1 filter.""" + accumulator = MatchAccumulator[str]() + accumulator.add("exact_term", score=1.0, is_exact_match=True) + accumulator.add("related_term", score=0.9, is_exact_match=False) + + matches = list(accumulator._matches_with_min_hit_count(min_hit_count=1)) # type: ignore + assert len(matches) == 1 + assert matches[0].value == "exact_term" + + +def test_match_accumulator_related_then_exact_same_value(): + """Adding a related match then an exact match for the same value.""" + accumulator = MatchAccumulator[str]() + accumulator.add("term", score=0.5, is_exact_match=False) + accumulator.add("term", score=1.0, is_exact_match=True) + + match = accumulator.get_match("term") + assert match is not None + assert match.hit_count == 1 + assert match.score == 1.0 + assert match.related_hit_count == 1 + assert match.related_score == 0.5 + + +def test_match_accumulator_exact_then_related_same_value(): + """Adding an exact match then a related match for the same value.""" + accumulator = MatchAccumulator[str]() + accumulator.add("term", score=1.0, is_exact_match=True) + accumulator.add("term", score=0.3, is_exact_match=False) + + match = accumulator.get_match("term") + assert match is not None + assert match.hit_count == 1 + assert match.score == 1.0 + assert match.related_hit_count == 1 + assert match.related_score == 0.3 + + +def test_match_accumulator_multiple_related_accumulate(): + """Multiple related matches for the same value accumulate correctly.""" + accumulator = MatchAccumulator[str]() + accumulator.add("term", score=0.4, is_exact_match=False) + accumulator.add("term", score=0.6, is_exact_match=False) + + match = accumulator.get_match("term") + assert match is not None + assert match.hit_count == 0 + assert match.score == 0.0 + assert match.related_hit_count == 2 + assert match.related_score == pytest.approx(1.0) + + +def test_match_accumulator_total_score_includes_related(): + """calculate_total_score adds smoothed related score to the main score.""" + accumulator = MatchAccumulator[str]() + accumulator.add("exact_only", score=2.0, is_exact_match=True) + accumulator.add("mixed", score=1.0, is_exact_match=True) + accumulator.add("mixed", score=0.5, is_exact_match=False) + + accumulator.calculate_total_score() + + exact_only = accumulator.get_match("exact_only") + mixed = accumulator.get_match("mixed") + assert exact_only is not None + assert mixed is not None + # "mixed" should have a higher score than its raw 1.0 + # because the related_score of 0.5 is added (smoothed). + assert mixed.score > 1.0 + + def test_get_smooth_score(): """Test calculating smooth scores.""" assert get_smooth_score(10.0, 1) == 10.0 # Single hit count, no smoothing diff --git a/tests/test_email_import.py b/tests/test_email_import.py new file mode 100644 index 0000000..371136b --- /dev/null +++ b/tests/test_email_import.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.emails.email_import import ( + _merge_chunks, + _split_into_paragraphs, + _text_to_chunks, +) + + +class TestMergeChunks: + """Tests for _merge_chunks, specifically the separator-on-empty-chunk fix.""" + + def test_no_leading_separator(self) -> None: + """First chunk must NOT start with the separator.""" + result = list(_merge_chunks(["hello", "world"], "\n\n", 100)) + assert len(result) == 1 + assert result[0] == "hello\n\nworld" + assert not result[0].startswith("\n") + + def test_no_leading_separator_after_yield(self) -> None: + """After yielding a full chunk, the next chunk must not start with separator.""" + # Each piece is 5 chars; max_chunk_length=8 forces a split after each. + pieces = ["aaaaa", "bbbbb", "ccccc"] + result = list(_merge_chunks(pieces, "--", 8)) + for chunk in result: + assert not chunk.startswith("--"), f"chunk {chunk!r} starts with separator" + + def test_single_chunk(self) -> None: + result = list(_merge_chunks(["only"], "\n\n", 100)) + assert result == ["only"] + + def test_empty_input(self) -> None: + result = list(_merge_chunks([], "\n\n", 100)) + assert result == [] + + def test_exact_fit(self) -> None: + """Two chunks that fit exactly within max_chunk_length.""" + # "ab" + "\n\n" + "cd" = 6 chars + result = list(_merge_chunks(["ab", "cd"], "\n\n", 6)) + assert result == ["ab\n\ncd"] + + def test_overflow_splits(self) -> None: + """Chunks that don't fit together should be yielded separately.""" + # "ab" + "\n\n" + "cd" = 6 chars, max is 5 -> must split + result = list(_merge_chunks(["ab", "cd"], "\n\n", 5)) + assert result == ["ab", "cd"] + + def test_truncation_of_oversized_chunk(self) -> None: + """A single chunk longer than max_chunk_length is truncated.""" + result = list(_merge_chunks(["abcdefghij"], "\n\n", 5)) + assert result == ["abcde"] + + def test_multiple_merges_and_splits(self) -> None: + pieces = ["aa", "bb", "cccccc", "dd"] + # "aa" + "--" + "bb" = 6, fits in 8 + # "cccccc" alone = 6, can't merge with previous (6+2+6=14>8), yield "aa--bb" + # "cccccc" + "--" + "dd" = 10 > 8, yield "cccccc" + # "dd" yielded at end + result = list(_merge_chunks(pieces, "--", 8)) + assert result == ["aa--bb", "cccccc", "dd"] + + +class TestSplitIntoParagraphs: + def test_basic_split(self) -> None: + text = "para1\n\npara2\n\npara3" + assert _split_into_paragraphs(text) == ["para1", "para2", "para3"] + + def test_multiple_newlines(self) -> None: + text = "a\n\n\n\nb" + assert _split_into_paragraphs(text) == ["a", "b"] + + def test_no_split(self) -> None: + assert _split_into_paragraphs("single paragraph") == ["single paragraph"] + + def test_leading_trailing_newlines(self) -> None: + text = "\n\nfoo\n\n" + result = _split_into_paragraphs(text) + assert "foo" in result + assert "" not in result + + +class TestTextToChunks: + def test_short_text_single_chunk(self) -> None: + result = _text_to_chunks("short text", max_chunk_length=100) + assert result == ["short text"] + + def test_long_text_splits(self) -> None: + text = "para one\n\npara two\n\npara three" + result = _text_to_chunks(text, max_chunk_length=15) + assert len(result) > 1 + for chunk in result: + assert not chunk.startswith("\n"), f"chunk {chunk!r} has leading newline" + + def test_no_leading_separator_in_any_chunk(self) -> None: + """Regression: no chunk should start with the paragraph separator.""" + text = "A" * 50 + "\n\n" + "B" * 50 + "\n\n" + "C" * 50 + result = _text_to_chunks(text, max_chunk_length=60) + for chunk in result: + assert not chunk.startswith( + "\n\n" + ), f"chunk {chunk!r} has leading separator" diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 24933ca..d70c227 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""End-to-end tests for the MCP server.""" +"""End-to-end and unit tests for the MCP server.""" import json import os import sys from typing import Any +from unittest.mock import AsyncMock import pytest @@ -14,10 +15,17 @@ from mcp.client.session import ClientSession as ClientSessionType from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext -from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) from openai.types.chat import ChatCompletionMessageParam +import typechat from typeagent.aitools.utils import create_async_openai_client +from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse from conftest import EPISODE_53_INDEX @@ -178,3 +186,160 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): response_data = json.loads(response_text) assert response_data["success"] is False assert "No question provided" in response_data["answer"] + + +# --------------------------------------------------------------------------- +# Unit tests (formerly in test_mcp_server_unit.py) +# --------------------------------------------------------------------------- + +# Coverage import guard — tested implicitly (the module loads at all +# without `coverage` installed). We just verify the guard didn't break the +# import. + + +def test_server_module_imports() -> None: + """Importing the server module should not raise even without coverage.""" + import typeagent.mcp.server as mod + + assert hasattr(mod, "mcp") # The FastMCP instance exists + + +# --------------------------------------------------------------------------- +# PromptSection role mapping ("system" → "assistant") +# --------------------------------------------------------------------------- + + +class TestMCPTypeChatModelRoleMapping: + """Verify that PromptSection roles are mapped correctly to MCP roles.""" + + @staticmethod + def _make_model() -> tuple[MCPTypeChatModel, AsyncMock]: + session = AsyncMock() + # create_message returns a result with TextContent + session.create_message.return_value = AsyncMock( + content=TextContent(type="text", text="response") + ) + model = MCPTypeChatModel(session) + return model, session + + @pytest.mark.asyncio + async def test_string_prompt_becomes_user_message(self) -> None: + model, session = self._make_model() + await model.complete("hello") + + call_args = session.create_message.call_args + messages: list[SamplingMessage] = call_args.kwargs["messages"] + assert len(messages) == 1 + assert messages[0].role == "user" + assert isinstance(messages[0].content, TextContent) + assert messages[0].content.text == "hello" + + @pytest.mark.asyncio + async def test_user_role_preserved(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "user", "content": "question"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "user" + + @pytest.mark.asyncio + async def test_assistant_role_preserved(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "assistant", "content": "context"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "assistant" + + @pytest.mark.asyncio + async def test_system_role_mapped_to_assistant(self) -> None: + """System role doesn't exist in MCP SamplingMessage; it must be mapped.""" + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "instructions"}, + {"role": "user", "content": "question"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "assistant" # "system" → "assistant" + assert messages[1].role == "user" + + @pytest.mark.asyncio + async def test_mixed_roles_order(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "usr"}, + {"role": "assistant", "content": "asst"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert [m.role for m in messages] == ["assistant", "user", "assistant"] + + @pytest.mark.asyncio + async def test_exception_returns_failure(self) -> None: + model, session = self._make_model() + session.create_message.side_effect = RuntimeError("boom") + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "boom" in result.message + + @pytest.mark.asyncio + async def test_text_content_returns_success(self) -> None: + model, _ = self._make_model() + result = await model.complete("test") + assert isinstance(result, typechat.Success) + assert result.value == "response" + + @pytest.mark.asyncio + async def test_list_content_returns_joined(self) -> None: + model, session = self._make_model() + session.create_message.return_value = AsyncMock( + content=[ + TextContent(type="text", text="part1"), + TextContent(type="text", text="part2"), + ] + ) + result = await model.complete("test") + assert isinstance(result, typechat.Success) + assert result.value == "part1\npart2" + + +# --------------------------------------------------------------------------- +# match statement default case in query_conversation +# --------------------------------------------------------------------------- + + +class TestQuestionResponseMatchDefault: + """The match on combined_answer.type must handle unexpected types.""" + + def test_known_types(self) -> None: + """QuestionResponse can represent success and failure.""" + ok = QuestionResponse(success=True, answer="yes", time_used=42) + assert ok.success is True + fail = QuestionResponse(success=False, answer="no", time_used=0) + assert fail.success is False + + def test_answer_type_coverage(self) -> None: + """AnswerResponse.type should only be 'Answered' or 'NoAnswer'.""" + from typeagent.knowpro.answer_response_schema import AnswerResponse + + answered = AnswerResponse(type="Answered", answer="yes") + assert answered.type == "Answered" + no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno") + assert no_answer.type == "NoAnswer" diff --git a/tests/test_query.py b/tests/test_query.py index 5a58df3..4546e64 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -680,3 +680,122 @@ async def test_lookup_knowledge_type(): assert {r.semantic_ref_ordinal for r in result} == {0, 2} # Should return empty list if no match assert await lookup_knowledge_type(collection, "action") == [] + + +class TestGetTextRangeForDateRange: + """Tests for the ordinal counter fix and timestamp None guard.""" + + @pytest.mark.asyncio + async def test_messages_without_ordinal_attribute(self) -> None: + """Messages that lack .ordinal should still work (manual counter).""" + + class BareMessage(FakeMessage): + """A message subclass that explicitly lacks .ordinal.""" + + def __init__(self, ts: str) -> None: + super().__init__("text") + self.timestamp = ts + if hasattr(self, "ordinal"): + del self.ordinal + + conv = FakeConversation( + messages=[ + BareMessage("2020-01-01T01:00:00"), + BareMessage("2020-01-01T02:00:00"), + ], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + result = await get_text_range_for_date_range(conv, date_range) + assert result is not None + assert result.start.message_ordinal == 0 + assert result.end is not None + assert result.end.message_ordinal == 2 # exclusive end + + @pytest.mark.asyncio + async def test_none_timestamp_skipped(self) -> None: + """Messages with None timestamp should be skipped, not crash.""" + conv = FakeConversation( + messages=[ + FakeMessage("no timestamp"), # timestamp=None + FakeMessage("has timestamp", message_ordinal=1), + ], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + result = await get_text_range_for_date_range(conv, date_range) + # Only message at ordinal 1 matches: + assert result is not None + assert result.start.message_ordinal == 1 + assert result.end is not None + assert result.end.message_ordinal == 2 + + @pytest.mark.asyncio + async def test_all_none_timestamps_returns_none(self) -> None: + """If every message has None timestamp, result should be None.""" + conv = FakeConversation( + messages=[FakeMessage("a"), FakeMessage("b")], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + assert await get_text_range_for_date_range(conv, date_range) is None + + @pytest.mark.asyncio + async def test_single_message_in_range(self) -> None: + conv = FakeConversation( + messages=[FakeMessage("msg", message_ordinal=0)], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + result = await get_text_range_for_date_range(conv, date_range) + assert result is not None + assert result.start.message_ordinal == 0 + assert result.end is not None + assert result.end.message_ordinal == 1 + + +class TestWhereSemanticRefExprProvenance: + """Verify that WhereSemanticRefExpr copies (not shares) search_term_matches.""" + + @pytest.mark.asyncio + async def test_filtered_accumulator_has_copied_provenance( + self, searchable_conversation: FakeConversation + ) -> None: + """The filtered accumulator's search_term_matches is a copy.""" + from typeagent.knowpro.query import WhereSemanticRefExpr + + # Build a source accumulator with some provenance + src = SemanticRefAccumulator() + src.search_term_matches.add("term_a") + src.add_term_matches( + Term("test"), + [ScoredSemanticRefOrdinal(0, 1.0)], + is_exact_match=True, + weight=1.0, + ) + + # Create a trivial source expression that returns the above accumulator + class ConstExpr(QueryOpExpr[SemanticRefAccumulator]): + async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: + return src + + # WhereSemanticRefExpr with no predicates (all matches pass) + expr = WhereSemanticRefExpr( + source_expr=ConstExpr(), + predicates=[], + ) + ctx = QueryEvalContext(searchable_conversation) + filtered = await expr.eval(ctx) + + # Provenance was copied, not shared: + assert "term_a" in filtered.search_term_matches + filtered.search_term_matches.add("new_term") + assert "new_term" not in src.search_term_matches diff --git a/tests/test_searchlang.py b/tests/test_searchlang.py new file mode 100644 index 0000000..15c80aa --- /dev/null +++ b/tests/test_searchlang.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.knowpro.search import SearchOptions +from typeagent.knowpro.searchlang import ( + LanguageQueryCompileOptions, + LanguageSearchOptions, +) + + +class TestSearchOptionsRepr: + """Tests for the custom __repr__ on SearchOptions and LanguageSearchOptions.""" + + def test_all_defaults_shows_non_none_fields(self) -> None: + """Default fields that are not None (like exact_match=False) appear.""" + opts = SearchOptions() + r = repr(opts) + assert r.startswith("SearchOptions(") + # exact_match defaults to False, which is not None, so it shows up: + assert "exact_match=False" in r + # None-valued fields are omitted: + assert "max_knowledge_matches" not in r + + def test_non_none_fields_shown(self) -> None: + opts = SearchOptions(max_knowledge_matches=10, threshold_score=0.5) + r = repr(opts) + assert "max_knowledge_matches=10" in r + assert "threshold_score=0.5" in r + # Fields left at None are omitted: + assert "max_message_matches" not in r + assert "max_chars_in_budget" not in r + + def test_false_field_shown(self) -> None: + """False is not None, so it should appear.""" + opts = SearchOptions(exact_match=False) + assert "exact_match=False" in repr(opts) + + def test_true_field_shown(self) -> None: + opts = SearchOptions(exact_match=True) + assert "exact_match=True" in repr(opts) + + def test_all_fields_set(self) -> None: + """When every field is non-None, all appear in repr.""" + opts = SearchOptions( + max_knowledge_matches=10, + exact_match=True, + max_message_matches=20, + max_chars_in_budget=5000, + threshold_score=0.75, + ) + r = repr(opts) + assert "max_knowledge_matches=10" in r + assert "exact_match=True" in r + assert "max_message_matches=20" in r + assert "max_chars_in_budget=5000" in r + assert "threshold_score=0.75" in r + + def test_zero_values_shown(self) -> None: + """Zero is not None, so numeric zeros should appear.""" + opts = SearchOptions(max_knowledge_matches=0, threshold_score=0.0) + r = repr(opts) + assert "max_knowledge_matches=0" in r + assert "threshold_score=0.0" in r + + def test_no_dunder_or_method_names(self) -> None: + """The repr must not contain dunder names or method objects.""" + opts = SearchOptions(max_knowledge_matches=5) + r = repr(opts) + assert "__init__" not in r + assert "__eq__" not in r + assert "bound method" not in r + + +class TestLanguageSearchOptionsRepr: + """Tests for LanguageSearchOptions.__repr__ (subclass of SearchOptions).""" + + def test_all_defaults_shows_class_name(self) -> None: + opts = LanguageSearchOptions() + r = repr(opts) + # Subclass name, not parent name: + assert r.startswith("LanguageSearchOptions(") + + def test_inherited_and_own_fields(self) -> None: + opts = LanguageSearchOptions( + max_knowledge_matches=5, + compile_options=LanguageQueryCompileOptions(exact_scope=True), + ) + r = repr(opts) + assert "LanguageSearchOptions(" in r + assert "max_knowledge_matches=5" in r + assert "compile_options=" in r + assert "exact_scope=True" in r + + def test_none_fields_omitted(self) -> None: + opts = LanguageSearchOptions() + r = repr(opts) + assert "compile_options" not in r + assert "model_instructions" not in r + assert "max_knowledge_matches" not in r + + def test_no_private_fields(self) -> None: + """Fields starting with _ should never appear in repr.""" + opts = LanguageSearchOptions(max_knowledge_matches=3) + r = repr(opts) + # No key=value pair where the key starts with underscore: + inside = r.split("(", 1)[1].rstrip(")") + for part in inside.split(", "): + if "=" in part: + key = part.split("=", 1)[0] + assert not key.startswith("_"), f"private field {key!r} in repr" diff --git a/tests/test_utils.py b/tests/test_utils.py index ceea367..5966af6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ import os from dotenv import load_dotenv +import pytest import pydantic.dataclasses import typechat @@ -51,3 +52,43 @@ class DummySchema: # This will raise if the environment or typechat is not set up correctly translator = utils.create_translator(DummyModel(), DummySchema) assert hasattr(translator, "model") + + +class TestParseAzureEndpoint: + """Tests for parse_azure_endpoint regex matching.""" + + def test_api_version_after_question_mark( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """api-version as the first (and only) query parameter.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?api-version=2025-01-01-preview", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert version == "2025-01-01-preview" + assert endpoint.startswith("https://") + + def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> None: + """api-version preceded by & (not the first query parameter).""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?foo=bar&api-version=2025-01-01-preview", + ) + _, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert version == "2025-01-01-preview" + + def test_missing_env_var_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """RuntimeError when the environment variable is not set.""" + monkeypatch.delenv("NONEXISTENT_ENDPOINT", raising=False) + with pytest.raises(RuntimeError, match="not found"): + utils.parse_azure_endpoint("NONEXISTENT_ENDPOINT") + + def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """RuntimeError when the endpoint has no api-version field.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4", + ) + with pytest.raises(RuntimeError, match="doesn't contain valid api-version"): + utils.parse_azure_endpoint("TEST_ENDPOINT")