Skip to content

Commit b56caa6

Browse files
committed
refactor(Arg): Update with new C API
Update `Arg` with new C API; use new allocation/deallocation functions and implement arg type support PR: #29 Signed-off-by: Kostis Papazafeiropoulos <papazof@gmail.com> Reviewed-by: Anastassios Nanos <ananos@nubificus.co.uk> Approved-by: Anastassios Nanos <ananos@nubificus.co.uk>
1 parent 5de18dd commit b56caa6

4 files changed

Lines changed: 198 additions & 31 deletions

File tree

tests/test_genop.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import pytest
66

7-
from vaccel import Arg, OpType, Session
7+
from vaccel import Arg, ArgType, OpType, Session
88

99

1010
@pytest.fixture
@@ -41,9 +41,12 @@ def test_data():
4141

4242

4343
def test_exec(test_lib, test_args):
44-
arg_read = [OpType.EXEC, test_lib["path"], test_lib["symbol"]]
45-
arg_read.extend(test_args["read"])
46-
g_arg_read = [Arg(arg) for arg in arg_read]
44+
g_arg_read = [
45+
Arg(OpType.EXEC, ArgType.UINT8),
46+
Arg(test_lib["path"], ArgType.STRING),
47+
Arg(test_lib["symbol"], ArgType.STRING),
48+
]
49+
g_arg_read += [Arg(arg) for arg in test_args["read"]]
4750
g_arg_write = [Arg(arg) for arg in test_args["write"]]
4851

4952
session = Session()
@@ -55,20 +58,20 @@ def test_exec(test_lib, test_args):
5558

5659
def test_sgemm(test_data):
5760
arg_read = [
58-
Arg(OpType.BLAS_SGEMM),
59-
Arg(test_data["m"]),
60-
Arg(test_data["n"]),
61-
Arg(test_data["k"]),
62-
Arg(test_data["alpha"]),
63-
Arg(test_data["a"]),
64-
Arg(test_data["lda"]),
65-
Arg(test_data["b"]),
66-
Arg(test_data["ldb"]),
67-
Arg(test_data["beta"]),
68-
Arg(test_data["ldc"]),
61+
Arg(OpType.BLAS_SGEMM, ArgType.UINT8),
62+
Arg(test_data["m"], ArgType.INT64),
63+
Arg(test_data["n"], ArgType.INT64),
64+
Arg(test_data["k"], ArgType.INT64),
65+
Arg(test_data["alpha"], ArgType.FLOAT32),
66+
Arg(test_data["a"], ArgType.FLOAT32_ARRAY),
67+
Arg(test_data["lda"], ArgType.INT64),
68+
Arg(test_data["b"], ArgType.FLOAT32_ARRAY),
69+
Arg(test_data["ldb"], ArgType.INT64),
70+
Arg(test_data["beta"], ArgType.FLOAT32),
71+
Arg(test_data["ldc"], ArgType.INT64),
6972
]
7073
c = [float(0)] * test_data["m"] * test_data["n"]
71-
arg_write = [Arg(c)]
74+
arg_write = [Arg(c, ArgType.FLOAT32_ARRAY)]
7275

7376
session = Session()
7477
session.genop(arg_read, arg_write)

