Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions bitnet_tools/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def _build_parser() -> argparse.ArgumentParser:
help="Optional directory to save visualization charts",
)
multi_parser.add_argument("--no-cache", action="store_true", help="Disable file profile cache")
multi_parser.add_argument("--workers", type=int, default=None, help="Optional worker count for parallel file profiling")

report_parser = subparsers.add_parser("report", help="Build markdown summary report from CSV")
report_parser.add_argument("csv", type=Path, help="Input CSV path")
Expand Down Expand Up @@ -126,6 +127,7 @@ def main(argv: list[str] | None = None) -> int:
group_column=args.group_column,
target_column=args.target_column,
use_cache=not args.no_cache,
max_workers=args.workers,
)
if args.charts_dir is not None:
try:
Expand Down
45 changes: 36 additions & 9 deletions bitnet_tools/multi_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import math
import random
from concurrent.futures import ThreadPoolExecutor
from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path
Expand Down Expand Up @@ -342,30 +343,56 @@ def _generate_insights(files: list[dict[str, Any]], schema_drift: dict[str, Any]
return insights[:30]


def _load_or_profile_file(
path: Path,
group_column: str | None,
target_column: str | None,
use_cache: bool,
) -> dict[str, Any]:
profiled = _load_cached_profile(path, group_column, target_column) if use_cache else None
if profiled is None:
profiled = _profile_csv_stream(path, group_column=group_column, target_column=target_column)
if use_cache:
_save_cached_profile(path, group_column, target_column, profiled)
return profiled


def analyze_multiple_csv(
csv_paths: list[Path],
question: str,
group_column: str | None = None,
target_column: str | None = None,
use_cache: bool = True,
max_workers: int | None = None,
) -> dict[str, Any]:
if not csv_paths:
raise ValueError('at least one CSV path is required')

files: list[dict[str, Any]] = []
all_columns: list[set[str]] = []
total_rows = 0

for path in csv_paths:
if not path.exists():
raise FileNotFoundError(f'CSV file not found: {path}')

profiled = _load_cached_profile(path, group_column, target_column) if use_cache else None
if profiled is None:
profiled = _profile_csv_stream(path, group_column=group_column, target_column=target_column)
if use_cache:
_save_cached_profile(path, group_column, target_column, profiled)
worker_count = max_workers if (max_workers is not None and max_workers > 0) else min(4, len(csv_paths))

if worker_count == 1 or len(csv_paths) == 1:
profiled_list = [
_load_or_profile_file(path, group_column, target_column, use_cache)
for path in csv_paths
]
else:
with ThreadPoolExecutor(max_workers=worker_count) as executor:
profiled_list = list(
executor.map(
lambda p: _load_or_profile_file(p, group_column, target_column, use_cache),
csv_paths,
)
)

files: list[dict[str, Any]] = []
all_columns: list[set[str]] = []
total_rows = 0

for path, profiled in zip(csv_paths, profiled_list):
total_rows += profiled['summary']['row_count']
all_columns.append(set(profiled['summary']['columns']))
files.append(
Expand Down
12 changes: 12 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,15 @@ def test_multi_csv_top_values_capped_marker(monkeypatch, tmp_path):

assert prof["top_values_capped"] is True
assert any(x["value"] == "__OTHER__" for x in prof["top_values"])


def test_multi_csv_with_parallel_workers(tmp_path):
p1 = tmp_path / "a.csv"
p2 = tmp_path / "b.csv"
p1.write_text("city,val\nseoul,1\n", encoding="utf-8")
p2.write_text("city,val\nbusan,2\n", encoding="utf-8")

result = analyze_multiple_csv([p1, p2], "병렬", max_workers=2)

assert result["file_count"] == 2
assert [f["path"] for f in result["files"]] == [str(p1), str(p2)]
33 changes: 32 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,9 @@ def test_cli_multi_analyze_no_cache_flag(tmp_path, monkeypatch):
p1.write_text("city,val\nseoul,1\n", encoding="utf-8")
called = {}

def fake_analyze(csv, question, group_column=None, target_column=None, use_cache=True):
def fake_analyze(csv, question, group_column=None, target_column=None, use_cache=True, max_workers=None):
called["use_cache"] = use_cache
called["max_workers"] = max_workers
return {
"question": question,
"file_count": 1,
Expand All @@ -204,3 +205,33 @@ def fake_analyze(csv, question, group_column=None, target_column=None, use_cache

assert code == 0
assert called["use_cache"] is False
assert called["max_workers"] is None


def test_cli_multi_analyze_workers_flag(tmp_path, monkeypatch):
p1 = tmp_path / "a.csv"
p1.write_text("city,val\nseoul,1\n", encoding="utf-8")
called = {}

def fake_analyze(csv, question, group_column=None, target_column=None, use_cache=True, max_workers=None):
called["max_workers"] = max_workers
return {
"question": question,
"file_count": 1,
"total_row_count": 1,
"shared_columns": ["city"],
"union_columns": ["city", "val"],
"files": [{"path": str(p1), "summary": {"row_count": 1, "column_count": 2, "columns": ["city", "val"]}, "column_profiles": {"city": {"dtype": "string", "missing_ratio": 0.0, "unique_ratio": 1.0, "dominant_value_ratio": 1.0}, "val": {"dtype": "float", "missing_ratio": 0.0, "unique_ratio": 1.0, "dominant_value_ratio": 1.0}}, "group_target_ratio": None}],
"schema_drift": {},
"insights": [],
"code_guidance": {"recommended_steps": "", "pandas_example": ""},
}

monkeypatch.setattr(cli, "analyze_multiple_csv", fake_analyze)

out_json = tmp_path / "ow.json"
out_md = tmp_path / "ow.md"
code = cli.main(["multi-analyze", str(p1), "--question", "q", "--workers", "3", "--out-json", str(out_json), "--out-report", str(out_md)])

assert code == 0
assert called["max_workers"] == 3