From bf80325a3c9b29dc6ba651efeb4a64c3620d6ab5 Mon Sep 17 00:00:00 2001 From: Sailesh Mukil Date: Wed, 10 Sep 2025 22:33:01 -0700 Subject: [PATCH 1/3] Add the concept of DB selectors This change is added to minimize code-duplication between internal and OSS codebases. The primary difference between internal and cloud is the params used to connect to them. We can create the concept of DB selectors which encapsulate details of the underlying DB being connected to. Hence, instead of passing cloud specific parameters (`--database/--instance/--project`), these selectors can be plumbed through the code. - 3 types of selectors are implemented: CLOUD, INFRA, MOCK - Passing the following options to %%spanner_graphs selects the appropriate DB: - `--mock` selects a Mock Spanner database. - `--project <> --instance <> --database` selects a Cloud Spanner database. - `--infra_db_path <>` selects an Infra database - Core cloud and infra related database code already separated. cloud_database.py will be available in OSS and the infra version will not. So `infra_db_path` will not work in OSS and vice versa. MOCK will be supported in both internal and OSS. - A new JSON based wire format between JS and Python is added: { "selector": {"env": "SpannerEnv.MOCK/CLOUD/INFRA", "project": "<>", "instance": "<>", "database": "<>", "infra_db_path": "<>"}, "graph": "<>" } For example the protocol for cloud could look like: { "selector": {"env": "SpannerEnv.CLOUD", "project": "span-cloud-testing", "instance": "graph-demo", "database": "viz-demo", "infra_db_path": null}, "graph": "FinGraph" } Similarly INFRA would have infra_db_path populated and the rest as null. --- spanner_graphs/database.py | 57 +++++++++++++++ spanner_graphs/exec_env.py | 71 ++++++++++++++---- spanner_graphs/graph_server.py | 78 ++++++++++---------- spanner_graphs/graph_visualization.py | 2 +- spanner_graphs/magics.py | 93 +++++++++++++++--------- tests/graph_server_test.py | 29 +++++--- tests/magics_test.py | 56 ++++++++++----- tests/node_expansion_test.py | 100 +++++++++----------------- tests/sample_notebook_test.py | 34 ++++----- tests/server_test.py | 49 +++++++------ 10 files changed, 347 insertions(+), 222 deletions(-) diff --git a/spanner_graphs/database.py b/spanner_graphs/database.py index 91db0ac..63d94d4 100644 --- a/spanner_graphs/database.py +++ b/spanner_graphs/database.py @@ -25,6 +25,63 @@ import csv from dataclasses import dataclass +from enum import Enum, auto + +class SpannerEnv(Enum): + """Defines the types of Spanner environments the application can connect to.""" + CLOUD = auto() + INFRA = auto() + MOCK = auto() + +@dataclass +class DatabaseSelector: + """ + A factory and configuration holder for Spanner database connection details. + + This class provides a clean way to specify which Spanner database to connect to, + whether it's on Google Cloud, an internal infrastructure, or a local mock. + + Attributes: + env: The Spanner environment type. + project: The Google Cloud project. + instance: The Spanner instance. + database: The Spanner database. + infra_db_path: The path for an internal infrastructure database. + """ + env: SpannerEnv + project: str | None = None + instance: str | None = None + database: str | None = None + infra_db_path: str | None = None + + @classmethod + def cloud(cls, project: str, instance: str, database: str) -> 'DatabaseSelector': + """Creates a selector for a Google Cloud Spanner database.""" + if not project or not instance or not database: + raise ValueError("project, instance, and database are required for Cloud Spanner") + return cls(env=SpannerEnv.CLOUD, project=project, instance=instance, database=database) + + @classmethod + def infra(cls, infra_db_path: str) -> 'DatabaseSelector': + """Creates a selector for an internal infrastructure Spanner database.""" + if not infra_db_path: + raise ValueError("infra_db_path is required for Infra Spanner") + return cls(env=SpannerEnv.INFRA, infra_db_path=infra_db_path) + + @classmethod + def mock(cls) -> 'DatabaseSelector': + """Creates a selector for a mock Spanner database.""" + return cls(env=SpannerEnv.MOCK) + + def get_key(self) -> str: + if self.env == SpannerEnv.CLOUD: + return f"cloud_{self.project}_{self.instance}_{self.database}" + elif self.env == SpannerEnv.INFRA: + return f"infra_{self.infra_db_path}" + elif self.env == SpannerEnv.MOCK: + return "mock" + else: + raise ValueError("Unknown Spanner environment") class SpannerQueryResult(NamedTuple): # A dict where each key is a field name returned in the query and the list diff --git a/spanner_graphs/exec_env.py b/spanner_graphs/exec_env.py index 4a60efe..93ed825 100644 --- a/spanner_graphs/exec_env.py +++ b/spanner_graphs/exec_env.py @@ -16,26 +16,73 @@ """ This module maintains state for the execution environment of a session """ -from typing import Dict, Union -from spanner_graphs.database import SpannerDatabase, MockSpannerDatabase -from spanner_graphs.cloud_database import CloudSpannerDatabase +import importlib +from typing import Dict, Union +from spanner_graphs.database import ( + SpannerDatabase, + MockSpannerDatabase, + DatabaseSelector, + SpannerEnv, +) # Global dict of database instances created in a single session database_instances: Dict[str, Union[SpannerDatabase, MockSpannerDatabase]] = {} -def get_database_instance(project: str, instance: str, database: str, mock = False): - if mock: +def get_database_instance( + selector: DatabaseSelector, +) -> Union[SpannerDatabase, MockSpannerDatabase]: + """Gets a cached or new database instance based on the selector. + + Args: + selector: A `DatabaseSelector` object that specifies which database to + connect to. + + Returns: + An initialized `SpannerDatabase` or `MockSpannerDatabase` instance. + A CloudSpannerDatabase will only be available in public environments. + An InfraSpannerDatabase will only be available in internal environments. + + Raises: + RuntimeError: If the required Spanner client library (for Cloud or Infra) + is not installed in the environment. + ValueError: If the selector specifies an unknown or unsupported + environment. + """ + if selector.env == SpannerEnv.MOCK: return MockSpannerDatabase() - key = f"{project}_{instance}_{database}" + key = selector.get_key() db = database_instances.get(key) + if db: + return db - # Currently, we only create and return CloudSpannerDatabase instances. In the future, different - # implementations could be introduced. - if not db: - db = CloudSpannerDatabase(project, instance, database) - database_instances[key] = db + elif selector.env == SpannerEnv.CLOUD: + try: + cloud_db_module = importlib.import_module( + "spanner_graphs.cloud_database" + ) + CloudSpannerDatabase = getattr(cloud_db_module, "CloudSpannerDatabase") + db = CloudSpannerDatabase( + selector.project, selector.instance, selector.database + ) + except ImportError: + raise RuntimeError( + "Cloud Spanner support is not available in this environment." + ) + elif selector.env == SpannerEnv.INFRA: + try: + infra_db_module = importlib.import_module( + "spanner_graphs.infra_database" + ) + InfraSpannerDatabase = getattr(infra_db_module, "InfraSpannerDatabase") + db = InfraSpannerDatabase(selector.infra_db_path) + except ImportError: + raise RuntimeError( + "Infra Spanner support is not available in this environment." + ) + else: + raise ValueError(f"Unsupported Spanner environment: {selector.env}") + database_instances[key] = db return db - diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index cf318c3..3fe936e 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -25,7 +25,7 @@ from spanner_graphs.conversion import get_nodes_edges from spanner_graphs.exec_env import get_database_instance -from spanner_graphs.database import SpannerQueryResult +from spanner_graphs.database import DatabaseSelector, SpannerQueryResult, SpannerEnv # Supported types for a property PROPERTY_TYPE_SET = { @@ -52,6 +52,24 @@ class EdgeDirection(Enum): INCOMING = "INCOMING" OUTGOING = "OUTGOING" + +def dict_to_selector(selector_dict: Dict[str, Any]) -> DatabaseSelector: + """ + Picks the correct DB selector based on the environment the server is running in. + """ + try: + env = SpannerEnv[selector_dict['env'].split('.')[-1]] + if env == SpannerEnv.CLOUD: + return DatabaseSelector.cloud(selector_dict['project'], selector_dict['instance'], selector_dict['database']) + elif env == SpannerEnv.INFRA: + return DatabaseSelector.infra(selector_dict['infra_db_path']) + elif env == SpannerEnv.MOCK: + return DatabaseSelector.mock() + raise ValueError(f"Invalid env in selector dict: {selector_dict}") + except Exception as e: + print (f"Unexpected error when fetching selector: {e}") + + def is_valid_property_type(property_type: str) -> bool: """ Validates a property type. @@ -79,7 +97,7 @@ def is_valid_property_type(property_type: str) -> bool: return True def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploration], EdgeDirection): - required_fields = ["project", "instance", "database", "graph", "uid", "node_labels", "direction"] + required_fields = ["uid", "node_labels", "direction"] missing_fields = [field for field in required_fields if data.get(field) is None] if missing_fields: @@ -146,7 +164,8 @@ def validate_node_expansion_request(data) -> (list[NodePropertyForDataExploratio return validated_properties, direction def execute_node_expansion( - params_str: str, + selector_dict: Dict[str, Any], + graph: str, request: dict) -> dict: """Execute a node expansion query to find connected nodes and edges. @@ -158,13 +177,9 @@ def execute_node_expansion( dict: A dictionary containing the query response with nodes and edges. """ - params = json.loads(params_str) - node_properties, direction = validate_node_expansion_request(params | request) + node_properties, direction = validate_node_expansion_request(request) - project = params.get("project") - instance = params.get("instance") - database = params.get("database") - graph = params.get("graph") + selector = dict_to_selector(selector_dict) uid = request.get("uid") node_labels = request.get("node_labels") edge_label = request.get("edge_label") @@ -204,14 +219,11 @@ def execute_node_expansion( RETURN TO_JSON(e) as e, TO_JSON(d) as d """ - return execute_query(project, instance, database, query, mock=False) + return execute_query(selector_dict, query) def execute_query( - project: str, - instance: str, - database: str, + selector_dict: Dict[str, Any], query: str, - mock: bool = False, ) -> Dict[str, Any]: """Executes a query against a database and formats the result. @@ -220,19 +232,14 @@ def execute_query( If the query fails, it returns a detailed error message, optionally including the database schema to aid in debugging. - Args: - project: The cloud project ID. - instance: The database instance name. - database: The database name. - query: The query string to execute. - mock: If True, use a mock database instance for testing. Defaults to False. - Returns: A dictionary containing either the structured 'response' with nodes, edges, and other data, or an 'error' key with a descriptive message. """ try: - db_instance = get_database_instance(project, instance, database, mock) + selector = dict_to_selector(selector_dict) + db_instance = get_database_instance(selector) + result: SpannerQueryResult = db_instance.execute_query(query) if len(result.rows) == 0 and result.err: @@ -382,32 +389,25 @@ def handle_post_query(self): data = self.parse_post_data() params = json.loads(data["params"]) response = execute_query( - project=params["project"], - instance=params["instance"], - database=params["database"], - query=data["query"], - mock=params["mock"] + selector_dict=params["selector"], + query=data["query"] ) self.do_data_response(response) def handle_post_node_expansion(self): - """Handle POST requests for node expansion. - - Expects a JSON payload with: - - params: A JSON string containing connection parameters (project, instance, database, graph) - - request: A dictionary with node details (uid, node_labels, node_properties, direction, edge_label) - """ try: data = self.parse_post_data() + params = json.loads(data.get("params")) + selector_dict = params["selector"] + graph = params.get("graph") + request_data = data.get("request") - # Execute node expansion with: - # - params_str: JSON string with connection parameters (project, instance, database, graph) - # - request: Dict with node details (uid, node_labels, node_properties, direction, edge_label) self.do_data_response(execute_node_expansion( - params_str=data.get("params"), - request=data.get("request") + selector_dict=selector_dict, + graph=graph, + request=request_data )) - except BaseException as e: + except Exception as e: self.do_error_response(e) return diff --git a/spanner_graphs/graph_visualization.py b/spanner_graphs/graph_visualization.py index 8a9cd77..30ace90 100644 --- a/spanner_graphs/graph_visualization.py +++ b/spanner_graphs/graph_visualization.py @@ -57,7 +57,7 @@ def generate_visualization_html(query: str, port: int, params: str): search_dir = parent template_content = _load_file([search_dir, 'frontend', 'static', 'jupyter.html']) - + # Load the JavaScript bundle directly js_file_path = os.path.join(search_dir, 'third_party', 'index.js') try: diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index b412006..cf9b5d0 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -24,6 +24,7 @@ import sys from threading import Thread import re +from dataclasses import is_dataclass, asdict from IPython.core.display import HTML, JSON from IPython.core.magic import Magics, magics_class, cell_magic @@ -33,6 +34,7 @@ from ipywidgets import interact from jinja2 import Template +from spanner_graphs.database import DatabaseSelector from spanner_graphs.exec_env import get_database_instance from spanner_graphs.graph_server import ( GraphServer, execute_query, execute_node_expansion, @@ -86,11 +88,13 @@ def is_colab() -> bool: def receive_query_request(query: str, params: str): params_dict = json.loads(params) - return JSON(execute_query(project=params_dict["project"], - instance=params_dict["instance"], - database=params_dict["database"], - query=query, - mock=params_dict["mock"])) + selector_dict = params_dict.get("selector") + if not selector_dict: + return JSON({"error": "Missing selector in params"}) + try: + return JSON(execute_query(selector_dict=selector_dict, query=query)) + except Exception as e: + return JSON({"error": str(e)}) def receive_node_expansion_request(request: dict, params_str: str): """Handle node expansion requests in Google Colab environment @@ -103,11 +107,8 @@ def receive_node_expansion_request(request: dict, params_str: str): - direction: str - Direction of expansion ("INCOMING" or "OUTGOING") - edge_label: Optional[str] - Label of edges to filter by params_str: A JSON string containing connection parameters: - - project: str - GCP project ID - - instance: str - Spanner instance ID - - database: str - Spanner database ID + - selector: Dict - The DatabaseSelector object as a dict - graph: str - Graph name - - mock: bool - Whether to use mock data Returns: JSON: A JSON-serialized response containing either: @@ -115,9 +116,23 @@ def receive_node_expansion_request(request: dict, params_str: str): - An error message if the request failed """ try: - return JSON(execute_node_expansion(params_str, request)) + params_dict = json.loads(params_str) + selector_dict = params_dict.get("selector") + graph = params_dict.get("graph") + if not selector_dict: + return JSON({"error": "Missing selector in params"}) + + return JSON(execute_node_expansion(selector_dict=selector_dict, graph=graph, request=request)) except BaseException as e: - return JSON({"error": e}) + return JSON({"error": str(e)}) + +def custom_json_serializer(o): + """A JSON serializer that handles dataclasses and enums.""" + if is_dataclass(o): + return asdict(o) + if isinstance(o, Enum): + return f"{o.__class__.__name__}.{o.name}" + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") @magics_class class NetworkVisualizationMagics(Magics): @@ -129,6 +144,7 @@ def __init__(self, shell): self.limit = 5 self.args = None self.cell = None + self.selector = None if is_colab(): from google.colab import output @@ -149,17 +165,18 @@ def visualize(self): if match: graph = match.group(1) + # Pack the selector and graph into the params to be sent to the GraphServer + params = { + "selector": self.selector, + "graph": graph + } + # Generate the HTML content html_content = generate_visualization_html( query=self.cell, port=GraphServer.port, - params=json.dumps({ - "project": self.args.project, - "instance": self.args.instance, - "database": self.args.database, - "mock": self.args.mock, - "graph": graph - })) + params=json.dumps(params, default=custom_json_serializer)) + display(HTML(html_content)) @cell_magic @@ -177,35 +194,41 @@ def spanner_graph(self, line: str, cell: str): parser.add_argument("--mock", action="store_true", help="Use mock database") + parser.add_argument("--infra_db_path", + action="store_true", + help="Connect to internal Infra Spanner") try: args = parser.parse_args(line.split()) - if not args.mock: - if not (args.project and args.instance and args.database): + selector = None + if args.mock: + selector = DatabaseSelector.mock() + elif args.infra_db_path: + selector = DatabaseSelector.infra(infra_db_path=args.database) + else: + if not (args.project and args.instance): raise ValueError( - "Please provide `--project`, `--instance`, " - "and `--database` values for your query.") - if not cell or not cell.strip(): - print("Error: Query is required.") - return + "Please provide `--project` and `--instance` for Cloud Spanner." + ) + selector = DatabaseSelector.cloud(args.project, args.instance, args.database) - self.args = parser.parse_args(line.split()) + if not args.mock and (not cell or not cell.strip()): + print("Error: Query is required.") + return + + self.args = args self.cell = cell - self.database = get_database_instance( - self.args.project, - self.args.instance, - self.args.database, - mock=self.args.mock) + self.selector = selector + self.database = get_database_instance(self.selector) clear_output(wait=True) self.visualize() except BaseException as e: print(f"Error: {e}") - print("Usage: %%spanner_graph --project PROJECT_ID " - "--instance INSTANCE_ID --database DATABASE_ID " - "[--mock] ") + print("Usage: %%spanner_graph --infra_db_path ") + print(" %%spanner_graph --project --instance --database ") + print(" %%spanner_graph --mock") print(" Graph query here...") - def load_ipython_extension(ipython): """Registration function""" ipython.register_magics(NetworkVisualizationMagics) diff --git a/tests/graph_server_test.py b/tests/graph_server_test.py index 7b405e2..8a881af 100644 --- a/tests/graph_server_test.py +++ b/tests/graph_server_test.py @@ -6,6 +6,7 @@ is_valid_property_type, execute_node_expansion, ) +from spanner_graphs.database import SpannerEnv class TestPropertyTypeHandling(unittest.TestCase): def test_validate_property_type_valid_types(self): @@ -75,12 +76,14 @@ def test_property_value_formatting(self, mock_execute_query): ("ENUM", "ENUM_VALUE", "'''ENUM_VALUE'''"), ] - params = json.dumps({ + selector_dict = { + "env": str(SpannerEnv.CLOUD), "project": "test-project", "instance": "test-instance", "database": "test-database", - "graph": "test-graph", - }) + "infra_db_path": None + } + graph = "test-graph" for type_str, value, expected_format in test_cases: with self.subTest(type=type_str, value=value): @@ -95,13 +98,14 @@ def test_property_value_formatting(self, mock_execute_query): } execute_node_expansion( - params_str=params, + selector_dict=selector_dict, + graph=graph, request=request ) # Extract the actual formatted value from the query last_call = mock_execute_query.call_args[0] # Get the positional args - query = last_call[3] # The query is the 4th positional arg + query = last_call[1] # The query is the 2nd positional arg # Find the WHERE clause in the query and extract the value where_line = [line for line in query.split('\n') if 'WHERE' in line][0] @@ -117,12 +121,14 @@ def test_property_value_formatting_no_type(self, mock_execute_query): # Create a property dictionary with string type (since null type is not allowed) prop_dict = {"key": "test_property", "value": "test_value", "type": "STRING"} - params = json.dumps({ + selector_dict = { + "env": str(SpannerEnv.CLOUD), "project": "test-project", "instance": "test-instance", "database": "test-database", - "graph": "test-graph", - }) + "infra_db_path": None + } + graph = "test-graph" request = { "uid": "test-uid", @@ -132,13 +138,14 @@ def test_property_value_formatting_no_type(self, mock_execute_query): } execute_node_expansion( - params_str=params, + selector_dict=selector_dict, + graph=graph, request=request ) # Extract the actual formatted value from the query - last_call = mock_execute_query.call_args[0] - query = last_call[3] + last_call = mock_execute_query.call_args[0] # Get the positional args + query = last_call[1] # The query is the 2nd positional arg where_line = [line for line in query.split('\n') if 'WHERE' in line][0] expected_pattern = "n.test_property='''test_value'''" self.assertIn(expected_pattern, where_line, diff --git a/tests/magics_test.py b/tests/magics_test.py index fef2fac..51da9de 100644 --- a/tests/magics_test.py +++ b/tests/magics_test.py @@ -3,6 +3,7 @@ from IPython.core.interactiveshell import InteractiveShell from spanner_graphs.graph_server import GraphServer from spanner_graphs.magics import NetworkVisualizationMagics, load_ipython_extension +from spanner_graphs.database import DatabaseSelector class TestNetworkVisualizationMagics(unittest.TestCase): def setUp(self): @@ -11,6 +12,7 @@ def setUp(self): # Initialize our magic class self.magics = NetworkVisualizationMagics(self.ip) + self.magics.selector = None # Initialize selector @classmethod def tearDownClass(cls): @@ -34,38 +36,55 @@ def test_magic_registration(self): self.ip.register_magics.assert_called_once_with(NetworkVisualizationMagics) @patch('spanner_graphs.magics.get_database_instance') - @patch('spanner_graphs.magics.GraphServer') - @patch('spanner_graphs.magics.display') - def test_spanner_graph_magic_with_valid_args(self, mock_display, mock_server, mock_db): - """Test the %%spanner_graph magic with valid arguments""" + @patch('spanner_graphs.magics.generate_visualization_html') + def test_spanner_graph_magic_with_cloud_args(self, mock_generate_html, mock_db): + """Test the %%spanner_graph magic with valid cloud arguments""" # Setup mock database mock_db.return_value = MagicMock() - - # Setup mock server - mock_server.port = 8080 + mock_generate_html.return_value = "" # Test line with valid arguments line = "--project test_project --instance test_instance --database test_db" cell = "SELECT * FROM test_table" # Execute the magic - result = self.magics.spanner_graph(line, cell) + self.magics.spanner_graph(line, cell) + + # Verify database was initialized with correct parameters + expected_selector = DatabaseSelector.cloud("test_project", "test_instance", "test_db") + mock_db.assert_called_once_with(expected_selector) + self.assertEqual(self.magics.selector, expected_selector) + + # Verify display was called (exact HTML content verification would be complex) + mock_generate_html.assert_called_once() + + @patch('spanner_graphs.magics.get_database_instance') + @patch('spanner_graphs.magics.generate_visualization_html') + def test_spanner_graph_magic_with_mock_args(self, mock_generate_html, mock_db): + """Test the %%spanner_graph magic with mock arguments""" + # Setup mock database + mock_db.return_value = MagicMock() + mock_generate_html.return_value = "" + + # Test line with valid arguments + line = "--mock" + cell = "SELECT * FROM test_table" + + # Execute the magic + self.magics.spanner_graph(line, cell) # Verify database was initialized with correct parameters - mock_db.assert_called_once_with( - "test_project", - "test_instance", - "test_db", - mock=False - ) + expected_selector = DatabaseSelector.mock() + mock_db.assert_called_once_with(expected_selector) + self.assertEqual(self.magics.selector, expected_selector) # Verify display was called (exact HTML content verification would be complex) - mock_display.assert_called_once() + mock_generate_html.assert_called_once() def test_spanner_graph_magic_with_invalid_args(self): """Test the %%spanner_graph magic with invalid arguments""" - # Test with missing required arguments - line = "--project test_project" # Missing instance and database + # Test with missing required arguments for cloud + line = "--project test_project --database test_db" # Missing instance cell = "SELECT * FROM test_table" # Execute the magic and capture output @@ -74,8 +93,7 @@ def test_spanner_graph_magic_with_invalid_args(self): # Verify error message was printed mock_print.assert_any_call( - "Error: Please provide `--project`, `--instance`, " - "and `--database` values for your query." + "Error: Please provide `--project` and `--instance` for Cloud Spanner." ) def test_spanner_graph_magic_with_empty_cell(self): diff --git a/tests/node_expansion_test.py b/tests/node_expansion_test.py index 900caab..172d680 100644 --- a/tests/node_expansion_test.py +++ b/tests/node_expansion_test.py @@ -4,23 +4,29 @@ from spanner_graphs.magics import receive_node_expansion_request from spanner_graphs.graph_server import EdgeDirection +from spanner_graphs.database import DatabaseSelector, SpannerEnv class TestNodeExpansion(unittest.TestCase): def setUp(self): self.sample_request = { "uid": "node-123", - "node_key_property_name": "id", - "node_key_property_value": "123", - "node_key_property_type": "INT64", + "node_labels": ["Person"], + "node_properties": [ + {"key": "id", "value": "123", "type": "INT64"} + ], "direction": "OUTGOING", "edge_label": "CONNECTS_TO" } + # Updated params to use DatabaseSelector structure self.sample_params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", + "selector": { + "env": str(SpannerEnv.CLOUD), + "project": "test-project", + "instance": "test-instance", + "database": "test-database", + "infra_db_path": None + }, "graph": "test_graph", - "mock": False }) @patch('spanner_graphs.magics.validate_node_expansion_request') @@ -36,30 +42,16 @@ def test_receive_node_expansion_request(self, mock_execute, mock_validate): } } - # Create request and params objects - request = { - "uid": "node-123", - "node_labels": ["Person"], - "node_properties": [ - {"key": "id", "value": "123", "type": "INT64"} - ], - "direction": "OUTGOING", - "edge_label": "CONNECTS_TO" - } - - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) - # Call the function - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(self.sample_request, self.sample_params) # Verify execute_node_expansion was called with correct parameters - mock_execute.assert_called_once_with(params, request) + params_dict = json.loads(self.sample_params) + mock_execute.assert_called_once_with( + selector_dict=params_dict["selector"], + graph=params_dict["graph"], + request=self.sample_request + ) # Verify the result is wrapped in JSON self.assertEqual(result.data, mock_execute.return_value) @@ -77,30 +69,20 @@ def test_receive_node_expansion_request_without_edge_label(self, mock_execute, m } } - # Create request without edge_label and params objects - request = { - "uid": "node-123", - "node_labels": ["Person"], - "node_properties": [ - {"key": "id", "value": "123", "type": "INT64"} - ], - "direction": "OUTGOING" - # No edge_label - } - - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) + # Create request without edge_label + request = self.sample_request.copy() + del request["edge_label"] # Call the function - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(request, self.sample_params) # Verify execute_node_expansion was called with correct parameters - mock_execute.assert_called_once_with(params, request) + params_dict = json.loads(self.sample_params) + mock_execute.assert_called_once_with( + selector_dict=params_dict["selector"], + graph=params_dict["graph"], + request=request + ) # Verify the result is wrapped in JSON self.assertEqual(result.data, mock_execute.return_value) @@ -121,17 +103,10 @@ def test_invalid_property_type(self, mock_validate): "direction": "OUTGOING" } - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) - # Call the function and verify it returns an error response - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(request, self.sample_params) self.assertIn("error", result.data) + self.assertIn("Invalid property type", result.data["error"]) @patch('spanner_graphs.magics.validate_node_expansion_request') def test_invalid_direction(self, mock_validate): @@ -149,17 +124,10 @@ def test_invalid_direction(self, mock_validate): "direction": "INVALID_DIRECTION" } - params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "graph": "test_graph", - "mock": False - }) - # Call the function and verify it returns an error response - result = receive_node_expansion_request(request, params) + result = receive_node_expansion_request(request, self.sample_params) self.assertIn("error", result.data) + self.assertIn("Invalid direction", result.data["error"]) if __name__ == '__main__': unittest.main() diff --git a/tests/sample_notebook_test.py b/tests/sample_notebook_test.py index 17400d2..ab12c53 100644 --- a/tests/sample_notebook_test.py +++ b/tests/sample_notebook_test.py @@ -4,6 +4,7 @@ from IPython.core.interactiveshell import InteractiveShell from spanner_graphs.graph_server import GraphServer from spanner_graphs.magics import NetworkVisualizationMagics, load_ipython_extension +from spanner_graphs.database import DatabaseSelector class TestSampleNotebook(unittest.TestCase): def setUp(self): @@ -12,6 +13,7 @@ def setUp(self): # Initialize our magic class self.magics = NetworkVisualizationMagics(self.ip) + self.magics.selector = None # Load the notebook content with open('sample.ipynb', 'r') as f: @@ -59,29 +61,24 @@ def test_notebook_cells(self): # Test the mock visualization with mocked dependencies with patch('spanner_graphs.magics.get_database_instance') as mock_db, \ - patch('spanner_graphs.magics.GraphServer') as mock_server, \ - patch('spanner_graphs.magics.display') as mock_display: + patch('spanner_graphs.magics.generate_visualization_html') as mock_generate_html: mock_db.return_value = MagicMock() - mock_server.port = 8080 + mock_generate_html.return_value = "" # Test with a valid query since empty cell is handled by IPython line = '--mock' cell = 'GRAPH FinGraph\nMATCH p = (a)-[e]->(b)\nRETURN TO_JSON(p) AS path\nLIMIT 100' # Execute the magic with a valid query - result = self.magics.spanner_graph(line, cell) + self.magics.spanner_graph(line, cell) # Verify database was initialized with mock=True - mock_db.assert_called_once_with( - None, # project - None, # instance - None, # database - mock=True - ) + expected_selector = DatabaseSelector.mock() + mock_db.assert_called_once_with(expected_selector) # Verify display was called - mock_display.assert_called_once() + mock_generate_html.assert_called_once() # Fourth cell should be the Spanner Graph query query_cell = self.code_cells[3] @@ -97,29 +94,28 @@ def test_notebook_cells(self): # Test the query with mocked dependencies with patch('spanner_graphs.magics.get_database_instance') as mock_db, \ - patch('spanner_graphs.magics.GraphServer') as mock_server, \ - patch('spanner_graphs.magics.display') as mock_display: + patch('spanner_graphs.magics.generate_visualization_html') as mock_generate_html: mock_db.return_value = MagicMock() - mock_server.port = 8080 + mock_generate_html.return_value = "" # Extract the actual line and cell content from the notebook line = next(line for line in query_cell['source'] if line.startswith('%%spanner_graph')).replace('%%spanner_graph ', '') cell = ''.join(line for line in query_cell['source'] if not line.startswith('%%spanner_graph')) # Execute the magic with the actual notebook content - result = self.magics.spanner_graph(line, cell) + self.magics.spanner_graph(line, cell) # Verify database was initialized with placeholder values - mock_db.assert_called_once_with( + expected_selector = DatabaseSelector.cloud( "{project_id}", "{instance_name}", - "{database_name}", - mock=False + "{database_name}" ) + mock_db.assert_called_once_with(expected_selector) # Verify display was called - mock_display.assert_called_once() + mock_generate_html.assert_called_once() if __name__ == '__main__': unittest.main() diff --git a/tests/server_test.py b/tests/server_test.py index ecdf514..9c1bc28 100644 --- a/tests/server_test.py +++ b/tests/server_test.py @@ -16,6 +16,7 @@ import requests import json from spanner_graphs.graph_server import GraphServer +from spanner_graphs.database import SpannerEnv class TestSpannerServer(unittest.TestCase): def setUp(self): @@ -39,15 +40,15 @@ def test_post_query_with_mock(self): """Test querying with mock database""" # Build the request URL route = GraphServer.build_route(GraphServer.endpoints["post_query"]) - + # Create request data with the new structure params = json.dumps({ - "project": "test-project", - "instance": "test-instance", - "database": "test-database", - "mock": True + "selector": { + "env": str(SpannerEnv.MOCK) + }, + "graph": "TestGraph" }) - + request_data = { "params": params, "query": "GRAPH TestGraph MATCH (n) RETURN n" @@ -55,11 +56,11 @@ def test_post_query_with_mock(self): # Send POST request response = requests.post(route, json=request_data) - + # Verify response self.assertEqual(response.status_code, 200) response_data = response.json() - + # Check response structure self.assertIn("response", response_data) response = response_data["response"] @@ -72,13 +73,13 @@ def test_post_query_with_mock(self): # Verify we got some data self.assertTrue(len(response["nodes"]) > 0, "Should have at least one node") self.assertTrue(len(response["edges"]) > 0, "Should have at least one edge") - + # Verify node structure node = response["nodes"][0] self.assertIn("identifier", node) self.assertIn("labels", node) self.assertIn("properties", node) - + # Verify edge structure edge = response["edges"][0] self.assertIn("identifier", edge) @@ -91,25 +92,33 @@ def test_node_expansion_error_handling(self): """Test that errors in node expansion are properly handled and returned.""" # Build the request URL route = GraphServer.build_route(GraphServer.endpoints["post_node_expansion"]) - + # Create request data with invalid fields to trigger validation error - request_data = { - "project": "test-project", - "instance": "test-instance", - "database": "test-database", + params = { + "selector": { + "env": str(SpannerEnv.CLOUD), + "project_id": "test-project", + "instance_id": "test-instance", + "database_id": "test-database" + }, "graph": "test-graph", - "uid": "test-uid", - # Missing required node_labels field - "direction": "INVALID_DIRECTION" # Invalid direction + } + request_data = { + "params": json.dumps(params), + "request": { + "uid": "test-uid", + # Missing required node_labels field + "direction": "INVALID_DIRECTION" # Invalid direction + } } # Send POST request response = requests.post(route, json=request_data) - + # Verify response self.assertEqual(response.status_code, 200) # Server still returns 200 but with error data response_data = response.json() - + # Check error presence self.assertIn("error", response_data) self.assertIsNotNone(response_data["error"]) From 958fb37afb961e207aaa11af8061aaec678ab856 Mon Sep 17 00:00:00 2001 From: Sailesh Mukil Date: Wed, 17 Sep 2025 10:53:37 -0700 Subject: [PATCH 2/3] Make dev_util/serve_dev.py work with the new selector protocol - Also fixes the frontend/static/test --- frontend/static/dev.html | 19 ++++++++++++++----- frontend/static/test.html | 14 +++++++------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/frontend/static/dev.html b/frontend/static/dev.html index 5b1841e..93cc458 100644 --- a/frontend/static/dev.html +++ b/frontend/static/dev.html @@ -484,11 +484,20 @@

