Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(
self._next_buffered_index = 0
self._buffer = collections.deque()
self._lock = threading.Lock()
self._executor_wrapper = None

assert isinstance(read_options.num_threads, int)
assert isinstance(read_options.prefetch_buffer_size, int)
Expand Down Expand Up @@ -262,6 +263,11 @@ def __str__(self) -> str:
f" allow_nones={self._allow_nones})"
)

def set_executor_wrapper(
self, wrapper: typing.Callable[[futures.Executor], futures.Executor]
):
self._executor_wrapper = wrapper

def _set_prefetch_buffer_size(self, buffer_size: int):
self._target_prefetch_buffer_size = buffer_size
# The executor is created in the constructor only if the prefetch buffer
Expand All @@ -275,6 +281,8 @@ def _set_prefetch_buffer_size(self, buffer_size: int):
self._executor = futures.ThreadPoolExecutor(
self._target_num_threads, thread_name_prefix="grain-prefetch"
)
if self._executor_wrapper:
self._executor = self._executor_wrapper(self._executor)
elif self._target_prefetch_buffer_size == 0 and hasattr(self, "_executor"):
self._executor.shutdown()
delattr(self, "_executor")
Expand All @@ -290,6 +298,8 @@ def _set_num_threads(self, num_threads: int) -> None:
self._executor = futures.ThreadPoolExecutor(
self._target_num_threads, thread_name_prefix="grain-prefetch"
)
if self._executor_wrapper:
self._executor = self._executor_wrapper(self._executor)
elif hasattr(self, "_executor"):
delattr(self, "_executor")
if old_executor is not None:
Expand Down
Loading