Skip to content

Commit 27ed7bc

Browse files
authored
gather,save,load。deepxctl。 (#26)
* tokenizer:验证 * ompsimd:gather,save,load front/py:测试了save,load * deepxctl: deepxctl,golang实现,提供统一命令行运维 * deepxctl: deepxctl,golang实现,提供统一命令行运维 * deepxctl: deepxctl,golang实现,提供统一命令行运维 * deepxctl: deepxctl,golang实现,提供统一命令行运维 * gather:测试验证,和pytorch保持了一致 * gather:测试验证,和pytorch保持了一致 * cuda:load,save,gather,编译通过,测试有点异常。
1 parent 450cb78 commit 27ed7bc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+2873
-1089
lines changed

deepxctl/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
.idea
2+
deepxctl

deepxctl/cmd/tensor/print.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
package tensor
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"os"
7+
8+
coretensor "github.com/array2d/deepx/deepxctl/tensor"
9+
)
10+
11+
func PrintCmd() {
12+
printCmd := flag.NewFlagSet("print", flag.ExitOnError)
13+
tensorPath := os.Args[0]
14+
if tensorPath == "" {
15+
fmt.Println("请指定文件路径")
16+
printCmd.Usage()
17+
return
18+
}
19+
var err error
20+
var shape coretensor.Shape
21+
shape, err = coretensor.LoadShape(tensorPath)
22+
if err != nil {
23+
fmt.Println("读取文件失败:", err)
24+
}
25+
switch shape.Dtype {
26+
case "bool":
27+
var t coretensor.Tensor[bool]
28+
t, err = coretensor.LoadTensor[bool](tensorPath)
29+
if err != nil {
30+
fmt.Println("读取文件失败:", err)
31+
}
32+
t.Print()
33+
case "int8":
34+
var t coretensor.Tensor[int8]
35+
t, err = coretensor.LoadTensor[int8](tensorPath)
36+
if err != nil {
37+
fmt.Println("读取文件失败:", err)
38+
}
39+
t.Print()
40+
case "int16":
41+
var t coretensor.Tensor[int16]
42+
t, err = coretensor.LoadTensor[int16](tensorPath)
43+
if err != nil {
44+
fmt.Println("读取文件失败:", err)
45+
}
46+
t.Print()
47+
case "int32":
48+
var t coretensor.Tensor[int32]
49+
t, err = coretensor.LoadTensor[int32](tensorPath)
50+
if err != nil {
51+
fmt.Println("读取文件失败:", err)
52+
}
53+
t.Print()
54+
case "int64":
55+
var t coretensor.Tensor[int64]
56+
t, err = coretensor.LoadTensor[int64](tensorPath)
57+
if err != nil {
58+
fmt.Println("读取文件失败:", err)
59+
}
60+
t.Print()
61+
case "float16":
62+
// var t coretensor.Tensor[float16]
63+
// t, err = coretensor.LoadTensor[float16](tensorPath)
64+
// if err != nil {
65+
// fmt.Println("读取文件失败:", err)
66+
// }
67+
// t.Print()
68+
case "float32":
69+
var t coretensor.Tensor[float32]
70+
t, err = coretensor.LoadTensor[float32](tensorPath)
71+
if err != nil {
72+
fmt.Println("读取文件失败:", err)
73+
}
74+
t.Print()
75+
case "float64":
76+
var t coretensor.Tensor[float64]
77+
t, err = coretensor.LoadTensor[float64](tensorPath)
78+
if err != nil {
79+
fmt.Println("读取文件失败:", err)
80+
}
81+
t.Print()
82+
}
83+
}

deepxctl/cmd/tensor/tensor.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package tensor
2+
3+
import (
4+
"fmt"
5+
"os"
6+
)
7+
8+
func PrintUsage() {
9+
fmt.Println("使用方法:")
10+
fmt.Println(" tensor print <文件路径>")
11+
fmt.Println(" tensor help")
12+
}
13+
14+
func Execute() {
15+
if len(os.Args) < 1 {
16+
PrintUsage()
17+
os.Exit(1)
18+
}
19+
20+
subCmd := "help"
21+
if len(os.Args) > 0 {
22+
subCmd = os.Args[0]
23+
}
24+
25+
switch subCmd {
26+
case "print":
27+
os.Args = os.Args[1:]
28+
PrintCmd()
29+
case "help":
30+
PrintUsage()
31+
default:
32+
fmt.Printf("未知的张量命令: %s\n", subCmd)
33+
PrintUsage()
34+
os.Exit(1)
35+
}
36+
}

