Skip to content

Commit 8d4e742

Browse files
committed
send/resp:通信增加了返回确认
1 parent e5f46e4 commit 8d4e742

20 files changed

Lines changed: 173 additions & 115 deletions

File tree

excuter/common/src/client/udpserver.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ namespace client
4242
while (true)
4343
{
4444
len = sizeof(cliaddr);
45-
n = recvfrom(sockfd, (char *)buffer, 1024, 0, (struct sockaddr *)&cliaddr, &len);
45+
n = recvfrom(sockfd, (char *)buffer, sizeof(buffer), 0, (struct sockaddr *)&cliaddr, &len);
4646
buffer[n] = '\0';
4747

4848
// 新增换行拆分逻辑
@@ -52,7 +52,9 @@ namespace client
5252
if (!line.empty()) {
5353
cout << "~" << line << endl;
5454
char *IR = const_cast<char *>(line.c_str());
55-
func(IR);
55+
string strresp=func(IR);
56+
sendto(sockfd, strresp.c_str(), strresp.size(), 0,
57+
(const struct sockaddr*)&cliaddr, len);
5658
}
5759
}
5860
}

excuter/common/src/client/udpserver.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace client{
2222
udpserver(int port);
2323
~udpserver();
2424
void start();
25-
using handlefunc = std::function<void(char *buffer)>;
25+
using handlefunc = std::function<std::string(const char *buffer)>;
2626
handlefunc func;
2727
};
2828

excuter/common/src/deepx/op/op.cpp

Lines changed: 85 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,83 +2,123 @@
22

