-
Notifications
You must be signed in to change notification settings - Fork 0
Add planner-based analysis pipeline and optional /api/analyze planner flow #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,268 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass, field | ||
| import csv | ||
| import io | ||
| import random | ||
| import re | ||
| from typing import Any | ||
|
|
||
|
|
||
| @dataclass | ||
| class AnalysisIntent: | ||
| question: str | ||
| top_n: int | None = None | ||
| sample_n: int | None = None | ||
| threshold: float | None = None | ||
| threshold_column: str | None = None | ||
| region: str | None = None | ||
| region_column: str | None = None | ||
| compare_periods: bool = False | ||
| metric_column: str | None = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class AnalysisPlan: | ||
| intent: AnalysisIntent | ||
| nodes: list[dict[str, Any]] | ||
| fallback: bool = False | ||
| warnings: list[str] = field(default_factory=list) | ||
|
|
||
|
|
||
| def _schema_columns(schema: dict[str, Any]) -> list[str]: | ||
| cols = schema.get("columns", []) | ||
| if isinstance(cols, list): | ||
| return [str(c) for c in cols] | ||
| return [] | ||
|
|
||
|
|
||
| def _schema_dtypes(schema: dict[str, Any]) -> dict[str, str]: | ||
| dtypes = schema.get("dtypes", {}) | ||
| if isinstance(dtypes, dict): | ||
| return {str(k): str(v) for k, v in dtypes.items()} | ||
| return {} | ||
|
|
||
|
|
||
| def _first_numeric_column(schema: dict[str, Any]) -> str | None: | ||
| dtypes = _schema_dtypes(schema) | ||
| for col, dtype in dtypes.items(): | ||
| if dtype in {"float", "int", "number", "numeric"}: | ||
| return col | ||
| return None | ||
|
|
||
|
|
||
| def _first_text_column(schema: dict[str, Any]) -> str | None: | ||
| dtypes = _schema_dtypes(schema) | ||
| for col in _schema_columns(schema): | ||
| if dtypes.get(col, "string") == "string": | ||
| return col | ||
| return _schema_columns(schema)[0] if _schema_columns(schema) else None | ||
|
|
||
|
|
||
| def _safe_float(value: Any) -> float | None: | ||
| try: | ||
| return float(str(value).strip().replace(",", "")) | ||
| except (ValueError, TypeError): | ||
| return None | ||
|
|
||
|
|
||
| def parse_question_to_intent(question: str, schema: dict[str, Any]) -> AnalysisIntent: | ||
| text = (question or "").strip() | ||
| lower = text.lower() | ||
| intent = AnalysisIntent(question=text) | ||
|
|
||
| top_m = re.search(r"(?:top|상위)\s*(\d+)", lower) | ||
| if top_m: | ||
| intent.top_n = int(top_m.group(1)) | ||
|
|
||
| sample_m = re.search(r"(?:sample|샘플)\s*(\d+)", lower) | ||
| if sample_m: | ||
| intent.sample_n = int(sample_m.group(1)) | ||
|
|
||
| th_m = re.search(r"(?:threshold|임계값)\s*(\d+(?:\.\d+)?)", lower) | ||
| if th_m: | ||
| intent.threshold = float(th_m.group(1)) | ||
|
|
||
| if intent.threshold is None: | ||
| above_m = re.search(r"(\d+(?:\.\d+)?)\s*(?:이상|초과)", text) | ||
| if above_m: | ||
| intent.threshold = float(above_m.group(1)) | ||
|
|
||
| if any(token in text for token in ["전후", "전/후", "이전", "이후", "before", "after", "대비"]): | ||
| intent.compare_periods = True | ||
|
|
||
| columns = _schema_columns(schema) | ||
| for col in columns: | ||
| if col.lower() in lower and intent.threshold is not None: | ||
| intent.threshold_column = col | ||
| break | ||
|
|
||
| if intent.threshold is not None and not intent.threshold_column: | ||
| intent.threshold_column = _first_numeric_column(schema) | ||
|
|
||
| region_col_candidates = [c for c in columns if any(k in c.lower() for k in ["region", "city", "area", "지역", "도시"])] | ||
| if region_col_candidates: | ||
| intent.region_column = region_col_candidates[0] | ||
|
|
||
| known_regions = schema.get("region_values", []) | ||
| if not isinstance(known_regions, list): | ||
| known_regions = [] | ||
| for rg in known_regions: | ||
| if str(rg) and str(rg).lower() in lower: | ||
| intent.region = str(rg) | ||
| break | ||
|
|
||
| if intent.region is None and intent.region_column: | ||
| tokens = [t for t in re.split(r"\s+", text) if t] | ||
| for tok in tokens: | ||
| if re.fullmatch(r"[가-힣A-Za-z][가-힣A-Za-z0-9_-]+", tok): | ||
| if tok.lower() not in {"top", "sample", "threshold", "임계값", "상위", "샘플"}: | ||
| if tok in columns: | ||
| continue | ||
| intent.region = tok | ||
| break | ||
|
|
||
| intent.metric_column = _first_numeric_column(schema) | ||
| return intent | ||
|
|
||
|
|
||
| def build_plan(intent: AnalysisIntent, schema_profile: dict[str, Any]) -> AnalysisPlan: | ||
| warnings: list[str] = [] | ||
| group_col = intent.region_column or _first_text_column(schema_profile) | ||
| metric_col = intent.metric_column or _first_numeric_column(schema_profile) | ||
|
|
||
| if metric_col is None: | ||
| warnings.append("numeric metric column not found") | ||
|
|
||
| nodes = [ | ||
| { | ||
| "op": "filter", | ||
| "enabled": bool(intent.region or intent.threshold is not None), | ||
| "region_column": intent.region_column, | ||
| "region": intent.region, | ||
| "threshold_column": intent.threshold_column, | ||
| "threshold": intent.threshold, | ||
| }, | ||
| { | ||
| "op": "groupby", | ||
| "enabled": bool(group_col), | ||
| "columns": [group_col] if group_col else [], | ||
| }, | ||
| { | ||
| "op": "agg", | ||
| "enabled": bool(metric_col), | ||
| "metric": metric_col, | ||
| "fn": "sum", | ||
| }, | ||
| { | ||
| "op": "rank", | ||
| "enabled": bool(intent.top_n), | ||
| "top_n": intent.top_n, | ||
| "order": "desc", | ||
| }, | ||
| { | ||
| "op": "sample", | ||
| "enabled": bool(intent.sample_n), | ||
| "sample_n": intent.sample_n, | ||
| "seed": 42, | ||
| }, | ||
| { | ||
| "op": "export", | ||
| "enabled": True, | ||
| "include_meta": True, | ||
| }, | ||
| ] | ||
|
|
||
| return AnalysisPlan(intent=intent, nodes=nodes, fallback=False, warnings=warnings) | ||
|
|
||
|
|
||
| def _execute_filter(rows: list[dict[str, Any]], node: dict[str, Any]) -> list[dict[str, Any]]: | ||
| out = rows | ||
| region = node.get("region") | ||
| region_col = node.get("region_column") | ||
| if region and region_col: | ||
| out = [r for r in out if str(r.get(region_col, "")).strip().lower() == str(region).strip().lower()] | ||
|
|
||
| threshold = node.get("threshold") | ||
| threshold_col = node.get("threshold_column") | ||
| if threshold is not None and threshold_col: | ||
| fth = float(threshold) | ||
| filtered: list[dict[str, Any]] = [] | ||
| for r in out: | ||
| num = _safe_float(r.get(threshold_col)) | ||
| if num is not None and num >= fth: | ||
| filtered.append(r) | ||
| out = filtered | ||
| return out | ||
|
|
||
|
|
||
| def _execute_group_agg(rows: list[dict[str, Any]], group_col: str | None, metric_col: str | None) -> list[dict[str, Any]]: | ||
| if not group_col or not metric_col: | ||
| return rows | ||
| grouped: dict[str, float] = {} | ||
| counts: dict[str, int] = {} | ||
| for r in rows: | ||
| key = str(r.get(group_col, "<missing>")) | ||
| val = _safe_float(r.get(metric_col)) | ||
| if val is None: | ||
| continue | ||
| grouped[key] = grouped.get(key, 0.0) + val | ||
| counts[key] = counts.get(key, 0) + 1 | ||
| return [{group_col: k, f"sum_{metric_col}": v, "count": counts.get(k, 0)} for k, v in grouped.items()] | ||
|
|
||
|
|
||
| def execute_plan(plan: AnalysisPlan, data: list[dict[str, Any]]) -> dict[str, Any]: | ||
| rows = list(data) | ||
| meta: dict[str, Any] = {"node_count": len(plan.nodes), "warnings": list(plan.warnings)} | ||
|
|
||
| try: | ||
| grouped_rows = rows | ||
| group_col: str | None = None | ||
| metric_col: str | None = None | ||
|
|
||
| for node in plan.nodes: | ||
| if not node.get("enabled", False): | ||
| continue | ||
| op = node.get("op") | ||
| if op == "filter": | ||
| rows = _execute_filter(rows, node) | ||
| grouped_rows = rows | ||
| elif op == "groupby": | ||
| cols = node.get("columns", []) | ||
| group_col = cols[0] if cols else None | ||
| elif op == "agg": | ||
| metric_col = node.get("metric") | ||
| grouped_rows = _execute_group_agg(rows, group_col, metric_col) | ||
| elif op == "rank": | ||
| top_n = int(node.get("top_n") or 0) | ||
| if top_n > 0: | ||
| rank_key = f"sum_{metric_col}" if metric_col else None | ||
| if rank_key: | ||
| grouped_rows = sorted(grouped_rows, key=lambda r: _safe_float(r.get(rank_key)) or 0.0, reverse=True)[:top_n] | ||
| elif op == "sample": | ||
| sample_n = int(node.get("sample_n") or 0) | ||
| if sample_n > 0 and rows: | ||
| rnd = random.Random(int(node.get("seed") or 42)) | ||
| rows = rnd.sample(rows, k=min(sample_n, len(rows))) | ||
| elif op == "export": | ||
| pass | ||
| else: | ||
| raise ValueError(f"unsupported op: {op}") | ||
|
|
||
| return { | ||
| "table": grouped_rows, | ||
| "sample": rows[: int(plan.intent.sample_n or 5)], | ||
| "meta": {**meta, "fallback": False, "filtered_row_count": len(rows)}, | ||
| } | ||
| except Exception as exc: | ||
| return { | ||
| "table": data[:10], | ||
| "sample": data[:5], | ||
| "meta": {**meta, "fallback": True, "error": str(exc)}, | ||
| } | ||
|
|
||
|
|
||
| def execute_plan_from_csv_text(plan: AnalysisPlan, csv_text: str) -> dict[str, Any]: | ||
| reader = csv.DictReader(io.StringIO(csv_text)) | ||
| rows = list(reader) | ||
| return execute_plan(plan, rows) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,7 @@ | |
| from .analysis import build_analysis_payload_from_request | ||
| from .document_extract import extract_document_tables_from_base64, table_to_analysis_request | ||
| from .multi_csv import analyze_multiple_csv | ||
| from .planner import build_plan, execute_plan_from_csv_text, parse_question_to_intent | ||
| from .visualize import create_multi_charts | ||
|
|
||
|
|
||
|
|
@@ -491,6 +492,7 @@ def do_POST(self) -> None: | |
| question = str(payload.get("question", "")).strip() | ||
| if not question: | ||
| question = "이 데이터의 핵심 인사이트를 알려줘" | ||
| use_planner = bool(payload.get("use_planner", False)) | ||
|
|
||
| input_type = str(payload.get("input_type", "csv") or "csv").strip().lower() | ||
| normalized_csv_text = str(payload.get("normalized_csv_text", "") or "") | ||
|
|
@@ -545,6 +547,14 @@ def do_POST(self) -> None: | |
| ), | ||
| HTTPStatus.BAD_REQUEST, | ||
| ) | ||
| if use_planner: | ||
| intent = parse_question_to_intent(question, result.get("summary", {})) | ||
| plan = build_plan(intent, result.get("summary", {})) | ||
| result["planner"] = { | ||
| "intent": intent.__dict__, | ||
| "plan": {"nodes": plan.nodes, "warnings": plan.warnings, "fallback": plan.fallback}, | ||
| "execution": execute_plan_from_csv_text(plan, str(request_payload.get("normalized_csv_text", "") or "")), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This execution step always reads Useful? React with 👍 / 👎. |
||
| } | ||
| return self._send_json(result) | ||
|
|
||
| if route == '/api/preprocess/jobs': | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| from bitnet_tools.planner import AnalysisPlan, build_plan, execute_plan, parse_question_to_intent | ||
|
|
||
|
|
||
| def _schema(): | ||
| return { | ||
| "columns": ["region", "sales", "period"], | ||
| "dtypes": {"region": "string", "sales": "float", "period": "string"}, | ||
| "region_values": ["서울", "부산"], | ||
| } | ||
|
|
||
|
|
||
| def test_parse_question_to_intent_extracts_controls(): | ||
| intent = parse_question_to_intent("서울 지역 top 3, sample 2, sales 임계값 100 전/후 비교", _schema()) | ||
|
|
||
| assert intent.top_n == 3 | ||
| assert intent.sample_n == 2 | ||
| assert intent.threshold == 100 | ||
| assert intent.threshold_column == "sales" | ||
| assert intent.region == "서울" | ||
| assert intent.compare_periods is True | ||
|
|
||
|
|
||
| def test_build_plan_contains_execution_graph_nodes(): | ||
| intent = parse_question_to_intent("상위 5 샘플 2", _schema()) | ||
| plan = build_plan(intent, _schema()) | ||
|
|
||
| assert [n["op"] for n in plan.nodes] == ["filter", "groupby", "agg", "rank", "sample", "export"] | ||
| assert any(node["op"] == "rank" and node["enabled"] for node in plan.nodes) | ||
|
|
||
|
|
||
| def test_execute_plan_fallback_on_invalid_node(): | ||
| plan = AnalysisPlan( | ||
| intent=parse_question_to_intent("기본", _schema()), | ||
| nodes=[{"op": "unknown", "enabled": True}], | ||
| ) | ||
| data = [{"region": "서울", "sales": "120", "period": "after"}] | ||
|
|
||
| result = execute_plan(plan, data) | ||
|
|
||
| assert result["meta"]["fallback"] is True | ||
| assert "unsupported op" in result["meta"]["error"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a question asks for a before/after comparison without naming a region (for example containing
before,after,이전, or이후), this fallback path treats the first such token asintent.region. That makes the filter node run againstregion_columnwith a non-region value, which can drop all rows and return empty planner results for valid compare-period queries.Useful? React with 👍 / 👎.