Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 23 additions & 53 deletions git_sync_filtered/cli.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,15 @@
from fnmatch import translate as glob_translate
from pathlib import Path

import click
from pydantic import BaseModel, ConfigDict, FilePath, field_validator

from git_sync_filtered.sync import sync


class SyncConfig(BaseModel):
model_config = ConfigDict(frozen=True)

private: str
public: str
keep: tuple[str, ...]
keep_from_file: FilePath | None = None
sync_branch: str = "upstream/sync"
main_branch: str = "main"
private_branch: str = "main"
dry_run: bool = False
merge: bool = False
force: bool = False

@field_validator("keep", mode="before")
@classmethod
def ensure_non_empty(cls, v: tuple[str, ...]) -> tuple[str, ...]:
if not v:
raise ValueError("At least one --keep path required")
return v

@field_validator("keep", mode="after")
@classmethod
def validate_glob_paths(cls, v: tuple[str, ...]) -> tuple[str, ...]:
for path in v:
if not path:
raise ValueError("Keep path cannot be empty")
glob_translate(path)
return v

@field_validator("sync_branch", "main_branch", "private_branch", mode="after")
@classmethod
def validate_branch_name(cls, v: str) -> str:
if not v:
raise ValueError("Branch name cannot be empty")
if v.startswith("/") or ".." in v:
raise ValueError(f"Invalid branch name: {v!r}")
return v
def _validate_branch(name: str) -> None:
if not name:
raise ValueError("Branch name cannot be empty")
if name.startswith("/") or ".." in name:
raise ValueError(f"Invalid branch name: {name!r}")


