diff --git a/.gitignore b/.gitignore index d12c881..b1eb4f3 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,9 @@ *.tgz yarn.lock +bun.lock package-lock.json +.cache/ npm-debug.log yarn-error.log /node_modules/ diff --git a/.gitmodules b/.gitmodules index 40a4a0f..d3fbe2b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "deps/mlx"] path = deps/mlx - url = https://github.com/ml-explore/mlx + url = https://github.com/robert-johansson/mlx [submodule "deps/kizunapi"] path = deps/kizunapi - url = https://github.com/photoionization/kizunapi + url = https://github.com/robert-johansson/kizunapi diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cb0999..c1b0b69 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,5 +38,6 @@ target_include_directories(${PROJECT_NAME} PRIVATE "deps/kizunapi") option(MLX_BUILD_TESTS "Build tests for mlx" OFF) option(MLX_BUILD_EXAMPLES "Build examples for mlx" OFF) option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" ON) +option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" ON) add_subdirectory(deps/mlx) target_link_libraries(${PROJECT_NAME} mlx) diff --git a/deps/kizunapi b/deps/kizunapi index b8d0622..0071f6c 160000 --- a/deps/kizunapi +++ b/deps/kizunapi @@ -1 +1 @@ -Subproject commit b8d06226897a0cfe42a6efab39c413efd35b2276 +Subproject commit 0071f6c30b2f7ce9b3a3ac13f0d17ab464f21073 diff --git a/deps/mlx b/deps/mlx index b529515..c60059d 160000 --- a/deps/mlx +++ b/deps/mlx @@ -1 +1 @@ -Subproject commit b529515eb158edd0919746ce4e545fe0879d6437 +Subproject commit c60059d58778cfd755ba2a9a5850200be870a501 diff --git a/node_mlx.node.d.ts b/node_mlx.node.d.ts index 2000c8a..3f859f9 100644 --- a/node_mlx.node.d.ts +++ b/node_mlx.node.d.ts @@ -195,6 +195,7 @@ declare module '*node_mlx.node' { function argmin(array: ScalarOrArray, axis?: number, keepdims?: boolean, s?: StreamOrDevice): array; function argpartition(array: ScalarOrArray, kth: number, axis?: number, s?: StreamOrDevice): array; function argsort(array: ScalarOrArray, s?: StreamOrDevice): array; + function searchsorted(a: ScalarOrArray, v: ScalarOrArray, right?: boolean, s?: StreamOrDevice): array; function arrayEqual(a: ScalarOrArray, b: ScalarOrArray, equalNan?: boolean, s?: StreamOrDevice): array; function asStrided(array: ScalarOrArray, shape?: number[], strides?: number[], offset?: number, s?: StreamOrDevice): array; function atleast1d(...arrays: array[]): array; @@ -240,6 +241,10 @@ declare module '*node_mlx.node' { function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array; function erf(array: ScalarOrArray, s?: StreamOrDevice): array; function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array; + function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function digamma(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array; + function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array; function exp(array: ScalarOrArray, s?: StreamOrDevice): array; function expm1(array: ScalarOrArray, s?: StreamOrDevice): array; function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array; @@ -378,6 +383,7 @@ declare module '*node_mlx.node' { function tidy(func: () => U): U; function dispose(...args: unknown[]): void; function getWrappersCount(): number; + function sweepDeadArrays(): number; // Metal. namespace metal { diff --git a/src/array.cc b/src/array.cc index 8edbb55..82afc58 100644 --- a/src/array.cc +++ b/src/array.cc @@ -349,6 +349,7 @@ napi_value Item(mx::array* a, napi_env env) { return nullptr; } a->eval(); + a->detach(); return VisitArrayData([env](auto* data) { return ki::ToNodeValue(env, *data); }, a); @@ -383,6 +384,7 @@ napi_value ToList(mx::array* a, napi_env env) { if (a->ndim() == 0) return Item(a, env); a->eval(); + a->detach(); return VisitArrayData([env, a](auto* data) { return MxArrayToJsArray(env, *a, data); }, a); @@ -409,6 +411,7 @@ napi_value ToTypedArray(mx::array* a, napi_env env) { return nullptr; } a->eval(); + a->detach(); // Create a ArrayBuffer that stores a reference to array's data. using DataType = std::shared_ptr; napi_value buffer; @@ -462,38 +465,58 @@ std::stack> g_tidy_arrays; // Release all array pointers allocated during the call. napi_value Tidy(napi_env env, std::function func) { - // Push a new set to stack. + // Push a new set to stack. TypeBridge::Wrap inserts arrays here during func(). g_tidy_arrays.push(std::set()); - auto& top = g_tidy_arrays.top(); + // Shared flag: tracks whether cpp_then already popped the stack. + // Prevents double-pop in nested tidy (inner finally must not pop outer set). + auto popped = std::make_shared(false); return AwaitFunction( env, std::move(func), - [&top](napi_env env, napi_value result) { - // Exclude the arrays in result from the stack. + [popped](napi_env env, napi_value result) { + // Move the set out of the stack so it's safe from concurrent modification. + auto top = std::move(g_tidy_arrays.top()); + g_tidy_arrays.pop(); + *popped = true; + // Exclude the arrays in result from the set. TreeVisit(env, result, [&top](napi_env env, napi_value value) { if (auto a = ki::FromNodeTo(env, value); a) top.erase(*a); return napi_value(); }); - // Clear the arrays in the stack. + // Clear the arrays in the set. ki::InstanceData* instance_data = ki::InstanceData::Get(env); for (mx::array* a : top) { - // The arary might be in 3 states: + // The array might be in 3 states: // 1. Its JS object is well alive. - // 2. The JS object has been fully GCed. + // 2. The JS object has been fully GCed (finalizer ran, ptr freed). // 3. The JS object is marked as dead, but the finalizer has not run. - // We have to unbind the JS object in 1, and only delete array in 1 - // and 3. + // We must check wrapper validity BEFORE dereferencing the pointer, + // because in state 2 the pointer is dangling (already deleted by + // TypeBridge::Finalize during GC). napi_value value; - if (instance_data->GetWrapper(a, &value)) + bool has_wrapper = instance_data->GetWrapper(a, &value); + if (has_wrapper) { + // State 1: JS object alive — unbind it napi_remove_wrap(env, value, nullptr); - if (instance_data->DeleteWrapper(a)) + } + // Try to claim ownership (returns true for states 1 and 3) + if (instance_data->DeleteWrapper(a)) { + // Safe to dereference: pointer is still valid (not yet finalized) + int64_t ext = ki::internal::ExternalMemorySize::Get(a); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a; + } + // State 2: fully GC'd — skip, pointer is dangling } return result; }, - [](napi_env env) { - // Always pop even when error happened. - g_tidy_arrays.pop(); + [popped](napi_env env) { + // Only pop if cpp_then didn't already handle it. + if (!*popped) + g_tidy_arrays.pop(); }); } @@ -504,8 +527,13 @@ void Dispose(const ki::Arguments& args) { TreeVisit(args.Env(), args[i], [instance_data](napi_env env, napi_value value) { if (auto a = ki::FromNodeTo(env, value); a) { + int64_t ext = ki::internal::ExternalMemorySize::Get(a.value()); napi_remove_wrap(env, value, nullptr); instance_data->DeleteWrapper(a.value()); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } delete a.value(); } return napi_value(); @@ -520,6 +548,26 @@ size_t GetWrappersCount(napi_env env) { } // namespace +// Synchronously sweep dead array wrappers. +// Finds arrays whose JS wrappers have been GC'd but whose deferred finalizers +// haven't run yet, and immediately frees the native Metal buffers. +// Called automatically from Eval() to prevent Metal resource accumulation. +// Returns the number of arrays swept. +size_t SweepDeadArrays(napi_env env) { + ki::InstanceData* instance_data = ki::InstanceData::Get(env); + auto dead_ptrs = instance_data->CollectDeadWrappers(); + for (void* ptr : dead_ptrs) { + mx::array* a = static_cast(ptr); + int64_t ext = ki::internal::ExternalMemorySize::Get(a); + if (ext > 0) { + int64_t adjusted; + napi_adjust_external_memory(env, -ext, &adjusted); + } + delete a; + } + return dead_ptrs.size(); +} + namespace ki { // Allow passing Dtype to JS directly, no memory management involved as they are @@ -808,5 +856,6 @@ void InitArray(napi_env env, napi_value exports) { ki::Set(env, exports, "tidy", &Tidy, "dispose", &Dispose, - "getWrappersCount", &GetWrappersCount); + "getWrappersCount", &GetWrappersCount, + "sweepDeadArrays", &SweepDeadArrays); } diff --git a/src/array.h b/src/array.h index 901e2a5..9c3346e 100644 --- a/src/array.h +++ b/src/array.h @@ -65,6 +65,32 @@ struct Type : public AllowPassByValue { napi_value value); }; +namespace internal { + +// Report external memory for mx::array to enable GC pressure signaling. +// MLX arrays hold Metal GPU buffers that are invisible to the JS GC. +// Without this, the GC doesn't know about GPU memory pressure and doesn't +// collect array wrappers fast enough, causing Metal resource exhaustion. +template<> +struct ExternalMemorySize { + static int64_t Get(mx::array* a) { + // Metal has a hard limit of 499K buffer allocations. We must create + // enough external memory pressure to force the GC to collect array + // wrappers before hitting it. Report 1MB per array as the minimum + // external cost — this is much larger than the actual data size but + // necessary to trigger sufficiently aggressive GC for GPU resources. + size_t n = a->nbytes(); + constexpr int64_t min_cost = 1024 * 1024; // 1MB + return static_cast(n) > min_cost ? static_cast(n) : min_cost; + } +}; + +} // namespace internal + } // namespace ki +// Synchronously sweep dead array wrappers whose JS finalizers haven't run yet. +// Called automatically from Eval() to prevent Metal resource accumulation. +size_t SweepDeadArrays(napi_env env); + #endif // SRC_ARRAY_H_ diff --git a/src/fast.cc b/src/fast.cc index 6cfe1e0..24aa539 100644 --- a/src/fast.cc +++ b/src/fast.cc @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention( throw std::invalid_argument(msg.str()); } return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, mask_str, {}, s); + queries, keys, values, scale, mask_str, {}, {}, s); } else { auto mask_arr = std::get(mask); return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {mask_arr}, s); + queries, keys, values, scale, "", {mask_arr}, {}, s); } } else { return mx::fast::scaled_dot_product_attention( - queries, keys, values, scale, "", {}, s); + queries, keys, values, scale, "", {}, {}, s); } } diff --git a/src/fft.cc b/src/fft.cc index 808889a..51cccca 100644 --- a/src/fft.cc +++ b/src/fft.cc @@ -32,7 +32,7 @@ std::function FFTNOpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name, std::optional> axes, mx::StreamOrDevice s) { if (n && axes) { - return mx::fft::fftn(a, std::move(*n), std::move(*axes), s); + mx::Shape shape_n(n->begin(), n->end()); + return func1(a, shape_n, std::move(*axes), s); } else if (axes) { - return mx::fft::fftn(a, std::move(*axes), s); + return func2(a, std::move(*axes), s); } else if (n) { std::ostringstream msg; msg << "[" << name << "] " << "`axes` should not be `None` if `s` is not `None`."; throw std::invalid_argument(msg.str()); } else { - return mx::fft::fftn(a, s); + return func3(a, s); } }; } @@ -66,7 +67,7 @@ std::function FFT2OpWrapper(const char* name, mx::array(*func1)(const mx::array&, - const std::vector&, + const mx::Shape&, const std::vector&, mx::StreamOrDevice), mx::array(*func2)(const mx::array&, diff --git a/src/indexing.cc b/src/indexing.cc index e5d7286..b61dc12 100644 --- a/src/indexing.cc +++ b/src/indexing.cc @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a, a->shape().begin() + non_none_indices, a->shape().end()); up = mx::reshape(std::move(up), std::move(up_reshape)); - mx::Shape axes(arr_indices.size(), 0); + std::vector axes(arr_indices.size(), 0); std::iota(axes.begin(), axes.end(), 0); return {std::move(arr_indices), std::move(up), std::move(axes)}; } diff --git a/src/memory.cc b/src/memory.cc index 64f789f..af641b6 100644 --- a/src/memory.cc +++ b/src/memory.cc @@ -9,5 +9,7 @@ void InitMemory(napi_env env, napi_value exports) { "setMemoryLimit", &mx::set_memory_limit, "setWiredLimit", &mx::set_wired_limit, "setCacheLimit", &mx::set_cache_limit, - "clearCache", &mx::clear_cache); + "clearCache", &mx::clear_cache, + "getNumResources", &mx::get_num_resources, + "getResourceLimit", &mx::get_resource_limit); } diff --git a/src/metal.cc b/src/metal.cc index 91f79ae..777f1b4 100644 --- a/src/metal.cc +++ b/src/metal.cc @@ -1,4 +1,14 @@ #include "src/bindings.h" +#include "mlx/backend/gpu/device_info.h" + +namespace metal_ops { + +const std::unordered_map>& +DeviceInfo() { + return mx::gpu::device_info(0); +} + +} // namespace metal_ops void InitMetal(napi_env env, napi_value exports) { napi_value metal = ki::CreateObject(env); @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) { "isAvailable", &mx::metal::is_available, "startCapture", &mx::metal::start_capture, "stopCapture", &mx::metal::stop_capture, - "deviceInfo", &mx::metal::device_info); + "deviceInfo", &metal_ops::DeviceInfo); } diff --git a/src/ops.cc b/src/ops.cc index aad9d05..bdbd461 100644 --- a/src/ops.cc +++ b/src/ops.cc @@ -191,7 +191,7 @@ mx::array Full(std::variant shape, ScalarOrArray vals, std::optional dtype, mx::StreamOrDevice s) { - return mx::full(PutIntoVector(std::move(shape)), + return mx::full(PutIntoShape(std::move(shape)), ToArray(std::move(vals), std::move(dtype)), s); } @@ -199,13 +199,13 @@ mx::array Full(std::variant shape, mx::array Zeros(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::zeros(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::zeros(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Ones(std::variant shape, std::optional dtype, mx::StreamOrDevice s) { - return mx::ones(PutIntoVector(std::move(shape)), dtype.value_or(mx::float32), s); + return mx::ones(PutIntoShape(std::move(shape)), dtype.value_or(mx::float32), s); } mx::array Eye(int n, @@ -303,8 +303,9 @@ std::vector Split(const mx::array& a, if (auto i = std::get_if(&indices); i) { return mx::split(a, *i, axis.value_or(0), s); } else { - return mx::split(a, std::move(std::get>(indices)), - axis.value_or(0), s); + auto& v = std::get>(indices); + mx::Shape shape_indices(v.begin(), v.end()); + return mx::split(a, std::move(shape_indices), axis.value_or(0), s); } } @@ -346,6 +347,13 @@ mx::array ArgSort(const mx::array& a, return mx::argsort(a, s); } +mx::array SearchSorted(const mx::array& a, + const mx::array& v, + std::optional right, + mx::StreamOrDevice s) { + return mx::searchsorted(a, v, right.value_or(false), s); +} + mx::array Softmax(const mx::array& a, OptionalAxes axis, std::optional precise, @@ -544,7 +552,7 @@ mx::array ConvTranspose1d( mx::StreamOrDevice s) { return mx::conv_transpose1d(input, weight, stride.value_or(1), padding.value_or(0), dilation.value_or(1), - groups.value_or(1), s); + /*output_padding=*/0, groups.value_or(1), s); } mx::array ConvTranspose2d( @@ -574,7 +582,7 @@ mx::array ConvTranspose2d( dilation_pair = std::move(*p); return mx::conv_transpose2d(input, weight, stride_pair, padding_pair, - dilation_pair, groups.value_or(1), s); + dilation_pair, {0, 0}, groups.value_or(1), s); } mx::array ConvTranspose3d( @@ -604,7 +612,7 @@ mx::array ConvTranspose3d( dilation_tuple = std::move(*p); return mx::conv_transpose3d(input, weight, stride_tuple, padding_tuple, - dilation_tuple, groups.value_or(1), s); + dilation_tuple, {0, 0, 0}, groups.value_or(1), s); } mx::array ConvGeneral( @@ -789,6 +797,10 @@ void InitOps(napi_env env, napi_value exports) { "expm1", &mx::expm1, "erf", &mx::erf, "erfinv", &mx::erfinv, + "lgamma", &mx::lgamma, + "digamma", &mx::digamma, + "besselI0e", &mx::bessel_i0e, + "besselI1e", &mx::bessel_i1e, "sin", &mx::sin, "cos", &mx::cos, "tan", &mx::tan, @@ -811,7 +823,9 @@ void InitOps(napi_env env, napi_value exports) { "stopGradient", &mx::stop_gradient, "sigmoid", &mx::sigmoid, "power", BinOpWrapper(&mx::power), - "arange", &ops::ARange, + "arange", &ops::ARange); + + ki::Set(env, exports, "linspace", &ops::Linspace, "kron", &mx::kron, "take", &ops::Take, @@ -848,15 +862,20 @@ void InitOps(napi_env env, napi_value exports) { "min", DimOpWrapper(&mx::min), "max", DimOpWrapper(&mx::max), "logcumsumexp", CumOpWrapper(&mx::logcumsumexp), - "logsumexp", DimOpWrapper(&mx::logsumexp), + "logsumexp", DimOpWrapper(&mx::logsumexp)); + + ki::Set(env, exports, "mean", DimOpWrapper(&mx::mean), "variance", &ops::Var, "std", &ops::Std, "split", &ops::Split, - "argmin", &ops::ArgMin, + "argmin", &ops::ArgMin); + + ki::Set(env, exports, "argmax", &ops::ArgMax, "sort", &ops::Sort, "argsort", &ops::ArgSort, + "searchsorted", &ops::SearchSorted, "partition", KthOpWrapper(&mx::partition, &mx::partition), "argpartition", KthOpWrapper(&mx::argpartition, &mx::argpartition), "topk", KthOpWrapper(&mx::topk, &mx::topk), @@ -864,7 +883,9 @@ void InitOps(napi_env env, napi_value exports) { "blockMaskedMM", &mx::block_masked_mm, "gatherMM", &mx::gather_mm, "gatherQMM", &mx::gather_qmm, - "softmax", &ops::Softmax, + "softmax", &ops::Softmax); + + ki::Set(env, exports, "concatenate", &ops::Concatenate, "concat", &ops::Concatenate, "stack", &ops::Stack, @@ -876,7 +897,9 @@ void InitOps(napi_env env, napi_value exports) { "cumsum", CumOpWrapper(&mx::cumsum), "cumprod", CumOpWrapper(&mx::cumprod), "cummax", CumOpWrapper(&mx::cummax), - "cummin", CumOpWrapper(&mx::cummin), + "cummin", CumOpWrapper(&mx::cummin)); + + ki::Set(env, exports, "conj", &mx::conjugate, "conjugate", &mx::conjugate, "convolve", &ops::Convolve, @@ -912,7 +935,9 @@ void InitOps(napi_env env, napi_value exports) { "bitwiseXor", BinOpWrapper(&mx::bitwise_xor), "leftShift", BinOpWrapper(&mx::left_shift), "rightShift", BinOpWrapper(&mx::right_shift), - "view", &mx::view, + "view", &mx::view); + + ki::Set(env, exports, "hadamardTransform", &mx::hadamard_transform, "einsumPath", &mx::einsum_path, "einsum", &mx::einsum, diff --git a/src/transforms.cc b/src/transforms.cc index c2be8c0..f03ad14 100644 --- a/src/transforms.cc +++ b/src/transforms.cc @@ -166,11 +166,14 @@ ValueAndGradImpl(const char* error_tag, std::iota(gradient_indices.begin(), gradient_indices.end(), 0); // The result of |js_func| execution. napi_value result = nullptr; + // Flag set when the JS callback fails during tracing. + bool callback_failed = false; // Call value_and_grad with the JS function. napi_env env = js_func.Env(); auto value_and_grad_func = mx::value_and_grad( [error_tag, scalar_func_only, - &js_func, &args, &argnums, &arrays, &strides, &result, &env]( + &js_func, &args, &argnums, &arrays, &strides, &result, + &callback_failed, &env]( const std::vector& primals) -> std::vector { // Read the args into |js_args| vector, and replace the arrays in it // with the traced |primals|. @@ -191,6 +194,7 @@ ValueAndGradImpl(const char* error_tag, js_args.size(), js_args.empty() ? nullptr : &js_args.front(), &result) != napi_ok) { + callback_failed = true; return {}; } // Validate the return value. @@ -240,6 +244,18 @@ ValueAndGradImpl(const char* error_tag, // Call the function immediately, because this C++ lambda is actually the // result of value_and_grad. const auto& [values, gradients] = value_and_grad_func(arrays); + // If the JS callback threw during tracing, propagate the error instead + // of continuing with garbage results (stale tracer Symbol objects). + if (callback_failed) { + // Re-throw if there's a pending exception, otherwise create one. + bool has_exception = false; + napi_is_exception_pending(env, &has_exception); + if (!has_exception) { + ki::ThrowError(env, error_tag, + " The function threw an error during tracing."); + } + return {nullptr, nullptr}; + } // Convert gradients to JS value. For array inputs the gradients will be // returned, for Array and Object inputs the original arg will be returned // with their array properties replaced with corresponding gradients. @@ -265,7 +281,18 @@ ValueAndGradImpl(const char* error_tag, namespace transforms_ops { void Eval(ki::Arguments* args) { - mx::eval(TreeFlatten(args)); + auto arrays = TreeFlatten(args); + mx::eval(arrays); + // Detach evaluated arrays from the computation graph. + // After eval, each array still holds shared_ptr references to its inputs + // (for potential re-evaluation or gradient computation). In long-running + // processes, these graph chains prevent Metal buffers from being freed, + // causing num_resources to grow monotonically until crash. + // Since node-mlx manages gradients explicitly via valueAndGrad/grad + // (which trace their own graphs), the forward graph is not needed after eval. + for (auto& a : arrays) { + a.detach(); + } } napi_value AsyncEval(ki::Arguments* args) { @@ -477,6 +504,7 @@ void InitTransforms(napi_env env, napi_value exports) { "grad", &transforms_ops::Grad, "vmap", &transforms_ops::VMap, "compile", &transforms_ops::Compile, + "compileClearCache", &mx::detail::compile_clear_cache, "disableCompile", &mx::disable_compile, "enableCompile", &mx::enable_compile); } diff --git a/src/utils.cc b/src/utils.cc index b827c53..f309c48 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -1,7 +1,7 @@ #include "src/array.h" #include "src/utils.h" -mx::Shape PutIntoVector(std::variant shape) { +mx::Shape PutIntoShape(std::variant shape) { if (auto i = std::get_if(&shape); i) return {*i}; return std::move(std::get(shape)); diff --git a/src/utils.h b/src/utils.h index 8e9fd41..3cf2fb2 100644 --- a/src/utils.h +++ b/src/utils.h @@ -8,12 +8,45 @@ namespace mx = mlx::core; +// Teach kizunapi how to serialize/deserialize SmallVector (used for Shape +// and other types in MLX >= 0.26). Mirrors the std::vector specialization. +namespace ki { + +template +struct Type> { + static constexpr const char* name = "Array"; + static napi_status ToNode(napi_env env, + const mlx::core::SmallVector& vec, + napi_value* result) { + napi_status s = napi_create_array_with_length(env, vec.size(), result); + if (s != napi_ok) return s; + for (size_t i = 0; i < vec.size(); ++i) { + napi_value el; + s = ConvertToNode(env, vec[i], &el); + if (s != napi_ok) return s; + s = napi_set_element(env, *result, i, el); + if (s != napi_ok) return s; + } + return napi_ok; + } + static std::optional> FromNode( + napi_env env, napi_value value) { + // Read as std::vector then convert to SmallVector. + auto vec = Type>::FromNode(env, value); + if (!vec) return std::nullopt; + return mlx::core::SmallVector(vec->begin(), vec->end()); + } +}; + +} // namespace ki + using OptionalAxes = std::variant>; using ScalarOrArray = std::variant; -// Read args into a vector of types. -template -bool ReadArgs(ki::Arguments* args, std::vector* results) { +// Read args into a container of types (vector or SmallVector). +template +bool ReadArgs(ki::Arguments* args, Container* results) { + using T = typename Container::value_type; while (args->RemainingsLength() > 0) { std::optional a = args->GetNext(); if (!a) { @@ -45,8 +78,15 @@ void DefineToString(napi_env env, napi_value prototype) { symbol, ki::MemberFunction(&ToString)); } +// If input is one int, put it into a Shape, otherwise just return the Shape. +mx::Shape PutIntoShape(std::variant shape); + // If input is one int, put it into a vector, otherwise just return the vector. -std::vector PutIntoVector(std::variant> shape); +inline std::vector PutIntoVector(std::variant> v) { + if (auto i = std::get_if(&v); i) + return {*i}; + return std::move(std::get>(v)); +} // Get axis arg from js value. std::vector GetReduceAxes(OptionalAxes value, int dims);