diff --git a/bitnet_tools/cli.py b/bitnet_tools/cli.py index f278035..fbc442c 100644 --- a/bitnet_tools/cli.py +++ b/bitnet_tools/cli.py @@ -80,6 +80,7 @@ def _build_parser() -> argparse.ArgumentParser: help="Optional directory to save visualization charts", ) multi_parser.add_argument("--no-cache", action="store_true", help="Disable file profile cache") + multi_parser.add_argument("--workers", type=int, default=None, help="Optional worker count for parallel file profiling") report_parser = subparsers.add_parser("report", help="Build markdown summary report from CSV") report_parser.add_argument("csv", type=Path, help="Input CSV path") @@ -126,6 +127,7 @@ def main(argv: list[str] | None = None) -> int: group_column=args.group_column, target_column=args.target_column, use_cache=not args.no_cache, + max_workers=args.workers, ) if args.charts_dir is not None: try: diff --git a/bitnet_tools/multi_csv.py b/bitnet_tools/multi_csv.py index 667813c..70c0f51 100644 --- a/bitnet_tools/multi_csv.py +++ b/bitnet_tools/multi_csv.py @@ -5,6 +5,7 @@ import json import math import random +from concurrent.futures import ThreadPoolExecutor from collections import Counter, defaultdict from datetime import datetime from pathlib import Path @@ -342,30 +343,56 @@ def _generate_insights(files: list[dict[str, Any]], schema_drift: dict[str, Any] return insights[:30] +def _load_or_profile_file( + path: Path, + group_column: str | None, + target_column: str | None, + use_cache: bool, +) -> dict[str, Any]: + profiled = _load_cached_profile(path, group_column, target_column) if use_cache else None + if profiled is None: + profiled = _profile_csv_stream(path, group_column=group_column, target_column=target_column) + if use_cache: + _save_cached_profile(path, group_column, target_column, profiled) + return profiled + + def analyze_multiple_csv( csv_paths: list[Path], question: str, group_column: str | None = None, target_column: str | None = None, use_cache: bool = True, + max_workers: int | None = None, ) -> dict[str, Any]: if not csv_paths: raise ValueError('at least one CSV path is required') - files: list[dict[str, Any]] = [] - all_columns: list[set[str]] = [] - total_rows = 0 - for path in csv_paths: if not path.exists(): raise FileNotFoundError(f'CSV file not found: {path}') - profiled = _load_cached_profile(path, group_column, target_column) if use_cache else None - if profiled is None: - profiled = _profile_csv_stream(path, group_column=group_column, target_column=target_column) - if use_cache: - _save_cached_profile(path, group_column, target_column, profiled) + worker_count = max_workers if (max_workers is not None and max_workers > 0) else min(4, len(csv_paths)) + + if worker_count == 1 or len(csv_paths) == 1: + profiled_list = [ + _load_or_profile_file(path, group_column, target_column, use_cache) + for path in csv_paths + ] + else: + with ThreadPoolExecutor(max_workers=worker_count) as executor: + profiled_list = list( + executor.map( + lambda p: _load_or_profile_file(p, group_column, target_column, use_cache), + csv_paths, + ) + ) + + files: list[dict[str, Any]] = [] + all_columns: list[set[str]] = [] + total_rows = 0 + for path, profiled in zip(csv_paths, profiled_list): total_rows += profiled['summary']['row_count'] all_columns.append(set(profiled['summary']['columns'])) files.append( diff --git a/tests/test_analysis.py b/tests/test_analysis.py index cc88aa7..e789069 100644 --- a/tests/test_analysis.py +++ b/tests/test_analysis.py @@ -140,3 +140,15 @@ def test_multi_csv_top_values_capped_marker(monkeypatch, tmp_path): assert prof["top_values_capped"] is True assert any(x["value"] == "__OTHER__" for x in prof["top_values"]) + + +def test_multi_csv_with_parallel_workers(tmp_path): + p1 = tmp_path / "a.csv" + p2 = tmp_path / "b.csv" + p1.write_text("city,val\nseoul,1\n", encoding="utf-8") + p2.write_text("city,val\nbusan,2\n", encoding="utf-8") + + result = analyze_multiple_csv([p1, p2], "병렬", max_workers=2) + + assert result["file_count"] == 2 + assert [f["path"] for f in result["files"]] == [str(p1), str(p2)] diff --git a/tests/test_cli.py b/tests/test_cli.py index 290a418..00027c6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -182,8 +182,9 @@ def test_cli_multi_analyze_no_cache_flag(tmp_path, monkeypatch): p1.write_text("city,val\nseoul,1\n", encoding="utf-8") called = {} - def fake_analyze(csv, question, group_column=None, target_column=None, use_cache=True): + def fake_analyze(csv, question, group_column=None, target_column=None, use_cache=True, max_workers=None): called["use_cache"] = use_cache + called["max_workers"] = max_workers return { "question": question, "file_count": 1, @@ -204,3 +205,33 @@ def fake_analyze(csv, question, group_column=None, target_column=None, use_cache assert code == 0 assert called["use_cache"] is False + assert called["max_workers"] is None + + +def test_cli_multi_analyze_workers_flag(tmp_path, monkeypatch): + p1 = tmp_path / "a.csv" + p1.write_text("city,val\nseoul,1\n", encoding="utf-8") + called = {} + + def fake_analyze(csv, question, group_column=None, target_column=None, use_cache=True, max_workers=None): + called["max_workers"] = max_workers + return { + "question": question, + "file_count": 1, + "total_row_count": 1, + "shared_columns": ["city"], + "union_columns": ["city", "val"], + "files": [{"path": str(p1), "summary": {"row_count": 1, "column_count": 2, "columns": ["city", "val"]}, "column_profiles": {"city": {"dtype": "string", "missing_ratio": 0.0, "unique_ratio": 1.0, "dominant_value_ratio": 1.0}, "val": {"dtype": "float", "missing_ratio": 0.0, "unique_ratio": 1.0, "dominant_value_ratio": 1.0}}, "group_target_ratio": None}], + "schema_drift": {}, + "insights": [], + "code_guidance": {"recommended_steps": "", "pandas_example": ""}, + } + + monkeypatch.setattr(cli, "analyze_multiple_csv", fake_analyze) + + out_json = tmp_path / "ow.json" + out_md = tmp_path / "ow.md" + code = cli.main(["multi-analyze", str(p1), "--question", "q", "--workers", "3", "--out-json", str(out_json), "--out-report", str(out_md)]) + + assert code == 0 + assert called["max_workers"] == 3