Skip to content

Commit 5f422bc

Browse files
authored
py:tensor __getitem__ 已支持 (#74)
* tensor:__getitem__ 实现中 * py:tensor __getitem__
1 parent 73addbb commit 5f422bc

11 files changed

Lines changed: 154 additions & 32 deletions

File tree

front/py/deepx/nn/functional/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"mean",
3838
"rsqrt",
3939
"softmax",
40-
"squeeze","unsqueeze",
40+
"squeeze","unsqueeze","sliceselect","cat",
4141

4242
#other
4343
"calculate_fan_in_and_fan_out",

front/py/deepx/nn/functional/changeshape.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
from typing import Union
12
from deepx import Tensor
2-
from .leaffunc_changeshape import reshape
3-
3+
from .leaffunc_changeshape import reshape,indexselect, concat
4+
from .leaffunc_init import newtensor,arange
45
def squeeze(t:Tensor,dim:int)->Tensor:
56
assert isinstance(dim,int)
67
assert isinstance(t,Tensor)
@@ -15,4 +16,17 @@ def unsqueeze(t:Tensor,dim:int)->Tensor:
1516
dim=dim%t.ndim
1617
newshape=list(t.shape)
1718
newshape.insert(dim,1)
18-
return reshape(t,tuple(newshape))
19+
return reshape(t,tuple(newshape))
20+
21+
def sliceselect(t:Tensor,sliceobj:slice,dim:int=-1,out:Union[Tensor,str]='')->Tensor:
22+
assert isinstance(dim,int)
23+
assert isinstance(sliceobj,slice)
24+
assert isinstance(t,Tensor)
25+
dim=dim%t.ndim
26+
start=start = 0 if sliceobj.start is None else sliceobj.start % t.shape[dim]
27+
stop= t.shape[dim] if sliceobj.stop is None else sliceobj.stop % t.shape[dim]
28+
29+
index=arange(start,stop,dtype='int32')
30+
return indexselect(t,index,dim=dim,out=out)
31+
32+
cat= concat

front/py/deepx/nn/functional/leaffunc_changeshape.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,17 +80,16 @@ def broadcastTo(t:Tensor,new_shape:tuple[int,...],out:Union[Tensor,str]='',requi
8080
return outtensor
8181
broadcast_to = broadcastTo
8282

83-
def indexselect(input:Tensor,indices:Tensor,gatheraxis:int,out:Union[Tensor,str]='')->Tensor:
84-
assert gatheraxis>=0 and gatheraxis<input.ndim
85-
83+
def indexselect(input:Tensor,indices:Tensor,dim:int,out:Union[Tensor,str]='')->Tensor:
84+
assert dim>=0 and dim<input.ndim
8685
outtensor=out
8786
if isinstance(out,str) or out is None:
88-
outshape=Shape.indexselectshape(input.shape,indices.shape,gatheraxis)
87+
outshape=Shape.indexselectshape(input.shape,indices.shape,dim)
8988
outtensor=newtensor(outshape,dtype=input.dtype,name=out)
9089
assert outtensor.shape==outshape
9190

9291
from .rtf_changeshape import rtf_indexselect
93-
rtf_indexselect(input,indices,gatheraxis,outtensor,defaultauthor['indexselect'])
92+
rtf_indexselect(input,indices,dim,outtensor,defaultauthor['indexselect'])
9493
return outtensor
9594

9695
def repeat(input:Tensor,repeats:tuple[int,...],out:Union[Tensor,str]=''):

front/py/deepx/nn/functional/leaffunc_life.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from deepx.tensor import Tensor
22
from typing import Union
33

4-
def newtensor(shape:tuple[int,...],dtype:str='float32',name:str=None):
4+
def newtensor(shape:tuple[int,...],dtype:str='float32',name:str=None)->Tensor:
55
assert isinstance(shape,tuple)
66
for i in shape:
77
assert isinstance(i,int)

front/py/deepx/tensor/changeshape.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ def indexselect(self,index:Tensor,gatheraxis:int=0,out:Union[Tensor,str]='')->Te
6262
result=indexselect_func(self,index,gatheraxis,out)
6363
return result
6464

65+
@tensor_method
66+
def sliceselect(self,index:slice,dim:int=0,out:Union[Tensor,str]='')->Tensor:
67+
assert isinstance(index,slice)
68+
gatheraxis=dim%self.ndim
69+
from deepx.nn.functional import sliceselect as sliceselect_func
70+
result=sliceselect_func(self,index,gatheraxis,out)
71+
return result
72+
6573
@tensor_method
6674
def squeeze(self,dim:int)->Tensor:
6775
from deepx.nn.functional import squeeze as squeeze_func

front/py/deepx/tensor/tensor.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,48 @@ def __matmul__(self, other:'Tensor'):
159159
def __rmatmul__(self, other:'Tensor'):
160160
return other.matmul(self)
161161

162-
def __getitem__(self, index:'Tensor'):
163-
return self.indexselect(index)
162+
def __getitem__(self, idx):
163+
# 简单操作
164+
if isinstance(idx,Tensor):
165+
return self.indexselect(idx)
166+
if isinstance(idx, int):
167+
return self.sliceselect(slice(idx,idx+1)).squeeze(dim=0)
168+
169+
## 阶段1,
170+
if isinstance(idx, slice):
171+
indices = [idx]
172+
elif isinstance(idx, tuple):
173+
indices = list(idx)
174+
else:
175+
raise TypeError(f"Index must be an integer, slice, tuple, or Tensor, not {type(idx).__name__}")
176+
# 阶段2
177+
result = self
178+
new_axis_positions = []
179+
dim_cursor = 0
180+
181+
for item in indices:
182+
if item is None:
183+
# 如果是 None,则表示在该位置添加一个新的维度
184+
new_axis_positions.append(dim_cursor)
185+
continue
186+
if item == Ellipsis:
187+
num_ellipsis = self.ndim - len(indices) + 1
188+
dim_cursor += num_ellipsis
189+
continue
190+
# 如果是完整的切片 (e.g., ':'),则无需操作,直接进入下一维度
191+
if item == slice(None, None, None):
192+
dim_cursor += 1
193+
continue
194+
result=result.sliceselect(item,dim=dim_cursor)
195+
dim_cursor += 1
196+
197+
# 2. 在指定位置添加新维度(由 None 产生)
198+
i=0
199+
for pos in sorted(new_axis_positions):
200+
result = result.unsqueeze(pos+i)
201+
i += 1
202+
203+
return result
164204

165205
#shape操作
166206
@property

front/py/deepx/transformer/models/llama/attention.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from typing import Optional,Tuple
22
from deepx.nn.modules import Module,Linear
3-
from deepx import Tensor,matmul,softmax,concat,arange,dropout as dropout_func
3+
from deepx import Tensor,matmul,softmax,cat,dropout as dropout_func
44

55

66

77
def rotate_half(x:Tensor):
8-
index_front=arange(0,x.shape[-1]//2,dtype="int32")
9-
index_back=arange(x.shape[-1]//2,x.shape[-1],dtype="int32")
10-
x1 = x.indexselect(gatheraxis=-1,index=index_front)
11-
x2 = x.indexselect(gatheraxis=-1,index=index_back)
12-
return concat((-x2, x1,), dim=-1)
8+
x1 = x[..., : x.shape[-1] // 2]
9+
x2 = x[..., x.shape[-1] // 2 :]
10+
return cat((-x2, x1,), dim=-1)
1311

1412
def apply_rotary_pos_emb(q:Tensor, k:Tensor, cos:Tensor, sin:Tensor, unsqueeze_dim:int=1):
1513
cos = cos.unsqueeze(unsqueeze_dim)

front/py/examples/1_tensor/getitem.py

Lines changed: 0 additions & 14 deletions
This file was deleted.
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from deepx import newtensor,arange
2+
t = newtensor((2, 3, 13))
3+
t.arange_()
4+
print()
5+
t2 = t[None, :, None]
6+
t2.print()
7+
t3=t[:,None,:]
8+
t3.print()
9+
x=t
10+
x1 = x[..., : x.shape[-1] // 2]
11+
x2 = x[..., x.shape[-1] // 2 :]
12+
x1.print()
13+
x2.print()
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
3+
t = torch.full((2, 3, 13), 1)
4+
t2 = t[None, :, None]
5+
print(t2.shape)
6+
print(t2)
7+
x=t
8+
x1 = x[..., : x.shape[-1] // 2]
9+
x2 = x[..., x.shape[-1] // 2 :]
10+
print(x1)
11+
print(x2)

0 commit comments

Comments
 (0)