vaccel/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Python API for vAccel."""
44

55
from ._version import __version__
6-
from .arg import Arg
6+
from .arg import Arg, ArgType
77
from .config import Config
88
from .op import OpType
99
from .resource import Resource, ResourceType
@@ -12,6 +12,7 @@
1212

1313
__all__ = [
1414
"Arg",
15+
"ArgType",
1516
"Config",
1617
"OpType",
1718
"Resource",

vaccel/arg.py

Lines changed: 165 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,106 @@
22

33
"""Interface to the `struct vaccel_arg` C object."""
44

5-
from typing import Any
5+
import logging
6+
from typing import Any, Final
67

78
from ._c_types import CAny, CType
8-
from ._libvaccel import ffi
9-
from .error import NullPointerError, ptr_or_raise
9+
from ._c_types.utils import CEnumBuilder
10+
from ._libvaccel import ffi, lib
11+
from .error import FFIError, NullPointerError, ptr_or_raise
12+
13+
logger = logging.getLogger(__name__)
14+
15+
enum_builder = CEnumBuilder(lib)
16+
ArgType = enum_builder.from_prefix("ArgType", "VACCEL_ARG_")
17+
18+
19+
class ArgTypeMapper:
20+
"""Utility for mapping between `ArgType` and other common types."""
21+
22+
_NUMERIC_TYPES: Final[set[ArgType]] = {
23+
ArgType.INT8,
24+
ArgType.INT8_ARRAY,
25+
ArgType.INT16,
26+
ArgType.INT16_ARRAY,
27+
ArgType.INT32,
28+
ArgType.INT32_ARRAY,
29+
ArgType.INT64,
30+
ArgType.INT64_ARRAY,
31+
ArgType.UINT8,
32+
ArgType.UINT8_ARRAY,
33+
ArgType.UINT16,
34+
ArgType.UINT16_ARRAY,
35+
ArgType.UINT32,
36+
ArgType.UINT32_ARRAY,
37+
ArgType.UINT64,
38+
ArgType.UINT64_ARRAY,
39+
ArgType.FLOAT32,
40+
ArgType.FLOAT32_ARRAY,
41+
ArgType.FLOAT64,
42+
}
43+
44+
_ARG_TYPE_TO_C: Final[dict[ArgType, str]] = {
45+
ArgType.INT8: "int8_t",
46+
ArgType.INT8_ARRAY: "int8_t *",
47+
ArgType.INT16: "int16_t",
48+
ArgType.INT16_ARRAY: "int16_t *",
49+
ArgType.INT32: "int32_t",
50+
ArgType.INT32_ARRAY: "int32_t *",
51+
ArgType.INT64: "int64_t",
52+
ArgType.INT64_ARRAY: "int64_t *",
53+
ArgType.UINT8: "uint8_t",
54+
ArgType.UINT8_ARRAY: "uint8_t *",
55+
ArgType.UINT16: "uint16_t",
56+
ArgType.UINT16_ARRAY: "uint16_t *",
57+
ArgType.UINT32: "uint32_t",
58+
ArgType.UINT32_ARRAY: "uint32_t *",
59+
ArgType.UINT64: "uint64_t",
60+
ArgType.UINT64_ARRAY: "uint64_t *",
61+
ArgType.FLOAT32: "float",
62+
ArgType.FLOAT32_ARRAY: "float *",
63+
ArgType.FLOAT64: "double",
64+
ArgType.FLOAT64_ARRAY: "double *",
65+
ArgType.BOOL: "bool",
66+
ArgType.BOOL_ARRAY: "bool *",
67+
ArgType.CHAR: "char",
68+
ArgType.CHAR_ARRAY: "char *",
69+
ArgType.UCHAR: "unsigned char",
70+
ArgType.UCHAR_ARRAY: "unsigned char *",
71+
ArgType.STRING: "char *",
72+
ArgType.BUFFER: "void *",
73+
}
74+
75+
@classmethod
76+
def is_numeric(cls, arg_type: ArgType) -> bool:
77+
"""Checks if the arg type represents a numeric type.
78+
79+
Args:
80+
arg_type: The arg type value.
81+
82+
Returns:
83+
True if the arg type represents a numeric type.
84+
"""
85+
return arg_type in cls._NUMERIC_TYPES
86+
87+
@classmethod
88+
def type_to_c_type(cls, arg_type: ArgType) -> str:
89+
"""Converts an `ArgType` to a C type string.
90+
91+
Args:
92+
arg_type: The arg type value.
93+
94+
Returns:
95+
A corresponding C type as a string (e.g., "float", "int64_t").
96+
97+
Raises:
98+
ValueError: If the `arg_type` value is not supported.
99+
"""
100+
if arg_type not in cls._ARG_TYPE_TO_C:
101+
supported = ", ".join(str(d) for d in cls._ARG_TYPE_TO_C)
102+
msg = f"Unsupported ArgType: {arg_type}. Supported: {supported}"
103+
raise ValueError(msg)
104+
return cls._ARG_TYPE_TO_C[arg_type]
10105

11106

12107
class Arg(CType):
@@ -20,23 +115,49 @@ class Arg(CType):
20115
21116
Attributes:
22117
_c_data (CAny): The encapsulated C data that is passed to the C struct.
118+
_c_obj_ptr (ffi.CData): A double pointer to the underlying
119+
`struct vaccel_arg` C object.
120+
type_ (ArgType): The type of the arg.
121+
custom_type_id (int): The user-specified type ID of the arg if the type
122+
is `ArgType.CUSTOM`.
23123
"""
24124

25-
def __init__(self, data: Any):
125+
def __init__(
126+
self, data: Any, type_: ArgType = ArgType.RAW, custom_type_id: int = 0
127+
):
26128
"""Initializes a new `Arg` object.
27129
28130
Args:
29131
data: The input data to be passed to the C struct.
132+
type_: The type of the arg.
133+
custom_type_id: The user-specified type ID of the arg if the type is
134+
`ArgType.CUSTOM`.
30135
"""
31-
self._c_data = CAny(data)
136+
if ArgType != ArgType.RAW and ArgTypeMapper.is_numeric(type_):
137+
precision = ArgTypeMapper.type_to_c_type(type_)
138+
self._c_data = CAny(data, precision=precision)
139+
else:
140+
self._c_data = CAny(data)
141+
self._c_obj_ptr = ffi.NULL
142+
self._type = type_
143+
self._custom_type_id = custom_type_id
32144
super().__init__()
33145

