Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/test_mooncake_force_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,56 @@ def test_enable_hard_pin_default_off(self):
assert restored.enable_hard_pin is False


class TestMooncakeEnvDefaults:
def test_tcp_memcpy_default_is_applied_by_export_env(self):
config = MooncakeConfig(protocol="tcp")

with patch.dict(os.environ, {}, clear=True):
config.export_env()

assert os.environ["MC_STORE_MEMCPY"] == "0"

def test_tcp_memcpy_default_preserves_user_override(self):
config = MooncakeConfig(protocol="tcp")

with patch.dict(os.environ, {"MC_STORE_MEMCPY": "1"}, clear=True):
config.apply_env_defaults()

assert os.environ["MC_STORE_MEMCPY"] == "1"

def test_tcp_memcpy_default_not_applied_for_rdma(self):
config = MooncakeConfig(protocol="rdma")

with patch.dict(os.environ, {}, clear=True):
config.apply_env_defaults()

assert "MC_STORE_MEMCPY" not in os.environ

def test_direct_store_setup_applies_tcp_memcpy_before_mooncake_client_setup(self):
config = MooncakeConfig(protocol="tcp", async_put_pool_size=0)
mock_raw_store = MagicMock()
mock_raw_store.setup.return_value = 0

class ConcreteStore(MooncakeHiddenStateStore):
pass

def make_raw_store():
assert os.environ["MC_STORE_MEMCPY"] == "0"
return mock_raw_store

store = ConcreteStore(config)
with (
patch.dict(os.environ, {}, clear=True),
patch("torchspec.transfer.mooncake.store.MooncakeDistributedStore", make_raw_store),
patch.object(ConcreteStore, "_verify_force_delete"),
patch.object(ConcreteStore, "_build_replicate_config"),
patch("torch.cuda.is_available", return_value=False),
):
store.setup()

mock_raw_store.setup.assert_called_once()


# ---------------------------------------------------------------------------
# Tests 2-3: _verify_force_delete
# ---------------------------------------------------------------------------
Expand Down
5 changes: 2 additions & 3 deletions tests/test_placement_group.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from argparse import Namespace
import importlib.util
import sys
import types
from argparse import Namespace
from pathlib import Path
from unittest.mock import MagicMock, patch

import pytest


repo_root = Path(__file__).resolve().parents[1]
torchspec_pkg = sys.modules.get("torchspec")
if torchspec_pkg is None and importlib.util.find_spec("torch") is None:
Expand Down Expand Up @@ -67,8 +66,8 @@ def __init__(self, **kwargs):
sys.modules["torchspec.ray.train_group"] = train_group_stub

from torchspec.ray.placement_group import ( # noqa: E402
_NodeConstraint,
_build_custom_bundles,
_NodeConstraint,
_sort_probed_bundle_infos,
create_placement_groups,
)
Expand Down
9 changes: 9 additions & 0 deletions torchspec/config/mooncake_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ def export_env(self) -> None:
os.environ["MOONCAKE_PROTOCOL"] = self.protocol
os.environ["MOONCAKE_DEVICE_NAME"] = self.device_name
os.environ["MOONCAKE_ENABLE_GPU_DIRECT"] = "1" if self.enable_gpu_direct else "0"
self.apply_env_defaults()
if self.async_put_pool_size is not None:
os.environ["MOONCAKE_ASYNC_PUT_POOL_SIZE"] = str(self.async_put_pool_size)
os.environ["MOONCAKE_STORE_FULL_WAIT_SECONDS"] = str(self.store_full_wait_seconds)
Expand All @@ -190,6 +191,14 @@ def export_env(self) -> None:
os.environ["MOONCAKE_GET_RETRY_MAX_WAIT_SECONDS"] = str(self.get_retry_max_wait_seconds)
os.environ["MOONCAKE_ENABLE_HARD_PIN"] = "1" if self.enable_hard_pin else "0"

def apply_env_defaults(self) -> None:
"""Apply Mooncake process defaults that are needed before client setup."""
# Fix: https://github.com/kvcache-ai/Mooncake/issues/1986
if self.protocol.lower() == "tcp" and "MC_STORE_MEMCPY" not in os.environ:
# Mooncake's TCP-only memcpy fast path can segfault in same-host
# multi-process get paths. Preserve an explicit user override.
os.environ["MC_STORE_MEMCPY"] = "0"

@classmethod
def from_env(cls) -> "MooncakeConfig":
"""Create config from environment variables."""
Expand Down
1 change: 0 additions & 1 deletion torchspec/ray/placement_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from torchspec.ray.train_group import RayTrainGroup
from torchspec.utils.logging import logger


# Ray exposes a tiny "node:<ip>" resource on each node. Requiring a fractional
# amount pins a bundle to that node without consuming a full logical resource.
_NODE_RESOURCE_EPSILON = 0.001
Expand Down
1 change: 1 addition & 0 deletions torchspec/transfer/mooncake/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def setup(self, device: torch.device | int | None = None) -> None:
"Set mooncake.device_name to a specific RDMA device (e.g. 'mlx5_0')."
)

self.config.apply_env_defaults()
self._store = MooncakeDistributedStore()
logger.info(
"Connecting to Mooncake master at %s, metadata server at %s",
Expand Down
Loading