Skip to content

Commit 34f66e5

Browse files
authored
min,max,less,greater,equal,notequal联合验证 (#48)
* front:min,max;less,greater,equal,notequal * front:min,max;less,greater,equal,notequal * less,greater,equal,notequal联合验证 * equal,notequal:ompsimd验证
1 parent e6cd0cb commit 34f66e5

File tree

29 files changed

+982
-81
lines changed

29 files changed

+982
-81
lines changed

README.md

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
# deepx
22

3-
deepx提出了一种原生分布式自动并行的训推一体化的深度学习框架
3+
deepx提出了一种以IR计算图为核心的原生分布式自动并行的训推一体化的深度学习框架,以IR计算图为核心,经过多层级等价替换,实现从简单的数学形式的计算图,自适应等价替换为分布式、并行、自动反向的工程系统架构
44

55
## 一.deepx概述
66

7-
deepx的执行支持eager和auto两种模式
7+
deepx的分为前端表达侧,编译替换层,执行器层
88

9-
+ eager立即执行函数
10-
+ auto则会经过计算图编译器优化器
9+
+ 前端表达侧,交由算法工程师、用接近数学的表达方式,设计其数学计算过程。只表示为单节点、单线程的简介数学表达过程,不设计复杂的device类型、计算节点数量等。
10+
+ 编译替换层:注册了多轮不同类型的IR编译器,实现等价替换,可以以插件的形式增加自定义能力如定制kvcache,实现对计算图进行局部替换,获得新的能力。
11+
+ 执行器层:实现真正的tensor运算,大规模并行化。
1112

1213
### 前端
1314

1415
python sdk提供接近pytorch的API
15-
也容许其他语言的sdk接入
16+
也容许其他语言的sdk接入
1617

1718
+ IR通信调度。不同于pytorch或其他py+bind c++这种单一进程的栈上函数调度执行的方式。deepx各个程序(如front的python sdk,back的计算图编译器优化器、excuter如ompsimd)之间,通过IR实现网络通信调度,需要各自启动对应进程。
1819

@@ -21,36 +22,43 @@ python sdk提供接近pytorch的API
2122
|--------------|-----------------------|-------------------------|
2223
| 执行模式 | 单进程内函数栈调度 | 多进程分布式协同 |
2324
| 通信方式 | 内存直接访问 | IR网络计算调度协议交换 |
24-
| 组件耦合度 | 紧耦合(Python绑定C++)| 松耦合(gRPC/自定义协议)|
25+
| 组件耦合度 | 紧耦合(Python绑定C++)| 松耦合|
26+
| tensor生命周期管理 | 由python侧控制 | 由deltensor这个IR指令,显示管理tensor|
2527

26-
### 调度面
28+
### 编译替换层
2729

2830
+ 注册中心:收集当前已就绪的执行器的算子列表,收集算子时耗和空间占用信息
2931
+ 计算图编译器优化器:fusion算子,计算图节点消除,自动生成tensor拆分并行的计算子图并替代原节点
3032
+ 执行调度器:数据并行,流水线并行(前向反向并行),模型并行。
3133
+ front生成基础IR,编译器负责进行fusion成excuter注册的高级算子。
3234

33-
### 执行器
35+
### 执行层
3436

35-
负责低级的算子计算操作,以Op为执行的核心单元
37+
执行层包括op和mem两种执行器,但实际实现时,当前只设计了一个程序同时负责op和mem的管理。
38+
39+
负责低级的算子计算操作,以IR为执行的核心单元
3640
```
37-
Op{args(args_grad),returns(returns_grad)|func forward,backward}
41+
Op{args(args_grad),returns(returns_grad)|func run}
3842
```
3943

40-
大部分Op都需要同时实现forward和backward,但也有部分只为推理设计的fusionOp可以根据需要实现forward。
44+
Op需要实现run方法
4145

4246
关于excuter,只要能按deepxIR序列执行,并返回结果,就可以接入deepx分布式调度框架,因此,从硬件、指令、加速库、高级框架包括训练、推理引擎,都可以稍作修改,就接入deepx体系。
4347

48+
当前的
49+
4450

4551
#### 默认执行器
4652
+ cpu执行器,已实现ompsimd。其支持的算子列表[ompsimd](doc/excuter/op-mem-ompsimd/list.md)
4753

4854
#### GPU执行器
49-
+ cuda执行器【实现中状态】
55+
+ cuda执行器,其支持的算子列表[cuda](doc/excuter/op-mem-cuda/list.md)
56+
5057
欢迎大家提交cuda代码
5158

5259
+ rocm
53-
60+
+ apple
61+
+ 其他硬件加速器
5462

5563
#### 张量计算框架or函数级执行器
5664

doc/excuter/op-mem-cuda/list.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
| Operation | Author | Math Formula | IR Instruction |
4141
|-----------|--------|--------------|----------------|
4242
| normal | miaobyte | normal(mean,stddev,seed)->T1 | normal(var<any> mean, var<any> stddev, var<int32> seed)->(tensor<any> t) |
43+
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
4344
| uniform | miaobyte | uniform(low,high,seed)->T1 | uniform(var<any> low, var<any> high, var<int32> seed)->(tensor<any> t) |
4445
| arange | miaobyte | arange(start,step)->T1 | arange(var<any> start, var<any> step)->(tensor<any> t) |
4546
| constant | miaobyte | constant(value)->T1 | constant(var<any> value)->(tensor<any> t) |
@@ -50,18 +51,19 @@
5051
|-----------|--------|--------------|----------------|
5152
| switch | miaobyte | C=switch(tensors,cases) | switch(listtensor<any> tensors, tensor<int8> cases)->(tensor<any> result) |
5253
| greaterscalar | miaobyte | mask=compare(T1, scalar) | greaterscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
53-
| equalscalar | miaobyte | mask=compare(T1, scalar) | equalscalar(tensor<any> A, var<any> scalar, var<float64> epsilon)->(tensor<bool> mask) |
54+
| notequal | miaobyte | T1!=T2->mask | notequal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
55+
| equalscalar | miaobyte | T1==scalar->mask | equalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
5456
| min | miaobyte | T3=min(T1, T2) | min(tensor<any> A, tensor<any> B)->(tensor<any> C) |
5557
| maxscalar | miaobyte | T3=max(T1, scalar) | maxscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
5658
| tan | miaobyte | T3=tan(T1) | tan(tensor<float64|float32> A)->(tensor<float64|float32> C) |
5759
| sin | miaobyte | T3=sin(T1) | sin(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
58-
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
5960
| divscalar | miaobyte | T3=scalar/T1 | divscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
6061
| log | miaobyte | T3=log(T1) | log(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
6162
| addscalar | miaobyte | T3=T1+scalar | addscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
6263
| greater | miaobyte | mask=compare(T1, T2) | greater(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
6364
| lessscalar | miaobyte | mask=compare(T1, scalar) | lessscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
6465
| cos | miaobyte | T3=cos(T1) | cos(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
66+
| notequalscalar | miaobyte | T1!=scalar->mask | notequalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
6567
| minscalar | miaobyte | T3=min(T1, scalar) | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
6668
| rpowscalar | miaobyte | T3=pow(scalar, T1) | rpowscalar(var<float64|int32> scalar, tensor<float64|float32> A)->(tensor<float64|float32> C) |
6769
| rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
@@ -75,7 +77,7 @@
7577
| subscalar | miaobyte | T3=T1-scalar | subscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
7678
| exp | miaobyte | T3=exp(T1) | exp(tensor<float64|float32|float16|bfloat16> A)->(tensor<float64|float32|float16|bfloat16> C) |
7779
| mul | miaobyte | T3=T1*T2 | mul(tensor<any> A, tensor<any> B)->(tensor<any> C) |
78-
| equal | miaobyte | mask=compare(T1, T2) | equal(tensor<any> A, tensor<any> B, var<float64> epsilon)->(tensor<bool> mask) |
80+
| equal | miaobyte | T1==T2->mask | equal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
7981
| mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
8082
| div | miaobyte | T3=T1/T2 | div(tensor<any> A, tensor<any> B)->(tensor<any> C) |
8183
| invert | miaobyte | T3=~T1 | invert(tensor<int64|int32|int16|int8> A)->(tensor<int64|int32|int16|int8> C) |

doc/excuter/op-mem-ompsimd/list.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
| Operation | Author | Math Formula | IR Instruction |
4242
|-----------|--------|--------------|----------------|
4343
| normal | miaobyte | normal(mean,stddev,seed)->T1 | normal(var<any> mean, var<any> std, var<int32> seed)->(tensor<any> t) |
44+
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
4445
| uniform | miaobyte | uniform(low,high,seed)->T1 | uniform(var<any> low, var<any> high, var<int32> seed)->(tensor<any> t) |
4546
| arange | miaobyte | arange(start,step)->T1 | arange(var<any> start, var<any> step)->(tensor<any> t) |
4647
| constant | miaobyte | constant(value)->T1 | constant(var<any> value)->(tensor<any> t) |
@@ -51,15 +52,16 @@
5152
|-----------|--------|--------------|----------------|
5253
| switch | miaobyte | C=switch([tensors],case) | switch(listtensor<any> tensors, tensor<int8> cases)->(tensor<any> C) |
5354
| greaterscalar | miaobyte | mask=greater(T1,scalar) | greaterscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
54-
| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
55+
| notequal | miaobyte | notequal(T1,T2)->mask | notequal(tensor<any> A, tensor<any> B, var<float32> epsilon)->(tensor<bool> mask) |
56+
| equalscalar | miaobyte | mask=equal(T1,scalar) | equalscalar(tensor<any> A, var<any> scalar, var<float32> eposilon)->(tensor<bool> mask) |
5557
| min | miaobyte | T3=min(T1,T2) | min(tensor<any> A, tensor<any> B)->(tensor<any> C) |
5658
| maxscalar | miaobyte | T3=max(T1,scalar) | maxscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
57-
| dropout | miaobyte | dropout(p,seed)->A | dropout(var<float32> p, var<int32> seed)->(tensor<any> A) |
5859
| divscalar | miaobyte | T3=T1/scalar | divscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
5960
| log | miaobyte | T3=log(T1) | log(tensor<any> A)->(tensor<any> C) |
6061
| addscalar | miaobyte | T3=T1+scalar | addscalar(tensor<any> a, var<any> scalar)->(tensor<any> c) |
6162
| greater | miaobyte | mask=greater(T1,T2) | greater(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
6263
| lessscalar | miaobyte | mask=less(T1,scalar) | lessscalar(tensor<any> A, var<any> scalar)->(tensor<bool> mask) |
64+
| notequalscalar | miaobyte | mask=notequal(T1,scalar) | notequalscalar(tensor<any> A, var<any> scalar, var<float32> epsilon)->(tensor<bool> mask) |
6365
| minscalar | miaobyte | T3=min(T1,scalar) | minscalar(tensor<any> A, var<any> scalar)->(tensor<any> C) |
6466
| rpowscalar | miaobyte | T3=scalar^T1 | rpowscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
6567
| rdivscalar | miaobyte | T3=scalar/T1 | rdivscalar(var<any> scalar, tensor<any> A)->(tensor<any> C) |
@@ -73,7 +75,7 @@
7375
| subscalar | miaobyte | T3=T1-scalar | subscalar(tensor<any> a, var<any> scalar)->(tensor<any> c) |
7476
| exp | miaobyte | T3=exp(T1) | exp(tensor<any> A)->(tensor<any> C) |
7577
| mul | miaobyte | T3=T1*T2 | mul(tensor<any> A, tensor<any> B)->(tensor<any> C) |
76-
| equal | miaobyte | mask=equal(T1,T2) | equal(tensor<any> A, tensor<any> B)->(tensor<bool> mask) |
78+
| equal | miaobyte | equal(T1,T2)->mask | equal(tensor<any> A, tensor<any> B, var<float32> eposilon)->(tensor<bool> mask) |
7779
| mulscalar | miaobyte | T3=T1*scalar | mulscalar(tensor<any> A, var<any> b)->(tensor<any> C) |
7880
| div | miaobyte | T3=T1/T2 | div(tensor<any> A, tensor<any> B)->(tensor<any> C) |
7981
| invert | miaobyte | T3=~T1 | invert(tensor<int64|int32|int16|int8> A)->(tensor<int64|int32|int16|int8> C) |

excuter/cpp-common/src/deepx/tensorfunc/elementwise.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,31 @@ namespace deepx::tensorfunc
333333
{
334334
equalscalarDispatcher<Author, T, MaskT>::equalscalar(A, scalar, epsilon, mask);
335335
}
336+
//notequal(A,B)=>mask
337+
template <typename Author, typename T, typename MaskT>
338+
struct notequalDispatcher
339+
{
340+
static void notequal(const Tensor<T> &A, const Tensor<T> &B,const float epsilon, Tensor<MaskT> &mask) = delete;
341+
};
342+
343+
template <typename Author, typename T, typename MaskT>
344+
void notequal(const Tensor<T> &A, const Tensor<T> &B,const float epsilon, Tensor<MaskT> &mask)
345+
{
346+
notequalDispatcher<Author, T, MaskT>::notequal(A, B, epsilon, mask);
347+
}
348+
349+
// notequal(A,scalar)=>mask
350+
template <typename Author, typename T, typename MaskT>
351+
struct notequalscalarDispatcher
352+
{
353+
static void notequalscalar(const Tensor<T> &A, const T scalar,const float epsilon, Tensor<MaskT> &mask) = delete;
354+
};
355+
356+
template <typename Author, typename T, typename MaskT>
357+
void notequalscalar(const Tensor<T> &A, const T scalar,const float epsilon, Tensor<MaskT> &mask)
358+
{
359+
notequalscalarDispatcher<Author, T, MaskT>::notequalscalar(A, scalar, epsilon, mask);
360+
}
336361

337362
// less(A,B)=>mask
338363
template <typename Author, typename T, typename MaskT>

excuter/op-mem-cuda/src/client/tfs.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ namespace deepx::tf
406406
{
407407
Param("A", DataCategory::Tensor, Precision::Any),
408408
Param("B", DataCategory::Tensor, Precision::Any),
409-
Param("epsilon", DataCategory::Var, Precision::Float64),
409+
Param("epsilon", DataCategory::Var, Precision::Float32),
410410
}),
411411
vector<Param>(
412412
{
@@ -416,7 +416,27 @@ namespace deepx::tf
416416
{
417417
Param("A", DataCategory::Tensor, Precision::Any),
418418
Param("scalar", DataCategory::Var, Precision::Any),
419-
Param("epsilon", DataCategory::Var, Precision::Float64),
419+
Param("epsilon", DataCategory::Var, Precision::Float32),
420+
}),
421+
vector<Param>(
422+
{
423+
Param("mask", DataCategory::Tensor, Precision::Bool),
424+
})));
425+
tffactory.add_tf(std::make_shared<NotEqual<miaobyte>>(vector<Param>(
426+
{
427+
Param("A", DataCategory::Tensor, Precision::Any),
428+
Param("B", DataCategory::Tensor, Precision::Any),
429+
Param("epsilon", DataCategory::Var, Precision::Float32),
430+
}),
431+
vector<Param>(
432+
{
433+
Param("mask", DataCategory::Tensor, Precision::Bool),
434+
})));
435+
tffactory.add_tf(std::make_shared<NotEqualScalar<miaobyte>>(vector<Param>(
436+
{
437+
Param("A", DataCategory::Tensor, Precision::Any),
438+
Param("scalar", DataCategory::Var, Precision::Any),
439+
Param("epsilon", DataCategory::Var, Precision::Float32),
420440
}),
421441
vector<Param>(
422442
{

excuter/op-mem-cuda/src/deepx/mem/mem_cuda.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,12 @@ namespace deepx::mem
107107
result->data = ptr_tensor->data;
108108
break;
109109
}
110-
110+
case Precision::Bool:
111+
{
112+
auto ptr_tensor = std::static_pointer_cast<Tensor<bool>>(ptr);
113+
result->data = ptr_tensor->data;
114+
break;
115+
}
111116
default:
112117
throw std::runtime_error("Unsupported dtype: " + precision_str(ptr->shape.dtype));
113118
}

0 commit comments

Comments
 (0)