From 96d2c4a205915f0077473e326e573e8bb9b3c3b4 Mon Sep 17 00:00:00 2001 From: shivasurya Date: Sat, 2 May 2026 09:19:18 -0400 Subject: [PATCH] feat(graph/clike): shared C/C++ AST extraction helpers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the cross-cutting primitives that the C parser (parser_c.go) and the C++ parser (parser_cpp.go) will share. C and C++ are dispatched as two distinct languages in graph/parser.go, but the extraction logic for declarations, types, parameters, calls, and keyword filtering is largely identical between the two grammars — centralising it here avoids two parallel implementations drifting apart. Helpers added in this PR: - ExtractFunctionInfo / FunctionInfo — name, return type, parameters, declaration-vs-definition flag from a function_definition node. Forward declarations (no compound_statement body) carry IsDeclaration=true so the call-graph builder can distinguish them from in-translation-unit definitions. - ExtractStructFields / FieldInfo — field name+type pairs from a field_declaration_list, used for both C structs and C++ class bodies. - ExtractTypeString — assembles the canonical type string (qualifiers + base + pointer/reference suffixes) from a (typeNode, declarator) pair. Handles primitive_type, type_identifier, qualified_identifier (std::string), template_type (vector), nested pointers (int**), and reference_declarator (T&) — the latter requires walking past the C++ grammar's anonymous inner declarator child via innerDeclarator. - ExtractParameters — (names, types) parallel slices for parameter_list nodes. Variadics emit ("...", "...") so callers can preserve arity. - ExtractCallInfo / CallInfo — classifies call_expression into free / method-dot / method-arrow / qualified shapes and captures the target, args, and receiver text for the call-resolution layer. - IsCKeyword / IsCppKeyword — backed by cKeywords (C89..C23) and a C++-only addition map; IsCppKeyword unions both. Used by statement extraction to drop reserved words from identifier lists. The package documentation moves from an inline comment in detection.go to a dedicated doc.go that summarises each subsystem and explains how parser_c.go / parser_cpp.go will consume the helpers in subsequent PRs. Tests cover every shape the helpers must handle — including the C++-specific reference_declarator and qualified_identifier cases that required teaching the declarator walker to fall back to scanning named children when the grammar omits the field-named "declarator" child. Co-Authored-By: Claude --- sast-engine/graph/clike/declarations.go | 166 ++++++++++ sast-engine/graph/clike/declarations_test.go | 255 +++++++++++++++ sast-engine/graph/clike/detection.go | 8 - sast-engine/graph/clike/doc.go | 42 +++ sast-engine/graph/clike/helpers.go | 248 +++++++++++++++ sast-engine/graph/clike/helpers_test.go | 309 +++++++++++++++++++ sast-engine/graph/clike/testhelpers_test.go | 81 +++++ sast-engine/graph/clike/types.go | 145 +++++++++ sast-engine/graph/clike/types_test.go | 150 +++++++++ 9 files changed, 1396 insertions(+), 8 deletions(-) create mode 100644 sast-engine/graph/clike/declarations.go create mode 100644 sast-engine/graph/clike/declarations_test.go create mode 100644 sast-engine/graph/clike/doc.go create mode 100644 sast-engine/graph/clike/helpers.go create mode 100644 sast-engine/graph/clike/helpers_test.go create mode 100644 sast-engine/graph/clike/testhelpers_test.go create mode 100644 sast-engine/graph/clike/types.go create mode 100644 sast-engine/graph/clike/types_test.go diff --git a/sast-engine/graph/clike/declarations.go b/sast-engine/graph/clike/declarations.go new file mode 100644 index 00000000..dbc092b3 --- /dev/null +++ b/sast-engine/graph/clike/declarations.go @@ -0,0 +1,166 @@ +package clike + +import sitter "github.com/smacker/go-tree-sitter" + +// FunctionInfo holds extracted information from a C or C++ function_definition +// (or function-shaped declaration) node. The structure is identical for both +// languages — the dispatcher in graph/parser_c.go / graph/parser_cpp.go is +// responsible for setting Node.Language and any C++-specific fields (class +// context, namespace) on the resulting graph.Node. +// +// IsDeclaration is true when the node carries no compound_statement body — +// i.e. a forward declaration in a header (`int compute(int);`) rather than +// a definition. Forward-only declarations are still recorded so that callers +// can be linked to them; PR-03's call-graph builder later uses IsDeclaration +// to decide whether the function is callable in this translation unit or +// whether resolution must reach across files. +type FunctionInfo struct { + Name string + ReturnType string + ParamNames []string + ParamTypes []string + IsDeclaration bool + LineNumber uint32 +} + +// FieldInfo holds a single field name + type extracted from a struct, union, +// or class body. Anonymous fields (rare but legal in C11 and C++) carry an +// empty Name; callers may either skip them or synthesize a name from TypeStr. +type FieldInfo struct { + Name string + TypeStr string +} + +// ExtractFunctionInfo extracts the function name, return type, parameter +// names, and parameter types from a function_definition node. The same +// implementation works for both C and C++ because tree-sitter exposes the +// same field names ("type", "declarator", "parameters", "body") in both +// grammars. +// +// The C/C++ AST shape is: +// +// function_definition +// ├── type ← return type (primitive_type, type_identifier, …) +// ├── declarator ← function_declarator wrapping the name and parameters +// │ ├── declarator (identifier or pointer_declarator) ← name +// │ └── parameters ← parameter_list +// └── body ← compound_statement (omitted for forward declarations) +// +// Returns nil if node is nil or not a function_definition. Empty parameter +// lists yield empty (non-nil) slices. +func ExtractFunctionInfo(node *sitter.Node, sourceCode []byte) *FunctionInfo { + if node == nil { + return nil + } + + info := &FunctionInfo{ + ParamNames: []string{}, + ParamTypes: []string{}, + LineNumber: node.StartPoint().Row + 1, + IsDeclaration: node.ChildByFieldName("body") == nil, + } + + typeNode := node.ChildByFieldName("type") + declarator := node.ChildByFieldName("declarator") + + // The function name lives at the bottom of the declarator chain, after + // any pointer_declarator wrappers used for return-type pointers + // (e.g. char* foo()). The function_declarator itself is reached by + // walking through pointer_declarator nodes. + funcDecl := unwrapToFunctionDeclarator(declarator) + info.ReturnType = ExtractTypeString(typeNode, returnTypeDeclarator(declarator), sourceCode) + + if funcDecl == nil { + // Best-effort fallback: the node isn't well-formed. Return what we + // have so the caller can still record a partial function entry. + return info + } + + if nameNode := funcDecl.ChildByFieldName("declarator"); nameNode != nil { + info.Name = nameNode.Content(sourceCode) + } + + if paramList := funcDecl.ChildByFieldName("parameters"); paramList != nil { + names, types := ExtractParameters(paramList, sourceCode) + info.ParamNames = names + info.ParamTypes = types + } + + return info +} + +// ExtractStructFields walks a field_declaration_list node and returns a +// FieldInfo for every field_declaration child. Bitfields keep the bare type +// (the bit count is dropped) because the type registry does not yet track +// bitfield widths and storing them in the type string would defeat downstream +// type comparison. +// +// Returns nil if list is nil. An empty struct returns an empty (non-nil) +// slice so callers can range without nil-checking the result. +func ExtractStructFields(list *sitter.Node, sourceCode []byte) []FieldInfo { + if list == nil { + return nil + } + + fields := []FieldInfo{} + for i := 0; i < int(list.NamedChildCount()); i++ { + child := list.NamedChild(i) + if child == nil || child.Type() != "field_declaration" { + continue + } + + typeNode := child.ChildByFieldName("type") + declarator := child.ChildByFieldName("declarator") + typeStr := ExtractTypeString(typeNode, declarator, sourceCode) + + name := fieldDeclaratorName(declarator, sourceCode) + fields = append(fields, FieldInfo{Name: name, TypeStr: typeStr}) + } + return fields +} + +// unwrapToFunctionDeclarator walks past any pointer_declarator wrappers and +// returns the function_declarator at the centre. Returns nil if no +// function_declarator is reachable. +func unwrapToFunctionDeclarator(node *sitter.Node) *sitter.Node { + for cur := node; cur != nil; cur = cur.ChildByFieldName("declarator") { + if cur.Type() == "function_declarator" { + return cur + } + } + return nil +} + +// returnTypeDeclarator returns the chain of pointer_declarator nodes that sit +// between the function_definition and its function_declarator. These +// declarators contribute * suffixes to the return type, not to the parameter +// list. For "char* foo()" the chain is one pointer_declarator deep; for plain +// "int foo()" it is empty (returns nil). +func returnTypeDeclarator(node *sitter.Node) *sitter.Node { + if node == nil || node.Type() == "function_declarator" { + return nil + } + // Walk the pointer chain into a synthetic declarator that pointerRefSuffix + // can consume. The shape of the AST already matches what pointerRefSuffix + // expects, so we can pass node directly: pointerRefSuffix stops as soon + // as it sees the function_declarator. + return node +} + +// fieldDeclaratorName extracts the bare identifier name from a field +// declarator, stripping pointer / array / reference wrappers. Returns "" +// for anonymous fields (legal in C11 and common with bitfields like +// `int : 3;`) so callers can decide whether to keep or skip them. +func fieldDeclaratorName(declarator *sitter.Node, sourceCode []byte) string { + for cur := declarator; cur != nil; { + switch cur.Type() { + case "field_identifier", "identifier": + return cur.Content(sourceCode) + case "pointer_declarator", "array_declarator", "reference_declarator": + cur = innerDeclarator(cur) + continue + } + return "" + } + return "" +} diff --git a/sast-engine/graph/clike/declarations_test.go b/sast-engine/graph/clike/declarations_test.go new file mode 100644 index 00000000..ca751358 --- /dev/null +++ b/sast-engine/graph/clike/declarations_test.go @@ -0,0 +1,255 @@ +package clike + +import ( + "testing" +) + +// TestExtractFunctionInfo covers the full range of C and C++ function shapes: +// definitions vs forward declarations, void/typed returns, pointer returns, +// variadic functions, and member functions in C++ classes. +func TestExtractFunctionInfo(t *testing.T) { + tests := []struct { + name string + language string // "c" or "cpp" + code string + wantName string + wantReturn string + wantParamNames []string + wantParamTypes []string + wantIsDeclaration bool + }{ + { + name: "C function with body", + language: "c", + code: "int add(int a, int b) { return a + b; }", + wantName: "add", + wantReturn: "int", + wantParamNames: []string{"a", "b"}, + wantParamTypes: []string{"int", "int"}, + wantIsDeclaration: false, + }, + { + name: "C void function", + language: "c", + code: "void log_msg(const char* fmt) { (void)fmt; }", + wantName: "log_msg", + wantReturn: "void", + wantParamNames: []string{"fmt"}, + wantParamTypes: []string{"const char*"}, + wantIsDeclaration: false, + }, + { + name: "C function returning pointer", + language: "c", + code: "char* allocate(size_t n) { return 0; }", + wantName: "allocate", + wantReturn: "char*", + wantParamNames: []string{"n"}, + wantParamTypes: []string{"size_t"}, + wantIsDeclaration: false, + }, + { + name: "C variadic function", + language: "c", + code: "int printf(const char* fmt, ...) { return 0; }", + wantName: "printf", + wantReturn: "int", + wantParamNames: []string{"fmt", "..."}, + wantParamTypes: []string{"const char*", "..."}, + wantIsDeclaration: false, + }, + { + name: "C++ method definition", + language: "cpp", + code: "int Socket::send(const std::string& msg) { return 0; }", + wantName: "Socket::send", + wantReturn: "int", + wantParamNames: []string{"msg"}, + wantParamTypes: []string{"const std::string&"}, + wantIsDeclaration: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr, rt := snippet(t, tt.language, tt.code) + defer tr.Close() + + fnDef := findNode(rt, "function_definition") + if fnDef == nil { + t.Fatal("function_definition not found") + } + + info := ExtractFunctionInfo(fnDef, []byte(tt.code)) + if info == nil { + t.Fatal("ExtractFunctionInfo returned nil") + } + + if info.Name != tt.wantName { + t.Errorf("Name = %q, want %q", info.Name, tt.wantName) + } + if info.ReturnType != tt.wantReturn { + t.Errorf("ReturnType = %q, want %q", info.ReturnType, tt.wantReturn) + } + if !equalStringSlices(info.ParamNames, tt.wantParamNames) { + t.Errorf("ParamNames = %v, want %v", info.ParamNames, tt.wantParamNames) + } + if !equalStringSlices(info.ParamTypes, tt.wantParamTypes) { + t.Errorf("ParamTypes = %v, want %v", info.ParamTypes, tt.wantParamTypes) + } + if info.IsDeclaration != tt.wantIsDeclaration { + t.Errorf("IsDeclaration = %v, want %v", info.IsDeclaration, tt.wantIsDeclaration) + } + if info.LineNumber == 0 { + t.Error("LineNumber should be > 0") + } + }) + } +} + +// TestExtractFunctionInfo_NilNode verifies the nil guard. +func TestExtractFunctionInfo_NilNode(t *testing.T) { + if got := ExtractFunctionInfo(nil, nil); got != nil { + t.Errorf("ExtractFunctionInfo(nil) = %+v, want nil", got) + } +} + +// TestExtractStructFields covers field extraction for C structs and C++ +// classes, including pointer fields and primitive fields. +func TestExtractStructFields(t *testing.T) { + tests := []struct { + name string + language string + code string + want []FieldInfo + }{ + { + name: "C struct with primitive and pointer fields", + language: "c", + code: "struct Buffer { char* data; size_t len; int capacity; };", + want: []FieldInfo{ + {Name: "data", TypeStr: "char*"}, + {Name: "len", TypeStr: "size_t"}, + {Name: "capacity", TypeStr: "int"}, + }, + }, + { + name: "Empty C struct", + language: "c", + code: "struct Empty { };", + want: []FieldInfo{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr, rt := snippet(t, tt.language, tt.code) + defer tr.Close() + + list := findNode(rt, "field_declaration_list") + if list == nil { + t.Fatal("field_declaration_list not found") + } + + got := ExtractStructFields(list, []byte(tt.code)) + if len(got) != len(tt.want) { + t.Fatalf("got %d fields, want %d (%v)", len(got), len(tt.want), got) + } + for i := range tt.want { + if got[i] != tt.want[i] { + t.Errorf("field[%d] = %+v, want %+v", i, got[i], tt.want[i]) + } + } + }) + } +} + +// TestExtractStructFields_NilList verifies the nil guard. +func TestExtractStructFields_NilList(t *testing.T) { + if got := ExtractStructFields(nil, nil); got != nil { + t.Errorf("ExtractStructFields(nil) = %v, want nil", got) + } +} + +// TestExtractStructFields_BitfieldAndArray covers the array_declarator +// path inside fieldDeclaratorName and the bitfield case where the +// declarator is a plain field_identifier with a sibling bitfield_clause. +func TestExtractStructFields_BitfieldAndArray(t *testing.T) { + code := "struct S { int x : 3; int arr[10]; };" + tr, rt := parseCSnippet(t, code) + defer tr.Close() + + list := findNode(rt, "field_declaration_list") + if list == nil { + t.Fatal("field_declaration_list not found") + } + + got := ExtractStructFields(list, []byte(code)) + want := []FieldInfo{ + {Name: "x", TypeStr: "int"}, + {Name: "arr", TypeStr: "int"}, + } + if len(got) != len(want) { + t.Fatalf("got %d fields, want %d (%+v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("field[%d] = %+v, want %+v", i, got[i], want[i]) + } + } +} + +// TestExtractStructFields_SkipNonFieldChildren covers the filter that +// rejects non-field_declaration children inside a class body — C++ class +// bodies routinely interleave access_specifier nodes with the actual +// fields, and those must not show up in the FieldInfo slice. +func TestExtractStructFields_SkipNonFieldChildren(t *testing.T) { + code := "class C { public: int x; private: int y; };" + tr, rt := parseCppSnippet(t, code) + defer tr.Close() + + list := findNode(rt, "field_declaration_list") + if list == nil { + t.Fatal("field_declaration_list not found") + } + + got := ExtractStructFields(list, []byte(code)) + want := []FieldInfo{ + {Name: "x", TypeStr: "int"}, + {Name: "y", TypeStr: "int"}, + } + if len(got) != len(want) { + t.Fatalf("got %d fields, want %d (%+v)", len(got), len(want), got) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("field[%d] = %+v, want %+v", i, got[i], want[i]) + } + } +} + +// TestExtractFunctionInfo_NoFunctionDeclarator covers the defensive +// fallback that triggers when the declarator chain never reaches a +// function_declarator (malformed AST, or a bare declaration like +// "int x;" passed by mistake). The returned info is still non-nil so +// callers can record a partial entry. +func TestExtractFunctionInfo_NoFunctionDeclarator(t *testing.T) { + code := "int x;" + tr, rt := parseCSnippet(t, code) + defer tr.Close() + + decl := findNode(rt, "declaration") + if decl == nil { + t.Fatal("declaration not found") + } + info := ExtractFunctionInfo(decl, []byte(code)) + if info == nil { + t.Fatal("expected non-nil partial info") + } + if info.Name != "" { + t.Errorf("Name = %q, want empty", info.Name) + } + if !info.IsDeclaration { + t.Error("expected IsDeclaration=true for body-less node") + } +} diff --git a/sast-engine/graph/clike/detection.go b/sast-engine/graph/clike/detection.go index afb74d4f..377ec95b 100644 --- a/sast-engine/graph/clike/detection.go +++ b/sast-engine/graph/clike/detection.go @@ -1,11 +1,3 @@ -// Package clike contains shared helpers for parsing C and C++ source files. -// -// The parsing pipeline treats C and C++ as two distinct languages with separate -// tree-sitter grammars but a large amount of shared structure (declarations, -// statements, type strings). The helpers in this package live alongside the -// language-specific siblings (graph/golang, graph/python, graph/java) and -// provide the cross-cutting primitives — language detection today, AST and -// type extraction in subsequent PRs. package clike import ( diff --git a/sast-engine/graph/clike/doc.go b/sast-engine/graph/clike/doc.go new file mode 100644 index 00000000..4dc8ef10 --- /dev/null +++ b/sast-engine/graph/clike/doc.go @@ -0,0 +1,42 @@ +// Package clike contains shared helpers for parsing C and C++ source files. +// +// The parsing pipeline treats C and C++ as two distinct languages — separate +// tree-sitter grammars, separate Node.Language values ("c" and "cpp"), +// separate dispatchers in graph/parser.go — but the AST node structure for +// declarations, statements, and types is largely shared between the two +// grammars. Rather than duplicate the extraction logic in two parallel +// dispatchers, the cross-cutting primitives live here. +// +// # Detection (this PR) +// +// - IsCSourceFile / IsCppSourceFile route a file to the correct grammar +// - DetectCppInHeader is a best-effort heuristic for the .h ambiguity +// (.h is shared between C and C++; the worker calls this once per file +// and CacheHeaderLanguage stores the result for zero-I/O hot-path lookups) +// +// # Declarations +// +// - FunctionInfo / ExtractFunctionInfo extract name, return type, params, +// and the declaration-vs-definition flag from a function_definition node +// - FieldInfo / ExtractStructFields extract struct/class field name+type +// +// # Types +// +// - ExtractTypeString assembles a complete C/C++ type string from the +// primitive_type / type_identifier / qualified_identifier / template_type +// and pointer_declarator / reference_declarator / type_qualifier nodes +// produced by tree-sitter (e.g. "const std::vector&", "char*") +// +// # Helpers +// +// - ExtractParameters extracts (names, types) from a parameter_list +// - ExtractCallInfo extracts target, arguments, and call-shape metadata +// (free function vs method vs qualified) from a call_expression +// - IsCKeyword / IsCppKeyword are used by statement extraction to filter +// reserved words out of identifier lists +// +// All helpers are pure AST operations: they take *sitter.Node and []byte and +// return plain values. They have no dependency on graph.Node, graph.CodeGraph, +// or any other higher-level type. The parsers in graph/parser_c.go (PR-03) +// and graph/parser_cpp.go (PR-04) consume them. +package clike diff --git a/sast-engine/graph/clike/helpers.go b/sast-engine/graph/clike/helpers.go new file mode 100644 index 00000000..0402d50e --- /dev/null +++ b/sast-engine/graph/clike/helpers.go @@ -0,0 +1,248 @@ +package clike + +import ( + "strings" + + sitter "github.com/smacker/go-tree-sitter" +) + +// CallInfo describes a single C/C++ call_expression. The dispatcher in +// graph/parser_c.go (PR-03) and graph/parser_cpp.go (PR-04) uses this +// structure to decide which call_resolution strategy to apply: free function +// calls go through the file-scope index, qualified calls go through the +// namespace index, and method calls (the obj.foo() and obj->foo() shapes) +// require the receiver type to resolve. +type CallInfo struct { + // Target is the source-level callee text. For free functions this is + // the bare name ("malloc"); for qualified calls it includes the + // namespace chain ("std::move", "mylib::Socket::connect"); for method + // and arrow calls it is the field/method name only ("free", "size"). + Target string + + // Args holds each argument expression as raw source text. Argument + // parsing is deferred to a later PR — at this layer we only need the + // arity and approximate text for diagnostics. + Args []string + + // IsMethod is true for the field_expression shapes: + // obj.foo() (dot operator) + // ptr->foo() (arrow operator) + IsMethod bool + + // IsArrow distinguishes -> from . on a method call. Both set IsMethod; + // IsArrow=true additionally tells the resolver that Receiver is + // pointer-typed, which matters for member access through smart + // pointers and forward-declared classes. + IsArrow bool + + // IsQualified is true for namespace-qualified calls like + // std::move(x) or mylib::ns::func(a). When true, Target already + // contains the full chain. + IsQualified bool + + // Receiver is the source-level expression on the left of '.' or '->' + // for method calls, empty otherwise. It is captured raw (not resolved) + // because the type-inference step that turns it into a class FQN runs + // in a later pass. + Receiver string +} + +// ExtractParameters extracts the parameter names and types from a +// parameter_list node. The two slices are returned as parallel arrays +// (names[i] corresponds to types[i]) to match the convention used by +// graph/golang/helpers.go and the Java parser. +// +// The extractor handles every shape produced by tree-sitter for C and C++: +// +// - parameter_declaration: int x → "x", "int" +// - parameter_declaration with pointer: char* buf → "buf", "char*" +// - parameter_declaration with reference: const T& v → "v", "const T&" +// - parameter_declaration without name: int → "", "int" +// - optional_parameter_declaration: int x = 0 → "x", "int" +// - variadic_parameter: ... → "...", "..." +// +// Returns empty (non-nil) slices when paramList is nil or empty. +func ExtractParameters(paramList *sitter.Node, sourceCode []byte) (names []string, types []string) { + names = []string{} + types = []string{} + if paramList == nil { + return names, types + } + + for i := 0; i < int(paramList.NamedChildCount()); i++ { + param := paramList.NamedChild(i) + if param == nil { + continue + } + + switch param.Type() { + case "parameter_declaration", "optional_parameter_declaration": + name, typ := extractSingleParameter(param, sourceCode) + names = append(names, name) + types = append(types, typ) + case "variadic_parameter", "variadic_parameter_declaration", "...": + names = append(names, "...") + types = append(types, "...") + } + } + return names, types +} + +// extractSingleParameter pulls the name and type from a single parameter_declaration. +func extractSingleParameter(param *sitter.Node, sourceCode []byte) (string, string) { + typeNode := param.ChildByFieldName("type") + declarator := param.ChildByFieldName("declarator") + typeStr := ExtractTypeString(typeNode, declarator, sourceCode) + name := parameterDeclaratorName(declarator, sourceCode) + return name, typeStr +} + +// parameterDeclaratorName walks a parameter declarator chain to find the +// bare identifier name, stripping pointer / reference / array wrappers. +// Returns "" when the parameter is unnamed (legal in C and common in +// abstract declarators used for casts). +func parameterDeclaratorName(declarator *sitter.Node, sourceCode []byte) string { + for cur := declarator; cur != nil; { + switch cur.Type() { + case "identifier": + return cur.Content(sourceCode) + case "pointer_declarator", "reference_declarator", "array_declarator": + cur = innerDeclarator(cur) + continue + } + // Abstract declarators (abstract_pointer_declarator etc.) and any + // other shape have no identifier we can extract. + return "" + } + return "" +} + +// ExtractCallInfo extracts the callee, arguments, and call shape from a +// call_expression node. The function returns nil when node is nil or not a +// call_expression so callers can pass it through unchecked AST traversals. +// +// tree-sitter's C / C++ call_expression has two named fields: +// +// function ← identifier, field_expression, qualified_identifier, … +// arguments ← argument_list (named children are the args) +// +// The function field's node type is what determines IsMethod / IsArrow / +// IsQualified — see the cases inside. +func ExtractCallInfo(node *sitter.Node, sourceCode []byte) *CallInfo { + if node == nil || node.Type() != "call_expression" { + return nil + } + + info := &CallInfo{Args: []string{}} + + if fn := node.ChildByFieldName("function"); fn != nil { + populateCallTarget(info, fn, sourceCode) + } + if argList := node.ChildByFieldName("arguments"); argList != nil { + for i := 0; i < int(argList.NamedChildCount()); i++ { + if arg := argList.NamedChild(i); arg != nil { + info.Args = append(info.Args, arg.Content(sourceCode)) + } + } + } + return info +} + +// populateCallTarget classifies the function expression and writes the +// derived shape flags / target / receiver back into info. +func populateCallTarget(info *CallInfo, fn *sitter.Node, sourceCode []byte) { + switch fn.Type() { + case "identifier": + info.Target = fn.Content(sourceCode) + + case "field_expression": + // obj.method() or obj->method(). + // tree-sitter exposes "argument" (the receiver) and "field" (the + // method name); the access kind is the operator child between them. + info.IsMethod = true + if recv := fn.ChildByFieldName("argument"); recv != nil { + info.Receiver = recv.Content(sourceCode) + } + if field := fn.ChildByFieldName("field"); field != nil { + info.Target = field.Content(sourceCode) + } + info.IsArrow = strings.Contains(fn.Content(sourceCode), "->") + + case "qualified_identifier": + info.IsQualified = true + info.Target = strings.TrimSpace(fn.Content(sourceCode)) + + default: + // Fallback covers function pointers, lambdas, parenthesized + // expressions, etc. We record the raw text so downstream code can + // still match on the source form even when we cannot classify it. + info.Target = strings.TrimSpace(fn.Content(sourceCode)) + } +} + +// cKeywords is the canonical set of C reserved words plus a handful of +// common constants that statement extraction should never report as +// referenced variables. The list spans C89 through C23. +// +// Bool/null constants (NULL, EOF, true, false, nullptr) are included here +// because real C code references them as identifiers; treating them as +// keywords prevents def-use chains from carrying noise entries. +var cKeywords = map[string]bool{ + // C89/C90 + "auto": true, "break": true, "case": true, "char": true, "const": true, + "continue": true, "default": true, "do": true, "double": true, "else": true, + "enum": true, "extern": true, "float": true, "for": true, "goto": true, + "if": true, "int": true, "long": true, "register": true, "return": true, + "short": true, "signed": true, "sizeof": true, "static": true, "struct": true, + "switch": true, "typedef": true, "union": true, "unsigned": true, "void": true, + "volatile": true, "while": true, + // C99 + "restrict": true, "inline": true, "_Bool": true, + "_Complex": true, "_Imaginary": true, + // C11 + "_Alignas": true, "_Alignof": true, "_Atomic": true, "_Generic": true, + "_Noreturn": true, "_Static_assert": true, "_Thread_local": true, + // C23 + "bool": true, "true": true, "false": true, + "nullptr": true, "constexpr": true, "typeof": true, + // Common constants treated as keywords for identifier filtering + "NULL": true, "EOF": true, +} + +// cppKeywords contains C++-only additions on top of cKeywords. It deliberately +// does NOT duplicate any entry already in cKeywords — IsCppKeyword merges +// both maps so that "const" and "class" both resolve correctly without the +// definitions drifting out of sync. +var cppKeywords = map[string]bool{ + "class": true, "namespace": true, "template": true, "typename": true, + "public": true, "private": true, "protected": true, "virtual": true, + "override": true, "final": true, "new": true, "delete": true, + "this": true, "throw": true, "try": true, "catch": true, + "using": true, "operator": true, "friend": true, "mutable": true, + "explicit": true, "export": true, + "consteval": true, "constinit": true, + "co_await": true, "co_return": true, "co_yield": true, + "concept": true, "requires": true, "decltype": true, + "noexcept": true, "static_assert": true, "thread_local": true, + "alignas": true, "alignof": true, + // Common C++ standard-library types frequently used as bare identifiers. + // Treating them as keywords prevents def-use chains from including them + // as plain variable references when no qualifier is present. + "string": true, "vector": true, "map": true, "set": true, + "size_t": true, "ptrdiff_t": true, + "wchar_t": true, "char8_t": true, "char16_t": true, "char32_t": true, +} + +// IsCKeyword reports whether name is a C reserved word or one of the common +// C constants (NULL, EOF) that should be filtered out of variable lists. +func IsCKeyword(name string) bool { + return cKeywords[name] +} + +// IsCppKeyword reports whether name is reserved in C++ — either as a C +// keyword that C++ inherits, or as a C++-only addition. The caller should +// use this for C++ source files; C source should use IsCKeyword to avoid +// rejecting identifiers like "class" or "new" that are legal in C. +func IsCppKeyword(name string) bool { + return cKeywords[name] || cppKeywords[name] +} diff --git a/sast-engine/graph/clike/helpers_test.go b/sast-engine/graph/clike/helpers_test.go new file mode 100644 index 00000000..9bdfac5f --- /dev/null +++ b/sast-engine/graph/clike/helpers_test.go @@ -0,0 +1,309 @@ +package clike + +import ( + "testing" +) + +// TestExtractParameters covers the parameter shapes that statement +// extraction (PR-05) needs to consume cleanly: typed parameters, pointer +// and reference parameters, variadics, and unnamed (abstract) parameters. +func TestExtractParameters(t *testing.T) { + tests := []struct { + name string + language string + code string + wantNames []string + wantTypes []string + }{ + { + name: "two named C parameters", + language: "c", + code: "void f(int a, int b);", + wantNames: []string{"a", "b"}, + wantTypes: []string{"int", "int"}, + }, + { + name: "C variadic", + language: "c", + code: "int printf(const char* fmt, ...);", + wantNames: []string{"fmt", "..."}, + wantTypes: []string{"const char*", "..."}, + }, + { + name: "C unnamed parameter", + language: "c", + code: "void f(int);", + wantNames: []string{""}, + wantTypes: []string{"int"}, + }, + { + name: "C void parameter list", + language: "c", + code: "void f(void);", + wantNames: []string{""}, + wantTypes: []string{"void"}, + }, + { + name: "C++ const reference", + language: "cpp", + code: "void f(const std::string& s, int n);", + wantNames: []string{"s", "n"}, + wantTypes: []string{"const std::string&", "int"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr, rt := snippet(t, tt.language, tt.code) + defer tr.Close() + + pl := findNode(rt, "parameter_list") + if pl == nil { + t.Fatal("parameter_list not found") + } + + names, types := ExtractParameters(pl, []byte(tt.code)) + if !equalStringSlices(names, tt.wantNames) { + t.Errorf("names = %v, want %v", names, tt.wantNames) + } + if !equalStringSlices(types, tt.wantTypes) { + t.Errorf("types = %v, want %v", types, tt.wantTypes) + } + }) + } +} + +// TestExtractParameters_Nil verifies the nil-safe behaviour required by +// callers who pass through unchecked AST traversals. +func TestExtractParameters_Nil(t *testing.T) { + names, types := ExtractParameters(nil, nil) + if len(names) != 0 || len(types) != 0 { + t.Errorf("ExtractParameters(nil) = (%v, %v), want empty", names, types) + } +} + +// TestExtractParameters_AbstractPointer covers the abstract_pointer_declarator +// path: forward declarations and function-pointer typedefs commonly omit +// parameter names (`void f(int*)`). The * suffix must still appear in the +// type and the name must come back empty. +func TestExtractParameters_AbstractPointer(t *testing.T) { + code := "void f(int*);" + tr, rt := parseCSnippet(t, code) + defer tr.Close() + + pl := findNode(rt, "parameter_list") + if pl == nil { + t.Fatal("parameter_list not found") + } + names, types := ExtractParameters(pl, []byte(code)) + if !equalStringSlices(names, []string{""}) { + t.Errorf("names = %v, want [\"\"]", names) + } + if !equalStringSlices(types, []string{"int*"}) { + t.Errorf("types = %v, want [int*]", types) + } +} + +// TestExtractCallInfo covers the four call shapes emitted by tree-sitter +// for C and C++: free function, method (.), arrow method (->), and +// qualified (namespace) calls. The receiver and method-flag combinations +// drive the call-resolution logic in PR-03/PR-04. +func TestExtractCallInfo(t *testing.T) { + tests := []struct { + name string + language string + code string + want CallInfo + }{ + { + name: "simple free function call", + language: "c", + code: "void f() { malloc(128); }", + want: CallInfo{ + Target: "malloc", + Args: []string{"128"}, + }, + }, + { + name: "free function with two args", + language: "c", + code: "void f() { strcpy(dst, src); }", + want: CallInfo{ + Target: "strcpy", + Args: []string{"dst", "src"}, + }, + }, + { + name: "method call via dot", + language: "cpp", + code: "void f(Buffer b) { b.free(); }", + want: CallInfo{ + Target: "free", + Args: []string{}, + IsMethod: true, + Receiver: "b", + }, + }, + { + name: "method call via arrow", + language: "cpp", + code: "void f(Buffer* b) { b->free(); }", + want: CallInfo{ + Target: "free", + Args: []string{}, + IsMethod: true, + IsArrow: true, + Receiver: "b", + }, + }, + { + name: "qualified namespace call", + language: "cpp", + code: "void f() { std::move(x); }", + want: CallInfo{ + Target: "std::move", + Args: []string{"x"}, + IsQualified: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tr, rt := snippet(t, tt.language, tt.code) + defer tr.Close() + + call := findNode(rt, "call_expression") + if call == nil { + t.Fatal("call_expression not found") + } + + got := ExtractCallInfo(call, []byte(tt.code)) + if got == nil { + t.Fatal("ExtractCallInfo returned nil") + } + if got.Target != tt.want.Target { + t.Errorf("Target = %q, want %q", got.Target, tt.want.Target) + } + if got.IsMethod != tt.want.IsMethod { + t.Errorf("IsMethod = %v, want %v", got.IsMethod, tt.want.IsMethod) + } + if got.IsArrow != tt.want.IsArrow { + t.Errorf("IsArrow = %v, want %v", got.IsArrow, tt.want.IsArrow) + } + if got.IsQualified != tt.want.IsQualified { + t.Errorf("IsQualified = %v, want %v", got.IsQualified, tt.want.IsQualified) + } + if got.Receiver != tt.want.Receiver { + t.Errorf("Receiver = %q, want %q", got.Receiver, tt.want.Receiver) + } + if !equalStringSlices(got.Args, tt.want.Args) { + t.Errorf("Args = %v, want %v", got.Args, tt.want.Args) + } + }) + } +} + +// TestExtractCallInfo_FunctionPointerCall covers the populateCallTarget +// default branch — calls through a function pointer or any other +// non-classifiable function expression should preserve the raw source so +// downstream code can still match on it. +func TestExtractCallInfo_FunctionPointerCall(t *testing.T) { + code := "void f() { (*fp)(x); }" + tr, rt := parseCSnippet(t, code) + defer tr.Close() + + call := findNode(rt, "call_expression") + if call == nil { + t.Fatal("call_expression not found") + } + got := ExtractCallInfo(call, []byte(code)) + if got == nil { + t.Fatal("ExtractCallInfo returned nil") + } + if got.Target != "(*fp)" { + t.Errorf("Target = %q, want %q", got.Target, "(*fp)") + } + if got.IsMethod || got.IsQualified || got.IsArrow { + t.Errorf("expected unclassified call, got method=%v qualified=%v arrow=%v", + got.IsMethod, got.IsQualified, got.IsArrow) + } + if !equalStringSlices(got.Args, []string{"x"}) { + t.Errorf("Args = %v, want [x]", got.Args) + } +} + +// TestExtractCallInfo_NilOrWrongNode verifies the guards that protect +// callers from passing arbitrary nodes through this helper. +func TestExtractCallInfo_NilOrWrongNode(t *testing.T) { + if got := ExtractCallInfo(nil, nil); got != nil { + t.Errorf("ExtractCallInfo(nil) = %+v, want nil", got) + } + + tr, rt := parseCSnippet(t, "int x;") + defer tr.Close() + if got := ExtractCallInfo(rt, nil); got != nil { + t.Errorf("ExtractCallInfo(non-call) = %+v, want nil", got) + } +} + +// TestIsCKeyword verifies the C reserved-word set covers the spans the +// grammar emits as identifiers in tree-sitter so statement extraction in +// PR-05 can filter them safely. +func TestIsCKeyword(t *testing.T) { + cases := []struct { + name string + want bool + }{ + // C89/C90 core + {"int", true}, {"void", true}, {"return", true}, {"struct", true}, + {"const", true}, {"static", true}, {"sizeof", true}, + // C99 + {"inline", true}, {"restrict", true}, + // C11 + {"_Atomic", true}, {"_Generic", true}, + // C23 + {"true", true}, {"false", true}, {"nullptr", true}, + // Common constants + {"NULL", true}, {"EOF", true}, + // Non-keywords + {"foo", false}, {"bar", false}, {"my_func", false}, + // C++-only — must NOT be a C keyword + {"class", false}, {"namespace", false}, {"new", false}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + if got := IsCKeyword(tt.name); got != tt.want { + t.Errorf("IsCKeyword(%q) = %v, want %v", tt.name, got, tt.want) + } + }) + } +} + +// TestIsCppKeyword verifies that C++ keyword recognition includes both the +// inherited C keywords and the C++-only additions. +func TestIsCppKeyword(t *testing.T) { + cases := []struct { + name string + want bool + }{ + // Inherited from C + {"int", true}, {"const", true}, {"static", true}, {"return", true}, + // C++-only additions + {"class", true}, {"namespace", true}, {"template", true}, + {"new", true}, {"delete", true}, {"this", true}, + {"throw", true}, {"try", true}, {"catch", true}, + {"public", true}, {"private", true}, {"protected", true}, + // Common stdlib types treated as keywords for filtering + {"string", true}, {"vector", true}, {"size_t", true}, + // Non-keywords + {"foo", false}, {"my_class", false}, {"compute", false}, + } + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + if got := IsCppKeyword(tt.name); got != tt.want { + t.Errorf("IsCppKeyword(%q) = %v, want %v", tt.name, got, tt.want) + } + }) + } +} diff --git a/sast-engine/graph/clike/testhelpers_test.go b/sast-engine/graph/clike/testhelpers_test.go new file mode 100644 index 00000000..b6d8a12d --- /dev/null +++ b/sast-engine/graph/clike/testhelpers_test.go @@ -0,0 +1,81 @@ +package clike + +import ( + "context" + "testing" + + sitter "github.com/smacker/go-tree-sitter" + clang "github.com/smacker/go-tree-sitter/c" + cpplang "github.com/smacker/go-tree-sitter/cpp" +) + +// parseCSnippet parses the given C source and returns (tree, root). Caller +// must defer tree.Close(). +func parseCSnippet(t *testing.T, code string) (*sitter.Tree, *sitter.Node) { + t.Helper() + parser := sitter.NewParser() + parser.SetLanguage(clang.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(code)) + if err != nil { + t.Fatalf("parse C: %v", err) + } + return tree, tree.RootNode() +} + +// parseCppSnippet parses the given C++ source and returns (tree, root). +// Caller must defer tree.Close(). +func parseCppSnippet(t *testing.T, code string) (*sitter.Tree, *sitter.Node) { + t.Helper() + parser := sitter.NewParser() + parser.SetLanguage(cpplang.GetLanguage()) + defer parser.Close() + + tree, err := parser.ParseCtx(context.Background(), nil, []byte(code)) + if err != nil { + t.Fatalf("parse C++: %v", err) + } + return tree, tree.RootNode() +} + +// snippet parses code with the C or C++ grammar selected by language. +func snippet(t *testing.T, language, code string) (*sitter.Tree, *sitter.Node) { + t.Helper() + if language == "cpp" { + return parseCppSnippet(t, code) + } + return parseCSnippet(t, code) +} + +// equalStringSlices reports whether two string slices are element-wise equal. +// nil and empty are treated as equal so tests can compare against literal +// []string{} without distinguishing the two representations. +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +// findNode performs a pre-order search and returns the first descendant of +// node whose Type() matches nodeType. Returns nil when no match exists. +func findNode(node *sitter.Node, nodeType string) *sitter.Node { + if node == nil { + return nil + } + if node.Type() == nodeType { + return node + } + for i := 0; i < int(node.ChildCount()); i++ { + if found := findNode(node.Child(i), nodeType); found != nil { + return found + } + } + return nil +} diff --git a/sast-engine/graph/clike/types.go b/sast-engine/graph/clike/types.go new file mode 100644 index 00000000..438267de --- /dev/null +++ b/sast-engine/graph/clike/types.go @@ -0,0 +1,145 @@ +package clike + +import ( + "strings" + + sitter "github.com/smacker/go-tree-sitter" +) + +// ExtractTypeString assembles a complete C/C++ type string from the AST nodes +// produced by tree-sitter for a declaration. The result includes type +// qualifiers (const / volatile / restrict), the base type (primitive_type, +// type_identifier, qualified_identifier, or template_type), and any pointer +// or reference suffixes derived from the declarator chain. +// +// The function is invoked with the "type" node from a parameter_declaration, +// declaration, or field_declaration; the caller passes in the matching +// "declarator" node to pick up the * / & / [] suffixes that tree-sitter +// places on the declarator side of the AST. Either argument may be nil: +// a nil typeNode produces "void" (the empty-return-type convention used by +// tree-sitter's C grammar), and a nil declarator simply means no suffixes. +// +// Examples (typeNode + declarator → result): +// +// primitive_type "int" + identifier "x" → "int" +// primitive_type "char" + pointer_declarator → "char*" +// primitive_type "int" + pointer_declarator(2) → "int**" +// type_identifier "FILE" + pointer_declarator → "FILE*" +// primitive_type "int" with const + identifier "x" → "const int" +// qualified_identifier "std::string" + reference_declarator → "std::string&" +// template_type "std::vector" + identifier "v" → "std::vector" +// primitive_type "long" with unsigned + identifier "n" → "unsigned long" +// +// The helper is whitespace-conservative: qualifiers are joined with single +// spaces and pointer/reference suffixes are appended without spaces, which +// matches the canonical form used by every other type registry in the +// codebase (Python typeshed, Go types.go, Java fully-qualified names). +func ExtractTypeString(typeNode, declarator *sitter.Node, sourceCode []byte) string { + if typeNode == nil { + return "void" + } + + // Qualifiers (const, volatile, restrict, unsigned, signed) live as + // sibling type_qualifier / sized_type_specifier nodes on the parent. + qualifiers := collectTypeQualifiers(typeNode, sourceCode) + + base := baseTypeString(typeNode, sourceCode) + suffix := pointerRefSuffix(declarator, sourceCode) + + if len(qualifiers) == 0 { + return base + suffix + } + return strings.Join(qualifiers, " ") + " " + base + suffix +} + +// baseTypeString returns the human-readable base type from typeNode, without +// any qualifiers or pointer/reference suffixes. +// +// Tree-sitter emits five shapes in the spots we call this from — +// primitive_type, type_identifier, sized_type_specifier, qualified_identifier +// ("std::string"), and template_type ("vector") — and all of them +// serialise verbatim from source, so a single content fetch suffices. +func baseTypeString(typeNode *sitter.Node, sourceCode []byte) string { + return strings.TrimSpace(typeNode.Content(sourceCode)) +} + +// collectTypeQualifiers walks the parent of typeNode looking for qualifier +// siblings (type_qualifier nodes such as "const", "volatile", "restrict", +// and the C-specific signedness markers "unsigned" / "signed" expressed as +// sized_type_specifier siblings on certain grammars). Order is preserved +// from source. +func collectTypeQualifiers(typeNode *sitter.Node, sourceCode []byte) []string { + parent := typeNode.Parent() + if parent == nil { + return nil + } + + var quals []string + for i := 0; i < int(parent.NamedChildCount()); i++ { + sib := parent.NamedChild(i) + if sib == nil || sib.Equal(typeNode) { + continue + } + if sib.Type() == "type_qualifier" { + quals = append(quals, strings.TrimSpace(sib.Content(sourceCode))) + } + } + return quals +} + +// pointerRefSuffix walks down a declarator chain and emits one * for each +// pointer_declarator and one & for each reference_declarator (C++ only). +// The traversal stops at the first non-pointer/reference node, which is +// usually the identifier that names the entity being declared. +// +// tree-sitter nests pointer declarators left-to-right, so "int**" appears +// as pointer_declarator(pointer_declarator(identifier)); collecting them +// inside-out produces the correct "**" suffix order. +func pointerRefSuffix(declarator *sitter.Node, _ []byte) string { + suffix := "" + for cur := declarator; cur != nil; { + switch cur.Type() { + case "pointer_declarator", "abstract_pointer_declarator": + suffix += "*" + case "reference_declarator", "abstract_reference_declarator": + suffix += "&" + default: + return suffix + } + next := innerDeclarator(cur) + if next == nil { + return suffix + } + cur = next + } + return suffix +} + +// innerDeclarator returns the next declarator inside a wrapper such as +// pointer_declarator, reference_declarator, or array_declarator. The C +// grammar exposes this as a field-named "declarator" child, but the C++ +// reference_declarator (and several abstract_* variants) does not name the +// inner child — in that case we fall back to the first non-anonymous named +// child, which is the wrapped declarator or identifier in every grammar +// rule that uses these node types. +func innerDeclarator(wrapper *sitter.Node) *sitter.Node { + if wrapper == nil { + return nil + } + if c := wrapper.ChildByFieldName("declarator"); c != nil { + return c + } + for i := 0; i < int(wrapper.NamedChildCount()); i++ { + c := wrapper.NamedChild(i) + if c == nil { + continue + } + // Skip type-side qualifiers that occasionally appear inside + // declarators (e.g., const inside `int * const p`). + if c.Type() == "type_qualifier" { + continue + } + return c + } + return nil +} diff --git a/sast-engine/graph/clike/types_test.go b/sast-engine/graph/clike/types_test.go new file mode 100644 index 00000000..c62a7a18 --- /dev/null +++ b/sast-engine/graph/clike/types_test.go @@ -0,0 +1,150 @@ +package clike + +import ( + "testing" + + sitter "github.com/smacker/go-tree-sitter" +) + +// TestExtractTypeString covers the type strings produced for parameter and +// field declarations across both C and C++. Each case parses a real source +// snippet, locates the first parameter_declaration or field_declaration, and +// invokes ExtractTypeString on its (type, declarator) pair. +func TestExtractTypeString(t *testing.T) { + tests := []struct { + name string + // language: "c" or "cpp" + language string + code string + // node selects which AST node to feed into ExtractTypeString: + // "param" for the first parameter_declaration, "field" for the + // first field_declaration. + nodeKind string + want string + }{ + { + name: "plain int parameter", + language: "c", + code: "void f(int x);", + nodeKind: "param", + want: "int", + }, + { + name: "char pointer parameter", + language: "c", + code: "void f(char* buf);", + nodeKind: "param", + want: "char*", + }, + { + name: "double pointer parameter", + language: "c", + code: "void f(int** pp);", + nodeKind: "param", + want: "int**", + }, + { + name: "const char pointer parameter", + language: "c", + code: "void f(const char* fmt);", + nodeKind: "param", + want: "const char*", + }, + { + name: "FILE pointer parameter", + language: "c", + code: "void f(FILE* fp);", + nodeKind: "param", + want: "FILE*", + }, + { + name: "unsigned long long parameter", + language: "c", + code: "void f(unsigned long long n);", + nodeKind: "param", + want: "unsigned long long", + }, + { + name: "void pointer parameter", + language: "c", + code: "void f(void* p);", + nodeKind: "param", + want: "void*", + }, + { + name: "struct field pointer", + language: "c", + code: "struct S { char* name; };", + nodeKind: "field", + want: "char*", + }, + { + name: "C++ string reference", + language: "cpp", + code: "void f(std::string& s);", + nodeKind: "param", + want: "std::string&", + }, + { + name: "C++ const string reference", + language: "cpp", + code: "void f(const std::string& s);", + nodeKind: "param", + want: "const std::string&", + }, + { + name: "C++ vector of int", + language: "cpp", + code: "void f(std::vector v);", + nodeKind: "param", + want: "std::vector", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var tree *sitter.Tree + var root *sitter.Node + if tt.language == "cpp" { + tree, root = parseCppSnippet(t, tt.code) + } else { + tree, root = parseCSnippet(t, tt.code) + } + defer tree.Close() + + var typeNode, declarator *sitter.Node + switch tt.nodeKind { + case "param": + p := findNode(root, "parameter_declaration") + if p == nil { + t.Fatal("parameter_declaration not found") + } + typeNode = p.ChildByFieldName("type") + declarator = p.ChildByFieldName("declarator") + case "field": + f := findNode(root, "field_declaration") + if f == nil { + t.Fatal("field_declaration not found") + } + typeNode = f.ChildByFieldName("type") + declarator = f.ChildByFieldName("declarator") + default: + t.Fatalf("unknown nodeKind %q", tt.nodeKind) + } + + got := ExtractTypeString(typeNode, declarator, []byte(tt.code)) + if got != tt.want { + t.Errorf("ExtractTypeString = %q, want %q", got, tt.want) + } + }) + } +} + +// TestExtractTypeString_NilType verifies the void fallback for a nil type +// node (used when tree-sitter omits the type slot for void returns). +func TestExtractTypeString_NilType(t *testing.T) { + got := ExtractTypeString(nil, nil, nil) + if got != "void" { + t.Errorf("ExtractTypeString(nil, nil, nil) = %q, want %q", got, "void") + } +}