diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py old mode 100644 new mode 100755 index c41d32dcd..03c71498e --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -43,6 +43,7 @@ def close(self): "sample_uid", "op_seq_bucket_id", "input_shapes_bucket_id", + "input_dtypes_bucket_id", "sample_type", "sample_uids", ], @@ -64,10 +65,11 @@ def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]): grouped = defaultdict(list) for bucket_info in sample_bucket_infos: - grouped[bucket_info.op_seq_bucket_id].append(bucket_info.sample_uid) + key = (bucket_info.op_seq_bucket_id, bucket_info.input_dtypes_bucket_id) + grouped[key].append(bucket_info.sample_uid) grouped = dict(grouped) - for op_seq, sample_uids in grouped.items(): + for key, sample_uids in grouped.items(): new_uuid = str(uuid_module.uuid4()) for sample_uid in sample_uids: yield sample_uid, new_uuid @@ -89,25 +91,26 @@ def main(): db.connect() query_str = """ -SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids +SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids FROM ( SELECT s.uuid AS sample_uid, s.sample_type AS sample_type, b.op_seq_bucket_id AS op_seq_bucket_id, - b.input_shapes_bucket_id AS input_shapes_bucket_id + b.input_shapes_bucket_id AS input_shapes_bucket_id, + b.input_dtypes_bucket_id AS input_dtypes_bucket_id FROM graph_sample s JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid order by s.create_at asc, s.uuid asc ) b -GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id; +GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,b.input_dtypes_bucket_id; """ query_results = db.query(query_str) print("Output:", len(query_results)) - query_results = [SampleBucketInfo(row) for row in query_results] + query_results = [SampleBucketInfo(*row) for row in query_results] session = get_session(args.db_path) diff --git a/sqlite/graphsample_insert.py b/sqlite/graphsample_insert.py index a5f623cba..afb40816a 100755 --- a/sqlite/graphsample_insert.py +++ b/sqlite/graphsample_insert.py @@ -70,6 +70,10 @@ def insert_subgraph_source( try: parent_relative_path = get_parent_relative_path(relative_model_path) if sample_type == "fusible_graph" or sample_type == "typical_graph": + parent_parts = parent_relative_path.split("/") + parent_parts = parent_parts[2:] + parent_relative_path = "/".join(parent_parts) + if sample_type == "sole_op_graph": parent_parts = parent_relative_path.split("/") parent_parts = parent_parts[1:] parent_relative_path = "/".join(parent_parts) @@ -255,7 +259,7 @@ def insert_datatype_generalization_source( def _get_data_type(model_path_prefix: str, relative_model_path: str): - return "todo" + return relative_model_path.split("/")[0] # SampleOpNameList and SampleOpName insert func @@ -477,13 +481,13 @@ def main(args): args.relative_model_path, args.db_path, ) - insert_datatype_generalization_source( - subgraph_source_data["subgraph_uuid"], - subgraph_source_data["full_graph_uuid"], - args.model_path_prefix, - args.relative_model_path, - args.db_path, - ) + insert_datatype_generalization_source( + subgraph_source_data["subgraph_uuid"], + subgraph_source_data["full_graph_uuid"], + args.model_path_prefix, + args.relative_model_path, + args.db_path, + ) print(f"success insert: {data['relative_model_path']}") except sqlite3.IntegrityError as e: print("insert failed: integrity error (possible duplicate uuid or graph_hash)") diff --git a/sqlite/graphsample_insert.sh b/sqlite/graphsample_insert.sh index 043ee63eb..98ef9b901 100755 --- a/sqlite/graphsample_insert.sh +++ b/sqlite/graphsample_insert.sh @@ -3,11 +3,11 @@ set -x GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") DB_PATH="${1:-${GRAPH_NET_ROOT}/sqlite/GraphNet.db}" -TORCH_MODEL_LIST="graph_net/config/torch_samples_list.txt" -PADDLE_MODEL_LIST="graph_net/config/small10_paddle_samples_list.txt" -TYPICAL_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/deduplicated_subgraph_sample_list.txt" -FUSIBLE_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/deduplicated_dimension_generalized_subgraph_sample_list.txt" -SOLE_OP_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/solo_sample_list.txt" +DATASET_ROOT="${GRAPH_NET_ROOT}/20260317" +TORCH_MODEL_LIST="${DATASET_ROOT}/full_graph.txt" +TYPICAL_GRAPH_SAMPLES_LIST="${DATASET_ROOT}/typical_graph.txt" +FUSIBLE_GRAPH_SAMPLES_LIST="${DATASET_ROOT}/fusible_graph.txt" +SOLE_OP_GRAPH_SAMPLES_LIST="${DATASET_ROOT}/sole_op_graph.txt" ORDER_VALUE=0 if [ ! -f "$DB_PATH" ]; then @@ -18,7 +18,7 @@ fi while IFS= read -r model_rel_path; do echo "insert : $model_rel_path" python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "$GRAPH_NET_ROOT" \ + --model_path_prefix "$DATASET_ROOT/full_graph" \ --relative_model_path "$model_rel_path" \ --repo_uid "hf_torch_samples" \ --sample_type "full_graph" \ @@ -32,23 +32,9 @@ done < "$TORCH_MODEL_LIST" while IFS= read -r model_rel_path; do echo "insert : $model_rel_path" python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "$GRAPH_NET_ROOT" \ + --model_path_prefix "${DATASET_ROOT}/typical_graph" \ --relative_model_path "$model_rel_path" \ - --repo_uid "hf_paddle_samples" \ - --sample_type "full_graph" \ - --order_value "$ORDER_VALUE" \ - --db_path "$DB_PATH" - - ((ORDER_VALUE++)) - -done < "$PADDLE_MODEL_LIST" - -while IFS= read -r model_rel_path; do - echo "insert : $model_rel_path" - python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/typical_graph" \ - --relative_model_path "$model_rel_path" \ - --op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \ + --op_names_path_prefix "${DATASET_ROOT}/03_sample_op_names" \ --repo_uid "hf_torch_samples" \ --sample_type "typical_graph" \ --order_value "$ORDER_VALUE" \ @@ -61,9 +47,9 @@ done < "$TYPICAL_GRAPH_SAMPLES_LIST" while IFS= read -r model_rel_path; do echo "insert : $model_rel_path" python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/fusible_graph" \ + --model_path_prefix "${DATASET_ROOT}/fusible_graph" \ --relative_model_path "$model_rel_path" \ - --op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \ + --op_names_path_prefix "${DATASET_ROOT}/03_sample_op_names" \ --repo_uid "hf_torch_samples" \ --sample_type "fusible_graph" \ --order_value "$ORDER_VALUE" \ @@ -76,9 +62,9 @@ done < "$FUSIBLE_GRAPH_SAMPLES_LIST" while IFS= read -r model_rel_path; do echo "insert : $model_rel_path" python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \ - --model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/sole_op_graph" \ + --model_path_prefix "${DATASET_ROOT}/sole_op_graph" \ --relative_model_path "$model_rel_path" \ - --op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \ + --op_names_path_prefix "${DATASET_ROOT}/03_sample_op_names" \ --repo_uid "hf_torch_samples" \ --sample_type "sole_op_graph" \ --order_value "$ORDER_VALUE" \