33
namespace deepx::op
44
{
5-
void Op::load(const char *str)
6-
{
7-
// 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) requires_grad=true
5+
//与deepx/front/py/deepx/nn/deepxir.py对应
6+
7+
// 新格式示例:mul@float32 a(a_grad) b(b_grad) -> a(a_grad) //id=1 create_time=1714512000 send_time=1714512000 recv_time=1714512000
8+
void Op::load(const char* str) {
89
string input(str);
9-
size_t arrow_pos = input.find("->");
10-
string head = input.substr(0, arrow_pos);
11-
string tail = arrow_pos != string::npos ? input.substr(arrow_pos + 2) : "";
10+
11+
// 分割元数据部分
12+
size_t meta_pos = input.find("//");
13+
string body = input.substr(0, meta_pos);
14+
string meta = (meta_pos != string::npos) ? input.substr(meta_pos + 2) : "";
15+
16+
// 解析操作主体
17+
size_t arrow_pos = body.find("->");
18+
if (arrow_pos == string::npos) {
19+
arrow_pos = body.find("<-");
20+
if (arrow_pos != string::npos) {
21+
grad = true; // 反向传播标记
22+
}
23+
}
24+
25+
if (arrow_pos == string::npos) {
26+
throw runtime_error("Invalid IR format: missing arrow");
27+
}
1228

13-
// 解析操作名和类型
29+
string head = body.substr(0, arrow_pos);
30+
string tail = body.substr(arrow_pos + 2);
31+
32+
// 解析操作名和数据类型
1433
size_t at_pos = head.find('@');
15-
if (at_pos != string::npos)
16-
{
34+
if (at_pos != string::npos) {
1735
name = head.substr(0, at_pos);
18-
dtype = head.substr(at_pos + 1, head.find(' ') - at_pos - 1);
19-
head = head.substr(head.find(' ') + 1);
20-
}
21-
else
22-
{
23-
name = head.substr(0, head.find(' '));
24-
dtype = "any";
25-
head = head.substr(name.size() + 1);
36+
size_t space_pos = head.find(' ', at_pos);
37+
if (space_pos != string::npos) {
38+
dtype = head.substr(at_pos + 1, space_pos - at_pos - 1);
39+
head = head.substr(space_pos + 1);
40+
} else {
41+
dtype = head.substr(at_pos + 1);
42+
head.clear();
43+
}
44+
} else {
45+
size_t space_pos = head.find(' ');
46+
if (space_pos != string::npos) {
47+
name = head.substr(0, space_pos);
48+
head = head.substr(space_pos + 1);
49+
dtype = "any";
50+
} else {
51+
name = head;
52+
head.clear();
53+
dtype = "any";
54+
}
2655
}
2756

28-
// 解析输入参数(支持带括号的梯度名)
57+
// 解析输入参数
2958
stringstream head_ss(head);
3059
string token;
31-
while (head_ss >> token)
32-
{
33-
size_t bracket = token.find('(');
34-
if (bracket != string::npos)
35-
{
60+
while (head_ss >> token) {
61+
size_t bracket = token.find('(');
62+
if (bracket != string::npos && token.back() == ')') {
3663
args.push_back(token.substr(0, bracket));
37-
args_grad.push_back(token.substr(bracket + 1, token.find(')') - bracket - 1));
38-
require_grad = true;
39-
}
40-
else
41-
{
64+
args_grad.push_back(token.substr(bracket + 1, token.size() - bracket - 2));
65+
} else {
4266
args.push_back(token);
67+
args_grad.emplace_back(""); // 保持梯度与参数数量一致
4368
}
4469
}
4570

46-
// 解析输出参数和标志
71+
// 解析输出参数
4772
stringstream tail_ss(tail);
48-
while (tail_ss >> token)
49-
{
50-
if (token.find('(') != string::npos)
51-
{
52-
size_t bracket = token.find('(');
73+
while (tail_ss >> token) {
74+
size_t bracket = token.find('(');
75+
if (bracket != string::npos && token.back() == ')') {
5376
returns.push_back(token.substr(0, bracket));
54-
returns_grad.push_back(token.substr(bracket + 1, token.find(')') - bracket - 1));
55-
}
56-
else if (token == "requires_grad=true")
57-
{
58-
require_grad = true;
59-
}
60-
else
61-
{
77+
returns_grad.push_back(token.substr(bracket + 1, token.size() - bracket - 2));
78+
} else {
6279
returns.push_back(token);
80+
returns_grad.emplace_back(""); // 保持梯度与参数数量一致
81+
}
82+
}
83+
84+
// 解析元数据
85+
if (!meta.empty()) {
86+
stringstream meta_ss(meta);
87+
string key, value;
88+
while (meta_ss >> key) {
89+
size_t eq_pos = key.find('=');
90+
if (eq_pos != string::npos) {
91+
value = key.substr(eq_pos + 1);
92+
key = key.substr(0, eq_pos);
93+
94+
if (key == "id") {
95+
id = stoi(value);
96+
} else if (key == "created_at") {
97+
created_at = system_clock::from_time_t(stod(value));
98+
} else if (key == "sent_at") {
99+
sent_at = system_clock::from_time_t(stod(value));
100+
}
101+
}
63102
}
64103
}
65104
}
66105

106+
67107
void Op::init(const string &opname,
68108
const string &dtype,
69109
const vector<string> &args,
70110
const vector<string> &returns,
71-
bool require_grad,
111+
bool grad,
72112
const vector<string> &args_grad,
73113
const vector<string> &returns_grad)
74114
{
75115
this->name = opname;
76116
this->dtype = dtype;
77117
this->args = args;
78118
this->returns = returns;
79-
this->require_grad = require_grad;
119+
this->grad = grad;
80120

81-
if (require_grad)
121+
if (grad)
82122
{
83123
// 如果提供了梯度变量名,就使用提供的名字
84124
if (!args_grad.empty())

excuter/common/src/deepx/op/op.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include <string>
88
#include <iostream>
99
#include <sstream>
10-
10+
#include <chrono>
1111
#include "deepx/tensor.hpp"
1212
#include "deepx/mem/mem.hpp"
1313
#include "deepx/dtype.hpp"
@@ -16,18 +16,21 @@ namespace deepx::op
1616
{
1717
using deepx::mem::Mem;
1818
using namespace std;
19-
19+
using namespace std::chrono;
2020
class Op
2121
{
2222
public:
2323
string name;
2424
string dtype;
2525
vector<string> args;
2626
vector<string> args_grad;
27-
bool require_grad;
27+
bool grad=false;
2828
vector<string> returns;
2929
vector<string> returns_grad;
30-
30+
int id;
31+
system_clock::time_point created_at;
32+
system_clock::time_point sent_at;
33+
system_clock::time_point recv_at;
3134
public:
3235
Op() = default;
3336
Op(const Op &) = default;
@@ -56,7 +59,7 @@ namespace deepx::op
5659
const string &dtype,
5760
const vector<string> &args,
5861
const vector<string> &returns,
59-
bool require_grad,
62+
bool grad,
6063
const vector<string> &args_grad,
6164
const vector<string> &returns_grad);
6265
};

excuter/op-mem-ompsimd/src/client/main.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#include <mutex>
22

33
#include <deepx/tensorfunc/init.hpp>
4-
#include <deepx/tensorfunc/new.hpp>
5-
#include <deepx/tensorfunc/print.hpp>
6-
74
#include "deepx/op/op.hpp"
85
#include "deepx/op/opfactory.hpp"
96
#include "deepx/mem/mem.hpp"
@@ -21,25 +18,31 @@ int main()
2118
deepx::op::OpFactory opfactory;
2219
register_all(opfactory);
2320

24-
server.func = [&mem, &opfactory, &memmutex](const char *buffer)
21+
server.func = [&mem, &opfactory, &memmutex](const char *buffer)->std::string
2522
{
2623
deepx::op::Op op;
24+
op.recv_at = chrono::system_clock::now();
2725
op.load(buffer);
2826
if (opfactory.ops.find(op.name)==opfactory.ops.end()){
2927
cout<<"<op> "<<op.name<<" not found"<<endl;
30-
return;
28+
return "error";
3129
}
3230
auto &type_map = opfactory.ops.find(op.name)->second;
3331
if (type_map.find(op.dtype)==type_map.end()){
3432
cout<<"<op>"<<op.name<<" "<<op.dtype<<" not found"<<endl;
35-
return;
33+
return "error";
3634
}
3735
auto src = type_map.find(op.dtype)->second;
38-
39-
(*src).init(op.name, op.dtype, op.args, op.returns, op.require_grad, op.args_grad, op.returns_grad);
36+
37+
(*src).init(op.name, op.dtype, op.args, op.returns, op.grad, op.args_grad, op.returns_grad);
4038
memmutex.lock();
41-
(*src).forward(mem);
39+
if (op.grad) {
40+
(*src).backward(mem);
41+
}else {
42+
(*src).forward(mem);
43+
}
4244
memmutex.unlock();
45+
return to_string(op.id);
4346
};
4447
server.start();
4548
return 0;

excuter/op-mem-ompsimd/src/client/udpsend.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

excuter/op-mem-ompsimd/src/deepx/op/transpose.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77

88
namespace deepx::op{
9-
109
template <typename T>
1110
class Transpose : public OpT<T> {
1211
public:

front/py/deepx/nn/deepxir.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Tuple, List, Optional
2+
import time
23

4+
_id_counter=0
35
class DeepxIR:
46
def __init__(self,
57
name:str,
@@ -37,12 +39,9 @@ def __init__(self,
3739
self._grad = grad
3840
self._args_grad = args_grad if grad else []
3941
self._returns_grad = returns_grad if grad else []
40-
41-
def forward(self, *input) -> Tuple:
42-
raise NotImplementedError
43-
44-
def backward(self, *grad_outputs) -> Tuple:
45-
raise NotImplementedError
42+
self._id=++_id_counter
43+
self._created_at=time.time()
44+
self._sent_at=None
4645

4746
def __str__(self):
4847
if self._dtype == None or self._dtype == '':
@@ -67,6 +66,10 @@ def __str__(self):
6766
if self._grad and self._returns_grad[i]:
6867
ret_part += f"({self._returns_grad[i]})"
6968
parts.append(ret_part)
70-
69+
70+
parts.append("//")
71+
parts.append(f"id={self._id}")
72+
parts.append(f"created_at={self._created_at}")
73+
parts.append(f"sent_at={self._sent_at}")
7174
return ' '.join(parts)
7275

front/py/deepx/nn/functional/activite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@ def relu(t: Tensor,inplace:bool=False)->Tensor:
77
if not inplace:
88
out=Tensor(shape=t.shape, dtype=t.dtype, device=t.device)
99
ir=DeepxIR("max_scalar",t.dtype,[t._node.name,0], [out._node.name])
10-
send(str(ir))
10+
send(ir)
1111
return out
1212

front/py/deepx/nn/functional/elementwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def _A_B_elementwiseop_C(
1515
out.node.add_input(opnode)
1616
if a.graph.eager:
1717
ir=DeepxIR(op, a.dtype, [a.node.name, b.node.name], [out.node.name])
18-
send(str(ir))
18+
send(ir)
1919
def _A_b_elementwiseop_C(
2020
a:Tensor,
2121
b: Optional[Union[ float, int]] = None,
@@ -30,7 +30,7 @@ def _A_b_elementwiseop_C(
3030
varir=DeepxIR("argset", a.dtype, [b], [varnode.name])
3131
send(str(varir))
3232
ir=DeepxIR(op, a.dtype, [a.node.name,varnode.name], [out.node.name])
33-
send(str(ir))
33+
send(ir)
3434
#add
3535
OpNode.register("add")
3636
OpNode.register("add_scalar")

0 commit comments

Comments
 (0)