diff --git a/bp2vec_train.py b/bp2vec_train.py index b6002ad..5d393ab 100644 --- a/bp2vec_train.py +++ b/bp2vec_train.py @@ -262,8 +262,8 @@ def build_model(num_batters: int, num_pitchers: int, num_outcomes: int): from tensorflow import keras from tensorflow.keras import layers except ImportError: - log.error("TensorFlow not installed. Run: pip install tensorflow") - sys.exit(1) + log.error("TensorFlow not installed — bp2vec training skipped. Add tensorflow to requirements_army.txt") + return None batter_idx_in = keras.Input(shape=(1,), dtype="int32", name="batter_idx") pitcher_idx_in = keras.Input(shape=(1,), dtype="int32", name="pitcher_idx") @@ -312,6 +312,9 @@ def train(seasons: list[int]) -> None: log.info("Building model: %d batters, %d pitchers, %d outcomes", num_batters, num_pitchers, num_outcomes) model = build_model(num_batters, num_pitchers, num_outcomes) + if model is None: + log.warning("bp2vec training aborted — TensorFlow unavailable. Add tensorflow to requirements_army.txt") + return model.summary() log.info("Training for %d epochs...", NUM_EPOCHS) diff --git a/requirements_army.txt b/requirements_army.txt index a8d1f15..b070764 100644 --- a/requirements_army.txt +++ b/requirements_army.txt @@ -41,3 +41,4 @@ pyarrow pybaseball>=2.2.7 mlb-statsapi>=1.7.2 curl_cffi>=0.6.0 +tensorflow>=2.15.0