@click.command()
Expand All @@ -64,6 +29,16 @@ def validate_branch_name(cls, v: str) -> str:
)
@click.option("--merge", is_flag=True, help="Merge into main branch after sync")
@click.option("--force", is_flag=True, help="Force push")
@click.option(
"--marker-prefix",
default="synced",
help="Prefix for sync marker in commit messages",
)
@click.option(
"--reset",
is_flag=True,
help="Reset sync state and re-sync all commits from beginning",
)
def main(
private: str,
public: str,
Expand All @@ -75,11 +50,16 @@ def main(
dry_run: bool,
merge: bool,
force: bool,
marker_prefix: str,
reset: bool,
) -> None:
"""Sync filtered commits from private to public repository."""

try:
config = SyncConfig(
for branch in (sync_branch, main_branch, private_branch):
_validate_branch(branch)

result = sync(
private=private,
public=public,
keep=keep,
Expand All @@ -90,18 +70,8 @@ def main(
dry_run=dry_run,
merge=merge,
force=force,
)
result = sync(
private=config.private,
public=config.public,
keep=config.keep,
keep_from_file=config.keep_from_file,
sync_branch=config.sync_branch,
main_branch=config.main_branch,
private_branch=config.private_branch,
dry_run=config.dry_run,
merge=config.merge,
force=config.force,
marker_prefix=marker_prefix,
reset=reset,
)
except ValueError as e:
raise click.ClickException(str(e))
Expand Down
14 changes: 14 additions & 0 deletions git_sync_filtered/lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import git


def check_sync_lock(repo: git.Repo, remote_name: str, sync_branch: str) -> bool:
"""Check if sync is already in progress by checking if sync branch exists in remote.

Assumes the caller has already fetched the remote. Uses remote_head attribute
to match just the branch name (not the full 'remote/branch' ref name).
"""
try:
refs = repo.remote(remote_name).refs
return sync_branch in [ref.remote_head for ref in refs]
except Exception:
return False
29 changes: 29 additions & 0 deletions git_sync_filtered/marker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import re


def parse_marker(message: str, prefix: str) -> str | None:
"""Extract SHA from commit message marker. Returns None if no marker found."""
pattern = rf"\[{prefix}:\s*([^\]]+)\]"
match = re.search(pattern, message)
if match:
return match.group(1)
return None


def append_marker_to_commit(message: str, sha: str, prefix: str) -> str:
"""Append or update marker in commit message."""
marker = f"[{prefix}: {sha}]"
pattern = rf"\[{prefix}:\s*[^\]]+\]"

new_message = re.sub(pattern, "", message)
new_message = new_message.rstrip()

return f"{new_message}\n\n{marker}"


def find_last_synced_sha(commit_messages: list[str], prefix: str) -> str | None:
"""Find the SHA from the most recent commit with a sync marker."""
return next(
(sha for msg in commit_messages if (sha := parse_marker(msg, prefix))),
None,
)
141 changes: 118 additions & 23 deletions git_sync_filtered/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import git
from git_filter_repo import FilteringOptions, RepoFilter

from git_sync_filtered.lock import check_sync_lock
from git_sync_filtered.marker import (
append_marker_to_commit,
find_last_synced_sha,
parse_marker,
)


class SyncResult(TypedDict):
paths_to_keep: list[str]
Expand All @@ -31,6 +38,7 @@ def collect_paths_to_keep(


def run_filter_repo(repo_path: Path | str, paths_to_keep: list[str]) -> None:
"""Run git-filter-repo to filter repository to only keep specified paths."""
old_cwd = os.getcwd()
os.chdir(repo_path)

Expand Down Expand Up @@ -59,24 +67,17 @@ def push_to_remote(
else:
repo.remote("public").set_url(public_url)

repo.remote("public").fetch()

if dry_run:
commits = []
for commit in repo.iter_commits(private_branch):
commits.append(f" {commit.hexsha[:8]} {commit.summary}")
return commits
else:
refspec = f"refs/heads/{private_branch}:refs/heads/{sync_branch}"
repo.remote("public").push(refspec=refspec, force=force)
return []
return [
f" {c.hexsha[:8]} {c.summary}" for c in repo.iter_commits(private_branch)
]

refspec = f"refs/heads/{private_branch}:refs/heads/{sync_branch}"
repo.remote("public").push(refspec=refspec, force=force)
return []

def merge_into_main(
repo: git.Repo,
main_branch: str,
sync_branch: str,
) -> bool:

def merge_into_main(repo: git.Repo, main_branch: str, sync_branch: str) -> bool:
repo.heads[main_branch].checkout()

try:
Expand All @@ -91,6 +92,55 @@ def merge_into_main(
return False


def _get_last_synced_sha_from_remote(
repo: git.Repo, sync_branch: str, marker_prefix: str
) -> str | None:
"""Get the last synced private SHA from the public repo's sync branch commit markers."""
try:
messages = [
c.message.decode("utf-8") if isinstance(c.message, bytes) else c.message
for c in repo.iter_commits(f"public/{sync_branch}")
]
return find_last_synced_sha(messages, marker_prefix)
except Exception:
return None


def _rewrite_commits_with_markers(
repo: git.Repo, branch: str, marker_prefix: str
) -> None:
"""Rewrite commit messages to include sync markers for commits not yet marked.

Processes oldest-to-newest so each amend applies cleanly in sequence.
"""
commits = list(repo.iter_commits(branch))

if not commits:
return

first_commit = commits[0]
with repo.config_writer() as config:
config.set_value("user", "name", first_commit.committer.name)
config.set_value("user", "email", first_commit.committer.email)

for commit in reversed(commits):
message = (
commit.message.decode("utf-8")
if isinstance(commit.message, bytes)
else commit.message
)

if parse_marker(message, marker_prefix):
continue

new_message = append_marker_to_commit(message, commit.hexsha, marker_prefix)

try:
repo.git.commit(message=new_message, amend=True)
except git.GitCommandError:
pass


def sync(
private: str,
public: str,
Expand All @@ -102,6 +152,8 @@ def sync(
dry_run: bool,
merge: bool,
force: bool,
marker_prefix: str = "synced",
reset: bool = False,
) -> SyncResult:
paths_to_keep = collect_paths_to_keep(keep, keep_from_file)

Expand All @@ -110,25 +162,68 @@ def sync(

with TemporaryDirectory(prefix="git-sync-") as work_dir:
work_dir_path = Path(work_dir)

# Step 1: Fetch public state to determine what was last synced
last_synced_sha: str | None = None
if not dry_run and not reset:
probe_path = work_dir_path / "probe"
probe_repo = git.Repo.clone_from(private, str(probe_path))
probe_repo.create_remote("public", public)
try:
probe_repo.remote("public").fetch()
last_synced_sha = _get_last_synced_sha_from_remote(
probe_repo, sync_branch, marker_prefix
)
except Exception: # noqa: BLE001
last_synced_sha = None

lock_branch = f"{sync_branch}-in-progress"
if check_sync_lock(probe_repo, "public", lock_branch):
raise ValueError(
f"Sync already in progress: {lock_branch} branch exists"
)

# Step 2: Clone the private repo; apply graft for incremental sync
private_clone = work_dir_path / "private"
private_repo = git.Repo.clone_from(private, str(private_clone))

if last_synced_sha:
# Graft: treat last_synced_sha as a root so filter-repo only
# rewrites commits after it
try:
private_repo.commit(last_synced_sha)
grafts_file = private_clone / ".git" / "info" / "grafts"
grafts_file.parent.mkdir(parents=True, exist_ok=True)
grafts_file.write_text(f"{last_synced_sha}\n")
except Exception: # noqa: BLE001
last_synced_sha = None # SHA gone; fall back to full sync

# Step 3: Filter the (possibly grafted) history
run_filter_repo(str(private_clone), paths_to_keep)

# Re-open Repo after filter-repo rewrites history
private_repo.close()
private_repo = git.Repo(private_clone)

# Step 4: Rewrite commit messages with sync markers
_rewrite_commits_with_markers(private_repo, private_branch, marker_prefix)

# Step 5: Push — force when incremental since SHAs are rewritten by filter-repo
dry_run_commits = push_to_remote(
private_repo, public, sync_branch, private_branch, force, dry_run
private_repo,
public,
sync_branch,
private_branch,
force=force or bool(last_synced_sha),
dry_run=dry_run,
)

merge_success: bool | None = None
if merge and not dry_run:
success = merge_into_main(private_repo, main_branch, sync_branch)
return {
"paths_to_keep": paths_to_keep,
"dry_run_commits": dry_run_commits,
"merge_success": success,
}
merge_success = merge_into_main(private_repo, main_branch, sync_branch)

return {
"paths_to_keep": paths_to_keep,
"dry_run_commits": dry_run_commits,
"merge_success": None,
"merge_success": merge_success,
}
41 changes: 41 additions & 0 deletions git_sync_filtered/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import git


def get_file_hashes(
repo: git.Repo, paths: list[str], ref: str = "HEAD"
) -> dict[str, str]:
"""Get SHA-1 object hashes for files using git ls-tree."""
hashes: dict[str, str] = {}
for path in paths:
result = repo.git.ls_tree("-r", ref, "--", path)
if not result:
continue
for line in result.splitlines():
parts = line.split()
if len(parts) >= 4:
obj_hash = parts[2]
file_path = parts[3]
hashes[file_path] = obj_hash
return hashes


def verify_sync_integrity(
private_repo: git.Repo,
public_repo: git.Repo,
paths_to_keep: list[str],
private_ref: str = "HEAD",
public_ref: str = "HEAD",
) -> bool:
"""
Verify that synced files in public repo match filtered files from private repo.

Compares file object hashes between private (filtered) and public repos.
Returns True if hashes match, False otherwise.
"""
if not paths_to_keep:
return True

private_hashes = get_file_hashes(private_repo, paths_to_keep, private_ref)
public_hashes = get_file_hashes(public_repo, paths_to_keep, public_ref)

return private_hashes == public_hashes
Loading