@@ -159,8 +159,48 @@ def __matmul__(self, other:'Tensor'):
159159 def __rmatmul__ (self , other :'Tensor' ):
160160 return other .matmul (self )
161161
162- def __getitem__ (self , index :'Tensor' ):
163- return self .indexselect (index )
162+ def __getitem__ (self , idx ):
163+ # 简单操作
164+ if isinstance (idx ,Tensor ):
165+ return self .indexselect (idx )
166+ if isinstance (idx , int ):
167+ return self .sliceselect (slice (idx ,idx + 1 )).squeeze (dim = 0 )
168+
169+ ## 阶段1,
170+ if isinstance (idx , slice ):
171+ indices = [idx ]
172+ elif isinstance (idx , tuple ):
173+ indices = list (idx )
174+ else :
175+ raise TypeError (f"Index must be an integer, slice, tuple, or Tensor, not { type (idx ).__name__ } " )
176+ # 阶段2
177+ result = self
178+ new_axis_positions = []
179+ dim_cursor = 0
180+
181+ for item in indices :
182+ if item is None :
183+ # 如果是 None,则表示在该位置添加一个新的维度
184+ new_axis_positions .append (dim_cursor )
185+ continue
186+ if item == Ellipsis :
187+ num_ellipsis = self .ndim - len (indices ) + 1
188+ dim_cursor += num_ellipsis
189+ continue
190+ # 如果是完整的切片 (e.g., ':'),则无需操作,直接进入下一维度
191+ if item == slice (None , None , None ):
192+ dim_cursor += 1
193+ continue
194+ result = result .sliceselect (item ,dim = dim_cursor )
195+ dim_cursor += 1
196+
197+ # 2. 在指定位置添加新维度(由 None 产生)
198+ i = 0
199+ for pos in sorted (new_axis_positions ):
200+ result = result .unsqueeze (pos + i )
201+ i += 1
202+
203+ return result
164204
165205 #shape操作
166206 @property
0 commit comments