|
2 | 2 |
|
3 | 3 | namespace deepx::op |
4 | 4 | { |
5 | | - void Op::load(const char *str) |
6 | | - { |
7 | | - // 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) requires_grad=true |
| 5 | + //与deepx/front/py/deepx/nn/deepxir.py对应 |
| 6 | + |
| 7 | + // 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) //id=1 create_time=1714512000 send_time=1714512000 recv_time=1714512000 |
| 8 | + void Op::load(const char* str) { |
8 | 9 | string input(str); |
9 | | - size_t arrow_pos = input.find("->"); |
10 | | - string head = input.substr(0, arrow_pos); |
11 | | - string tail = arrow_pos != string::npos ? input.substr(arrow_pos + 2) : ""; |
| 10 | + |
| 11 | + // 分割元数据部分 |
| 12 | + size_t meta_pos = input.find("//"); |
| 13 | + string body = input.substr(0, meta_pos); |
| 14 | + string meta = (meta_pos != string::npos) ? input.substr(meta_pos + 2) : ""; |
| 15 | + |
| 16 | + // 解析操作主体 |
| 17 | + size_t arrow_pos = body.find("->"); |
| 18 | + if (arrow_pos == string::npos) { |
| 19 | + arrow_pos = body.find("<-"); |
| 20 | + if (arrow_pos != string::npos) { |
| 21 | + grad = true; // 反向传播标记 |
| 22 | + } |
| 23 | + } |
| 24 | + |
| 25 | + if (arrow_pos == string::npos) { |
| 26 | + throw runtime_error("Invalid IR format: missing arrow"); |
| 27 | + } |
12 | 28 |
|
13 | | - // 解析操作名和类型 |
| 29 | + string head = body.substr(0, arrow_pos); |
| 30 | + string tail = body.substr(arrow_pos + 2); |
| 31 | + |
| 32 | + // 解析操作名和数据类型 |
14 | 33 | size_t at_pos = head.find('@'); |
15 | | - if (at_pos != string::npos) |
16 | | - { |
| 34 | + if (at_pos != string::npos) { |
17 | 35 | name = head.substr(0, at_pos); |
18 | | - dtype = head.substr(at_pos + 1, head.find(' ') - at_pos - 1); |
19 | | - head = head.substr(head.find(' ') + 1); |
20 | | - } |
21 | | - else |
22 | | - { |
23 | | - name = head.substr(0, head.find(' ')); |
24 | | - dtype = "any"; |
25 | | - head = head.substr(name.size() + 1); |
| 36 | + size_t space_pos = head.find(' ', at_pos); |
| 37 | + if (space_pos != string::npos) { |
| 38 | + dtype = head.substr(at_pos + 1, space_pos - at_pos - 1); |
| 39 | + head = head.substr(space_pos + 1); |
| 40 | + } else { |
| 41 | + dtype = head.substr(at_pos + 1); |
| 42 | + head.clear(); |
| 43 | + } |
| 44 | + } else { |
| 45 | + size_t space_pos = head.find(' '); |
| 46 | + if (space_pos != string::npos) { |
| 47 | + name = head.substr(0, space_pos); |
| 48 | + head = head.substr(space_pos + 1); |
| 49 | + dtype = "any"; |
| 50 | + } else { |
| 51 | + name = head; |
| 52 | + head.clear(); |
| 53 | + dtype = "any"; |
| 54 | + } |
26 | 55 | } |
27 | 56 |
|
28 | | - // 解析输入参数(支持带括号的梯度名) |
| 57 | + // 解析输入参数 |
29 | 58 | stringstream head_ss(head); |
30 | 59 | string token; |
31 | | - while (head_ss >> token) |
32 | | - { |
33 | | - size_t bracket = token.find('('); |
34 | | - if (bracket != string::npos) |
35 | | - { |
| 60 | + while (head_ss >> token) { |
| 61 | + size_t bracket = token.find('('); |
| 62 | + if (bracket != string::npos && token.back() == ')') { |
36 | 63 | args.push_back(token.substr(0, bracket)); |
37 | | - args_grad.push_back(token.substr(bracket + 1, token.find(')') - bracket - 1)); |
38 | | - require_grad = true; |
39 | | - } |
40 | | - else |
41 | | - { |
| 64 | + args_grad.push_back(token.substr(bracket + 1, token.size() - bracket - 2)); |
| 65 | + } else { |
42 | 66 | args.push_back(token); |
| 67 | + args_grad.emplace_back(""); // 保持梯度与参数数量一致 |
43 | 68 | } |
44 | 69 | } |
45 | 70 |
|
46 | | - // 解析输出参数和标志 |
| 71 | + // 解析输出参数 |
47 | 72 | stringstream tail_ss(tail); |
48 | | - while (tail_ss >> token) |
49 | | - { |
50 | | - if (token.find('(') != string::npos) |
51 | | - { |
52 | | - size_t bracket = token.find('('); |
| 73 | + while (tail_ss >> token) { |
| 74 | + size_t bracket = token.find('('); |
| 75 | + if (bracket != string::npos && token.back() == ')') { |
53 | 76 | returns.push_back(token.substr(0, bracket)); |
54 | | - returns_grad.push_back(token.substr(bracket + 1, token.find(')') - bracket - 1)); |
55 | | - } |
56 | | - else if (token == "requires_grad=true") |
57 | | - { |
58 | | - require_grad = true; |
59 | | - } |
60 | | - else |
61 | | - { |
| 77 | + returns_grad.push_back(token.substr(bracket + 1, token.size() - bracket - 2)); |
| 78 | + } else { |
62 | 79 | returns.push_back(token); |
| 80 | + returns_grad.emplace_back(""); // 保持梯度与参数数量一致 |
| 81 | + } |
| 82 | + } |
| 83 | + |
| 84 | + // 解析元数据 |
| 85 | + if (!meta.empty()) { |
| 86 | + stringstream meta_ss(meta); |
| 87 | + string key, value; |
| 88 | + while (meta_ss >> key) { |
| 89 | + size_t eq_pos = key.find('='); |
| 90 | + if (eq_pos != string::npos) { |
| 91 | + value = key.substr(eq_pos + 1); |
| 92 | + key = key.substr(0, eq_pos); |
| 93 | + |
| 94 | + if (key == "id") { |
| 95 | + id = stoi(value); |
| 96 | + } else if (key == "created_at") { |
| 97 | + created_at = system_clock::from_time_t(stod(value)); |
| 98 | + } else if (key == "sent_at") { |
| 99 | + sent_at = system_clock::from_time_t(stod(value)); |
| 100 | + } |
| 101 | + } |
63 | 102 | } |
64 | 103 | } |
65 | 104 | } |
66 | 105 |
|
| 106 | + |
67 | 107 | void Op::init(const string &opname, |
68 | 108 | const string &dtype, |
69 | 109 | const vector<string> &args, |
70 | 110 | const vector<string> &returns, |
71 | | - bool require_grad, |
| 111 | + bool grad, |
72 | 112 | const vector<string> &args_grad, |
73 | 113 | const vector<string> &returns_grad) |
74 | 114 | { |
75 | 115 | this->name = opname; |
76 | 116 | this->dtype = dtype; |
77 | 117 | this->args = args; |
78 | 118 | this->returns = returns; |
79 | | - this->require_grad = require_grad; |
| 119 | + this->grad = grad; |
80 | 120 |
|
81 | | - if (require_grad) |
| 121 | + if (grad) |
82 | 122 | { |
83 | 123 | // 如果提供了梯度变量名,就使用提供的名字 |
84 | 124 | if (!args_grad.empty()) |
|
0 commit comments