diff --git a/editdistance/_edit_distance_osa.cpp b/editdistance/_edit_distance_osa.cpp index e1a6feb..632dee3 100644 --- a/editdistance/_edit_distance_osa.cpp +++ b/editdistance/_edit_distance_osa.cpp @@ -77,7 +77,7 @@ std::vector> backtrack_all_paths( const double tol = 1e-6; if (i > 0 && std::abs((dp[i-1][j] + delete_weight) - current_cost) < tol) { - CppEditop op(DELETE, i-1, i-1, delete_weight, std::string(1, a[i-1])); + CppEditop op(DELETE, i-1, j, delete_weight, std::string(1, a[i-1])); current_path.push_back(op); auto paths = backtrack_all_paths(a, b, dp, i-1, j, current_path, replace_weight, insert_weight, delete_weight, swap_weight); all_paths.insert(all_paths.end(), paths.begin(), paths.end()); @@ -85,7 +85,7 @@ std::vector> backtrack_all_paths( } if (j > 0 && std::abs((dp[i][j-1] + insert_weight) - current_cost) < tol) { - CppEditop op(INSERT, i, i, insert_weight, std::string(1, b[j-1])); + CppEditop op(INSERT, i, j-1, insert_weight, std::string(1, b[j-1])); current_path.push_back(op); auto paths = backtrack_all_paths(a, b, dp, i, j-1, current_path, replace_weight, insert_weight, delete_weight, swap_weight); all_paths.insert(all_paths.end(), paths.begin(), paths.end()); diff --git a/editdistance/edit_distance_osa.pyx b/editdistance/edit_distance_osa.pyx index aa5dcc9..68356e9 100644 --- a/editdistance/edit_distance_osa.pyx +++ b/editdistance/edit_distance_osa.pyx @@ -98,6 +98,29 @@ def get_all_paths( python_paths.append(python_path) return python_paths +def apply_editops(src, dst, editops): + src_idx = 0 + s = "" + for op in editops: + while src_idx < op.src_idx: + s += src[src_idx] + src_idx += 1 + if op.name == EditopName.INSERT: + s += dst[op.dst_idx] + elif op.name == EditopName.DELETE: + src_idx += 1 + elif op.name == EditopName.REPLACE: + s += dst[op.dst_idx] + src_idx += 1 + elif op.name == EditopName.SWAP: + s += src[op.src_idx + 1] + s += src[op.src_idx] + src_idx += 2 + while src_idx < len(src): + s += src[src_idx] + src_idx += 1 + return s + def print_all_paths( str a, diff --git a/tests/tests_osa.py b/tests/tests_osa.py index 6035ab2..5647022 100644 --- a/tests/tests_osa.py +++ b/tests/tests_osa.py @@ -1,6 +1,7 @@ import unittest from editdistance.osa import ( + apply_editops, compute_distance, get_all_paths, ) @@ -22,6 +23,28 @@ ("entirely different", "cab", "axb", 2), ] +EDITOPS_TRANSFORM_TEST_CASES = [ + ("abc", "acb"), + ("kitten", "sitting"), + ("flaw", "lawn"), + ("", "abc"), + ("abc", ""), + ("abcdef", "azced"), + ("a", "a"), + ("a", ""), + ("", ""), + ("banana", "ban"), + ("intention", "execution"), + ("gumbo", "gambol"), + ("sunday", "saturday"), + ("ca", "abc"), + ("abcdef", "fedcba"), + ("racecar", "racecar"), + ("spelling", "spilling"), + ("distance", "instance"), + ("book", "back"), +] + class TestOsaDistance(unittest.TestCase): def test_compute_distance(self): @@ -49,3 +72,12 @@ def test_get_all_paths(self): ): paths = get_all_paths(source, target) self.assertEqual(len(paths), expected_num_paths) + + def test_editops_transform(self): + for src, dst in EDITOPS_TRANSFORM_TEST_CASES: + with self.subTest(src=src, dst=dst): + paths = get_all_paths(src, dst) + self.assertTrue(paths, f"No paths found for {src} -> {dst}") + for path in paths: + result = apply_editops(src, dst, path) + self.assertEqual(result, dst, f"Failed for {src} -> {dst} with path {path}")