-
Notifications
You must be signed in to change notification settings - Fork 0
PR #571: Fix bp2vec TensorFlow crash — graceful fallback + add TF to requirements #440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -262,8 +262,8 @@ def build_model(num_batters: int, num_pitchers: int, num_outcomes: int): | |
| from tensorflow import keras | ||
| from tensorflow.keras import layers | ||
| except ImportError: | ||
| log.error("TensorFlow not installed. Run: pip install tensorflow") | ||
| sys.exit(1) | ||
| log.error("TensorFlow not installed — bp2vec training skipped. Add tensorflow to requirements_army.txt") | ||
| return None | ||
|
|
||
| batter_idx_in = keras.Input(shape=(1,), dtype="int32", name="batter_idx") | ||
| pitcher_idx_in = keras.Input(shape=(1,), dtype="int32", name="pitcher_idx") | ||
|
|
@@ -312,6 +312,9 @@ def train(seasons: list[int]) -> None: | |
| log.info("Building model: %d batters, %d pitchers, %d outcomes", | ||
| num_batters, num_pitchers, num_outcomes) | ||
| model = build_model(num_batters, num_pitchers, num_outcomes) | ||
| if model is None: | ||
| log.warning("bp2vec training aborted — TensorFlow unavailable. Add tensorflow to requirements_army.txt") | ||
| return | ||
|
Comment on lines
+315
to
+317
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check correctly handles the case where TensorFlow is missing. However, it occurs after |
||
| model.summary() | ||
|
|
||
| log.info("Training for %d epochs...", NUM_EPOCHS) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,3 +41,4 @@ pyarrow | |
| pybaseball>=2.2.7 | ||
| mlb-statsapi>=1.7.2 | ||
| curl_cffi>=0.6.0 | ||
| tensorflow>=2.15.0 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The transition from
sys.exit(1)toreturn Nonecorrectly addresses the issue of disrupting the scheduler's event loop. However, please note thatload_statcast_seasons(line 130) still contains asys.exit(1)call onpybaseballimport failure. To fully achieve the goal of a graceful fallback and prevent process termination in a shared environment, that instance should also be updated to return or raise an exception instead of exiting.