From 0bed11d34af14c95f0b20f25188956987762aed8 Mon Sep 17 00:00:00 2001 From: Jiabin Hu Date: Wed, 4 Mar 2026 10:56:39 -0800 Subject: [PATCH 1/2] Allow specifiying query tags as a dict upon connection creation Signed-off-by: Jiabin Hu --- examples/query_tags_example.py | 15 ++++++--------- src/databricks/sql/client.py | 11 +++++++++++ tests/unit/test_session.py | 21 +++++++++++++++++++++ 3 files changed, 38 insertions(+), 9 deletions(-) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py index 687ce4140..8696c4d1a 100644 --- a/examples/query_tags_example.py +++ b/examples/query_tags_example.py @@ -8,7 +8,7 @@ in the system.query.history table for analytical purposes. There are two ways to set query tags: -1. Session-level: Set in session_configuration (applies to all queries in the session) +1. Connection-level: Pass query_tags parameter to sql.connect() (applies to all queries in the session) 2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query) Format: Dictionary with string keys and optional string values @@ -17,21 +17,18 @@ Special cases: - If a value is None, only the key is included (no colon or value) - Special characters (comma, colon and backslash) in values are automatically escaped -- Keys are not escaped (should be controlled identifiers) +- Backslashes in keys are automatically escaped; other special characters in keys are not escaped """ print("=== Query Tags Example ===\n") -# Example 1: Session-level query tags (old approach) -print("Example 1: Session-level query tags") +# Example 1: Connection-level query tags +print("Example 1: Connection-level query tags") with sql.connect( server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), http_path=os.getenv("DATABRICKS_HTTP_PATH"), access_token=os.getenv("DATABRICKS_TOKEN"), - session_configuration={ - 'QUERY_TAGS': 'team:engineering,test:query-tags', - 'ansi_mode': False - } + query_tags={"team": "engineering", "application": "etl"}, ) as connection: with connection.cursor() as cursor: @@ -41,7 +38,7 @@ print() -# Example 2: Per-query query tags (new approach) +# Example 2: Per-query query tags print("Example 2: Per-query query tags") with sql.connect( server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"), diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index efaf6ae4d..2aeea175e 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -36,6 +36,7 @@ ColumnQueue, build_client_context, get_session_config_value, + serialize_query_tags, ) from databricks.sql.parameters.native import ( DbsqlParameterBase, @@ -106,6 +107,7 @@ def __init__( schema: Optional[str] = None, _use_arrow_native_complex_types: Optional[bool] = True, ignore_transactions: bool = True, + query_tags: Optional[Dict[str, Optional[str]]] = None, **kwargs, ) -> None: """ @@ -281,6 +283,15 @@ def read(self) -> Optional[OAuthToken]: "spark.sql.thriftserver.metadata.metricview.enabled" ] = "true" + if query_tags is not None: + if session_configuration is None: + session_configuration = {} + serialized = serialize_query_tags(query_tags) + if serialized: + session_configuration["QUERY_TAGS"] = serialized + else: + session_configuration.pop("QUERY_TAGS", None) + self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 1d70ec4c4..3a43c1a75 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -202,3 +202,24 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): close_session_call_args = instance.close_session.call_args[0][0] assert close_session_call_args.guid == b"\x22" assert close_session_call_args.secret == b"\x33" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_query_tags_dict_sets_session_config(self, mock_client_class): + databricks.sql.connect( + query_tags={"team": "data-eng", "project": "etl"}, + **self.DUMMY_CONNECTION_ARGS, + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:data-eng,project:etl" + + @patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME) + def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_class): + databricks.sql.connect( + query_tags={"team": "new-team"}, + session_configuration={"QUERY_TAGS": "team:old-team,other:value"}, + **self.DUMMY_CONNECTION_ARGS, + ) + + call_kwargs = mock_client_class.return_value.open_session.call_args[1] + assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team" From b0aff372bdd07aaf157a4f6b23e569f684ea9699 Mon Sep 17 00:00:00 2001 From: Jiabin Hu Date: Wed, 4 Mar 2026 11:03:52 -0800 Subject: [PATCH 2/2] fix comment Signed-off-by: Jiabin Hu --- examples/query_tags_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/query_tags_example.py b/examples/query_tags_example.py index 8696c4d1a..977dc6ad5 100644 --- a/examples/query_tags_example.py +++ b/examples/query_tags_example.py @@ -17,7 +17,7 @@ Special cases: - If a value is None, only the key is included (no colon or value) - Special characters (comma, colon and backslash) in values are automatically escaped -- Backslashes in keys are automatically escaped; other special characters in keys are not escaped +- Backslashes in keys are automatically escaped; other special characters in keys are not allowed """ print("=== Query Tags Example ===\n")