diff --git a/README.md b/README.md index 511435e..7c721b3 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,50 @@ export NEO4J_USERNAME= your NEO4J_USERNAME export NEO4J_PASSWORD= your NEO4J_PASSWORD +#### Using MiniMax as an alternative LLM provider + +Medical-Graph-RAG supports [MiniMax](https://www.minimaxi.com/) as an alternative LLM provider via their OpenAI-compatible API. MiniMax-M2.7 offers 204K context window, which is well-suited for processing long medical documents. + +To use MiniMax instead of OpenAI: + +```bash +export MINIMAX_API_KEY=your_minimax_api_key +# Optional: explicitly set the provider (auto-detected from MINIMAX_API_KEY) +export LLM_PROVIDER=minimax +# Optional: override the default model +export LLM_MODEL=MiniMax-M2.7-highspeed +``` + +Available MiniMax models: `MiniMax-M2.7` (default, 204K context), `MiniMax-M2.7-highspeed` (faster, 204K context). + +> **Note:** MiniMax does not provide a public embedding API. When using MiniMax as the LLM provider, embeddings will still use OpenAI's `text-embedding-3-small` model. Ensure `OPENAI_API_KEY` is set for embedding operations. + +For the `nano_graphrag` pipeline, pass the MiniMax completion function: + +```python +from nano_graphrag import GraphRAG +from nano_graphrag._llm import minimax_m27_complete + +graph_func = GraphRAG( + working_dir="./nanotest", + best_model_func=minimax_m27_complete, + cheap_model_func=minimax_m27_complete, +) +``` + +For the CAMEL agent framework, use the MiniMax model type: + +```python +from camel.models import ModelFactory +from camel.types import ModelPlatformType, ModelType + +model = ModelFactory.create( + model_platform=ModelPlatformType.MINIMAX, + model_type=ModelType.MINIMAX_M27, + model_config_dict={"temperature": 0.2}, +) +``` + ### 2. Construct the graph (use "mimic_ex" dataset as an example) 1. Download mimic_ex [here](https://huggingface.co/datasets/Morson/mimic_ex), put that under your data path, like ./dataset/mimic_ex diff --git a/camel/configs/__init__.py b/camel/configs/__init__.py index a4eb671..294cc14 100644 --- a/camel/configs/__init__.py +++ b/camel/configs/__init__.py @@ -16,6 +16,7 @@ from .gemini_config import Gemini_API_PARAMS, GeminiConfig from .groq_config import GROQ_API_PARAMS, GroqConfig from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig +from .minimax_config import MINIMAX_API_PARAMS, MiniMaxConfig from .mistral_config import MISTRAL_API_PARAMS, MistralConfig from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig, OpenSourceConfig @@ -33,6 +34,8 @@ 'OpenSourceConfig', 'LiteLLMConfig', 'LITELLM_API_PARAMS', + 'MiniMaxConfig', + 'MINIMAX_API_PARAMS', 'OllamaConfig', 'OLLAMA_API_PARAMS', 'ZhipuAIConfig', diff --git a/camel/configs/minimax_config.py b/camel/configs/minimax_config.py new file mode 100644 index 0000000..1b854ad --- /dev/null +++ b/camel/configs/minimax_config.py @@ -0,0 +1,50 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +from __future__ import annotations + +from typing import Optional, Sequence, Union + +from openai._types import NOT_GIVEN, NotGiven + +from camel.configs.base_config import BaseConfig + + +class MiniMaxConfig(BaseConfig): + r"""Defines the parameters for generating chat completions using the + MiniMax API (OpenAI-compatible). + + MiniMax requires temperature in (0.0, 1.0]. Values outside this range + are clamped automatically. + + Args: + temperature (float, optional): Sampling temperature, clamped to + (0.0, 1.0] for MiniMax. (default: :obj:`0.2`) + top_p (float, optional): Nucleus sampling parameter. + (default: :obj:`1.0`) + max_tokens (int, optional): Maximum number of tokens to generate. + (default: :obj:`NOT_GIVEN`) + stream (bool, optional): Whether to stream partial results. + (default: :obj:`False`) + stop (str or list, optional): Stop sequences. + (default: :obj:`NOT_GIVEN`) + """ + + temperature: float = 0.2 + top_p: float = 1.0 + stream: bool = False + stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN + max_tokens: Union[int, NotGiven] = NOT_GIVEN + + +MINIMAX_API_PARAMS = {param for param in MiniMaxConfig.model_fields.keys()} diff --git a/camel/models/__init__.py b/camel/models/__init__.py index 833217c..ebfe456 100644 --- a/camel/models/__init__.py +++ b/camel/models/__init__.py @@ -17,6 +17,7 @@ from .gemini_model import GeminiModel from .groq_model import GroqModel from .litellm_model import LiteLLMModel +from .minimax_model import MiniMaxModel from .mistral_model import MistralModel from .model_factory import ModelFactory from .nemotron_model import NemotronModel @@ -33,6 +34,7 @@ 'OpenAIModel', 'AzureOpenAIModel', 'AnthropicModel', + 'MiniMaxModel', 'MistralModel', 'GroqModel', 'StubModel', diff --git a/camel/models/minimax_model.py b/camel/models/minimax_model.py new file mode 100644 index 0000000..f1161f4 --- /dev/null +++ b/camel/models/minimax_model.py @@ -0,0 +1,111 @@ +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== +import os +from typing import Any, Dict, List, Optional, Union + +from openai import OpenAI, Stream + +from camel.configs import MINIMAX_API_PARAMS +from camel.messages import OpenAIMessage +from camel.models import BaseModelBackend +from camel.types import ChatCompletion, ChatCompletionChunk, ModelType +from camel.utils import ( + BaseTokenCounter, + OpenAITokenCounter, + api_keys_required, +) + + +class MiniMaxModel(BaseModelBackend): + r"""MiniMax API in a unified BaseModelBackend interface. + + Uses MiniMax's OpenAI-compatible endpoint at https://api.minimax.io/v1. + Supports MiniMax-M2.7 and MiniMax-M2.7-highspeed models (204K context). + Temperature is automatically clamped to (0.0, 1.0]. + """ + + def __init__( + self, + model_type: ModelType, + model_config_dict: Dict[str, Any], + api_key: Optional[str] = None, + url: Optional[str] = None, + token_counter: Optional[BaseTokenCounter] = None, + ) -> None: + super().__init__( + model_type, model_config_dict, api_key, url, token_counter + ) + self._url = url or os.environ.get( + "MINIMAX_API_BASE_URL", "https://api.minimax.io/v1" + ) + self._api_key = api_key or os.environ.get("MINIMAX_API_KEY") + self._client = OpenAI( + timeout=60, + max_retries=3, + base_url=self._url, + api_key=self._api_key, + ) + + @property + def token_counter(self) -> BaseTokenCounter: + if not self._token_counter: + self._token_counter = OpenAITokenCounter(ModelType.GPT_4O) + return self._token_counter + + @api_keys_required("MINIMAX_API_KEY") + def run( + self, + messages: List[OpenAIMessage], + ) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]: + r"""Runs inference of MiniMax chat completion. + + Args: + messages (List[OpenAIMessage]): Message list with the chat history + in OpenAI API format. + + Returns: + Union[ChatCompletion, Stream[ChatCompletionChunk]]: + `ChatCompletion` in the non-stream mode, or + `Stream[ChatCompletionChunk]` in the stream mode. + """ + config = dict(self.model_config_dict) + # Clamp temperature to MiniMax's valid range (0.0, 1.0] + if "temperature" in config: + config["temperature"] = max(0.01, min(config["temperature"], 1.0)) + + response = self._client.chat.completions.create( + messages=messages, + model=self.model_type.value, + **config, + ) + return response + + def check_model_config(self): + r"""Check whether the model configuration contains any + unexpected arguments to MiniMax API. + + Raises: + ValueError: If the model configuration dictionary contains any + unexpected arguments. + """ + for param in self.model_config_dict: + if param not in MINIMAX_API_PARAMS: + raise ValueError( + f"Unexpected argument `{param}` is " + "input into MiniMax model backend." + ) + + @property + def stream(self) -> bool: + return self.model_config_dict.get('stream', False) diff --git a/camel/models/model_factory.py b/camel/models/model_factory.py index 5bac5f1..db9b5e9 100644 --- a/camel/models/model_factory.py +++ b/camel/models/model_factory.py @@ -19,6 +19,7 @@ from camel.models.gemini_model import GeminiModel from camel.models.groq_model import GroqModel from camel.models.litellm_model import LiteLLMModel +from camel.models.minimax_model import MiniMaxModel from camel.models.mistral_model import MistralModel from camel.models.ollama_model import OllamaModel from camel.models.open_source_model import OpenSourceModel @@ -90,6 +91,8 @@ def create( model_class = GeminiModel elif model_platform.is_mistral and model_type.is_mistral: model_class = MistralModel + elif model_platform.is_minimax and model_type.is_minimax: + model_class = MiniMaxModel elif model_type == ModelType.STUB: model_class = StubModel else: diff --git a/camel/types/enums.py b/camel/types/enums.py index 0cbdef0..dd4b6ed 100644 --- a/camel/types/enums.py +++ b/camel/types/enums.py @@ -82,6 +82,10 @@ class ModelType(Enum): MISTRAL_MIXTRAL_8x22B = "open-mixtral-8x22b" MISTRAL_CODESTRAL_MAMBA = "open-codestral-mamba" + # MiniMax AI Models + MINIMAX_M27 = "MiniMax-M2.7" + MINIMAX_M27_HIGHSPEED = "MiniMax-M2.7-highspeed" + @property def value_for_tiktoken(self) -> str: return ( @@ -191,6 +195,14 @@ def is_nvidia(self) -> bool: ModelType.NEMOTRON_4_REWARD, } + @property + def is_minimax(self) -> bool: + r"""Returns whether this type of models is a MiniMax model.""" + return self in { + ModelType.MINIMAX_M27, + ModelType.MINIMAX_M27_HIGHSPEED, + } + @property def is_gemini(self) -> bool: return self in {ModelType.GEMINI_1_5_FLASH, ModelType.GEMINI_1_5_PRO} @@ -265,6 +277,11 @@ def token_limit(self) -> int: ModelType.CLAUDE_3_5_SONNET, }: return 200_000 + elif self in { + ModelType.MINIMAX_M27, + ModelType.MINIMAX_M27_HIGHSPEED, + }: + return 204_000 elif self in { ModelType.MISTRAL_CODESTRAL_MAMBA, }: @@ -448,6 +465,7 @@ class ModelPlatformType(Enum): GEMINI = "gemini" VLLM = "vllm" MISTRAL = "mistral" + MINIMAX = "minimax" @property def is_openai(self) -> bool: @@ -504,6 +522,11 @@ def is_gemini(self) -> bool: r"""Returns whether this platform is Gemini.""" return self is ModelPlatformType.GEMINI + @property + def is_minimax(self) -> bool: + r"""Returns whether this platform is MiniMax.""" + return self is ModelPlatformType.MINIMAX + class AudioModelType(Enum): TTS_1 = "tts-1" diff --git a/nano_graphrag/_llm.py b/nano_graphrag/_llm.py index ccdbf6d..7585ae8 100644 --- a/nano_graphrag/_llm.py +++ b/nano_graphrag/_llm.py @@ -36,6 +36,41 @@ async def openai_complete_if_cache( return response.choices[0].message.content +async def minimax_complete_if_cache( + model, prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + """MiniMax completion via OpenAI-compatible API with temperature clamping.""" + minimax_async_client = AsyncOpenAI( + api_key=os.getenv("MINIMAX_API_KEY"), + base_url=os.getenv("MINIMAX_API_BASE_URL", "https://api.minimax.io/v1"), + ) + hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + if hashing_kv is not None: + args_hash = compute_args_hash(model, messages) + if_cache_return = await hashing_kv.get_by_id(args_hash) + if if_cache_return is not None: + return if_cache_return["return"] + + # Clamp temperature for MiniMax: must be in (0.0, 1.0] + if "temperature" in kwargs: + kwargs["temperature"] = max(0.01, min(kwargs["temperature"], 1.0)) + + response = await minimax_async_client.chat.completions.create( + model=model, messages=messages, **kwargs + ) + + if hashing_kv is not None: + await hashing_kv.upsert( + {args_hash: {"return": response.choices[0].message.content, "model": model}} + ) + return response.choices[0].message.content + + async def gpt_4o_complete( prompt, system_prompt=None, history_messages=[], **kwargs ) -> str: @@ -60,6 +95,32 @@ async def gpt_4o_mini_complete( ) +async def minimax_m27_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + """MiniMax M2.7 completion (204K context).""" + return await minimax_complete_if_cache( + "MiniMax-M2.7", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + +async def minimax_m27_highspeed_complete( + prompt, system_prompt=None, history_messages=[], **kwargs +) -> str: + """MiniMax M2.7-highspeed completion (204K context, faster inference).""" + return await minimax_complete_if_cache( + "MiniMax-M2.7-highspeed", + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + + @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192) async def openai_embedding(texts: list[str]) -> np.ndarray: openai_async_client = AsyncOpenAI( diff --git a/summerize.py b/summerize.py index 4510d39..207c672 100644 --- a/summerize.py +++ b/summerize.py @@ -4,8 +4,30 @@ import tiktoken import os -# Add your own OpenAI API key -openai_api_key = os.getenv("OPENAI_API_KEY") +# Provider configuration presets (shared with utils.py) +_PROVIDER_PRESETS = { + "openai": { + "api_key_env": "OPENAI_API_KEY", + "base_url_env": "OPENAI_API_BASE_URL", + "default_model": "gpt-4-1106-preview", + }, + "minimax": { + "api_key_env": "MINIMAX_API_KEY", + "base_url_env": "MINIMAX_API_BASE_URL", + "default_base_url": "https://api.minimax.io/v1", + "default_model": "MiniMax-M2.7", + }, +} + + +def _detect_provider(): + explicit = os.getenv("LLM_PROVIDER", "").lower() + if explicit in _PROVIDER_PRESETS: + return explicit + if os.getenv("MINIMAX_API_KEY"): + return "minimax" + return "openai" + sum_prompt = """ Generate a structured summary from the provided medical source (report, paper, or book), strictly adhering to the following categories. The summary should list key information under each category in a concise format: 'CATEGORY_NAME: Key information'. No additional explanations or detailed descriptions are necessary unless directly related to the categories: @@ -33,12 +55,20 @@ """ def call_openai_api(chunk): - client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_API_BASE_URL") + provider = _detect_provider() + preset = _PROVIDER_PRESETS[provider] + api_key = os.getenv(preset["api_key_env"]) + base_url = os.getenv( + preset["base_url_env"], + preset.get("default_base_url"), ) + client = OpenAI(api_key=api_key, base_url=base_url) + model = os.getenv("LLM_MODEL", preset["default_model"]) + temperature = 0.5 + if provider == "minimax": + temperature = max(0.01, min(temperature, 1.0)) response = client.chat.completions.create( - model="gpt-4-1106-preview", + model=model, messages=[ {"role": "system", "content": sum_prompt}, {"role": "user", "content": f" {chunk}"}, @@ -46,7 +76,7 @@ def call_openai_api(chunk): max_tokens=500, n=1, stop=None, - temperature=0.5, + temperature=temperature, ) return response.choices[0].message.content diff --git a/tests/test_minimax_integration.py b/tests/test_minimax_integration.py new file mode 100644 index 0000000..132224d --- /dev/null +++ b/tests/test_minimax_integration.py @@ -0,0 +1,88 @@ +"""Integration tests for MiniMax LLM provider. + +These tests require a valid MINIMAX_API_KEY environment variable and +make real API calls to the MiniMax service. They are skipped when +the key is not available. +""" +import os +import unittest +import asyncio + +MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY") +SKIP_REASON = "MINIMAX_API_KEY not set" + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxUtilsIntegration(unittest.TestCase): + """Integration test for utils.py call_llm with MiniMax.""" + + def test_call_llm_with_minimax(self): + """Test end-to-end LLM call via MiniMax provider.""" + os.environ["LLM_PROVIDER"] = "minimax" + from utils import call_llm + result = call_llm( + "You are a helpful assistant.", + "What is 2+2? Answer with just the number." + ) + self.assertIsInstance(result, str) + self.assertIn("4", result) + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxNanoGraphRAGIntegration(unittest.TestCase): + """Integration test for nano_graphrag MiniMax functions.""" + + def test_minimax_m27_complete(self): + """Test MiniMax M2.7 completion via nano_graphrag.""" + from nano_graphrag._llm import minimax_m27_complete + + async def _run(): + result = await minimax_m27_complete( + "What is 1+1? Answer with just the number." + ) + return result + + result = asyncio.get_event_loop().run_until_complete(_run()) + self.assertIsInstance(result, str) + self.assertIn("2", result) + + def test_minimax_m27_highspeed_complete(self): + """Test MiniMax M2.7-highspeed completion.""" + from nano_graphrag._llm import minimax_m27_highspeed_complete + + async def _run(): + result = await minimax_m27_highspeed_complete( + "What is 3+3? Answer with just the number." + ) + return result + + result = asyncio.get_event_loop().run_until_complete(_run()) + self.assertIsInstance(result, str) + self.assertIn("6", result) + + +@unittest.skipUnless(MINIMAX_API_KEY, SKIP_REASON) +class TestMiniMaxModelIntegration(unittest.TestCase): + """Integration test for CAMEL MiniMax model backend.""" + + def test_camel_minimax_model_run(self): + """Test CAMEL ModelFactory-created MiniMax model.""" + from camel.models import ModelFactory + from camel.types import ModelPlatformType, ModelType + + model = ModelFactory.create( + model_platform=ModelPlatformType.MINIMAX, + model_type=ModelType.MINIMAX_M27, + model_config_dict={"temperature": 0.5, "max_tokens": 50}, + ) + messages = [ + {"role": "user", "content": "What is 5+5? Answer with just the number."} + ] + response = model.run(messages) + content = response.choices[0].message.content + self.assertIsInstance(content, str) + self.assertIn("10", content) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_minimax_unit.py b/tests/test_minimax_unit.py new file mode 100644 index 0000000..08e3a39 --- /dev/null +++ b/tests/test_minimax_unit.py @@ -0,0 +1,351 @@ +"""Unit tests for MiniMax LLM provider integration.""" +import os +import unittest +from unittest.mock import MagicMock, patch, AsyncMock +import asyncio + + +class TestProviderDetection(unittest.TestCase): + """Test provider auto-detection logic in utils.py.""" + + def _import_utils(self): + """Import utils module functions directly.""" + import importlib + import sys + # We need to handle the import chain carefully + # utils.py imports from summerize.py, so we mock that chain + if 'utils' in sys.modules: + importlib.reload(sys.modules['utils']) + # Import the detection functions from utils + from utils import _detect_provider, PROVIDER_PRESETS, _clamp_temperature + return _detect_provider, PROVIDER_PRESETS, _clamp_temperature + + @patch.dict(os.environ, {}, clear=True) + def test_default_provider_is_openai(self): + """When no env vars are set, default to openai.""" + _detect_provider, _, _ = self._import_utils() + self.assertEqual(_detect_provider(), "openai") + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}, clear=True) + def test_auto_detect_minimax_from_api_key(self): + """MINIMAX_API_KEY triggers minimax auto-detection.""" + _detect_provider, _, _ = self._import_utils() + self.assertEqual(_detect_provider(), "minimax") + + @patch.dict(os.environ, {"LLM_PROVIDER": "minimax"}, clear=True) + def test_explicit_provider_env_var(self): + """LLM_PROVIDER env var takes priority.""" + _detect_provider, _, _ = self._import_utils() + self.assertEqual(_detect_provider(), "minimax") + + @patch.dict(os.environ, { + "LLM_PROVIDER": "openai", + "MINIMAX_API_KEY": "test-key" + }, clear=True) + def test_explicit_provider_overrides_auto_detect(self): + """LLM_PROVIDER=openai overrides MINIMAX_API_KEY auto-detection.""" + _detect_provider, _, _ = self._import_utils() + self.assertEqual(_detect_provider(), "openai") + + def test_provider_presets_have_required_keys(self): + """Verify all provider presets have required configuration keys.""" + _, PROVIDER_PRESETS, _ = self._import_utils() + for name, preset in PROVIDER_PRESETS.items(): + self.assertIn("api_key_env", preset, f"{name} missing api_key_env") + self.assertIn("default_model", preset, f"{name} missing default_model") + + def test_minimax_preset_values(self): + """Verify MiniMax preset has correct defaults.""" + _, PROVIDER_PRESETS, _ = self._import_utils() + mm = PROVIDER_PRESETS["minimax"] + self.assertEqual(mm["api_key_env"], "MINIMAX_API_KEY") + self.assertEqual(mm["default_base_url"], "https://api.minimax.io/v1") + self.assertEqual(mm["default_model"], "MiniMax-M2.7") + self.assertIsNone(mm["embedding_model"]) + + +class TestTemperatureClamping(unittest.TestCase): + """Test temperature clamping for MiniMax.""" + + def _import_clamp(self): + from utils import _clamp_temperature + return _clamp_temperature + + def test_minimax_clamp_zero(self): + """MiniMax temp=0 should clamp to 0.01.""" + clamp = self._import_clamp() + self.assertAlmostEqual(clamp(0.0, "minimax"), 0.01) + + def test_minimax_clamp_negative(self): + """MiniMax negative temp should clamp to 0.01.""" + clamp = self._import_clamp() + self.assertAlmostEqual(clamp(-1.0, "minimax"), 0.01) + + def test_minimax_clamp_above_one(self): + """MiniMax temp>1.0 should clamp to 1.0.""" + clamp = self._import_clamp() + self.assertAlmostEqual(clamp(1.5, "minimax"), 1.0) + + def test_minimax_valid_temp_unchanged(self): + """MiniMax temp in valid range should stay unchanged.""" + clamp = self._import_clamp() + self.assertAlmostEqual(clamp(0.5, "minimax"), 0.5) + + def test_minimax_temp_one_unchanged(self): + """MiniMax temp=1.0 is the upper bound, should be unchanged.""" + clamp = self._import_clamp() + self.assertAlmostEqual(clamp(1.0, "minimax"), 1.0) + + def test_openai_no_clamping(self): + """OpenAI provider should not clamp temperature.""" + clamp = self._import_clamp() + self.assertAlmostEqual(clamp(1.5, "openai"), 1.5) + self.assertAlmostEqual(clamp(0.0, "openai"), 0.0) + + +class TestSummerizeProviderDetection(unittest.TestCase): + """Test provider detection in summerize.py.""" + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}, clear=True) + def test_summerize_detects_minimax(self): + from summerize import _detect_provider + self.assertEqual(_detect_provider(), "minimax") + + @patch.dict(os.environ, {}, clear=True) + def test_summerize_defaults_openai(self): + from summerize import _detect_provider + self.assertEqual(_detect_provider(), "openai") + + +class TestMiniMaxModelType(unittest.TestCase): + """Test MiniMax model type enum values.""" + + def test_minimax_model_types_exist(self): + from camel.types import ModelType + self.assertEqual(ModelType.MINIMAX_M27.value, "MiniMax-M2.7") + self.assertEqual(ModelType.MINIMAX_M27_HIGHSPEED.value, "MiniMax-M2.7-highspeed") + + def test_minimax_is_minimax(self): + from camel.types import ModelType + self.assertTrue(ModelType.MINIMAX_M27.is_minimax) + self.assertTrue(ModelType.MINIMAX_M27_HIGHSPEED.is_minimax) + + def test_openai_is_not_minimax(self): + from camel.types import ModelType + self.assertFalse(ModelType.GPT_4O.is_minimax) + + def test_minimax_is_not_openai(self): + from camel.types import ModelType + self.assertFalse(ModelType.MINIMAX_M27.is_openai) + + def test_minimax_token_limit(self): + from camel.types import ModelType + self.assertEqual(ModelType.MINIMAX_M27.token_limit, 204_000) + self.assertEqual(ModelType.MINIMAX_M27_HIGHSPEED.token_limit, 204_000) + + +class TestMiniMaxPlatformType(unittest.TestCase): + """Test MiniMax platform type enum.""" + + def test_minimax_platform_exists(self): + from camel.types import ModelPlatformType + self.assertEqual(ModelPlatformType.MINIMAX.value, "minimax") + + def test_is_minimax(self): + from camel.types import ModelPlatformType + self.assertTrue(ModelPlatformType.MINIMAX.is_minimax) + self.assertFalse(ModelPlatformType.OPENAI.is_minimax) + + +class TestMiniMaxConfig(unittest.TestCase): + """Test MiniMax configuration class.""" + + def test_config_defaults(self): + from camel.configs import MiniMaxConfig + cfg = MiniMaxConfig() + self.assertEqual(cfg.temperature, 0.2) + self.assertEqual(cfg.top_p, 1.0) + self.assertFalse(cfg.stream) + + def test_config_api_params(self): + from camel.configs import MINIMAX_API_PARAMS + self.assertIn("temperature", MINIMAX_API_PARAMS) + self.assertIn("top_p", MINIMAX_API_PARAMS) + self.assertIn("max_tokens", MINIMAX_API_PARAMS) + self.assertIn("stream", MINIMAX_API_PARAMS) + self.assertIn("stop", MINIMAX_API_PARAMS) + + +class TestMiniMaxModel(unittest.TestCase): + """Test MiniMax model backend.""" + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_model_initialization(self): + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelType + model = MiniMaxModel( + model_type=ModelType.MINIMAX_M27, + model_config_dict={"temperature": 0.5}, + ) + self.assertEqual(model._url, "https://api.minimax.io/v1") + self.assertEqual(model._api_key, "test-key") + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_model_custom_url(self): + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelType + model = MiniMaxModel( + model_type=ModelType.MINIMAX_M27, + model_config_dict={}, + url="https://custom.minimax.io/v1", + ) + self.assertEqual(model._url, "https://custom.minimax.io/v1") + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_check_model_config_valid(self): + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelType + model = MiniMaxModel( + model_type=ModelType.MINIMAX_M27, + model_config_dict={"temperature": 0.5, "max_tokens": 100}, + ) + model.check_model_config() # Should not raise + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_check_model_config_invalid(self): + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelType + with self.assertRaises(ValueError): + MiniMaxModel( + model_type=ModelType.MINIMAX_M27, + model_config_dict={"invalid_param": True}, + ) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_stream_property(self): + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelType + model = MiniMaxModel( + model_type=ModelType.MINIMAX_M27, + model_config_dict={"stream": True}, + ) + self.assertTrue(model.stream) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_temperature_clamping_in_run(self): + """Verify that run() clamps temperature before sending to API.""" + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelType + model = MiniMaxModel( + model_type=ModelType.MINIMAX_M27, + model_config_dict={"temperature": 0.0}, + ) + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "test response" + model._client.chat.completions.create = MagicMock(return_value=mock_response) + model.run([{"role": "user", "content": "test"}]) + call_kwargs = model._client.chat.completions.create.call_args + self.assertGreater(call_kwargs[1]["temperature"], 0.0) + + +class TestModelFactory(unittest.TestCase): + """Test MiniMax integration in ModelFactory.""" + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_factory_creates_minimax_model(self): + from camel.models import ModelFactory + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelPlatformType, ModelType + model = ModelFactory.create( + model_platform=ModelPlatformType.MINIMAX, + model_type=ModelType.MINIMAX_M27, + model_config_dict={"temperature": 0.2}, + ) + self.assertIsInstance(model, MiniMaxModel) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_factory_creates_minimax_highspeed(self): + from camel.models import ModelFactory + from camel.models.minimax_model import MiniMaxModel + from camel.types import ModelPlatformType, ModelType + model = ModelFactory.create( + model_platform=ModelPlatformType.MINIMAX, + model_type=ModelType.MINIMAX_M27_HIGHSPEED, + model_config_dict={}, + ) + self.assertIsInstance(model, MiniMaxModel) + + def test_factory_rejects_wrong_platform_model_pair(self): + from camel.models import ModelFactory + from camel.types import ModelPlatformType, ModelType + with self.assertRaises(ValueError): + ModelFactory.create( + model_platform=ModelPlatformType.MINIMAX, + model_type=ModelType.GPT_4O, + model_config_dict={}, + ) + + +class TestNanoGraphRAGMiniMax(unittest.TestCase): + """Test MiniMax functions in nano_graphrag._llm.""" + + def test_minimax_complete_if_cache_exists(self): + from nano_graphrag._llm import minimax_complete_if_cache + self.assertTrue(callable(minimax_complete_if_cache)) + + def test_minimax_m27_complete_exists(self): + from nano_graphrag._llm import minimax_m27_complete + self.assertTrue(callable(minimax_m27_complete)) + + def test_minimax_m27_highspeed_complete_exists(self): + from nano_graphrag._llm import minimax_m27_highspeed_complete + self.assertTrue(callable(minimax_m27_highspeed_complete)) + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_minimax_complete_if_cache_clamps_temperature(self): + """Test that temperature is clamped in minimax_complete_if_cache.""" + from nano_graphrag._llm import minimax_complete_if_cache + + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "test" + + async def _run(): + with patch('nano_graphrag._llm.AsyncOpenAI') as MockClient: + mock_instance = MagicMock() + mock_instance.chat.completions.create = AsyncMock(return_value=mock_response) + MockClient.return_value = mock_instance + await minimax_complete_if_cache( + "MiniMax-M2.7", "test prompt", temperature=0.0 + ) + call_kwargs = mock_instance.chat.completions.create.call_args + self.assertGreaterEqual(call_kwargs[1].get("temperature", 0.01), 0.01) + + asyncio.get_event_loop().run_until_complete(_run()) + + +class TestNanoGraphRAGMiniMaxCaching(unittest.TestCase): + """Test MiniMax caching in nano_graphrag._llm.""" + + @patch.dict(os.environ, {"MINIMAX_API_KEY": "test-key"}) + def test_cache_hit_returns_cached_value(self): + """Test that cached results are returned without API call.""" + from nano_graphrag._llm import minimax_complete_if_cache + + mock_kv = MagicMock() + mock_kv.get_by_id = AsyncMock(return_value={"return": "cached result"}) + + async def _run(): + with patch('nano_graphrag._llm.AsyncOpenAI'): + result = await minimax_complete_if_cache( + "MiniMax-M2.7", "test prompt", hashing_kv=mock_kv + ) + self.assertEqual(result, "cached result") + + asyncio.get_event_loop().run_until_complete(_run()) + + +if __name__ == "__main__": + unittest.main() diff --git a/utils.py b/utils.py index 411b264..1d5fe1f 100644 --- a/utils.py +++ b/utils.py @@ -15,14 +15,70 @@ Modify the response to the question using the provided references. Include precise citations relevant to your answer. You may use multiple citations simultaneously, denoting each with the reference index number. For example, cite the first and third documents as [1][3]. If the references do not pertain to the response, simply provide a concise answer to the original question. """ -# Add your own OpenAI API key -openai_api_key = os.getenv("OPENAI_API_KEY") - -def get_embedding(text, mod = "text-embedding-3-small"): - client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_API_BASE_URL") +# Provider configuration presets +PROVIDER_PRESETS = { + "openai": { + "api_key_env": "OPENAI_API_KEY", + "base_url_env": "OPENAI_API_BASE_URL", + "default_model": "gpt-4-1106-preview", + "embedding_model": "text-embedding-3-small", + }, + "minimax": { + "api_key_env": "MINIMAX_API_KEY", + "base_url_env": "MINIMAX_API_BASE_URL", + "default_base_url": "https://api.minimax.io/v1", + "default_model": "MiniMax-M2.7", + "embedding_model": None, # MiniMax does not have a public embedding API + }, +} + + +def _detect_provider(): + """Auto-detect LLM provider from environment variables. + + Priority: LLM_PROVIDER env var > MINIMAX_API_KEY presence > default openai. + """ + explicit = os.getenv("LLM_PROVIDER", "").lower() + if explicit in PROVIDER_PRESETS: + return explicit + if os.getenv("MINIMAX_API_KEY"): + return "minimax" + return "openai" + + +def _get_client(provider=None): + """Build an OpenAI-compatible client for the active provider.""" + provider = provider or _detect_provider() + preset = PROVIDER_PRESETS[provider] + api_key = os.getenv(preset["api_key_env"]) + base_url = os.getenv( + preset["base_url_env"], + preset.get("default_base_url"), ) + return OpenAI(api_key=api_key, base_url=base_url), preset + + +def _clamp_temperature(temperature, provider): + """Clamp temperature to valid range for the provider.""" + if provider == "minimax": + # MiniMax requires temperature in (0.0, 1.0] + return max(0.01, min(temperature, 1.0)) + return temperature + + +def get_embedding(text, mod=None): + provider = _detect_provider() + preset = PROVIDER_PRESETS[provider] + if mod is None: + mod = preset.get("embedding_model") or "text-embedding-3-small" + if provider == "minimax" and preset.get("embedding_model") is None: + # Fall back to OpenAI for embeddings when using MiniMax as LLM provider + client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY"), + base_url=os.getenv("OPENAI_API_BASE_URL"), + ) + else: + client, _ = _get_client(provider) response = client.embeddings.create( input=text, @@ -83,12 +139,12 @@ def add_sum(n4j,content,gid): return s def call_llm(sys, user): - client = OpenAI( - api_key=os.getenv("OPENAI_API_KEY"), - base_url=os.getenv("OPENAI_API_BASE_URL") - ) + provider = _detect_provider() + client, preset = _get_client(provider) + model = os.getenv("LLM_MODEL", preset["default_model"]) + temperature = _clamp_temperature(0.5, provider) response = client.chat.completions.create( - model="gpt-4-1106-preview", + model=model, messages=[ {"role": "system", "content": sys}, {"role": "user", "content": f" {user}"}, @@ -96,7 +152,7 @@ def call_llm(sys, user): max_tokens=500, n=1, stop=None, - temperature=0.5, + temperature=temperature, ) return response.choices[0].message.content