Skip to content

Commit e5b9e46

Browse files
ewiandaclaude
andcommitted
Address PR #19 review: defensive checks and TYPE_CHECKING tests
- parseMain: add ChildCount and nil checks before indexing into comparison operator children to prevent panics on malformed ASTs - parseImportStatements: guard import_from_statement with ChildCount check to handle incomplete/invalid syntax safely - Add TestParseTypeChecking covering TYPE_CHECKING, typing.TYPE_CHECKING, and non-TYPE_CHECKING conditional import blocks Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b574ffb commit e5b9e46

2 files changed

Lines changed: 81 additions & 14 deletions

File tree

python/file_parser.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,27 @@ func parseCode(code []byte) (*sitter.Node, error) {
105105
func (p *FileParser) parseMain(node *sitter.Node) bool {
106106
for i := 0; i < int(node.ChildCount()); i++ {
107107
child := node.Child(i)
108-
if child.Type() == "if_statement" &&
109-
child.ChildCount() > 1 &&
110-
child.Child(1).Type() == sitterNodeTypeComparisonOperator &&
111-
child.Child(1).Child(1).Type() == "==" {
112-
statement := child.Child(1)
113-
a, b := statement.Child(0), statement.Child(2)
114-
// convert "'__main__' == __name__" to "__name__ == '__main__'"
115-
if b.Type() == sitterNodeTypeIdentifier {
116-
a, b = b, a
117-
}
118-
if a.Type() == sitterNodeTypeIdentifier && a.Content(p.code) == "__name__" &&
119-
b.Type() == sitterNodeTypeString && string(p.code[b.StartByte()+1:b.EndByte()-1]) == "__main__" {
120-
return true
121-
}
108+
if child.Type() != "if_statement" || child.ChildCount() < 2 {
109+
continue
110+
}
111+
cond := child.Child(1)
112+
if cond == nil || cond.Type() != sitterNodeTypeComparisonOperator || cond.ChildCount() < 3 {
113+
continue
114+
}
115+
if cond.Child(1).Type() != "==" {
116+
continue
117+
}
118+
a, b := cond.Child(0), cond.Child(2)
119+
if a == nil || b == nil {
120+
continue
121+
}
122+
// convert "'__main__' == __name__" to "__name__ == '__main__'"
123+
if b.Type() == sitterNodeTypeIdentifier {
124+
a, b = b, a
125+
}
126+
if a.Type() == sitterNodeTypeIdentifier && a.Content(p.code) == "__name__" &&
127+
b.Type() == sitterNodeTypeString && string(p.code[b.StartByte()+1:b.EndByte()-1]) == "__main__" {
128+
return true
122129
}
123130
}
124131
return false
@@ -174,6 +181,9 @@ func (p *FileParser) parseImportStatements(node *sitter.Node) {
174181
p.output.Modules = append(p.output.Modules, m)
175182
}
176183
} else if node.Type() == sitterNodeTypeImportFromStatement {
184+
if node.ChildCount() < 4 {
185+
return
186+
}
177187
from := node.Child(1).Content(p.code)
178188
if strings.HasPrefix(from, ".") {
179189
return

python/file_parser_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,3 +254,60 @@ func TestParseFull(t *testing.T) {
254254
FileName: "a.py",
255255
}, *output)
256256
}
257+
258+
func TestParseTypeChecking(t *testing.T) {
259+
t.Parallel()
260+
units := []struct {
261+
name string
262+
code string
263+
filepath string
264+
result []module
265+
}{
266+
{
267+
name: "TYPE_CHECKING bare",
268+
code: `import os
269+
if TYPE_CHECKING:
270+
from foo import bar
271+
`,
272+
filepath: "abc.py",
273+
result: []module{
274+
{Name: "os", LineNumber: 1, Filepath: "abc.py"},
275+
{Name: "foo.bar", LineNumber: 3, Filepath: "abc.py", From: "foo"},
276+
},
277+
},
278+
{
279+
name: "typing.TYPE_CHECKING",
280+
code: `import os
281+
if typing.TYPE_CHECKING:
282+
from baz import qux
283+
import something
284+
`,
285+
filepath: "abc.py",
286+
result: []module{
287+
{Name: "os", LineNumber: 1, Filepath: "abc.py"},
288+
{Name: "baz.qux", LineNumber: 3, Filepath: "abc.py", From: "baz"},
289+
{Name: "something", LineNumber: 4, Filepath: "abc.py"},
290+
},
291+
},
292+
{
293+
name: "not TYPE_CHECKING",
294+
code: `if some_other_condition:
295+
from foo import bar
296+
`,
297+
filepath: "abc.py",
298+
result: []module{
299+
{Name: "foo.bar", LineNumber: 2, Filepath: "abc.py", From: "foo"},
300+
},
301+
},
302+
}
303+
for _, u := range units {
304+
t.Run(u.name, func(t *testing.T) {
305+
p := NewFileParser()
306+
code := []byte(u.code)
307+
p.SetCodeAndFile(code, "", u.filepath)
308+
output, err := p.Parse(context.Background())
309+
assert.NoError(t, err)
310+
assert.Equal(t, u.result, output.Modules)
311+
})
312+
}
313+
}

0 commit comments

Comments
 (0)