deepxctl/go.mod

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module github.com/array2d/deepx/deepxctl
2+
3+
go 1.23.2
4+
5+
require gopkg.in/yaml.v2 v2.4.0 // indirect

deepxctl/go.sum

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
2+
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
3+
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=

deepxctl/main.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
9+
"github.com/array2d/deepx/deepxctl/cmd/tensor"
10+
)
11+
12+
var version = "0.1.0"
13+
14+
func printUsage() {
15+
execName := filepath.Base(os.Args[0])
16+
fmt.Printf("用法: %s [命令] [参数]\n\n", execName)
17+
fmt.Println("可用命令:")
18+
fmt.Println(" tensor 张量操作相关命令")
19+
fmt.Println(" version 显示版本信息")
20+
fmt.Println(" help 显示帮助信息")
21+
fmt.Println("\n使用 '%s help [命令]' 获取命令的详细信息", execName)
22+
}
23+
24+
func main() {
25+
flag.Usage = printUsage
26+
27+
if len(os.Args) < 2 {
28+
printUsage()
29+
os.Exit(1)
30+
}
31+
32+
// 获取子命令
33+
cmd := os.Args[1]
34+
35+
// 根据子命令执行相应操作
36+
switch cmd {
37+
case "tensor":
38+
// 移除子命令,让子命令处理剩余的参数
39+
os.Args = os.Args[2:]
40+
tensor.Execute()
41+
42+
case "version":
43+
fmt.Printf("deepxctl 版本 %s\n", version)
44+
45+
case "help":
46+
if len(os.Args) > 2 {
47+
helpCmd := os.Args[2]
48+
switch helpCmd {
49+
case "tensor":
50+
tensor.PrintUsage()
51+
default:
52+
fmt.Printf("未知命令: %s\n", helpCmd)
53+
printUsage()
54+
}
55+
} else {
56+
printUsage()
57+
}
58+
59+
default:
60+
fmt.Printf("未知命令: %s\n", cmd)
61+
printUsage()
62+
os.Exit(1)
63+
}
64+
}

deepxctl/tensor/fp16.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package tensor
2+
3+
import (
4+
"encoding/binary"
5+
"math"
6+
)
7+
8+
func Byte2ToFloat16(value []byte) float32 {
9+
bits := binary.BigEndian.Uint16(value)
10+
// 这里需要实现float16到float32的转换
11+
// 简化实现,实际项目中需要更完整的实现
12+
sign := float32(1)
13+
if bits&0x8000 != 0 {
14+
sign = -1
15+
}
16+
exp := int((bits & 0x7C00) >> 10)
17+
frac := float32(bits&0x03FF) / 1024.0
18+
19+
if exp == 0 {
20+
return sign * frac * float32(1.0/16384.0) // 非规格化数
21+
} else if exp == 31 {
22+
if frac == 0 {
23+
return sign * float32(math.Inf(1)) // 无穷大
24+
}
25+
return float32(math.NaN()) // NaN
26+
}
27+
return sign * float32(math.Pow(2, float64(exp-15))) * (1.0 + frac) // 规格化数
28+
}

deepxctl/tensor/io.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package tensor
2+
3+
import (
4+
"encoding/binary"
5+
"os"
6+
7+
"gopkg.in/yaml.v2"
8+
)
9+
10+
func LoadShape(filePath string) (shape Shape, err error) {
11+
var shapeData []byte
12+
shapeData, err = os.ReadFile(filePath + ".shape")
13+
if err != nil {
14+
return
15+
}
16+
17+
err = yaml.Unmarshal(shapeData, &shape)
18+
if err != nil {
19+
return
20+
}
21+
return
22+
}
23+
func LoadTensor[T Number](filePath string) (tensor Tensor[T], err error) {
24+
25+
_, err = os.ReadFile(filePath + ".shape")
26+
if err != nil {
27+
return
28+
}
29+
var shape Shape
30+
shape, err = LoadShape(filePath)
31+
if err != nil {
32+
return
33+
}
34+
file, err := os.Open(filePath + ".data")
35+
if err != nil {
36+
return
37+
}
38+
defer file.Close()
39+
data := make([]T, shape.Size)
40+
41+
err = binary.Read(file, binary.LittleEndian, data)
42+
if err != nil {
43+
return
44+
}
45+
tensor = Tensor[T]{Data: data, Shape: shape}
46+
return
47+
}

