Skip to content

Commit f0edca4

Browse files
authored
Merge pull request #456 from validmind/cacahfla/tiktoken-fallback
Provide fallback if tiktoken cannot be imported
2 parents 4c096c5 + 53670d2 commit f0edca4

4 files changed

Lines changed: 202 additions & 17 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "validmind"
3-
version = "2.10.5"
3+
version = "2.10.6"
44
description = "ValidMind Library"
55
readme = "README.pypi.md"
66
requires-python = ">=3.9,<3.13"

tests/test_test_descriptions.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2+
# See the LICENSE file in the root of this repository for details.
3+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4+
5+
import unittest
6+
from unittest.mock import patch
7+
8+
import validmind.ai.test_descriptions as test_desc_module
9+
from validmind.ai.test_descriptions import (
10+
_estimate_tokens_simple,
11+
_truncate_summary,
12+
_truncate_text_simple,
13+
)
14+
15+
16+
class TestTokenEstimation(unittest.TestCase):
17+
"""Test token estimation and truncation functions."""
18+
19+
def test_estimate_tokens_simple(self):
20+
"""Test simple character-based token estimation."""
21+
# Test with empty string
22+
self.assertEqual(_estimate_tokens_simple(""), 0)
23+
24+
# Test with 100 characters (should be ~25 tokens)
25+
text_100 = "a" * 100
26+
self.assertEqual(_estimate_tokens_simple(text_100), 25)
27+
28+
# Test with 400 characters (should be 100 tokens)
29+
text_400 = "a" * 400
30+
self.assertEqual(_estimate_tokens_simple(text_400), 100)
31+
32+
def test_truncate_text_simple_no_truncation(self):
33+
"""Test that short text is not truncated."""
34+
short_text = "This is a short text."
35+
result = _truncate_text_simple(short_text, max_tokens=100)
36+
self.assertEqual(result, short_text)
37+
38+
def test_truncate_text_simple_with_truncation(self):
39+
"""Test that long text is truncated correctly."""
40+
# Create text that's definitely longer than max_tokens
41+
long_text = "a" * 10000 # 10000 chars = ~2500 tokens
42+
43+
result = _truncate_text_simple(long_text, max_tokens=100)
44+
45+
# Should be truncated
46+
self.assertIn("...[truncated]", result)
47+
self.assertLess(len(result), len(long_text))
48+
49+
# Should have beginning and end
50+
self.assertTrue(result.startswith("a"))
51+
self.assertTrue(result.endswith("a"))
52+
53+
54+
class TestTruncateSummary(unittest.TestCase):
55+
"""Test the _truncate_summary function."""
56+
57+
def test_none_and_short_text(self):
58+
"""Test None input and short text that doesn't need truncation."""
59+
# None input
60+
self.assertIsNone(_truncate_summary(None, "test.id"))
61+
62+
# Short text
63+
short_text = "This is a short summary."
64+
result = _truncate_summary(short_text, "test.id", max_tokens=100)
65+
self.assertEqual(result, short_text)
66+
67+
# Character length optimization (text shorter than max_tokens in chars)
68+
text = "a" * 50
69+
result = _truncate_summary(text, "test.id", max_tokens=100)
70+
self.assertEqual(result, text)
71+
72+
@patch("validmind.ai.test_descriptions._TIKTOKEN_AVAILABLE", False)
73+
def test_fallback_truncation(self):
74+
"""Test truncation using fallback when tiktoken is unavailable."""
75+
long_summary = "y" * 10000 # ~2500 estimated tokens
76+
77+
result = _truncate_summary(long_summary, "test.id", max_tokens=100)
78+
79+
# Should be truncated with marker
80+
self.assertIn("...[truncated]", result)
81+
self.assertLess(len(result), len(long_summary))
82+
# Should preserve beginning and end
83+
self.assertTrue(result.startswith("y"))
84+
self.assertTrue(result.endswith("y"))
85+
86+
87+
class TestCodePathSelection(unittest.TestCase):
88+
"""Test that the correct code path (tiktoken vs fallback) is selected."""
89+
90+
def test_module_state(self):
91+
"""Test that module-level flags are set correctly at load time."""
92+
self.assertIsInstance(test_desc_module._TIKTOKEN_AVAILABLE, bool)
93+
94+
if test_desc_module._TIKTOKEN_AVAILABLE:
95+
self.assertIsNotNone(test_desc_module._TIKTOKEN_ENCODING)
96+
else:
97+
self.assertIsNone(test_desc_module._TIKTOKEN_ENCODING)
98+
99+
@patch("validmind.ai.test_descriptions._TIKTOKEN_AVAILABLE", True)
100+
@patch("validmind.ai.test_descriptions._TIKTOKEN_ENCODING")
101+
@patch("validmind.ai.test_descriptions._estimate_tokens_simple")
102+
def test_tiktoken_path(self, mock_estimate, mock_encoding):
103+
"""Test tiktoken path is used when available and fallback is not."""
104+
mock_encoding.encode.return_value = list(range(1000))
105+
mock_encoding.decode.return_value = "decoded"
106+
107+
long_summary = "x" * 10000
108+
result = _truncate_summary(long_summary, "test.id", max_tokens=100)
109+
110+
# Verify tiktoken was called
111+
mock_encoding.encode.assert_called_once_with(long_summary)
112+
self.assertEqual(mock_encoding.decode.call_count, 2)
113+
# Verify fallback was NOT called
114+
mock_estimate.assert_not_called()
115+
116+
self.assertIn("decoded", result)
117+
118+
@patch("validmind.ai.test_descriptions._TIKTOKEN_AVAILABLE", False)
119+
@patch("validmind.ai.test_descriptions._TIKTOKEN_ENCODING")
120+
@patch("validmind.ai.test_descriptions._estimate_tokens_simple")
121+
@patch("validmind.ai.test_descriptions._truncate_text_simple")
122+
def test_fallback_path(self, mock_truncate, mock_estimate, mock_encoding):
123+
"""Test fallback path is used when tiktoken unavailable."""
124+
mock_estimate.return_value = 1000
125+
mock_truncate.return_value = "fallback_result"
126+
127+
long_summary = "y" * 10000
128+
result = _truncate_summary(long_summary, "test.id", max_tokens=100)
129+
130+
# Verify fallback was called
131+
mock_estimate.assert_called_once_with(long_summary)
132+
mock_truncate.assert_called_once_with(long_summary, 100)
133+
# Verify tiktoken was NOT called
134+
mock_encoding.encode.assert_not_called()
135+
mock_encoding.decode.assert_not_called()
136+
137+
self.assertEqual(result, "fallback_result")
138+
139+

