diff --git a/images/python.png b/images/python.png index c451cb3..4c3d104 100644 Binary files a/images/python.png and b/images/python.png differ diff --git a/lua/diffmantic/core/actions.lua b/lua/diffmantic/core/actions.lua index f69567c..1f5c05d 100644 --- a/lua/diffmantic/core/actions.lua +++ b/lua/diffmantic/core/actions.lua @@ -1,18 +1,497 @@ local M = {} +local semantic = require("diffmantic.core.semantic") +local roles = require("diffmantic.core.roles") +local analysis = require("diffmantic.core.analysis") + +local LITERAL_MEMBER_NODE_TYPES = { + pair = true, + keyed_element = true, + property_signature = true, + shorthand_property_identifier = true, +} + +local function range_text(buf, range) + if not buf or not range then + return nil + end + if range.start_row == nil or range.end_row == nil or range.start_col == nil or range.end_col == nil then + return nil + end + if range.start_row ~= range.end_row then + return nil + end + local line = vim.api.nvim_buf_get_lines(buf, range.start_row, range.start_row + 1, false)[1] or "" + if line == "" then + return nil + end + local start_col = range.start_col + 1 + local end_col = range.end_col + if end_col < start_col then + return nil + end + return line:sub(start_col, end_col) +end + +local function action_pair_key(action) + if action and action.src_node and action.dst_node then + return action.src_node:id() .. ":" .. action.dst_node:id() + end + if action and action.src and action.dst then + return table.concat({ + tostring(action.src.start_row), + tostring(action.src.start_col), + tostring(action.src.end_row), + tostring(action.src.end_col), + tostring(action.dst.start_row), + tostring(action.dst.start_col), + tostring(action.dst.end_row), + tostring(action.dst.end_col), + }, ":") + end + return nil +end + +local function has_effective_update_hunks(action, src_buf, dst_buf) + local action_analysis = action and action.analysis or nil + local hunks = action_analysis and action_analysis.hunks or nil + if not hunks or #hunks == 0 then + return false + end + + local rename_pairs = action_analysis.rename_pairs or {} + for _, hunk in ipairs(hunks) do + if hunk.kind == "insert" or hunk.kind == "delete" then + return true + end + if hunk.kind == "change" then + local src_text = range_text(src_buf, hunk.src) + local dst_text = range_text(dst_buf, hunk.dst) + if not src_text or not dst_text then + return true + end + if src_text ~= dst_text and rename_pairs[src_text] ~= dst_text then + return true + end + end + end + return false +end + +local function range_contains(outer, inner) + if not outer or not inner then + return false + end + if outer.start_row == nil or outer.end_row == nil or outer.start_col == nil or outer.end_col == nil then + return false + end + if inner.start_row == nil or inner.end_row == nil or inner.start_col == nil or inner.end_col == nil then + return false + end + if inner.start_row < outer.start_row or inner.end_row > outer.end_row then + return false + end + if inner.start_row == outer.start_row and inner.start_col < outer.start_col then + return false + end + if inner.end_row == outer.end_row and inner.end_col > outer.end_col then + return false + end + return true +end + +local function ranges_equal(a, b) + if not a or not b then + return false + end + return a.start_row == b.start_row + and a.start_col == b.start_col + and a.end_row == b.end_row + and a.end_col == b.end_col +end + +local function ranges_related(a, b) + return ranges_equal(a, b) or range_contains(a, b) or range_contains(b, a) +end + +local TEMP_RESULT_DECL_NODE_TYPES = { + lexical_declaration = true, + short_var_declaration = true, + init_declarator = true, + variable_declaration = true, + local_variable_declaration = true, +} + +local function line_text(buf, row) + if not buf or row == nil then + return nil + end + local lines = vim.api.nvim_buf_get_lines(buf, row, row + 1, false) + return lines and lines[1] or nil +end + +local function action_side_text(action, buf, side) + if not action or not buf then + return nil + end + local node = side == "src" and action.src_node or action.dst_node + if node then + return vim.treesitter.get_node_text(node, buf) + end + local range = side == "src" and action.src or action.dst + return range_text(buf, range) +end + +local function extract_assigned_identifier(text) + if not text or text == "" then + return nil + end + local id = text:match("^%s*[Cc]onst%s+([%a_][%w_]*)%s*=") + or text:match("^%s*[Ll]et%s+([%a_][%w_]*)%s*=") + or text:match("^%s*[Vv]ar%s+([%a_][%w_]*)%s*=") + or text:match("^%s*[Ll]ocal%s+([%a_][%w_]*)%s*=") + or text:match("^%s*([%a_][%w_]*)%s*:=") + if id then + return id + end + local lhs = text:match("^(.-)=") + if not lhs then + return nil + end + return lhs:match("([%a_][%w_]*)%s*$") +end + +local function extract_return_identifier(text) + if not text or text == "" then + return nil + end + local expr = text:match("^%s*return%s+(.+)$") + if not expr then + return nil + end + local trimmed = expr:gsub("%s*;%s*$", ""):match("^%s*(.-)%s*$") + if not trimmed then + return nil + end + return trimmed:match("^%(*%s*([%a_][%w_]*)%s*%)?$") +end + +local function mark_temp_result_refactor_overrides(actions_list, src_buf, dst_buf) + if not src_buf or not dst_buf then + return + end + + local function mark_insert_actions_for_rows(update_action, declared_name, decl_row, ret_row) + for _, action in ipairs(actions_list) do + if action.type == "insert" and action.dst and range_contains(update_action.dst, action.dst) then + local node_type = action.metadata and action.metadata.node_type or nil + if action.dst.start_row == decl_row and TEMP_RESULT_DECL_NODE_TYPES[node_type] then + local decl_text = action_side_text(action, dst_buf, "dst") or line_text(dst_buf, decl_row) + if extract_assigned_identifier(decl_text) == declared_name then + action.metadata = action.metadata or {} + action.metadata.render_as_change = true + end + elseif action.dst.start_row == ret_row and node_type == "return_statement" then + local ret_text = action_side_text(action, dst_buf, "dst") or line_text(dst_buf, ret_row) + if extract_return_identifier(ret_text) == declared_name then + action.metadata = action.metadata or {} + action.metadata.render_as_change = true + end + end + end + end + end + + for _, update_action in ipairs(actions_list) do + if update_action.type == "update" and update_action.dst then + local hunks = update_action.analysis and update_action.analysis.hunks or {} + for _, change_hunk in ipairs(hunks) do + if change_hunk.kind == "change" and change_hunk.src and change_hunk.dst then + local src_line = line_text(src_buf, change_hunk.src.start_row) + local dst_line = line_text(dst_buf, change_hunk.dst.start_row) + if src_line and dst_line and src_line:match("^%s*return%s+") then + local declared_name = extract_assigned_identifier(dst_line) + if declared_name then + local matched_insert = nil + for _, insert_hunk in ipairs(hunks) do + if + insert_hunk.kind == "insert" + and insert_hunk.dst + and insert_hunk.dst.start_row >= change_hunk.dst.start_row + and insert_hunk.dst.start_row <= change_hunk.dst.start_row + 2 + then + local inserted_line = line_text(dst_buf, insert_hunk.dst.start_row) + if extract_return_identifier(inserted_line) == declared_name then + matched_insert = insert_hunk + break + end + end + end + + if matched_insert then + matched_insert.render_as_change = true + mark_insert_actions_for_rows( + update_action, + declared_name, + change_hunk.dst.start_row, + matched_insert.dst.start_row + ) + end + end + end + end + end + end + end +end + +local function mark_literal_member_render_overrides(actions_list) + local update_contexts = {} + for _, action in ipairs(actions_list) do + if action.type == "update" then + local context = { + action = action, + src = action.src, + dst = action.dst, + src_change_hunks = {}, + dst_change_hunks = {}, + } + local hunks = action.analysis and action.analysis.hunks or {} + for _, hunk in ipairs(hunks) do + if hunk.kind == "change" then + if hunk.src then + table.insert(context.src_change_hunks, hunk.src) + end + if hunk.dst then + table.insert(context.dst_change_hunks, hunk.dst) + end + end + end + table.insert(update_contexts, context) + end + end + + local function overlaps_any(ranges, target) + if not target then + return false + end + for _, r in ipairs(ranges) do + if ranges_related(r, target) then + return true + end + end + return false + end + + for _, action in ipairs(actions_list) do + if action.type == "insert" or action.type == "delete" then + local metadata = action.metadata or {} + if LITERAL_MEMBER_NODE_TYPES[metadata.node_type] then + local target = action.type == "insert" and action.dst or action.src + for _, context in ipairs(update_contexts) do + local container = action.type == "insert" and context.dst or context.src + local change_hunks = action.type == "insert" and context.dst_change_hunks or context.src_change_hunks + -- Key-aware policy: only override inserts/deletes that overlap replacement hunks. + if range_contains(container, target) and overlaps_any(change_hunks, target) then + action.metadata = action.metadata or {} + action.metadata.render_as_change = true + local hunks = context.action.analysis and context.action.analysis.hunks or {} + for _, hunk in ipairs(hunks) do + if action.type == "insert" and hunk.kind == "insert" and ranges_related(hunk.dst, action.dst) then + hunk.render_as_change = true + elseif action.type == "delete" and hunk.kind == "delete" and ranges_related(hunk.src, action.src) then + hunk.render_as_change = true + end + end + break + end + end + end + end + end +end + +local function range_metadata(node) + if not node then + return nil + end + local sr, sc, er, ec = node:range() + return { + start_row = sr, + start_col = sc, + end_row = er, + end_col = ec, + start_line = sr + 1, + end_line = er + 1, + } +end + +local function build_action(action_type, src_node, dst_node, extra) + local src = range_metadata(src_node) + local dst = range_metadata(dst_node) + local node = src_node or dst_node + local from_line = src and src.start_line or nil + local to_line = dst and dst.start_line or nil + + local action = { + type = action_type, + src_node = src_node, + dst_node = dst_node, + src = src and vim.tbl_extend("force", {}, src, { text = nil }) or nil, + dst = dst and vim.tbl_extend("force", {}, dst, { text = nil }) or nil, + metadata = { + node_type = node and node:type() or nil, + old_name = nil, + new_name = nil, + from_line = from_line, + to_line = to_line, + suppressed_renames = nil, + }, + analysis = nil, + } + + if extra then + if extra.context then + action.context = extra.context + end + if extra.analysis then + action.analysis = extra.analysis + end + if extra.metadata then + action.metadata = vim.tbl_extend("force", action.metadata, extra.metadata) + end + action.metadata.old_name = extra.old_name or extra.from or action.metadata.old_name + action.metadata.new_name = extra.new_name or extra.to or action.metadata.new_name + action.metadata.from_line = extra.from_line or action.metadata.from_line + action.metadata.to_line = extra.to_line or action.metadata.to_line + action.metadata.node_type = extra.node_type or action.metadata.node_type + action.metadata.suppressed_renames = extra.suppressed_renames or action.metadata.suppressed_renames + if extra.context and extra.context.suppressed_usages then + action.metadata.suppressed_renames = extra.context.suppressed_usages + end + end + + return action +end + +local function build_summary(actions_list) + local summary = { + counts = { + move = 0, + rename = 0, + rename_suppressed = 0, + update = 0, + insert = 0, + delete = 0, + total = #actions_list, + }, + moves = {}, + renames = {}, + suppressed_renames = {}, + updates = {}, + inserts = {}, + deletes = {}, + } + + for _, action in ipairs(actions_list) do + local t = action.type + if summary.counts[t] ~= nil then + summary.counts[t] = summary.counts[t] + 1 + end + + if t == "move" then + local metadata = action.metadata or {} + table.insert(summary.moves, { + node_type = metadata.node_type, + from_line = metadata.from_line, + to_line = metadata.to_line, + src_range = action.src, + dst_range = action.dst, + }) + elseif t == "rename" then + local metadata = action.metadata or {} + local suppressed_usages = metadata.suppressed_renames or (action.context and action.context.suppressed_usages or {}) + local suppressed_count = #suppressed_usages + summary.counts.rename_suppressed = summary.counts.rename_suppressed + suppressed_count + + table.insert(summary.renames, { + node_type = metadata.node_type, + from = metadata.old_name, + to = metadata.new_name, + from_line = metadata.from_line, + to_line = metadata.to_line, + src_range = action.src, + dst_range = action.dst, + suppressed_usage_count = suppressed_count, + }) + + for _, usage in ipairs(suppressed_usages) do + local usage_meta = usage.metadata or {} + table.insert(summary.suppressed_renames, { + from = usage_meta.old_name, + to = usage_meta.new_name, + from_line = usage_meta.from_line, + to_line = usage_meta.to_line, + src_range = usage.src, + dst_range = usage.dst, + suppressed_by = { + from = metadata.old_name, + to = metadata.new_name, + from_line = metadata.from_line, + to_line = metadata.to_line, + }, + }) + end + elseif t == "update" then + local metadata = action.metadata or {} + table.insert(summary.updates, { + node_type = metadata.node_type, + from_line = metadata.from_line, + to_line = metadata.to_line, + src_range = action.src, + dst_range = action.dst, + }) + elseif t == "insert" then + local metadata = action.metadata or {} + table.insert(summary.inserts, { + node_type = metadata.node_type, + line = metadata.to_line, + dst_range = action.dst, + }) + elseif t == "delete" then + local metadata = action.metadata or {} + table.insert(summary.deletes, { + node_type = metadata.node_type, + line = metadata.from_line, + src_range = action.src, + }) + end + end + + return summary +end -- Generate edit actions from node mappings --- Actions describe what changed: insert, delete, update, move +-- Actions describe what changed: insert, delete, update, move, rename function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, opts) + opts = opts or {} local actions = {} local timings = nil local hrtime = nil - if opts and opts.timings then + local src_has_parse_error = src_root and src_root.has_error and src_root:has_error() or false + local dst_has_parse_error = dst_root and dst_root.has_error and dst_root:has_error() or false + local has_parse_error = src_has_parse_error or dst_has_parse_error + if opts.timings then timings = {} if vim and vim.loop and vim.loop.hrtime then hrtime = vim.loop.hrtime end end + local src_role_index = opts.src_role_index or nil + local dst_role_index = opts.dst_role_index or nil + local src_buf = opts.src_buf or nil + local dst_buf = opts.dst_buf or nil + local function start_timer() if not hrtime then return nil @@ -27,6 +506,286 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op timings[key] = (hrtime() - started_at) / 1e6 end + local function enrich_update_actions_with_semantics(actions_list) + if not src_buf or not dst_buf then + return + end + + for _, action in ipairs(actions_list) do + local src_node = action.src_node + local dst_node = action.dst_node + if action.type == "update" and src_node and dst_node then + local leaf_changes = semantic.find_leaf_changes(src_node, dst_node, src_buf, dst_buf, src_role_index, dst_role_index) + local rename_pairs = {} + for _, change in ipairs(leaf_changes) do + if + semantic.is_rename_identifier(change.src_node, src_role_index) + or semantic.is_rename_identifier(change.dst_node, dst_role_index) + then + rename_pairs[change.src_text] = change.dst_text + end + end + + action.analysis = { + leaf_changes = leaf_changes, + rename_pairs = rename_pairs, + } + end + end + end + + local function emit_rename_actions(actions_list) + local renames = {} + local seen = {} + local is_buf_available = src_buf and dst_buf + + local function pair_key(from_text, to_text) + return tostring(from_text) .. "\x1f" .. tostring(to_text) + end + + local function push_rename(src_node, dst_node, from_text, to_text, context) + if not src_node or not dst_node or not from_text or not to_text then + return + end + if from_text == to_text then + return + end + local key = table.concat({ + tostring(src_node:id()), + tostring(dst_node:id()), + from_text, + to_text, + }, ":") + if seen[key] then + return + end + seen[key] = true + table.insert(renames, build_action("rename", src_node, dst_node, { + from = from_text, + to = to_text, + context = context, + })) + end + + local function is_decl_rename(src_node, dst_node) + local function is_class_name_node(node, role_index) + return roles.has_capture(node, role_index, "diff.class.name") + end + + local function is_class_like_context(node, role_index) + local cur = node + while cur do + if roles.has_kind(cur, role_index, "class") then + return true + end + cur = cur:parent() + end + return false + end + + local src_is_class_name = is_class_name_node(src_node, src_role_index) + local dst_is_class_name = is_class_name_node(dst_node, dst_role_index) + if + (not src_is_class_name and is_class_like_context(src_node, src_role_index)) + or (not dst_is_class_name and is_class_like_context(dst_node, dst_role_index)) + then + return false + end + + return semantic.is_rename_identifier(src_node, src_role_index) + and semantic.is_rename_identifier(dst_node, dst_role_index) + end + + local seed_pairs = {} + local function add_seed(from_text, to_text) + if from_text and to_text and from_text ~= to_text then + seed_pairs[pair_key(from_text, to_text)] = true + end + end + + local function collect_param_identifiers(node, bufnr) + if not node or not bufnr then + return {} + end + + local parameter_kinds = { + parameters = true, + parameter_list = true, + formal_parameters = true, + } + + local function find_parameter_node(n) + if parameter_kinds[n:type()] then + return n + end + for child in n:iter_children() do + local found = find_parameter_node(child) + if found then + return found + end + end + return nil + end + + local params_root = find_parameter_node(node) + if not params_root then + return {} + end + + local out = {} + local function first_identifier(n) + if n:child_count() == 0 then + local t = n:type() + if t == "identifier" or t == "field_identifier" or t == "property_identifier" then + local text = vim.treesitter.get_node_text(n, bufnr) + if text and text:match("^[%a_][%w_]*$") then + return { node = n, text = text } + end + end + return nil + end + for child in n:iter_children() do + local found = first_identifier(child) + if found then + return found + end + end + return nil + end + + for child in params_root:iter_children() do + if child:named() then + local found = first_identifier(child) + if found then + table.insert(out, found) + end + end + end + return out + end + + -- Pass 1a: high-confidence declaration-like rename seeds from semantic leaf changes. + for _, action in ipairs(actions_list) do + if action.type == "update" and action.analysis and action.analysis.leaf_changes then + for _, change in ipairs(action.analysis.leaf_changes) do + local src_node = change.src_node + local dst_node = change.dst_node + if src_node and dst_node and change.src_text ~= change.dst_text and is_decl_rename(src_node, dst_node) then + add_seed(change.src_text, change.dst_text) + end + end + end + end + + -- Pass 1a.2: positional parameter rename seeds/actions for updated functions. + if is_buf_available then + for _, action in ipairs(actions_list) do + if action.type == "update" and action.src_node and action.dst_node then + local src_params = collect_param_identifiers(action.src_node, src_buf) + local dst_params = collect_param_identifiers(action.dst_node, dst_buf) + if #src_params > 0 and #src_params == #dst_params and #src_params <= 16 then + for i = 1, #src_params do + local s = src_params[i] + local d = dst_params[i] + if s.text ~= d.text then + add_seed(s.text, d.text) + if is_decl_rename(s.node, d.node) then + push_rename(s.node, d.node, s.text, d.text, { + src_parent_type = action.src_node and action.src_node:type() or nil, + dst_parent_type = action.dst_node and action.dst_node:type() or nil, + source = "parameter_positional", + declaration = true, + }) + end + end + end + end + end + end + end + + -- Pass 2: emit leaf rename actions gated by seeds (and declaration renames). + for _, action in ipairs(actions_list) do + if action.type == "update" and action.analysis and action.analysis.leaf_changes then + for _, change in ipairs(action.analysis.leaf_changes) do + local src_node = change.src_node + local dst_node = change.dst_node + local src_text = change.src_text + local dst_text = change.dst_text + if src_node and dst_node and src_text and dst_text and src_text ~= dst_text then + local is_decl = is_decl_rename(src_node, dst_node) + local key = pair_key(src_text, dst_text) + if seed_pairs[key] or is_decl then + push_rename(src_node, dst_node, src_text, dst_text, { + src_parent_type = action.src_node and action.src_node:type() or nil, + dst_parent_type = action.dst_node and action.dst_node:type() or nil, + declaration = is_decl, + }) + end + end + end + end + end + + -- If a declaration rename exists for a pair, suppress usage-level duplicates for that pair. + local declaration_pairs = {} + for _, rename_action in ipairs(renames) do + local metadata = rename_action.metadata or {} + local from_name = metadata.old_name + local to_name = metadata.new_name + if rename_action.context and rename_action.context.declaration then + declaration_pairs[pair_key(from_name, to_name)] = true + end + end + + local suppressed_by_pair = {} + local filtered_renames = {} + for _, rename_action in ipairs(renames) do + local metadata = rename_action.metadata or {} + local from_name = metadata.old_name + local to_name = metadata.new_name + local key = pair_key(from_name, to_name) + local is_declaration = rename_action.context and rename_action.context.declaration + if not declaration_pairs[key] or is_declaration then + table.insert(filtered_renames, rename_action) + else + suppressed_by_pair[key] = suppressed_by_pair[key] or {} + table.insert(suppressed_by_pair[key], { + src = rename_action.src, + dst = rename_action.dst, + metadata = { + old_name = from_name, + new_name = to_name, + from_line = metadata.from_line, + to_line = metadata.to_line, + }, + context = rename_action.context, + }) + end + end + + for _, rename_action in ipairs(filtered_renames) do + if rename_action.context and rename_action.context.declaration then + local metadata = rename_action.metadata or {} + local from_name = metadata.old_name + local to_name = metadata.new_name + local key = pair_key(from_name, to_name) + local suppressed_usages = suppressed_by_pair[key] + if suppressed_usages and #suppressed_usages > 0 then + rename_action.context.suppressed_usages = suppressed_usages + rename_action.context.suppressed_usage_count = #suppressed_usages + if rename_action.metadata then + rename_action.metadata.suppressed_renames = suppressed_usages + end + end + end + end + + for _, rename_action in ipairs(filtered_renames) do + table.insert(actions_list, rename_action) + end + end + -- Build O(1) lookup tables local precompute_start = start_timer() local src_to_dst = {} @@ -36,6 +795,17 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op dst_to_src[m.dst] = m.src end + local roles_start = start_timer() + if src_buf and dst_buf then + if not src_role_index then + src_role_index = roles.build_index(src_root, src_buf) + end + if not dst_role_index then + dst_role_index = roles.build_index(dst_root, dst_buf) + end + end + stop_timer(roles_start, "roles") + local significant_types = { function_declaration = true, variable_declaration = true, @@ -52,7 +822,6 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op assignment_statement = true, for_statement = true, while_statement = true, - function_call = true, -- Python class_definition = true, import_statement = true, @@ -80,6 +849,77 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op struct_specifier = true, } + local function has_kind(node, index, kind) + return index and roles.has_structural_kind(node, index, kind) or false + end + + local function has_role_captures(index) + if not index or not index.by_capture then + return false + end + for capture, _ in pairs(index.by_capture) do + if capture ~= "diff.fallback.node" then + return true + end + end + return false + end + + local src_uses_roles = has_role_captures(src_role_index) + local dst_uses_roles = has_role_captures(dst_role_index) + + local function is_significant(info, index) + local uses_roles = (index == src_role_index and src_uses_roles) or (index == dst_role_index and dst_uses_roles) + local node = info.node + if has_kind(node, index, "function") + or has_kind(node, index, "class") + or has_kind(node, index, "variable") + or has_kind(node, index, "assignment") + or has_kind(node, index, "import") + or has_kind(node, index, "return") + or has_kind(node, index, "preproc") + then + return true + end + if uses_roles then + return false + end + return significant_types[info.type] or false + end + + local function is_transparent_update_ancestor(info, index) + local uses_roles = (index == src_role_index and src_uses_roles) or (index == dst_role_index and dst_uses_roles) + if has_kind(info.node, index, "class") then + return true + end + if uses_roles then + return false + end + return transparent_update_ancestors[info.type] or false + end + + local function is_movable(info, index) + local uses_roles = (index == src_role_index and src_uses_roles) or (index == dst_role_index and dst_uses_roles) + if has_kind(info.node, index, "function") or has_kind(info.node, index, "class") then + return true + end + if uses_roles then + return false + end + return movable_types[info.type] or false + end + + local function should_emit_update(info, index) + if not is_significant(info, index) then + return false + end + -- Type/struct/class bodies are better represented by insert/delete or move+rename. + if has_kind(info.node, index, "class") then + return false + end + return true + end + -- Helper: check if node or any descendant has different content local function has_content_change(src_node, dst_node) local src_info_data = src_info[src_node:id()] @@ -100,65 +940,68 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op end end - -- Precompute ancestry flags for source nodes (unmapped significant ancestors) - local src_has_unmapped_sig_ancestor = {} - for id, info in pairs(src_info) do - local current = info.parent - while current do - local p_id = current:id() - local p_info = src_info[p_id] - if p_info then - if not src_to_dst[p_id] and significant_types[p_info.type] then - src_has_unmapped_sig_ancestor[id] = true - break - end - current = p_info.parent - else - break - end + local function parent_id_of(info) + if not info then + return nil + end + if info.parent_id ~= nil then + return info.parent_id end + return info.parent and info.parent:id() or nil end - -- Precompute ancestry flags for destination nodes (unmapped significant ancestors) - local dst_has_unmapped_sig_ancestor = {} - for id, info in pairs(dst_info) do - local current = info.parent - while current do - local p_id = current:id() - local p_info = dst_info[p_id] - if p_info then - if not dst_to_src[p_id] and significant_types[p_info.type] then - dst_has_unmapped_sig_ancestor[id] = true - break - end - current = p_info.parent - else - break + local function memoized_ancestor_flags(info_map, predicate) + local memo = {} + local function has_match(id) + local cached = memo[id] + if cached ~= nil then + return cached end + local info = info_map[id] + if not info then + memo[id] = false + return false + end + local parent_id = parent_id_of(info) + if not parent_id then + memo[id] = false + return false + end + local parent_info = info_map[parent_id] + if not parent_info then + memo[id] = false + return false + end + if predicate(parent_id, parent_info) then + memo[id] = true + return true + end + local result = has_match(parent_id) + memo[id] = result + return result end + for id in pairs(info_map) do + has_match(id) + end + return memo end + -- Precompute ancestry flags for source nodes (unmapped significant ancestors) + local src_has_unmapped_sig_ancestor = memoized_ancestor_flags(src_info, function(parent_id, parent_info) + return not src_to_dst[parent_id] and is_significant(parent_info, src_role_index) + end) + + -- Precompute ancestry flags for destination nodes (unmapped significant ancestors) + local dst_has_unmapped_sig_ancestor = memoized_ancestor_flags(dst_info, function(parent_id, parent_info) + return not dst_to_src[parent_id] and is_significant(parent_info, dst_role_index) + end) + -- Precompute ancestry flags for updated significant ancestors - local src_has_updated_sig_ancestor = {} - for id, info in pairs(src_info) do - local current = info.parent - while current do - local p_id = current:id() - local p_info = src_info[p_id] - if p_info then - if nodes_with_changes[p_id] - and significant_types[p_info.type] - and not transparent_update_ancestors[p_info.type] - then - src_has_updated_sig_ancestor[id] = true - break - end - current = p_info.parent - else - break - end - end - end + local src_has_updated_sig_ancestor = memoized_ancestor_flags(src_info, function(parent_id, parent_info) + return nodes_with_changes[parent_id] + and is_significant(parent_info, src_role_index) + and not is_transparent_update_ancestor(parent_info, src_role_index) + end) -- UPDATES: mapped nodes with different content, but only significant types without updated ancestors stop_timer(precompute_start, "precompute") @@ -166,9 +1009,9 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op for _, m in ipairs(mappings) do local s, d = src_info[m.src], dst_info[m.dst] - if nodes_with_changes[m.src] and significant_types[s.type] then + if nodes_with_changes[m.src] and should_emit_update(s, src_role_index) then if not src_has_updated_sig_ancestor[m.src] then - table.insert(actions, { type = "update", node = s.node, target = d.node }) + table.insert(actions, build_action("update", s.node, d.node)) end end end @@ -177,16 +1020,24 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op -- MOVES: use LCS to find which top-level mapped functions changed relative order -- Only functions not in the LCS are considered moved. local moves_start = start_timer() - local movable_pairs = {} + local movable_pairs = {} + local src_root_id = src_root:id() + local dst_root_id = dst_root:id() for _, m in ipairs(mappings) do local s = src_info[m.src] - if s and movable_types[s.type] then - local src_parent_is_root = s.parent and s.parent:id() == src_root:id() + if s and is_movable(s, src_role_index) then + local src_parent_is_root = (s.parent_id or (s.parent and s.parent:id()) or nil) == src_root_id local d = dst_info[m.dst] - local dst_parent_is_root = d and d.parent and d.parent:id() == dst_root:id() + local dst_parent_is_root = d and ((d.parent_id or (d.parent and d.parent:id()) or nil) == dst_root_id) if src_parent_is_root and dst_parent_is_root then - local src_line = s.node:range() - local dst_line = d.node:range() + local src_line = s.start_row + if src_line == nil then + src_line = s.node:range() + end + local dst_line = d.start_row + if dst_line == nil then + dst_line = d.node:range() + end table.insert(movable_pairs, { src_id = m.src, dst_id = m.dst, @@ -248,14 +1099,29 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op end -- Mark nodes NOT in LIS as moved (only if line difference is significant) + local function has_structural_move_evidence(index) + local dst_line = movable_pairs[index].dst_line + for j, other in ipairs(movable_pairs) do + if j ~= index then + local dst_inversion = (j < index and other.dst_line > dst_line) or (j > index and other.dst_line < dst_line) + if dst_inversion then + return true + end + end + end + return false + end + + -- Mark nodes NOT in LIS as moved only when there is real order inversion evidence. + -- If parse errors exist in either tree, keep move emission conservative. for i, pair in ipairs(movable_pairs) do if not in_lis[i] then local line_diff = math.abs(pair.dst_line - pair.src_line) - if line_diff > 3 then + if not has_parse_error and line_diff > 3 and has_structural_move_evidence(i) then local s = src_info[pair.src_id] local d = dst_info[pair.dst_id] if s and d then - table.insert(actions, { type = "move", node = s.node, target = d.node }) + table.insert(actions, build_action("move", s.node, d.node)) end end end @@ -265,9 +1131,9 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op -- DELETES: unmapped source nodes local deletes_start = start_timer() for id, info in pairs(src_info) do - if not src_to_dst[id] and significant_types[info.type] then + if not src_to_dst[id] and is_significant(info, src_role_index) then if not src_has_unmapped_sig_ancestor[id] then - table.insert(actions, { type = "delete", node = info.node }) + table.insert(actions, build_action("delete", info.node, nil)) end end end @@ -276,15 +1142,82 @@ function M.generate_actions(src_root, dst_root, mappings, src_info, dst_info, op -- INSERTS: unmapped destination nodes local inserts_start = start_timer() for id, info in pairs(dst_info) do - if not dst_to_src[id] and significant_types[info.type] then + if not dst_to_src[id] and is_significant(info, dst_role_index) then if not dst_has_unmapped_sig_ancestor[id] then - table.insert(actions, { type = "insert", node = info.node }) + table.insert(actions, build_action("insert", nil, info.node)) end end end stop_timer(inserts_start, "inserts") - return actions, timings + local semantic_start = start_timer() + enrich_update_actions_with_semantics(actions) + stop_timer(semantic_start, "semantic") + + local rename_start = start_timer() + emit_rename_actions(actions) + stop_timer(rename_start, "renames") + + local analysis_start = start_timer() + if src_buf and dst_buf then + analysis.enrich(actions, { + src_buf = src_buf, + dst_buf = dst_buf, + }) + end + stop_timer(analysis_start, "analysis") + + local update_suppress_start = start_timer() + local moved_pairs = {} + local declaration_rename_pairs = {} + for _, action in ipairs(actions) do + local key = action_pair_key(action) + if key then + if action.type == "move" then + moved_pairs[key] = true + elseif action.type == "rename" and action.context and action.context.declaration then + declaration_rename_pairs[key] = true + end + end + end + + local filtered_actions = {} + for _, action in ipairs(actions) do + if action.type == "update" then + local key = action_pair_key(action) + local action_analysis = action.analysis or {} + local has_effective_hunks = has_effective_update_hunks(action, src_buf, dst_buf) + local has_any_hunks = action_analysis.hunks and #action_analysis.hunks > 0 + local has_rename_pairs = action_analysis.rename_pairs and next(action_analysis.rename_pairs) ~= nil + local overlaps_decl_rename = key and declaration_rename_pairs[key] or false + local overlaps_move = key and moved_pairs[key] or false + if action_analysis.rename_only and has_rename_pairs and not has_effective_hunks then + goto continue + end + if overlaps_decl_rename and not has_effective_hunks then + goto continue + end + if overlaps_move and not has_effective_hunks then + goto continue + end + if not has_effective_hunks and not has_any_hunks then + goto continue + end + end + table.insert(filtered_actions, action) + ::continue:: + end + actions = filtered_actions + stop_timer(update_suppress_start, "update_suppress") + + mark_temp_result_refactor_overrides(actions, src_buf, dst_buf) + mark_literal_member_render_overrides(actions) + + local summary_start = start_timer() + local summary = build_summary(actions) + stop_timer(summary_start, "summary") + + return actions, timings, summary end return M diff --git a/lua/diffmantic/core/analysis.lua b/lua/diffmantic/core/analysis.lua new file mode 100644 index 0000000..10bfb6e --- /dev/null +++ b/lua/diffmantic/core/analysis.lua @@ -0,0 +1,747 @@ +local semantic = require("diffmantic.core.semantic") + +local M = {} + +local function range_metadata(node) + if not node then + return nil + end + local sr, sc, er, ec = node:range() + return { + start_row = sr, + start_col = sc, + end_row = er, + end_col = ec, + start_line = sr + 1, + end_line = er + 1, + } +end + +local function clone_range(r) + if not r then + return nil + end + return { + start_row = r.start_row, + start_col = r.start_col, + end_row = r.end_row, + end_col = r.end_col, + start_line = r.start_line, + end_line = r.end_line, + } +end + +local function line_text(buf, row) + local lines = vim.api.nvim_buf_get_lines(buf, row, row + 1, false) + return lines[1] or "" +end + +local function base_col_for_row(row, start_row, start_col) + if row == start_row then + return start_col + end + return 0 +end + +local function tokenize_line(text) + local tokens = {} + local i = 1 + local len = #text + while i <= len do + local ch = text:sub(i, i) + if ch:match("%s") then + i = i + 1 + elseif ch:match("[%w_]") then + local j = i + 1 + while j <= len and text:sub(j, j):match("[%w_]") do + j = j + 1 + end + table.insert(tokens, { text = text:sub(i, j - 1), start_col = i, end_col = j - 1 }) + i = j + else + -- Keep punctuation/granular symbols as single-char tokens so we can + -- align shared delimiters (e.g. closing quote) and avoid off-by-one spans. + table.insert(tokens, { text = ch, start_col = i, end_col = i }) + i = i + 1 + end + end + return tokens + end + +local function tokens_equal(a, b, rename_map) + if a.text == b.text then + return true + end + if rename_map and rename_map[a.text] == b.text then + return true + end + return false +end + +local function lcs_matches(a, b, rename_map) + local n = #a + local m = #b + if n == 0 or m == 0 then + return {}, {} + end + + local dp = {} + for i = 1, n + 1 do + dp[i] = {} + for j = 1, m + 1 do + dp[i][j] = 0 + end + end + + for i = n, 1, -1 do + for j = m, 1, -1 do + if tokens_equal(a[i], b[j], rename_map) then + dp[i][j] = dp[i + 1][j + 1] + 1 + else + local skip_src = dp[i + 1][j] + local skip_dst = dp[i][j + 1] + dp[i][j] = (skip_src >= skip_dst) and skip_src or skip_dst + end + end + end + + local match_a = {} + local match_b = {} + local i = 1 + local j = 1 + while i <= n and j <= m do + if tokens_equal(a[i], b[j], rename_map) and dp[i][j] == (dp[i + 1][j + 1] + 1) then + match_a[i] = true + match_b[j] = true + i = i + 1 + j = j + 1 + else + local skip_src = dp[i + 1][j] + local skip_dst = dp[i][j + 1] + if skip_dst > skip_src then + j = j + 1 + elseif skip_src > skip_dst then + i = i + 1 + else + -- Tie: keep destination earlier by advancing source. + i = i + 1 + end + end + end + return match_a, match_b +end + +local function unmatched_token_spans(tokens, matched, source_line) + local spans = {} + local i = 1 + while i <= #tokens do + if matched[i] then + i = i + 1 + else + local start_col = tokens[i].start_col + local end_col = tokens[i].end_col + local j = i + 1 + while j <= #tokens and not matched[j] do + local gap_start = end_col + 1 + local gap_end = tokens[j].start_col - 1 + local gap = "" + if gap_start <= gap_end then + gap = source_line:sub(gap_start, gap_end) + end + if gap ~= "" and not gap:match("^%s+$") then + break + end + end_col = tokens[j].end_col + j = j + 1 + end + table.insert(spans, { start_col = start_col, end_col = end_col }) + i = j + end + end + return spans +end + +local function make_range(row, start_col, end_col) + if row == nil or start_col == nil or end_col == nil or end_col <= start_col then + return nil + end + return { + start_row = row, + start_col = start_col, + end_row = row, + end_col = end_col, + start_line = row + 1, + end_line = row + 1, + } +end + +local function trimmed_range_for_line(buf, row, base_col, line_value) + if not line_value then + return nil + end + local first = line_value:find("%S") + if not first then + return nil + end + local rev_last = line_value:reverse():find("%S") + local last = #line_value - rev_last + 1 + return make_range(row, base_col + first - 1, base_col + last) +end + +local function hunk_change(src_range, dst_range) + return { kind = "change", src = src_range, dst = dst_range } +end + +local function hunk_insert(dst_range) + return { kind = "insert", src = nil, dst = dst_range } +end + +local function hunk_delete(src_range) + return { kind = "delete", src = src_range, dst = nil } +end + +local function range_contains(outer, inner) + if not outer or not inner then + return false + end + if outer.start_row == nil or outer.end_row == nil or outer.start_col == nil or outer.end_col == nil then + return false + end + if inner.start_row == nil or inner.end_row == nil or inner.start_col == nil or inner.end_col == nil then + return false + end + if inner.start_row < outer.start_row or inner.end_row > outer.end_row then + return false + end + if inner.start_row == outer.start_row and inner.start_col < outer.start_col then + return false + end + if inner.end_row == outer.end_row and inner.end_col > outer.end_col then + return false + end + return true +end + +local function clone_range_like(r) + return clone_range(r) +end + +local function collect_suppressed_rename_pairs(actions) + local pairs = {} + local declaration_pairs = {} + for _, action in ipairs(actions) do + if action.type == "rename" then + local metadata = action.metadata or {} + local old_name = metadata.old_name + local new_name = metadata.new_name + if action.context and action.context.declaration and old_name and new_name and old_name ~= new_name then + declaration_pairs[old_name] = new_name + end + + local suppressed = action.metadata and action.metadata.suppressed_renames or nil + if suppressed then + for _, usage in ipairs(suppressed) do + local src = clone_range_like(usage.src) + local dst = clone_range_like(usage.dst) + if src and dst then + table.insert(pairs, { src = src, dst = dst }) + end + end + end + end + end + return pairs, declaration_pairs +end + +local function text_for_range(buf, range) + if not buf or not range then + return nil + end + if range.start_row == nil or range.end_row == nil or range.start_col == nil or range.end_col == nil then + return nil + end + if range.start_row ~= range.end_row then + return nil + end + local line = line_text(buf, range.start_row) + if not line or line == "" then + return nil + end + local start_col = range.start_col + 1 + local end_col = range.end_col + if end_col < start_col then + return nil + end + return line:sub(start_col, end_col) +end + +local function is_suppressed_change_hunk(hunk_src, hunk_dst, suppressed_pairs, declaration_pairs, src_buf, dst_buf) + if not hunk_src or not hunk_dst then + return false + end + for _, pair in ipairs(suppressed_pairs) do + if range_contains(hunk_src, pair.src) and range_contains(hunk_dst, pair.dst) then + return true + end + end + if declaration_pairs and next(declaration_pairs) then + local src_text = text_for_range(src_buf, hunk_src) + local dst_text = text_for_range(dst_buf, hunk_dst) + if src_text and dst_text and declaration_pairs[src_text] == dst_text then + return true + end + end + return false +end + +local function whitespace_between(buf, row, left_end_col, right_start_col) + if not buf or row == nil or left_end_col == nil or right_start_col == nil then + return false + end + if right_start_col < left_end_col then + return false + end + if right_start_col == left_end_col then + return true + end + local line = line_text(buf, row) + if not line then + return false + end + local gap = line:sub(left_end_col + 1, right_start_col) + return gap:match("^%s*$") ~= nil +end + +local function can_merge_ranges(prev, cur, buf) + if not prev or not cur then + return false + end + if prev.start_row ~= prev.end_row or cur.start_row ~= cur.end_row then + return false + end + if prev.start_row ~= cur.start_row then + return false + end + if cur.start_col < prev.end_col then + return false + end + return whitespace_between(buf, prev.start_row, prev.end_col, cur.start_col) +end + +local function clone_hunk(h) + return { + kind = h.kind, + src = clone_range(h.src), + dst = clone_range(h.dst), + } +end + +local function hunk_start(h) + local r = h.src or h.dst + if not r then + return math.huge, math.huge + end + return r.start_row or math.huge, r.start_col or math.huge +end + +local function extend_range(base, incoming) + if not base or not incoming then + return + end + if incoming.end_row > base.end_row or (incoming.end_row == base.end_row and incoming.end_col > base.end_col) then + base.end_row = incoming.end_row + base.end_col = incoming.end_col + base.end_line = incoming.end_line + end +end + +local function can_merge_hunks(prev, cur, src_buf, dst_buf) + if not prev or not cur or prev.kind ~= cur.kind then + return false + end + if prev.kind == "change" then + return can_merge_ranges(prev.src, cur.src, src_buf) and can_merge_ranges(prev.dst, cur.dst, dst_buf) + end + if prev.kind == "insert" then + return can_merge_ranges(prev.dst, cur.dst, dst_buf) + end + if prev.kind == "delete" then + return can_merge_ranges(prev.src, cur.src, src_buf) + end + return false +end + +local function merge_adjacent_hunks(hunks, src_buf, dst_buf) + if not hunks or #hunks <= 1 then + return hunks or {} + end + local ordered = {} + for _, h in ipairs(hunks) do + table.insert(ordered, clone_hunk(h)) + end + table.sort(ordered, function(a, b) + local ar, ac = hunk_start(a) + local br, bc = hunk_start(b) + if ar == br then + return ac < bc + end + return ar < br + end) + local merged = {} + for _, h in ipairs(ordered) do + local prev = merged[#merged] + if prev and can_merge_hunks(prev, h, src_buf, dst_buf) then + extend_range(prev.src, h.src) + extend_range(prev.dst, h.dst) + else + table.insert(merged, h) + end + end + return merged +end + +local function refine_single_line_change_hunks(hunks, src_buf, dst_buf) + if not hunks or #hunks == 0 then + return hunks or {} + end + for _, hunk in ipairs(hunks) do + if + hunk.kind == "change" + and hunk.src + and hunk.dst + and hunk.src.start_row == hunk.src.end_row + and hunk.dst.start_row == hunk.dst.end_row + then + local src_text = text_for_range(src_buf, hunk.src) + local dst_text = text_for_range(dst_buf, hunk.dst) + if src_text and dst_text and src_text ~= dst_text then + local fragment = semantic.diff_fragment(src_text, dst_text) + if fragment then + local refined_src = make_range( + hunk.src.start_row, + hunk.src.start_col + math.max(fragment.old_start - 1, 0), + hunk.src.start_col + math.max(fragment.old_end, 0) + ) + local refined_dst = make_range( + hunk.dst.start_row, + hunk.dst.start_col + math.max(fragment.new_start - 1, 0), + hunk.dst.start_col + math.max(fragment.new_end, 0) + ) + if refined_src and refined_dst then + hunk.src = refined_src + hunk.dst = refined_dst + end + end + end + end + end + return hunks +end + +local function collapse_identical_insert_delete_hunks(hunks, src_buf, dst_buf) + if not hunks or #hunks == 0 then + return hunks or {} + end + + local deletes_by_text = {} + local inserts_by_text = {} + for idx, hunk in ipairs(hunks) do + if hunk.kind == "delete" and hunk.src and hunk.src.start_row == hunk.src.end_row then + local text = text_for_range(src_buf, hunk.src) + if text and text ~= "" then + deletes_by_text[text] = deletes_by_text[text] or {} + table.insert(deletes_by_text[text], idx) + end + elseif hunk.kind == "insert" and hunk.dst and hunk.dst.start_row == hunk.dst.end_row then + local text = text_for_range(dst_buf, hunk.dst) + if text and text ~= "" then + inserts_by_text[text] = inserts_by_text[text] or {} + table.insert(inserts_by_text[text], idx) + end + end + end + + local drop = {} + for text, delete_idxs in pairs(deletes_by_text) do + local insert_idxs = inserts_by_text[text] + if insert_idxs and #insert_idxs > 0 then + local pair_count = math.min(#delete_idxs, #insert_idxs) + for i = 1, pair_count do + drop[delete_idxs[i]] = true + drop[insert_idxs[i]] = true + end + end + end + + if next(drop) == nil then + return hunks + end + + local filtered = {} + for idx, hunk in ipairs(hunks) do + if not drop[idx] then + table.insert(filtered, hunk) + end + end + return filtered +end + +local function fallback_hunks_from_diff(src_node, dst_node, src_buf, dst_buf, rename_pairs) + local src_text = vim.treesitter.get_node_text(src_node, src_buf) + local dst_text = vim.treesitter.get_node_text(dst_node, dst_buf) + if not src_text or not dst_text then + return {} + end + + local src_lines = vim.split(src_text, "\n", { plain = true }) + local dst_lines = vim.split(dst_text, "\n", { plain = true }) + local ok, hunks = pcall(vim.text.diff, src_text, dst_text, { + result_type = "indices", + linematch = 60, + }) + if not ok or not hunks then + return {} + end + + local rename_map = rename_pairs or {} + local sr, sc = src_node:range() + local tr, tc = dst_node:range() + local out = {} + + local function push_change_line(src_row, dst_row, src_line, dst_line) + local src_base = base_col_for_row(src_row, sr, sc) + local dst_base = base_col_for_row(dst_row, tr, tc) + if not src_line or not dst_line or src_line == dst_line then + return + end + + local tokens_src = tokenize_line(src_line) + local tokens_dst = tokenize_line(dst_line) + if #tokens_src > 0 or #tokens_dst > 0 then + local match_src, match_dst = lcs_matches(tokens_src, tokens_dst, rename_map) + local src_spans = unmatched_token_spans(tokens_src, match_src, src_line) + local dst_spans = unmatched_token_spans(tokens_dst, match_dst, dst_line) + if #src_spans > 0 or #dst_spans > 0 then + local emitted = false + local span_count = math.min(#src_spans, #dst_spans) + + for i = 1, span_count do + local src_range = make_range( + src_row, + src_base + src_spans[i].start_col - 1, + src_base + src_spans[i].end_col + ) + local dst_range = make_range( + dst_row, + dst_base + dst_spans[i].start_col - 1, + dst_base + dst_spans[i].end_col + ) + if src_range and dst_range then + table.insert(out, hunk_change(src_range, dst_range)) + emitted = true + end + end + + for i = span_count + 1, #src_spans do + local src_range = make_range( + src_row, + src_base + src_spans[i].start_col - 1, + src_base + src_spans[i].end_col + ) + if src_range then + table.insert(out, hunk_change(src_range, nil)) + emitted = true + end + end + + for i = span_count + 1, #dst_spans do + local dst_range = make_range( + dst_row, + dst_base + dst_spans[i].start_col - 1, + dst_base + dst_spans[i].end_col + ) + if dst_range then + table.insert(out, hunk_change(nil, dst_range)) + emitted = true + end + end + + if emitted then + return + end + end + end + + local fragment = semantic.diff_fragment(src_line, dst_line) + if fragment then + local src_range = make_range( + src_row, + src_base + math.max(fragment.old_start - 1, 0), + src_base + math.max(fragment.old_end, 0) + ) + local dst_range = make_range( + dst_row, + dst_base + math.max(fragment.new_start - 1, 0), + dst_base + math.max(fragment.new_end, 0) + ) + if src_range and dst_range then + table.insert(out, hunk_change(src_range, dst_range)) + end + return + end + + local src_range = trimmed_range_for_line(src_buf, src_row, src_base, src_line) + local dst_range = trimmed_range_for_line(dst_buf, dst_row, dst_base, dst_line) + if src_range and dst_range then + table.insert(out, hunk_change(src_range, dst_range)) + end + end + + for _, h in ipairs(hunks) do + local start_a, count_a, start_b, count_b = h[1], h[2], h[3], h[4] + local overlap = math.min(count_a, count_b) + + for i = 0, overlap - 1 do + local src_row = sr + start_a - 1 + i + local dst_row = tr + start_b - 1 + i + push_change_line(src_row, dst_row, src_lines[start_a + i], dst_lines[start_b + i]) + end + + for i = overlap, count_a - 1 do + local src_row = sr + start_a - 1 + i + local src_line = src_lines[start_a + i] + local src_base = base_col_for_row(src_row, sr, sc) + local src_range = trimmed_range_for_line(src_buf, src_row, src_base, src_line) + if src_range then + table.insert(out, hunk_delete(src_range)) + end + end + + for i = overlap, count_b - 1 do + local dst_row = tr + start_b - 1 + i + local dst_line = dst_lines[start_b + i] + local dst_base = base_col_for_row(dst_row, tr, tc) + local dst_range = trimmed_range_for_line(dst_buf, dst_row, dst_base, dst_line) + if dst_range then + table.insert(out, hunk_insert(dst_range)) + end + end + end + + return out +end + +function M.enrich(actions, opts) + local src_buf = opts and opts.src_buf or nil + local dst_buf = opts and opts.dst_buf or nil + if not src_buf or not dst_buf then + return + end + local suppressed_pairs, declaration_pairs = collect_suppressed_rename_pairs(actions) + + for _, action in ipairs(actions) do + if action.type == "update" and action.src_node and action.dst_node then + local raw_leaf_changes = action.analysis and action.analysis.leaf_changes or {} + local rename_pairs = action.analysis and action.analysis.rename_pairs or {} + local normalized_leaf = {} + local hunks = {} + local has_non_rename_change = false + local saw_precise_candidate = false + + for _, change in ipairs(raw_leaf_changes) do + local src_range = range_metadata(change.src_node) + local dst_range = range_metadata(change.dst_node) + table.insert(normalized_leaf, { + src = clone_range(src_range), + dst = clone_range(dst_range), + src_text = change.src_text, + dst_text = change.dst_text, + }) + + if change.src_text ~= change.dst_text and rename_pairs[change.src_text] ~= change.dst_text then + has_non_rename_change = true + if src_range and dst_range then + saw_precise_candidate = true + end + local hunk_src = src_range + local hunk_dst = dst_range + + if + src_range + and dst_range + and src_range.start_row == src_range.end_row + and dst_range.start_row == dst_range.end_row + then + local fragment = semantic.diff_fragment(change.src_text or "", change.dst_text or "") + if fragment then + hunk_src = make_range( + src_range.start_row, + src_range.start_col + math.max(fragment.old_start - 1, 0), + src_range.start_col + math.max(fragment.old_end, 0) + ) + hunk_dst = make_range( + dst_range.start_row, + dst_range.start_col + math.max(fragment.new_start - 1, 0), + dst_range.start_col + math.max(fragment.new_end, 0) + ) + end + end + + if + hunk_src + and hunk_dst + and not is_suppressed_change_hunk( + hunk_src, + hunk_dst, + suppressed_pairs, + declaration_pairs, + src_buf, + dst_buf + ) + then + table.insert(hunks, hunk_change(hunk_src, hunk_dst)) + end + end + end + + local rename_only = (#normalized_leaf > 0) and (next(rename_pairs) ~= nil) and not has_non_rename_change + + if #hunks == 0 and not rename_only and not saw_precise_candidate then + hunks = fallback_hunks_from_diff(action.src_node, action.dst_node, src_buf, dst_buf, rename_pairs) + end + + if #hunks > 0 and (#suppressed_pairs > 0 or next(declaration_pairs) ~= nil) then + local filtered = {} + for _, hunk in ipairs(hunks) do + if + hunk.kind ~= "change" + or not is_suppressed_change_hunk( + hunk.src, + hunk.dst, + suppressed_pairs, + declaration_pairs, + src_buf, + dst_buf + ) + then + table.insert(filtered, hunk) + end + end + hunks = filtered + end + + hunks = merge_adjacent_hunks(hunks, src_buf, dst_buf) + hunks = refine_single_line_change_hunks(hunks, src_buf, dst_buf) + hunks = collapse_identical_insert_delete_hunks(hunks, src_buf, dst_buf) + + action.analysis = { + leaf_changes = normalized_leaf, + rename_pairs = rename_pairs, + hunks = hunks, + rename_only = rename_only, + } + end + end +end + +return M diff --git a/lua/diffmantic/core/bottom_up.lua b/lua/diffmantic/core/bottom_up.lua index 94caf90..fd7ca2e 100644 --- a/lua/diffmantic/core/bottom_up.lua +++ b/lua/diffmantic/core/bottom_up.lua @@ -1,8 +1,52 @@ local M = {} +local roles = require("diffmantic.core.roles") + +local function node_key(info) + if info.start_row ~= nil then + return info.start_row, info.start_col, info.end_row, info.end_col + end + local sr, sc, er, ec = info.node:range() + return sr, sc, er, ec +end + +local function compare_info_order(a_info, b_info) + local asr, asc, aer, aec = node_key(a_info) + local bsr, bsc, ber, bec = node_key(b_info) + if asr ~= bsr then + return asr < bsr + end + if asc ~= bsc then + return asc < bsc + end + if aer ~= ber then + return aer < ber + end + if aec ~= bec then + return aec < bec + end + return a_info.type < b_info.type +end -- Bottom-up matching: match nodes from leaves up, using parent mappings -- Tries to match nodes with the same type and label, and optionally name -function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src_buf, dst_buf) +function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src_buf, dst_buf, opts) + opts = opts or {} + local src_role_index = opts.src_role_index or roles.build_index(src_root, src_buf) + local dst_role_index = opts.dst_role_index or roles.build_index(dst_root, dst_buf) + local src_root_id = src_root:id() + + local node_text_cache = {} + local function node_text(node, bufnr) + local key = tostring(bufnr) .. ":" .. tostring(node:id()) + local cached = node_text_cache[key] + if cached ~= nil then + return cached + end + local text = vim.treesitter.get_node_text(node, bufnr) + node_text_cache[key] = text + return text + end + -- Build O(1) lookup tables local src_to_dst = {} local dst_to_src = {} @@ -12,11 +56,62 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src end -- Get the name of a declaration node (function or variable) - local function get_declaration_name(node, bufnr) - if node:type() == "class_specifier" or node:type() == "struct_specifier" or node:type() == "enum_specifier" or node:type() == "union_specifier" then + local function get_declaration_name(node, bufnr, role_index, name_cache) + local node_id = node:id() + local cached = name_cache[node_id] + if cached ~= nil then + return cached or nil + end + + local function cache_and_return(value) + name_cache[node_id] = value or false + return value + end + + local function find_first_identifier(n) + if not n then + return nil + end + if n:child_count() == 0 then + local t = n:type() + if t == "identifier" or t == "field_identifier" or t == "property_identifier" then + return n + end + return nil + end + for child in n:iter_children() do + local found = find_first_identifier(child) + if found then + return found + end + end + return nil + end + + local function_name = roles.get_kind_name_text(node, role_index, bufnr, "function") + if function_name and #function_name > 0 then + return cache_and_return(function_name) + end + + local class_name = roles.get_kind_name_text(node, role_index, bufnr, "class") + if class_name and #class_name > 0 then + return cache_and_return(class_name) + end + + local variable_name = roles.get_kind_name_text(node, role_index, bufnr, "variable") + if variable_name and #variable_name > 0 then + return cache_and_return(variable_name) + end + + if + node:type() == "class_specifier" + or node:type() == "struct_specifier" + or node:type() == "enum_specifier" + or node:type() == "union_specifier" + then local name_node = node:field("name")[1] or node:field("tag")[1] if name_node then - return vim.treesitter.get_node_text(name_node, bufnr) + return cache_and_return(node_text(name_node, bufnr)) end end @@ -27,7 +122,7 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src end local ntype = name_node:type() if ntype == "identifier" then - return vim.treesitter.get_node_text(name_node, bufnr) + return node_text(name_node, bufnr) end if ntype == "dot_index_expression" then local tbl = name_node:field("table")[1] @@ -47,21 +142,21 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src return left .. ":" .. right end end - return vim.treesitter.get_node_text(name_node, bufnr) + return node_text(name_node, bufnr) end local name_nodes = node:field("name") if name_nodes and name_nodes[1] then local full_name = lua_name_from_node(name_nodes[1]) if full_name and #full_name > 0 then - return full_name + return cache_and_return(full_name) end end end for child in node:iter_children() do if child:type() == "identifier" then - return vim.treesitter.get_node_text(child, bufnr) + return cache_and_return(node_text(child, bufnr)) end end @@ -73,7 +168,7 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src if subchild:type() == "variable_list" then for id_node in subchild:iter_children() do if id_node:type() == "identifier" then - return vim.treesitter.get_node_text(id_node, bufnr) + return cache_and_return(node_text(id_node, bufnr)) end end end @@ -86,10 +181,9 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src if node:type() == "function_definition" then for child in node:iter_children() do if child:type() == "function_declarator" then - for subchild in child:iter_children() do - if subchild:type() == "identifier" then - return vim.treesitter.get_node_text(subchild, bufnr) - end + local found = find_first_identifier(child) + if found then + return cache_and_return(node_text(found, bufnr)) end end end @@ -103,7 +197,7 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src if decl then for subchild in decl:iter_children() do if subchild:type() == "identifier" or subchild:type() == "field_identifier" then - return vim.treesitter.get_node_text(subchild, bufnr) + return cache_and_return(node_text(subchild, bufnr)) end end end @@ -117,7 +211,7 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src if decl then for subchild in decl:iter_children() do if subchild:type() == "identifier" or subchild:type() == "field_identifier" then - return vim.treesitter.get_node_text(subchild, bufnr) + return cache_and_return(node_text(subchild, bufnr)) end end end @@ -129,21 +223,26 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src if child:type() == "assignment" then for subchild in child:iter_children() do if subchild:type() == "identifier" then - return vim.treesitter.get_node_text(subchild, bufnr) + return cache_and_return(node_text(subchild, bufnr)) end end end end end - return nil + return cache_and_return(nil) end -- Try to extract a stable "value hash" for assignments to disambiguate renames. - local function get_assignment_value_hash(node, info) + local function get_assignment_value_hash(node, info, cache) if not node then return nil end + local node_id = node:id() + local cached = cache[node_id] + if cached ~= nil then + return cached or nil + end -- Python: expression_statement (assignment left: ..., right: ...) if node:type() == "expression_statement" then for child in node:iter_children() do @@ -157,11 +256,13 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src right = last end if right and info[right:id()] then - return info[right:id()].hash + cache[node_id] = info[right:id()].hash + return cache[node_id] end end end end + cache[node_id] = false return nil end @@ -240,19 +341,124 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src struct_specifier = true, } + local function is_identifier_type(info, role_index) + if + roles.has_structural_kind(info.node, role_index, "function") + or roles.has_structural_kind(info.node, role_index, "class") + or roles.has_structural_kind(info.node, role_index, "variable") + or roles.has_structural_kind(info.node, role_index, "assignment") + then + return true + end + return identifier_types[info.type] or false + end + + local function is_unique_structure_fallback_type(info, role_index) + if + roles.has_structural_kind(info.node, role_index, "function") + or roles.has_structural_kind(info.node, role_index, "class") + then + return true + end + return unique_structure_fallback_types[info.type] or false + end + + local src_ids = {} + for id in pairs(src_info) do + table.insert(src_ids, id) + end + table.sort(src_ids, function(a, b) + local ah = src_info[a].height or 0 + local bh = src_info[b].height or 0 + if ah == bh then + return compare_info_order(src_info[a], src_info[b]) + end + return ah < bh + end) + + local src_decl_name_cache = {} + local dst_decl_name_cache = {} + local src_value_hash_cache = {} + local dst_value_hash_cache = {} + local parent_candidates = {} + + local function candidate_signature(info) + return info.type .. "\x1f" .. info.label + end + + local function build_parent_candidates(dest_parent_id) + local state = { by_sig = {} } + local function push_child(child) + local child_id = child:id() + if dst_to_src[child_id] then + return + end + local d_info = dst_info[child_id] + if not d_info then + return + end + local sig = candidate_signature(d_info) + local queue = state.by_sig[sig] + if not queue then + queue = { head = 1, items = {} } + state.by_sig[sig] = queue + end + local items = queue.items + items[#items + 1] = child_id + end + + if dest_parent_id then + local d_parent = dst_info[dest_parent_id] and dst_info[dest_parent_id].node or nil + if d_parent then + for child in d_parent:iter_children() do + push_child(child) + end + end + else + for child in dst_root:iter_children() do + push_child(child) + end + end + return state + end + + local function queue_for_parent_sig(dest_parent_id, sig) + local key = dest_parent_id or 0 + local state = parent_candidates[key] + if not state then + state = build_parent_candidates(dest_parent_id) + parent_candidates[key] = state + end + return state.by_sig[sig] + end + + local function first_unmapped_candidate_id(queue) + if not queue then + return nil + end + local items = queue.items + local head = queue.head + while head <= #items and dst_to_src[items[head]] do + head = head + 1 + end + queue.head = head + return items[head] + end + -- Try to match unmapped nodes whose parent is mapped - for id, info in pairs(src_info) do + for _, id in ipairs(src_ids) do + local info = src_info[id] if not src_to_dst[id] then - local parent = info.parent + local parent_id = info.parent_id local parent_mapped = false local dest_parent_id = nil - if not parent then + if not parent_id then parent_mapped = true - elseif parent:id() == src_root:id() then + elseif parent_id == src_root_id then parent_mapped = true else - local dst_id = src_to_dst[parent:id()] + local dst_id = src_to_dst[parent_id] if dst_id then parent_mapped = true dest_parent_id = dst_id @@ -260,77 +466,68 @@ function M.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src end if parent_mapped then - local candidates = {} - if dest_parent_id then - local d_parent = dst_info[dest_parent_id].node - for child in d_parent:iter_children() do - if not dst_to_src[child:id()] then - table.insert(candidates, child) - end - end - else - for child in dst_root:iter_children() do - if not dst_to_src[child:id()] then - table.insert(candidates, child) - end - end - end + local queue = queue_for_parent_sig(dest_parent_id, candidate_signature(info)) + local candidates = queue and queue.items or nil local src_name = nil - if identifier_types[info.type] then - src_name = get_declaration_name(info.node, src_buf) + if is_identifier_type(info, src_role_index) then + src_name = get_declaration_name(info.node, src_buf, src_role_index, src_decl_name_cache) end - local src_value_hash = get_assignment_value_hash(info.node, src_info) + local src_value_hash = get_assignment_value_hash(info.node, src_info, src_value_hash_cache) local rename_candidate = nil local structure_candidates = {} local rename_score = -1 local rename_tie = false - for _, cand in ipairs(candidates) do - local d_info = dst_info[cand:id()] - if d_info.type == info.type and d_info.label == info.label then - if src_name then - local dst_name = get_declaration_name(cand, dst_buf) + if not src_name then + local candidate_id = first_unmapped_candidate_id(queue) + if candidate_id then + table.insert(mappings, { src = id, dst = candidate_id }) + src_to_dst[id] = candidate_id + dst_to_src[candidate_id] = id + end + elseif candidates then + local start_idx = queue and queue.head or 1 + for i = start_idx, #candidates do + local candidate_id = candidates[i] + if not dst_to_src[candidate_id] then + local cand = dst_info[candidate_id].node + local d_info = dst_info[candidate_id] + local dst_name = get_declaration_name(cand, dst_buf, dst_role_index, dst_decl_name_cache) if src_name == dst_name then - table.insert(mappings, { src = id, dst = cand:id() }) - src_to_dst[id] = cand:id() - dst_to_src[cand:id()] = id + table.insert(mappings, { src = id, dst = candidate_id }) + src_to_dst[id] = candidate_id + dst_to_src[candidate_id] = id rename_candidate = nil break elseif dst_name and src_info[id].structure_hash == d_info.structure_hash then - local dst_value_hash = get_assignment_value_hash(cand, dst_info) + local dst_value_hash = get_assignment_value_hash(cand, dst_info, dst_value_hash_cache) if src_value_hash and dst_value_hash and src_value_hash ~= dst_value_hash then goto continue_candidate end - table.insert(structure_candidates, cand:id()) + table.insert(structure_candidates, candidate_id) local score = name_similarity(src_name, dst_name) if score < 0.8 then goto continue_candidate end if score > rename_score then - rename_candidate = cand:id() + rename_candidate = candidate_id rename_score = score rename_tie = false elseif score == rename_score and score > 0 then rename_tie = true end end - else - table.insert(mappings, { src = id, dst = cand:id() }) - src_to_dst[id] = cand:id() - dst_to_src[cand:id()] = id - rename_candidate = nil - break end + ::continue_candidate:: end - ::continue_candidate:: end if not src_to_dst[id] and rename_candidate and not rename_tie and rename_score > 0 then table.insert(mappings, { src = id, dst = rename_candidate }) src_to_dst[id] = rename_candidate dst_to_src[rename_candidate] = id - elseif not src_to_dst[id] and unique_structure_fallback_types[info.type] then + elseif not src_to_dst[id] and is_unique_structure_fallback_type(info, src_role_index) then if #structure_candidates == 1 then local candidate_id = structure_candidates[1] table.insert(mappings, { src = id, dst = candidate_id }) diff --git a/lua/diffmantic/core/recovery.lua b/lua/diffmantic/core/recovery.lua index b55d4b2..4a40094 100644 --- a/lua/diffmantic/core/recovery.lua +++ b/lua/diffmantic/core/recovery.lua @@ -1,8 +1,40 @@ local M = {} --- Recovery matching: tries to match remaining unmapped nodes using LCS and unique type -function M.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_buf, dst_buf) - -- Build O(1) lookup tables +local function node_key(info) + if info.start_row ~= nil then + return info.start_row, info.start_col, info.end_row, info.end_col + end + local sr, sc, er, ec = info.node:range() + return sr, sc, er, ec +end + +local function compare_info_order(a_info, b_info) + local asr, asc, aer, aec = node_key(a_info) + local bsr, bsc, ber, bec = node_key(b_info) + if asr ~= bsr then + return asr < bsr + end + if asc ~= bsc then + return asc < bsc + end + if aer ~= ber then + return aer < ber + end + if aec ~= bec then + return aec < bec + end + return a_info.type < b_info.type +end + +-- Recovery matching: tries to match remaining unmapped nodes using LCS and unique type. +function M.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_buf, dst_buf, opts) + opts = opts or {} + local lcs_cell_limit = opts.recovery_lcs_cell_limit or 6000 + local skip_unique_type_match = { + field_declaration = true, + } + + -- Build O(1) lookup tables. local src_to_dst = {} local dst_to_src = {} for _, m in ipairs(mappings) do @@ -10,82 +42,139 @@ function M.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_ dst_to_src[m.dst] = m.src end - -- Longest Common Subsequence (LCS) for matching children + local function can_match(src_node, dst_node, hash_key) + local s = src_info[src_node:id()] + local d = dst_info[dst_node:id()] + if not s or not d then + return false + end + return s[hash_key] == d[hash_key] and src_node:type() == dst_node:type() + end + + local function greedy_lcs(src_list, dst_list, hash_key) + local result = {} + local j = 1 + for i = 1, #src_list do + local src_node = src_list[i] + while j <= #dst_list do + local dst_node = dst_list[j] + j = j + 1 + if can_match(src_node, dst_node, hash_key) then + table.insert(result, { src = src_node, dst = dst_node }) + break + end + end + end + return result + end + + -- Longest Common Subsequence (LCS) for matching children. + -- Reconstructed left-to-right so duplicate-compatible nodes prefer earlier dst siblings. local function lcs(src_list, dst_list, hash_key) local m, n = #src_list, #dst_list if m == 0 or n == 0 then return {} end + if (m * n) > lcs_cell_limit then + return greedy_lcs(src_list, dst_list, hash_key) + end local dp = {} - for i = 0, m do + for i = 1, m + 1 do dp[i] = {} - for j = 0, n do + for j = 1, n + 1 do dp[i][j] = 0 end end - for i = 1, m do - for j = 1, n do - local s, d = src_list[i], dst_list[j] - if src_info[s:id()][hash_key] == dst_info[d:id()][hash_key] and s:type() == d:type() then - dp[i][j] = dp[i - 1][j - 1] + 1 + for i = m, 1, -1 do + for j = n, 1, -1 do + if can_match(src_list[i], dst_list[j], hash_key) then + dp[i][j] = dp[i + 1][j + 1] + 1 else - dp[i][j] = math.max(dp[i - 1][j], dp[i][j - 1]) + dp[i][j] = math.max(dp[i + 1][j], dp[i][j + 1]) end end end - -- Backtrack to find matches + -- Deterministic left-biased reconstruction. local result = {} - local i, j = m, n - while i > 0 and j > 0 do - local s, d = src_list[i], dst_list[j] - if src_info[s:id()][hash_key] == dst_info[d:id()][hash_key] and s:type() == d:type() then - table.insert(result, 1, { src = s, dst = d }) - i, j = i - 1, j - 1 - elseif dp[i - 1][j] > dp[i][j - 1] then - i = i - 1 + local i, j = 1, 1 + while i <= m and j <= n do + if can_match(src_list[i], dst_list[j], hash_key) and dp[i][j] == (dp[i + 1][j + 1] + 1) then + table.insert(result, { src = src_list[i], dst = dst_list[j] }) + i = i + 1 + j = j + 1 else - j = j - 1 + local skip_src = dp[i + 1][j] + local skip_dst = dp[i][j + 1] + if skip_dst > skip_src then + j = j + 1 + elseif skip_src > skip_dst then + i = i + 1 + else + -- Tie: advance src to keep earlier destination candidates. + i = i + 1 + end end end return result end - -- Helper to add a mapping and update lookup tables - local function add_mapping(src_id, dst_id) - table.insert(mappings, { src = src_id, dst = dst_id }) - src_to_dst[src_id] = dst_id - dst_to_src[dst_id] = src_id - end - - -- Try to match children using LCS and unique type - local function simple_recovery(src_node, dst_node) - local src_children, dst_children = {}, {} + local function has_unmatched_on_both_sides(src_node, dst_node) + local src_has = false for child in src_node:iter_children() do if not src_to_dst[child:id()] then - table.insert(src_children, child) + src_has = true + break end end + if not src_has then + return false + end for child in dst_node:iter_children() do if not dst_to_src[child:id()] then - table.insert(dst_children, child) + return true end end - if #src_children == 0 or #dst_children == 0 then + return false + end + + local pending = {} + local queued = {} + + local function queue_key(src_id, dst_id) + return tostring(src_id) .. ":" .. tostring(dst_id) + end + + local function enqueue(src_id, dst_id) + local key = queue_key(src_id, dst_id) + if queued[key] then return end + queued[key] = true + pending[#pending + 1] = { src_id = src_id, dst_id = dst_id, key = key } + end - -- Step 1: match children with same hash (exact match) - for _, match in ipairs(lcs(src_children, dst_children, "hash")) do - if not src_to_dst[match.src:id()] and not dst_to_src[match.dst:id()] then - add_mapping(match.src:id(), match.dst:id()) - end + -- Helper to add a mapping and update lookup tables. + local function add_mapping(src_id, dst_id) + if src_to_dst[src_id] or dst_to_src[dst_id] then + return false end + table.insert(mappings, { src = src_id, dst = dst_id }) + src_to_dst[src_id] = dst_id + dst_to_src[dst_id] = src_id + local src_entry = src_info[src_id] + local dst_entry = dst_info[dst_id] + if src_entry and dst_entry and has_unmatched_on_both_sides(src_entry.node, dst_entry.node) then + enqueue(src_id, dst_id) + end + return true + end - -- Step 2: match children with same structure_hash (for updates) - src_children, dst_children = {}, {} + local function unmatched_children(src_node, dst_node) + local src_children = {} + local dst_children = {} for child in src_node:iter_children() do if not src_to_dst[child:id()] then table.insert(src_children, child) @@ -96,25 +185,34 @@ function M.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_ table.insert(dst_children, child) end end - for _, match in ipairs(lcs(src_children, dst_children, "structure_hash")) do - if not src_to_dst[match.src:id()] and not dst_to_src[match.dst:id()] then - add_mapping(match.src:id(), match.dst:id()) - end + return src_children, dst_children + end + + -- Try to match children using LCS and unique type. + local function simple_recovery(src_node, dst_node) + local src_children, dst_children = unmatched_children(src_node, dst_node) + if #src_children == 0 or #dst_children == 0 then + return end - -- Step 3: match children with unique type (type appears only once) - src_children, dst_children = {}, {} - for child in src_node:iter_children() do - if not src_to_dst[child:id()] then - table.insert(src_children, child) + -- Step 1: match children with same hash (exact match). + for _, match in ipairs(lcs(src_children, dst_children, "hash")) do + if add_mapping(match.src:id(), match.dst:id()) then + simple_recovery(match.src, match.dst) end end - for child in dst_node:iter_children() do - if not dst_to_src[child:id()] then - table.insert(dst_children, child) + + -- Step 2: match children with same structure_hash (for updates). + src_children, dst_children = unmatched_children(src_node, dst_node) + for _, match in ipairs(lcs(src_children, dst_children, "structure_hash")) do + if match.src:type() ~= "field_declaration" and add_mapping(match.src:id(), match.dst:id()) then + simple_recovery(match.src, match.dst) end end + -- Step 3: match children with unique type (type appears only once). + src_children, dst_children = unmatched_children(src_node, dst_node) + local src_by_type, dst_by_type = {}, {} local src_type_count, dst_type_count = {}, {} for _, c in ipairs(src_children) do @@ -129,21 +227,31 @@ function M.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_ end for t, count in pairs(src_type_count) do - if count == 1 and dst_type_count[t] == 1 then + if count == 1 and dst_type_count[t] == 1 and not skip_unique_type_match[t] then local s, d = src_by_type[t], dst_by_type[t] - if not src_to_dst[s:id()] and not dst_to_src[d:id()] then - add_mapping(s:id(), d:id()) + if add_mapping(s:id(), d:id()) then simple_recovery(s, d) end end end end - -- Apply recovery to all mapped nodes - for id, info in pairs(src_info) do - local dst_id = src_to_dst[id] - if dst_id then - simple_recovery(info.node, dst_info[dst_id].node) + -- Seed worklist only with mapped pairs that still have unmatched children on both sides. + for _, mapping in ipairs(mappings) do + local src_entry = src_info[mapping.src] + local dst_entry = dst_info[mapping.dst] + if src_entry and dst_entry and has_unmatched_on_both_sides(src_entry.node, dst_entry.node) then + enqueue(mapping.src, mapping.dst) + end + end + + while #pending > 0 do + local pair = table.remove(pending) + queued[pair.key] = nil + local src_entry = src_info[pair.src_id] + local dst_entry = dst_info[pair.dst_id] + if src_entry and dst_entry and has_unmatched_on_both_sides(src_entry.node, dst_entry.node) then + simple_recovery(src_entry.node, dst_entry.node) end end diff --git a/lua/diffmantic/core/roles.lua b/lua/diffmantic/core/roles.lua new file mode 100644 index 0000000..75d5ea9 --- /dev/null +++ b/lua/diffmantic/core/roles.lua @@ -0,0 +1,210 @@ +local M = {} + +local CAPTURE_BY_KIND = { + ["function"] = { + "diff.function.outer", + "diff.function.name", + "diff.function.body", + }, + ["class"] = { + "diff.class.outer", + "diff.class.name", + "diff.class.body", + }, + ["variable"] = { + "diff.variable.outer", + "diff.variable.name", + }, + ["assignment"] = { + "diff.assignment.outer", + "diff.assignment.lhs", + "diff.assignment.rhs", + }, + ["import"] = { "diff.import.outer" }, + ["return"] = { "diff.return.outer" }, + ["preproc"] = { "diff.preproc.outer" }, + ["rename_identifier"] = { "diff.identifier.rename" }, +} + +local STRUCTURAL_CAPTURE_BY_KIND = { + ["function"] = { "diff.function.outer" }, + ["class"] = { "diff.class.outer" }, + ["variable"] = { "diff.variable.outer" }, + ["assignment"] = { "diff.assignment.outer" }, + ["import"] = { "diff.import.outer" }, + ["return"] = { "diff.return.outer" }, + ["preproc"] = { "diff.preproc.outer" }, +} + +local FALLBACK_QUERY = "((_) @diff.fallback.node)" + +local function add_capture(index, capture, node) + local id = node:id() + index.by_node[id] = index.by_node[id] or {} + index.by_node[id][capture] = true + + index.by_capture[capture] = index.by_capture[capture] or {} + if not index.by_capture[capture][id] then + index.by_capture[capture][id] = node + index.by_capture_list[capture] = index.by_capture_list[capture] or {} + table.insert(index.by_capture_list[capture], node) + end +end + +local function resolve_lang(bufnr) + local ft = vim.bo[bufnr].filetype + if not ft or ft == "" then + return nil + end + return vim.treesitter.language.get_lang(ft) or ft +end + +local function get_query(lang) + local ok, query = pcall(vim.treesitter.query.get, lang, "diffmantic") + if ok and query then + return query + end + + local parsed_ok, parsed = pcall(vim.treesitter.query.parse, lang, FALLBACK_QUERY) + if parsed_ok and parsed then + return parsed + end + + return nil +end + +function M.build_index(root, bufnr) + local lang = resolve_lang(bufnr) + if not lang then + return nil + end + + local query = get_query(lang) + if not query then + return nil + end + + local index = { + lang = lang, + by_node = {}, + by_capture = {}, + by_capture_list = {}, + _descendant_capture_cache = {}, + _kind_name_node_cache = {}, + _kind_name_text_cache = {}, + } + + for id, node in query:iter_captures(root, bufnr, 0, -1) do + local capture = query.captures[id] + if capture then + add_capture(index, capture, node) + end + end + + return index +end + +function M.has_capture(node, index, capture) + if not node or not index then + return false + end + local by_node = index.by_node[node:id()] + return by_node and by_node[capture] or false +end + +function M.find_descendant_with_capture(node, index, capture) + if not node or not index then + return nil + end + + local cache_key = capture .. ":" .. tostring(node:id()) + local cache = index._descendant_capture_cache + local cached = cache[cache_key] + if cached ~= nil then + return cached or nil + end + + local by_capture = index.by_capture_list[capture] + if not by_capture then + cache[cache_key] = false + return nil + end + + for _, captured in ipairs(by_capture) do + if node:equal(captured) or node:child_with_descendant(captured) then + cache[cache_key] = captured + return captured + end + end + + cache[cache_key] = false + return nil +end + +function M.has_kind(node, index, kind) + local captures = CAPTURE_BY_KIND[kind] + if not captures then + return false + end + + for _, capture in ipairs(captures) do + if M.has_capture(node, index, capture) then + return true + end + end + + return false +end + +function M.has_structural_kind(node, index, kind) + local captures = STRUCTURAL_CAPTURE_BY_KIND[kind] + if not captures then + return M.has_kind(node, index, kind) + end + + for _, capture in ipairs(captures) do + if M.has_capture(node, index, capture) then + return true + end + end + + return false +end + +function M.get_kind_name_node(node, index, kind) + if not node or not index then + return nil + end + local key = kind .. ":" .. tostring(node:id()) + local cache = index._kind_name_node_cache + local cached = cache[key] + if cached ~= nil then + return cached or nil + end + local capture = string.format("diff.%s.name", kind) + local found = M.find_descendant_with_capture(node, index, capture) + cache[key] = found or false + return found +end + +function M.get_kind_name_text(node, index, bufnr, kind) + if not node or not index then + return nil + end + local key = kind .. ":" .. tostring(node:id()) + local cache = index._kind_name_text_cache + local cached = cache[key] + if cached ~= nil then + return cached or nil + end + local name_node = M.get_kind_name_node(node, index, kind) + if not name_node then + cache[key] = false + return nil + end + local text = vim.treesitter.get_node_text(name_node, bufnr) + cache[key] = text or false + return text +end + +return M diff --git a/lua/diffmantic/core/semantic.lua b/lua/diffmantic/core/semantic.lua new file mode 100644 index 0000000..ada78dd --- /dev/null +++ b/lua/diffmantic/core/semantic.lua @@ -0,0 +1,372 @@ +local M = {} +local roles = require("diffmantic.core.roles") + +local function in_class_like_context(node, role_index) + if not node then + return false + end + local cur = node + while cur do + local ctype = cur:type() + if ctype == "class_specifier" or ctype == "struct_specifier" or ctype == "union_specifier" or ctype == "class_definition" then + return true + end + if role_index and roles.has_kind(cur, role_index, "class") then + return true + end + cur = cur:parent() + end + return false +end + +function M.is_ambiguous_member_change(node, role_index) + if not node then + return false + end + local node_type = node:type() + if node_type ~= "identifier" and node_type ~= "field_identifier" and node_type ~= "type_identifier" then + return false + end + + local cur = node + local parent = node:parent() + while parent do + local ptype = parent:type() + if ptype == "field_declaration" or ptype == "field_declarator" then + return true + end + if ptype == "declarator" then + local grandparent = parent:parent() + if grandparent and (grandparent:type() == "field_declaration" or grandparent:type() == "field_declarator") then + return true + end + end + if (ptype == "class_specifier" or ptype == "struct_specifier" or ptype == "union_specifier" or ptype == "enum_specifier") + and not M.node_in_field(parent, "name", cur) + and not M.node_in_field(parent, "tag", cur) + and node_type == "field_identifier" + then + return true + end + cur = parent + parent = parent:parent() + end + + if role_index and in_class_like_context(node, role_index) and node_type == "field_identifier" then + return true + end + + return false +end + +-- Leaf-level diffs for small updates; otherwise return empty. +function M.find_leaf_changes(src_node, dst_node, src_buf, dst_buf, src_role_index, dst_role_index) + local changes = {} + + local function get_all_leaves(node, bufnr) + local leaves = {} + local function traverse(n) + if n:child_count() == 0 then + table.insert(leaves, { + node = n, + text = vim.treesitter.get_node_text(n, bufnr), + type = n:type(), + }) + else + for child in n:iter_children() do + traverse(child) + end + end + end + traverse(node) + return leaves + end + + local src_leaves = get_all_leaves(src_node, src_buf) + local dst_leaves = get_all_leaves(dst_node, dst_buf) + + if #src_leaves ~= #dst_leaves then + return {} + end + + if math.abs(#src_leaves - #dst_leaves) > 2 then + return {} + end + + local min_len = math.min(#src_leaves, #dst_leaves) + local max_len = math.max(#src_leaves, #dst_leaves) + local same_count = 0 + + for i = 1, min_len do + local sl, dl = src_leaves[i], dst_leaves[i] + if sl.type == dl.type and sl.text == dl.text then + same_count = same_count + 1 + elseif sl.type == dl.type and sl.text ~= dl.text then + local src_ambiguous = M.is_ambiguous_member_change(sl.node, src_role_index) + local dst_ambiguous = M.is_ambiguous_member_change(dl.node, dst_role_index) + if src_ambiguous or dst_ambiguous then + -- Member declaration internals are structurally noisy; avoid rename/update inference. + same_count = same_count + 1 + else + table.insert(changes, { + src_node = sl.node, + dst_node = dl.node, + src_text = sl.text, + dst_text = dl.text, + }) + end + end + end + + local similarity = same_count / max_len + + if similarity < 0.5 then + return {} + end + + if #changes > 5 then + return {} + end + + return changes +end + +function M.node_in_field(parent, field_name, node) + local nodes = parent:field(field_name) + if not nodes then + return false + end + for _, field_node in ipairs(nodes) do + if field_node:equal(node) or field_node:child_with_descendant(node) then + return true + end + end + return false +end + +function M.is_rename_identifier(node, role_index) + if not node then + return false + end + + local function is_class_name_node(n, index) + return index and roles.has_capture(n, index, "diff.class.name") or false + end + + if M.is_ambiguous_member_change(node, role_index) then + return false + end + + if role_index and roles.has_kind(node, role_index, "rename_identifier") then + -- Allow renaming struct/type/class declaration names, but not members inside them. + if in_class_like_context(node, role_index) and not is_class_name_node(node, role_index) then + return false + end + return true + end + + local node_type = node:type() + if node_type ~= "identifier" and node_type ~= "type_identifier" and node_type ~= "field_identifier" then + return false + end + + local parent = node:parent() + if not parent then + return false + end + + local parent_type = parent:type() + local parameter_parent_kinds = { + parameters = true, + parameter_list = true, + formal_parameters = true, + required_parameter = true, + optional_parameter = true, + rest_parameter = true, + } + if parameter_parent_kinds[parent_type] then + return true + end + if parent_type == "parameter_declaration" and node_type == "identifier" then + return true + end + + if parent_type == "assignment" and M.node_in_field(parent, "left", node) then + return true + end + + if parent_type == "assignment_statement" and M.node_in_field(parent, "variable", node) then + return true + end + if parent_type == "variable_list" then + return true + end + + -- Language-specific name heuristics. + local current = node + while parent do + local ptype = parent:type() + if (ptype == "function_declaration" or ptype == "function_definition" or ptype == "class_definition" or ptype == "class_declaration") + and M.node_in_field(parent, "name", current) + then + return true + end + if (ptype == "class_specifier" or ptype == "struct_specifier" or ptype == "enum_specifier" or ptype == "union_specifier") + and (M.node_in_field(parent, "name", current) or M.node_in_field(parent, "tag", current)) + then + return true + end + if ptype == "function_declarator" then + return true + end + if ptype == "init_declarator" and M.node_in_field(parent, "declarator", current) then + return true + end + if ptype == "field_declaration" and M.node_in_field(parent, "declarator", current) then + -- C/C++ member fields often have no value context; rename inference here is noisy. + return false + end + if ptype == "declarator" then + local grandparent = parent:parent() + if grandparent and grandparent:type() == "field_declaration" then + return false + end + return true + end + if ptype == "field" then + if M.node_in_field(parent, "name", current) or M.node_in_field(parent, "key", current) then + return true + end + end + current = parent + parent = parent:parent() + end + + return false +end + +function M.is_value_node(node, text) + local node_type = node and node:type() or "" + if node_type:find("string") or node_type:find("number") or node_type:find("integer") or node_type:find("float") or node_type:find("boolean") then + return true + end + if node_type == "char_literal" or node_type:find("char") and node_type:find("literal") then + return true + end + if text then + if text:match("^%s*['\"].*['\"]%s*$") then + return true + end + if text:match("^%s*[%d%.]+%s*$") then + return true + end + if text == "true" or text == "false" or text == "nil" then + return true + end + end + return false +end + +local function expand_word_fragment(text, start_idx, end_idx) + local s = start_idx + local e = end_idx + while s > 1 and text:sub(s - 1, s - 1):match("[%w_]") do + s = s - 1 + end + while e <= #text and text:sub(e, e):match("[%w_]") do + e = e + 1 + end + return s, e - 1 +end + +local function expand_assignment_suffix(text, end_idx) + local i = end_idx + 1 + while i <= #text and text:sub(i, i):match("%s") do + i = i + 1 + end + if i <= #text and text:sub(i, i) == "=" then + local j = i + 1 + while j <= #text and text:sub(j, j):match("%s") do + j = j + 1 + end + return j - 1 + end + return end_idx +end + +function M.classify_text_change(old_text, new_text) + if old_text == new_text then + return nil + end + + local max_prefix = math.min(#old_text, #new_text) + local prefix = 0 + while prefix < max_prefix and old_text:sub(prefix + 1, prefix + 1) == new_text:sub(prefix + 1, prefix + 1) do + prefix = prefix + 1 + end + + local max_suffix = math.min(#old_text - prefix, #new_text - prefix) + local suffix = 0 + while suffix < max_suffix do + local o = old_text:sub(#old_text - suffix, #old_text - suffix) + local n = new_text:sub(#new_text - suffix, #new_text - suffix) + if o ~= n then + break + end + suffix = suffix + 1 + end + + local old_start = prefix + 1 + local old_end = #old_text - suffix + local new_start = prefix + 1 + local new_end = #new_text - suffix + + local has_old = old_start <= old_end + local has_new = new_start <= new_end + if not has_old and not has_new then + return nil + end + + if has_old then + old_start, old_end = expand_word_fragment(old_text, old_start, old_end) + old_end = expand_assignment_suffix(old_text, old_end) + end + if has_new then + new_start, new_end = expand_word_fragment(new_text, new_start, new_end) + new_end = expand_assignment_suffix(new_text, new_end) + end + + local kind = "replace" + if has_new and not has_old then + kind = "insert" + elseif has_old and not has_new then + kind = "delete" + end + + return { + kind = kind, + old_start = old_start, + old_end = old_end, + new_start = new_start, + new_end = new_end, + old_fragment = has_old and old_text:sub(old_start, old_end) or "", + new_fragment = has_new and new_text:sub(new_start, new_end) or "", + } +end + +function M.diff_fragment(old_text, new_text) + local change = M.classify_text_change(old_text, new_text) + if not change or change.kind ~= "replace" then + return nil + end + return { + old_start = change.old_start, + old_end = change.old_end, + new_start = change.new_start, + new_end = change.new_end, + old_fragment = change.old_fragment, + new_fragment = change.new_fragment, + } +end + +return M diff --git a/lua/diffmantic/core/top_down.lua b/lua/diffmantic/core/top_down.lua index 727c005..ab33094 100644 --- a/lua/diffmantic/core/top_down.lua +++ b/lua/diffmantic/core/top_down.lua @@ -2,15 +2,41 @@ local ts_utils = require("diffmantic.treesitter") local M = {} +local function node_key(info) + return info.start_row, info.start_col, info.end_row, info.end_col +end + +local function compare_node_order(a, b) + local asr, asc, aer, aec = node_key(a) + local bsr, bsc, ber, bec = node_key(b) + if asr ~= bsr then + return asr < bsr + end + if asc ~= bsc then + return asc < bsc + end + if aer ~= ber then + return aer < ber + end + if aec ~= bec then + return aec < bec + end + return a.type < b.type +end + -- Top-down matching: match nodes from the top of the tree downwards -- Matches nodes with the same hash at each height level -function M.top_down_match(src_root, dst_root, src_buf, dst_buf) +function M.top_down_match(src_root, dst_root, src_buf, dst_buf, opts) + opts = opts or {} local mappings = {} - local src_info = ts_utils.preprocess_tree(src_root, src_buf) - local dst_info = ts_utils.preprocess_tree(dst_root, dst_buf) + local src_info = opts.src_info or ts_utils.preprocess_tree(src_root, src_buf, opts) + local dst_info = opts.dst_info or ts_utils.preprocess_tree(dst_root, dst_buf, opts) local src_mapped = {} local dst_mapped = {} + local src_to_dst = {} + local src_root_id = src_root:id() + local dst_root_id = dst_root:id() -- Group nodes by their height in the tree local function get_nodes_by_height(info) @@ -21,11 +47,38 @@ function M.top_down_match(src_root, dst_root, src_buf, dst_buf) end table.insert(by_height[data.height], data) end + + for _, nodes in pairs(by_height) do + table.sort(nodes, compare_node_order) + end + return by_height end local src_by_height = get_nodes_by_height(src_info) local dst_by_height = get_nodes_by_height(dst_info) + local function dst_parent_key(info) + local parent_id = info.parent_id + if not parent_id then + return 0 + end + if parent_id == dst_root_id then + return dst_root_id + end + return parent_id + end + + local function src_parent_key(info) + local parent_id = info.parent_id + if not parent_id then + return 0 + end + if parent_id == src_root_id then + return dst_root_id + end + return src_to_dst[parent_id] or -1 + end + -- Find the maximum height in both trees local max_h = 0 for h in pairs(src_by_height) do @@ -39,32 +92,48 @@ function M.top_down_match(src_root, dst_root, src_buf, dst_buf) end end - -- For each height, match nodes with the same hash using hash indexing + -- For each height, match nodes with the same hash and compatible mapped parent. + -- Parent buckets avoid O(k) scans through all same-hash candidates. for h = max_h, 1, -1 do local s_nodes = src_by_height[h] or {} local d_nodes = dst_by_height[h] or {} - local dst_by_hash = {} + local dst_by_hash_parent = {} for _, d in ipairs(d_nodes) do if not dst_mapped[d.id] then - if not dst_by_hash[d.hash] then - dst_by_hash[d.hash] = {} + local hash_buckets = dst_by_hash_parent[d.hash] + if not hash_buckets then + hash_buckets = {} + dst_by_hash_parent[d.hash] = hash_buckets + end + + local pkey = dst_parent_key(d) + local queue = hash_buckets[pkey] + if not queue then + queue = { head = 1, items = {} } + hash_buckets[pkey] = queue end - table.insert(dst_by_hash[d.hash], d) + local items = queue.items + items[#items + 1] = d end end for _, s in ipairs(s_nodes) do if not src_mapped[s.id] then - local candidates = dst_by_hash[s.hash] - if candidates then - for i, d in ipairs(candidates) do - if not dst_mapped[d.id] then + local pkey = src_parent_key(s) + if pkey ~= -1 then + local hash_buckets = dst_by_hash_parent[s.hash] + local queue = hash_buckets and hash_buckets[pkey] or nil + if queue then + local head = queue.head + local items = queue.items + local d = items[head] + if d then + queue.head = head + 1 table.insert(mappings, { src = s.id, dst = d.id }) src_mapped[s.id] = true dst_mapped[d.id] = true - table.remove(candidates, i) - break + src_to_dst[s.id] = d.id end end end diff --git a/lua/diffmantic/init.lua b/lua/diffmantic/init.lua index 4e43eeb..27f26d5 100644 --- a/lua/diffmantic/init.lua +++ b/lua/diffmantic/init.lua @@ -2,24 +2,85 @@ local M = {} local core = require("diffmantic.core") local ui = require("diffmantic.ui") local debug_utils = require("diffmantic.debug_utils") +local roles = require("diffmantic.core.roles") + +local function hl(name) + local ok, value = pcall(vim.api.nvim_get_hl, 0, { name = name, link = false }) + if not ok then + return {} + end + return value or {} +end + +local function pick_bg(names) + for _, name in ipairs(names) do + local value = hl(name).bg + if value then + return value + end + end + return nil +end + +local function pick_fg(names, fallback) + for _, name in ipairs(names) do + local value = hl(name).fg + if value then + return value + end + end + return fallback +end local function setup_highlights() - local add_fg = vim.api.nvim_get_hl(0, { name = "DiffAdd" }).fg or 0xa6e3a1 - local delete_fg = vim.api.nvim_get_hl(0, { name = "DiffDelete" }).fg or 0xf38ba8 - local change_fg = vim.api.nvim_get_hl(0, { name = "DiffChange" }).fg or 0xf9e2af - - vim.api.nvim_set_hl(0, "DiffAddText", { fg = add_fg, bg = "NONE", ctermbg = "NONE" }) - vim.api.nvim_set_hl(0, "DiffDeleteText", { fg = delete_fg, bg = "NONE", ctermbg = "NONE" }) - vim.api.nvim_set_hl(0, "DiffChangeText", { fg = change_fg, bg = "NONE", ctermbg = "NONE" }) - vim.api.nvim_set_hl(0, "DiffMoveText", { fg = 0x89b4fa, bg = "NONE", ctermbg = "NONE" }) - vim.api.nvim_set_hl(0, "DiffRenameText", { fg = change_fg, bg = "NONE", ctermbg = "NONE", underline = true }) + local add_bg = pick_bg({ "DiffAdd", "DiffText" }) + local delete_bg = pick_bg({ "DiffDelete", "DiffText" }) + local change_bg = pick_bg({ "DiffText", "DiffChange" }) + local move_bg = pick_bg({ "DiffChange", "DiffText" }) + if move_bg == change_bg then + -- If move and change resolve to the same background, de-emphasize move fill so updates remain visible. + move_bg = nil + end + + local add_sign_fg = pick_fg({ "DiffAdd" }, 0x49D17D) + local delete_sign_fg = pick_fg({ "DiffDelete" }, 0xFF6B6B) + local change_sign_fg = pick_fg({ "DiagnosticWarn", "DiffChange" }, 0xE8C95A) + local move_sign_fg = pick_fg({ "DiagnosticInfo", "DiffText" }, 0x5AA2FF) + + vim.api.nvim_set_hl(0, "DiffmanticAdd", { fg = add_sign_fg, bg = add_bg }) + vim.api.nvim_set_hl(0, "DiffmanticDelete", { fg = delete_sign_fg, bg = delete_bg }) + vim.api.nvim_set_hl(0, "DiffmanticChange", { fg = change_sign_fg, bg = change_bg }) + vim.api.nvim_set_hl( + 0, + "DiffmanticChangeAccent", + { fg = change_sign_fg, bg = "NONE", underline = true, bold = true } + ) + vim.api.nvim_set_hl(0, "DiffmanticMove", { fg = move_sign_fg, bg = move_bg or "NONE" }) + vim.api.nvim_set_hl(0, "DiffmanticRename", { fg = change_sign_fg, underline = true, bold = true, italic = true }) + + vim.api.nvim_set_hl(0, "DiffmanticAddSign", { fg = add_sign_fg, bg = "NONE" }) + vim.api.nvim_set_hl(0, "DiffmanticDeleteSign", { fg = delete_sign_fg, bg = "NONE" }) + vim.api.nvim_set_hl(0, "DiffmanticChangeSign", { fg = change_sign_fg, bg = "NONE" }) + vim.api.nvim_set_hl(0, "DiffmanticMoveSign", { fg = move_sign_fg, bg = "NONE" }) + vim.api.nvim_set_hl(0, "DiffmanticRenameSign", { fg = change_sign_fg, bg = "NONE" }) + + vim.api.nvim_set_hl(0, "DiffmanticAddFiller", { fg = add_sign_fg, bg = add_bg }) + vim.api.nvim_set_hl(0, "DiffmanticDeleteFiller", { fg = delete_sign_fg, bg = delete_bg }) + vim.api.nvim_set_hl(0, "DiffmanticMoveFiller", { fg = move_sign_fg, bg = move_bg }) end function M.setup(opts) setup_highlights() + + local aug = vim.api.nvim_create_augroup("diffmantic_highlights", { clear = true }) + vim.api.nvim_create_autocmd("ColorScheme", { + group = aug, + callback = setup_highlights, + }) end function M.diff(args) + setup_highlights() local parts = vim.split(args, " ", { trimempty = true }) if #parts == 0 then print("Please provide one or two files paths to compare.") @@ -28,36 +89,27 @@ function M.diff(args) local file1, file2 = parts[1], parts[2] local buf1, buf2 + local win1, win2 if file2 then -- Case: 2 files provided. Open them in split. vim.cmd("tabnew") vim.cmd("edit " .. file1) buf1 = vim.api.nvim_get_current_buf() - local win1 = vim.api.nvim_get_current_win() + win1 = vim.api.nvim_get_current_win() vim.cmd("vsplit " .. file2) buf2 = vim.api.nvim_get_current_buf() - local win2 = vim.api.nvim_get_current_win() - - vim.wo[win1].scrollbind = true - vim.wo[win1].cursorbind = true - vim.wo[win2].scrollbind = true - vim.wo[win2].cursorbind = true + win2 = vim.api.nvim_get_current_win() else -- Case: 1 file provided. Compare current buffer vs file. buf1 = vim.api.nvim_get_current_buf() - local win1 = vim.api.nvim_get_current_win() + win1 = vim.api.nvim_get_current_win() local expanded_path = vim.fn.expand(file1) vim.cmd("vsplit " .. expanded_path) buf2 = vim.api.nvim_get_current_buf() - local win2 = vim.api.nvim_get_current_win() - - vim.wo[win1].scrollbind = true - vim.wo[win1].cursorbind = true - vim.wo[win2].scrollbind = true - vim.wo[win2].cursorbind = true + win2 = vim.api.nvim_get_current_win() end local lang = vim.treesitter.language.get_lang(vim.bo[buf1].filetype) @@ -74,23 +126,58 @@ function M.diff(args) end local root1 = parser1:parse()[1]:root() local root2 = parser2:parse()[1]:root() + local src_role_index = roles.build_index(root1, buf1) + local dst_role_index = roles.build_index(root2, buf2) - local mappings, src_info, dst_info = core.top_down_match(root1, root2, buf1, buf2) + local mappings, src_info, dst_info = core.top_down_match(root1, root2, buf1, buf2, { + adaptive_mode = true, + }) -- print("Top-down mappings: " .. #mappings) -- local before_bottom_up = #mappings - mappings = core.bottom_up_match(mappings, src_info, dst_info, root1, root2, buf1, buf2) + mappings = core.bottom_up_match(mappings, src_info, dst_info, root1, root2, buf1, buf2, { + src_role_index = src_role_index, + dst_role_index = dst_role_index, + adaptive_mode = true, + }) -- print("Mappings after Bottom-up: " .. #mappings .. " (+" .. (#mappings - before_bottom_up) .. " new)") -- local before_recovery = #mappings - mappings = core.recovery_match(root1, root2, mappings, src_info, dst_info, buf1, buf2) + local src_count = 0 + local dst_count = 0 + for _ in pairs(src_info) do + src_count = src_count + 1 + end + for _ in pairs(dst_info) do + dst_count = dst_count + 1 + end + local max_nodes = math.max(src_count, dst_count) + mappings = core.recovery_match(root1, root2, mappings, src_info, dst_info, buf1, buf2, { + recovery_lcs_cell_limit = max_nodes >= 25000 and 1500 or 6000, + adaptive_mode = true, + }) -- debug_utils.print_recovery_mappings(mappings, before_recovery, src_info, dst_info, buf1, buf2) - local actions = core.generate_actions(root1, root2, mappings, src_info, dst_info) + local actions = core.generate_actions(root1, root2, mappings, src_info, dst_info, { + src_buf = buf1, + dst_buf = buf2, + src_role_index = src_role_index, + dst_role_index = dst_role_index, + adaptive_mode = true, + }) + vim.diagnostic.enable(false, { bufnr = buf1 }) + vim.diagnostic.enable(false, { bufnr = buf2 }) -- debug_utils.print_actions(actions, buf1, buf2) -- debug_utils.print_mappings(mappings, src_info, dst_info, buf1, buf2) - ui.apply_highlights(buf1, buf2, actions) + ui.apply_highlights(buf1, buf2, actions, { mappings = mappings, src_info = src_info, dst_info = dst_info }) + + vim.api.nvim_win_set_cursor(win1, { 1, 0 }) + vim.api.nvim_win_set_cursor(win2, { 1, 0 }) + vim.wo[win1].scrollbind = true + vim.wo[win1].cursorbind = true + vim.wo[win2].scrollbind = true + vim.wo[win2].cursorbind = true end return M diff --git a/lua/diffmantic/treesitter.lua b/lua/diffmantic/treesitter.lua index 88b6e68..10bbb2b 100644 --- a/lua/diffmantic/treesitter.lua +++ b/lua/diffmantic/treesitter.lua @@ -10,6 +10,10 @@ local function string_hash(str) return h end +local function hash_combine(acc, value) + return ((acc * 33) + value + 97) % 4294967296 +end + -- A leaf node has no children (e.g., a variable name, number, string literal) local function is_leaf(node) return node:named_child_count() == 0 @@ -27,34 +31,54 @@ end -- Walk through the entire syntax tree and compute metadata for each node -- Returns a table mapping node IDs to their computed info -function M.preprocess_tree(root, bufnr) +function M.preprocess_tree(root, bufnr, opts) + opts = opts or {} local info = {} + local label_hash_cache = {} + local type_hash_cache = {} + + local function cached_string_hash(cache, text) + local value = cache[text] + if value == nil then + value = string_hash(text) + cache[text] = value + end + return value + end local function visit(node) local id = node:id() local type = node:type() local label = get_label(node, bufnr) + local sr, sc, er, ec = node:range() local height = 1 local size = 1 - local child_hashes = "" - local child_structure_hashes = "" + local hash_acc = cached_string_hash(type_hash_cache, type) + local structure_hash_acc = hash_acc -- Recursively process all children first (post-order traversal) for child in node:iter_children() do local child_info = visit(child) height = math.max(height, child_info.height + 1) size = size + child_info.size - child_hashes = child_hashes .. tostring(child_info.hash) - child_structure_hashes = child_structure_hashes .. tostring(child_info.structure_hash) - info[child:id()].parent = node + hash_acc = hash_combine(hash_acc, child_info.hash) + structure_hash_acc = hash_combine(structure_hash_acc, child_info.structure_hash) + child_info.parent = node + child_info.parent_id = id + end + + if label ~= "" then + hash_acc = hash_combine(hash_acc, cached_string_hash(label_hash_cache, label)) + else + hash_acc = hash_combine(hash_acc, 0) end -- hash: unique if type + label + children all match (exact match) - local hash = string_hash(type .. label .. child_hashes) + local hash = hash_acc -- structure_hash: unique if type + children structure match (ignores labels) -- useful for detecting moved/renamed code - local structure_hash = string_hash(type .. child_structure_hashes) + local structure_hash = structure_hash_acc info[id] = { node = node, @@ -65,6 +89,11 @@ function M.preprocess_tree(root, bufnr) type = type, label = label, id = id, + start_row = sr, + start_col = sc, + end_row = er, + end_col = ec, + parent_id = nil, } return info[id] end diff --git a/lua/diffmantic/ui.lua b/lua/diffmantic/ui.lua index a155533..c72b786 100644 --- a/lua/diffmantic/ui.lua +++ b/lua/diffmantic/ui.lua @@ -10,10 +10,10 @@ function M.clear_highlights(bufnr) end end -function M.apply_highlights(src_buf, dst_buf, actions) +function M.apply_highlights(src_buf, dst_buf, actions, opts) M.clear_highlights(src_buf) M.clear_highlights(dst_buf) - renderer.render(src_buf, dst_buf, actions, ns) + return renderer.render(src_buf, dst_buf, actions, ns, opts) end return M diff --git a/lua/diffmantic/ui/filler.lua b/lua/diffmantic/ui/filler.lua new file mode 100644 index 0000000..4a956f7 --- /dev/null +++ b/lua/diffmantic/ui/filler.lua @@ -0,0 +1,1013 @@ +local M = {} +local VIRT_LINE_LEN = 300 +local HL_ADD = "DiffmanticAddFiller" +local HL_DELETE = "DiffmanticDeleteFiller" +local HL_MOVE = "DiffmanticMoveFiller" + +local function make_virt_line(hl_group) + return { { string.rep("╱", VIRT_LINE_LEN), hl_group } } +end + +--- Return the number of lines a range spans (1-indexed count). +local function line_span(range) + if not range then + return 0 + end + local sr = range.start_row + local er = range.end_row + local ec = range.end_col + if sr == nil or er == nil or ec == nil then + return 0 + end + local count = er - sr + if ec > 0 then + count = count + 1 + end + if count <= 0 then + count = 1 + end + return count +end + +local function range_contains(outer, inner) + if not outer or not inner then + return false + end + if outer.start_row == nil or outer.end_row == nil or outer.start_col == nil or outer.end_col == nil then + return false + end + if inner.start_row == nil or inner.end_row == nil or inner.start_col == nil or inner.end_col == nil then + return false + end + if inner.start_row < outer.start_row or inner.end_row > outer.end_row then + return false + end + if inner.start_row == outer.start_row and inner.start_col < outer.start_col then + return false + end + if inner.end_row == outer.end_row and inner.end_col > outer.end_col then + return false + end + return true +end + +local function ranges_equal(a, b) + if not a or not b then + return false + end + return a.start_row == b.start_row + and a.start_col == b.start_col + and a.end_row == b.end_row + and a.end_col == b.end_col +end + +local function ranges_related(a, b) + return ranges_equal(a, b) or range_contains(a, b) or range_contains(b, a) +end + +local function hl_for_type(action_type) + if action_type == "insert" then + return HL_ADD + elseif action_type == "delete" then + return HL_DELETE + end + return HL_MOVE +end + +local function clamp_row(row, line_count) + row = tonumber(row) or 0 + if row < 0 then + return 0 + end + if row > line_count then + return line_count + end + return row +end + +local function count_trailing_blank_lines(buf, range) + if not buf or not range or range.end_row == nil then + return 0 + end + local line_count = vim.api.nvim_buf_line_count(buf) + if line_count <= 0 then + return 0 + end + local row = range.end_row + 1 + if row < 0 then + row = 0 + end + if row >= line_count then + return 0 + end + local lines = vim.api.nvim_buf_get_lines(buf, row, line_count, false) + local count = 0 + for _, line in ipairs(lines) do + if line:match("^%s*$") then + count = count + 1 + else + break + end + end + return count +end + +local function count_leading_blank_lines(buf, range) + if not buf or not range or range.start_row == nil then + return 0 + end + local row = range.start_row - 1 + if row < 0 then + return 0 + end + local count = 0 + while row >= 0 do + local line = vim.api.nvim_buf_get_lines(buf, row, row + 1, false)[1] or "" + if line:match("^%s*$") then + count = count + 1 + row = row - 1 + else + break + end + end + return count +end + +local function padding_mode_for_buf(buf) + if not buf then + return "default" + end + local ft = vim.bo[buf] and vim.bo[buf].filetype or "" + if ft == "python" then + return "python_bottom" + end + return "default" +end + +local function normalize_move_padding(leading, trailing, mode) + leading = math.max(0, tonumber(leading) or 0) + trailing = math.max(0, tonumber(trailing) or 0) + local gap = math.max(leading, trailing) + if gap <= 0 then + return 0, 0 + end + if leading > 0 and trailing == 0 then + local top = math.floor((gap + 1) / 2) + local bottom = gap - top + return top, bottom + end + if mode == "python_bottom" then + return 0, gap + end + local top = math.floor(gap / 2) + local bottom = gap - top + return top, bottom +end + +local function style_priority(style) + if not style then + return -1 + end + if style.transparent then + return 0 + end + local hl = style.hl_group + if hl == HL_ADD or hl == HL_DELETE then + return 2 + end + if hl == HL_MOVE then + return 1 + end + return 1 +end + +local function merge_style(existing, candidate) + if not existing then + return { + hl_group = candidate.hl_group, + transparent = candidate.transparent or false, + } + end + if candidate.transparent then + return existing + end + if existing.transparent then + return { + hl_group = candidate.hl_group, + transparent = false, + } + end + if style_priority(candidate) >= style_priority(existing) then + return { + hl_group = candidate.hl_group, + transparent = false, + } + end + return existing +end + +local function ensure_entry(side_rows, row) + local entry = side_rows[row] + if entry then + return entry + end + entry = { + row = row, + line_styles = {}, + } + side_rows[row] = entry + return entry +end + +local function find_first_group(line_styles, group) + for idx, style in ipairs(line_styles or {}) do + if style and not style.transparent and style.hl_group == group then + return idx + end + end + return nil +end + +local function find_last_group(line_styles, group) + local last = nil + for idx, style in ipairs(line_styles or {}) do + if style and not style.transparent and style.hl_group == group then + last = idx + end + end + return last +end + +local function add_filler_lines(side_rows, row, count, opts) + count = tonumber(count) or 0 + if count <= 0 then + return + end + opts = opts or {} + local entry = ensure_entry(side_rows, row) + local line_styles = entry.line_styles + local candidate + if opts.transparent then + candidate = { transparent = true } + else + candidate = { + hl_group = opts.hl_group or HL_MOVE, + transparent = false, + } + end + local offset = opts.offset + if offset == nil then + if opts.append_if_existing and not candidate.transparent then + local insert_at = #line_styles + 1 + if candidate.hl_group == HL_ADD then + local first_delete = find_first_group(line_styles, HL_DELETE) + if first_delete then + insert_at = first_delete + end + elseif candidate.hl_group == HL_DELETE then + local last_add = find_last_group(line_styles, HL_ADD) + if last_add then + insert_at = last_add + 1 + end + end + for i = 1, count do + table.insert(line_styles, insert_at + (i - 1), { + hl_group = candidate.hl_group, + transparent = false, + }) + end + return + end + offset = opts.append_if_existing and #line_styles or 0 + end + if offset < 0 then + offset = 0 + end + for i = 1, count do + local idx = offset + i + local last = #line_styles + if idx > last + 1 then + for gap = last + 1, idx - 1 do + line_styles[gap] = { transparent = true } + end + end + line_styles[idx] = merge_style(line_styles[idx], candidate) + end +end + +local function find_containing_move(moves, target, side_key) + for _, move in ipairs(moves) do + local container = move[side_key] + if range_contains(container, target) then + return move + end + end + return nil +end + +local function summed_contained_span(actions, action_type, side_key, container_range) + if not container_range then + return 0 + end + local total = 0 + for _, action in ipairs(actions or {}) do + local meta = action.metadata or {} + if action.type == action_type and not meta.render_as_change then + local target = action[side_key] + if range_contains(container_range, target) then + total = total + line_span(target) + end + end + end + return total +end + +local function uncovered_update_hunk_ranges(update_action, actions, kind, side_key) + local analysis = update_action and update_action.analysis or nil + local hunks = analysis and analysis.hunks or nil + if not hunks or #hunks == 0 then + return {} + end + local ranges = {} + for _, hunk in ipairs(hunks) do + if hunk.kind == kind and not hunk.render_as_change then + local target = hunk[side_key] + if target then + local covered = false + for _, action in ipairs(actions or {}) do + local meta = action.metadata or {} + if action.type == kind and not meta.render_as_change and action[side_key] then + if ranges_related(action[side_key], target) then + covered = true + break + end + end + end + if not covered then + table.insert(ranges, target) + end + end + end + end + table.sort(ranges, function(a, b) + if a.start_row ~= b.start_row then + return a.start_row < b.start_row + end + return (a.start_col or 0) < (b.start_col or 0) + end) + return ranges +end + +local function merge_adjacent_ranges(ranges) + if not ranges or #ranges == 0 then + return {} + end + table.sort(ranges, function(a, b) + if a.start_row ~= b.start_row then + return a.start_row < b.start_row + end + return (a.start_col or 0) < (b.start_col or 0) + end) + local out = {} + local current = nil + for _, r in ipairs(ranges) do + if not current then + current = { + start_row = r.start_row, + start_col = r.start_col, + end_row = r.end_row, + end_col = r.end_col, + } + else + local touch_or_overlap = r.start_row <= (current.end_row + 1) + if touch_or_overlap then + if + r.end_row > current.end_row + or (r.end_row == current.end_row and (r.end_col or 0) > (current.end_col or 0)) + then + current.end_row = r.end_row + current.end_col = r.end_col + end + else + table.insert(out, current) + current = { + start_row = r.start_row, + start_col = r.start_col, + end_row = r.end_row, + end_col = r.end_col, + } + end + end + end + if current then + table.insert(out, current) + end + return out +end + +local function is_projection_context(action) + if not action or not action.src or not action.dst then + return false + end + local atype = action.type + if atype == "insert" or atype == "delete" or atype == "rename" or atype == "move" then + return false + end + return true +end + +local function is_valid_range(range) + if not range then + return false + end + if range.start_row == nil or range.start_col == nil then + return false + end + if range.end_row == nil or range.end_col == nil then + return false + end + return true +end + +local function to_range(info) + if not info then + return nil + end + local range = { + start_row = info.start_row, + start_col = info.start_col, + end_row = info.end_row, + end_col = info.end_col, + } + if not is_valid_range(range) then + return nil + end + return range +end + +local function build_projection_contexts_from_actions(actions) + local contexts = {} + for _, action in ipairs(actions or {}) do + if is_projection_context(action) then + table.insert(contexts, { + src = action.src, + dst = action.dst, + }) + end + end + return contexts +end + +local function build_projection_contexts_from_mappings(mappings, src_info, dst_info) + local contexts = {} + local seen = {} + for _, mapping in ipairs(mappings or {}) do + local src = src_info and src_info[mapping.src] or nil + local dst = dst_info and dst_info[mapping.dst] or nil + if src and dst and src.parent_id ~= nil and dst.parent_id ~= nil then + local src_node = src.node + local dst_node = dst.node + if src_node and dst_node and src_node:named() and dst_node:named() then + local src_range = to_range(src) + local dst_range = to_range(dst) + if src_range and dst_range then + local key = table.concat({ + src_range.start_row, + src_range.start_col, + src_range.end_row, + src_range.end_col, + "|", + dst_range.start_row, + dst_range.start_col, + dst_range.end_row, + dst_range.end_col, + }, ":") + if not seen[key] then + seen[key] = true + table.insert(contexts, { + src = src_range, + dst = dst_range, + }) + end + end + end + end + end + return contexts +end + +local function sort_projection_contexts(contexts) + table.sort(contexts, function(a, b) + local as = (a.src and a.src.start_row) or 0 + local bs = (b.src and b.src.start_row) or 0 + if as ~= bs then + return as < bs + end + local ae = (a.src and a.src.end_row) or as + local be = (b.src and b.src.end_row) or bs + return ae < be + end) + return contexts +end + +local function build_projection_contexts(actions, mappings, src_info, dst_info) + local contexts = build_projection_contexts_from_mappings(mappings, src_info, dst_info) + if #contexts == 0 then + contexts = build_projection_contexts_from_actions(actions) + else + local action_contexts = build_projection_contexts_from_actions(actions) + for _, ctx in ipairs(action_contexts) do + table.insert(contexts, ctx) + end + end + return sort_projection_contexts(contexts) +end + +local function contains_row(range, row) + if not range or range.start_row == nil or range.end_row == nil then + return false + end + return row >= range.start_row and row <= range.end_row +end + +local function best_containing_context(contexts, from_key, row) + local best = nil + local best_span = math.huge + for _, ctx in ipairs(contexts or {}) do + local r = ctx[from_key] + if contains_row(r, row) then + local span = line_span(r) + if span < best_span then + best = ctx + best_span = span + end + end + end + return best +end + +local function project_within_range(from_range, to_range, row) + if not from_range or not to_range then + return row + end + local from_start = from_range.start_row or row + local from_end = from_range.end_row or from_start + local to_start = to_range.start_row or row + local to_end = to_range.end_row or to_start + local from_span = math.max(1, (from_end - from_start) + 1) + local to_span = math.max(1, (to_end - to_start) + 1) + if from_span == 1 or to_span == 1 then + local rel = row - from_start + if rel < 0 then + rel = 0 + end + if rel > to_span - 1 then + rel = to_span - 1 + end + return to_start + rel + end + local rel = row - from_start + if rel < 0 then + rel = 0 + elseif rel > from_span - 1 then + rel = from_span - 1 + end + local ratio = rel / (from_span - 1) + return math.floor(to_start + (ratio * (to_span - 1)) + 0.5) +end + +local function neighbor_projected_row(row, contexts, from_key, to_key, opts) + opts = opts or {} + local preserve_gap = opts.preserve_gap or false + local to_buf = opts.to_buf + + local function prev_anchor(prev_to) + if not prev_to or prev_to.end_row == nil then + return nil + end + local anchor = prev_to.end_row + 1 + if preserve_gap and to_buf then + anchor = anchor + count_trailing_blank_lines(to_buf, prev_to) + end + return anchor + end + + local function next_anchor(next_to) + if not next_to or next_to.start_row == nil then + return nil + end + local anchor = next_to.start_row + if preserve_gap and to_buf then + anchor = anchor - count_leading_blank_lines(to_buf, next_to) + end + return anchor + end + + local prev_ctx = nil + local next_ctx = nil + for _, ctx in ipairs(contexts) do + local from = ctx[from_key] + if from then + if from.end_row and from.end_row < row then + if not prev_ctx or from.end_row > ((prev_ctx[from_key] and prev_ctx[from_key].end_row) or -1) then + prev_ctx = ctx + end + end + if from.start_row and from.start_row > row then + if + not next_ctx + or from.start_row < ((next_ctx[from_key] and next_ctx[from_key].start_row) or math.huge) + then + next_ctx = ctx + end + end + end + end + + local raw_row + if prev_ctx and next_ctx then + local prev_from = prev_ctx[from_key] + local next_from = next_ctx[from_key] + local prev_to = prev_ctx[to_key] + local next_to = next_ctx[to_key] + local prev_distance = row - (prev_from and prev_from.end_row or row) + local next_distance = (next_from and next_from.start_row or row) - row + if next_distance < prev_distance and next_to and next_to.start_row ~= nil then + raw_row = next_anchor(next_to) + elseif prev_to and prev_to.end_row ~= nil then + raw_row = prev_anchor(prev_to) + end + end + if raw_row == nil and prev_ctx and prev_ctx[to_key] and prev_ctx[to_key].end_row ~= nil then + raw_row = prev_anchor(prev_ctx[to_key]) + elseif raw_row == nil and next_ctx and next_ctx[to_key] and next_ctx[to_key].start_row ~= nil then + raw_row = next_anchor(next_ctx[to_key]) + elseif raw_row == nil then + raw_row = row + end + return raw_row +end + +local function project_row_from_contexts(row, contexts, from_key, to_key, line_count, opts) + opts = opts or {} + if row == nil then + return 0 + end + if not contexts or #contexts == 0 then + return clamp_row(row, line_count) + end + local containing = nil + if not opts.prefer_neighbors then + containing = best_containing_context(contexts, from_key, row) + end + local raw_row + if containing then + raw_row = project_within_range(containing[from_key], containing[to_key], row) + else + raw_row = neighbor_projected_row(row, contexts, from_key, to_key, opts) + end + return clamp_row(raw_row, line_count) +end + +local function move_overlaps_context(action, contexts) + local src = action and action.src + local dst = action and action.dst + if not src or not dst then + return false + end + for _, ctx in ipairs(contexts or {}) do + if ranges_related(ctx.src, src) and ranges_related(ctx.dst, dst) then + return true + end + end + return false +end + +local function filter_contexts_for_move_anchor(contexts, move_action) + if not move_action then + return contexts or {} + end + local out = {} + for _, ctx in ipairs(contexts or {}) do + local same_pair = ranges_related(ctx.src, move_action.src) and ranges_related(ctx.dst, move_action.dst) + if not same_pair then + table.insert(out, ctx) + end + end + return out +end + +local function rows_to_fillers(side_rows) + local out = {} + for _, entry in pairs(side_rows) do + table.insert(out, { + row = entry.row, + line_styles = entry.line_styles, + }) + end + table.sort(out, function(a, b) + return a.row < b.row + end) + return out +end + +local function total_virtual_lines(side_rows) + local total = 0 + for _, entry in pairs(side_rows) do + total = total + #(entry.line_styles or {}) + end + return total +end + +function M.compute(actions, src_buf, dst_buf, opts) + opts = opts or {} + local src_line_count = src_buf and vim.api.nvim_buf_line_count(src_buf) or 0 + local dst_line_count = dst_buf and vim.api.nvim_buf_line_count(dst_buf) or 0 + local src_rows = {} + local dst_rows = {} + local moves = {} + local updates = {} + local move_layout = {} + local src_move_events = {} + local dst_move_events = {} + local projection_contexts = build_projection_contexts(actions, opts.mappings, opts.src_info, opts.dst_info) + + for _, action in ipairs(actions or {}) do + if action.type == "move" and action.src and action.dst then + table.insert(moves, action) + table.insert(src_move_events, action) + table.insert(dst_move_events, action) + elseif (action.type == "update" or action.type == "rename") and action.src and action.dst then + table.insert(updates, action) + end + end + + table.sort(src_move_events, function(a, b) + return a.dst.start_row < b.dst.start_row + end) + table.sort(dst_move_events, function(a, b) + return a.src.start_row < b.src.start_row + end) + + local src_shift = 0 + local single_src_move = #src_move_events == 1 + local src_move_padding_mode = padding_mode_for_buf(dst_buf) + for _, action in ipairs(src_move_events) do + local leading = count_leading_blank_lines(dst_buf, action.dst) + local trailing = count_trailing_blank_lines(dst_buf, action.dst) + local raw_row + if single_src_move then + local move_anchor_contexts = filter_contexts_for_move_anchor(projection_contexts, action) + raw_row = project_row_from_contexts( + action.dst.start_row, + move_anchor_contexts, + "dst", + "src", + src_line_count, + { preserve_gap = false, to_buf = src_buf, prefer_neighbors = true } + ) + if leading > 0 and trailing > 0 then + raw_row = clamp_row(raw_row + math.min(leading, trailing), src_line_count) + end + else + raw_row = clamp_row(action.dst.start_row, src_line_count) + end + local row = raw_row - src_shift + if row < 0 then + row = 0 + end + local entry = src_rows[row] + local base_offset = entry and #(entry.line_styles or {}) or 0 + local span = line_span(action.dst) + leading, trailing = normalize_move_padding(leading, trailing, src_move_padding_mode) + add_filler_lines(src_rows, row, leading, { offset = base_offset, transparent = true }) + add_filler_lines(src_rows, row, span, { offset = base_offset + leading, hl_group = HL_MOVE }) + add_filler_lines(src_rows, row, trailing, { offset = base_offset + leading + span, transparent = true }) + src_shift = src_shift + leading + span + trailing + move_layout[action] = move_layout[action] or {} + move_layout[action].src_row = row + move_layout[action].src_base_offset = base_offset + leading + end + + local dst_move_padding_mode = padding_mode_for_buf(src_buf) + for _, action in ipairs(dst_move_events) do + local move_anchor_contexts = filter_contexts_for_move_anchor(projection_contexts, action) + local row = project_row_from_contexts( + action.src.start_row, + move_anchor_contexts, + "src", + "dst", + dst_line_count, + { preserve_gap = true, to_buf = dst_buf } + ) + local entry = dst_rows[row] + local base_offset = entry and #(entry.line_styles or {}) or 0 + local span = line_span(action.src) + local leading = count_leading_blank_lines(src_buf, action.src) + local trailing = count_trailing_blank_lines(src_buf, action.src) + leading, trailing = normalize_move_padding(leading, trailing, dst_move_padding_mode) + add_filler_lines(dst_rows, row, leading, { offset = base_offset, transparent = true }) + add_filler_lines(dst_rows, row, span, { offset = base_offset + leading, hl_group = HL_MOVE }) + add_filler_lines(dst_rows, row, trailing, { offset = base_offset + leading + span, transparent = true }) + move_layout[action] = move_layout[action] or {} + move_layout[action].dst_row = row + move_layout[action].dst_base_offset = base_offset + leading + end + + for _, action in ipairs(actions or {}) do + local meta = action.metadata or {} + if meta.render_as_change then + goto continue + end + local atype = action.type + local src = action.src + local dst = action.dst + + if atype == "insert" and dst then + local span = line_span(dst) + local nested_move = find_containing_move(moves, dst, "dst") + if nested_move then + local layout = move_layout[nested_move] or {} + local row = layout.src_row or clamp_row(nested_move.dst.start_row, src_line_count) + local base_offset = layout.src_base_offset or 0 + local offset = base_offset + (dst.start_row - nested_move.dst.start_row) + add_filler_lines(src_rows, row, span, { hl_group = HL_ADD, offset = offset }) + else + local row = project_row_from_contexts( + dst.start_row, + projection_contexts, + "dst", + "src", + src_line_count, + { prefer_neighbors = true } + ) + add_filler_lines(src_rows, row, span, { hl_group = HL_ADD, append_if_existing = true }) + end + elseif atype == "delete" and src then + local span = line_span(src) + local nested_move = find_containing_move(moves, src, "src") + if nested_move then + local layout = move_layout[nested_move] or {} + local row = layout.dst_row or clamp_row(nested_move.src.start_row, dst_line_count) + local base_offset = layout.dst_base_offset or 0 + local offset = base_offset + (src.start_row - nested_move.src.start_row) + add_filler_lines(dst_rows, row, span, { hl_group = HL_DELETE, offset = offset }) + else + local row = project_row_from_contexts( + src.start_row, + projection_contexts, + "src", + "dst", + dst_line_count, + { prefer_neighbors = true } + ) + add_filler_lines(dst_rows, row, span, { hl_group = HL_DELETE, append_if_existing = true }) + end + elseif (atype == "update" or atype == "rename") and src and dst then + local overlaps_move = move_overlaps_context(action, moves) + local src_lines = line_span(src) + local dst_lines = line_span(dst) + local uncovered_delete = + merge_adjacent_ranges(uncovered_update_hunk_ranges(action, actions, "delete", "src")) + local uncovered_insert = + merge_adjacent_ranges(uncovered_update_hunk_ranges(action, actions, "insert", "dst")) + local uncovered_delete_span = 0 + local uncovered_insert_span = 0 + + for _, hunk_src in ipairs(uncovered_delete) do + local span = line_span(hunk_src) + local nested_move = overlaps_move and find_containing_move(moves, hunk_src, "src") or nil + if nested_move then + local layout = move_layout[nested_move] or {} + local row = layout.dst_row or clamp_row(nested_move.src.start_row, dst_line_count) + local base_offset = layout.dst_base_offset or 0 + local offset = base_offset + (hunk_src.start_row - nested_move.src.start_row) + add_filler_lines(dst_rows, row, span, { hl_group = HL_DELETE, offset = offset }) + else + local row = project_row_from_contexts( + hunk_src.start_row, + projection_contexts, + "src", + "dst", + dst_line_count, + { prefer_neighbors = true } + ) + add_filler_lines(dst_rows, row, span, { hl_group = HL_DELETE, append_if_existing = true }) + end + uncovered_delete_span = uncovered_delete_span + span + end + + for _, hunk_dst in ipairs(uncovered_insert) do + local span = line_span(hunk_dst) + local nested_move = overlaps_move and find_containing_move(moves, hunk_dst, "dst") or nil + if nested_move then + local layout = move_layout[nested_move] or {} + local row = layout.src_row or clamp_row(nested_move.dst.start_row, src_line_count) + local base_offset = layout.src_base_offset or 0 + local offset = base_offset + (hunk_dst.start_row - nested_move.dst.start_row) + add_filler_lines(src_rows, row, span, { hl_group = HL_ADD, offset = offset }) + else + local row = project_row_from_contexts( + hunk_dst.start_row, + projection_contexts, + "dst", + "src", + src_line_count, + { prefer_neighbors = true } + ) + add_filler_lines(src_rows, row, span, { hl_group = HL_ADD, append_if_existing = true }) + end + uncovered_insert_span = uncovered_insert_span + span + end + + if overlaps_move then + goto continue + end + + if src_lines > dst_lines then + local delete_inside = summed_contained_span(actions, "delete", "src", src) + local remaining = (src_lines - dst_lines) - delete_inside - uncovered_delete_span + if remaining > 0 then + local row = clamp_row(dst.start_row + dst_lines, dst_line_count) + add_filler_lines(dst_rows, row, remaining, { hl_group = HL_DELETE, append_if_existing = true }) + end + elseif dst_lines > src_lines then + local insert_inside = summed_contained_span(actions, "insert", "dst", dst) + local remaining = (dst_lines - src_lines) - insert_inside - uncovered_insert_span + if remaining > 0 then + local row = clamp_row(src.start_row + src_lines, src_line_count) + add_filler_lines(src_rows, row, remaining, { hl_group = HL_ADD, append_if_existing = true }) + end + end + end + ::continue:: + end + + local src_total = total_virtual_lines(src_rows) + local dst_total = total_virtual_lines(dst_rows) + local src_visual = src_line_count + src_total + local dst_visual = dst_line_count + dst_total + if src_visual > dst_visual then + add_filler_lines(dst_rows, dst_line_count, src_visual - dst_visual, { + transparent = true, + append_if_existing = true, + }) + elseif dst_visual > src_visual then + add_filler_lines(src_rows, src_line_count, dst_visual - src_visual, { + transparent = true, + append_if_existing = true, + }) + end + + return rows_to_fillers(src_rows), rows_to_fillers(dst_rows) +end + +function M.apply(buf, ns, fillers) + if not buf or not ns or not fillers or #fillers == 0 then + return + end + local line_count = vim.api.nvim_buf_line_count(buf) + if line_count <= 0 then + return + end + for _, filler in ipairs(fillers) do + local row = clamp_row(filler.row or 0, line_count) + local virt_lines = {} + local line_styles = filler.line_styles + if line_styles and #line_styles > 0 then + for _, style in ipairs(line_styles) do + if style and style.transparent then + table.insert(virt_lines, {}) + else + local hl = (style and style.hl_group) or filler.hl_group or HL_MOVE + table.insert(virt_lines, make_virt_line(hl)) + end + end + else + local count = filler.count or 0 + local hl = filler.hl_group or HL_MOVE + for _ = 1, count do + table.insert(virt_lines, make_virt_line(hl)) + end + end + if #virt_lines > 0 then + pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, 0, { + virt_lines = virt_lines, + virt_lines_above = true, + }) + end + end +end + +M._private = { + line_span = line_span, + range_contains = range_contains, + clamp_row = clamp_row, + count_trailing_blank_lines = count_trailing_blank_lines, + count_leading_blank_lines = count_leading_blank_lines, + summed_contained_span = summed_contained_span, + add_filler_lines = add_filler_lines, + total_virtual_lines = total_virtual_lines, +} + +return M diff --git a/lua/diffmantic/ui/helpers.lua b/lua/diffmantic/ui/helpers.lua deleted file mode 100644 index f7ec8c6..0000000 --- a/lua/diffmantic/ui/helpers.lua +++ /dev/null @@ -1,513 +0,0 @@ -local M = {} - --- Leaf-level diffs for small updates; otherwise return empty. -function M.find_leaf_changes(src_node, dst_node, src_buf, dst_buf) - local changes = {} - - local function get_all_leaves(node, bufnr) - local leaves = {} - local function traverse(n) - if n:child_count() == 0 then - table.insert(leaves, { - node = n, - text = vim.treesitter.get_node_text(n, bufnr), - type = n:type(), - }) - else - for child in n:iter_children() do - traverse(child) - end - end - end - traverse(node) - return leaves - end - - local src_leaves = get_all_leaves(src_node, src_buf) - local dst_leaves = get_all_leaves(dst_node, dst_buf) - - if #src_leaves ~= #dst_leaves then - return {} - end - - if math.abs(#src_leaves - #dst_leaves) > 2 then - return {} - end - - local min_len = math.min(#src_leaves, #dst_leaves) - local max_len = math.max(#src_leaves, #dst_leaves) - local same_count = 0 - - for i = 1, min_len do - local sl, dl = src_leaves[i], dst_leaves[i] - if sl.type == dl.type and sl.text == dl.text then - same_count = same_count + 1 - elseif sl.type == dl.type and sl.text ~= dl.text then - table.insert(changes, { - src_node = sl.node, - dst_node = dl.node, - src_text = sl.text, - dst_text = dl.text, - }) - end - end - - local similarity = same_count / max_len - - if similarity < 0.5 then - return {} - end - - if #changes > 5 then - return {} - end - - return changes -end - -function M.node_in_field(parent, field_name, node) - local nodes = parent:field(field_name) - if not nodes then - return false - end - for _, field_node in ipairs(nodes) do - if field_node:equal(node) or field_node:child_with_descendant(node) then - return true - end - end - return false -end - -function M.is_rename_identifier(node) - if not node then - return false - end - - local node_type = node:type() - if node_type ~= "identifier" and node_type ~= "type_identifier" and node_type ~= "field_identifier" then - return false - end - - local parent = node:parent() - if not parent then - return false - end - - local parent_type = parent:type() - if parent_type == "parameters" or parent_type == "parameter_list" or parent_type == "formal_parameters" then - return true - end - - if parent_type == "assignment" and M.node_in_field(parent, "left", node) then - return true - end - - if parent_type == "assignment_statement" and M.node_in_field(parent, "variable", node) then - return true - end - if parent_type == "variable_list" then - return true - end - - -- Language-specific name heuristics. - local current = node - while parent do - local ptype = parent:type() - if (ptype == "function_declaration" or ptype == "function_definition" or ptype == "class_definition" or ptype == "class_declaration") - and M.node_in_field(parent, "name", current) - then - return true - end - if (ptype == "class_specifier" or ptype == "struct_specifier" or ptype == "enum_specifier" or ptype == "union_specifier") - and (M.node_in_field(parent, "name", current) or M.node_in_field(parent, "tag", current)) - then - return true - end - if ptype == "function_declarator" then - return true - end - if ptype == "init_declarator" and M.node_in_field(parent, "declarator", current) then - return true - end - if ptype == "field_declaration" and M.node_in_field(parent, "declarator", current) then - return true - end - if ptype == "declarator" then - return true - end - if ptype == "field" then - return true - end - current = parent - parent = parent:parent() - end - - return false -end - -function M.is_value_node(node, text) - local node_type = node and node:type() or "" - if node_type:find("string") or node_type:find("number") or node_type:find("integer") or node_type:find("float") or node_type:find("boolean") then - return true - end - if node_type == "char_literal" or node_type:find("char") and node_type:find("literal") then - return true - end - if text then - if text:match("^%s*['\"].*['\"]%s*$") then - return true - end - if text:match("^%s*[%d%.]+%s*$") then - return true - end - if text == "true" or text == "false" or text == "nil" then - return true - end - end - return false -end - -local function expand_word_fragment(text, start_idx, end_idx) - local s = start_idx - local e = end_idx - while s > 1 and text:sub(s - 1, s - 1):match("[%w_]") do - s = s - 1 - end - while e <= #text and text:sub(e, e):match("[%w_]") do - e = e + 1 - end - return s, e - 1 -end - -function M.diff_fragment(old_text, new_text) - if old_text == new_text then - return nil - end - - local max_prefix = math.min(#old_text, #new_text) - local prefix = 0 - while prefix < max_prefix and old_text:sub(prefix + 1, prefix + 1) == new_text:sub(prefix + 1, prefix + 1) do - prefix = prefix + 1 - end - - local max_suffix = math.min(#old_text - prefix, #new_text - prefix) - local suffix = 0 - while suffix < max_suffix do - local o = old_text:sub(#old_text - suffix, #old_text - suffix) - local n = new_text:sub(#new_text - suffix, #new_text - suffix) - if o ~= n then - break - end - suffix = suffix + 1 - end - - local old_start = prefix + 1 - local old_end = #old_text - suffix - local new_start = prefix + 1 - local new_end = #new_text - suffix - - if old_start > old_end or new_start > new_end then - return nil - end - - old_start, old_end = expand_word_fragment(old_text, old_start, old_end) - new_start, new_end = expand_word_fragment(new_text, new_start, new_end) - - return { - old_start = old_start, - old_end = old_end, - new_start = new_start, - new_end = new_end, - old_fragment = old_text:sub(old_start, old_end), - new_fragment = new_text:sub(new_start, new_end), - } -end - -function M.set_inline_virt_text(buf, ns, row, col, text, hl) - local opts = { - virt_text = { { text, hl } }, - virt_text_pos = "inline", - } - local ok = pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, col, opts) - if ok then - return - end - opts.virt_text_pos = "eol" - pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, col, opts) -end - -function M.highlight_internal_diff(src_node, dst_node, src_buf, dst_buf, ns, opts) - local src_text = vim.treesitter.get_node_text(src_node, src_buf) - local dst_text = vim.treesitter.get_node_text(dst_node, dst_buf) - if not src_text or not dst_text or src_text == "" or dst_text == "" then - return false - end - - local src_lines = vim.split(src_text, "\n", { plain = true }) - local dst_lines = vim.split(dst_text, "\n", { plain = true }) - - local ok, hunks = pcall(vim.text.diff, src_text, dst_text, { - result_type = "indices", - linematch = 60, - }) - - local sr, _, er, _ = src_node:range() - local tr, _, ter, _ = dst_node:range() - local src_end = er - 1 - local dst_end = ter - 1 - local signs_src = opts and opts.signs_src or nil - local signs_dst = opts and opts.signs_dst or nil - local rename_map = opts and opts.rename_map or nil - - local function mark_fragment(buf, row, start_col, end_col, hl_group) - if row < 0 or end_col <= start_col then - return false - end - return pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, start_col, { - end_row = row, - end_col = end_col, - hl_group = hl_group, - }) - end - - local function mark_sign(buf, row, text, hl_group, sign_rows) - if row < 0 then - return false - end - if sign_rows and sign_rows[row] then - return false - end - local ok = pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, 0, { - sign_text = text, - sign_hl_group = hl_group, - }) - if ok and sign_rows then - sign_rows[row] = true - end - return ok - end - - local function tokenize_line(text) - local tokens = {} - local i = 1 - local len = #text - while i <= len do - local ch = text:sub(i, i) - if ch:match("%s") then - i = i + 1 - elseif ch:match("[%w_]") then - local j = i + 1 - while j <= len and text:sub(j, j):match("[%w_]") do - j = j + 1 - end - table.insert(tokens, { text = text:sub(i, j - 1), start_col = i, end_col = j - 1 }) - i = j - else - local j = i + 1 - while j <= len and not text:sub(j, j):match("[%w_%s]") do - j = j + 1 - end - table.insert(tokens, { text = text:sub(i, j - 1), start_col = i, end_col = j - 1 }) - i = j - end - end - return tokens - end - - local function tokens_equal(a, b) - if a.text == b.text then - return true - end - if rename_map and rename_map[a.text] == b.text then - return true - end - return false - end - - local function lcs_matches(a, b) - local n = #a - local m = #b - if n == 0 or m == 0 then - return {}, {} - end - local dp = {} - for i = 0, n do - dp[i] = {} - dp[i][0] = 0 - end - for j = 1, m do - dp[0][j] = 0 - end - for i = 1, n do - for j = 1, m do - if tokens_equal(a[i], b[j]) then - dp[i][j] = dp[i - 1][j - 1] + 1 - else - local up = dp[i - 1][j] - local left = dp[i][j - 1] - dp[i][j] = (up >= left) and up or left - end - end - end - local match_a = {} - local match_b = {} - local i = n - local j = m - while i > 0 and j > 0 do - if tokens_equal(a[i], b[j]) then - match_a[i] = true - match_b[j] = true - i = i - 1 - j = j - 1 - else - local up = dp[i - 1][j] - local left = dp[i][j - 1] - if up >= left then - i = i - 1 - else - j = j - 1 - end - end - end - return match_a, match_b - end - - local function mark_full_line(buf, row, hl_group) - if row < 0 then - return false - end - return pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, 0, { - end_row = row + 1, - end_col = 0, - hl_group = hl_group, - hl_eol = true, - }) - end - - local function highlight_line_pair(src_row, dst_row, s_line, d_line) - if s_line and d_line and s_line ~= d_line then - local tokens_src = tokenize_line(s_line) - local tokens_dst = tokenize_line(d_line) - if #tokens_src > 0 or #tokens_dst > 0 then - local match_src, match_dst = lcs_matches(tokens_src, tokens_dst) - local did_src = false - local did_dst = false - if src_row <= src_end then - for i, tok in ipairs(tokens_src) do - if not match_src[i] then - did_src = mark_fragment(src_buf, src_row, tok.start_col - 1, tok.end_col, "DiffChangeText") or did_src - end - end - if did_src then - mark_sign(src_buf, src_row, "U", "DiffChangeText", signs_src) - end - end - if dst_row <= dst_end then - for i, tok in ipairs(tokens_dst) do - if not match_dst[i] then - did_dst = mark_fragment(dst_buf, dst_row, tok.start_col - 1, tok.end_col, "DiffChangeText") or did_dst - end - end - if did_dst then - mark_sign(dst_buf, dst_row, "U", "DiffChangeText", signs_dst) - end - end - if did_src or did_dst then - return true - end - return false - end - local fragment = M.diff_fragment(s_line, d_line) - if fragment then - local did = false - if src_row <= src_end then - did = mark_fragment(src_buf, src_row, fragment.old_start - 1, fragment.old_end, "DiffChangeText") or did - mark_sign(src_buf, src_row, "U", "DiffChangeText", signs_src) - end - if dst_row <= dst_end then - did = mark_fragment(dst_buf, dst_row, fragment.new_start - 1, fragment.new_end, "DiffChangeText") or did - mark_sign(dst_buf, dst_row, "U", "DiffChangeText", signs_dst) - end - return did - end - local did = false - if src_row <= src_end then - did = mark_full_line(src_buf, src_row, "DiffChangeText") or did - mark_sign(src_buf, src_row, "U", "DiffChangeText", signs_src) - end - if dst_row <= dst_end then - did = mark_full_line(dst_buf, dst_row, "DiffChangeText") or did - mark_sign(dst_buf, dst_row, "U", "DiffChangeText", signs_dst) - end - return did - end - if s_line and not d_line then - if src_row <= src_end then - mark_sign(src_buf, src_row, "-", "DiffDeleteText", signs_src) - return mark_full_line(src_buf, src_row, "DiffDeleteText") - end - elseif d_line and not s_line then - if dst_row <= dst_end then - mark_sign(dst_buf, dst_row, "+", "DiffAddText", signs_dst) - return mark_full_line(dst_buf, dst_row, "DiffAddText") - end - end - return false - end - - local did_highlight = false - - if ok and hunks and #hunks > 0 then - for _, h in ipairs(hunks) do - local start_a, count_a, start_b, count_b = h[1], h[2], h[3], h[4] - local overlap = math.min(count_a, count_b) - - for i = 0, overlap - 1 do - local src_row = sr + start_a - 1 + i - local dst_row = tr + start_b - 1 + i - local s_line = src_lines[start_a + i] - local d_line = dst_lines[start_b + i] - if highlight_line_pair(src_row, dst_row, s_line, d_line) then - did_highlight = true - end - end - - if count_a > overlap then - for i = overlap, count_a - 1 do - local src_row = sr + start_a - 1 + i - if src_row <= src_end then - mark_sign(src_buf, src_row, "-", "DiffDeleteText", signs_src) - did_highlight = mark_full_line(src_buf, src_row, "DiffDeleteText") or did_highlight - end - end - end - - if count_b > overlap then - for i = overlap, count_b - 1 do - local dst_row = tr + start_b - 1 + i - if dst_row <= dst_end then - mark_sign(dst_buf, dst_row, "+", "DiffAddText", signs_dst) - did_highlight = mark_full_line(dst_buf, dst_row, "DiffAddText") or did_highlight - end - end - end - end - - return did_highlight - end - - local max_lines = math.max(#src_lines, #dst_lines) - for i = 1, max_lines do - local src_row = sr + i - 1 - local dst_row = tr + i - 1 - local s_line = src_lines[i] - local d_line = dst_lines[i] - if highlight_line_pair(src_row, dst_row, s_line, d_line) then - did_highlight = true - end - end - - return did_highlight -end - -return M diff --git a/lua/diffmantic/ui/renderer.lua b/lua/diffmantic/ui/renderer.lua index df6b7f8..aaa65eb 100644 --- a/lua/diffmantic/ui/renderer.lua +++ b/lua/diffmantic/ui/renderer.lua @@ -1,331 +1,329 @@ -local helpers = require("diffmantic.ui.helpers") +local signs = require("diffmantic.ui.signs") +local filler = require("diffmantic.ui.filler") local M = {} -function M.render(src_buf, dst_buf, actions, ns) - -- Suppress insert/delete inside moved/updated ranges. - local src_suppress = {} - local dst_suppress = {} - local rename_map = {} +local HL_PRIORITY = { + DiffmanticMove = 10, + DiffmanticAdd = 20, + DiffmanticDelete = 20, + DiffmanticChange = 30, + DiffmanticChangeAccent = 35, + DiffmanticRename = 40, +} - local function add_range(ranges, node) - if not node then - return - end - local sr, _, er, _ = node:range() - table.insert(ranges, { start_row = sr, end_row = er }) +local function set_extmark(buf, ns, row, col, opts) + if opts and opts.hl_group and not opts.priority then + opts.priority = HL_PRIORITY[opts.hl_group] or 20 end + return pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, col, opts) +end - for _, action in ipairs(actions) do - if action.type == "move" or action.type == "update" then - add_range(src_suppress, action.node) - add_range(dst_suppress, action.target) - end +local function apply_span(buf, ns, range, hl_group) + if not range or not hl_group then + return + end + local sr = range.start_row + local sc = range.start_col + local er = range.end_row or sr + local ec = range.end_col + if sr == nil or sc == nil or ec == nil then + return + end + if sr == er and ec <= sc then + ec = sc + 1 end + set_extmark(buf, ns, sr, sc, { + end_row = er, + end_col = ec, + hl_group = hl_group, + }) +end - for _, action in ipairs(actions) do - if action.type == "update" then - local leaf_changes = helpers.find_leaf_changes(action.node, action.target, src_buf, dst_buf) - for _, change in ipairs(leaf_changes) do - if helpers.is_rename_identifier(change.src_node) or helpers.is_rename_identifier(change.dst_node) then - rename_map[change.src_text] = change.dst_text - end - end - end +local function apply_sign(buf, ns, row, text, hl_group, seen_rows) + if row == nil or not text or not hl_group then + return end + signs.mark(buf, ns, row, 0, text, hl_group, seen_rows) +end - local function is_suppressed(ranges, node) - if not node then - return false - end - local sr, _, er, _ = node:range() - for _, range in ipairs(ranges) do - if sr >= range.start_row and er <= range.end_row then - return true - end - end +local function ranges_overlap(a, b) + if not a or not b then + return false + end + if a.start_row == nil or a.end_row == nil or a.start_col == nil or a.end_col == nil then + return false + end + if b.start_row == nil or b.end_row == nil or b.start_col == nil or b.end_col == nil then + return false + end + if a.end_row < b.start_row or b.end_row < a.start_row then return false end + if a.start_row == b.end_row and a.start_col >= b.end_col then + return false + end + if b.start_row == a.end_row and b.start_col >= a.end_col then + return false + end + return true +end - for _, action in ipairs(actions) do - local node = action.node - local sr, sc, er, ec = node:range() +local function overlaps_any(range, ranges) + if not range or not ranges then + return false + end + for _, candidate in ipairs(ranges) do + if ranges_overlap(range, candidate) then + return true + end + end + return false +end - if action.type == "move" then - local target = action.target - local tr, tc, ter, tec = target:range() - local src_line = sr + 1 - local dst_line = tr + 1 +local function apply_virt(buf, ns, row, col, text, hl_group, pos) + if row == nil or not text then + return + end + local opts = { + virt_text = { { text, hl_group or "Comment" } }, + virt_text_pos = pos or "eol", + } + local ok = set_extmark(buf, ns, row, col or 0, opts) + if not ok and opts.virt_text_pos == "inline" then + opts.virt_text_pos = "eol" + set_extmark(buf, ns, row, col or 0, opts) + end +end - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, sr, sc, { - end_row = er, - end_col = ec, - hl_group = "DiffMoveText", - virt_text = { { string.format(" ⤷ moved L%d → L%d", src_line, dst_line), "Comment" } }, - virt_text_pos = "eol", - sign_text = "M", - sign_hl_group = "DiffMoveText", - }) - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, tr, tc, { - end_row = ter, - end_col = tec, - hl_group = "DiffMoveText", - virt_text = { { string.format(" ⤶ from L%d", src_line), "Comment" } }, - virt_text_pos = "eol", - sign_text = "M", - sign_hl_group = "DiffMoveText", - }) - elseif action.type == "update" then - local target = action.target - local tr, tc, ter, tec = target:range() +local TYPE_STYLE = { + move = { hl = "DiffmanticMove", sign = "M" }, + rename = { hl = "DiffmanticRename", sign = "R" }, + update = { hl = "DiffmanticChange", sign = "U" }, + insert = { hl = "DiffmanticAdd", sign = "+" }, + delete = { hl = "DiffmanticDelete", sign = "-" }, +} - local leaf_changes = helpers.find_leaf_changes(node, target, src_buf, dst_buf) +local HUNK_STYLE = { + change = { + src_hl = "DiffmanticChange", + dst_hl = "DiffmanticChange", + src_sign = "U", + dst_sign = "U", + }, + insert = { + src_hl = nil, + dst_hl = "DiffmanticAdd", + src_sign = nil, + dst_sign = "+", + }, + delete = { + src_hl = "DiffmanticDelete", + dst_hl = nil, + src_sign = "-", + dst_sign = nil, + }, +} - local signs_src = {} - local signs_dst = {} - if #leaf_changes > 0 then - local rename_signs = {} - local rename_signs_src = {} - local rename_inline_src = {} - local rename_inline_dst = {} - local update_signs_dst = {} - local update_signs_src = {} - local rename_pairs = {} +local function range_text(buf, range) + if not buf or not range then + return nil + end + if range.start_row == nil or range.end_row == nil or range.start_col == nil or range.end_col == nil then + return nil + end + if range.start_row ~= range.end_row then + return nil + end + local line = vim.api.nvim_buf_get_lines(buf, range.start_row, range.start_row + 1, false)[1] or "" + if line == "" then + return nil + end + local start_col = range.start_col + 1 + local end_col = range.end_col + if end_col < start_col then + return nil + end + return line:sub(start_col, end_col) +end - for src_text, dst_text in pairs(rename_map) do - rename_pairs[src_text] = dst_text - end +local function hunk_is_effective_non_rename(hunk, rename_pairs, src_buf, dst_buf) + if not hunk then + return false + end + if hunk.kind == "insert" or hunk.kind == "delete" then + return true + end + if hunk.kind ~= "change" then + return false + end + local src_text = range_text(src_buf, hunk.src) + local dst_text = range_text(dst_buf, hunk.dst) + if not src_text or not dst_text then + return true + end + return src_text ~= dst_text and rename_pairs[src_text] ~= dst_text +end - for _, change in ipairs(leaf_changes) do - local src_node = change.src_node - local dst_node = change.dst_node - if helpers.is_rename_identifier(src_node) or helpers.is_rename_identifier(dst_node) then - rename_pairs[change.src_text] = change.dst_text - end - end +local function effective_update_hunks(action, src_buf, dst_buf) + local analysis = action and action.analysis or nil + local hunks = analysis and analysis.hunks or nil + if not hunks or #hunks == 0 then + return {} + end + local rename_pairs = analysis.rename_pairs or {} + local effective = {} + for _, hunk in ipairs(hunks) do + if hunk_is_effective_non_rename(hunk, rename_pairs, src_buf, dst_buf) then + table.insert(effective, hunk) + end + end + return effective +end - for _, change in ipairs(leaf_changes) do - local src_node = change.src_node - local dst_node = change.dst_node - local ctr, ctc, cter, ctec = dst_node:range() - local csr, csc, cser, csec = src_node:range() +local function move_to_arrow(from_line, to_line) + if type(from_line) ~= "number" or type(to_line) ~= "number" then + return "⤴" + end + if to_line > from_line then + return "⤵" + end + return "⤴" +end - local is_rename_ref = false - if not (helpers.is_rename_identifier(src_node) or helpers.is_rename_identifier(dst_node)) then - local src_type = src_node:type() - local dst_type = dst_node:type() - if - ( - src_type == "identifier" - or src_type == "field_identifier" - or src_type == "type_identifier" - ) - and (dst_type == "identifier" or dst_type == "field_identifier" or dst_type == "type_identifier") - and ( - rename_pairs[change.src_text] == change.dst_text - or rename_map[change.src_text] == change.dst_text - ) - then - is_rename_ref = true - end - end +function M.render(src_buf, dst_buf, actions, ns, opts) + local src_sign_rows = {} + local dst_sign_rows = {} + local src_move_ranges = {} + local dst_move_ranges = {} + local src_fillers, dst_fillers = filler.compute(actions, src_buf, dst_buf, opts) - if is_rename_ref then - -- Identifier usage changed only due to rename; ignore to avoid noise. - goto continue_leaf - end + filler.apply(src_buf, ns, src_fillers) + filler.apply(dst_buf, ns, dst_fillers) - if helpers.is_rename_identifier(src_node) or helpers.is_rename_identifier(dst_node) then - -- Rename: highlight identifier with inline "was"/"->". - if not rename_signs_src[csr] then - rename_signs_src[csr] = true - signs_src[csr] = true - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc, { - end_row = cser, - end_col = csec, - hl_group = "DiffRenameText", - sign_text = "R", - sign_hl_group = "DiffRenameText", - }) - else - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc, { - end_row = cser, - end_col = csec, - hl_group = "DiffChangeText", - }) - end + for _, action in ipairs(actions) do + if action.type == "move" then + if action.src then + table.insert(src_move_ranges, action.src) + end + if action.dst then + table.insert(dst_move_ranges, action.dst) + end + end + end - if not rename_signs[ctr] then - rename_signs[ctr] = true - signs_dst[ctr] = true - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc, { - end_row = cter, - end_col = ctec, - hl_group = "DiffRenameText", - sign_text = "R", - sign_hl_group = "DiffRenameText", - }) - else - signs_dst[ctr] = true - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc, { - end_row = cter, - end_col = ctec, - hl_group = "DiffChangeText", - }) - end - local src_key = tostring(src_node:id()) - if not rename_inline_src[src_key] then - rename_inline_src[src_key] = true - helpers.set_inline_virt_text(src_buf, ns, csr, csec, " -> " .. change.dst_text, "Comment") - end - local dst_key = tostring(dst_node:id()) - if not rename_inline_dst[dst_key] then - rename_inline_dst[dst_key] = true - helpers.set_inline_virt_text( - dst_buf, - ns, - ctr, - ctec, - string.format(" (was %s)", change.src_text), - "Comment" - ) - end - elseif - helpers.is_value_node(src_node, change.src_text) - or helpers.is_value_node(dst_node, change.dst_text) - then - -- Value change: micro-diff only (no virtual text). - local fragment = helpers.diff_fragment(change.src_text, change.dst_text) - if fragment then - local rel_start = fragment.new_start - 1 - local rel_end = fragment.new_end - if cter == ctr then - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc + rel_start, { - end_row = cter, - end_col = ctc + rel_end, - hl_group = "DiffChange", - }) - end - if cser == csr then - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc + fragment.old_start - 1, { - end_row = cser, - end_col = csc + fragment.old_end, - hl_group = "DiffChange", - }) - end + for _, action in ipairs(actions) do + local base_style = TYPE_STYLE[action.type] + if base_style then + local src = action.src + local dst = action.dst + local meta = action.metadata or {} + local style = base_style - if not update_signs_dst[ctr] then - update_signs_dst[ctr] = true - signs_dst[ctr] = true - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc + rel_end, { - sign_text = "U", - sign_hl_group = "DiffChangeText", - }) - end - if not update_signs_src[csr] then - update_signs_src[csr] = true - signs_src[csr] = true - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc + fragment.old_end, { - sign_text = "U", - sign_hl_group = "DiffChangeText", - }) - end - else - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc, { - end_row = cter, - end_col = ctec, - hl_group = "DiffChange", - sign_text = "U", - sign_hl_group = "DiffChangeText", - }) - if cser == csr then - if not update_signs_src[csr] then - update_signs_src[csr] = true - signs_src[csr] = true - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc, { - end_row = cser, - end_col = csec, - hl_group = "DiffChange", - sign_text = "U", - sign_hl_group = "DiffChangeText", - }) - else - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc, { - end_row = cser, - end_col = csec, - hl_group = "DiffChange", - }) - end - end - end - else - if cser >= csr then - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc, { - end_row = cser, - end_col = csec, - hl_group = "DiffChangeText", - }) - end - if cter >= ctr then - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc, { - end_row = cter, - end_col = ctec, - hl_group = "DiffChangeText", - }) - end - if not update_signs_src[csr] then - update_signs_src[csr] = true - signs_src[csr] = true - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, csr, csc, { - sign_text = "U", - sign_hl_group = "DiffChangeText", - }) + if action.type == "update" then + local effective_hunks = effective_update_hunks(action, src_buf, dst_buf) + if #effective_hunks == 0 then + goto continue + end + local rendered_hunk = false + for _, hunk in ipairs(effective_hunks) do + local hstyle = HUNK_STYLE[hunk.kind] or HUNK_STYLE.change + if hunk.render_as_change then + hstyle = HUNK_STYLE.change + end + if hunk.src and hstyle.src_hl then + apply_span(src_buf, ns, hunk.src, hstyle.src_hl) + if hstyle.src_hl == "DiffmanticChange" and overlaps_any(hunk.src, src_move_ranges) then + -- Add a foreground/underline accent so updates stay visible over moved regions. + apply_span(src_buf, ns, hunk.src, "DiffmanticChangeAccent") end - if not update_signs_dst[ctr] then - update_signs_dst[ctr] = true - signs_dst[ctr] = true - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, ctr, ctc, { - sign_text = "U", - sign_hl_group = "DiffChangeText", - }) + apply_sign(src_buf, ns, hunk.src.start_row, hstyle.src_sign, hstyle.src_hl, src_sign_rows) + rendered_hunk = true + end + if hunk.dst and hstyle.dst_hl then + apply_span(dst_buf, ns, hunk.dst, hstyle.dst_hl) + if hstyle.dst_hl == "DiffmanticChange" and overlaps_any(hunk.dst, dst_move_ranges) then + -- Add a foreground/underline accent so updates stay visible over moved regions. + apply_span(dst_buf, ns, hunk.dst, "DiffmanticChangeAccent") end + apply_sign(dst_buf, ns, hunk.dst.start_row, hstyle.dst_sign, hstyle.dst_hl, dst_sign_rows) + rendered_hunk = true end - - ::continue_leaf:: + end + if not rendered_hunk then + goto continue end else - helpers.highlight_internal_diff(node, target, src_buf, dst_buf, ns, { - signs_src = signs_src, - signs_dst = signs_dst, - rename_map = rename_map, - }) - end - elseif action.type == "delete" then - if is_suppressed(src_suppress, node) and node:type() ~= "field_declaration" then - goto continue_action - end - pcall(vim.api.nvim_buf_set_extmark, src_buf, ns, sr, sc, { - end_row = er, - end_col = ec, - hl_group = "DiffDeleteText", - sign_text = "-", - sign_hl_group = "DiffDeleteText", - }) - elseif action.type == "insert" then - if is_suppressed(dst_suppress, node) and node:type() ~= "field_declaration" then - goto continue_action + if (action.type == "insert" or action.type == "delete") and meta.render_as_change then + -- render_as_change inserts/deletes are represented by update hunks only. + goto continue + end + if src then + apply_span(src_buf, ns, src, style.hl) + apply_sign(src_buf, ns, src.start_row, style.sign, style.hl, src_sign_rows) + end + if dst then + apply_span(dst_buf, ns, dst, style.hl) + apply_sign(dst_buf, ns, dst.start_row, style.sign, style.hl, dst_sign_rows) + end + + if action.type == "move" then + if src and meta.to_line then + local arrow = move_to_arrow(meta.from_line, meta.to_line) + apply_virt( + src_buf, + ns, + src.start_row, + src.end_col or 0, + string.format(" %s moved to L%d", arrow, meta.to_line), + "Comment", + "eol" + ) + end + if dst and meta.from_line then + apply_virt( + dst_buf, + ns, + dst.start_row, + dst.end_col or 0, + string.format(" ⤶ from L%d", meta.from_line), + "Comment", + "eol" + ) + end + elseif action.type == "rename" then + if src and meta.new_name then + apply_virt( + src_buf, + ns, + src.start_row, + src.end_col or 0, + " -> " .. meta.new_name, + "Comment", + "inline" + ) + end + if dst and meta.old_name then + apply_virt( + dst_buf, + ns, + dst.start_row, + dst.end_col or 0, + string.format(" (was %s)", meta.old_name), + "Comment", + "inline" + ) + end + end end - pcall(vim.api.nvim_buf_set_extmark, dst_buf, ns, sr, sc, { - end_row = er, - end_col = ec, - hl_group = "DiffAddText", - sign_text = "+", - sign_hl_group = "DiffAddText", - }) + ::continue:: end - - ::continue_action:: end + + return { + src_fillers = src_fillers, + dst_fillers = dst_fillers, + } end return M diff --git a/lua/diffmantic/ui/signs.lua b/lua/diffmantic/ui/signs.lua new file mode 100644 index 0000000..e194fe5 --- /dev/null +++ b/lua/diffmantic/ui/signs.lua @@ -0,0 +1,88 @@ +local M = {} + +local SIGN_GROUP_BY_TEXT_GROUP = { + DiffmanticAdd = "DiffmanticAddSign", + DiffmanticDelete = "DiffmanticDeleteSign", + DiffmanticChange = "DiffmanticChangeSign", + DiffmanticMove = "DiffmanticMoveSign", + DiffmanticRename = "DiffmanticRenameSign", +} + +local SIGN_PRIORITY_BY_TEXT_GROUP = { + DiffmanticAdd = 40, + DiffmanticDelete = 40, + DiffmanticChange = 20, + DiffmanticMove = 10, + DiffmanticRename = 30, +} + +function M.glyph() + return vim.g.diffmantic_side_sign_glyph or "▎" +end + +function M.style() + return vim.g.diffmantic_sign_style or "both" +end + +local function normalize_sign_char(text) + if not text or text == "" then + return nil + end + return text:sub(1, 1) +end + +function M.sign_text(text) + local sign_char = normalize_sign_char(text) + if not sign_char then + return nil + end + + local style = M.style() + if style == "letter" then + return sign_char + end + if style == "gutter" then + return M.glyph() + end + return M.glyph() .. sign_char +end + +function M.group_for_hl(hl_group) + return SIGN_GROUP_BY_TEXT_GROUP[hl_group] or hl_group +end + +function M.priority_for_hl(hl_group) + return SIGN_PRIORITY_BY_TEXT_GROUP[hl_group] or 0 +end + +function M.mark(buf, ns, row, col, text, hl_group, sign_rows) + if row == nil or row < 0 then + return false + end + if not text or text == "" then + return false + end + + local priority = M.priority_for_hl(hl_group) + local existing_priority = -1 + if sign_rows and sign_rows[row] then + local existing = sign_rows[row] + existing_priority = type(existing) == "number" and existing or 0 + end + if existing_priority >= priority then + return false + end + + local ok = pcall(vim.api.nvim_buf_set_extmark, buf, ns, row, col or 0, { + sign_text = M.sign_text(text), + sign_hl_group = M.group_for_hl(hl_group), + priority = priority, + }) + + if ok and sign_rows then + sign_rows[row] = priority + end + return ok +end + +return M diff --git a/queries/c/diffmantic.scm b/queries/c/diffmantic.scm new file mode 100644 index 0000000..732c65b --- /dev/null +++ b/queries/c/diffmantic.scm @@ -0,0 +1,116 @@ +(function_definition + declarator: (function_declarator) + body: (compound_statement) @diff.function.body) @diff.function.outer + +(function_definition + declarator: (pointer_declarator + declarator: (function_declarator)) + body: (compound_statement) @diff.function.body) @diff.function.outer + +(function_declarator + declarator: (identifier) @diff.function.name) + +(function_declarator + declarator: (pointer_declarator + declarator: (identifier) @diff.function.name)) + +(struct_specifier + name: (type_identifier) @diff.class.name + body: (field_declaration_list) @diff.class.body) @diff.class.outer + +(union_specifier + name: (type_identifier) @diff.class.name + body: (field_declaration_list) @diff.class.body) @diff.class.outer + +(enum_specifier + name: (type_identifier) @diff.class.name + body: (enumerator_list) @diff.class.body) @diff.class.outer + +(init_declarator + declarator: [ + (identifier) @diff.variable.name + (parenthesized_declarator + (identifier) @diff.variable.name) + (pointer_declarator + declarator: (identifier) @diff.variable.name) + (array_declarator + declarator: (identifier) @diff.variable.name) + (pointer_declarator + declarator: (array_declarator + declarator: (identifier) @diff.variable.name)) + (array_declarator + declarator: (pointer_declarator + declarator: (identifier) @diff.variable.name)) + ]) @diff.variable.outer + +(field_declaration + declarator: [ + (field_identifier) @diff.variable.name + (pointer_declarator + declarator: (field_identifier) @diff.variable.name) + (array_declarator + declarator: (field_identifier) @diff.variable.name) + (pointer_declarator + declarator: (array_declarator + declarator: (field_identifier) @diff.variable.name)) + (array_declarator + declarator: (pointer_declarator + declarator: (field_identifier) @diff.variable.name)) + ]) @diff.variable.outer + +(field_declaration + [ + (field_identifier) @diff.variable.name + (pointer_declarator + declarator: (field_identifier) @diff.variable.name) + (array_declarator + declarator: (field_identifier) @diff.variable.name) + (pointer_declarator + declarator: (array_declarator + declarator: (field_identifier) @diff.variable.name)) + (array_declarator + declarator: (pointer_declarator + declarator: (field_identifier) @diff.variable.name)) + ]) @diff.variable.outer + +(assignment_expression + left: (_) @diff.assignment.lhs + right: (_) @diff.assignment.rhs) @diff.assignment.outer + +(return_statement) @diff.return.outer +(preproc_include) @diff.preproc.outer +(preproc_def) @diff.preproc.outer +(preproc_function_def) @diff.preproc.outer + +(function_declarator + declarator: (identifier) @diff.identifier.rename) + +(function_declarator + declarator: (pointer_declarator + declarator: (identifier) @diff.identifier.rename)) + +(struct_specifier + name: (type_identifier) @diff.identifier.rename) + +(union_specifier + name: (type_identifier) @diff.identifier.rename) + +(enum_specifier + name: (type_identifier) @diff.identifier.rename) + +(init_declarator + declarator: [ + (identifier) @diff.identifier.rename + (parenthesized_declarator + (identifier) @diff.identifier.rename) + (pointer_declarator + declarator: (identifier) @diff.identifier.rename) + (array_declarator + declarator: (identifier) @diff.identifier.rename) + (pointer_declarator + declarator: (array_declarator + declarator: (identifier) @diff.identifier.rename)) + (array_declarator + declarator: (pointer_declarator + declarator: (identifier) @diff.identifier.rename)) + ]) diff --git a/queries/cpp/diffmantic.scm b/queries/cpp/diffmantic.scm new file mode 100644 index 0000000..9526208 --- /dev/null +++ b/queries/cpp/diffmantic.scm @@ -0,0 +1,86 @@ +(function_definition + declarator: (function_declarator) + body: (compound_statement) @diff.function.body) @diff.function.outer + +(function_definition + declarator: (pointer_declarator + declarator: (function_declarator)) + body: (compound_statement) @diff.function.body) @diff.function.outer + +(function_declarator + declarator: (identifier) @diff.function.name) + +(function_declarator + declarator: (pointer_declarator + declarator: (identifier) @diff.function.name)) + +(class_specifier + name: (type_identifier) @diff.class.name + body: (field_declaration_list) @diff.class.body) @diff.class.outer + +(struct_specifier + name: (type_identifier) @diff.class.name + body: (field_declaration_list) @diff.class.body) @diff.class.outer + +(init_declarator + declarator: [ + (identifier) @diff.variable.name + (field_identifier) @diff.variable.name + ]) @diff.variable.outer + +(field_declaration + declarator: [ + (field_identifier) @diff.variable.name + (pointer_declarator + declarator: (field_identifier) @diff.variable.name) + (array_declarator + declarator: (field_identifier) @diff.variable.name) + (pointer_declarator + declarator: (array_declarator + declarator: (field_identifier) @diff.variable.name)) + (array_declarator + declarator: (pointer_declarator + declarator: (field_identifier) @diff.variable.name)) + ]) @diff.variable.outer + +(field_declaration + [ + (field_identifier) @diff.variable.name + (pointer_declarator + declarator: (field_identifier) @diff.variable.name) + (array_declarator + declarator: (field_identifier) @diff.variable.name) + (pointer_declarator + declarator: (array_declarator + declarator: (field_identifier) @diff.variable.name)) + (array_declarator + declarator: (pointer_declarator + declarator: (field_identifier) @diff.variable.name)) + ]) @diff.variable.outer + +(assignment_expression + left: (_) @diff.assignment.lhs + right: (_) @diff.assignment.rhs) @diff.assignment.outer + +(return_statement) @diff.return.outer +(preproc_include) @diff.preproc.outer +(preproc_def) @diff.preproc.outer +(preproc_function_def) @diff.preproc.outer + +(function_declarator + declarator: (identifier) @diff.identifier.rename) + +(function_declarator + declarator: (pointer_declarator + declarator: (identifier) @diff.identifier.rename)) + +(class_specifier + name: (type_identifier) @diff.identifier.rename) + +(struct_specifier + name: (type_identifier) @diff.identifier.rename) + +(init_declarator + declarator: [ + (identifier) @diff.identifier.rename + ]) diff --git a/queries/fallback.scm b/queries/fallback.scm new file mode 100644 index 0000000..506932f --- /dev/null +++ b/queries/fallback.scm @@ -0,0 +1 @@ +((_) @diff.fallback.node) diff --git a/queries/go/diffmantic.scm b/queries/go/diffmantic.scm new file mode 100644 index 0000000..f8ede2d --- /dev/null +++ b/queries/go/diffmantic.scm @@ -0,0 +1,64 @@ +(function_declaration + name: (identifier) @diff.function.name + body: (block) @diff.function.body) @diff.function.outer + +(method_declaration + name: (field_identifier) @diff.function.name + body: (block) @diff.function.body) @diff.function.outer + +(type_declaration + (type_spec + name: (type_identifier) @diff.class.name + type: [ + (struct_type) @diff.class.body + (interface_type) @diff.class.body + ])) @diff.class.outer + +(var_declaration + (var_spec + name: (identifier) @diff.variable.name)) @diff.variable.outer + +(const_declaration + (const_spec + name: (identifier) @diff.variable.name)) @diff.variable.outer + +(short_var_declaration + left: (expression_list (identifier) @diff.variable.name)) @diff.variable.outer + +(field_declaration + name: (field_identifier) @diff.variable.name) @diff.variable.outer + +(field_declaration + (field_identifier) @diff.variable.name) @diff.variable.outer + +(assignment_statement + left: (_) @diff.assignment.lhs + right: (_) @diff.assignment.rhs) @diff.assignment.outer + +(keyed_element + key: (_) @diff.assignment.lhs + value: (_) @diff.assignment.rhs) @diff.assignment.outer + +(import_declaration) @diff.import.outer +(return_statement) @diff.return.outer + +(function_declaration + name: (identifier) @diff.identifier.rename) + +(method_declaration + name: (field_identifier) @diff.identifier.rename) + +(type_declaration + (type_spec + name: (type_identifier) @diff.identifier.rename)) + +(var_declaration + (var_spec + name: (identifier) @diff.identifier.rename)) + +(const_declaration + (const_spec + name: (identifier) @diff.identifier.rename)) + +(short_var_declaration + left: (expression_list (identifier) @diff.identifier.rename)) diff --git a/queries/javascript/diffmantic.scm b/queries/javascript/diffmantic.scm new file mode 100644 index 0000000..6477d72 --- /dev/null +++ b/queries/javascript/diffmantic.scm @@ -0,0 +1,41 @@ +(function_declaration + name: (identifier) @diff.function.name + body: (statement_block) @diff.function.body) @diff.function.outer + +(method_definition + name: (property_identifier) @diff.function.name + body: (statement_block) @diff.function.body) @diff.function.outer + +(class_declaration + name: (identifier) @diff.class.name + body: (class_body) @diff.class.body) @diff.class.outer + +(variable_declarator + name: [(identifier) (object_pattern) (array_pattern)] @diff.variable.name) @diff.variable.outer + +(lexical_declaration + (variable_declarator + name: (identifier) @diff.variable.name)) @diff.variable.outer + +(assignment_expression + left: (_) @diff.assignment.lhs + right: (_) @diff.assignment.rhs) @diff.assignment.outer + +(pair + key: (_) @diff.assignment.lhs + value: (_) @diff.assignment.rhs) @diff.assignment.outer + +(import_statement) @diff.import.outer +(return_statement) @diff.return.outer + +(function_declaration + name: (identifier) @diff.identifier.rename) + +(method_definition + name: (property_identifier) @diff.identifier.rename) + +(class_declaration + name: (identifier) @diff.identifier.rename) + +(variable_declarator + name: (identifier) @diff.identifier.rename) diff --git a/queries/lua/diffmantic.scm b/queries/lua/diffmantic.scm new file mode 100644 index 0000000..5cf06a0 --- /dev/null +++ b/queries/lua/diffmantic.scm @@ -0,0 +1,35 @@ +(function_declaration + name: [(identifier) (dot_index_expression) (method_index_expression)] @diff.function.name + body: (block) @diff.function.body) @diff.function.outer + +(variable_declaration + (assignment_statement + (variable_list name: (identifier) @diff.function.name) + (expression_list value: (function_definition + body: (block) @diff.function.body)))) @diff.function.outer + +(variable_declaration + (assignment_statement + (variable_list name: (_) @diff.variable.name) + (expression_list))) @diff.variable.outer + +(assignment_statement + (variable_list) @diff.assignment.lhs + (expression_list) @diff.assignment.rhs) @diff.assignment.outer + +(return_statement) @diff.return.outer + +(function_declaration + name: (identifier) @diff.identifier.rename) + +(variable_declaration + (assignment_statement + (variable_list name: (identifier) @diff.identifier.rename) + (expression_list))) + +(assignment_statement + (variable_list name: (identifier) @diff.identifier.rename) + (expression_list)) + +(field + name: (identifier) @diff.identifier.rename) diff --git a/queries/python/diffmantic.scm b/queries/python/diffmantic.scm new file mode 100644 index 0000000..9e84d3e --- /dev/null +++ b/queries/python/diffmantic.scm @@ -0,0 +1,24 @@ +(function_definition + name: (identifier) @diff.function.name + body: (block) @diff.function.body) @diff.function.outer + +(class_definition + name: (identifier) @diff.class.name + body: (block) @diff.class.body) @diff.class.outer + +(assignment + left: (_) @diff.assignment.lhs + right: (_) @diff.assignment.rhs) @diff.assignment.outer + +(import_statement) @diff.import.outer +(import_from_statement) @diff.import.outer +(return_statement) @diff.return.outer + +(function_definition + name: (identifier) @diff.identifier.rename) + +(class_definition + name: (identifier) @diff.identifier.rename) + +(assignment + left: (identifier) @diff.identifier.rename) diff --git a/queries/typescript/diffmantic.scm b/queries/typescript/diffmantic.scm new file mode 100644 index 0000000..0d5c7e4 --- /dev/null +++ b/queries/typescript/diffmantic.scm @@ -0,0 +1,61 @@ +(function_declaration + name: (identifier) @diff.function.name + body: (statement_block) @diff.function.body) @diff.function.outer + +(method_definition + name: (property_identifier) @diff.function.name + body: (statement_block) @diff.function.body) @diff.function.outer + +(class_declaration + name: (type_identifier) @diff.class.name + body: (class_body) @diff.class.body) @diff.class.outer + +(type_alias_declaration + name: (type_identifier) @diff.class.name + value: (object_type) @diff.class.body) @diff.class.outer + +(interface_declaration + name: (type_identifier) @diff.class.name + body: (interface_body) @diff.class.body) @diff.class.outer + +(variable_declarator + name: [(identifier) (object_pattern) (array_pattern)] @diff.variable.name) @diff.variable.outer + +(lexical_declaration + (variable_declarator + name: (identifier) @diff.variable.name)) @diff.variable.outer + +(property_signature + name: (property_identifier) @diff.variable.name) @diff.variable.outer + +(assignment_expression + left: (_) @diff.assignment.lhs + right: (_) @diff.assignment.rhs) @diff.assignment.outer + +(pair + key: (_) @diff.assignment.lhs + value: (_) @diff.assignment.rhs) @diff.assignment.outer + +(import_statement) @diff.import.outer +(return_statement) @diff.return.outer + +(function_declaration + name: (identifier) @diff.identifier.rename) + +(method_definition + name: (property_identifier) @diff.identifier.rename) + +(class_declaration + name: (type_identifier) @diff.identifier.rename) + +(type_alias_declaration + name: (type_identifier) @diff.identifier.rename) + +(interface_declaration + name: (type_identifier) @diff.identifier.rename) + +(variable_declarator + name: (identifier) @diff.identifier.rename) + +(property_signature + name: (property_identifier) @diff.identifier.rename) diff --git a/test/benchmark.lua b/test/benchmark.lua index 2f267fa..c3d269f 100644 --- a/test/benchmark.lua +++ b/test/benchmark.lua @@ -4,6 +4,7 @@ local core = require("diffmantic.core") local ts = require("diffmantic.treesitter") +local roles = require("diffmantic.core.roles") -- Redirect output to file for headless mode local output_file = io.open("/tmp/gumtree_benchmark.txt", "w") @@ -33,7 +34,7 @@ local function generate_lua_code(num_functions, vars_per_function) return table.concat(lines, "\n") end --- Generate modified version (swaps some functions, changes some values) +-- Generate modified version (swaps some functions, changes some values, renames symbols) local function generate_modified_lua(num_functions, vars_per_function, changes) local lines = {} table.insert(lines, "-- Auto-generated benchmark file (modified)") @@ -54,17 +55,37 @@ local function generate_modified_lua(num_functions, vars_per_function, changes) end end + local rename_fn_count = math.min(changes.renames or 0, num_functions) + for _, i in ipairs(order) do - table.insert(lines, string.format("function M.func_%d()", i)) + local function_name = string.format("func_%d", i) + local renamed_function_name = nil + if i <= rename_fn_count then + renamed_function_name = function_name .. "_renamed" + end + table.insert(lines, string.format("function M.%s()", renamed_function_name or function_name)) + + local first_var = string.format("var_%d_1", i) + local renamed_first_var = nil + if i <= rename_fn_count then + renamed_first_var = string.format("renamed_%d_1", i) + end + for j = 1, vars_per_function do + local var_name = string.format("var_%d_%d", i, j) + if j == 1 and renamed_first_var then + var_name = renamed_first_var + end local value = i * 100 + j -- Change some values if i <= (changes.updates or 0) and j == 1 then value = value + 1000 end - table.insert(lines, string.format(" local var_%d_%d = %d", i, j, value)) + table.insert(lines, string.format(" local %s = %d", var_name, value)) end - table.insert(lines, string.format(" return var_%d_1 + var_%d_%d", i, i, vars_per_function)) + + local return_first_var = renamed_first_var or first_var + table.insert(lines, string.format(" return %s + var_%d_%d", return_first_var, i, vars_per_function)) table.insert(lines, "end") table.insert(lines, "") end @@ -84,7 +105,7 @@ local function run_benchmark(name, num_functions, vars_per_function, changes) log(string.format("\n=== %s ===", name)) log(string.format("Source: %d lines, Dest: %d lines", #src_lines, #dst_lines)) log(string.format("Functions: %d, Vars/function: %d", num_functions, vars_per_function)) - log(string.format("Changes: %d swaps, %d updates", changes.swaps or 0, changes.updates or 0)) + log(string.format("Changes: %d swaps, %d updates, %d renames", changes.swaps or 0, changes.updates or 0, changes.renames or 0)) -- Create buffers local src_buf = vim.api.nvim_create_buf(false, true) @@ -101,39 +122,69 @@ local function run_benchmark(name, num_functions, vars_per_function, changes) local dst_tree = dst_parser:parse()[1] local src_root = src_tree:root() local dst_root = dst_tree:root() + local src_role_index = roles.build_index(src_root, src_buf) + local dst_role_index = roles.build_index(dst_root, dst_buf) -- Benchmark the full matching pipeline local start_total = vim.loop.hrtime() -- Top-down match (includes preprocessing internally) local start_topdown = vim.loop.hrtime() - local mappings, src_info, dst_info = core.top_down_match(src_root, dst_root, src_buf, dst_buf) + local mappings, src_info, dst_info = core.top_down_match(src_root, dst_root, src_buf, dst_buf, { + adaptive_mode = true, + }) local topdown_time = (vim.loop.hrtime() - start_topdown) / 1e6 + local src_node_count = 0 + local dst_node_count = 0 + for _ in pairs(src_info) do + src_node_count = src_node_count + 1 + end + for _ in pairs(dst_info) do + dst_node_count = dst_node_count + 1 + end + local max_nodes = math.max(src_node_count, dst_node_count) + -- Bottom-up match local start_bottomup = vim.loop.hrtime() - mappings = core.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src_buf, dst_buf) + mappings = core.bottom_up_match(mappings, src_info, dst_info, src_root, dst_root, src_buf, dst_buf, { + src_role_index = src_role_index, + dst_role_index = dst_role_index, + adaptive_mode = true, + }) local bottomup_time = (vim.loop.hrtime() - start_bottomup) / 1e6 -- Recovery match (simple recovery) local start_recovery = vim.loop.hrtime() - mappings = core.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_buf, dst_buf) + mappings = core.recovery_match(src_root, dst_root, mappings, src_info, dst_info, src_buf, dst_buf, { + recovery_lcs_cell_limit = max_nodes >= 25000 and 1500 or 6000, + adaptive_mode = true, + }) local recovery_time = (vim.loop.hrtime() - start_recovery) / 1e6 -- Generate actions local start_actions = vim.loop.hrtime() - local actions, action_timings = core.generate_actions(src_root, dst_root, mappings, src_info, dst_info, { timings = true }) + local actions, action_timings = core.generate_actions(src_root, dst_root, mappings, src_info, dst_info, { + timings = true, + src_buf = src_buf, + dst_buf = dst_buf, + src_role_index = src_role_index, + dst_role_index = dst_role_index, + adaptive_mode = true, + }) local actions_time = (vim.loop.hrtime() - start_actions) / 1e6 local total_time = (vim.loop.hrtime() - start_total) / 1e6 -- Count actions - local move_count, update_count, delete_count, insert_count = 0, 0, 0, 0 + local move_count, update_count, delete_count, insert_count, rename_count = 0, 0, 0, 0, 0 for _, action in ipairs(actions) do if action.type == "move" then move_count = move_count + 1 elseif action.type == "update" then update_count = update_count + 1 + elseif action.type == "rename" then + rename_count = rename_count + 1 elseif action.type == "delete" then delete_count = delete_count + 1 elseif action.type == "insert" then @@ -148,10 +199,15 @@ local function run_benchmark(name, num_functions, vars_per_function, changes) log(string.format(" Actions: %8.2f ms", actions_time)) if action_timings then log(string.format(" Prep: %8.2f ms", action_timings.precompute or 0)) + log(string.format(" Roles: %8.2f ms", action_timings.roles or 0)) log(string.format(" Updates: %8.2f ms", action_timings.updates or 0)) log(string.format(" Moves: %8.2f ms", action_timings.moves or 0)) + log(string.format(" Renames: %8.2f ms", action_timings.renames or 0)) log(string.format(" Deletes: %8.2f ms", action_timings.deletes or 0)) log(string.format(" Inserts: %8.2f ms", action_timings.inserts or 0)) + log(string.format(" Semantic: %8.2f ms", action_timings.semantic or 0)) + log(string.format(" Analysis: %8.2f ms", action_timings.analysis or 0)) + log(string.format(" Suppress: %8.2f ms", action_timings.update_suppress or 0)) end log(string.format(" TOTAL: %8.2f ms", total_time)) @@ -159,9 +215,10 @@ local function run_benchmark(name, num_functions, vars_per_function, changes) log(string.format(" Mappings: %d", #mappings)) log( string.format( - " Actions: %d (moves=%d, updates=%d, deletes=%d, inserts=%d)", + " Actions: %d (moves=%d, renames=%d, updates=%d, deletes=%d, inserts=%d)", #actions, move_count, + rename_count, update_count, delete_count, insert_count @@ -194,15 +251,15 @@ local results = {} -- Test cases: (name, functions, vars_per_func, {swaps, updates}) local test_cases = { - { "~100 lines", 10, 5, { swaps = 2, updates = 2 } }, - { "~250 lines", 25, 5, { swaps = 5, updates = 5 } }, - { "~500 lines", 50, 5, { swaps = 10, updates = 10 } }, - { "~750 lines", 75, 5, { swaps = 15, updates = 15 } }, - { "~1000 lines", 100, 5, { swaps = 20, updates = 20 } }, - { "~2500 lines", 280, 5, { swaps = 50, updates = 50 } }, - { "~5000 lines", 560, 5, { swaps = 100, updates = 100 } }, - { "~7500 lines", 840, 5, { swaps = 150, updates = 150 } }, - { "~10000 lines", 1120, 5, { swaps = 200, updates = 200 } }, + { "~100 lines", 10, 5, { swaps = 2, updates = 2, renames = 2 } }, + { "~250 lines", 25, 5, { swaps = 5, updates = 5, renames = 5 } }, + { "~500 lines", 50, 5, { swaps = 10, updates = 10, renames = 10 } }, + { "~750 lines", 75, 5, { swaps = 15, updates = 15, renames = 15 } }, + { "~1000 lines", 100, 5, { swaps = 20, updates = 20, renames = 20 } }, + { "~2500 lines", 280, 5, { swaps = 50, updates = 50, renames = 50 } }, + { "~5000 lines", 560, 5, { swaps = 100, updates = 100, renames = 100 } }, + { "~7500 lines", 840, 5, { swaps = 150, updates = 150, renames = 150 } }, + { "~10000 lines", 1120, 5, { swaps = 200, updates = 200, renames = 200 } }, } for _, tc in ipairs(test_cases) do diff --git a/test/comparison/after.c b/test/comparison/after.c new file mode 100644 index 0000000..0ca8a1a --- /dev/null +++ b/test/comparison/after.c @@ -0,0 +1,62 @@ +#include +#include + +#define MAX_USERS 500 +const char *ROLE = "viewer"; + +struct User { + char name[64]; + char email[64]; + int active; + const char *created_at; +}; + +void format_user_display(const struct User *user, char *out, size_t out_size) { + snprintf(out, out_size, "%s <%s>", user->name, user->email); +} + +int validate_email_address(const char *email_str) { + const char *at = strchr(email_str, '@'); + if (!at) { + return 0; + } + if (!strchr(at, '.')) { + return 0; + } + return 1; +} + +struct User create_user(const char *username, const char *email, const char *role) { + if (!validate_email_address(email)) { + fprintf(stderr, "Invalid email format\n"); + } + + struct User user; + memset(&user, 0, sizeof(user)); + strncpy(user.name, username, sizeof(user.name) - 1); + strncpy(user.email, email, sizeof(user.email) - 1); + user.active = 1; + user.created_at = NULL; + (void)role; + return user; +} + +const char *get_user_permissions(const char *role) { + if (strcmp(role, "member") == 0) { + return "read"; + } + if (strcmp(role, "editor") == 0) { + return "read,write"; + } + if (strcmp(role, "admin") == 0) { + return "read,write,delete,manage"; + } + if (strcmp(role, "superadmin") == 0) { + return "read,write,delete,manage,configure"; + } + return ""; +} + +void deactivate_user(int user_id) { + printf("Deactivating user %d\n", user_id); +} diff --git a/test/comparison/after.cpp b/test/comparison/after.cpp new file mode 100644 index 0000000..7edb81c --- /dev/null +++ b/test/comparison/after.cpp @@ -0,0 +1,62 @@ +#include +#include +#include + +const int MAX_USERS = 500; +const std::string ROLE = "viewer"; + +struct User { + std::string name; + std::string email; + bool active; + std::string created_at; +}; + +std::string format_user_display(const User &user) { + std::string result = user.name + " <" + user.email + ">"; + return result; +} + +bool validate_email_address(const std::string &email_str) { + if (email_str.find('@') == std::string::npos) { + return false; + } + if (email_str.find('.') == std::string::npos) { + return false; + } + return true; +} + +User create_user(const std::string &username, const std::string &email, const std::string &role = ROLE) { + if (!validate_email_address(email)) { + throw std::runtime_error("Invalid email format"); + } + + User user; + user.name = username; + user.email = email; + user.active = true; + user.created_at = ""; + (void)role; + return user; +} + +std::string get_user_permissions(const std::string &role) { + if (role == "member") { + return "read"; + } + if (role == "editor") { + return "read,write"; + } + if (role == "admin") { + return "read,write,delete,manage"; + } + if (role == "superadmin") { + return "read,write,delete,manage,configure"; + } + return ""; +} + +void deactivate_user(int user_id) { + std::printf("Deactivating user %d\n", user_id); +} diff --git a/test/comparison/after.go b/test/comparison/after.go new file mode 100644 index 0000000..1b530aa --- /dev/null +++ b/test/comparison/after.go @@ -0,0 +1,57 @@ +package main + +import "fmt" + +type User struct { + Name string + Email string + Role string + Active bool + CreatedAt *string +} + +const MAX_USERS = 500 +const ROLE = "viewer" + +func formatUserDisplay(user User) string { + result := fmt.Sprintf("%s <%s>", user.Name, user.Email) + return result +} + +func validateEmailAddress(emailStr string) bool { + if len(emailStr) == 0 { + return false + } + return true +} + +func createUser(username string, email string, role string) User { + if role == "" { + role = ROLE + } + if !validateEmailAddress(email) { + panic("invalid email") + } + return User{Name: username, Email: email, Role: role, Active: true, CreatedAt: nil} +} + +func getUserPermissions(role string) []string { + if role == "member" { + return []string{"read"} + } + if role == "editor" { + return []string{"read", "write"} + } + if role == "admin" { + return []string{"read", "write", "delete", "manage"} + } + if role == "superadmin" { + return []string{"read", "write", "delete", "manage", "configure"} + } + return []string{} +} + +func deactivateUser(userID int) bool { + fmt.Printf("Deactivating user %d\n", userID) + return true +} diff --git a/test/comparison/after.js b/test/comparison/after.js new file mode 100644 index 0000000..dde230c --- /dev/null +++ b/test/comparison/after.js @@ -0,0 +1,54 @@ +// User management module. + +const MAX_USERS = 500; +const ROLE = "viewer"; + +function formatUserDisplay(user) { + const result = `${user.name} <${user.email}>`; + return result; +} + +function validateEmailAddress(emailStr) { + if (!emailStr.includes("@")) { + return false; + } + if (!emailStr.split("@")[1].includes(".")) { + return false; + } + return true; +} + +function createUser(username, email, role = ROLE) { + if (!validateEmailAddress(email)) { + throw new Error("Invalid email format"); + } + + return { + name: username, + email, + role, + active: true, + createdAt: null, + }; +} + +function getUserPermissions(role) { + if (role === "member") { + return ["read"]; + } + if (role === "editor") { + return ["read", "write"]; + } + if (role === "admin") { + return ["read", "write", "delete", "manage"]; + } + if (role === "superadmin") { + return ["read", "write", "delete", "manage", "configure"]; + } + return []; +} + +function deactivateUser(userId) { + console.log(`Deactivating user ${userId}`); + return true; +} diff --git a/test/comparison/after.lua b/test/comparison/after.lua new file mode 100644 index 0000000..2c288dd --- /dev/null +++ b/test/comparison/after.lua @@ -0,0 +1,66 @@ +-- User management module. + +local MAX_USERS = 500 +local ROLE = "viewer" + +local function format_user_display(username) + local x = 10 + return username.name .. " <" .. username.email .. ">" .. x +end + +local function validate_email_address(email_str) + if not string.find(email_str, "@") then + return false + end + local at = string.find(email_str, "@") + if not string.find(string.sub(email_str, at), ".") then + return false + end + return true +end + +local function create_user(username, email, role) + role = role or ROLE + if not validate_email_address(email) then + error("Invalid email format") + end + + local user = { + name = username, + email = email, + active = true, + created_at = nil, + } + return user +end + +local function get_user_permissions(user) + local permissions = { + member = { "read" }, + editor = { "read", "write" }, + admin = { "read", "write", "delete", "manage" }, + superadmin = { + "read", + "write", + "delete", + "manage", + "configure", + }, + } + return permissions[user.role] or {} +end + +local function deactivate_user(user_id) + print("Deactivating user " .. user_id) + return true +end + +return { + MAX_USERS = MAX_USERS, + ROLE = ROLE, + format_user_display = format_user_display, + validate_email_address = validate_email_address, + create_user = create_user, + get_user_permissions = get_user_permissions, + deactivate_user = deactivate_user, +} diff --git a/test/comparison/after.py b/test/comparison/after.py new file mode 100644 index 0000000..04c9fb0 --- /dev/null +++ b/test/comparison/after.py @@ -0,0 +1,54 @@ +"""User management module.""" + +MAX_USERS = 500 +ROLE = "viewer" + + +def format_user_display(user): + return f"{user['name']} <{user['email']}>" + + +def validate_email_address(email_str): + """Check if email format is valid.""" + if "@" not in email_str: + return False + if "." not in email_str.split("@")[1]: + return False + return True + + +def create_user(username, email, role=ROLE): + """Create a new user with the given details.""" + if not validate_email_address(email): + raise ValueError("Invalid email format") + + user = { + "name": username, + "email": email, + "role": role, + "active": True, + "created_at": None, + } + return user + + +def get_user_permissions(user): + """Get permissions based on user role.""" + permissions = { + "member": ["read"], + "admin": ["read", "write", "delete", "manage"], + "superadmin": [ + "read", + "write", + "delete", + "manage", + "configure", + ], + } + return permissions.get(user["role"], []) + + +def deactivate_user(user_id): + """Deactivate a user by ID instead of deleting.""" + print(f"Deactivating user {user_id}") + return True diff --git a/test/comparison/after.ts b/test/comparison/after.ts new file mode 100644 index 0000000..7846f05 --- /dev/null +++ b/test/comparison/after.ts @@ -0,0 +1,61 @@ +// User management module. + +type Role = "viewer" | "editor" | "admin" | "superadmin"; + +type User = { + name: string; + email: string; + role: Role; + active: boolean; + createdAt: string | null; +}; + +const MAX_USERS: number = 500; +const ROLE: Role = "viewer"; + +function formatUserDisplay(user: User): string { + const result = `${user.name} <${user.email}>`; + return result; +} + +function validateEmailAddress(emailStr: string): boolean { + if (!emailStr.includes("@")) { + return false; + } + if (!emailStr.split("@")[1].includes(".")) { + return false; + } + return true; +} + +function createUser(username: string, email: string, role: Role = ROLE): User { + if (!validateEmailAddress(email)) { + throw new Error("Invalid email format"); + } + + return { + name: username, + email, + role, + active: true, + createdAt: null, + }; +} + +function getUserPermissions(role: Role): string[] { + if (role === "member") { + return ["read"]; + } + if (role === "editor") { + return ["read", "write"]; + } + if (role === "admin") { + return ["read", "write", "delete", "manage"]; + } + return ["read", "write", "delete", "manage", "configure"]; +} + +function deactivateUser(userId: number): boolean { + console.log(`Deactivating user ${userId}`); + return true; +} diff --git a/test/comparison/before.c b/test/comparison/before.c new file mode 100644 index 0000000..f08f15a --- /dev/null +++ b/test/comparison/before.c @@ -0,0 +1,58 @@ +#include +#include + +#define MAX_USERS 100 +const char *DEFAULT_ROLE = "viewer"; + +int validate_email(const char *email) { + const char *at = strchr(email, '@'); + if (!at) { + return 0; + } + if (!strchr(at, '.')) { + return 0; + } + return 1; +} + +struct User { + char name[64]; + char email[64]; + char role[32]; + int active; +}; + +struct User create_user(const char *name, const char *email, const char *role) { + if (!validate_email(email)) { + fprintf(stderr, "Invalid email format\n"); + } + + struct User user; + memset(&user, 0, sizeof(user)); + strncpy(user.name, name, sizeof(user.name) - 1); + strncpy(user.email, email, sizeof(user.email) - 1); + strncpy(user.role, role, sizeof(user.role) - 1); + user.active = 1; + return user; +} + +const char *get_user_permissions(const char *role) { + if (strcmp(role, "viewer") == 0) { + return "read"; + } + if (strcmp(role, "editor") == 0) { + return "read,write"; + } + if (strcmp(role, "admin") == 0) { + return "read,write,delete,manage"; + } + return ""; +} + +void delete_user(int user_id) { + printf("Deleting user %d\n", user_id); +} + +void format_user_display(const struct User *user, char *out, size_t out_size) { + snprintf(out, out_size, "%s <%s>", user->name, user->email); +} diff --git a/test/comparison/before.cpp b/test/comparison/before.cpp new file mode 100644 index 0000000..32cb591 --- /dev/null +++ b/test/comparison/before.cpp @@ -0,0 +1,57 @@ +#include +#include +#include + +const int MAX_USERS = 100; +const std::string DEFAULT_ROLE = "viewer"; + +bool validate_email(const std::string &email) { + if (email.find('@') == std::string::npos) { + return false; + } + if (email.find('.') == std::string::npos) { + return false; + } + return true; +} + +struct User { + std::string name; + std::string email; + std::string role; + bool active; +}; + +User create_user(const std::string &name, const std::string &email, const std::string &role = DEFAULT_ROLE) { + if (!validate_email(email)) { + throw std::runtime_error("Invalid email format"); + } + + User user; + user.name = name; + user.email = email; + user.role = role; + user.active = true; + return user; +} + +std::string get_user_permissions(const std::string &role) { + if (role == "viewer") { + return "read"; + } + if (role == "editor") { + return "read,write"; + } + if (role == "admin") { + return "read,write,delete,manage"; + } + return ""; +} + +void delete_user(int user_id) { + std::printf("Deleting user %d\n", user_id); +} + +std::string format_user_display(const User &user) { + return user.name + " <" + user.email + ">"; +} diff --git a/test/comparison/before.go b/test/comparison/before.go new file mode 100644 index 0000000..3a20275 --- /dev/null +++ b/test/comparison/before.go @@ -0,0 +1,52 @@ +package main + +import "fmt" + +type User struct { + Name string + Email string + Role string + Active bool +} + +const MAX_USERS = 100 +const DEFAULT_ROLE = "viewer" + +func validateEmail(email string) bool { + if len(email) == 0 { + return false + } + return true +} + +func createUser(name string, email string, role string) User { + if role == "" { + role = DEFAULT_ROLE + } + if !validateEmail(email) { + panic("invalid email") + } + return User{Name: name, Email: email, Role: role, Active: true} +} + +func getUserPermissions(role string) []string { + if role == "viewer" { + return []string{"read"} + } + if role == "editor" { + return []string{"read", "write"} + } + if role == "admin" { + return []string{"read", "write", "delete", "manage"} + } + return []string{} +} + +func deleteUser(userID int) bool { + fmt.Printf("Deleting user %d\n", userID) + return true +} + +func formatUserDisplay(user User) string { + return fmt.Sprintf("%s <%s>", user.Name, user.Email) +} diff --git a/test/comparison/before.js b/test/comparison/before.js new file mode 100644 index 0000000..cb18a04 --- /dev/null +++ b/test/comparison/before.js @@ -0,0 +1,49 @@ +// User management module. + +const MAX_USERS = 100; +const DEFAULT_ROLE = "viewer"; + +function validateEmail(email) { + if (!email.includes("@")) { + return false; + } + if (!email.split("@")[1].includes(".")) { + return false; + } + return true; +} + +function createUser(name, email, role = DEFAULT_ROLE) { + if (!validateEmail(email)) { + throw new Error("Invalid email format"); + } + + return { + name, + email, + role, + active: true, + }; +} + +function getUserPermissions(role) { + if (role === "viewer") { + return ["read"]; + } + if (role === "editor") { + return ["read", "write"]; + } + if (role === "admin") { + return ["read", "write", "delete", "manage"]; + } + return []; +} + +function deleteUser(userId) { + console.log(`Deleting user ${userId}`); + return true; +} + +function formatUserDisplay(user) { + return `${user.name} <${user.email}>`; +} diff --git a/test/comparison/before.lua b/test/comparison/before.lua new file mode 100644 index 0000000..6b1324c --- /dev/null +++ b/test/comparison/before.lua @@ -0,0 +1,58 @@ +-- User management module. + +local MAX_USERS = 100 +local DEFAULT_ROLE = "viewer" + +local function validate_email(email) + if not string.find(email, "@") then + return false + end + local at = string.find(email, "@") + if not string.find(string.sub(email, at), ".") then + return false + end + return true +end + +local function create_user(name, email, role) + role = role or DEFAULT_ROLE + if not validate_email(email) then + error("Invalid email format") + end + + local user = { + name = name, + email = email, + role = role, + active = true, + } + return user +end + +local function get_user_permissions(user) + local permissions = { + viewer = { "read" }, + editor = { "read", "write" }, + admin = { "read", "write", "delete", "manage" }, + } + return permissions[user.role] or {} +end + +local function delete_user(user_id) + print("Deleting user " .. user_id) + return true +end + +local function format_user_display(user) + return user.name .. " <" .. user.email .. ">" +end + +return { + MAX_USERS = MAX_USERS, + DEFAULT_ROLE = DEFAULT_ROLE, + validate_email = validate_email, + create_user = create_user, + get_user_permissions = get_user_permissions, + delete_user = delete_user, + format_user_display = format_user_display, +} diff --git a/test/comparison/before.py b/test/comparison/before.py new file mode 100644 index 0000000..9c469dd --- /dev/null +++ b/test/comparison/before.py @@ -0,0 +1,48 @@ +"""User management module.""" + +MAX_USERS = 100 +DEFAULT_ROLE = "viewer" + + +def validate_email(email): + """Check if email format is valid.""" + if "@" not in email: + return False + if "." not in email.split("@")[1]: + return False + return True + + +def create_user(name, email, role=DEFAULT_ROLE): + """Create a new user with the given details.""" + if not validate_email(email): + raise ValueError("Invalid email format") + + user = { + "name": name, + "email": email, + "role": role, + "active": True, + } + return user + + +def get_user_permissions(user): + """Get permissions based on user role.""" + permissions = { + "viewer": ["read"], + "editor": ["read", "write"], + "admin": ["read", "write", "delete", "manage"], + } + return permissions.get(user["role"], []) + + +def format_user_display(user): + """Format user for display.""" + return f"{user['name']} <{user['email']}>" + + +def delete_user(user_id): + """Delete a user by ID.""" + print(f"Deleting user {user_id}") + return True diff --git a/test/comparison/before.ts b/test/comparison/before.ts new file mode 100644 index 0000000..e3e10d4 --- /dev/null +++ b/test/comparison/before.ts @@ -0,0 +1,55 @@ +// User management module. + +type Role = "viewer" | "editor" | "admin"; + +type User = { + name: string; + email: string; + role: Role; + active: boolean; +}; + +const MAX_USERS: number = 100; +const DEFAULT_ROLE: Role = "viewer"; + +function validateEmail(email: string): boolean { + if (!email.includes("@")) { + return false; + } + if (!email.split("@")[1].includes(".")) { + return false; + } + return true; +} + +function createUser(name: string, email: string, role: Role = DEFAULT_ROLE): User { + if (!validateEmail(email)) { + throw new Error("Invalid email format"); + } + + return { + name, + email, + role, + active: true, + }; +} + +function getUserPermissions(role: Role): string[] { + if (role === "viewer") { + return ["read"]; + } + if (role === "editor") { + return ["read", "write"]; + } + return ["read", "write", "delete", "manage"]; +} + +function deleteUser(userId: number): boolean { + console.log(`Deleting user ${userId}`); + return true; +} + +function formatUserDisplay(user: User): string { + return `${user.name} <${user.email}>`; +} diff --git a/test/test2.py b/test/test2.py index 1654080..94786f9 100644 --- a/test/test2.py +++ b/test/test2.py @@ -2,11 +2,6 @@ # Tests: Move, Update, Insert, Rename -def calculate_difference(a, b): - """Subtract two numbers.""" - return a - b - - def calculate_sum(a, b): """Add two numbers.""" return a + b @@ -35,6 +30,11 @@ def fetch_data(self, path): return self.base_url + path +def calculate_difference(x, y): + """Subtract two numbers.""" + return x - y + + # Configuration API_URL = "https://api.example.com/v2" # Updated URL CACHE_DIR = "/tmp/app/cache-v2"