Skip to content

Commit 3bdca9e

Browse files
timsaucerclaude
andcommitted
refactor(udf): raise KeyError on UDF/UDAF/UDWF lookup miss
`SessionContext.udf` / `udaf` / `udwf` previously surfaced upstream `DataFusionError::Plan` as a generic exception whose message ("There is no UDF named ...") is set by DataFusion and can drift between releases. Pre-check membership via `udfs()` / `udafs()` / `udwfs()` and raise `PyKeyError` on miss so callers get the Pythonic dict-style lookup behavior and tests are no longer coupled to the upstream wording. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 430d4ba commit 3bdca9e

3 files changed

Lines changed: 21 additions & 12 deletions

File tree

crates/core/src/context.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,18 +1072,27 @@ impl PySessionContext {
10721072
self.ctx.deregister_udwf(name);
10731073
}
10741074

1075-
pub fn udf(&self, name: &str) -> PyDataFusionResult<PyScalarUDF> {
1076-
let function = (*self.ctx.udf(name)?).clone();
1075+
pub fn udf(&self, name: &str) -> PyResult<PyScalarUDF> {
1076+
if !self.ctx.udfs().contains(name) {
1077+
return Err(PyKeyError::new_err(format!("no UDF named '{name}'")));
1078+
}
1079+
let function = (*self.ctx.udf(name).map_err(py_datafusion_err)?).clone();
10771080
Ok(PyScalarUDF { function })
10781081
}
10791082

1080-
pub fn udaf(&self, name: &str) -> PyDataFusionResult<PyAggregateUDF> {
1081-
let function = (*self.ctx.udaf(name)?).clone();
1083+
pub fn udaf(&self, name: &str) -> PyResult<PyAggregateUDF> {
1084+
if !self.ctx.udafs().contains(name) {
1085+
return Err(PyKeyError::new_err(format!("no UDAF named '{name}'")));
1086+
}
1087+
let function = (*self.ctx.udaf(name).map_err(py_datafusion_err)?).clone();
10821088
Ok(PyAggregateUDF { function })
10831089
}
10841090

1085-
pub fn udwf(&self, name: &str) -> PyDataFusionResult<PyWindowUDF> {
1086-
let function = (*self.ctx.udwf(name)?).clone();
1091+
pub fn udwf(&self, name: &str) -> PyResult<PyWindowUDF> {
1092+
if !self.ctx.udwfs().contains(name) {
1093+
return Err(PyKeyError::new_err(format!("no UDWF named '{name}'")));
1094+
}
1095+
let function = (*self.ctx.udwf(name).map_err(py_datafusion_err)?).clone();
10871096
Ok(PyWindowUDF { function })
10881097
}
10891098

python/datafusion/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,7 +1323,7 @@ def udf(self, name: str) -> ScalarUDF:
13231323
name: Name of the registered scalar UDF.
13241324
13251325
Raises:
1326-
Exception: If no scalar UDF is registered under ``name``.
1326+
KeyError: If no scalar UDF is registered under ``name``.
13271327
13281328
Examples:
13291329
Register a UDF, then look it up by name and use it in the
@@ -1371,7 +1371,7 @@ def udaf(self, name: str) -> AggregateUDF:
13711371
name: Name of the registered aggregate UDF.
13721372
13731373
Raises:
1374-
Exception: If no aggregate UDF is registered under ``name``.
1374+
KeyError: If no aggregate UDF is registered under ``name``.
13751375
13761376
Examples:
13771377
Look up a built-in aggregate by name and use it in
@@ -1404,7 +1404,7 @@ def udwf(self, name: str) -> WindowUDF:
14041404
name: Name of the registered window UDF.
14051405
14061406
Raises:
1407-
Exception: If no window UDF is registered under ``name``.
1407+
KeyError: If no window UDF is registered under ``name``.
14081408
14091409
Examples:
14101410
Look up a built-in window function by name and use it in

python/tests/test_udf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def test_udf_lookup(ctx, df) -> None:
9393
result = df_result.collect()[0].column(0)
9494
assert result == pa.array([False, False, True])
9595

96-
with pytest.raises(Exception, match="no UDF named"):
96+
with pytest.raises(KeyError, match="no UDF named"):
9797
ctx.udf("does_not_exist")
9898

9999

@@ -133,7 +133,7 @@ def test_udaf_lookup_builtin(ctx, df) -> None:
133133
result = df.aggregate([], [sum_fn(column("a")).alias("total")]).collect()
134134
assert result[0].column(0).to_pylist() == [6]
135135

136-
with pytest.raises(Exception, match="no UDAF named"):
136+
with pytest.raises(KeyError, match="no UDAF named"):
137137
ctx.udaf("does_not_exist")
138138

139139

@@ -143,7 +143,7 @@ def test_udwf_lookup_builtin(ctx, df) -> None:
143143
result = df.select(column("a"), rn().alias("rn")).collect()
144144
assert result[0].column(1).to_pylist() == [1, 2, 3]
145145

146-
with pytest.raises(Exception, match="no UDWF named"):
146+
with pytest.raises(KeyError, match="no UDWF named"):
147147
ctx.udwf("does_not_exist")
148148

149149

0 commit comments

Comments
 (0)