diff --git a/bitnet_tools/planner.py b/bitnet_tools/planner.py new file mode 100644 index 0000000..85fe941 --- /dev/null +++ b/bitnet_tools/planner.py @@ -0,0 +1,268 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +import csv +import io +import random +import re +from typing import Any + + +@dataclass +class AnalysisIntent: + question: str + top_n: int | None = None + sample_n: int | None = None + threshold: float | None = None + threshold_column: str | None = None + region: str | None = None + region_column: str | None = None + compare_periods: bool = False + metric_column: str | None = None + + +@dataclass +class AnalysisPlan: + intent: AnalysisIntent + nodes: list[dict[str, Any]] + fallback: bool = False + warnings: list[str] = field(default_factory=list) + + +def _schema_columns(schema: dict[str, Any]) -> list[str]: + cols = schema.get("columns", []) + if isinstance(cols, list): + return [str(c) for c in cols] + return [] + + +def _schema_dtypes(schema: dict[str, Any]) -> dict[str, str]: + dtypes = schema.get("dtypes", {}) + if isinstance(dtypes, dict): + return {str(k): str(v) for k, v in dtypes.items()} + return {} + + +def _first_numeric_column(schema: dict[str, Any]) -> str | None: + dtypes = _schema_dtypes(schema) + for col, dtype in dtypes.items(): + if dtype in {"float", "int", "number", "numeric"}: + return col + return None + + +def _first_text_column(schema: dict[str, Any]) -> str | None: + dtypes = _schema_dtypes(schema) + for col in _schema_columns(schema): + if dtypes.get(col, "string") == "string": + return col + return _schema_columns(schema)[0] if _schema_columns(schema) else None + + +def _safe_float(value: Any) -> float | None: + try: + return float(str(value).strip().replace(",", "")) + except (ValueError, TypeError): + return None + + +def parse_question_to_intent(question: str, schema: dict[str, Any]) -> AnalysisIntent: + text = (question or "").strip() + lower = text.lower() + intent = AnalysisIntent(question=text) + + top_m = re.search(r"(?:top|상위)\s*(\d+)", lower) + if top_m: + intent.top_n = int(top_m.group(1)) + + sample_m = re.search(r"(?:sample|샘플)\s*(\d+)", lower) + if sample_m: + intent.sample_n = int(sample_m.group(1)) + + th_m = re.search(r"(?:threshold|임계값)\s*(\d+(?:\.\d+)?)", lower) + if th_m: + intent.threshold = float(th_m.group(1)) + + if intent.threshold is None: + above_m = re.search(r"(\d+(?:\.\d+)?)\s*(?:이상|초과)", text) + if above_m: + intent.threshold = float(above_m.group(1)) + + if any(token in text for token in ["전후", "전/후", "이전", "이후", "before", "after", "대비"]): + intent.compare_periods = True + + columns = _schema_columns(schema) + for col in columns: + if col.lower() in lower and intent.threshold is not None: + intent.threshold_column = col + break + + if intent.threshold is not None and not intent.threshold_column: + intent.threshold_column = _first_numeric_column(schema) + + region_col_candidates = [c for c in columns if any(k in c.lower() for k in ["region", "city", "area", "지역", "도시"])] + if region_col_candidates: + intent.region_column = region_col_candidates[0] + + known_regions = schema.get("region_values", []) + if not isinstance(known_regions, list): + known_regions = [] + for rg in known_regions: + if str(rg) and str(rg).lower() in lower: + intent.region = str(rg) + break + + if intent.region is None and intent.region_column: + tokens = [t for t in re.split(r"\s+", text) if t] + for tok in tokens: + if re.fullmatch(r"[가-힣A-Za-z][가-힣A-Za-z0-9_-]+", tok): + if tok.lower() not in {"top", "sample", "threshold", "임계값", "상위", "샘플"}: + if tok in columns: + continue + intent.region = tok + break + + intent.metric_column = _first_numeric_column(schema) + return intent + + +def build_plan(intent: AnalysisIntent, schema_profile: dict[str, Any]) -> AnalysisPlan: + warnings: list[str] = [] + group_col = intent.region_column or _first_text_column(schema_profile) + metric_col = intent.metric_column or _first_numeric_column(schema_profile) + + if metric_col is None: + warnings.append("numeric metric column not found") + + nodes = [ + { + "op": "filter", + "enabled": bool(intent.region or intent.threshold is not None), + "region_column": intent.region_column, + "region": intent.region, + "threshold_column": intent.threshold_column, + "threshold": intent.threshold, + }, + { + "op": "groupby", + "enabled": bool(group_col), + "columns": [group_col] if group_col else [], + }, + { + "op": "agg", + "enabled": bool(metric_col), + "metric": metric_col, + "fn": "sum", + }, + { + "op": "rank", + "enabled": bool(intent.top_n), + "top_n": intent.top_n, + "order": "desc", + }, + { + "op": "sample", + "enabled": bool(intent.sample_n), + "sample_n": intent.sample_n, + "seed": 42, + }, + { + "op": "export", + "enabled": True, + "include_meta": True, + }, + ] + + return AnalysisPlan(intent=intent, nodes=nodes, fallback=False, warnings=warnings) + + +def _execute_filter(rows: list[dict[str, Any]], node: dict[str, Any]) -> list[dict[str, Any]]: + out = rows + region = node.get("region") + region_col = node.get("region_column") + if region and region_col: + out = [r for r in out if str(r.get(region_col, "")).strip().lower() == str(region).strip().lower()] + + threshold = node.get("threshold") + threshold_col = node.get("threshold_column") + if threshold is not None and threshold_col: + fth = float(threshold) + filtered: list[dict[str, Any]] = [] + for r in out: + num = _safe_float(r.get(threshold_col)) + if num is not None and num >= fth: + filtered.append(r) + out = filtered + return out + + +def _execute_group_agg(rows: list[dict[str, Any]], group_col: str | None, metric_col: str | None) -> list[dict[str, Any]]: + if not group_col or not metric_col: + return rows + grouped: dict[str, float] = {} + counts: dict[str, int] = {} + for r in rows: + key = str(r.get(group_col, "")) + val = _safe_float(r.get(metric_col)) + if val is None: + continue + grouped[key] = grouped.get(key, 0.0) + val + counts[key] = counts.get(key, 0) + 1 + return [{group_col: k, f"sum_{metric_col}": v, "count": counts.get(k, 0)} for k, v in grouped.items()] + + +def execute_plan(plan: AnalysisPlan, data: list[dict[str, Any]]) -> dict[str, Any]: + rows = list(data) + meta: dict[str, Any] = {"node_count": len(plan.nodes), "warnings": list(plan.warnings)} + + try: + grouped_rows = rows + group_col: str | None = None + metric_col: str | None = None + + for node in plan.nodes: + if not node.get("enabled", False): + continue + op = node.get("op") + if op == "filter": + rows = _execute_filter(rows, node) + grouped_rows = rows + elif op == "groupby": + cols = node.get("columns", []) + group_col = cols[0] if cols else None + elif op == "agg": + metric_col = node.get("metric") + grouped_rows = _execute_group_agg(rows, group_col, metric_col) + elif op == "rank": + top_n = int(node.get("top_n") or 0) + if top_n > 0: + rank_key = f"sum_{metric_col}" if metric_col else None + if rank_key: + grouped_rows = sorted(grouped_rows, key=lambda r: _safe_float(r.get(rank_key)) or 0.0, reverse=True)[:top_n] + elif op == "sample": + sample_n = int(node.get("sample_n") or 0) + if sample_n > 0 and rows: + rnd = random.Random(int(node.get("seed") or 42)) + rows = rnd.sample(rows, k=min(sample_n, len(rows))) + elif op == "export": + pass + else: + raise ValueError(f"unsupported op: {op}") + + return { + "table": grouped_rows, + "sample": rows[: int(plan.intent.sample_n or 5)], + "meta": {**meta, "fallback": False, "filtered_row_count": len(rows)}, + } + except Exception as exc: + return { + "table": data[:10], + "sample": data[:5], + "meta": {**meta, "fallback": True, "error": str(exc)}, + } + + +def execute_plan_from_csv_text(plan: AnalysisPlan, csv_text: str) -> dict[str, Any]: + reader = csv.DictReader(io.StringIO(csv_text)) + rows = list(reader) + return execute_plan(plan, rows) diff --git a/bitnet_tools/web.py b/bitnet_tools/web.py index ed81905..c831904 100644 --- a/bitnet_tools/web.py +++ b/bitnet_tools/web.py @@ -23,6 +23,7 @@ from .analysis import build_analysis_payload_from_request from .document_extract import extract_document_tables_from_base64, table_to_analysis_request from .multi_csv import analyze_multiple_csv +from .planner import build_plan, execute_plan_from_csv_text, parse_question_to_intent from .visualize import create_multi_charts @@ -491,6 +492,7 @@ def do_POST(self) -> None: question = str(payload.get("question", "")).strip() if not question: question = "이 데이터의 핵심 인사이트를 알려줘" + use_planner = bool(payload.get("use_planner", False)) input_type = str(payload.get("input_type", "csv") or "csv").strip().lower() normalized_csv_text = str(payload.get("normalized_csv_text", "") or "") @@ -545,6 +547,14 @@ def do_POST(self) -> None: ), HTTPStatus.BAD_REQUEST, ) + if use_planner: + intent = parse_question_to_intent(question, result.get("summary", {})) + plan = build_plan(intent, result.get("summary", {})) + result["planner"] = { + "intent": intent.__dict__, + "plan": {"nodes": plan.nodes, "warnings": plan.warnings, "fallback": plan.fallback}, + "execution": execute_plan_from_csv_text(plan, str(request_payload.get("normalized_csv_text", "") or "")), + } return self._send_json(result) if route == '/api/preprocess/jobs': diff --git a/tests/test_planner.py b/tests/test_planner.py new file mode 100644 index 0000000..c01013d --- /dev/null +++ b/tests/test_planner.py @@ -0,0 +1,41 @@ +from bitnet_tools.planner import AnalysisPlan, build_plan, execute_plan, parse_question_to_intent + + +def _schema(): + return { + "columns": ["region", "sales", "period"], + "dtypes": {"region": "string", "sales": "float", "period": "string"}, + "region_values": ["서울", "부산"], + } + + +def test_parse_question_to_intent_extracts_controls(): + intent = parse_question_to_intent("서울 지역 top 3, sample 2, sales 임계값 100 전/후 비교", _schema()) + + assert intent.top_n == 3 + assert intent.sample_n == 2 + assert intent.threshold == 100 + assert intent.threshold_column == "sales" + assert intent.region == "서울" + assert intent.compare_periods is True + + +def test_build_plan_contains_execution_graph_nodes(): + intent = parse_question_to_intent("상위 5 샘플 2", _schema()) + plan = build_plan(intent, _schema()) + + assert [n["op"] for n in plan.nodes] == ["filter", "groupby", "agg", "rank", "sample", "export"] + assert any(node["op"] == "rank" and node["enabled"] for node in plan.nodes) + + +def test_execute_plan_fallback_on_invalid_node(): + plan = AnalysisPlan( + intent=parse_question_to_intent("기본", _schema()), + nodes=[{"op": "unknown", "enabled": True}], + ) + data = [{"region": "서울", "sales": "120", "period": "after"}] + + result = execute_plan(plan, data) + + assert result["meta"]["fallback"] is True + assert "unsupported op" in result["meta"]["error"] diff --git a/tests/test_web.py b/tests/test_web.py index aee4c60..5b1903a 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -282,3 +282,23 @@ def test_analyze_document_fallback_error_uses_error_and_error_detail(): finally: server.shutdown() thread.join(timeout=1) + + +def test_analyze_api_with_planner_returns_execution_payload(): + server, thread = _run_server() + base = f'http://127.0.0.1:{server.server_port}' + try: + code, body = _post_json(base + '/api/analyze', { + 'input_type': 'csv', + 'source_name': 'sales.csv', + 'normalized_csv_text': 'region,sales\n서울,120\n부산,80\n', + 'question': '서울 top 1 sample 1 임계값 100', + 'use_planner': True, + }) + assert code == 200 + assert 'planner' in body + assert body['planner']['intent']['top_n'] == 1 + assert body['planner']['execution']['meta']['fallback'] is False + finally: + server.shutdown() + thread.join(timeout=1)