Configure Visualization

window.app.tearDown(); } + let selector; + if (mock) { + selector = { env: 'SpannerEnv.MOCK' }; + } else { + selector = { + env: 'SpannerEnv.CLOUD', + project: project, + instance: instance, + database: database + }; + } + const params = { - 'project': project, - 'instance': instance, - 'database': database, - 'mock': mock, + 'selector': selector, 'graph': graph }; @@ -546,4 +555,4 @@

Configure Visualization

toggleCommandPalette(); - \ No newline at end of file + diff --git a/frontend/static/test.html b/frontend/static/test.html index 915e2b2..a85d6db 100644 --- a/frontend/static/test.html +++ b/frontend/static/test.html @@ -68,16 +68,16 @@ } const mount = document.querySelector('.mount-spanner-test'); - params = { - 'project': 'project-foo', - 'instance': 'instance-foo', - 'database': 'database-foo', - 'mock': true - } + const params = { + selector: { + env: 'SpannerEnv.MOCK' + }, + graph: '' + }; window.app = new SpannerApp({ id: 'spanner-test', port:'', params:params, mount:mount, query: '' }); }); - \ No newline at end of file + From a8d66e8e68163720ac1820d7155f079140dc95dd Mon Sep 17 00:00:00 2001 From: Sailesh Mukil Date: Thu, 18 Sep 2025 11:47:04 -0700 Subject: [PATCH 3/3] Address review comments for DB selectors --- spanner_graphs/graph_server.py | 1 - spanner_graphs/magics.py | 1 - 2 files changed, 2 deletions(-) diff --git a/spanner_graphs/graph_server.py b/spanner_graphs/graph_server.py index 3fe936e..6324207 100644 --- a/spanner_graphs/graph_server.py +++ b/spanner_graphs/graph_server.py @@ -179,7 +179,6 @@ def execute_node_expansion( node_properties, direction = validate_node_expansion_request(request) - selector = dict_to_selector(selector_dict) uid = request.get("uid") node_labels = request.get("node_labels") edge_label = request.get("edge_label") diff --git a/spanner_graphs/magics.py b/spanner_graphs/magics.py index cf9b5d0..1741e0d 100644 --- a/spanner_graphs/magics.py +++ b/spanner_graphs/magics.py @@ -224,7 +224,6 @@ def spanner_graph(self, line: str, cell: str): self.visualize() except BaseException as e: print(f"Error: {e}") - print("Usage: %%spanner_graph --infra_db_path ") print(" %%spanner_graph --project --instance --database ") print(" %%spanner_graph --mock") print(" Graph query here...")