Refactor EXLA block lowering through EXLA.CustomCall protocol#1739
Refactor EXLA block lowering through EXLA.CustomCall protocol#1739Chapaman wants to merge 11 commits intoelixir-nx:mainfrom
Conversation
| defprotocol EXLA.CustomCall do | ||
| @moduledoc """ | ||
| Protocol used by `EXLA.Defn` to lower specific `Nx.block/4` tags natively | ||
| instead of compiling the fallback callback. | ||
|
|
||
| Implementations receive the block tag struct, the output template (`out`), | ||
| the already-recursed MLIR `EXLA.MLIR.Value` arguments and the active | ||
| `EXLA.Client`. | ||
| """ | ||
|
|
||
| @fallback_to_any true | ||
|
|
||
| @doc """ | ||
| Returns `true` when EXLA should lower the block natively via `call/4`. | ||
|
|
||
| When it returns `false`, `EXLA.Defn` falls back to compiling the block's | ||
| default callback implementation. | ||
| """ | ||
| def apply?(struct, out, args, client) | ||
|
|
||
| @fallback_to_any true | ||
|
|
||
| @doc """ | ||
| Lowers the block natively. | ||
|
|
||
| Must return the list of `EXLA.MLIR.Value`s (or a single value) that | ||
| represents the block result, matching the shape of `out`. | ||
| """ | ||
| def call(struct, out, args, client) | ||
| end | ||
|
|
||
| defimpl EXLA.CustomCall, for: Any do | ||
| alias EXLA.MLIR.Value | ||
| alias EXLA.Defn | ||
|
|
||
| # --- apply?/4 --- | ||
|
|
||
| def apply?( | ||
| %Nx.Block.LinAlg.QR{}, | ||
| {%{type: {q_type_kind, _}}, _r}, | ||
| _args, | ||
| client | ||
| ) do | ||
| q_type_kind != :c and client.platform == :host | ||
| end |
There was a problem hiding this comment.
We split this into apply?/4 and call/4 because a single callback would mix “should EXLA take the native path for this compile?” with “emit MLIR.” Keeping two functions separates eligibility from lowering.
apply?answers whether the native path applies for this block tag in context: e.g. for QR we still have%Nx.Block.LinAlg.QR{}, but we only native-lower when the output type is real (not complex) and the client is:host. Another block could gate on arity, struct fields, mesh, etc.—without duplicating the fullcallbody.callonly runs whenapply?is true and contains the actualValue.*/ Defn lowering.
EXLA.Defn mirrors that split: recurse operands → apply? → either call or the generic block fallback (default_block_implementation).
There was a problem hiding this comment.
You could have a single callback and instead return :skip in call when it cannot be lowered.
| def call(struct, out, args, client) | ||
| end | ||
|
|
||
| defimpl EXLA.CustomCall, for: Any do |
There was a problem hiding this comment.
Should those be different implementations rather than having it all in Any?
There was a problem hiding this comment.
This way the user can provide their own overrides!
There was a problem hiding this comment.
Ah, I see! Let's add a comment explaining why then! :D
|
@polvalente @Chapaman very nice! However, I think we should couple this more closely to actual custom calls (in C), because that's how it will work in practice. In other words, we should only move to the protocol blocks which are implemented as custom C calls. This will make it easier to see what they have in common (and we won't need to get distracted with things like take, which is quite different). |
| // Loads a shared library with RTLD_GLOBAL so XLA FFI static registrations run. | ||
| // Used from tests (e.g. alias custom-call registration plugin). | ||
| fine::Ok<> dlopen_test_plugin(ErlNifEnv *env, std::string path) { | ||
| void *handle = dlopen(path.c_str(), RTLD_NOW | RTLD_GLOBAL); | ||
| if (handle == nullptr) { | ||
| const char *err = dlerror(); | ||
| throw std::invalid_argument(err ? err : "dlopen failed"); | ||
| } | ||
| (void)handle; | ||
| return fine::Ok(); | ||
| } | ||
|
|
||
| FINE_NIF(dlopen_test_plugin, 0); |
There was a problem hiding this comment.
Let's expose this as load_dylib so that it's not used just for tests
| // | ||
| // Built when BUILD_EXLA_TEST_PLUGIN=1 (Mix test). Load with RTLD_GLOBAL via | ||
| // EXLA.NIF.dlopen_test_plugin/1 before compiling or running graphs that emit | ||
| // the alias call_target_name. |
There was a problem hiding this comment.
We can wrap this file in ifndef EXLA_PROD to exclude its contents from in MIX_ENV=prod.
Let's also remove the plugin suffix. I suggest exla_test/custom_calls.cc as the file path in c_src
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),) | ||
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) | ||
| $(EXLA_SO): $(TEST_PLUGIN_SO) | ||
| endif | ||
| endif |
There was a problem hiding this comment.
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),) | |
| ifneq ($(BUILD_EXLA_TEST_PLUGIN),0) | |
| $(EXLA_SO): $(TEST_PLUGIN_SO) | |
| endif | |
| endif |
| "EXLA_VERSION" => "#{@version}", | ||
| "BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0") |
There was a problem hiding this comment.
| "EXLA_VERSION" => "#{@version}", | |
| "BUILD_EXLA_TEST_PLUGIN" => if(Mix.env() == :test, do: "1", else: "0") | |
| "EXLA_VERSION" => "#{@version}" |
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | ||
| @ mkdir -p $(dir $@) | ||
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | ||
|
|
||
| $(EXLA_SO): $(EXLA_CACHE_SO) |
There was a problem hiding this comment.
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | |
| @ mkdir -p $(dir $@) | |
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | |
| $(EXLA_SO): $(EXLA_CACHE_SO) | |
| $(TEST_PLUGIN_SO): $(TEST_PLUGIN_CC) | $(XLA_EXTENSION_DIR) | |
| @ mkdir -p $(PRIV_DIR) | |
| $(CXX) $(CFLAGS) -shared $(TEST_PLUGIN_CC) -o $@ $(LDFLAGS) | |
| EXLA_SO_DEPS = $(EXLA_CACHE_SO) | |
| ifeq($(MIX_ENV),test) | |
| EXLA_SO_DEPS += $(TEST_PLUGIN_SO) | |
| end | |
| $(EXLA_SO): $(EXLA_SO_DEPS) |
| @doc false | ||
| def qr_with_call_target( | ||
| %Value{function: func} = value, | ||
| q_typespec, | ||
| r_typespec, | ||
| call_target_name | ||
| ) | ||
| when is_binary(call_target_name) do |
There was a problem hiding this comment.
What if the protocol instead had 2 callbacks:
@spec function_name(struct, output_container, input_templates_list, client) :: String.t() | :skip
@spec config(struct, output_container, input_templates_list, client) :: map() | nilfunction_name would be the string name registered by invoking the .so (what we have here as call_target_name)
input and output typespecs we can infer ourselves
and backend_config would be obtained from the config callback. We'd have to validate all map values are encodable as mlir::DictionaryAttr (but we can encode them ourselves).
There was a problem hiding this comment.
There's a separate discussion on whether we expose other things such as has_side_effect and output_operand_aliases, but we can add these afterwards.
| case {operand_type, computation_type} do | ||
| {{:f, 32}, {:f, 32}} -> "eigh_cpu_custom_call_f32" | ||
| {{:f, 64}, {:f, 64}} -> "eigh_cpu_custom_call_f64" | ||
| {{:s, 8}, {:f, 32}} -> "eigh_cpu_custom_call_s8" | ||
| {{:s, 16}, {:f, 32}} -> "eigh_cpu_custom_call_s16" | ||
| {{:s, 32}, {:f, 32}} -> "eigh_cpu_custom_call_s32" | ||
| {{:s, 64}, {:f, 32}} -> "eigh_cpu_custom_call_s64" | ||
| {{:u, 8}, {:f, 32}} -> "eigh_cpu_custom_call_u8" | ||
| {{:u, 16}, {:f, 32}} -> "eigh_cpu_custom_call_u16" | ||
| {{:u, 32}, {:f, 32}} -> "eigh_cpu_custom_call_u32" | ||
| {{:u, 64}, {:f, 32}} -> "eigh_cpu_custom_call_u64" | ||
| _ -> :skip | ||
| end |
There was a problem hiding this comment.
I think the name is invariant to the output type. We can just use the input type
There was a problem hiding this comment.
Also, why are these not defps in the protocol impl?
There was a problem hiding this comment.
Also, why are these not defps in the protocol impl?
I thought I needed to make it a def and not defp because I'm using it in both EXLA.CustomCall and EXLA.MLIR.Value should I change it?
There was a problem hiding this comment.
Are Value.qr and Value.eigh actually used? I think if they are, they should just be aliases to Value.custom_call (which will end up calling the protocol).
If this is indeed possible, than these functions will just be used in the protocol implementation.
There was a problem hiding this comment.
😅😅 it was not being used lol
removed it :~
| defmodule EXLA.CustomCall.Builtins do | ||
| @moduledoc false | ||
|
|
||
| @doc """ | ||
| Host CPU `stablehlo.custom_call` target for `Nx.LinAlg.qr/2`, or `:skip`. | ||
|
|
||
| `operand_type` is the input matrix element type. | ||
| """ | ||
| def qr_cpu_target(operand_type) do | ||
| case operand_type do | ||
| {:f, 32} -> "qr_cpu_custom_call_f32" | ||
| {:f, 64} -> "qr_cpu_custom_call_f64" |
There was a problem hiding this comment.
Let's get rid of this module and only define the target resolution inside the call sites
|
|
||
| } // namespace exla | ||
|
|
||
| // Host QR custom calls: integer operands with f32 Q/R (see Nx.Type.to_floating/1 |
There was a problem hiding this comment.
This should not be here, right? Is this for test only?
There was a problem hiding this comment.
In refactoring to use the protocol, we ended up losing the ability to do typecasts for the operands like we did before.
We either would have to add these or we'd have to add an extra callback to the protocol to make it work in this simpler version which doesn't leak EXLA.MLIR.
In the final implementation these will be absorbed into elixir-nx/xla and we'll just use them.
There was a problem hiding this comment.
@polvalente we could also allow the custom call to return the types of the inputs and we cast them if they different. For most cases, the types of the inputs are the ones given as argument.
There was a problem hiding this comment.
I guess this is a good middle ground. Let's do this and we'll have a proper struct for the return format with the input types, callback name, attributes and whatnot
|
|
||
| def function_name(_, _, _, _), do: :skip | ||
|
|
||
| def config(_, _, _, _), do: nil |
There was a problem hiding this comment.
We don't need a separate callback, make it return {:ok, name, config}. We also need to better document what is the config.
There was a problem hiding this comment.
What about if we add output operand aliases too? Would you still prefer a single callback?
There was a problem hiding this comment.
We return everything from a single callback, so we make sure the format supports extensions.
|
|
||
| defp to_type(%Value{} = op, type) do | ||
| @doc false | ||
| def to_type(%Value{} = op, type) do |
There was a problem hiding this comment.
Why are we converting all of these to public?
There was a problem hiding this comment.
I think this is leftover from a previous iteration
Summary
EXLA.CustomCall(apply?/4,call/4) as the hook for native lowering ofNx.block/4in EXLA, with@fallback_to_any trueand implementations onAnyfor the blocks that previously had dedicated clauses indefn.ex.EXLA.Defn: replaced the long chain of:blockspecial cases with onecached_recur_operator(:block, …)path — recurse tensor args, thenapply?→callif true, elsedefault_block_implementation/5(the previous generic subfunction +Value.callpath).Nx.Block.LinAlg.QR,Nx.Block.LinAlg.Eigh,Nx.Block.Take,Nx.Block.TopK,Nx.Block.FFT2,Nx.Block.IFFT2,Nx.Block.RFFT, andNx.Block.IRFFTinto the protocol impl; behavior should match the old code paths.EXLA.Defnfunctions are now@doc falsepublic (to_type,op_type,op_shape,expr_to_typespec,axes_for_rank,fft,fft2) so the protocol can call them;fft/fft2no longer takestate(builder comes from%Value{}.function).{:f, _}vs{:c, _}), notq.type != :c(types are tuples).