deepxctl/tensor/print.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package tensor
2+
3+
import "fmt"
4+
5+
func (t *Tensor[T]) Range(dimCount int, f func(indices []int)) {
6+
Shape := t.Shape
7+
if dimCount > len(Shape.Shape) {
8+
panic("dimCount exceeds the number of dimensions in the Tensor.")
9+
}
10+
11+
totalSize := 1
12+
13+
// 计算总的循环次数
14+
for i := 0; i < dimCount; i++ {
15+
totalSize *= Shape.At(i)
16+
}
17+
indices := make([]int, dimCount) // 初始化索引向量
18+
// 遍历所有可能的索引组合
19+
for idx := 0; idx < totalSize; idx++ {
20+
// 反算出 indices 数组
21+
idx_ := idx
22+
for dim := dimCount - 1; dim >= 0; dim-- {
23+
indices[dim] = idx_ % Shape.At(dim) // 计算当前维度的索引
24+
idx_ /= Shape.At(dim) // 更新 idx
25+
}
26+
f(indices) // 调用传入的函数
27+
}
28+
}
29+
30+
func AutoFormat(dtype string) string {
31+
switch dtype {
32+
case "bool":
33+
return "%v"
34+
case "int8":
35+
return "%d"
36+
case "int16":
37+
return "%d"
38+
case "int32":
39+
return "%d"
40+
case "int64":
41+
return "%d"
42+
case "float16":
43+
return "%f"
44+
case "float32":
45+
return "%f"
46+
case "float64":
47+
return "%f"
48+
default:
49+
return "%v"
50+
}
51+
}
52+
53+
// Print 打印Tensor的值
54+
func (t *Tensor[T]) Print(format_ ...string) {
55+
Shape := t.Shape
56+
format := AutoFormat(t.Dtype)
57+
if len(format_) > 0 {
58+
format = format_[0]
59+
}
60+
fmt.Print("shape:[")
61+
for i := 0; i < Shape.Dim; i++ {
62+
fmt.Print(Shape.At(i))
63+
if i < Shape.Dim-1 {
64+
fmt.Print(", ")
65+
}
66+
}
67+
fmt.Println("]")
68+
if Shape.Dim == 1 {
69+
fmt.Print("[")
70+
for i := 0; i < Shape.At(0); i++ {
71+
if i > 0 {
72+
fmt.Print(" ")
73+
}
74+
fmt.Printf(format, t.Get(i))
75+
}
76+
fmt.Println("]")
77+
} else if Shape.Dim == 2 {
78+
fmt.Println("[")
79+
for i := 0; i < Shape.At(0); i++ {
80+
fmt.Print(" [")
81+
for j := 0; j < Shape.At(1); j++ {
82+
if j > 0 {
83+
fmt.Print(" ")
84+
}
85+
fmt.Printf(format, t.Get(i, j))
86+
}
87+
88+
fmt.Print("]")
89+
if i < Shape.At(0)-1 {
90+
fmt.Print(",")
91+
}
92+
fmt.Println()
93+
}
94+
fmt.Println("]")
95+
} else {
96+
t.Range(Shape.Dim-2, func(indices []int) {
97+
fmt.Print(indices)
98+
m, n := Shape.At(Shape.Dim-2), Shape.At(Shape.Dim-1)
99+
fmt.Print([]int{m, n})
100+
fmt.Println("=")
101+
102+
fmt.Println("[")
103+
for i := 0; i < m; i++ {
104+
fmt.Print(" [")
105+
for j := 0; j < n; j++ {
106+
if j > 0 {
107+
fmt.Print(" ")
108+
}
109+
fmt.Printf(format, t.Get(append(indices, i, j)...))
110+
}
111+
112+
fmt.Print("]")
113+
if i < m-1 {
114+
fmt.Print(",")
115+
}
116+
fmt.Println()
117+
}
118+
fmt.Println("]")
119+
})
120+
}
121+
}

0 commit comments

Comments
 (0)