34146
def _init_c_obj(self):
35147
"""Initializes the underlying `struct vaccel_arg` C object."""
36-
c_data = self._c_data
37-
self._c_obj = ffi.new("struct vaccel_arg *")
38-
self._c_obj.size = c_data.c_size
39-
self._c_obj.buf = c_data._c_ptr
148+
self._c_obj_ptr = ffi.new("struct vaccel_arg **")
149+
ret = lib.vaccel_arg_from_buf(
150+
self._c_obj_ptr,
151+
self._c_data._c_ptr,
152+
self._c_data.c_size,
153+
self._type,
154+
self._custom_type_id,
155+
)
156+
if ret != 0:
157+
raise FFIError(ret, "Could not initialize arg")
158+
159+
self._c_obj = self._c_obj_ptr[0]
160+
self._c_size = ffi.sizeof("struct vaccel_arg")
40161

41162
@property
42163
def value(self) -> ffi.CData:
@@ -47,6 +168,24 @@ def value(self) -> ffi.CData:
47168
"""
48169
return self._c_ptr_or_raise[0]
49170

171+
def _del_c_obj(self):
172+
"""Deletes the underlying `struct vaccel_arg` C object.
173+
174+
Raises:
175+
FFIError: If arg deletion fails.
176+
"""
177+
ret = lib.vaccel_arg_delete(self._c_ptr_or_raise)
178+
if ret != 0:
179+
raise FFIError(ret, "Could not delete arg")
180+
181+
def __del__(self):
182+
try:
183+
self._del_c_obj()
184+
except NullPointerError:
185+
pass
186+
except FFIError:
187+
logger.exception("Failed to clean up Arg")
188+
50189
@property
51190
def buf(self) -> Any:
52191
"""Returns the buffer value from the underlying C struct.
@@ -62,6 +201,15 @@ def buf(self) -> Any:
62201
self._c_data, f"{self.__class__.__name__}._c_data"
63202
).value
64203

204+
@property
205+
def type(self) -> ArgType:
206+
"""The arg type.
207+
208+
Returns:
209+
The type of the arg.
210+
"""
211+
return ArgType(self._c_ptr_or_raise.type)
212+
65213
def __repr__(self):
66214
try:
67215
_c_ptr = (
@@ -70,6 +218,13 @@ def __repr__(self):
70218
else "NULL"
71219
)
72220
size = self._c_obj.size if self._c_obj != ffi.NULL else 0
221+
type_ = self.type
222+
type_name = getattr(type_, "name", repr(type_))
73223
except (AttributeError, TypeError, NullPointerError):
74224
return f"<{self.__class__.__name__} (uninitialized or invalid)>"
75-
return f"<{self.__class__.__name__} size={size} at {_c_ptr}>"
225+
return (
226+
f"<{self.__class__.__name__} "
227+
f"size={size} "
228+
f"type={type_name} "
229+
f"at {_c_ptr}>"
230+
)

vaccel/ops/exec.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def exec(
5151
FFIError: If the C operation fails.
5252
"""
5353
if arg_read is not None:
54-
c_arg_read = CList([Arg(arg) for arg in arg_read])
54+
c_arg_read = CList(
55+
[arg if isinstance(arg, Arg) else Arg(arg) for arg in arg_read]
56+
)
5557
c_arg_read_ptr = c_arg_read._c_ptr
5658
c_arg_read_len = len(c_arg_read)
5759
else:
@@ -60,7 +62,9 @@ def exec(
6062
c_arg_read_len = 0
6163

6264
if arg_write is not None:
63-
c_arg_write = CList([Arg(arg) for arg in arg_write])
65+
c_arg_write = CList(
66+
[arg if isinstance(arg, Arg) else Arg(arg) for arg in arg_write]
67+
)
6468
c_arg_write_ptr = c_arg_write._c_ptr
6569
c_arg_write_len = len(c_arg_write)
6670
else:
@@ -112,7 +116,9 @@ def exec_with_resource(
112116
FFIError: If the C operation fails.
113117
"""
114118
if arg_read is not None:
115-
c_arg_read = CList([Arg(arg) for arg in arg_read])
119+
c_arg_read = CList(
120+
[arg if isinstance(arg, Arg) else Arg(arg) for arg in arg_read]
121+
)
116122
c_arg_read_ptr = c_arg_read._c_ptr
117123
c_arg_read_len = len(c_arg_read)
118124
else:
@@ -121,7 +127,9 @@ def exec_with_resource(
121127
c_arg_read_len = 0
122128

123129
if arg_write is not None:
124-
c_arg_write = CList([Arg(arg) for arg in arg_write])
130+
c_arg_write = CList(
131+
[arg if isinstance(arg, Arg) else Arg(arg) for arg in arg_write]
132+
)
125133
c_arg_write_ptr = c_arg_write._c_ptr
126134
c_arg_write_len = len(c_arg_write)
127135
else:

0 commit comments

Comments
 (0)