Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions sqlite/graph_net_sample_groups_insert.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def close(self):
"sample_uid",
"op_seq_bucket_id",
"input_shapes_bucket_id",
"input_dtypes_bucket_id",
"sample_type",
"sample_uids",
],
Expand All @@ -64,10 +65,11 @@ def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]):

grouped = defaultdict(list)
for bucket_info in sample_bucket_infos:
grouped[bucket_info.op_seq_bucket_id].append(bucket_info.sample_uid)
key = (bucket_info.op_seq_bucket_id, bucket_info.input_dtypes_bucket_id)
grouped[key].append(bucket_info.sample_uid)

grouped = dict(grouped)
for op_seq, sample_uids in grouped.items():
for key, sample_uids in grouped.items():
new_uuid = str(uuid_module.uuid4())
for sample_uid in sample_uids:
yield sample_uid, new_uuid
Expand All @@ -89,25 +91,26 @@ def main():
db.connect()

query_str = """
SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids
SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids
FROM (
SELECT
s.uuid AS sample_uid,
s.sample_type AS sample_type,
b.op_seq_bucket_id AS op_seq_bucket_id,
b.input_shapes_bucket_id AS input_shapes_bucket_id
b.input_shapes_bucket_id AS input_shapes_bucket_id,
b.input_dtypes_bucket_id AS input_dtypes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b
ON s.uuid = b.sample_uid
order by s.create_at asc, s.uuid asc
) b
GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id;
GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,b.input_dtypes_bucket_id;
"""

query_results = db.query(query_str)
print("Output:", len(query_results))

query_results = [SampleBucketInfo(row) for row in query_results]
query_results = [SampleBucketInfo(*row) for row in query_results]

session = get_session(args.db_path)

Expand Down
20 changes: 12 additions & 8 deletions sqlite/graphsample_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def insert_subgraph_source(
try:
parent_relative_path = get_parent_relative_path(relative_model_path)
if sample_type == "fusible_graph" or sample_type == "typical_graph":
parent_parts = parent_relative_path.split("/")
parent_parts = parent_parts[2:]
parent_relative_path = "/".join(parent_parts)
if sample_type == "sole_op_graph":
parent_parts = parent_relative_path.split("/")
parent_parts = parent_parts[1:]
parent_relative_path = "/".join(parent_parts)
Expand Down Expand Up @@ -255,7 +259,7 @@ def insert_datatype_generalization_source(


def _get_data_type(model_path_prefix: str, relative_model_path: str):
return "todo"
return relative_model_path.split("/")[0]


# SampleOpNameList and SampleOpName insert func
Expand Down Expand Up @@ -477,13 +481,13 @@ def main(args):
args.relative_model_path,
args.db_path,
)
insert_datatype_generalization_source(
subgraph_source_data["subgraph_uuid"],
subgraph_source_data["full_graph_uuid"],
args.model_path_prefix,
args.relative_model_path,
args.db_path,
)
insert_datatype_generalization_source(
subgraph_source_data["subgraph_uuid"],
subgraph_source_data["full_graph_uuid"],
args.model_path_prefix,
args.relative_model_path,
args.db_path,
)
print(f"success insert: {data['relative_model_path']}")
except sqlite3.IntegrityError as e:
print("insert failed: integrity error (possible duplicate uuid or graph_hash)")
Expand Down
38 changes: 12 additions & 26 deletions sqlite/graphsample_insert.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ set -x

GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))")
DB_PATH="${1:-${GRAPH_NET_ROOT}/sqlite/GraphNet.db}"
TORCH_MODEL_LIST="graph_net/config/torch_samples_list.txt"
PADDLE_MODEL_LIST="graph_net/config/small10_paddle_samples_list.txt"
TYPICAL_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/deduplicated_subgraph_sample_list.txt"
FUSIBLE_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/deduplicated_dimension_generalized_subgraph_sample_list.txt"
SOLE_OP_GRAPH_SAMPLES_LIST="subgraph_dataset_20260203/solo_sample_list.txt"
DATASET_ROOT="${GRAPH_NET_ROOT}/20260317"
TORCH_MODEL_LIST="${DATASET_ROOT}/full_graph.txt"
TYPICAL_GRAPH_SAMPLES_LIST="${DATASET_ROOT}/typical_graph.txt"
FUSIBLE_GRAPH_SAMPLES_LIST="${DATASET_ROOT}/fusible_graph.txt"
SOLE_OP_GRAPH_SAMPLES_LIST="${DATASET_ROOT}/sole_op_graph.txt"
ORDER_VALUE=0

if [ ! -f "$DB_PATH" ]; then
Expand All @@ -18,7 +18,7 @@ fi
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "$GRAPH_NET_ROOT" \
--model_path_prefix "$DATASET_ROOT/full_graph" \
--relative_model_path "$model_rel_path" \
--repo_uid "hf_torch_samples" \
--sample_type "full_graph" \
Expand All @@ -32,23 +32,9 @@ done < "$TORCH_MODEL_LIST"
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "$GRAPH_NET_ROOT" \
--model_path_prefix "${DATASET_ROOT}/typical_graph" \
--relative_model_path "$model_rel_path" \
--repo_uid "hf_paddle_samples" \
--sample_type "full_graph" \
--order_value "$ORDER_VALUE" \
--db_path "$DB_PATH"

((ORDER_VALUE++))

done < "$PADDLE_MODEL_LIST"

while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/typical_graph" \
--relative_model_path "$model_rel_path" \
--op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \
--op_names_path_prefix "${DATASET_ROOT}/03_sample_op_names" \
--repo_uid "hf_torch_samples" \
--sample_type "typical_graph" \
--order_value "$ORDER_VALUE" \
Expand All @@ -61,9 +47,9 @@ done < "$TYPICAL_GRAPH_SAMPLES_LIST"
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/fusible_graph" \
--model_path_prefix "${DATASET_ROOT}/fusible_graph" \
--relative_model_path "$model_rel_path" \
--op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \
--op_names_path_prefix "${DATASET_ROOT}/03_sample_op_names" \
--repo_uid "hf_torch_samples" \
--sample_type "fusible_graph" \
--order_value "$ORDER_VALUE" \
Expand All @@ -76,9 +62,9 @@ done < "$FUSIBLE_GRAPH_SAMPLES_LIST"
while IFS= read -r model_rel_path; do
echo "insert : $model_rel_path"
python3 "${GRAPH_NET_ROOT}/sqlite/graphsample_insert.py" \
--model_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/sole_op_graph" \
--model_path_prefix "${DATASET_ROOT}/sole_op_graph" \
--relative_model_path "$model_rel_path" \
--op_names_path_prefix "${GRAPH_NET_ROOT}/subgraph_dataset_20260203/03_sample_op_names" \
--op_names_path_prefix "${DATASET_ROOT}/03_sample_op_names" \
--repo_uid "hf_torch_samples" \
--sample_type "sole_op_graph" \
--order_value "$ORDER_VALUE" \
Expand Down
Loading