diff --git a/tests/test_study/test_study_functions.py b/tests/test_study/test_study_functions.py
index 2a2d276ec..5dc74bb09 100644
--- a/tests/test_study/test_study_functions.py
+++ b/tests/test_study/test_study_functions.py
@@ -8,14 +8,164 @@
import openml.study
from openml.testing import TestBase
+from unittest import mock
+import requests
+
+class StudyMockServer:
+ """Helper class to encapsulate all mock XML generation and response mocking."""
+
+ @staticmethod
+ def make_response(xml: str | Exception):
+ if isinstance(xml, Exception):
+ return xml
+ r = mock.Mock(spec=["status_code", "text", "headers"])
+ r.status_code = 200
+ r.headers = {}
+ r.text = xml
+ return r
+
+ @staticmethod
+ def build_study_xml(
+ study_id, n_data, n_tasks, n_flows, n_setups, n_runs=None,
+ main_entity_type="run", alias=None, status="active",
+ name="Dummy Study", description="desc", run_ids=None,
+ task_ids=None, visibility="public"
+ ):
+ def repeat(tag, n):
+ return "".join(f"{i}" for i in range(1, n+1))
+
+ runs_block = ""
+ if run_ids is not None:
+ runs_block = f"" + "".join(f"{rid}" for rid in run_ids) + f""
+ elif n_runs is not None:
+ runs_block = f"{repeat('run', n_runs)}"
+
+ tasks_block = "".join(f"{tid}" for tid in task_ids) if task_ids else repeat("task", n_tasks)
+ alias_block = f"{alias}" if alias else ""
+
+ linked_entities_block = ""
+ if main_entity_type == "task":
+ tids = task_ids if task_ids else range(1, n_tasks + 1)
+ linked_entities = "".join(f"{tid}task" for tid in tids)
+ linked_entities_block = f"{linked_entities}"
+
+ return f"""
+
+ {study_id}
+ {alias_block}
+ {main_entity_type}
+ {name}
+ {description}
+ {visibility}
+ {status}
+ 2020-01-01
+ tester
+ {repeat("data", n_data)}
+ {linked_entities_block}
+ {tasks_block}
+ {repeat("flow", n_flows)}
+ {repeat("setup", n_setups)}
+ {runs_block}
+
+ """
+
+ @staticmethod
+ def build_study_upload_xml(study_id: int):
+ return f'{study_id}'
+
+ @staticmethod
+ def build_study_attach_xml(study_id: int, n_entities: int):
+ return f'{study_id}run{n_entities}'
+
+ @staticmethod
+ def build_study_detach_xml(study_id: int, n_entities: int):
+ return f'{study_id}run{n_entities}'
+
+ @staticmethod
+ def build_status_update_xml(study_id: int, status: str):
+ return f'{study_id}{status}'
+
+ @staticmethod
+ def build_delete_xml(study_id: int):
+ return f'{study_id}'
+
+ @staticmethod
+ def build_evaluations_xml(run_ids):
+ evaluations = "".join(f"{rid}{rid}{rid}{rid}dummy1dummypredictive_accuracy2020-01-0132290.5" for rid in run_ids)
+ return f'{evaluations}'
+
+ @staticmethod
+ def build_users_xml():
+ return '3229tester'
+
+ @staticmethod
+ def build_runs_xml(run_ids):
+ runs = "".join(f"{rid}{rid}1{rid}{rid}32292020-01-01" for rid in run_ids)
+ return f'{runs}'
+
+ @staticmethod
+ def setup_publish_benchmark_suite_mocks(mock_get, mock_post, fixture_name, fixture_descr):
+ mock_post.side_effect = [
+ StudyMockServer.make_response(StudyMockServer.build_study_upload_xml(146)),
+ StudyMockServer.make_response(StudyMockServer.build_study_attach_xml(146, 6)),
+ StudyMockServer.make_response(StudyMockServer.build_study_detach_xml(146, 3)),
+ StudyMockServer.make_response(StudyMockServer.build_status_update_xml(146, "deactivated")),
+ ]
+ mock_get.side_effect = [
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=146, name=fixture_name, description=fixture_descr, main_entity_type="task", task_ids=[1,2,3], n_data=1, n_tasks=3, n_flows=0, n_setups=0, status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=146, name=fixture_name, description=fixture_descr, main_entity_type="task", task_ids=[1,2,3,4,5,6], n_data=1, n_tasks=6, n_flows=0, n_setups=0, status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=146, name=fixture_name, description=fixture_descr, main_entity_type="task", task_ids=[4,5,6], n_data=1, n_tasks=3, n_flows=0, n_setups=0, status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=146, name=fixture_name, description=fixture_descr, main_entity_type="task", task_ids=[4,5,6], n_data=1, n_tasks=3, n_flows=0, n_setups=0, status="deactivated")),
+ ]
+
+ @staticmethod
+ def setup_publish_study_mocks(mock_get, mock_post, mock_delete, fixt_name, fixt_descr):
+ mock_get.side_effect = [
+ StudyMockServer.make_response(StudyMockServer.build_evaluations_xml(range(1,11))),
+ StudyMockServer.make_response(StudyMockServer.build_users_xml()),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(name=fixt_name, description=fixt_descr, study_id=157, main_entity_type="run", n_data=1, n_tasks=10, n_flows=10, n_setups=10, n_runs=10, status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_runs_xml(range(1,11))),
+ StudyMockServer.make_response(StudyMockServer.build_evaluations_xml(range(1,11))),
+ StudyMockServer.make_response(StudyMockServer.build_users_xml()),
+ StudyMockServer.make_response(StudyMockServer.build_runs_xml(range(11,22))),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=157, name=fixt_name, description=fixt_descr, main_entity_type="run", n_data=1, n_tasks=20, n_flows=20, n_setups=20, n_runs=21, status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=157, name=fixt_name, description=fixt_descr, main_entity_type="run", n_data=1, n_tasks=10, n_flows=10, n_setups=10, run_ids=range(11, 22), status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=157, name=fixt_name, description=fixt_descr, main_entity_type="run", n_data=1, n_tasks=10, n_flows=10, n_setups=10, run_ids=range(11, 22), status="deactivated")),
+ ]
+ mock_post.side_effect = [
+ StudyMockServer.make_response(StudyMockServer.build_study_upload_xml(157)),
+ StudyMockServer.make_response(StudyMockServer.build_study_attach_xml(157, 21)),
+ StudyMockServer.make_response(StudyMockServer.build_study_detach_xml(157, 11)),
+ StudyMockServer.make_response(StudyMockServer.build_status_update_xml(157, "deactivated")),
+ ]
+ mock_delete.return_value = StudyMockServer.make_response(StudyMockServer.build_delete_xml(157))
+
+ @staticmethod
+ def setup_study_attach_illegal_mocks(mock_get, mock_post, mock_delete):
+ mock_get.side_effect = [
+ StudyMockServer.make_response(StudyMockServer.build_runs_xml(range(1,11))),
+ StudyMockServer.make_response(StudyMockServer.build_runs_xml(range(1,21))),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=300, name="study with illegal runs", description="none", main_entity_type="run", n_data=1, n_tasks=10, n_flows=10, n_setups=10, n_runs=10, status="in_preparation")),
+ StudyMockServer.make_response(StudyMockServer.build_study_xml(study_id=300, name="study with illegal runs", description="none", main_entity_type="run", n_data=1, n_tasks=10, n_flows=10, n_setups=10, n_runs=10, status="in_preparation")),
+ ]
+ mock_post.side_effect = [
+ StudyMockServer.make_response(StudyMockServer.build_study_upload_xml(300)),
+ openml.exceptions.OpenMLServerException("Problem attaching entities."),
+ openml.exceptions.OpenMLServerException("Problem attaching entities."),
+ ]
+ mock_delete.return_value = StudyMockServer.make_response(StudyMockServer.build_delete_xml(300))
+
class TestStudyFunctions(TestBase):
_multiprocess_can_split_ = True
@pytest.mark.production_server()
@pytest.mark.xfail(reason="failures_issue_1544", strict=False)
- def test_get_study_old(self):
- self.use_production_server()
+ @mock.patch.object(requests.Session, "get")
+ def test_get_study_old(self, mock_get):
+ mock_get.return_value = StudyMockServer.make_response(
+ StudyMockServer.build_study_xml(study_id=34, n_data=105, n_tasks=105, n_flows=27, n_setups=30, n_runs=None)
+ )
study = openml.study.get_study(34)
assert len(study.data) == 105
@@ -25,9 +175,11 @@ def test_get_study_old(self):
assert study.runs is None
@pytest.mark.production_server()
- def test_get_study_new(self):
- self.use_production_server()
-
+ @mock.patch.object(requests.Session, "get")
+ def test_get_study_new(self, mock_get):
+ mock_get.return_value = StudyMockServer.make_response(
+ StudyMockServer.build_study_xml(study_id=123, n_data=299, n_tasks=299, n_flows=5, n_setups=1253, n_runs=1693)
+ )
study = openml.study.get_study(123)
assert len(study.data) == 299
assert len(study.tasks) == 299
@@ -36,9 +188,11 @@ def test_get_study_new(self):
assert len(study.runs) == 1693
@pytest.mark.production_server()
- def test_get_openml100(self):
- self.use_production_server()
-
+ @mock.patch.object(requests.Session, "get")
+ def test_get_openml100(self, mock_get):
+ mock_get.return_value = StudyMockServer.make_response(
+ StudyMockServer.build_study_xml(study_id=99, alias="OpenML100", n_data=100, n_tasks=100, n_flows=0, n_setups=0, n_runs=None, main_entity_type="task")
+ )
study = openml.study.get_study("OpenML100", "tasks")
assert isinstance(study, openml.study.OpenMLBenchmarkSuite)
study_2 = openml.study.get_suite("OpenML100")
@@ -46,8 +200,11 @@ def test_get_openml100(self):
assert study.study_id == study_2.study_id
@pytest.mark.production_server()
- def test_get_study_error(self):
- self.use_production_server()
+ @mock.patch.object(requests.Session, "get")
+ def test_get_study_error(self, mock_get):
+ mock_get.return_value = StudyMockServer.make_response(
+ StudyMockServer.build_study_xml(study_id=99, n_data=1, n_tasks=1, n_flows=0, n_setups=0, n_runs=None, main_entity_type="task")
+ )
with pytest.raises(
ValueError, match="Unexpected entity type 'task' reported by the server, expected 'run'"
@@ -55,8 +212,11 @@ def test_get_study_error(self):
openml.study.get_study(99)
@pytest.mark.production_server()
- def test_get_suite(self):
- self.use_production_server()
+ @mock.patch.object(requests.Session, "get")
+ def test_get_suite(self, mock_get):
+ mock_get.return_value = StudyMockServer.make_response(
+ StudyMockServer.build_study_xml(study_id=99, n_data=72, n_tasks=72, n_flows=0, n_setups=0, n_runs=None, main_entity_type="task")
+ )
study = openml.study.get_suite(99)
assert len(study.data) == 72
@@ -66,8 +226,11 @@ def test_get_suite(self):
assert study.setups is None
@pytest.mark.production_server()
- def test_get_suite_error(self):
- self.use_production_server()
+ @mock.patch.object(requests.Session, "get")
+ def test_get_suite_error(self, mock_get):
+ mock_get.return_value = StudyMockServer.make_response(
+ StudyMockServer.build_study_xml(study_id=123, n_data=1, n_tasks=1, n_flows=0, n_setups=0, n_runs=None, main_entity_type="run")
+ )
with pytest.raises(
ValueError, match="Unexpected entity type 'run' reported by the server, expected 'task'"
@@ -75,12 +238,15 @@ def test_get_suite_error(self):
openml.study.get_suite(123)
@pytest.mark.test_server()
- def test_publish_benchmark_suite(self):
+ @mock.patch.object(requests.Session, "post")
+ @mock.patch.object(requests.Session, "get")
+ def test_publish_benchmark_suite(self, mock_get, mock_post):
fixture_alias = None
fixture_name = "unit tested benchmark suite"
fixture_descr = "bla"
fixture_task_ids = [1, 2, 3]
-
+
+ StudyMockServer.setup_publish_benchmark_suite_mocks(mock_get, mock_post, fixture_name, fixture_descr)
study = openml.study.create_benchmark_suite(
alias=fixture_alias,
name=fixture_name,
@@ -144,23 +310,41 @@ def _test_publish_empty_study_is_allowed(self, explicit: bool):
assert study_downloaded.runs is None
@pytest.mark.test_server()
- def test_publish_empty_study_explicit(self):
+ @mock.patch.object(requests.Session, "post")
+ @mock.patch.object(requests.Session, "get")
+ def test_publish_empty_study_explicit(self, mock_get, mock_post):
+ mock_post.side_effect = [StudyMockServer.make_response(StudyMockServer.build_study_upload_xml(200))]
+ empty_study_xml = StudyMockServer.build_study_xml(study_id=200, name="empty-study-explicit", description="a study with no runs attached explicitly", main_entity_type="run", task_ids=None, n_data=0, n_tasks=0, n_flows=0, n_setups=0, n_runs=None, status="in_preparation")
+ mock_get.side_effect = [StudyMockServer.make_response(empty_study_xml), StudyMockServer.make_response(empty_study_xml)]
self._test_publish_empty_study_is_allowed(explicit=True)
@pytest.mark.test_server()
- def test_publish_empty_study_implicit(self):
+ @mock.patch.object(requests.Session, "post")
+ @mock.patch.object(requests.Session, "get")
+ def test_publish_empty_study_implicit(self, mock_get, mock_post):
+ mock_post.side_effect = [StudyMockServer.make_response(StudyMockServer.build_study_upload_xml(200))]
+ empty_study_xml = StudyMockServer.build_study_xml(study_id=200, name="empty-study-implicit", description="a study with no runs attached implicitly", main_entity_type="run", task_ids=None, n_data=0, n_tasks=0, n_flows=0, n_setups=0, n_runs=None, status="in_preparation")
+ mock_get.side_effect = [StudyMockServer.make_response(empty_study_xml), StudyMockServer.make_response(empty_study_xml)]
+
self._test_publish_empty_study_is_allowed(explicit=False)
@pytest.mark.flaky()
@pytest.mark.test_server()
- def test_publish_study(self):
+ @mock.patch.object(requests.Session, "delete")
+ @mock.patch.object(requests.Session, "post")
+ @mock.patch.object(requests.Session, "get")
+ def test_publish_study(self, mock_get, mock_post, mock_delete):
+ fixt_alias = None
+ fixt_name = "unit tested study"
+ fixt_descr = "bla"
+
+ StudyMockServer.setup_publish_study_mocks(mock_get, mock_post, mock_delete, fixt_name, fixt_descr)
+
# get some random runs to attach
run_list = openml.evaluations.list_evaluations("predictive_accuracy", size=10)
assert len(run_list) == 10
- fixt_alias = None
- fixt_name = "unit tested study"
- fixt_descr = "bla"
+
fixt_flow_ids = {evaluation.flow_id for evaluation in run_list.values()}
fixt_task_ids = {evaluation.task_id for evaluation in run_list.values()}
fixt_setup_ids = {evaluation.setup_id for evaluation in run_list.values()}
@@ -223,7 +407,13 @@ def test_publish_study(self):
assert res
@pytest.mark.test_server()
- def test_study_attach_illegal(self):
+ @mock.patch.object(requests.Session, "delete")
+ @mock.patch.object(requests.Session, "post")
+ @mock.patch.object(requests.Session, "get")
+ def test_study_attach_illegal(self, mock_get, mock_post, mock_delete):
+
+ StudyMockServer.setup_study_attach_illegal_mocks(mock_get, mock_post, mock_delete)
+
run_list = openml.runs.list_runs(size=10)
assert len(run_list) == 10
run_list_more = openml.runs.list_runs(size=20)
@@ -258,7 +448,15 @@ def test_study_attach_illegal(self):
self.assertListEqual(study_original.runs, study_downloaded.runs)
@unittest.skip("It is unclear when we can expect the test to pass or fail.")
- def test_study_list(self):
+ @mock.patch.object(requests.Session, "get")
+ def test_study_list(self, mock_get):
+ studies_xml = """
+
+ 1study-onein_preparation
+ 2study-twoin_preparation
+
+ """
+ mock_get.return_value = StudyMockServer.make_response(studies_xml)
study_list = openml.study.list_studies(status="in_preparation")
# might fail if server is recently reset
assert len(study_list) >= 2