validmind/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.10.5"
1+
__version__ = "2.10.6"

validmind/ai/test_descriptions.py

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from concurrent.futures import ThreadPoolExecutor
88
from typing import Any, Dict, List, Optional, Union
99

10-
import tiktoken
11-
1210
from ..client_config import client_config
1311
from ..logging import get_logger
1412
from ..utils import NumpyEncoder, md_to_html, test_id_to_name
@@ -25,6 +23,21 @@
2523

2624
logger = get_logger(__name__)
2725

26+
# Try to import tiktoken once at module load
27+
# Cache the result to avoid repeated import attempts
28+
_TIKTOKEN_AVAILABLE = False
29+
_TIKTOKEN_ENCODING = None
30+
31+
try:
32+
import tiktoken
33+
34+
_TIKTOKEN_ENCODING = tiktoken.encoding_for_model("gpt-4o")
35+
_TIKTOKEN_AVAILABLE = True
36+
except (ImportError, Exception) as e:
37+
logger.debug(
38+
f"tiktoken unavailable, will use character-based token estimation: {e}"
39+
)
40+
2841

2942
def _get_llm_global_context():
3043
# Get the context from the environment variable
@@ -42,27 +55,60 @@ def _get_llm_global_context():
4255
return context if context_enabled and context else None
4356

4457

58+
def _estimate_tokens_simple(text: str) -> int:
59+
"""Estimate token count using character-based heuristic.
60+
61+
Uses ~4 characters per token for English/JSON text.
62+
This is a fallback when tiktoken is unavailable.
63+
"""
64+
return len(text) // 4
65+
66+
67+
def _truncate_text_simple(text: str, max_tokens: int) -> str:
68+
"""Truncate text using character-based estimation."""
69+
estimated_chars = max_tokens * 4
70+
if len(text) <= estimated_chars:
71+
return text
72+
73+
# Keep first portion and last 100 tokens worth (~400 chars)
74+
# But ensure we don't take more than 25% for the tail
75+
last_chars = min(400, estimated_chars // 4)
76+
first_chars = estimated_chars - last_chars
77+
78+
return text[:first_chars] + "...[truncated]" + text[-last_chars:]
79+
80+
4581
def _truncate_summary(
4682
summary: Union[str, None], test_id: str, max_tokens: int = 100_000
4783
):
4884
if summary is None or len(summary) < max_tokens:
4985
# since string itself is less than max_tokens, definitely small enough
5086
return summary
5187

52-
# TODO: better context length handling
53-
encoding = tiktoken.encoding_for_model("gpt-4o")
54-
summary_tokens = encoding.encode(summary)
88+
if _TIKTOKEN_AVAILABLE:
89+
# Use tiktoken for accurate token counting
90+
summary_tokens = _TIKTOKEN_ENCODING.encode(summary)
5591

56-
if len(summary_tokens) > max_tokens:
57-
logger.warning(
58-
f"Truncating {test_id} due to context length restrictions..."
59-
" Generated description may be innacurate"
60-
)
61-
summary = (
62-
encoding.decode(summary_tokens[:max_tokens])
63-
+ "...[truncated]"
64-
+ encoding.decode(summary_tokens[-100:])
65-
)
92+
if len(summary_tokens) > max_tokens:
93+
logger.warning(
94+
f"Truncating {test_id} due to context length restrictions..."
95+
" Generated description may be inaccurate"
96+
)
97+
summary = (
98+
_TIKTOKEN_ENCODING.decode(summary_tokens[:max_tokens])
99+
+ "...[truncated]"
100+
+ _TIKTOKEN_ENCODING.decode(summary_tokens[-100:])
101+
)
102+
else:
103+
# Fallback to character-based estimation
104+
estimated_tokens = _estimate_tokens_simple(summary)
105+
106+
if estimated_tokens > max_tokens:
107+
logger.warning(
108+
f"Truncating {test_id} (estimated) due to context length restrictions..."
109+
" Generated description may be inaccurate"
110+
)
111+
summary = _truncate_text_simple(summary, max_tokens)
66112

67113
return summary
68114

0 commit comments

Comments
 (0)