Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change
* Adds experimental `ShapeDtypeStructProtocol` and `ShapeDtypeStruct` to
represent dataset element specs.
* Updates TfMixtureIndexSampler to support datasets with weights of 0.
* Adds profiling of multiprocess workers when using XProf profiler. To enable,
set flag `grain_enable_multiprocess_worker_profiling=true` and add
`"profile_subprocesses" = True` in advanced profiler options.

* Breaking changes:

Expand Down
13 changes: 12 additions & 1 deletion grain/_src/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,10 @@ py_library(
name = "profiler",
srcs = ["profiler.py"],
srcs_version = "PY3",
deps = ["@abseil-py//absl/logging"],
deps = [
"@abseil-py//absl/flags",
"@abseil-py//absl/logging",
],
)

py_test(
Expand All @@ -194,7 +197,11 @@ py_test(
srcs_version = "PY3",
deps = [
":profiler",
"@abseil-py//absl/flags",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:flagsaver",
"@pypi//cloudpickle:pkg",
"@pypi//portpicker:pkg",
],
)

Expand All @@ -208,8 +215,12 @@ py_test(
srcs_version = "PY3",
deps = [
":profiler",
"@abseil-py//absl/flags",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:flagsaver",
"@pypi//cloudpickle:pkg",
"@pypi//jax:pkg", # buildcleaner: keep
"@pypi//portpicker:pkg",
],
)

Expand Down
136 changes: 123 additions & 13 deletions grain/_src/core/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,139 @@
# limitations under the License.
"""Import wrapper for framework specific profilers."""

import functools
from typing import Callable

from absl import flags
from absl import logging

_GRAIN_ENABLE_MULTIPROCESS_WORKER_PROFILING = flags.DEFINE_bool(
"grain_enable_multiprocess_worker_profiling",
False,
"If True, starts profiler servers on spawned worker processes to be"
" profiled alongside the main process (when requested).",
)

# Internal constants.
_NO_FRAMEWORK = "NO_FRAMEWORK"
is_enabled: Callable[[], bool] = lambda: False
TraceAnnotation = None # pylint: disable=invalid-name
start_server: Callable[[int], None] = lambda port: None
stop_server: Callable[[], None] = lambda: None
framework = _NO_FRAMEWORK
_framework = _NO_FRAMEWORK
_subprocess_hooks_loaded = False


class TraceAnnotation(object):
"""No-op trace annotation for when the profiler is not loaded."""

def __init__(self, *args, **kwargs):
del args, kwargs
pass

def __enter__(self):
pass

def __exit__(self, exc_type, exc_value, traceback):
del exc_type, exc_value, traceback
pass


def is_enabled() -> bool:
"""Returns whether the profiler is enabled."""
return False


def is_loaded() -> bool:
"""Returns whether the profiler is loaded."""
return _framework != _NO_FRAMEWORK


def is_worker_profiling_supported() -> bool:
"""Returns whether worker profiling is supported."""
return is_loaded() and _subprocess_hooks_loaded


def is_worker_profiling_enabled() -> bool:
"""Returns whether worker profiling is enabled."""
return (
is_worker_profiling_supported()
and _GRAIN_ENABLE_MULTIPROCESS_WORKER_PROFILING.value
)


def get_framework() -> str:
"""Returns the framework used for profiling."""
return _framework


def start_server(port: int) -> None:
"""Starts the profiler server."""
del port
pass


def stop_server() -> None:
"""Stops the profiler server."""
pass


def register_subprocess(pid: int, port: int) -> Callable[[], None]:
"""Registers a subprocess with its xprof port.

Args:
pid: The process ID of the subprocess.
port: The port of the profiler server in the subprocess.

Returns:
A function that can be called to unregister the subprocess.

Raises:
RuntimeError: If the subprocess fails to be registered.
"""
del pid, port
raise RuntimeError(
"Subprocess profiler registration is not supported in this environment."
)


def get_worker_init_fn(port: int) -> Callable[[], None]:
"""Start the profiler server in a worker process."""
if not is_worker_profiling_enabled():
return lambda: None

def _worker_init_fn() -> None:
try:
start_server(port)
except Exception as e: # pylint: disable=broad-except
logging.warning("Failed to start profiler server: %s", str(e))

return _worker_init_fn


try:
if framework == _NO_FRAMEWORK:
if _framework == _NO_FRAMEWORK:
from jax import profiler # pylint: disable=g-import-not-at-top # pytype: disable=import-error
from jax._src.lib import jaxlib_extension_version # pylint: disable=g-import-not-at-top # pytype: disable=import-error

TraceAnnotation = profiler.TraceAnnotation
is_enabled = profiler.TraceAnnotation.is_enabled
start_server = profiler.start_server
stop_server = profiler.stop_server
framework = "jax"
except ImportError:
logging.warning("Failed to load jax profiler")
# jaxlib_extension_version >=448 has subprocess profiling hooks enabled.
if jaxlib_extension_version >= 448:
# multiprocess workers will crash when attempting to connect to a backend
# so we skip using requires_backend=False.
start_server = functools.partial(
profiler.start_server, requires_backend=False
)
register_subprocess = profiler.register_subprocess
_subprocess_hooks_loaded = True
else:
start_server = profiler.start_server
_subprocess_hooks_loaded = False
logging.warning(
"Grain multiprocess worker profiling requires jaxlib extension"
" version 448 or later (jaxlib >= 0.11.0). Current version: %s.",
jaxlib_extension_version,
)

stop_server = profiler.stop_server

def is_loaded():
return framework != _NO_FRAMEWORK
_framework = "jax"
except ImportError as e:
logging.warning("Failed to load jax profiler: %s", e)
89 changes: 79 additions & 10 deletions grain/_src/core/profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,41 @@

import os
import socket
import time
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
import cloudpickle
import multiprocessing as mp
from grain._src.core import profiler
import portpicker


def _worker_main(worker_init_fn: bytes):
"""Helper function to start a profiler server in a subprocess."""
worker_init_fn = cloudpickle.loads(worker_init_fn)
worker_init_fn()
time.sleep(10)


class ProfilerTest(absltest.TestCase):

def test_framework(self):
expected_framework = os.environ.get("EXPECTED_FRAMEWORK") or "jax"
self.assertEqual(profiler.framework, expected_framework)
expected_framework = os.environ.get("EXPECTED_FRAMEWORK")
self.assertEqual(profiler.get_framework(), expected_framework)

def test_trace_annotation(self):
if profiler.framework == profiler._NO_FRAMEWORK:
self.assertIsNone(profiler.TraceAnnotation)
else:
self.assertIsNotNone(profiler.TraceAnnotation)
with profiler.TraceAnnotation("test"):
passes = True
self.assertTrue(passes)
self.assertIsNotNone(profiler.TraceAnnotation)
with profiler.TraceAnnotation("test"):
passes = True
self.assertTrue(passes)

def test_is_enabled(self):
self.assertIsNotNone(profiler.is_enabled)
self.assertFalse(profiler.is_enabled())

def test_profiler_server(self):
if profiler.framework == profiler._NO_FRAMEWORK:
if not profiler.is_loaded():
self.assertIsNone(profiler.start_server(1234))
else:
port = 1234
Expand All @@ -32,6 +46,61 @@ def test_profiler_server(self):
self.assertEqual(result, 0)
profiler.stop_server()

@flagsaver.flagsaver(grain_enable_multiprocess_worker_profiling=True)
def test_register_unregister_subprocess(self):
port = portpicker.pick_unused_port()
mp_context = mp.get_context("spawn")
subprocess = mp_context.Process(
target=_worker_main,
kwargs=dict({
"worker_init_fn": cloudpickle.dumps(
profiler.get_worker_init_fn(port)
)
}),
daemon=True,
)
subprocess.start()
if not profiler.is_worker_profiling_supported():
with self.assertRaises(RuntimeError):
profiler.register_subprocess(subprocess.pid, port)
else:
unregister_fn = profiler.register_subprocess(subprocess.pid, port)
unregister_fn()
subprocess.kill()

@absltest.skipUnless(
profiler.is_worker_profiling_supported(),
"Worker profiling is not supported.",
)
@flagsaver.flagsaver(grain_enable_multiprocess_worker_profiling=True)
def test_get_worker_init_fn_starts_server(self):
port = portpicker.pick_unused_port()
worker_init_fn = profiler.get_worker_init_fn(port)
worker_init_fn()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex(("localhost", port))
self.assertEqual(result, 0)
profiler.stop_server()

@flagsaver.flagsaver(grain_enable_multiprocess_worker_profiling=False)
def test_get_worker_init_fn_does_not_start_server(self):
port = portpicker.pick_unused_port()
worker_init_fn = profiler.get_worker_init_fn(port)
worker_init_fn()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
result = sock.connect_ex(("localhost", port))
self.assertNotEqual(result, 0)

@flagsaver.flagsaver
def test_worker_profiling_enabled_flag(self):
flags.FLAGS.grain_enable_multiprocess_worker_profiling = False
self.assertFalse(profiler.is_worker_profiling_enabled())
flags.FLAGS.grain_enable_multiprocess_worker_profiling = True
self.assertEqual(
profiler.is_worker_profiling_enabled(),
profiler.is_worker_profiling_supported(),
)


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ py_library(
"//grain/_src/core:config",
"//grain/_src/core:exceptions",
"//grain/_src/core:monitoring",
"//grain/_src/core:profiler",
"//grain/_src/core:sharding",
"//grain/_src/core:traceback_util",
"//grain/_src/core:transforms",
Expand All @@ -64,6 +65,7 @@ py_library(
"@pypi//cloudpickle:pkg",
"@pypi//etils:pkg",
"@pypi//numpy:pkg",
"@pypi//portpicker:pkg",
],
)

Expand Down
Loading
Loading