diff --git a/pyroengine/engine.py b/pyroengine/engine.py index e109ed0f..55289d3f 100644 --- a/pyroengine/engine.py +++ b/pyroengine/engine.py @@ -80,12 +80,12 @@ class Engine: def __init__( self, model_path: Optional[str] = None, - conf_thresh: float = 0.2, + conf_thresh: float = 0.35, model_conf_thresh: float = 0.05, max_bbox_size: float = 0.4, api_url: Optional[str] = None, cam_creds: Optional[Dict[str, Dict[str, str]]] = None, - nb_consecutive_frames: int = 5, + nb_consecutive_frames: int = 7, frame_size: Optional[Tuple[int, int]] = None, cache_backup_period: int = 60, frame_saving_period: Optional[int] = None, diff --git a/src/run.py b/src/run.py index f022c3bb..f33ac04a 100644 --- a/src/run.py +++ b/src/run.py @@ -82,7 +82,7 @@ def main(args): ) # Model parser.add_argument("--model_path", type=str, default=None, help="model path") - parser.add_argument("--thresh", type=float, default=0.2, help="Confidence threshold") + parser.add_argument("--thresh", type=float, default=0.35, help="Confidence threshold") parser.add_argument("--max_bbox_size", type=float, default=0.4, help="Maximum bbox size") # Camera & cache @@ -106,7 +106,7 @@ def main(args): parser.add_argument( "--nb-consecutive_frames", type=int, - default=5, + default=6, help="Number of consecutive frames to combine for prediction", ) parser.add_argument( diff --git a/tests/test_engine.py b/tests/test_engine.py index 7452b759..73129ca3 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -56,14 +56,25 @@ def test_engine_offline(tmpdir_factory, mock_wildfire_image, mock_forest_image): assert isinstance(out, float) assert 0 <= out <= 1 assert len(engine._states["-1"]["last_predictions"]) == 3 - assert engine._states["-1"]["ongoing"] + assert not engine._states["-1"]["ongoing"] assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image) assert engine._states["-1"]["last_predictions"][2][1].shape[0] > 0 assert engine._states["-1"]["last_predictions"][2][1].shape[1] == 5 - assert len(engine._states["-1"]["last_predictions"][-1][2][0]) == 5 assert engine._states["-1"]["last_predictions"][2][3] < datetime.now().isoformat() assert engine._states["-1"]["last_predictions"][2][4] is False + out = engine.predict(mock_wildfire_image) + assert isinstance(out, float) + assert 0 <= out <= 1 + assert len(engine._states["-1"]["last_predictions"]) == 4 + assert engine._states["-1"]["ongoing"] + assert isinstance(engine._states["-1"]["last_predictions"][0][0], Image.Image) + assert engine._states["-1"]["last_predictions"][3][1].shape[0] > 0 + assert engine._states["-1"]["last_predictions"][3][1].shape[1] == 5 + assert len(engine._states["-1"]["last_predictions"][-1][2][0]) == 5 + assert engine._states["-1"]["last_predictions"][3][3] < datetime.now().isoformat() + assert engine._states["-1"]["last_predictions"][3][4] is False + def create_dummy_onnx_model(model_path): """Creates a small dummy ONNX model."""