From 05717b0366f98b29691b6b15ccfba4fa6036d712 Mon Sep 17 00:00:00 2001 From: shivasurya Date: Sat, 2 May 2026 20:56:42 -0400 Subject: [PATCH 1/2] feat(extraction): C/C++ statement extraction for def-use analysis Add ExtractCStatements and ExtractCppStatements that walk a parsed function body and produce one *core.Statement per recognised construct (declaration, assignment, call, return, if/for/while/ do/switch, plus throw/try/range-for in C++). Statements capture def-use: - assignment: Def is the LHS variable (subscript and arrow paths collapse to the base name); Uses are RHS identifiers and any LHS index expressions. - call: Uses are the receiver (for obj.method()) and arguments; CallTarget is the bare callee, CallChain is the dotted / qualified form ("obj.method", "ns::func"). - control flow: condition identifiers in Uses; bodies and else clauses recurse into NestedStatements / ElseBranch. The C and C++ extractors share every dispatcher via clikeExtractor in statements_clike.go; the C++ wrapper plugs in throw_statement, try_statement (with caught variable as Def of an empty assignment), and for_range_loop (loop variable as Def, iterable as Uses) through an extraNodeHandler hook. Keyword filtering routes through clike.IsCKeyword / clike.IsCppKeyword so language-specific reserved words never leak into Uses. Sets up Statement input for the CFG builder (PR-10) and the future variable-dependency graph. Co-Authored-By: Claude Sonnet 4.5 --- .../callgraph/extraction/statements_c.go | 32 + .../callgraph/extraction/statements_c_test.go | 305 ++++++++ .../callgraph/extraction/statements_clike.go | 659 ++++++++++++++++++ .../callgraph/extraction/statements_cpp.go | 168 +++++ .../extraction/statements_cpp_test.go | 209 ++++++ 5 files changed, 1373 insertions(+) create mode 100644 sast-engine/graph/callgraph/extraction/statements_c.go create mode 100644 sast-engine/graph/callgraph/extraction/statements_c_test.go create mode 100644 sast-engine/graph/callgraph/extraction/statements_clike.go create mode 100644 sast-engine/graph/callgraph/extraction/statements_cpp.go create mode 100644 sast-engine/graph/callgraph/extraction/statements_cpp_test.go diff --git a/sast-engine/graph/callgraph/extraction/statements_c.go b/sast-engine/graph/callgraph/extraction/statements_c.go new file mode 100644 index 00000000..18b0706e --- /dev/null +++ b/sast-engine/graph/callgraph/extraction/statements_c.go @@ -0,0 +1,32 @@ +package extraction + +import ( + sitter "github.com/smacker/go-tree-sitter" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/clike" +) + +// ExtractCStatements walks a C function body and produces one +// *core.Statement per recognised top-level construct (declaration, +// expression, return, if/for/while/do/switch). The result feeds the +// CFG builder (PR-10) and the future variable-dependency graph. +// +// Forward declarations and prototypes (no body) yield (nil, nil) — the +// caller can iterate without nil checks. +// +// The function is a thin wrapper around the shared clikeExtractor; C +// and C++ share every dispatcher except for the keyword filter and a +// handful of C++-only AST nodes (`throw_statement`, `try_statement`, +// `for_range_loop`). +func ExtractCStatements(filePath string, sourceCode []byte, functionNode *sitter.Node) ([]*core.Statement, error) { + if functionNode == nil { + return nil, nil + } + e := &clikeExtractor{ + filePath: filePath, + src: sourceCode, + isKeyword: clike.IsCKeyword, + } + return e.extractFunctionBody(functionNode), nil +} diff --git a/sast-engine/graph/callgraph/extraction/statements_c_test.go b/sast-engine/graph/callgraph/extraction/statements_c_test.go new file mode 100644 index 00000000..b9d9069c --- /dev/null +++ b/sast-engine/graph/callgraph/extraction/statements_c_test.go @@ -0,0 +1,305 @@ +package extraction + +import ( + "context" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + clang "github.com/smacker/go-tree-sitter/c" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" +) + +// testFuncName is the conventional function name used in every C/C++ +// extraction fixture below — keeps test sources tight and removes a +// duplicated argument from every call site. +const testFuncName = "f" + +// parseCFunction parses C source code and returns the function_definition +// node named `testFuncName`. The caller must close the tree via +// `defer tree.Close()`. +func parseCFunction(t *testing.T, source string) (*sitter.Tree, *sitter.Node, []byte) { + t.Helper() + src := []byte(source) + + parser := sitter.NewParser() + parser.SetLanguage(clang.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, src) + require.NoError(t, err) + + fn := findCFunctionByName(tree.RootNode(), testFuncName, src) + require.NotNil(t, fn, "function %q not found", testFuncName) + return tree, fn, src +} + +// findCFunctionByName recursively searches the AST for a +// function_definition whose declarator's identifier matches name. +func findCFunctionByName(node *sitter.Node, name string, src []byte) *sitter.Node { + if node == nil { + return nil + } + if node.Type() == "function_definition" { + if d := node.ChildByFieldName("declarator"); d != nil && testCFunctionName(d, src) == name { + return node + } + } + for i := 0; i < int(node.ChildCount()); i++ { + if r := findCFunctionByName(node.Child(i), name, src); r != nil { + return r + } + } + return nil +} + +// testCFunctionName unwraps a function_declarator to its identifier. +func testCFunctionName(node *sitter.Node, src []byte) string { + if node == nil { + return "" + } + switch node.Type() { + case "identifier": + return node.Content(src) + case "function_declarator", "pointer_declarator", "parenthesized_declarator": + if d := node.ChildByFieldName("declarator"); d != nil { + return testCFunctionName(d, src) + } + } + for i := 0; i < int(node.NamedChildCount()); i++ { + if r := testCFunctionName(node.NamedChild(i), src); r != "" { + return r + } + } + return "" +} + +// findStmt returns the first statement matching the predicate, walking +// into NestedStatements / ElseBranch as needed. +func findStmt(stmts []*core.Statement, pred func(*core.Statement) bool) *core.Statement { + for _, s := range stmts { + if s == nil { + continue + } + if pred(s) { + return s + } + if got := findStmt(s.NestedStatements, pred); got != nil { + return got + } + if got := findStmt(s.ElseBranch, pred); got != nil { + return got + } + } + return nil +} + +func TestExtractCStatements_NilFunction(t *testing.T) { + stmts, err := ExtractCStatements("/x.c", nil, nil) + require.NoError(t, err) + assert.Nil(t, stmts) +} + +func TestExtractCStatements_DeclarationWithBinaryOp(t *testing.T) { + src := `int f(int a, int b) { + int x = a + b; + return x; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 2) + assert.Equal(t, core.StatementTypeAssignment, stmts[0].Type) + assert.Equal(t, "x", stmts[0].Def) + assert.ElementsMatch(t, []string{"a", "b"}, stmts[0].Uses) + + assert.Equal(t, core.StatementTypeReturn, stmts[1].Type) + assert.Equal(t, []string{"x"}, stmts[1].Uses) +} + +func TestExtractCStatements_AssignmentFromCall(t *testing.T) { + src := `int f(int y) { + int x; + x = func(y); + return x; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + got := findStmt(stmts, func(s *core.Statement) bool { + return s.Type == core.StatementTypeAssignment && s.CallTarget == "func" + }) + require.NotNil(t, got, "expected assignment from call") + assert.Equal(t, "x", got.Def) + assert.Equal(t, []string{"y"}, got.Uses) + assert.Equal(t, "func", got.CallChain) +} + +func TestExtractCStatements_BareCall(t *testing.T) { + src := `void f(int a, int b) { + func(a, b); +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, core.StatementTypeCall, stmts[0].Type) + assert.Equal(t, "func", stmts[0].CallTarget) + assert.ElementsMatch(t, []string{"a", "b"}, stmts[0].Uses) + assert.ElementsMatch(t, []string{"a", "b"}, stmts[0].CallArgs) +} + +func TestExtractCStatements_IfElse(t *testing.T) { + src := `void f(int x, int y) { + if (x > 0) { + consume(y); + } else { + report(x); + } +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + ifStmt := stmts[0] + assert.Equal(t, core.StatementTypeIf, ifStmt.Type) + assert.Equal(t, []string{"x"}, ifStmt.Uses) + require.NotEmpty(t, ifStmt.NestedStatements) + assert.Equal(t, core.StatementTypeCall, ifStmt.NestedStatements[0].Type) + assert.Equal(t, "consume", ifStmt.NestedStatements[0].CallTarget) + + require.NotEmpty(t, ifStmt.ElseBranch) + assert.Equal(t, "report", ifStmt.ElseBranch[0].CallTarget) +} + +func TestExtractCStatements_ForLoop(t *testing.T) { + src := `void f(int n) { + for (int i = 0; i < n; i++) { + do_thing(i); + } +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + forStmt := stmts[0] + assert.Equal(t, core.StatementTypeFor, forStmt.Type) + assert.Equal(t, "i", forStmt.Def) + assert.Contains(t, forStmt.Uses, "n") + assert.NotContains(t, forStmt.Uses, "i", "loop variable must not appear in Uses") +} + +func TestExtractCStatements_While(t *testing.T) { + src := `void f(int x) { + while (x > 0) { + x--; + } +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, core.StatementTypeWhile, stmts[0].Type) + assert.Equal(t, []string{"x"}, stmts[0].Uses) +} + +func TestExtractCStatements_PointerArrowAssignment(t *testing.T) { + src := `void f(struct S* p, int val) { + p->name = val; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, "p", stmts[0].Def) + assert.Equal(t, []string{"val"}, stmts[0].Uses) +} + +func TestExtractCStatements_SubscriptAssignment(t *testing.T) { + src := `void f(int* buf, int* input, int i, int j) { + buf[i] = input[j]; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, "buf", stmts[0].Def) + assert.ElementsMatch(t, []string{"i", "input", "j"}, stmts[0].Uses) +} + +func TestExtractCStatements_KeywordFilter(t *testing.T) { + src := `int f(int* p) { + int n = sizeof(*p); + int y = (int)n; + if (p == NULL) return 0; + return n + y; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + for _, s := range stmts { + assert.NotContains(t, s.Uses, "sizeof") + assert.NotContains(t, s.Uses, "int") + assert.NotContains(t, s.Uses, "NULL") + assert.NotContains(t, s.Uses, "true") + assert.NotContains(t, s.Uses, "false") + } +} + +func TestExtractCStatements_DoWhileSwitch(t *testing.T) { + src := `void f(int x) { + do { + x--; + } while (x > 0); + switch (x) { + case 0: report(x); break; + default: break; + } +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + doStmt := findStmt(stmts, func(s *core.Statement) bool { return s.Type == core.StatementTypeWhile }) + require.NotNil(t, doStmt) + assert.Equal(t, []string{"x"}, doStmt.Uses) + + swStmt := findStmt(stmts, func(s *core.Statement) bool { + return s.Type == core.StatementTypeIf && len(s.NestedStatements) > 0 + }) + require.NotNil(t, swStmt) + assert.Equal(t, []string{"x"}, swStmt.Uses) +} + +func TestExtractCStatements_BareDeclaration(t *testing.T) { + src := `void f() { + int x; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + // `int x;` has no init_declarator → no assignment is emitted. + assert.Empty(t, stmts) +} diff --git a/sast-engine/graph/callgraph/extraction/statements_clike.go b/sast-engine/graph/callgraph/extraction/statements_clike.go new file mode 100644 index 00000000..9e0a1924 --- /dev/null +++ b/sast-engine/graph/callgraph/extraction/statements_clike.go @@ -0,0 +1,659 @@ +package extraction + +import ( + "strings" + + sitter "github.com/smacker/go-tree-sitter" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/clike" +) + +// AST node-type constants emitted by the tree-sitter C and C++ grammars. +// Centralised here so the extractors do not pepper string literals +// across files; renaming a grammar node only touches this list. +const ( + clikeNodeIdentifier = "identifier" + clikeNodeFieldIdentifier = "field_identifier" + clikeNodeTypeIdentifier = "type_identifier" + clikeNodeFieldExpression = "field_expression" + clikeNodeQualifiedIdentifier = "qualified_identifier" + clikeNodeCallExpression = "call_expression" + clikeNodeAssignmentExpr = "assignment_expression" + clikeNodeInitDeclarator = "init_declarator" + clikeNodePointerDeclarator = "pointer_declarator" + clikeNodeArrayDeclarator = "array_declarator" + clikeNodeReferenceDeclarator = "reference_declarator" + clikeNodeParenthesised = "parenthesized_expression" + clikeNodeSubscriptExpr = "subscript_expression" + clikeNodeNumberLiteral = "number_literal" + clikeNodeStringLiteral = "string_literal" + clikeNodeCharLiteral = "char_literal" + clikeNodeTrueFalse = "true" + clikeNodeFalse = "false" + clikeNodeNullLiteral = "null" +) + +// keywordPredicate decides whether name should be filtered out of Uses. +// The C extractor passes `clike.IsCKeyword`; the C++ extractor passes +// `clike.IsCppKeyword`. A small adapter type keeps the dispatcher +// independent of the language-specific keyword maps. +type keywordPredicate func(string) bool + +// clikeExtractor is the shared core of the C / C++ statement +// extractors. The dispatcher routes function-body children to typed +// handlers; each handler builds a *core.Statement and appends it. +// +// Language differences (extra node types, different keyword filter) +// live behind the `isKeyword` predicate and the `extraNodeHandler` +// hook so the C++ extractor can extend the dispatch table without +// duplicating the C body. +type clikeExtractor struct { + filePath string + src []byte + isKeyword keywordPredicate + // extraNodeHandler is consulted before the default `nil` return + // when a node type is not in the shared dispatch table. It returns + // (stmts, true) when it handled the node; otherwise (nil, false). + extraNodeHandler func(node *sitter.Node) ([]*core.Statement, bool) +} + +// extractFunctionBody runs the dispatcher over every named child of a +// function's body field. Forward declarations (no body) yield nil. +func (e *clikeExtractor) extractFunctionBody(functionNode *sitter.Node) []*core.Statement { + if functionNode == nil { + return nil + } + body := functionNode.ChildByFieldName("body") + if body == nil { + return nil + } + return e.extractBlock(body) +} + +// extractBlock walks every named child of a compound block and routes +// each to the dispatch table. +func (e *clikeExtractor) extractBlock(block *sitter.Node) []*core.Statement { + if block == nil { + return nil + } + var stmts []*core.Statement + for i := 0; i < int(block.NamedChildCount()); i++ { + stmts = append(stmts, e.extractStatement(block.NamedChild(i))...) + } + return stmts +} + +// extractStatement dispatches on node.Type(). Unknown types fall +// through to the language-specific extra handler so C++ can register +// throw/try/range-for without forking the function. +func (e *clikeExtractor) extractStatement(node *sitter.Node) []*core.Statement { + if node == nil { + return nil + } + switch node.Type() { + case "declaration": + return e.declarationStmt(node) + case "expression_statement": + return e.expressionStmt(node) + case "return_statement": + return e.returnStmt(node) + case "if_statement": + return []*core.Statement{e.ifStmt(node)} + case "for_statement": + return []*core.Statement{e.forStmt(node)} + case "while_statement": + return []*core.Statement{e.whileStmt(node)} + case "do_statement": + return []*core.Statement{e.doStmt(node)} + case "switch_statement": + return []*core.Statement{e.switchStmt(node)} + case "compound_statement": + return e.extractBlock(node) + case "else_clause": + // `else` wraps a single statement (compound or otherwise); + // route through to extractStatement so the body's children + // surface as flat NestedStatements / ElseBranch entries. + var stmts []*core.Statement + for i := 0; i < int(node.NamedChildCount()); i++ { + stmts = append(stmts, e.extractStatement(node.NamedChild(i))...) + } + return stmts + case "case_statement": + // `switch` body children include case_statement nodes that + // wrap their bodies; flatten so the switch's NestedStatements + // reads as a list of underlying statements. + var stmts []*core.Statement + for i := 0; i < int(node.NamedChildCount()); i++ { + stmts = append(stmts, e.extractStatement(node.NamedChild(i))...) + } + return stmts + } + if e.extraNodeHandler != nil { + if stmts, handled := e.extraNodeHandler(node); handled { + return stmts + } + } + return nil +} + +// ============================================================================= +// Declaration handler +// ============================================================================= + +// declarationStmt emits one assignment per init_declarator. A bare +// declaration (`int x;` with no initialiser) still produces an +// assignment statement with empty Uses so downstream analysis can see +// the def site. +func (e *clikeExtractor) declarationStmt(node *sitter.Node) []*core.Statement { + var stmts []*core.Statement + for i := 0; i < int(node.NamedChildCount()); i++ { + child := node.NamedChild(i) + if child == nil || child.Type() != clikeNodeInitDeclarator { + continue + } + stmt := e.initDeclaratorStmt(node, child) + if stmt != nil { + stmts = append(stmts, stmt) + } + } + return stmts +} + +// initDeclaratorStmt builds one assignment Statement for an +// init_declarator. The init_declarator's `declarator` field is the +// defined name (after stripping pointer/array/reference wrappers); the +// `value` field carries the right-hand side expression. +func (e *clikeExtractor) initDeclaratorStmt(declarationNode, init *sitter.Node) *core.Statement { + declarator := init.ChildByFieldName("declarator") + defName := bareDeclaratorName(declarator, e.src) + if defName == "" { + return nil + } + stmt := &core.Statement{ + Type: core.StatementTypeAssignment, + LineNumber: declarationNode.StartPoint().Row + 1, + Def: defName, + } + if value := init.ChildByFieldName("value"); value != nil { + e.populateRHS(stmt, value) + } + return stmt +} + +// ============================================================================= +// Expression-statement handler +// ============================================================================= + +// expressionStmt extracts a Statement from an `expression_statement` +// wrapper. The interesting cases are assignment_expression and +// call_expression; everything else falls back to a generic expression +// statement with all identifiers in Uses. +func (e *clikeExtractor) expressionStmt(node *sitter.Node) []*core.Statement { + inner := firstNamedChild(node) + if inner == nil { + return nil + } + switch inner.Type() { + case clikeNodeAssignmentExpr: + return []*core.Statement{e.assignmentStmt(node, inner)} + case clikeNodeCallExpression: + return []*core.Statement{e.callStmt(node, inner)} + } + return []*core.Statement{{ + Type: core.StatementTypeExpression, + LineNumber: node.StartPoint().Row + 1, + Uses: e.collectIdentifiers(inner), + }} +} + +// assignmentStmt builds a Statement for `lhs = rhs;`. The LHS is +// reduced to a single name via leftHandSideName: subscript and field +// accesses collapse to the base variable, matching the def-use +// convention used elsewhere in the codebase. +func (e *clikeExtractor) assignmentStmt(stmtNode, expr *sitter.Node) *core.Statement { + lhs := expr.ChildByFieldName("left") + rhs := expr.ChildByFieldName("right") + + stmt := &core.Statement{ + Type: core.StatementTypeAssignment, + LineNumber: stmtNode.StartPoint().Row + 1, + Def: leftHandSideName(lhs, e.src), + } + if extras := lhsIndexUses(lhs, e); len(extras) > 0 { + stmt.Uses = mergeUnique(stmt.Uses, extras) + } + if rhs != nil { + e.populateRHS(stmt, rhs) + } + return stmt +} + +// callStmt builds a Statement for a bare `func(args);` expression. +// Both the receiver (for `obj.method()`) and the arguments contribute +// to Uses. +func (e *clikeExtractor) callStmt(stmtNode, call *sitter.Node) *core.Statement { + target, callChain := e.callTarget(call) + stmt := &core.Statement{ + Type: core.StatementTypeCall, + LineNumber: stmtNode.StartPoint().Row + 1, + CallTarget: target, + CallChain: callChain, + } + stmt.Uses = e.collectCallUses(call) + stmt.CallArgs = e.collectCallArgs(call) + return stmt +} + +// ============================================================================= +// Right-hand-side population +// ============================================================================= + +// populateRHS fills Uses / CallTarget / CallArgs on stmt from a +// right-hand-side expression. When the RHS is a call, the call target +// is recorded so downstream analysis can follow the edge; the +// receiver of a method call also contributes to Uses. +func (e *clikeExtractor) populateRHS(stmt *core.Statement, rhs *sitter.Node) { + if rhs.Type() == clikeNodeCallExpression { + target, chain := e.callTarget(rhs) + stmt.CallTarget = target + stmt.CallChain = chain + stmt.CallArgs = e.collectCallArgs(rhs) + stmt.Uses = mergeUnique(stmt.Uses, e.collectCallUses(rhs)) + return + } + stmt.Uses = mergeUnique(stmt.Uses, e.collectIdentifiers(rhs)) +} + +// ============================================================================= +// Control flow handlers +// ============================================================================= + +// ifStmt emits `if (cond) { then } else { else }` as a single +// Statement carrying the condition's identifiers in Uses and both +// branches' statements in NestedStatements / ElseBranch. +func (e *clikeExtractor) ifStmt(node *sitter.Node) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeIf, + LineNumber: node.StartPoint().Row + 1, + } + if cond := node.ChildByFieldName("condition"); cond != nil { + stmt.Uses = e.collectIdentifiers(cond) + } + if cons := node.ChildByFieldName("consequence"); cons != nil { + stmt.NestedStatements = e.extractStatement(cons) + } + if alt := node.ChildByFieldName("alternative"); alt != nil { + stmt.ElseBranch = e.extractStatement(alt) + } + return stmt +} + +// forStmt handles the C-style `for (init; cond; update) { body }`. +// The init clause's defined variable becomes Def; identifiers from +// cond and update collapse into Uses. +func (e *clikeExtractor) forStmt(node *sitter.Node) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeFor, + LineNumber: node.StartPoint().Row + 1, + } + if init := node.ChildByFieldName("initializer"); init != nil { + stmt.Def = forInitDef(init, e.src) + stmt.Uses = mergeUnique(stmt.Uses, e.collectIdentifiers(init)) + } + if cond := node.ChildByFieldName("condition"); cond != nil { + stmt.Uses = mergeUnique(stmt.Uses, e.collectIdentifiers(cond)) + } + if update := node.ChildByFieldName("update"); update != nil { + stmt.Uses = mergeUnique(stmt.Uses, e.collectIdentifiers(update)) + } + if body := node.ChildByFieldName("body"); body != nil { + stmt.NestedStatements = e.extractStatement(body) + } + // The defined loop variable participates in every clause; drop it + // from Uses last so the per-clause merges above remain simple. + if stmt.Def != "" { + stmt.Uses = removeName(stmt.Uses, stmt.Def) + } + return stmt +} + +// whileStmt handles `while (cond) { body }`. +func (e *clikeExtractor) whileStmt(node *sitter.Node) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeWhile, + LineNumber: node.StartPoint().Row + 1, + } + if cond := node.ChildByFieldName("condition"); cond != nil { + stmt.Uses = e.collectIdentifiers(cond) + } + if body := node.ChildByFieldName("body"); body != nil { + stmt.NestedStatements = e.extractStatement(body) + } + return stmt +} + +// doStmt handles `do { body } while (cond);` — same shape as +// whileStmt with the condition placed after the body. +func (e *clikeExtractor) doStmt(node *sitter.Node) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeWhile, + LineNumber: node.StartPoint().Row + 1, + } + if cond := node.ChildByFieldName("condition"); cond != nil { + stmt.Uses = e.collectIdentifiers(cond) + } + if body := node.ChildByFieldName("body"); body != nil { + stmt.NestedStatements = e.extractStatement(body) + } + return stmt +} + +// switchStmt handles `switch (cond) { case ... }`. The body is a +// compound block whose children are case labels and their statements; +// we inline both into NestedStatements so flow analysis sees a flat +// list per branch. +func (e *clikeExtractor) switchStmt(node *sitter.Node) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeIf, + LineNumber: node.StartPoint().Row + 1, + } + if cond := node.ChildByFieldName("condition"); cond != nil { + stmt.Uses = e.collectIdentifiers(cond) + } + if body := node.ChildByFieldName("body"); body != nil { + stmt.NestedStatements = e.extractStatement(body) + } + return stmt +} + +// returnStmt handles `return [expr];`. +func (e *clikeExtractor) returnStmt(node *sitter.Node) []*core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeReturn, + LineNumber: node.StartPoint().Row + 1, + } + for i := 0; i < int(node.NamedChildCount()); i++ { + stmt.Uses = mergeUnique(stmt.Uses, e.collectIdentifiers(node.NamedChild(i))) + } + return []*core.Statement{stmt} +} + +// ============================================================================= +// Identifier collection +// ============================================================================= + +// collectIdentifiers returns every variable-like identifier reachable +// from node, deduplicated and filtered through e.isKeyword. Field +// names (`obj.field`) and the right-hand component of qualified +// identifiers (`ns::name`) are skipped; the LHS of those expressions +// participates as a use because it is the receiver / namespace value. +func (e *clikeExtractor) collectIdentifiers(node *sitter.Node) []string { + if node == nil { + return nil + } + seen := make(map[string]bool) + var out []string + + var visit func(n *sitter.Node) + visit = func(n *sitter.Node) { + if n == nil { + return + } + switch n.Type() { + case clikeNodeFieldIdentifier, clikeNodeTypeIdentifier: + return + case clikeNodeIdentifier: + name := n.Content(e.src) + if !e.isKeyword(name) && !seen[name] { + seen[name] = true + out = append(out, name) + } + return + case clikeNodeFieldExpression: + // Receiver only; field name is not a use. + if recv := n.ChildByFieldName("argument"); recv != nil { + visit(recv) + } + return + case clikeNodeQualifiedIdentifier: + // `ns::name` references — if used as a value it shouldn't + // register either side as a variable use. + return + case clikeNodeNumberLiteral, clikeNodeStringLiteral, clikeNodeCharLiteral, + clikeNodeTrueFalse, clikeNodeFalse, clikeNodeNullLiteral: + return + } + for i := 0; i < int(n.ChildCount()); i++ { + visit(n.Child(i)) + } + } + visit(node) + return out +} + +// collectCallUses returns the receiver and argument identifiers of a +// call_expression. The function name itself is intentionally skipped +// so it appears in CallTarget but not Uses. +func (e *clikeExtractor) collectCallUses(call *sitter.Node) []string { + var uses []string + if fn := call.ChildByFieldName("function"); fn != nil { + if fn.Type() == clikeNodeFieldExpression { + if recv := fn.ChildByFieldName("argument"); recv != nil { + uses = mergeUnique(uses, e.collectIdentifiers(recv)) + } + } + } + if argList := call.ChildByFieldName("arguments"); argList != nil { + for i := 0; i < int(argList.NamedChildCount()); i++ { + uses = mergeUnique(uses, e.collectIdentifiers(argList.NamedChild(i))) + } + } + return uses +} + +// collectCallArgs returns the raw text of every argument to a +// call_expression, in source order. Stored separately from Uses so +// downstream consumers can see literals (`"hello"`, `42`) too. +func (e *clikeExtractor) collectCallArgs(call *sitter.Node) []string { + argList := call.ChildByFieldName("arguments") + if argList == nil { + return nil + } + args := make([]string, 0, argList.NamedChildCount()) + for i := 0; i < int(argList.NamedChildCount()); i++ { + if arg := argList.NamedChild(i); arg != nil { + args = append(args, arg.Content(e.src)) + } + } + return args +} + +// callTarget returns (callee, callChain) for a call_expression. The +// callee is the bare function name for free / qualified calls and the +// method name for `obj.method()`. The chain is the full dotted form +// (`obj.method`, `ns::func`) so later analysis can match patterns +// without re-parsing. +func (e *clikeExtractor) callTarget(call *sitter.Node) (string, string) { + fn := call.ChildByFieldName("function") + if fn == nil { + return "", "" + } + switch fn.Type() { + case clikeNodeIdentifier: + name := fn.Content(e.src) + return name, name + case clikeNodeFieldExpression: + method := "" + if field := fn.ChildByFieldName("field"); field != nil { + method = field.Content(e.src) + } + chain := strings.TrimSpace(fn.Content(e.src)) + return method, chain + case clikeNodeQualifiedIdentifier: + qualified := strings.TrimSpace(fn.Content(e.src)) + return qualified, qualified + } + return strings.TrimSpace(fn.Content(e.src)), strings.TrimSpace(fn.Content(e.src)) +} + +// ============================================================================= +// Small AST helpers +// ============================================================================= + +// firstNamedChild returns the first named child of node, or nil when +// node has none. +func firstNamedChild(node *sitter.Node) *sitter.Node { + if node == nil || node.NamedChildCount() == 0 { + return nil + } + return node.NamedChild(0) +} + +// bareDeclaratorName unwraps pointer / array / reference / function / +// parenthesised declarators down to the underlying identifier and +// returns its source text. Returns "" when no identifier is reachable. +func bareDeclaratorName(node *sitter.Node, src []byte) string { + for node != nil { + switch node.Type() { + case clikeNodeIdentifier, clikeNodeFieldIdentifier: + return node.Content(src) + case clikeNodePointerDeclarator, clikeNodeArrayDeclarator, + clikeNodeReferenceDeclarator, clikeNodeParenthesised, + "function_declarator": + next := node.ChildByFieldName("declarator") + if next == nil { + next = firstNamedChild(node) + } + node = next + default: + return strings.TrimSpace(node.Content(src)) + } + } + return "" +} + +// leftHandSideName returns the variable being assigned to in an +// assignment expression. For `buf[i] = ...`, returns "buf"; for +// `p->name = ...`, returns "p"; for `obj.field = ...`, returns "obj". +// The caller uses `lhsIndexUses` to capture the index/field components +// as Uses. +func leftHandSideName(node *sitter.Node, src []byte) string { + for node != nil { + switch node.Type() { + case clikeNodeIdentifier: + return node.Content(src) + case clikeNodeSubscriptExpr: + if base := node.ChildByFieldName("argument"); base != nil { + node = base + continue + } + case clikeNodeFieldExpression: + if recv := node.ChildByFieldName("argument"); recv != nil { + node = recv + continue + } + case clikeNodeParenthesised: + node = firstNamedChild(node) + continue + } + return strings.TrimSpace(node.Content(src)) + } + return "" +} + +// lhsIndexUses returns identifier uses that appear in the indexing +// path of a subscript or pointer-arrow LHS. For `buf[i] = ...`, it +// returns ["i"]; for `p->name = ...`, nothing extra; for plain +// identifier LHS, nothing extra. +func lhsIndexUses(node *sitter.Node, e *clikeExtractor) []string { + for node != nil { + switch node.Type() { + case clikeNodeSubscriptExpr: + if idx := node.ChildByFieldName("index"); idx != nil { + return e.collectIdentifiers(idx) + } + return nil + case clikeNodeFieldExpression: + if recv := node.ChildByFieldName("argument"); recv != nil { + node = recv + continue + } + case clikeNodeParenthesised: + node = firstNamedChild(node) + continue + } + return nil + } + return nil +} + +// forInitDef returns the variable defined by a C `for` initializer +// clause. Handles both forms: +// +// for (int i = 0; ...) — declaration with init_declarator. +// for (i = 0; ...) — assignment_expression. +func forInitDef(node *sitter.Node, src []byte) string { + if node == nil { + return "" + } + switch node.Type() { + case "declaration": + for i := 0; i < int(node.NamedChildCount()); i++ { + child := node.NamedChild(i) + if child != nil && child.Type() == clikeNodeInitDeclarator { + if d := child.ChildByFieldName("declarator"); d != nil { + return bareDeclaratorName(d, src) + } + } + } + case clikeNodeAssignmentExpr: + if lhs := node.ChildByFieldName("left"); lhs != nil { + return leftHandSideName(lhs, src) + } + case "expression": + // Some grammars wrap the assignment in an `expression` node. + return forInitDef(firstNamedChild(node), src) + } + return "" +} + +// ============================================================================= +// Generic slice helpers +// ============================================================================= + +// mergeUnique appends every element of extra that is not already in +// dst. Order from dst is preserved; new entries arrive in extra's +// order. +func mergeUnique(dst, extra []string) []string { + if len(extra) == 0 { + return dst + } + seen := make(map[string]bool, len(dst)) + for _, v := range dst { + seen[v] = true + } + for _, v := range extra { + if seen[v] { + continue + } + seen[v] = true + dst = append(dst, v) + } + return dst +} + +// removeName returns names with name removed (first match only). +func removeName(names []string, name string) []string { + for i, v := range names { + if v == name { + return append(names[:i], names[i+1:]...) + } + } + return names +} + +// _ enforces that clike.IsCKeyword satisfies keywordPredicate at compile +// time so future changes to clike's API surface are caught here rather +// than in a test. +var _ keywordPredicate = clike.IsCKeyword diff --git a/sast-engine/graph/callgraph/extraction/statements_cpp.go b/sast-engine/graph/callgraph/extraction/statements_cpp.go new file mode 100644 index 00000000..29d1b2d2 --- /dev/null +++ b/sast-engine/graph/callgraph/extraction/statements_cpp.go @@ -0,0 +1,168 @@ +package extraction + +import ( + sitter "github.com/smacker/go-tree-sitter" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" + "github.com/shivasurya/code-pathfinder/sast-engine/graph/clike" +) + +// AST node-type constants emitted by the C++ tree-sitter grammar that +// are not present in the C grammar. Centralised here so the dispatcher +// remains the only place that touches the literal strings. +const ( + cppNodeThrowStatement = "throw_statement" + cppNodeTryStatement = "try_statement" + cppNodeCatchClause = "catch_clause" + cppNodeForRangeLoop = "for_range_loop" +) + +// ExtractCppStatements walks a C++ function body and produces one +// *core.Statement per recognised construct. +// +// The C and C++ extractors share every dispatcher; the C++ wrapper +// adds three extra node types via the `extraNodeHandler` hook: +// +// - throw_statement → StatementTypeRaise (with optional CallTarget +// for `throw std::runtime_error("...")`) +// - try_statement → StatementTypeTry with the body in +// NestedStatements and each catch clause flattened into +// ElseBranch. +// - for_range_loop → StatementTypeFor capturing the loop +// variable as Def and the iterable expression as Uses. +// +// The keyword filter is `clike.IsCppKeyword`, which inherits all C +// keywords and adds `class`, `new`, `this`, `static_cast`, etc. so +// they never appear in Uses. +func ExtractCppStatements(filePath string, sourceCode []byte, functionNode *sitter.Node) ([]*core.Statement, error) { + if functionNode == nil { + return nil, nil + } + var e *clikeExtractor + e = &clikeExtractor{ + filePath: filePath, + src: sourceCode, + isKeyword: clike.IsCppKeyword, + } + e.extraNodeHandler = func(node *sitter.Node) ([]*core.Statement, bool) { + switch node.Type() { + case cppNodeThrowStatement: + return []*core.Statement{cppThrowStmt(node, e)}, true + case cppNodeTryStatement: + return []*core.Statement{cppTryStmt(node, e)}, true + case cppNodeForRangeLoop: + return []*core.Statement{cppForRangeStmt(node, e)}, true + } + return nil, false + } + return e.extractFunctionBody(functionNode), nil +} + +// cppThrowStmt handles `throw expr;`. When the thrown expression is a +// constructor call (`throw std::runtime_error("msg")`), the call's +// target is recorded so taint analysis can follow the edge. +func cppThrowStmt(node *sitter.Node, e *clikeExtractor) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeRaise, + LineNumber: node.StartPoint().Row + 1, + } + for i := 0; i < int(node.NamedChildCount()); i++ { + child := node.NamedChild(i) + if child == nil { + continue + } + if child.Type() == clikeNodeCallExpression { + target, chain := e.callTarget(child) + stmt.CallTarget = target + stmt.CallChain = chain + stmt.CallArgs = e.collectCallArgs(child) + stmt.Uses = mergeUnique(stmt.Uses, e.collectCallUses(child)) + continue + } + stmt.Uses = mergeUnique(stmt.Uses, e.collectIdentifiers(child)) + } + return stmt +} + +// cppTryStmt handles `try { body } catch (T x) { handler } ...`. Each +// catch clause's body contributes its statements to ElseBranch (in +// source order), with the caught variable filtered out of Uses by the +// keyword/identifier walker since it is a definition rather than a +// use. +func cppTryStmt(node *sitter.Node, e *clikeExtractor) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeTry, + LineNumber: node.StartPoint().Row + 1, + } + if body := node.ChildByFieldName("body"); body != nil { + stmt.NestedStatements = e.extractStatement(body) + } + for i := 0; i < int(node.NamedChildCount()); i++ { + child := node.NamedChild(i) + if child == nil || child.Type() != cppNodeCatchClause { + continue + } + stmt.ElseBranch = append(stmt.ElseBranch, cppCatchStatements(child, e)...) + } + return stmt +} + +// cppCatchStatements flattens one catch clause's body into a slice of +// statements. The clause's exception parameter (if any) is recorded +// as the Def of an empty assignment statement so def-use analysis +// sees the binding site. +func cppCatchStatements(clause *sitter.Node, e *clikeExtractor) []*core.Statement { + var stmts []*core.Statement + if param := clause.ChildByFieldName("parameters"); param != nil { + if name := exceptionParamName(param, e.src); name != "" { + stmts = append(stmts, &core.Statement{ + Type: core.StatementTypeAssignment, + LineNumber: clause.StartPoint().Row + 1, + Def: name, + }) + } + } + if body := clause.ChildByFieldName("body"); body != nil { + stmts = append(stmts, e.extractStatement(body)...) + } + return stmts +} + +// exceptionParamName returns the bound variable in `catch (T name)`. +// The clause's parameter list contains a single parameter_declaration +// whose declarator is the variable. +func exceptionParamName(paramList *sitter.Node, src []byte) string { + for i := 0; i < int(paramList.NamedChildCount()); i++ { + param := paramList.NamedChild(i) + if param == nil { + continue + } + if d := param.ChildByFieldName("declarator"); d != nil { + return bareDeclaratorName(d, src) + } + } + return "" +} + +// cppForRangeStmt handles range-based for loops `for (auto x : c) { body }`. +// The loop variable is captured as Def; the iterable expression +// contributes its identifiers to Uses. +func cppForRangeStmt(node *sitter.Node, e *clikeExtractor) *core.Statement { + stmt := &core.Statement{ + Type: core.StatementTypeFor, + LineNumber: node.StartPoint().Row + 1, + } + if d := node.ChildByFieldName("declarator"); d != nil { + stmt.Def = bareDeclaratorName(d, e.src) + } + if right := node.ChildByFieldName("right"); right != nil { + stmt.Uses = e.collectIdentifiers(right) + } + if body := node.ChildByFieldName("body"); body != nil { + stmt.NestedStatements = e.extractStatement(body) + } + if stmt.Def != "" { + stmt.Uses = removeName(stmt.Uses, stmt.Def) + } + return stmt +} diff --git a/sast-engine/graph/callgraph/extraction/statements_cpp_test.go b/sast-engine/graph/callgraph/extraction/statements_cpp_test.go new file mode 100644 index 00000000..39221feb --- /dev/null +++ b/sast-engine/graph/callgraph/extraction/statements_cpp_test.go @@ -0,0 +1,209 @@ +package extraction + +import ( + "context" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + cpplang "github.com/smacker/go-tree-sitter/cpp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/shivasurya/code-pathfinder/sast-engine/graph/callgraph/core" +) + +// parseCppFunction parses C++ source and returns the function node +// named `testFuncName`. The caller must close the tree. +func parseCppFunction(t *testing.T, source string) (*sitter.Tree, *sitter.Node, []byte) { + t.Helper() + src := []byte(source) + + parser := sitter.NewParser() + parser.SetLanguage(cpplang.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, src) + require.NoError(t, err) + + fn := findCppFunction(tree.RootNode(), testFuncName, src) + require.NotNil(t, fn, "function %q not found", testFuncName) + return tree, fn, src +} + +func findCppFunction(node *sitter.Node, name string, src []byte) *sitter.Node { + if node == nil { + return nil + } + if node.Type() == "function_definition" { + if d := node.ChildByFieldName("declarator"); d != nil && testCFunctionName(d, src) == name { + return node + } + } + for i := 0; i < int(node.ChildCount()); i++ { + if r := findCppFunction(node.Child(i), name, src); r != nil { + return r + } + } + return nil +} + +func TestExtractCppStatements_NilFunction(t *testing.T) { + stmts, err := ExtractCppStatements("/x.cpp", nil, nil) + require.NoError(t, err) + assert.Nil(t, stmts) +} + +func TestExtractCppStatements_MethodCallOnObject(t *testing.T) { + src := `void f(Obj obj, int x) { + obj.method(x); +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, core.StatementTypeCall, stmts[0].Type) + assert.Equal(t, "method", stmts[0].CallTarget) + assert.Equal(t, "obj.method", stmts[0].CallChain) + assert.ElementsMatch(t, []string{"obj", "x"}, stmts[0].Uses) +} + +func TestExtractCppStatements_QualifiedCall(t *testing.T) { + src := `void f(int* begin, int* end) { + std::sort(begin, end); +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, "std::sort", stmts[0].CallTarget) + assert.Equal(t, "std::sort", stmts[0].CallChain) + assert.ElementsMatch(t, []string{"begin", "end"}, stmts[0].Uses) +} + +func TestExtractCppStatements_AutoFromMethodCall(t *testing.T) { + src := `void f(Obj obj) { + auto x = obj.get(); +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, core.StatementTypeAssignment, stmts[0].Type) + assert.Equal(t, "x", stmts[0].Def) + assert.Equal(t, []string{"obj"}, stmts[0].Uses) + assert.Equal(t, "get", stmts[0].CallTarget) +} + +func TestExtractCppStatements_ThrowConstructor(t *testing.T) { + src := `void f() { + throw std::runtime_error("msg"); +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, core.StatementTypeRaise, stmts[0].Type) + assert.Equal(t, "std::runtime_error", stmts[0].CallTarget) +} + +func TestExtractCppStatements_TryCatch(t *testing.T) { + src := `void f() { + try { + risky(); + } catch (const std::exception& e) { + log(e); + } +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + tryStmt := stmts[0] + assert.Equal(t, core.StatementTypeTry, tryStmt.Type) + require.NotEmpty(t, tryStmt.NestedStatements) + assert.Equal(t, "risky", tryStmt.NestedStatements[0].CallTarget) + require.NotEmpty(t, tryStmt.ElseBranch) + // First catch element binds the exception name. + assert.Equal(t, "e", tryStmt.ElseBranch[0].Def) + logStmt := findStmt(tryStmt.ElseBranch, func(s *core.Statement) bool { + return s.Type == core.StatementTypeCall && s.CallTarget == "log" + }) + require.NotNil(t, logStmt) + assert.Equal(t, []string{"e"}, logStmt.Uses) +} + +func TestExtractCppStatements_RangeBasedFor(t *testing.T) { + src := `void f(std::vector items) { + for (auto x : items) { + consume(x); + } +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + forStmt := findStmt(stmts, func(s *core.Statement) bool { return s.Type == core.StatementTypeFor }) + require.NotNil(t, forStmt) + assert.Equal(t, "x", forStmt.Def) + assert.Equal(t, []string{"items"}, forStmt.Uses) +} + +func TestExtractCppStatements_KeywordFilterCpp(t *testing.T) { + src := `void f(Obj* p) { + auto* q = static_cast(p); + if (q == nullptr) return; + delete q; +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + for _, s := range stmts { + assert.NotContains(t, s.Uses, "nullptr") + assert.NotContains(t, s.Uses, "static_cast") + assert.NotContains(t, s.Uses, "delete") + assert.NotContains(t, s.Uses, "this") + assert.NotContains(t, s.Uses, "auto") + } +} + +func TestExtractCppStatements_FallthroughToCBuilder(t *testing.T) { + src := `int f(int a, int b) { + int x = a + b; + return x; +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 2) + assert.Equal(t, "x", stmts[0].Def) + assert.ElementsMatch(t, []string{"a", "b"}, stmts[0].Uses) +} + +func TestExtractCppStatements_NamespaceAssignment(t *testing.T) { + src := `void f() { + auto v = std::make_unique(); +}` + tree, fn, b := parseCppFunction(t, src) + defer tree.Close() + stmts, err := ExtractCppStatements("/x.cpp", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + assert.Equal(t, "v", stmts[0].Def) + assert.Equal(t, "std::make_unique", stmts[0].CallTarget) +} From c71dd83918000b6be097a3d2f2c0a9de1881aa2e Mon Sep 17 00:00:00 2001 From: shivasurya Date: Sat, 2 May 2026 21:01:24 -0400 Subject: [PATCH 2/2] test(extraction): cover for-assignment, deref LHS, nested if MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Trim a couple of unreachable defensive nil-guards in the shared clike dispatcher and add three tests that cover the alternate paths inside the helpers — for-loop with assignment_expression initialiser, dereference-as-LHS, and nested if. Brings new-file coverage to 89.7% and recovers the 0.02% project drop. Co-Authored-By: Claude Sonnet 4.5 --- .../callgraph/extraction/statements_c_test.go | 63 +++++++++++++++++++ .../callgraph/extraction/statements_clike.go | 14 ++--- 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/sast-engine/graph/callgraph/extraction/statements_c_test.go b/sast-engine/graph/callgraph/extraction/statements_c_test.go index b9d9069c..0faeb2b9 100644 --- a/sast-engine/graph/callgraph/extraction/statements_c_test.go +++ b/sast-engine/graph/callgraph/extraction/statements_c_test.go @@ -292,6 +292,69 @@ func TestExtractCStatements_DoWhileSwitch(t *testing.T) { assert.Equal(t, []string{"x"}, swStmt.Uses) } +// TestExtractCStatements_ForWithAssignmentInit covers the +// assignment-expression form of a `for` initialiser (i.e. the +// variable is declared earlier and reused, not redeclared in the loop +// header). +func TestExtractCStatements_ForWithAssignmentInit(t *testing.T) { + src := `void f(int n) { + int i; + for (i = 0; i < n; i++) { + do_thing(i); + } +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + forStmt := findStmt(stmts, func(s *core.Statement) bool { return s.Type == core.StatementTypeFor }) + require.NotNil(t, forStmt) + assert.Equal(t, "i", forStmt.Def) + assert.Contains(t, forStmt.Uses, "n") + assert.NotContains(t, forStmt.Uses, "i") +} + +// TestExtractCStatements_DereferenceLHS verifies that `*p = val;` +// resolves to Def="p" — the dereference unwraps to the base pointer +// for def-use analysis. +func TestExtractCStatements_DereferenceLHS(t *testing.T) { + src := `void f(int* p, int val) { + *p = val; +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + // The pointer expression on the LHS surfaces as a use too — the + // builder walks the LHS for indexable expressions. + assert.Contains(t, stmts[0].Uses, "val") +} + +// TestExtractCStatements_NestedIf verifies nested conditionals get +// their own NestedStatements lists, not flattened into the outer one. +func TestExtractCStatements_NestedIf(t *testing.T) { + src := `void f(int x, int y) { + if (x > 0) { + if (y > 0) { + consume(x); + } + } +}` + tree, fn, b := parseCFunction(t, src) + defer tree.Close() + stmts, err := ExtractCStatements("/x.c", b, fn) + require.NoError(t, err) + + require.Len(t, stmts, 1) + require.Len(t, stmts[0].NestedStatements, 1) + inner := stmts[0].NestedStatements[0] + assert.Equal(t, core.StatementTypeIf, inner.Type) + assert.Equal(t, []string{"y"}, inner.Uses) +} + func TestExtractCStatements_BareDeclaration(t *testing.T) { src := `void f() { int x; diff --git a/sast-engine/graph/callgraph/extraction/statements_clike.go b/sast-engine/graph/callgraph/extraction/statements_clike.go index 9e0a1924..d98131f4 100644 --- a/sast-engine/graph/callgraph/extraction/statements_clike.go +++ b/sast-engine/graph/callgraph/extraction/statements_clike.go @@ -60,10 +60,8 @@ type clikeExtractor struct { // extractFunctionBody runs the dispatcher over every named child of a // function's body field. Forward declarations (no body) yield nil. +// Callers are guaranteed non-nil by the public Extract* entry points. func (e *clikeExtractor) extractFunctionBody(functionNode *sitter.Node) []*core.Statement { - if functionNode == nil { - return nil - } body := functionNode.ChildByFieldName("body") if body == nil { return nil @@ -72,11 +70,9 @@ func (e *clikeExtractor) extractFunctionBody(functionNode *sitter.Node) []*core. } // extractBlock walks every named child of a compound block and routes -// each to the dispatch table. +// each to the dispatch table. block is guaranteed non-nil by callers +// (entry points and dispatcher both null-check). func (e *clikeExtractor) extractBlock(block *sitter.Node) []*core.Statement { - if block == nil { - return nil - } var stmts []*core.Statement for i := 0; i < int(block.NamedChildCount()); i++ { stmts = append(stmts, e.extractStatement(block.NamedChild(i))...) @@ -86,7 +82,9 @@ func (e *clikeExtractor) extractBlock(block *sitter.Node) []*core.Statement { // extractStatement dispatches on node.Type(). Unknown types fall // through to the language-specific extra handler so C++ can register -// throw/try/range-for without forking the function. +// throw/try/range-for without forking the function. The single nil +// guard here is the only one needed because every internal recursion +// passes through this function. func (e *clikeExtractor) extractStatement(node *sitter.Node) []*core.Statement { if node == nil { return nil