Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions bp2vec_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def build_model(num_batters: int, num_pitchers: int, num_outcomes: int):
from tensorflow import keras
from tensorflow.keras import layers
except ImportError:
log.error("TensorFlow not installed. Run: pip install tensorflow")
sys.exit(1)
log.error("TensorFlow not installed — bp2vec training skipped. Add tensorflow to requirements_army.txt")
return None
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The transition from sys.exit(1) to return None correctly addresses the issue of disrupting the scheduler's event loop. However, please note that load_statcast_seasons (line 130) still contains a sys.exit(1) call on pybaseball import failure. To fully achieve the goal of a graceful fallback and prevent process termination in a shared environment, that instance should also be updated to return or raise an exception instead of exiting.


batter_idx_in = keras.Input(shape=(1,), dtype="int32", name="batter_idx")
pitcher_idx_in = keras.Input(shape=(1,), dtype="int32", name="pitcher_idx")
Expand Down Expand Up @@ -312,6 +312,9 @@ def train(seasons: list[int]) -> None:
log.info("Building model: %d batters, %d pitchers, %d outcomes",
num_batters, num_pitchers, num_outcomes)
model = build_model(num_batters, num_pitchers, num_outcomes)
if model is None:
log.warning("bp2vec training aborted — TensorFlow unavailable. Add tensorflow to requirements_army.txt")
return
Comment on lines +315 to +317
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check correctly handles the case where TensorFlow is missing. However, it occurs after load_statcast_seasons and build_indices, which can be time-consuming (the docstring mentions 10-20 minutes). Consider performing a quick check for TensorFlow availability at the very beginning of the train function to "fail fast" and avoid unnecessary data processing when the dependency is unavailable.

model.summary()

log.info("Training for %d epochs...", NUM_EPOCHS)
Expand Down
1 change: 1 addition & 0 deletions requirements_army.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ pyarrow
pybaseball>=2.2.7
mlb-statsapi>=1.7.2
curl_cffi>=0.6.0
tensorflow>=2.15.0