Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
*.tgz

yarn.lock
bun.lock
package-lock.json
.cache/
npm-debug.log
yarn-error.log
/node_modules/
Expand Down
4 changes: 2 additions & 2 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "deps/mlx"]
path = deps/mlx
url = https://github.com/ml-explore/mlx
url = https://github.com/robert-johansson/mlx
[submodule "deps/kizunapi"]
path = deps/kizunapi
url = https://github.com/photoionization/kizunapi
url = https://github.com/robert-johansson/kizunapi
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ target_include_directories(${PROJECT_NAME} PRIVATE "deps/kizunapi")
option(MLX_BUILD_TESTS "Build tests for mlx" OFF)
option(MLX_BUILD_EXAMPLES "Build examples for mlx" OFF)
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" ON)
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" ON)
add_subdirectory(deps/mlx)
target_link_libraries(${PROJECT_NAME} mlx)
2 changes: 1 addition & 1 deletion deps/kizunapi
2 changes: 1 addition & 1 deletion deps/mlx
Submodule mlx updated from b52951 to c60059
6 changes: 6 additions & 0 deletions node_mlx.node.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ declare module '*node_mlx.node' {
function argmin(array: ScalarOrArray, axis?: number, keepdims?: boolean, s?: StreamOrDevice): array;
function argpartition(array: ScalarOrArray, kth: number, axis?: number, s?: StreamOrDevice): array;
function argsort(array: ScalarOrArray, s?: StreamOrDevice): array;
function searchsorted(a: ScalarOrArray, v: ScalarOrArray, right?: boolean, s?: StreamOrDevice): array;
function arrayEqual(a: ScalarOrArray, b: ScalarOrArray, equalNan?: boolean, s?: StreamOrDevice): array;
function asStrided(array: ScalarOrArray, shape?: number[], strides?: number[], offset?: number, s?: StreamOrDevice): array;
function atleast1d(...arrays: array[]): array;
Expand Down Expand Up @@ -240,6 +241,10 @@ declare module '*node_mlx.node' {
function notEqual(a: ScalarOrArray, b: ScalarOrArray, s?: StreamOrDevice): array;
function erf(array: ScalarOrArray, s?: StreamOrDevice): array;
function erfinv(array: ScalarOrArray, s?: StreamOrDevice): array;
function lgamma(array: ScalarOrArray, s?: StreamOrDevice): array;
function digamma(array: ScalarOrArray, s?: StreamOrDevice): array;
function besselI0e(array: ScalarOrArray, s?: StreamOrDevice): array;
function besselI1e(array: ScalarOrArray, s?: StreamOrDevice): array;
function exp(array: ScalarOrArray, s?: StreamOrDevice): array;
function expm1(array: ScalarOrArray, s?: StreamOrDevice): array;
function expandDims(array: ScalarOrArray, dims: number | number[], s?: StreamOrDevice): array;
Expand Down Expand Up @@ -378,6 +383,7 @@ declare module '*node_mlx.node' {
function tidy<U>(func: () => U): U;
function dispose(...args: unknown[]): void;
function getWrappersCount(): number;
function sweepDeadArrays(): number;

// Metal.
namespace metal {
Expand Down
79 changes: 64 additions & 15 deletions src/array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ napi_value Item(mx::array* a, napi_env env) {
return nullptr;
}
a->eval();
a->detach();
return VisitArrayData([env](auto* data) {
return ki::ToNodeValue(env, *data);
}, a);
Expand Down Expand Up @@ -383,6 +384,7 @@ napi_value ToList(mx::array* a, napi_env env) {
if (a->ndim() == 0)
return Item(a, env);
a->eval();
a->detach();
return VisitArrayData([env, a](auto* data) {
return MxArrayToJsArray(env, *a, data);
}, a);
Expand All @@ -409,6 +411,7 @@ napi_value ToTypedArray(mx::array* a, napi_env env) {
return nullptr;
}
a->eval();
a->detach();
// Create a ArrayBuffer that stores a reference to array's data.
using DataType = std::shared_ptr<mx::array::Data>;
napi_value buffer;
Expand Down Expand Up @@ -462,38 +465,58 @@ std::stack<std::set<mx::array*>> g_tidy_arrays;

// Release all array pointers allocated during the call.
napi_value Tidy(napi_env env, std::function<napi_value()> func) {
// Push a new set to stack.
// Push a new set to stack. TypeBridge::Wrap inserts arrays here during func().
g_tidy_arrays.push(std::set<mx::array*>());
auto& top = g_tidy_arrays.top();
// Shared flag: tracks whether cpp_then already popped the stack.
// Prevents double-pop in nested tidy (inner finally must not pop outer set).
auto popped = std::make_shared<bool>(false);
return AwaitFunction(
env, std::move(func),
[&top](napi_env env, napi_value result) {
// Exclude the arrays in result from the stack.
[popped](napi_env env, napi_value result) {
// Move the set out of the stack so it's safe from concurrent modification.
auto top = std::move(g_tidy_arrays.top());
g_tidy_arrays.pop();
*popped = true;
// Exclude the arrays in result from the set.
TreeVisit(env, result, [&top](napi_env env, napi_value value) {
if (auto a = ki::FromNodeTo<mx::array*>(env, value); a)
top.erase(*a);
return napi_value();
});
// Clear the arrays in the stack.
// Clear the arrays in the set.
ki::InstanceData* instance_data = ki::InstanceData::Get(env);
for (mx::array* a : top) {
// The arary might be in 3 states:
// The array might be in 3 states:
// 1. Its JS object is well alive.
// 2. The JS object has been fully GCed.
// 2. The JS object has been fully GCed (finalizer ran, ptr freed).
// 3. The JS object is marked as dead, but the finalizer has not run.
// We have to unbind the JS object in 1, and only delete array in 1
// and 3.
// We must check wrapper validity BEFORE dereferencing the pointer,
// because in state 2 the pointer is dangling (already deleted by
// TypeBridge::Finalize during GC).
napi_value value;
if (instance_data->GetWrapper<mx::array>(a, &value))
bool has_wrapper = instance_data->GetWrapper<mx::array>(a, &value);
if (has_wrapper) {
// State 1: JS object alive — unbind it
napi_remove_wrap(env, value, nullptr);
if (instance_data->DeleteWrapper<mx::array>(a))
}
// Try to claim ownership (returns true for states 1 and 3)
if (instance_data->DeleteWrapper<mx::array>(a)) {
// Safe to dereference: pointer is still valid (not yet finalized)
int64_t ext = ki::internal::ExternalMemorySize<mx::array>::Get(a);
if (ext > 0) {
int64_t adjusted;
napi_adjust_external_memory(env, -ext, &adjusted);
}
delete a;
}
// State 2: fully GC'd — skip, pointer is dangling
}
return result;
},
[](napi_env env) {
// Always pop even when error happened.
g_tidy_arrays.pop();
[popped](napi_env env) {
// Only pop if cpp_then didn't already handle it.
if (!*popped)
g_tidy_arrays.pop();
});
}

Expand All @@ -504,8 +527,13 @@ void Dispose(const ki::Arguments& args) {
TreeVisit(args.Env(), args[i],
[instance_data](napi_env env, napi_value value) {
if (auto a = ki::FromNodeTo<mx::array*>(env, value); a) {
int64_t ext = ki::internal::ExternalMemorySize<mx::array>::Get(a.value());
napi_remove_wrap(env, value, nullptr);
instance_data->DeleteWrapper<mx::array>(a.value());
if (ext > 0) {
int64_t adjusted;
napi_adjust_external_memory(env, -ext, &adjusted);
}
delete a.value();
}
return napi_value();
Expand All @@ -520,6 +548,26 @@ size_t GetWrappersCount(napi_env env) {

} // namespace

// Synchronously sweep dead array wrappers.
// Finds arrays whose JS wrappers have been GC'd but whose deferred finalizers
// haven't run yet, and immediately frees the native Metal buffers.
// Called automatically from Eval() to prevent Metal resource accumulation.
// Returns the number of arrays swept.
size_t SweepDeadArrays(napi_env env) {
ki::InstanceData* instance_data = ki::InstanceData::Get(env);
auto dead_ptrs = instance_data->CollectDeadWrappers<mx::array>();
for (void* ptr : dead_ptrs) {
mx::array* a = static_cast<mx::array*>(ptr);
int64_t ext = ki::internal::ExternalMemorySize<mx::array>::Get(a);
if (ext > 0) {
int64_t adjusted;
napi_adjust_external_memory(env, -ext, &adjusted);
}
delete a;
}
return dead_ptrs.size();
}

namespace ki {

// Allow passing Dtype to JS directly, no memory management involved as they are
Expand Down Expand Up @@ -808,5 +856,6 @@ void InitArray(napi_env env, napi_value exports) {
ki::Set(env, exports,
"tidy", &Tidy,
"dispose", &Dispose,
"getWrappersCount", &GetWrappersCount);
"getWrappersCount", &GetWrappersCount,
"sweepDeadArrays", &SweepDeadArrays);
}
26 changes: 26 additions & 0 deletions src/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,32 @@ struct Type<mx::array> : public AllowPassByValue<mx::array> {
napi_value value);
};

namespace internal {

// Report external memory for mx::array to enable GC pressure signaling.
// MLX arrays hold Metal GPU buffers that are invisible to the JS GC.
// Without this, the GC doesn't know about GPU memory pressure and doesn't
// collect array wrappers fast enough, causing Metal resource exhaustion.
template<>
struct ExternalMemorySize<mx::array> {
static int64_t Get(mx::array* a) {
// Metal has a hard limit of 499K buffer allocations. We must create
// enough external memory pressure to force the GC to collect array
// wrappers before hitting it. Report 1MB per array as the minimum
// external cost — this is much larger than the actual data size but
// necessary to trigger sufficiently aggressive GC for GPU resources.
size_t n = a->nbytes();
constexpr int64_t min_cost = 1024 * 1024; // 1MB
return static_cast<int64_t>(n) > min_cost ? static_cast<int64_t>(n) : min_cost;
}
};

} // namespace internal

} // namespace ki

// Synchronously sweep dead array wrappers whose JS finalizers haven't run yet.
// Called automatically from Eval() to prevent Metal resource accumulation.
size_t SweepDeadArrays(napi_env env);

#endif // SRC_ARRAY_H_
6 changes: 3 additions & 3 deletions src/fast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,16 @@ mx::array ScaledDotProductAttention(
throw std::invalid_argument(msg.str());
}
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, mask_str, {}, s);
queries, keys, values, scale, mask_str, {}, {}, s);
} else {
auto mask_arr = std::get<mx::array>(mask);
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {mask_arr}, s);
queries, keys, values, scale, "", {mask_arr}, {}, s);
}

} else {
return mx::fast::scaled_dot_product_attention(
queries, keys, values, scale, "", {}, s);
queries, keys, values, scale, "", {}, {}, s);
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/fft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::function<mx::array(const mx::array& a,
mx::StreamOrDevice s)>
FFTNOpWrapper(const char* name,
mx::array(*func1)(const mx::array&,
const std::vector<int>&,
const mx::Shape&,
const std::vector<int>&,
mx::StreamOrDevice),
mx::array(*func2)(const mx::array&,
Expand All @@ -45,16 +45,17 @@ FFTNOpWrapper(const char* name,
std::optional<std::vector<int>> axes,
mx::StreamOrDevice s) {
if (n && axes) {
return mx::fft::fftn(a, std::move(*n), std::move(*axes), s);
mx::Shape shape_n(n->begin(), n->end());
return func1(a, shape_n, std::move(*axes), s);
} else if (axes) {
return mx::fft::fftn(a, std::move(*axes), s);
return func2(a, std::move(*axes), s);
} else if (n) {
std::ostringstream msg;
msg << "[" << name << "] "
<< "`axes` should not be `None` if `s` is not `None`.";
throw std::invalid_argument(msg.str());
} else {
return mx::fft::fftn(a, s);
return func3(a, s);
}
};
}
Expand All @@ -66,7 +67,7 @@ std::function<mx::array(const mx::array& a,
mx::StreamOrDevice s)>
FFT2OpWrapper(const char* name,
mx::array(*func1)(const mx::array&,
const std::vector<int>&,
const mx::Shape&,
const std::vector<int>&,
mx::StreamOrDevice),
mx::array(*func2)(const mx::array&,
Expand Down
2 changes: 1 addition & 1 deletion src/indexing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ ScatterResult ScatterArgsNDimentional(const mx::array* a,
a->shape().begin() + non_none_indices, a->shape().end());
up = mx::reshape(std::move(up), std::move(up_reshape));

mx::Shape axes(arr_indices.size(), 0);
std::vector<int> axes(arr_indices.size(), 0);
std::iota(axes.begin(), axes.end(), 0);
return {std::move(arr_indices), std::move(up), std::move(axes)};
}
Expand Down
4 changes: 3 additions & 1 deletion src/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ void InitMemory(napi_env env, napi_value exports) {
"setMemoryLimit", &mx::set_memory_limit,
"setWiredLimit", &mx::set_wired_limit,
"setCacheLimit", &mx::set_cache_limit,
"clearCache", &mx::clear_cache);
"clearCache", &mx::clear_cache,
"getNumResources", &mx::get_num_resources,
"getResourceLimit", &mx::get_resource_limit);
}
12 changes: 11 additions & 1 deletion src/metal.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
#include "src/bindings.h"
#include "mlx/backend/gpu/device_info.h"

namespace metal_ops {

const std::unordered_map<std::string, std::variant<std::string, size_t>>&
DeviceInfo() {
return mx::gpu::device_info(0);
}

} // namespace metal_ops

void InitMetal(napi_env env, napi_value exports) {
napi_value metal = ki::CreateObject(env);
Expand All @@ -8,5 +18,5 @@ void InitMetal(napi_env env, napi_value exports) {
"isAvailable", &mx::metal::is_available,
"startCapture", &mx::metal::start_capture,
"stopCapture", &mx::metal::stop_capture,
"deviceInfo", &mx::metal::device_info);
"deviceInfo", &metal_ops::DeviceInfo);
}
Loading