diff --git a/app/ui/app/src/utils/mergeModels.test.ts b/app/ui/app/src/utils/mergeModels.test.ts
index 90092182a7e..a962dc37c89 100644
--- a/app/ui/app/src/utils/mergeModels.test.ts
+++ b/app/ui/app/src/utils/mergeModels.test.ts
@@ -41,14 +41,14 @@ describe("Model merging logic", () => {
expect(merged.length).toBe(FEATURED_MODELS.length + 2);
});
- it("should hide cloud models in airplane mode", () => {
+ it("should hide cloud models when cloud is disabled", () => {
const localModels: Model[] = [
new Model({ model: "gpt-oss:120b-cloud" }),
new Model({ model: "llama3:latest" }),
new Model({ model: "mistral:latest" }),
];
- const merged = mergeModels(localModels, true); // airplane mode = true
+ const merged = mergeModels(localModels, true); // cloud disabled = true
// No cloud models should be present
const cloudModels = merged.filter((m) => m.isCloud());
diff --git a/app/ui/app/src/utils/mergeModels.ts b/app/ui/app/src/utils/mergeModels.ts
index abbfe00b8dc..814d2af42a5 100644
--- a/app/ui/app/src/utils/mergeModels.ts
+++ b/app/ui/app/src/utils/mergeModels.ts
@@ -32,7 +32,7 @@ function alphabeticalSort(a: Model, b: Model): number {
//Merges models, sorting cloud models first, then other models
export function mergeModels(
localModels: Model[],
- airplaneMode: boolean = false,
+ hideCloudModels: boolean = false,
): Model[] {
const allModels = (localModels || []).map((model) => model);
@@ -95,7 +95,7 @@ export function mergeModels(
remainingModels.sort(alphabeticalSort);
- return airplaneMode
+ return hideCloudModels
? [...featuredModels, ...remainingModels]
: [...cloudModels, ...featuredModels, ...remainingModels];
}
diff --git a/app/ui/ui.go b/app/ui/ui.go
index 0b32f917e68..ed9acc06061 100644
--- a/app/ui/ui.go
+++ b/app/ui/ui.go
@@ -284,12 +284,15 @@ func (s *Server) Handler() http.Handler {
mux.Handle("POST /api/v1/model/upstream", handle(s.modelUpstream))
mux.Handle("GET /api/v1/settings", handle(s.getSettings))
mux.Handle("POST /api/v1/settings", handle(s.settings))
+ mux.Handle("GET /api/v1/cloud", handle(s.getCloudSetting))
+ mux.Handle("POST /api/v1/cloud", handle(s.cloudSetting))
// Ollama proxy endpoints
ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/version", ollamaProxy)
+ mux.Handle("GET /api/status", ollamaProxy)
mux.Handle("HEAD /api/version", ollamaProxy)
mux.Handle("POST /api/me", ollamaProxy)
mux.Handle("POST /api/signout", ollamaProxy)
@@ -1460,6 +1463,40 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
})
}
+func (s *Server) cloudSetting(w http.ResponseWriter, r *http.Request) error {
+ var req struct {
+ Enabled bool `json:"enabled"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ return fmt.Errorf("invalid request body: %w", err)
+ }
+
+ if err := s.Store.SetCloudEnabled(req.Enabled); err != nil {
+ return fmt.Errorf("failed to persist cloud setting: %w", err)
+ }
+
+ s.Restart()
+
+ return s.writeCloudStatus(w)
+}
+
+func (s *Server) getCloudSetting(w http.ResponseWriter, r *http.Request) error {
+ return s.writeCloudStatus(w)
+}
+
+func (s *Server) writeCloudStatus(w http.ResponseWriter) error {
+ disabled, source, err := s.Store.CloudStatus()
+ if err != nil {
+ return fmt.Errorf("failed to load cloud status: %w", err)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ return json.NewEncoder(w).Encode(map[string]any{
+ "disabled": disabled,
+ "source": source,
+ })
+}
+
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
diff --git a/app/ui/ui_test.go b/app/ui/ui_test.go
index 9807974356f..5a523600744 100644
--- a/app/ui/ui_test.go
+++ b/app/ui/ui_test.go
@@ -115,6 +115,107 @@ func TestHandlePostApiSettings(t *testing.T) {
}
}
+func TestHandlePostApiCloudSetting(t *testing.T) {
+ tmpHome := t.TempDir()
+ t.Setenv("HOME", tmpHome)
+ t.Setenv("OLLAMA_NO_CLOUD", "")
+
+ testStore := &store.Store{
+ DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
+ }
+ defer testStore.Close()
+
+ restartCount := 0
+ server := &Server{
+ Store: testStore,
+ Restart: func() {
+ restartCount++
+ },
+ }
+
+ for _, tc := range []struct {
+ name string
+ body string
+ wantEnabled bool
+ }{
+ {name: "disable cloud", body: `{"enabled": false}`, wantEnabled: false},
+ {name: "enable cloud", body: `{"enabled": true}`, wantEnabled: true},
+ } {
+ t.Run(tc.name, func(t *testing.T) {
+ req := httptest.NewRequest("POST", "/api/v1/cloud", bytes.NewBufferString(tc.body))
+ req.Header.Set("Content-Type", "application/json")
+ rr := httptest.NewRecorder()
+
+ if err := server.cloudSetting(rr, req); err != nil {
+ t.Fatalf("cloudSetting() error = %v", err)
+ }
+ if rr.Code != http.StatusOK {
+ t.Fatalf("cloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
+ }
+
+ var got map[string]any
+ if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
+ t.Fatalf("cloudSetting() invalid response JSON: %v", err)
+ }
+ if got["disabled"] != !tc.wantEnabled {
+ t.Fatalf("response disabled = %v, want %v", got["disabled"], !tc.wantEnabled)
+ }
+
+ disabled, err := testStore.CloudDisabled()
+ if err != nil {
+ t.Fatalf("CloudDisabled() error = %v", err)
+ }
+ if gotEnabled := !disabled; gotEnabled != tc.wantEnabled {
+ t.Fatalf("cloud enabled = %v, want %v", gotEnabled, tc.wantEnabled)
+ }
+ })
+ }
+
+ if restartCount != 2 {
+ t.Fatalf("Restart called %d times, want 2", restartCount)
+ }
+}
+
+func TestHandleGetApiCloudSetting(t *testing.T) {
+ tmpHome := t.TempDir()
+ t.Setenv("HOME", tmpHome)
+ t.Setenv("OLLAMA_NO_CLOUD", "")
+
+ testStore := &store.Store{
+ DBPath: filepath.Join(t.TempDir(), "db.sqlite"),
+ }
+ defer testStore.Close()
+
+ if err := testStore.SetCloudEnabled(false); err != nil {
+ t.Fatalf("SetCloudEnabled(false) error = %v", err)
+ }
+
+ server := &Server{
+ Store: testStore,
+ Restart: func() {},
+ }
+
+ req := httptest.NewRequest("GET", "/api/v1/cloud", nil)
+ rr := httptest.NewRecorder()
+ if err := server.getCloudSetting(rr, req); err != nil {
+ t.Fatalf("getCloudSetting() error = %v", err)
+ }
+ if rr.Code != http.StatusOK {
+ t.Fatalf("getCloudSetting() status = %d, want %d", rr.Code, http.StatusOK)
+ }
+
+ var got map[string]any
+ if err := json.Unmarshal(rr.Body.Bytes(), &got); err != nil {
+ t.Fatalf("getCloudSetting() invalid response JSON: %v", err)
+ }
+ if got["disabled"] != true {
+ t.Fatalf("response disabled = %v, want true", got["disabled"])
+ }
+ if got["source"] != "config" {
+ t.Fatalf("response source = %v, want config", got["source"])
+ }
+}
+
func TestAuthenticationMiddleware(t *testing.T) {
tests := []struct {
name string
diff --git a/cmd/background_unix.go b/cmd/background_unix.go
new file mode 100644
index 00000000000..a4eea48c5b3
--- /dev/null
+++ b/cmd/background_unix.go
@@ -0,0 +1,13 @@
+//go:build !windows
+
+package cmd
+
+import "syscall"
+
+// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Unix.
+// Setpgid prevents the server from being killed when the parent process exits.
+func backgroundServerSysProcAttr() *syscall.SysProcAttr {
+ return &syscall.SysProcAttr{
+ Setpgid: true,
+ }
+}
diff --git a/cmd/background_windows.go b/cmd/background_windows.go
new file mode 100644
index 00000000000..fa43b740093
--- /dev/null
+++ b/cmd/background_windows.go
@@ -0,0 +1,12 @@
+package cmd
+
+import "syscall"
+
+// backgroundServerSysProcAttr returns SysProcAttr for running the server in the background on Windows.
+// CREATE_NO_WINDOW (0x08000000) prevents a console window from appearing.
+func backgroundServerSysProcAttr() *syscall.SysProcAttr {
+ return &syscall.SysProcAttr{
+ CreationFlags: 0x08000000,
+ HideWindow: true,
+ }
+}
diff --git a/cmd/cmd.go b/cmd/cmd.go
index 4f974ce8488..fd2349853ec 100644
--- a/cmd/cmd.go
+++ b/cmd/cmd.go
@@ -15,6 +15,7 @@ import (
"net"
"net/http"
"os"
+ "os/exec"
"os/signal"
"path/filepath"
"runtime"
@@ -29,12 +30,15 @@ import (
"github.com/containerd/console"
"github.com/mattn/go-runewidth"
"github.com/olekukonko/tablewriter"
+ "github.com/pkg/browser"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"golang.org/x/term"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/cmd/config"
+ "github.com/ollama/ollama/cmd/tui"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
@@ -51,7 +55,44 @@ import (
"github.com/ollama/ollama/x/imagegen"
)
-const ConnectInstructions = "To sign in, navigate to:\n %s\n\n"
+func init() {
+ // Override default selectors to use Bubbletea TUI instead of raw terminal I/O.
+ config.DefaultSingleSelector = func(title string, items []config.ModelItem, current string) (string, error) {
+ tuiItems := tui.ReorderItems(tui.ConvertItems(items))
+ result, err := tui.SelectSingle(title, tuiItems, current)
+ if errors.Is(err, tui.ErrCancelled) {
+ return "", config.ErrCancelled
+ }
+ return result, err
+ }
+
+ config.DefaultMultiSelector = func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
+ tuiItems := tui.ReorderItems(tui.ConvertItems(items))
+ result, err := tui.SelectMultiple(title, tuiItems, preChecked)
+ if errors.Is(err, tui.ErrCancelled) {
+ return nil, config.ErrCancelled
+ }
+ return result, err
+ }
+
+ config.DefaultSignIn = func(modelName, signInURL string) (string, error) {
+ userName, err := tui.RunSignIn(modelName, signInURL)
+ if errors.Is(err, tui.ErrCancelled) {
+ return "", config.ErrCancelled
+ }
+ return userName, err
+ }
+
+ config.DefaultConfirmPrompt = func(prompt string) (bool, error) {
+ ok, err := tui.RunConfirm(prompt)
+ if errors.Is(err, tui.ErrCancelled) {
+ return false, config.ErrCancelled
+ }
+ return ok, err
+ }
+}
+
+const ConnectInstructions = "If your browser did not open, navigate to:\n %s\n\n"
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
@@ -141,6 +182,10 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
mfConfig.System = cmd.Args
case "license":
mfConfig.License = cmd.Args
+ case "parser":
+ mfConfig.Parser = cmd.Args
+ case "renderer":
+ mfConfig.Renderer = cmd.Args
}
}
@@ -365,14 +410,25 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
return err
} else if info.RemoteHost != "" {
// Cloud model, no need to load/unload
+
+ isCloud := strings.HasPrefix(info.RemoteHost, "https://ollama.com")
+
+ // Check if user is signed in for ollama.com cloud models
+ if isCloud {
+ if _, err := client.Whoami(cmd.Context()); err != nil {
+ return err
+ }
+ }
+
if opts.ShowConnect {
p.StopAndClear()
- if strings.HasPrefix(info.RemoteHost, "https://ollama.com") {
+ if isCloud {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on 'ollama.com' ⚡\n", info.RemoteModel)
} else {
fmt.Fprintf(os.Stderr, "Connecting to '%s' on '%s'\n", info.RemoteModel, info.RemoteHost)
}
}
+
return nil
}
@@ -529,6 +585,17 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.WordWrap = !nowrap
+ useImagegen := false
+ if cmd.Flags().Lookup("imagegen") != nil {
+ useImagegen, err = cmd.Flags().GetBool("imagegen")
+ if err != nil {
+ return err
+ }
+ }
+ if useImagegen {
+ opts.Options["use_imagegen_runner"] = true
+ }
+
// Fill out the rest of the options based on information about the
// model.
client, err := api.ClientFromEnvironment()
@@ -662,6 +729,7 @@ func SigninHandler(cmd *cobra.Command, args []string) error {
fmt.Println()
if aErr.SigninURL != "" {
+ _ = browser.OpenURL(aErr.SigninURL)
fmt.Printf(ConnectInstructions, aErr.SigninURL)
}
return nil
@@ -1018,8 +1086,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
}
if resp.ModelInfo != nil {
- arch := resp.ModelInfo["general.architecture"].(string)
- rows = append(rows, []string{"", "architecture", arch})
+ arch, _ := resp.ModelInfo["general.architecture"].(string)
+ if arch != "" {
+ rows = append(rows, []string{"", "architecture", arch})
+ }
var paramStr string
if resp.Details.ParameterSize != "" {
@@ -1029,7 +1099,9 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
paramStr = format.HumanNumber(uint64(f))
}
}
- rows = append(rows, []string{"", "parameters", paramStr})
+ if paramStr != "" {
+ rows = append(rows, []string{"", "parameters", paramStr})
+ }
if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok {
if f, ok := v.(float64); ok {
@@ -1745,7 +1817,7 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
- return fmt.Errorf("ollama server not responding - %w", err)
+ return err
}
}
return nil
@@ -1786,6 +1858,216 @@ Environment Variables:
cmd.SetUsageTemplate(cmd.UsageTemplate() + envUsage)
}
+// ensureServerRunning checks if the ollama server is running and starts it in the background if not.
+func ensureServerRunning(ctx context.Context) error {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+
+ // Check if server is already running
+ if err := client.Heartbeat(ctx); err == nil {
+ return nil // server is already running
+ }
+
+ // Server not running, start it in the background
+ exe, err := os.Executable()
+ if err != nil {
+ return fmt.Errorf("could not find executable: %w", err)
+ }
+
+ serverCmd := exec.CommandContext(ctx, exe, "serve")
+ serverCmd.Env = os.Environ()
+ serverCmd.SysProcAttr = backgroundServerSysProcAttr()
+ if err := serverCmd.Start(); err != nil {
+ return fmt.Errorf("failed to start server: %w", err)
+ }
+
+ // Wait for the server to be ready
+ for {
+ time.Sleep(500 * time.Millisecond)
+ if err := client.Heartbeat(ctx); err == nil {
+ return nil // server has started
+ }
+ }
+}
+
+// runInteractiveTUI runs the main interactive TUI menu.
+func runInteractiveTUI(cmd *cobra.Command) {
+ // Ensure the server is running before showing the TUI
+ if err := ensureServerRunning(cmd.Context()); err != nil {
+ fmt.Fprintf(os.Stderr, "Error starting server: %v\n", err)
+ return
+ }
+
+ // Selector adapters for tui
+ singleSelector := func(title string, items []config.ModelItem, current string) (string, error) {
+ tuiItems := tui.ReorderItems(tui.ConvertItems(items))
+ result, err := tui.SelectSingle(title, tuiItems, current)
+ if errors.Is(err, tui.ErrCancelled) {
+ return "", config.ErrCancelled
+ }
+ return result, err
+ }
+
+ multiSelector := func(title string, items []config.ModelItem, preChecked []string) ([]string, error) {
+ tuiItems := tui.ReorderItems(tui.ConvertItems(items))
+ result, err := tui.SelectMultiple(title, tuiItems, preChecked)
+ if errors.Is(err, tui.ErrCancelled) {
+ return nil, config.ErrCancelled
+ }
+ return result, err
+ }
+
+ for {
+ result, err := tui.Run()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error: %v\n", err)
+ return
+ }
+
+ runModel := func(modelName string) {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error: %v\n", err)
+ return
+ }
+ if err := config.ShowOrPull(cmd.Context(), client, modelName); err != nil {
+ if errors.Is(err, config.ErrCancelled) {
+ return
+ }
+ fmt.Fprintf(os.Stderr, "Error: %v\n", err)
+ return
+ }
+ _ = config.SetLastModel(modelName)
+ opts := runOptions{
+ Model: modelName,
+ WordWrap: os.Getenv("TERM") == "xterm-256color",
+ Options: map[string]any{},
+ ShowConnect: true,
+ }
+ if err := loadOrUnloadModel(cmd, &opts); err != nil {
+ fmt.Fprintf(os.Stderr, "Error loading model: %v\n", err)
+ return
+ }
+ if err := generateInteractive(cmd, opts); err != nil {
+ fmt.Fprintf(os.Stderr, "Error running model: %v\n", err)
+ }
+ }
+
+ launchIntegration := func(name string) bool {
+ // If not configured or model no longer exists, prompt for model selection
+ configuredModel := config.IntegrationModel(name)
+ if configuredModel == "" || !config.ModelExists(cmd.Context(), configuredModel) || config.IsCloudModelDisabled(cmd.Context(), configuredModel) {
+ err := config.ConfigureIntegrationWithSelectors(cmd.Context(), name, singleSelector, multiSelector)
+ if errors.Is(err, config.ErrCancelled) {
+ return false // Return to main menu
+ }
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", name, err)
+ return true
+ }
+ }
+ if err := config.LaunchIntegration(name); err != nil {
+ fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", name, err)
+ }
+ return true
+ }
+
+ switch result.Selection {
+ case tui.SelectionNone:
+ // User quit
+ return
+ case tui.SelectionRunModel:
+ _ = config.SetLastSelection("run")
+ if modelName := config.LastModel(); modelName != "" && !config.IsCloudModelDisabled(cmd.Context(), modelName) {
+ runModel(modelName)
+ } else {
+ modelName, err := config.SelectModelWithSelector(cmd.Context(), singleSelector)
+ if errors.Is(err, config.ErrCancelled) {
+ continue // Return to main menu
+ }
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
+ continue
+ }
+ runModel(modelName)
+ }
+ case tui.SelectionChangeRunModel:
+ _ = config.SetLastSelection("run")
+ // Use model from modal if selected, otherwise show picker
+ modelName := result.Model
+ if modelName == "" {
+ var err error
+ modelName, err = config.SelectModelWithSelector(cmd.Context(), singleSelector)
+ if errors.Is(err, config.ErrCancelled) {
+ continue // Return to main menu
+ }
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error selecting model: %v\n", err)
+ continue
+ }
+ }
+ if config.IsCloudModelDisabled(cmd.Context(), modelName) {
+ continue // Return to main menu
+ }
+ runModel(modelName)
+ case tui.SelectionIntegration:
+ _ = config.SetLastSelection(result.Integration)
+ if !launchIntegration(result.Integration) {
+ continue // Return to main menu
+ }
+ case tui.SelectionChangeIntegration:
+ _ = config.SetLastSelection(result.Integration)
+ if len(result.Models) > 0 {
+ // Filter out cloud-disabled models
+ var filtered []string
+ for _, m := range result.Models {
+ if !config.IsCloudModelDisabled(cmd.Context(), m) {
+ filtered = append(filtered, m)
+ }
+ }
+ if len(filtered) == 0 {
+ continue
+ }
+ result.Models = filtered
+ // Multi-select from modal (Editor integrations)
+ if err := config.SaveAndEditIntegration(result.Integration, result.Models); err != nil {
+ fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
+ continue
+ }
+ if err := config.LaunchIntegrationWithModel(result.Integration, result.Models[0]); err != nil {
+ fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
+ }
+ } else if result.Model != "" {
+ if config.IsCloudModelDisabled(cmd.Context(), result.Model) {
+ continue
+ }
+ // Single-select from modal - save and launch
+ if err := config.SaveIntegration(result.Integration, []string{result.Model}); err != nil {
+ fmt.Fprintf(os.Stderr, "Error saving config: %v\n", err)
+ continue
+ }
+ if err := config.LaunchIntegrationWithModel(result.Integration, result.Model); err != nil {
+ fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
+ }
+ } else {
+ err := config.ConfigureIntegrationWithSelectors(cmd.Context(), result.Integration, singleSelector, multiSelector)
+ if errors.Is(err, config.ErrCancelled) {
+ continue // Return to main menu
+ }
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error configuring %s: %v\n", result.Integration, err)
+ continue
+ }
+ if err := config.LaunchIntegration(result.Integration); err != nil {
+ fmt.Fprintf(os.Stderr, "Error launching %s: %v\n", result.Integration, err)
+ }
+ }
+ }
+ }
+}
+
func NewCLI() *cobra.Command {
log.SetFlags(log.LstdFlags | log.Lshortfile)
cobra.EnableCommandSorting = false
@@ -1808,11 +2090,13 @@ func NewCLI() *cobra.Command {
return
}
- cmd.Print(cmd.UsageString())
+ runInteractiveTUI(cmd)
},
}
rootCmd.Flags().BoolP("version", "v", false, "Show version information")
+ rootCmd.Flags().Bool("verbose", false, "Show timings for response")
+ rootCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
createCmd := &cobra.Command{
Use: "create MODEL",
@@ -1872,6 +2156,9 @@ func NewCLI() *cobra.Command {
// Image generation flags (width, height, steps, seed, etc.)
imagegen.RegisterFlags(runCmd)
+ runCmd.Flags().Bool("imagegen", false, "Use the imagegen runner for LLM inference")
+ runCmd.Flags().MarkHidden("imagegen")
+
stopCmd := &cobra.Command{
Use: "stop MODEL",
Short: "Stop a running model",
@@ -1883,7 +2170,7 @@ func NewCLI() *cobra.Command {
serveCmd := &cobra.Command{
Use: "serve",
Aliases: []string{"start"},
- Short: "Start ollama",
+ Short: "Start Ollama",
Args: cobra.ExactArgs(0),
RunE: RunServer,
}
@@ -1916,6 +2203,15 @@ func NewCLI() *cobra.Command {
RunE: SigninHandler,
}
+ loginCmd := &cobra.Command{
+ Use: "login",
+ Short: "Sign in to ollama.com",
+ Hidden: true,
+ Args: cobra.ExactArgs(0),
+ PreRunE: checkServerHeartbeat,
+ RunE: SigninHandler,
+ }
+
signoutCmd := &cobra.Command{
Use: "signout",
Short: "Sign out from ollama.com",
@@ -1924,6 +2220,15 @@ func NewCLI() *cobra.Command {
RunE: SignoutHandler,
}
+ logoutCmd := &cobra.Command{
+ Use: "logout",
+ Short: "Sign out from ollama.com",
+ Hidden: true,
+ Args: cobra.ExactArgs(0),
+ PreRunE: checkServerHeartbeat,
+ RunE: SignoutHandler,
+ }
+
listCmd := &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
@@ -1998,7 +2303,7 @@ func NewCLI() *cobra.Command {
switch cmd {
case runCmd:
imagegen.AppendFlagsDocs(cmd)
- appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
+ appendEnvDocs(cmd, []envconfig.EnvVar{envVars["OLLAMA_EDITOR"], envVars["OLLAMA_HOST"], envVars["OLLAMA_NOHISTORY"]})
case serveCmd:
appendEnvDocs(cmd, []envconfig.EnvVar{
envVars["OLLAMA_DEBUG"],
@@ -2009,6 +2314,7 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_MAX_QUEUE"],
envVars["OLLAMA_MODELS"],
envVars["OLLAMA_NUM_PARALLEL"],
+ envVars["OLLAMA_NO_CLOUD"],
envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
@@ -2033,13 +2339,16 @@ func NewCLI() *cobra.Command {
pullCmd,
pushCmd,
signinCmd,
+ loginCmd,
signoutCmd,
+ logoutCmd,
listCmd,
psCmd,
copyCmd,
deleteCmd,
runnerCmd,
rpcCmd,
+ config.LaunchCmd(checkServerHeartbeat, runInteractiveTUI),
)
return rootCmd
diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go
index eedd0c61dbe..7217c3d13c3 100644
--- a/cmd/cmd_test.go
+++ b/cmd/cmd_test.go
@@ -3,6 +3,7 @@ package cmd
import (
"bytes"
"encoding/json"
+ "errors"
"fmt"
"io"
"net/http"
@@ -1553,7 +1554,7 @@ func TestShowInfoImageGen(t *testing.T) {
Details: api.ModelDetails{
Family: "ZImagePipeline",
ParameterSize: "10.3B",
- QuantizationLevel: "FP8",
+ QuantizationLevel: "Q8",
},
Capabilities: []model.Capability{model.CapabilityImage},
Requires: "0.14.0",
@@ -1565,7 +1566,7 @@ func TestShowInfoImageGen(t *testing.T) {
expect := " Model\n" +
" architecture ZImagePipeline \n" +
" parameters 10.3B \n" +
- " quantization FP8 \n" +
+ " quantization Q8 \n" +
" requires 0.14.0 \n" +
"\n" +
" Capabilities\n" +
@@ -1659,3 +1660,103 @@ func TestRunOptions_Copy_Independence(t *testing.T) {
t.Error("Copy Think should not be affected by original modification")
}
}
+
+func TestLoadOrUnloadModel_CloudModelAuth(t *testing.T) {
+ tests := []struct {
+ name string
+ remoteHost string
+ whoamiStatus int
+ whoamiResp any
+ expectedError string
+ }{
+ {
+ name: "ollama.com cloud model - user signed in",
+ remoteHost: "https://ollama.com",
+ whoamiStatus: http.StatusOK,
+ whoamiResp: api.UserResponse{Name: "testuser"},
+ },
+ {
+ name: "ollama.com cloud model - user not signed in",
+ remoteHost: "https://ollama.com",
+ whoamiStatus: http.StatusUnauthorized,
+ whoamiResp: map[string]string{
+ "error": "unauthorized",
+ "signin_url": "https://ollama.com/signin",
+ },
+ expectedError: "unauthorized",
+ },
+ {
+ name: "non-ollama.com remote - no auth check",
+ remoteHost: "https://other-remote.com",
+ whoamiStatus: http.StatusUnauthorized, // should not be called
+ whoamiResp: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ whoamiCalled := false
+ mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/show":
+ w.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w).Encode(api.ShowResponse{
+ RemoteHost: tt.remoteHost,
+ RemoteModel: "test-model",
+ }); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ case "/api/me":
+ whoamiCalled = true
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(tt.whoamiStatus)
+ if tt.whoamiResp != nil {
+ if err := json.NewEncoder(w).Encode(tt.whoamiResp); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ }
+ }
+ default:
+ http.NotFound(w, r)
+ }
+ }))
+ defer mockServer.Close()
+
+ t.Setenv("OLLAMA_HOST", mockServer.URL)
+
+ cmd := &cobra.Command{}
+ cmd.SetContext(t.Context())
+
+ opts := &runOptions{
+ Model: "test-cloud-model",
+ ShowConnect: false,
+ }
+
+ err := loadOrUnloadModel(cmd, opts)
+
+ if strings.HasPrefix(tt.remoteHost, "https://ollama.com") {
+ if !whoamiCalled {
+ t.Error("expected whoami to be called for ollama.com cloud model")
+ }
+ } else {
+ if whoamiCalled {
+ t.Error("whoami should not be called for non-ollama.com remote")
+ }
+ }
+
+ if tt.expectedError != "" {
+ if err == nil {
+ t.Errorf("expected error containing %q, got nil", tt.expectedError)
+ } else {
+ var authErr api.AuthorizationError
+ if !errors.As(err, &authErr) {
+ t.Errorf("expected AuthorizationError, got %T: %v", err, err)
+ }
+ }
+ } else {
+ if err != nil {
+ t.Errorf("expected no error, got %v", err)
+ }
+ }
+ })
+ }
+}
diff --git a/cmd/config/claude.go b/cmd/config/claude.go
new file mode 100644
index 00000000000..b7ed02af1d9
--- /dev/null
+++ b/cmd/config/claude.go
@@ -0,0 +1,192 @@
+package config
+
+import (
+ "context"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/envconfig"
+)
+
+// Claude implements Runner and AliasConfigurer for Claude Code integration
+type Claude struct{}
+
+// Compile-time check that Claude implements AliasConfigurer
+var _ AliasConfigurer = (*Claude)(nil)
+
+func (c *Claude) String() string { return "Claude Code" }
+
+func (c *Claude) args(model string, extra []string) []string {
+ var args []string
+ if model != "" {
+ args = append(args, "--model", model)
+ }
+ args = append(args, extra...)
+ return args
+}
+
+func (c *Claude) findPath() (string, error) {
+ if p, err := exec.LookPath("claude"); err == nil {
+ return p, nil
+ }
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", err
+ }
+ name := "claude"
+ if runtime.GOOS == "windows" {
+ name = "claude.exe"
+ }
+ fallback := filepath.Join(home, ".claude", "local", name)
+ if _, err := os.Stat(fallback); err != nil {
+ return "", err
+ }
+ return fallback, nil
+}
+
+func (c *Claude) Run(model string, args []string) error {
+ claudePath, err := c.findPath()
+ if err != nil {
+ return fmt.Errorf("claude is not installed, install from https://code.claude.com/docs/en/quickstart")
+ }
+
+ cmd := exec.Command(claudePath, c.args(model, args)...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ env := append(os.Environ(),
+ "ANTHROPIC_BASE_URL="+envconfig.Host().String(),
+ "ANTHROPIC_API_KEY=",
+ "ANTHROPIC_AUTH_TOKEN=ollama",
+ )
+
+ env = append(env, c.modelEnvVars(model)...)
+
+ cmd.Env = env
+ return cmd.Run()
+}
+
+// modelEnvVars returns Claude Code env vars that route all model tiers through Ollama.
+func (c *Claude) modelEnvVars(model string) []string {
+ primary := model
+ fast := model
+ if cfg, err := loadIntegration("claude"); err == nil && cfg.Aliases != nil {
+ if p := cfg.Aliases["primary"]; p != "" {
+ primary = p
+ }
+ if f := cfg.Aliases["fast"]; f != "" {
+ fast = f
+ }
+ }
+ return []string{
+ "ANTHROPIC_DEFAULT_OPUS_MODEL=" + primary,
+ "ANTHROPIC_DEFAULT_SONNET_MODEL=" + primary,
+ "ANTHROPIC_DEFAULT_HAIKU_MODEL=" + fast,
+ "CLAUDE_CODE_SUBAGENT_MODEL=" + primary,
+ }
+}
+
+// ConfigureAliases sets up model aliases for Claude Code.
+// model: the model to use (if empty, user will be prompted to select)
+// aliases: existing alias configuration to preserve/update
+// Cloud-only: subagent routing (fast model) is gated to cloud models only until
+// there is a better strategy for prompt caching on local models.
+func (c *Claude) ConfigureAliases(ctx context.Context, model string, existingAliases map[string]string, force bool) (map[string]string, bool, error) {
+ aliases := make(map[string]string)
+ for k, v := range existingAliases {
+ aliases[k] = v
+ }
+
+ if model != "" {
+ aliases["primary"] = model
+ }
+
+ if !force && aliases["primary"] != "" {
+ client, _ := api.ClientFromEnvironment()
+ if isCloudModel(ctx, client, aliases["primary"]) {
+ if isCloudModel(ctx, client, aliases["fast"]) {
+ return aliases, false, nil
+ }
+ } else {
+ delete(aliases, "fast")
+ return aliases, false, nil
+ }
+ }
+
+ items, existingModels, cloudModels, client, err := listModels(ctx)
+ if err != nil {
+ return nil, false, err
+ }
+
+ fmt.Fprintf(os.Stderr, "\n%sModel Configuration%s\n\n", ansiBold, ansiReset)
+
+ if aliases["primary"] == "" || force {
+ primary, err := DefaultSingleSelector("Select model:", items, aliases["primary"])
+ if err != nil {
+ return nil, false, err
+ }
+ if err := pullIfNeeded(ctx, client, existingModels, primary); err != nil {
+ return nil, false, err
+ }
+ if err := ensureAuth(ctx, client, cloudModels, []string{primary}); err != nil {
+ return nil, false, err
+ }
+ aliases["primary"] = primary
+ }
+
+ if isCloudModel(ctx, client, aliases["primary"]) {
+ if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
+ aliases["fast"] = aliases["primary"]
+ }
+ } else {
+ delete(aliases, "fast")
+ }
+
+ return aliases, true, nil
+}
+
+// SetAliases syncs the configured aliases to the Ollama server using prefix matching.
+// Cloud-only: for local models (fast is empty), we delete any existing aliases to
+// prevent stale routing to a previous cloud model.
+func (c *Claude) SetAliases(ctx context.Context, aliases map[string]string) error {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+
+ prefixes := []string{"claude-sonnet-", "claude-haiku-"}
+
+ if aliases["fast"] == "" {
+ for _, prefix := range prefixes {
+ _ = client.DeleteAliasExperimental(ctx, &api.AliasDeleteRequest{Alias: prefix})
+ }
+ return nil
+ }
+
+ prefixAliases := map[string]string{
+ "claude-sonnet-": aliases["primary"],
+ "claude-haiku-": aliases["fast"],
+ }
+
+ var errs []string
+ for prefix, target := range prefixAliases {
+ req := &api.AliasRequest{
+ Alias: prefix,
+ Target: target,
+ PrefixMatching: true,
+ }
+ if err := client.SetAliasExperimental(ctx, req); err != nil {
+ errs = append(errs, prefix)
+ }
+ }
+
+ if len(errs) > 0 {
+ return fmt.Errorf("failed to set aliases: %v", errs)
+ }
+ return nil
+}
diff --git a/cmd/config/claude_test.go b/cmd/config/claude_test.go
new file mode 100644
index 00000000000..e5ad16a20b1
--- /dev/null
+++ b/cmd/config/claude_test.go
@@ -0,0 +1,198 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+ "slices"
+ "strings"
+ "testing"
+)
+
+func TestClaudeIntegration(t *testing.T) {
+ c := &Claude{}
+
+ t.Run("String", func(t *testing.T) {
+ if got := c.String(); got != "Claude Code" {
+ t.Errorf("String() = %q, want %q", got, "Claude Code")
+ }
+ })
+
+ t.Run("implements Runner", func(t *testing.T) {
+ var _ Runner = c
+ })
+}
+
+func TestClaudeFindPath(t *testing.T) {
+ c := &Claude{}
+
+ t.Run("finds claude in PATH", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ name := "claude"
+ if runtime.GOOS == "windows" {
+ name = "claude.exe"
+ }
+ fakeBin := filepath.Join(tmpDir, name)
+ os.WriteFile(fakeBin, []byte("#!/bin/sh\n"), 0o755)
+ t.Setenv("PATH", tmpDir)
+
+ got, err := c.findPath()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != fakeBin {
+ t.Errorf("findPath() = %q, want %q", got, fakeBin)
+ }
+ })
+
+ t.Run("falls back to ~/.claude/local/claude", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
+
+ name := "claude"
+ if runtime.GOOS == "windows" {
+ name = "claude.exe"
+ }
+ fallback := filepath.Join(tmpDir, ".claude", "local", name)
+ os.MkdirAll(filepath.Dir(fallback), 0o755)
+ os.WriteFile(fallback, []byte("#!/bin/sh\n"), 0o755)
+
+ got, err := c.findPath()
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if got != fallback {
+ t.Errorf("findPath() = %q, want %q", got, fallback)
+ }
+ })
+
+ t.Run("returns error when neither PATH nor fallback exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ t.Setenv("PATH", t.TempDir()) // empty dir, no claude binary
+
+ _, err := c.findPath()
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ })
+}
+
+func TestClaudeArgs(t *testing.T) {
+ c := &Claude{}
+
+ tests := []struct {
+ name string
+ model string
+ args []string
+ want []string
+ }{
+ {"with model", "llama3.2", nil, []string{"--model", "llama3.2"}},
+ {"empty model", "", nil, nil},
+ {"with model and verbose", "llama3.2", []string{"--verbose"}, []string{"--model", "llama3.2", "--verbose"}},
+ {"empty model with help", "", []string{"--help"}, []string{"--help"}},
+ {"with allowed tools", "llama3.2", []string{"--allowedTools", "Read,Write,Bash"}, []string{"--model", "llama3.2", "--allowedTools", "Read,Write,Bash"}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := c.args(tt.model, tt.args)
+ if !slices.Equal(got, tt.want) {
+ t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestClaudeModelEnvVars(t *testing.T) {
+ c := &Claude{}
+
+ envMap := func(envs []string) map[string]string {
+ m := make(map[string]string)
+ for _, e := range envs {
+ k, v, _ := strings.Cut(e, "=")
+ m[k] = v
+ }
+ return m
+ }
+
+ t.Run("falls back to model param when no aliases saved", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ got := envMap(c.modelEnvVars("llama3.2"))
+ if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2" {
+ t.Errorf("OPUS = %q, want llama3.2", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
+ }
+ if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2" {
+ t.Errorf("SONNET = %q, want llama3.2", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
+ }
+ if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2" {
+ t.Errorf("HAIKU = %q, want llama3.2", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
+ }
+ if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2" {
+ t.Errorf("SUBAGENT = %q, want llama3.2", got["CLAUDE_CODE_SUBAGENT_MODEL"])
+ }
+ })
+
+ t.Run("uses primary alias for opus sonnet and subagent", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ SaveIntegration("claude", []string{"qwen3:8b"})
+ saveAliases("claude", map[string]string{"primary": "qwen3:8b"})
+
+ got := envMap(c.modelEnvVars("qwen3:8b"))
+ if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "qwen3:8b" {
+ t.Errorf("OPUS = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
+ }
+ if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "qwen3:8b" {
+ t.Errorf("SONNET = %q, want qwen3:8b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
+ }
+ if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "qwen3:8b" {
+ t.Errorf("HAIKU = %q, want qwen3:8b (no fast alias)", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
+ }
+ if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "qwen3:8b" {
+ t.Errorf("SUBAGENT = %q, want qwen3:8b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
+ }
+ })
+
+ t.Run("uses fast alias for haiku", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ SaveIntegration("claude", []string{"llama3.2:70b"})
+ saveAliases("claude", map[string]string{
+ "primary": "llama3.2:70b",
+ "fast": "llama3.2:8b",
+ })
+
+ got := envMap(c.modelEnvVars("llama3.2:70b"))
+ if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "llama3.2:70b" {
+ t.Errorf("OPUS = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
+ }
+ if got["ANTHROPIC_DEFAULT_SONNET_MODEL"] != "llama3.2:70b" {
+ t.Errorf("SONNET = %q, want llama3.2:70b", got["ANTHROPIC_DEFAULT_SONNET_MODEL"])
+ }
+ if got["ANTHROPIC_DEFAULT_HAIKU_MODEL"] != "llama3.2:8b" {
+ t.Errorf("HAIKU = %q, want llama3.2:8b", got["ANTHROPIC_DEFAULT_HAIKU_MODEL"])
+ }
+ if got["CLAUDE_CODE_SUBAGENT_MODEL"] != "llama3.2:70b" {
+ t.Errorf("SUBAGENT = %q, want llama3.2:70b", got["CLAUDE_CODE_SUBAGENT_MODEL"])
+ }
+ })
+
+ t.Run("alias primary overrides model param", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ SaveIntegration("claude", []string{"saved-model"})
+ saveAliases("claude", map[string]string{"primary": "saved-model"})
+
+ got := envMap(c.modelEnvVars("different-model"))
+ if got["ANTHROPIC_DEFAULT_OPUS_MODEL"] != "saved-model" {
+ t.Errorf("OPUS = %q, want saved-model", got["ANTHROPIC_DEFAULT_OPUS_MODEL"])
+ }
+ })
+}
diff --git a/cmd/config/cline.go b/cmd/config/cline.go
new file mode 100644
index 00000000000..847d8d43103
--- /dev/null
+++ b/cmd/config/cline.go
@@ -0,0 +1,123 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+
+ "github.com/ollama/ollama/envconfig"
+)
+
+// Cline implements Runner and Editor for the Cline CLI integration
+type Cline struct{}
+
+func (c *Cline) String() string { return "Cline" }
+
+func (c *Cline) Run(model string, args []string) error {
+ if _, err := exec.LookPath("cline"); err != nil {
+ return fmt.Errorf("cline is not installed, install with: npm install -g cline")
+ }
+
+ models := []string{model}
+ if config, err := loadIntegration("cline"); err == nil && len(config.Models) > 0 {
+ models = config.Models
+ }
+ var err error
+ models, err = resolveEditorModels("cline", models, func() ([]string, error) {
+ return selectModels(context.Background(), "cline", "")
+ })
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ if err := c.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+
+ cmd := exec.Command("cline", args...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ return cmd.Run()
+}
+
+func (c *Cline) Paths() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+ p := filepath.Join(home, ".cline", "data", "globalState.json")
+ if _, err := os.Stat(p); err == nil {
+ return []string{p}
+ }
+ return nil
+}
+
+func (c *Cline) Edit(models []string) error {
+ if len(models) == 0 {
+ return nil
+ }
+
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ configPath := filepath.Join(home, ".cline", "data", "globalState.json")
+ if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
+ return err
+ }
+
+ config := make(map[string]any)
+ if data, err := os.ReadFile(configPath); err == nil {
+ if err := json.Unmarshal(data, &config); err != nil {
+ return fmt.Errorf("failed to parse config: %w, at: %s", err, configPath)
+ }
+ }
+
+ // Set Ollama as the provider for both act and plan modes
+ baseURL := envconfig.Host().String()
+ config["ollamaBaseUrl"] = baseURL
+ config["actModeApiProvider"] = "ollama"
+ config["actModeOllamaModelId"] = models[0]
+ config["actModeOllamaBaseUrl"] = baseURL
+ config["planModeApiProvider"] = "ollama"
+ config["planModeOllamaModelId"] = models[0]
+ config["planModeOllamaBaseUrl"] = baseURL
+
+ config["welcomeViewCompleted"] = true
+
+ data, err := json.MarshalIndent(config, "", " ")
+ if err != nil {
+ return err
+ }
+ return writeWithBackup(configPath, data)
+}
+
+func (c *Cline) Models() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+
+ config, err := readJSONFile(filepath.Join(home, ".cline", "data", "globalState.json"))
+ if err != nil {
+ return nil
+ }
+
+ if config["actModeApiProvider"] != "ollama" {
+ return nil
+ }
+
+ modelID, _ := config["actModeOllamaModelId"].(string)
+ if modelID == "" {
+ return nil
+ }
+ return []string{modelID}
+}
diff --git a/cmd/config/cline_test.go b/cmd/config/cline_test.go
new file mode 100644
index 00000000000..7e9f7f07c89
--- /dev/null
+++ b/cmd/config/cline_test.go
@@ -0,0 +1,204 @@
+package config
+
+import (
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestClineIntegration(t *testing.T) {
+ c := &Cline{}
+
+ t.Run("String", func(t *testing.T) {
+ if got := c.String(); got != "Cline" {
+ t.Errorf("String() = %q, want %q", got, "Cline")
+ }
+ })
+
+ t.Run("implements Runner", func(t *testing.T) {
+ var _ Runner = c
+ })
+
+ t.Run("implements Editor", func(t *testing.T) {
+ var _ Editor = c
+ })
+}
+
+func TestClineEdit(t *testing.T) {
+ c := &Cline{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".cline", "data")
+ configPath := filepath.Join(configDir, "globalState.json")
+
+ readConfig := func() map[string]any {
+ data, _ := os.ReadFile(configPath)
+ var config map[string]any
+ json.Unmarshal(data, &config)
+ return config
+ }
+
+ t.Run("creates config from scratch", func(t *testing.T) {
+ os.RemoveAll(filepath.Join(tmpDir, ".cline"))
+
+ if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
+ t.Fatal(err)
+ }
+
+ config := readConfig()
+ if config["actModeApiProvider"] != "ollama" {
+ t.Errorf("actModeApiProvider = %v, want ollama", config["actModeApiProvider"])
+ }
+ if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
+ t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud", config["actModeOllamaModelId"])
+ }
+ if config["planModeApiProvider"] != "ollama" {
+ t.Errorf("planModeApiProvider = %v, want ollama", config["planModeApiProvider"])
+ }
+ if config["planModeOllamaModelId"] != "kimi-k2.5:cloud" {
+ t.Errorf("planModeOllamaModelId = %v, want kimi-k2.5:cloud", config["planModeOllamaModelId"])
+ }
+ if config["welcomeViewCompleted"] != true {
+ t.Errorf("welcomeViewCompleted = %v, want true", config["welcomeViewCompleted"])
+ }
+ })
+
+ t.Run("preserves existing fields", func(t *testing.T) {
+ os.RemoveAll(filepath.Join(tmpDir, ".cline"))
+ os.MkdirAll(configDir, 0o755)
+
+ existing := map[string]any{
+ "remoteRulesToggles": map[string]any{},
+ "remoteWorkflowToggles": map[string]any{},
+ "customSetting": "keep-me",
+ }
+ data, _ := json.Marshal(existing)
+ os.WriteFile(configPath, data, 0o644)
+
+ if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
+ t.Fatal(err)
+ }
+
+ config := readConfig()
+ if config["customSetting"] != "keep-me" {
+ t.Errorf("customSetting was not preserved")
+ }
+ if config["actModeOllamaModelId"] != "glm-5:cloud" {
+ t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
+ }
+ })
+
+ t.Run("updates model on re-edit", func(t *testing.T) {
+ os.RemoveAll(filepath.Join(tmpDir, ".cline"))
+
+ if err := c.Edit([]string{"kimi-k2.5:cloud"}); err != nil {
+ t.Fatal(err)
+ }
+ if err := c.Edit([]string{"glm-5:cloud"}); err != nil {
+ t.Fatal(err)
+ }
+
+ config := readConfig()
+ if config["actModeOllamaModelId"] != "glm-5:cloud" {
+ t.Errorf("actModeOllamaModelId = %v, want glm-5:cloud", config["actModeOllamaModelId"])
+ }
+ if config["planModeOllamaModelId"] != "glm-5:cloud" {
+ t.Errorf("planModeOllamaModelId = %v, want glm-5:cloud", config["planModeOllamaModelId"])
+ }
+ })
+
+ t.Run("empty models is no-op", func(t *testing.T) {
+ os.RemoveAll(filepath.Join(tmpDir, ".cline"))
+
+ if err := c.Edit(nil); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(configPath); !os.IsNotExist(err) {
+ t.Error("expected no config file to be created for empty models")
+ }
+ })
+
+ t.Run("uses first model as primary", func(t *testing.T) {
+ os.RemoveAll(filepath.Join(tmpDir, ".cline"))
+
+ if err := c.Edit([]string{"kimi-k2.5:cloud", "glm-5:cloud"}); err != nil {
+ t.Fatal(err)
+ }
+
+ config := readConfig()
+ if config["actModeOllamaModelId"] != "kimi-k2.5:cloud" {
+ t.Errorf("actModeOllamaModelId = %v, want kimi-k2.5:cloud (first model)", config["actModeOllamaModelId"])
+ }
+ })
+}
+
+func TestClineModels(t *testing.T) {
+ c := &Cline{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".cline", "data")
+ configPath := filepath.Join(configDir, "globalState.json")
+
+ t.Run("returns nil when no config", func(t *testing.T) {
+ if models := c.Models(); models != nil {
+ t.Errorf("Models() = %v, want nil", models)
+ }
+ })
+
+ t.Run("returns nil when provider is not ollama", func(t *testing.T) {
+ os.MkdirAll(configDir, 0o755)
+ config := map[string]any{
+ "actModeApiProvider": "anthropic",
+ "actModeOllamaModelId": "some-model",
+ }
+ data, _ := json.Marshal(config)
+ os.WriteFile(configPath, data, 0o644)
+
+ if models := c.Models(); models != nil {
+ t.Errorf("Models() = %v, want nil", models)
+ }
+ })
+
+ t.Run("returns model when ollama is configured", func(t *testing.T) {
+ os.MkdirAll(configDir, 0o755)
+ config := map[string]any{
+ "actModeApiProvider": "ollama",
+ "actModeOllamaModelId": "kimi-k2.5:cloud",
+ }
+ data, _ := json.Marshal(config)
+ os.WriteFile(configPath, data, 0o644)
+
+ models := c.Models()
+ if len(models) != 1 || models[0] != "kimi-k2.5:cloud" {
+ t.Errorf("Models() = %v, want [kimi-k2.5:cloud]", models)
+ }
+ })
+}
+
+func TestClinePaths(t *testing.T) {
+ c := &Cline{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("returns nil when no config exists", func(t *testing.T) {
+ if paths := c.Paths(); paths != nil {
+ t.Errorf("Paths() = %v, want nil", paths)
+ }
+ })
+
+ t.Run("returns path when config exists", func(t *testing.T) {
+ configDir := filepath.Join(tmpDir, ".cline", "data")
+ os.MkdirAll(configDir, 0o755)
+ configPath := filepath.Join(configDir, "globalState.json")
+ os.WriteFile(configPath, []byte("{}"), 0o644)
+
+ paths := c.Paths()
+ if len(paths) != 1 || paths[0] != configPath {
+ t.Errorf("Paths() = %v, want [%s]", paths, configPath)
+ }
+ })
+}
diff --git a/cmd/config/codex.go b/cmd/config/codex.go
new file mode 100644
index 00000000000..ee2c70542cf
--- /dev/null
+++ b/cmd/config/codex.go
@@ -0,0 +1,67 @@
+package config
+
+import (
+ "fmt"
+ "os"
+ "os/exec"
+ "strings"
+
+ "github.com/ollama/ollama/envconfig"
+ "golang.org/x/mod/semver"
+)
+
+// Codex implements Runner for Codex integration
+type Codex struct{}
+
+func (c *Codex) String() string { return "Codex" }
+
+func (c *Codex) args(model string, extra []string) []string {
+ args := []string{"--oss"}
+ if model != "" {
+ args = append(args, "-m", model)
+ }
+ args = append(args, extra...)
+ return args
+}
+
+func (c *Codex) Run(model string, args []string) error {
+ if err := checkCodexVersion(); err != nil {
+ return err
+ }
+
+ cmd := exec.Command("codex", c.args(model, args)...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ cmd.Env = append(os.Environ(),
+ "OPENAI_BASE_URL="+envconfig.Host().String()+"/v1/",
+ "OPENAI_API_KEY=ollama",
+ )
+ return cmd.Run()
+}
+
+func checkCodexVersion() error {
+ if _, err := exec.LookPath("codex"); err != nil {
+ return fmt.Errorf("codex is not installed, install with: npm install -g @openai/codex")
+ }
+
+ out, err := exec.Command("codex", "--version").Output()
+ if err != nil {
+ return fmt.Errorf("failed to get codex version: %w", err)
+ }
+
+ // Parse output like "codex-cli 0.87.0"
+ fields := strings.Fields(strings.TrimSpace(string(out)))
+ if len(fields) < 2 {
+ return fmt.Errorf("unexpected codex version output: %s", string(out))
+ }
+
+ version := "v" + fields[len(fields)-1]
+ minVersion := "v0.81.0"
+
+ if semver.Compare(version, minVersion) < 0 {
+ return fmt.Errorf("codex version %s is too old, minimum required is %s, update with: npm update -g @openai/codex", fields[len(fields)-1], "0.81.0")
+ }
+
+ return nil
+}
diff --git a/cmd/config/codex_test.go b/cmd/config/codex_test.go
new file mode 100644
index 00000000000..9c18910be5a
--- /dev/null
+++ b/cmd/config/codex_test.go
@@ -0,0 +1,31 @@
+package config
+
+import (
+ "slices"
+ "testing"
+)
+
+func TestCodexArgs(t *testing.T) {
+ c := &Codex{}
+
+ tests := []struct {
+ name string
+ model string
+ args []string
+ want []string
+ }{
+ {"with model", "llama3.2", nil, []string{"--oss", "-m", "llama3.2"}},
+ {"empty model", "", nil, []string{"--oss"}},
+ {"with model and profile", "qwen3-coder", []string{"-p", "myprofile"}, []string{"--oss", "-m", "qwen3-coder", "-p", "myprofile"}},
+ {"with sandbox flag", "llama3.2", []string{"--sandbox", "workspace-write"}, []string{"--oss", "-m", "llama3.2", "--sandbox", "workspace-write"}},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := c.args(tt.model, tt.args)
+ if !slices.Equal(got, tt.want) {
+ t.Errorf("args(%q, %v) = %v, want %v", tt.model, tt.args, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/cmd/config/config.go b/cmd/config/config.go
new file mode 100644
index 00000000000..ce9374ce501
--- /dev/null
+++ b/cmd/config/config.go
@@ -0,0 +1,280 @@
+// Package config provides integration configuration for external coding tools
+// (Claude Code, Codex, Droid, OpenCode) to use Ollama models.
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+type integration struct {
+ Models []string `json:"models"`
+ Aliases map[string]string `json:"aliases,omitempty"`
+}
+
+type config struct {
+ Integrations map[string]*integration `json:"integrations"`
+ LastModel string `json:"last_model,omitempty"`
+ LastSelection string `json:"last_selection,omitempty"` // "run" or integration name
+}
+
+func configPath() (string, error) {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(home, ".ollama", "config.json"), nil
+}
+
+func legacyConfigPath() (string, error) {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return "", err
+ }
+ return filepath.Join(home, ".ollama", "config", "config.json"), nil
+}
+
+// migrateConfig moves the config from the legacy path to ~/.ollama/config.json
+func migrateConfig() (bool, error) {
+ oldPath, err := legacyConfigPath()
+ if err != nil {
+ return false, err
+ }
+
+ oldData, err := os.ReadFile(oldPath)
+ if err != nil {
+ if os.IsNotExist(err) {
+ return false, nil
+ }
+ return false, err
+ }
+
+ // Ignore legacy files with invalid JSON and continue startup.
+ if !json.Valid(oldData) {
+ return false, nil
+ }
+
+ newPath, err := configPath()
+ if err != nil {
+ return false, err
+ }
+
+ if err := os.MkdirAll(filepath.Dir(newPath), 0o755); err != nil {
+ return false, err
+ }
+ if err := os.WriteFile(newPath, oldData, 0o644); err != nil {
+ return false, fmt.Errorf("write new config: %w", err)
+ }
+
+ _ = os.Remove(oldPath)
+ _ = os.Remove(filepath.Dir(oldPath)) // clean up empty directory
+
+ return true, nil
+}
+
+func load() (*config, error) {
+ path, err := configPath()
+ if err != nil {
+ return nil, err
+ }
+
+ data, err := os.ReadFile(path)
+ if err != nil && os.IsNotExist(err) {
+ if migrated, merr := migrateConfig(); merr == nil && migrated {
+ data, err = os.ReadFile(path)
+ }
+ }
+ if err != nil {
+ if os.IsNotExist(err) {
+ return &config{Integrations: make(map[string]*integration)}, nil
+ }
+ return nil, err
+ }
+
+ var cfg config
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return nil, fmt.Errorf("failed to parse config: %w, at: %s", err, path)
+ }
+ if cfg.Integrations == nil {
+ cfg.Integrations = make(map[string]*integration)
+ }
+ return &cfg, nil
+}
+
+func save(cfg *config) error {
+ path, err := configPath()
+ if err != nil {
+ return err
+ }
+
+ if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
+ return err
+ }
+
+ data, err := json.MarshalIndent(cfg, "", " ")
+ if err != nil {
+ return err
+ }
+
+ return writeWithBackup(path, data)
+}
+
+func SaveIntegration(appName string, models []string) error {
+ if appName == "" {
+ return errors.New("app name cannot be empty")
+ }
+
+ cfg, err := load()
+ if err != nil {
+ return err
+ }
+
+ key := strings.ToLower(appName)
+ existing := cfg.Integrations[key]
+ var aliases map[string]string
+ if existing != nil && existing.Aliases != nil {
+ aliases = existing.Aliases
+ }
+
+ cfg.Integrations[key] = &integration{
+ Models: models,
+ Aliases: aliases,
+ }
+
+ return save(cfg)
+}
+
+// IntegrationModel returns the first configured model for an integration, or empty string if not configured.
+func IntegrationModel(appName string) string {
+ ic, err := loadIntegration(appName)
+ if err != nil || len(ic.Models) == 0 {
+ return ""
+ }
+ return ic.Models[0]
+}
+
+// IntegrationModels returns all configured models for an integration, or nil.
+func IntegrationModels(appName string) []string {
+ ic, err := loadIntegration(appName)
+ if err != nil || len(ic.Models) == 0 {
+ return nil
+ }
+ return ic.Models
+}
+
+// LastModel returns the last model that was run, or empty string if none.
+func LastModel() string {
+ cfg, err := load()
+ if err != nil {
+ return ""
+ }
+ return cfg.LastModel
+}
+
+// SetLastModel saves the last model that was run.
+func SetLastModel(model string) error {
+ cfg, err := load()
+ if err != nil {
+ return err
+ }
+ cfg.LastModel = model
+ return save(cfg)
+}
+
+// LastSelection returns the last menu selection ("run" or integration name), or empty string if none.
+func LastSelection() string {
+ cfg, err := load()
+ if err != nil {
+ return ""
+ }
+ return cfg.LastSelection
+}
+
+// SetLastSelection saves the last menu selection ("run" or integration name).
+func SetLastSelection(selection string) error {
+ cfg, err := load()
+ if err != nil {
+ return err
+ }
+ cfg.LastSelection = selection
+ return save(cfg)
+}
+
+// ModelExists checks if a model exists on the Ollama server.
+func ModelExists(ctx context.Context, name string) bool {
+ if name == "" {
+ return false
+ }
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return false
+ }
+ models, err := client.List(ctx)
+ if err != nil {
+ return false
+ }
+ for _, m := range models.Models {
+ if m.Name == name || strings.HasPrefix(m.Name, name+":") {
+ return true
+ }
+ }
+ return false
+}
+
+func loadIntegration(appName string) (*integration, error) {
+ cfg, err := load()
+ if err != nil {
+ return nil, err
+ }
+
+ ic, ok := cfg.Integrations[strings.ToLower(appName)]
+ if !ok {
+ return nil, os.ErrNotExist
+ }
+
+ return ic, nil
+}
+
+func saveAliases(appName string, aliases map[string]string) error {
+ if appName == "" {
+ return errors.New("app name cannot be empty")
+ }
+
+ cfg, err := load()
+ if err != nil {
+ return err
+ }
+
+ key := strings.ToLower(appName)
+ existing := cfg.Integrations[key]
+ if existing == nil {
+ existing = &integration{}
+ }
+
+ // Replace aliases entirely (not merge) so deletions are persisted
+ existing.Aliases = aliases
+
+ cfg.Integrations[key] = existing
+ return save(cfg)
+}
+
+func listIntegrations() ([]integration, error) {
+ cfg, err := load()
+ if err != nil {
+ return nil, err
+ }
+
+ result := make([]integration, 0, len(cfg.Integrations))
+ for _, ic := range cfg.Integrations {
+ result = append(result, *ic)
+ }
+
+ return result, nil
+}
diff --git a/cmd/config/config_cloud_test.go b/cmd/config/config_cloud_test.go
new file mode 100644
index 00000000000..23e7313d955
--- /dev/null
+++ b/cmd/config/config_cloud_test.go
@@ -0,0 +1,677 @@
+package config
+
+import (
+ "context"
+ "errors"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestSetAliases_CloudModel(t *testing.T) {
+ // Test the SetAliases logic by checking the alias map behavior
+ aliases := map[string]string{
+ "primary": "kimi-k2.5:cloud",
+ "fast": "kimi-k2.5:cloud",
+ }
+
+ // Verify fast is set (cloud model behavior)
+ if aliases["fast"] == "" {
+ t.Error("cloud model should have fast alias set")
+ }
+ if aliases["fast"] != aliases["primary"] {
+ t.Errorf("fast should equal primary for auto-set, got fast=%q primary=%q", aliases["fast"], aliases["primary"])
+ }
+}
+
+func TestSetAliases_LocalModel(t *testing.T) {
+ aliases := map[string]string{
+ "primary": "llama3.2:latest",
+ }
+ // Simulate local model behavior: fast should be empty
+ delete(aliases, "fast")
+
+ if aliases["fast"] != "" {
+ t.Error("local model should have empty fast alias")
+ }
+}
+
+func TestSaveAliases_ReplacesNotMerges(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // First save with both primary and fast
+ initial := map[string]string{
+ "primary": "cloud-model",
+ "fast": "cloud-model",
+ }
+ if err := saveAliases("claude", initial); err != nil {
+ t.Fatalf("failed to save initial aliases: %v", err)
+ }
+
+ // Verify both are saved
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if loaded.Aliases["fast"] != "cloud-model" {
+ t.Errorf("expected fast=cloud-model, got %q", loaded.Aliases["fast"])
+ }
+
+ // Now save without fast (simulating switch to local model)
+ updated := map[string]string{
+ "primary": "local-model",
+ // fast intentionally missing
+ }
+ if err := saveAliases("claude", updated); err != nil {
+ t.Fatalf("failed to save updated aliases: %v", err)
+ }
+
+ // Verify fast is GONE (not merged/preserved)
+ loaded, err = loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load after update: %v", err)
+ }
+ if loaded.Aliases["fast"] != "" {
+ t.Errorf("fast should be removed after saving without it, got %q", loaded.Aliases["fast"])
+ }
+ if loaded.Aliases["primary"] != "local-model" {
+ t.Errorf("primary should be updated to local-model, got %q", loaded.Aliases["primary"])
+ }
+}
+
+func TestSaveAliases_PreservesModels(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // First save integration with models
+ if err := SaveIntegration("claude", []string{"model1", "model2"}); err != nil {
+ t.Fatalf("failed to save integration: %v", err)
+ }
+
+ // Then update aliases
+ aliases := map[string]string{"primary": "new-model"}
+ if err := saveAliases("claude", aliases); err != nil {
+ t.Fatalf("failed to save aliases: %v", err)
+ }
+
+ // Verify models are preserved
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if len(loaded.Models) != 2 || loaded.Models[0] != "model1" {
+ t.Errorf("models should be preserved, got %v", loaded.Models)
+ }
+}
+
+// TestSaveAliases_EmptyMap clears all aliases
+func TestSaveAliases_EmptyMap(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Save with aliases
+ if err := saveAliases("claude", map[string]string{"primary": "model", "fast": "model"}); err != nil {
+ t.Fatalf("failed to save: %v", err)
+ }
+
+ // Save empty map
+ if err := saveAliases("claude", map[string]string{}); err != nil {
+ t.Fatalf("failed to save empty: %v", err)
+ }
+
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if len(loaded.Aliases) != 0 {
+ t.Errorf("aliases should be empty, got %v", loaded.Aliases)
+ }
+}
+
+// TestSaveAliases_NilMap handles nil gracefully
+func TestSaveAliases_NilMap(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Save with aliases first
+ if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
+ t.Fatalf("failed to save: %v", err)
+ }
+
+ // Save nil map - should clear aliases
+ if err := saveAliases("claude", nil); err != nil {
+ t.Fatalf("failed to save nil: %v", err)
+ }
+
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if len(loaded.Aliases) > 0 {
+ t.Errorf("aliases should be nil or empty, got %v", loaded.Aliases)
+ }
+}
+
+// TestSaveAliases_EmptyAppName returns error
+func TestSaveAliases_EmptyAppName(t *testing.T) {
+ err := saveAliases("", map[string]string{"primary": "model"})
+ if err == nil {
+ t.Error("expected error for empty app name")
+ }
+}
+
+func TestSaveAliases_CaseInsensitive(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ if err := saveAliases("Claude", map[string]string{"primary": "model1"}); err != nil {
+ t.Fatalf("failed to save: %v", err)
+ }
+
+ // Load with different case
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if loaded.Aliases["primary"] != "model1" {
+ t.Errorf("expected primary=model1, got %q", loaded.Aliases["primary"])
+ }
+
+ // Update with different case
+ if err := saveAliases("CLAUDE", map[string]string{"primary": "model2"}); err != nil {
+ t.Fatalf("failed to update: %v", err)
+ }
+
+ loaded, err = loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load after update: %v", err)
+ }
+ if loaded.Aliases["primary"] != "model2" {
+ t.Errorf("expected primary=model2, got %q", loaded.Aliases["primary"])
+ }
+}
+
+// TestSaveAliases_CreatesIntegration creates integration if it doesn't exist
+func TestSaveAliases_CreatesIntegration(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Save aliases for non-existent integration
+ if err := saveAliases("newintegration", map[string]string{"primary": "model"}); err != nil {
+ t.Fatalf("failed to save: %v", err)
+ }
+
+ loaded, err := loadIntegration("newintegration")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if loaded.Aliases["primary"] != "model" {
+ t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
+ }
+}
+
+func TestConfigureAliases_AliasMap(t *testing.T) {
+ t.Run("cloud model auto-sets fast to primary", func(t *testing.T) {
+ aliases := make(map[string]string)
+ aliases["primary"] = "cloud-model"
+
+ // Simulate cloud model behavior
+ isCloud := true
+ if isCloud {
+ if aliases["fast"] == "" {
+ aliases["fast"] = aliases["primary"]
+ }
+ }
+
+ if aliases["fast"] != "cloud-model" {
+ t.Errorf("expected fast=cloud-model, got %q", aliases["fast"])
+ }
+ })
+
+ t.Run("cloud model preserves custom fast", func(t *testing.T) {
+ aliases := map[string]string{
+ "primary": "cloud-model",
+ "fast": "custom-fast-model",
+ }
+
+ // Simulate cloud model behavior - should preserve existing fast
+ isCloud := true
+ if isCloud {
+ if aliases["fast"] == "" {
+ aliases["fast"] = aliases["primary"]
+ }
+ }
+
+ if aliases["fast"] != "custom-fast-model" {
+ t.Errorf("expected fast=custom-fast-model (preserved), got %q", aliases["fast"])
+ }
+ })
+
+ t.Run("local model clears fast", func(t *testing.T) {
+ aliases := map[string]string{
+ "primary": "local-model",
+ "fast": "should-be-cleared",
+ }
+
+ // Simulate local model behavior
+ isCloud := false
+ if !isCloud {
+ delete(aliases, "fast")
+ }
+
+ if aliases["fast"] != "" {
+ t.Errorf("expected fast to be cleared, got %q", aliases["fast"])
+ }
+ })
+
+ t.Run("switching cloud to local clears fast", func(t *testing.T) {
+ // Start with cloud config
+ aliases := map[string]string{
+ "primary": "cloud-model",
+ "fast": "cloud-model",
+ }
+
+ // Switch to local
+ aliases["primary"] = "local-model"
+ isCloud := false
+ if !isCloud {
+ delete(aliases, "fast")
+ }
+
+ if aliases["fast"] != "" {
+ t.Errorf("fast should be cleared when switching to local, got %q", aliases["fast"])
+ }
+ if aliases["primary"] != "local-model" {
+ t.Errorf("primary should be updated, got %q", aliases["primary"])
+ }
+ })
+
+ t.Run("switching local to cloud sets fast", func(t *testing.T) {
+ // Start with local config (no fast)
+ aliases := map[string]string{
+ "primary": "local-model",
+ }
+
+ // Switch to cloud
+ aliases["primary"] = "cloud-model"
+ isCloud := true
+ if isCloud {
+ if aliases["fast"] == "" {
+ aliases["fast"] = aliases["primary"]
+ }
+ }
+
+ if aliases["fast"] != "cloud-model" {
+ t.Errorf("fast should be set when switching to cloud, got %q", aliases["fast"])
+ }
+ })
+}
+
+func TestSetAliases_PrefixMapping(t *testing.T) {
+ // This tests the expected mapping without needing a real client
+ aliases := map[string]string{
+ "primary": "my-cloud-model",
+ "fast": "my-fast-model",
+ }
+
+ expectedMappings := map[string]string{
+ "claude-sonnet-": aliases["primary"],
+ "claude-haiku-": aliases["fast"],
+ }
+
+ if expectedMappings["claude-sonnet-"] != "my-cloud-model" {
+ t.Errorf("claude-sonnet- should map to primary")
+ }
+ if expectedMappings["claude-haiku-"] != "my-fast-model" {
+ t.Errorf("claude-haiku- should map to fast")
+ }
+}
+
+func TestSetAliases_LocalDeletesPrefixes(t *testing.T) {
+ aliases := map[string]string{
+ "primary": "local-model",
+ // fast is empty/missing - indicates local model
+ }
+
+ prefixesToDelete := []string{"claude-sonnet-", "claude-haiku-"}
+
+ // Verify the logic: when fast is empty, we should delete
+ if aliases["fast"] != "" {
+ t.Error("fast should be empty for local model")
+ }
+
+ // Verify we have the right prefixes to delete
+ if len(prefixesToDelete) != 2 {
+ t.Errorf("expected 2 prefixes to delete, got %d", len(prefixesToDelete))
+ }
+}
+
+// TestAtomicUpdate_ServerFailsConfigNotSaved simulates atomic update behavior
+func TestAtomicUpdate_ServerFailsConfigNotSaved(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Simulate: server fails, config should NOT be saved
+ serverErr := errors.New("server unavailable")
+
+ if serverErr == nil {
+ t.Error("config should NOT be saved when server fails")
+ }
+}
+
+// TestAtomicUpdate_ServerSucceedsConfigSaved simulates successful atomic update
+func TestAtomicUpdate_ServerSucceedsConfigSaved(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Simulate: server succeeds, config should be saved
+ var serverErr error
+ if serverErr != nil {
+ t.Fatal("server should succeed")
+ }
+
+ if err := saveAliases("claude", map[string]string{"primary": "model"}); err != nil {
+ t.Fatalf("saveAliases failed: %v", err)
+ }
+
+ // Verify it was actually saved
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if loaded.Aliases["primary"] != "model" {
+ t.Errorf("expected primary=model, got %q", loaded.Aliases["primary"])
+ }
+}
+
+func TestConfigFile_PreservesUnknownFields(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Write config with extra fields
+ configPath := filepath.Join(tmpDir, ".ollama", "config.json")
+ os.MkdirAll(filepath.Dir(configPath), 0o755)
+
+ // Note: Our config struct only has Integrations, so top-level unknown fields
+ // won't be preserved by our current implementation. This test documents that.
+ initialConfig := `{
+ "integrations": {
+ "claude": {
+ "models": ["model1"],
+ "aliases": {"primary": "model1"},
+ "unknownField": "should be lost"
+ }
+ },
+ "topLevelUnknown": "will be lost"
+}`
+ os.WriteFile(configPath, []byte(initialConfig), 0o644)
+
+ // Update aliases
+ if err := saveAliases("claude", map[string]string{"primary": "model2"}); err != nil {
+ t.Fatalf("failed to save: %v", err)
+ }
+
+ // Read raw file to check
+ data, _ := os.ReadFile(configPath)
+ content := string(data)
+
+ // models should be preserved
+ if !contains(content, "model1") {
+ t.Error("models should be preserved")
+ }
+
+ // primary should be updated
+ if !contains(content, "model2") {
+ t.Error("primary should be updated to model2")
+ }
+}
+
+func contains(s, substr string) bool {
+ return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
+}
+
+func containsHelper(s, substr string) bool {
+ for i := 0; i <= len(s)-len(substr); i++ {
+ if s[i:i+len(substr)] == substr {
+ return true
+ }
+ }
+ return false
+}
+
+func TestClaudeImplementsAliasConfigurer(t *testing.T) {
+ c := &Claude{}
+ var _ AliasConfigurer = c // Compile-time check
+}
+
+func TestModelNameEdgeCases(t *testing.T) {
+ testCases := []struct {
+ name string
+ model string
+ }{
+ {"simple", "llama3.2"},
+ {"with tag", "llama3.2:latest"},
+ {"with cloud tag", "kimi-k2.5:cloud"},
+ {"with namespace", "library/llama3.2"},
+ {"with dots", "glm-4.7-flash"},
+ {"with numbers", "qwen3:8b"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ aliases := map[string]string{"primary": tc.model}
+ if err := saveAliases("claude", aliases); err != nil {
+ t.Fatalf("failed to save model %q: %v", tc.model, err)
+ }
+
+ loaded, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatalf("failed to load: %v", err)
+ }
+ if loaded.Aliases["primary"] != tc.model {
+ t.Errorf("expected primary=%q, got %q", tc.model, loaded.Aliases["primary"])
+ }
+ })
+ }
+}
+
+func TestSwitchingScenarios(t *testing.T) {
+ t.Run("cloud to local removes fast", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Initial cloud config
+ if err := saveAliases("claude", map[string]string{
+ "primary": "cloud-model",
+ "fast": "cloud-model",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ // Switch to local (no fast)
+ if err := saveAliases("claude", map[string]string{
+ "primary": "local-model",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ := loadIntegration("claude")
+ if loaded.Aliases["fast"] != "" {
+ t.Errorf("fast should be removed, got %q", loaded.Aliases["fast"])
+ }
+ if loaded.Aliases["primary"] != "local-model" {
+ t.Errorf("primary should be local-model, got %q", loaded.Aliases["primary"])
+ }
+ })
+
+ t.Run("local to cloud adds fast", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Initial local config
+ if err := saveAliases("claude", map[string]string{
+ "primary": "local-model",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ // Switch to cloud (with fast)
+ if err := saveAliases("claude", map[string]string{
+ "primary": "cloud-model",
+ "fast": "cloud-model",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ := loadIntegration("claude")
+ if loaded.Aliases["fast"] != "cloud-model" {
+ t.Errorf("fast should be cloud-model, got %q", loaded.Aliases["fast"])
+ }
+ })
+
+ t.Run("cloud to different cloud updates both", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Initial cloud config
+ if err := saveAliases("claude", map[string]string{
+ "primary": "cloud-model-1",
+ "fast": "cloud-model-1",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ // Switch to different cloud
+ if err := saveAliases("claude", map[string]string{
+ "primary": "cloud-model-2",
+ "fast": "cloud-model-2",
+ }); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ := loadIntegration("claude")
+ if loaded.Aliases["primary"] != "cloud-model-2" {
+ t.Errorf("primary should be cloud-model-2, got %q", loaded.Aliases["primary"])
+ }
+ if loaded.Aliases["fast"] != "cloud-model-2" {
+ t.Errorf("fast should be cloud-model-2, got %q", loaded.Aliases["fast"])
+ }
+ })
+}
+
+func TestToolCapabilityFiltering(t *testing.T) {
+ t.Run("all models checked for tool capability", func(t *testing.T) {
+ // Both cloud and local models are checked for tool capability via Show API
+ // Only models with "tools" in capabilities are included
+ m := modelInfo{Name: "tool-model", Remote: false, ToolCapable: true}
+ if !m.ToolCapable {
+ t.Error("tool capable model should be marked as such")
+ }
+ })
+
+ t.Run("modelInfo includes ToolCapable field", func(t *testing.T) {
+ m := modelInfo{Name: "test", Remote: true, ToolCapable: true}
+ if !m.ToolCapable {
+ t.Error("ToolCapable field should be accessible")
+ }
+ })
+}
+
+func TestIsCloudModel_RequiresClient(t *testing.T) {
+ t.Run("nil client always returns false", func(t *testing.T) {
+ // isCloudModel now only uses Show API, no suffix detection
+ if isCloudModel(context.Background(), nil, "model:cloud") {
+ t.Error("nil client should return false regardless of suffix")
+ }
+ if isCloudModel(context.Background(), nil, "local-model") {
+ t.Error("nil client should return false")
+ }
+ })
+}
+
+func TestModelsAndAliasesMustStayInSync(t *testing.T) {
+ t.Run("saveAliases followed by saveIntegration keeps them in sync", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Save aliases with one model
+ if err := saveAliases("claude", map[string]string{"primary": "model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Save integration with same model (this is the pattern we use)
+ if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ := loadIntegration("claude")
+ if loaded.Aliases["primary"] != loaded.Models[0] {
+ t.Errorf("aliases.primary (%q) != models[0] (%q)", loaded.Aliases["primary"], loaded.Models[0])
+ }
+ })
+
+ t.Run("out of sync config is detectable", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Simulate out-of-sync state (like manual edit or bug)
+ if err := SaveIntegration("claude", []string{"old-model"}); err != nil {
+ t.Fatal(err)
+ }
+ if err := saveAliases("claude", map[string]string{"primary": "new-model"}); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ := loadIntegration("claude")
+
+ // They should be different (this is the bug state)
+ if loaded.Models[0] == loaded.Aliases["primary"] {
+ t.Error("expected out-of-sync state for this test")
+ }
+
+ // The fix: when updating aliases, also update models
+ if err := SaveIntegration("claude", []string{loaded.Aliases["primary"]}); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ = loadIntegration("claude")
+ if loaded.Models[0] != loaded.Aliases["primary"] {
+ t.Errorf("after fix: models[0] (%q) should equal aliases.primary (%q)",
+ loaded.Models[0], loaded.Aliases["primary"])
+ }
+ })
+
+ t.Run("updating primary alias updates models too", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Initial state
+ if err := SaveIntegration("claude", []string{"initial-model"}); err != nil {
+ t.Fatal(err)
+ }
+ if err := saveAliases("claude", map[string]string{"primary": "initial-model"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Update aliases AND models together
+ newAliases := map[string]string{"primary": "updated-model"}
+ if err := saveAliases("claude", newAliases); err != nil {
+ t.Fatal(err)
+ }
+ if err := SaveIntegration("claude", []string{newAliases["primary"]}); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, _ := loadIntegration("claude")
+ if loaded.Models[0] != "updated-model" {
+ t.Errorf("models[0] should be updated-model, got %q", loaded.Models[0])
+ }
+ if loaded.Aliases["primary"] != "updated-model" {
+ t.Errorf("aliases.primary should be updated-model, got %q", loaded.Aliases["primary"])
+ }
+ })
+}
diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go
new file mode 100644
index 00000000000..fedde7af88d
--- /dev/null
+++ b/cmd/config/config_test.go
@@ -0,0 +1,595 @@
+package config
+
+import (
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+)
+
+// setTestHome sets both HOME (Unix) and USERPROFILE (Windows) for cross-platform tests
+func setTestHome(t *testing.T, dir string) {
+ t.Setenv("HOME", dir)
+ t.Setenv("USERPROFILE", dir)
+}
+
+// editorPaths is a test helper that safely calls Paths if the runner implements Editor
+func editorPaths(r Runner) []string {
+ if editor, ok := r.(Editor); ok {
+ return editor.Paths()
+ }
+ return nil
+}
+
+func TestIntegrationConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("save and load round-trip", func(t *testing.T) {
+ models := []string{"llama3.2", "mistral", "qwen2.5"}
+ if err := SaveIntegration("claude", models); err != nil {
+ t.Fatal(err)
+ }
+
+ config, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(config.Models) != len(models) {
+ t.Errorf("expected %d models, got %d", len(models), len(config.Models))
+ }
+ for i, m := range models {
+ if config.Models[i] != m {
+ t.Errorf("model %d: expected %s, got %s", i, m, config.Models[i])
+ }
+ }
+ })
+
+ t.Run("save and load aliases", func(t *testing.T) {
+ models := []string{"llama3.2"}
+ if err := SaveIntegration("claude", models); err != nil {
+ t.Fatal(err)
+ }
+ aliases := map[string]string{
+ "primary": "llama3.2:70b",
+ "fast": "llama3.2:8b",
+ }
+ if err := saveAliases("claude", aliases); err != nil {
+ t.Fatal(err)
+ }
+
+ config, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if config.Aliases == nil {
+ t.Fatal("expected aliases to be saved")
+ }
+ for k, v := range aliases {
+ if config.Aliases[k] != v {
+ t.Errorf("alias %s: expected %s, got %s", k, v, config.Aliases[k])
+ }
+ }
+ })
+
+ t.Run("saveIntegration preserves aliases", func(t *testing.T) {
+ if err := SaveIntegration("claude", []string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+ if err := saveAliases("claude", map[string]string{"primary": "model-a", "fast": "model-small"}); err != nil {
+ t.Fatal(err)
+ }
+
+ if err := SaveIntegration("claude", []string{"model-b"}); err != nil {
+ t.Fatal(err)
+ }
+ config, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if config.Aliases["primary"] != "model-a" {
+ t.Errorf("expected aliases to be preserved, got %v", config.Aliases)
+ }
+ })
+
+ t.Run("defaultModel returns first model", func(t *testing.T) {
+ SaveIntegration("codex", []string{"model-a", "model-b"})
+
+ config, _ := loadIntegration("codex")
+ defaultModel := ""
+ if len(config.Models) > 0 {
+ defaultModel = config.Models[0]
+ }
+ if defaultModel != "model-a" {
+ t.Errorf("expected model-a, got %s", defaultModel)
+ }
+ })
+
+ t.Run("defaultModel returns empty for no models", func(t *testing.T) {
+ config := &integration{Models: []string{}}
+ defaultModel := ""
+ if len(config.Models) > 0 {
+ defaultModel = config.Models[0]
+ }
+ if defaultModel != "" {
+ t.Errorf("expected empty string, got %s", defaultModel)
+ }
+ })
+
+ t.Run("app name is case-insensitive", func(t *testing.T) {
+ SaveIntegration("Claude", []string{"model-x"})
+
+ config, err := loadIntegration("claude")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defaultModel := ""
+ if len(config.Models) > 0 {
+ defaultModel = config.Models[0]
+ }
+ if defaultModel != "model-x" {
+ t.Errorf("expected model-x, got %s", defaultModel)
+ }
+ })
+
+ t.Run("multiple integrations in single file", func(t *testing.T) {
+ SaveIntegration("app1", []string{"model-1"})
+ SaveIntegration("app2", []string{"model-2"})
+
+ config1, _ := loadIntegration("app1")
+ config2, _ := loadIntegration("app2")
+
+ defaultModel1 := ""
+ if len(config1.Models) > 0 {
+ defaultModel1 = config1.Models[0]
+ }
+ defaultModel2 := ""
+ if len(config2.Models) > 0 {
+ defaultModel2 = config2.Models[0]
+ }
+ if defaultModel1 != "model-1" {
+ t.Errorf("expected model-1, got %s", defaultModel1)
+ }
+ if defaultModel2 != "model-2" {
+ t.Errorf("expected model-2, got %s", defaultModel2)
+ }
+ })
+}
+
+func TestListIntegrations(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("returns empty when no integrations", func(t *testing.T) {
+ configs, err := listIntegrations()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(configs) != 0 {
+ t.Errorf("expected 0 integrations, got %d", len(configs))
+ }
+ })
+
+ t.Run("returns all saved integrations", func(t *testing.T) {
+ SaveIntegration("claude", []string{"model-1"})
+ SaveIntegration("droid", []string{"model-2"})
+
+ configs, err := listIntegrations()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(configs) != 2 {
+ t.Errorf("expected 2 integrations, got %d", len(configs))
+ }
+ })
+}
+
+func TestEditorPaths(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("returns empty for claude (no Editor)", func(t *testing.T) {
+ r := integrations["claude"]
+ paths := editorPaths(r)
+ if len(paths) != 0 {
+ t.Errorf("expected no paths for claude, got %v", paths)
+ }
+ })
+
+ t.Run("returns empty for codex (no Editor)", func(t *testing.T) {
+ r := integrations["codex"]
+ paths := editorPaths(r)
+ if len(paths) != 0 {
+ t.Errorf("expected no paths for codex, got %v", paths)
+ }
+ })
+
+ t.Run("returns empty for droid when no config exists", func(t *testing.T) {
+ r := integrations["droid"]
+ paths := editorPaths(r)
+ if len(paths) != 0 {
+ t.Errorf("expected no paths, got %v", paths)
+ }
+ })
+
+ t.Run("returns path for droid when config exists", func(t *testing.T) {
+ settingsDir, _ := os.UserHomeDir()
+ settingsDir = filepath.Join(settingsDir, ".factory")
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(filepath.Join(settingsDir, "settings.json"), []byte(`{}`), 0o644)
+
+ r := integrations["droid"]
+ paths := editorPaths(r)
+ if len(paths) != 1 {
+ t.Errorf("expected 1 path, got %d", len(paths))
+ }
+ })
+
+ t.Run("returns paths for opencode when configs exist", func(t *testing.T) {
+ home, _ := os.UserHomeDir()
+ configDir := filepath.Join(home, ".config", "opencode")
+ stateDir := filepath.Join(home, ".local", "state", "opencode")
+ os.MkdirAll(configDir, 0o755)
+ os.MkdirAll(stateDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "opencode.json"), []byte(`{}`), 0o644)
+ os.WriteFile(filepath.Join(stateDir, "model.json"), []byte(`{}`), 0o644)
+
+ r := integrations["opencode"]
+ paths := editorPaths(r)
+ if len(paths) != 2 {
+ t.Errorf("expected 2 paths, got %d: %v", len(paths), paths)
+ }
+ })
+}
+
+func TestLoadIntegration_CorruptedJSON(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ dir := filepath.Join(tmpDir, ".ollama")
+ os.MkdirAll(dir, 0o755)
+ os.WriteFile(filepath.Join(dir, "config.json"), []byte(`{corrupted json`), 0o644)
+
+ _, err := loadIntegration("test")
+ if err == nil {
+ t.Error("expected error for nonexistent integration in corrupted file")
+ }
+}
+
+func TestSaveIntegration_NilModels(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ if err := SaveIntegration("test", nil); err != nil {
+ t.Fatalf("saveIntegration with nil models failed: %v", err)
+ }
+
+ config, err := loadIntegration("test")
+ if err != nil {
+ t.Fatalf("loadIntegration failed: %v", err)
+ }
+
+ if config.Models == nil {
+ // nil is acceptable
+ } else if len(config.Models) != 0 {
+ t.Errorf("expected empty or nil models, got %v", config.Models)
+ }
+}
+
+func TestSaveIntegration_EmptyAppName(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ err := SaveIntegration("", []string{"model"})
+ if err == nil {
+ t.Error("expected error for empty app name, got nil")
+ }
+ if err != nil && !strings.Contains(err.Error(), "app name cannot be empty") {
+ t.Errorf("expected 'app name cannot be empty' error, got: %v", err)
+ }
+}
+
+func TestLoadIntegration_NonexistentIntegration(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ _, err := loadIntegration("nonexistent")
+ if err == nil {
+ t.Error("expected error for nonexistent integration, got nil")
+ }
+ if !os.IsNotExist(err) {
+ t.Logf("error type is os.ErrNotExist as expected: %v", err)
+ }
+}
+
+func TestConfigPath(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ path, err := configPath()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ expected := filepath.Join(tmpDir, ".ollama", "config.json")
+ if path != expected {
+ t.Errorf("expected %s, got %s", expected, path)
+ }
+}
+
+func TestLoad(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("returns empty config when file does not exist", func(t *testing.T) {
+ cfg, err := load()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if cfg == nil {
+ t.Fatal("expected non-nil config")
+ }
+ if cfg.Integrations == nil {
+ t.Error("expected non-nil Integrations map")
+ }
+ if len(cfg.Integrations) != 0 {
+ t.Errorf("expected empty Integrations, got %d", len(cfg.Integrations))
+ }
+ })
+
+ t.Run("loads existing config", func(t *testing.T) {
+ path, _ := configPath()
+ os.MkdirAll(filepath.Dir(path), 0o755)
+ os.WriteFile(path, []byte(`{"integrations":{"test":{"models":["model-a"]}}}`), 0o644)
+
+ cfg, err := load()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if cfg.Integrations["test"] == nil {
+ t.Fatal("expected test integration")
+ }
+ if len(cfg.Integrations["test"].Models) != 1 {
+ t.Errorf("expected 1 model, got %d", len(cfg.Integrations["test"].Models))
+ }
+ })
+
+ t.Run("returns error for corrupted JSON", func(t *testing.T) {
+ path, _ := configPath()
+ os.MkdirAll(filepath.Dir(path), 0o755)
+ os.WriteFile(path, []byte(`{corrupted`), 0o644)
+
+ _, err := load()
+ if err == nil {
+ t.Error("expected error for corrupted JSON")
+ }
+ })
+}
+
+func TestMigrateConfig(t *testing.T) {
+ t.Run("migrates legacy file to new location", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ data := []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), data, 0o644)
+
+ migrated, err := migrateConfig()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !migrated {
+ t.Fatal("expected migration to occur")
+ }
+
+ newPath, _ := configPath()
+ got, err := os.ReadFile(newPath)
+ if err != nil {
+ t.Fatalf("new config not found: %v", err)
+ }
+ if string(got) != string(data) {
+ t.Errorf("content mismatch: got %s", got)
+ }
+
+ if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
+ t.Error("legacy file should have been removed")
+ }
+
+ if _, err := os.Stat(legacyDir); !os.IsNotExist(err) {
+ t.Error("legacy directory should have been removed")
+ }
+ })
+
+ t.Run("no-op when no legacy file exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ migrated, err := migrateConfig()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if migrated {
+ t.Error("expected no migration")
+ }
+ })
+
+ t.Run("skips corrupt legacy file", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{corrupt`), 0o644)
+
+ migrated, err := migrateConfig()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if migrated {
+ t.Error("should not migrate corrupt file")
+ }
+
+ if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); os.IsNotExist(err) {
+ t.Error("corrupt legacy file should not have been deleted")
+ }
+ })
+
+ t.Run("new path takes precedence over legacy", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"old":{"models":["old-model"]}}}`), 0o644)
+
+ newDir := filepath.Join(tmpDir, ".ollama")
+ os.WriteFile(filepath.Join(newDir, "config.json"), []byte(`{"integrations":{"new":{"models":["new-model"]}}}`), 0o644)
+
+ cfg, err := load()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, ok := cfg.Integrations["new"]; !ok {
+ t.Error("expected new-path integration to be loaded")
+ }
+ if _, ok := cfg.Integrations["old"]; ok {
+ t.Error("legacy integration should not have been loaded")
+ }
+ })
+
+ t.Run("idempotent when called twice", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
+
+ if _, err := migrateConfig(); err != nil {
+ t.Fatal(err)
+ }
+
+ migrated, err := migrateConfig()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if migrated {
+ t.Error("second migration should be a no-op")
+ }
+ })
+
+ t.Run("legacy directory preserved if not empty", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{}}`), 0o644)
+ os.WriteFile(filepath.Join(legacyDir, "other-file.txt"), []byte("keep me"), 0o644)
+
+ if _, err := migrateConfig(); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(legacyDir); os.IsNotExist(err) {
+ t.Error("directory with other files should not have been removed")
+ }
+ if _, err := os.Stat(filepath.Join(legacyDir, "other-file.txt")); os.IsNotExist(err) {
+ t.Error("other files in legacy directory should be untouched")
+ }
+ })
+
+ t.Run("save writes to new path after migration", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
+
+ // load triggers migration, then save should write to new path
+ if err := SaveIntegration("codex", []string{"qwen2.5"}); err != nil {
+ t.Fatal(err)
+ }
+
+ newPath := filepath.Join(tmpDir, ".ollama", "config.json")
+ if _, err := os.Stat(newPath); os.IsNotExist(err) {
+ t.Error("save should write to new path")
+ }
+
+ // old path should not be recreated
+ if _, err := os.Stat(filepath.Join(legacyDir, "config.json")); !os.IsNotExist(err) {
+ t.Error("save should not recreate legacy path")
+ }
+ })
+
+ t.Run("load triggers migration transparently", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ legacyDir := filepath.Join(tmpDir, ".ollama", "config")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "config.json"), []byte(`{"integrations":{"claude":{"models":["llama3.2"]}}}`), 0o644)
+
+ cfg, err := load()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if cfg.Integrations["claude"] == nil || cfg.Integrations["claude"].Models[0] != "llama3.2" {
+ t.Error("migration via load() did not preserve data")
+ }
+ })
+}
+
+func TestSave(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("creates config file", func(t *testing.T) {
+ cfg := &config{
+ Integrations: map[string]*integration{
+ "test": {Models: []string{"model-a", "model-b"}},
+ },
+ }
+
+ if err := save(cfg); err != nil {
+ t.Fatal(err)
+ }
+
+ path, _ := configPath()
+ if _, err := os.Stat(path); os.IsNotExist(err) {
+ t.Error("config file was not created")
+ }
+ })
+
+ t.Run("round-trip preserves data", func(t *testing.T) {
+ cfg := &config{
+ Integrations: map[string]*integration{
+ "claude": {Models: []string{"llama3.2", "mistral"}},
+ "codex": {Models: []string{"qwen2.5"}},
+ },
+ }
+
+ if err := save(cfg); err != nil {
+ t.Fatal(err)
+ }
+
+ loaded, err := load()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(loaded.Integrations) != 2 {
+ t.Errorf("expected 2 integrations, got %d", len(loaded.Integrations))
+ }
+ if loaded.Integrations["claude"] == nil {
+ t.Error("missing claude integration")
+ }
+ if len(loaded.Integrations["claude"].Models) != 2 {
+ t.Errorf("expected 2 models for claude, got %d", len(loaded.Integrations["claude"].Models))
+ }
+ })
+}
diff --git a/cmd/config/droid.go b/cmd/config/droid.go
new file mode 100644
index 00000000000..d1a9f54dcbb
--- /dev/null
+++ b/cmd/config/droid.go
@@ -0,0 +1,207 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "slices"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/envconfig"
+)
+
+// Droid implements Runner and Editor for Droid integration
+type Droid struct{}
+
+// droidSettings represents the Droid settings.json file (only fields we use)
+type droidSettings struct {
+ CustomModels []modelEntry `json:"customModels"`
+ SessionDefaultSettings sessionSettings `json:"sessionDefaultSettings"`
+}
+
+type sessionSettings struct {
+ Model string `json:"model"`
+ ReasoningEffort string `json:"reasoningEffort"`
+}
+
+type modelEntry struct {
+ Model string `json:"model"`
+ DisplayName string `json:"displayName"`
+ BaseURL string `json:"baseUrl"`
+ APIKey string `json:"apiKey"`
+ Provider string `json:"provider"`
+ MaxOutputTokens int `json:"maxOutputTokens"`
+ SupportsImages bool `json:"supportsImages"`
+ ID string `json:"id"`
+ Index int `json:"index"`
+}
+
+func (d *Droid) String() string { return "Droid" }
+
+func (d *Droid) Run(model string, args []string) error {
+ if _, err := exec.LookPath("droid"); err != nil {
+ return fmt.Errorf("droid is not installed, install from https://docs.factory.ai/cli/getting-started/quickstart")
+ }
+
+ // Call Edit() to ensure config is up-to-date before launch
+ models := []string{model}
+ if config, err := loadIntegration("droid"); err == nil && len(config.Models) > 0 {
+ models = config.Models
+ }
+ var err error
+ models, err = resolveEditorModels("droid", models, func() ([]string, error) {
+ return selectModels(context.Background(), "droid", "")
+ })
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ if err := d.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+
+ cmd := exec.Command("droid", args...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ return cmd.Run()
+}
+
+func (d *Droid) Paths() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+ p := filepath.Join(home, ".factory", "settings.json")
+ if _, err := os.Stat(p); err == nil {
+ return []string{p}
+ }
+ return nil
+}
+
+func (d *Droid) Edit(models []string) error {
+ if len(models) == 0 {
+ return nil
+ }
+
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ settingsPath := filepath.Join(home, ".factory", "settings.json")
+ if err := os.MkdirAll(filepath.Dir(settingsPath), 0o755); err != nil {
+ return err
+ }
+
+ // Read file once, unmarshal twice:
+ // map preserves unknown fields for writing back (including extra fields in model entries)
+ settingsMap := make(map[string]any)
+ var settings droidSettings
+ if data, err := os.ReadFile(settingsPath); err == nil {
+ if err := json.Unmarshal(data, &settingsMap); err != nil {
+ return fmt.Errorf("failed to parse settings file: %w, at: %s", err, settingsPath)
+ }
+ json.Unmarshal(data, &settings) // ignore error, zero values are fine
+ }
+
+ // Keep only non-Ollama models from the raw map (preserves extra fields)
+ // Rebuild Ollama models
+ var nonOllamaModels []any
+ if rawModels, ok := settingsMap["customModels"].([]any); ok {
+ for _, raw := range rawModels {
+ if m, ok := raw.(map[string]any); ok {
+ if m["apiKey"] != "ollama" {
+ nonOllamaModels = append(nonOllamaModels, raw)
+ }
+ }
+ }
+ }
+
+ // Build new Ollama model entries with sequential indices (0, 1, 2, ...)
+ client, _ := api.ClientFromEnvironment()
+
+ var newModels []any
+ var defaultModelID string
+ for i, model := range models {
+ maxOutput := 64000
+ if isCloudModel(context.Background(), client, model) {
+ if l, ok := lookupCloudModelLimit(model); ok {
+ maxOutput = l.Output
+ }
+ }
+ modelID := fmt.Sprintf("custom:%s-%d", model, i)
+ newModels = append(newModels, modelEntry{
+ Model: model,
+ DisplayName: model,
+ BaseURL: envconfig.Host().String() + "/v1",
+ APIKey: "ollama",
+ Provider: "generic-chat-completion-api",
+ MaxOutputTokens: maxOutput,
+ SupportsImages: false,
+ ID: modelID,
+ Index: i,
+ })
+ if i == 0 {
+ defaultModelID = modelID
+ }
+ }
+
+ settingsMap["customModels"] = append(newModels, nonOllamaModels...)
+
+ // Update session default settings (preserve unknown fields in the nested object)
+ sessionSettings, ok := settingsMap["sessionDefaultSettings"].(map[string]any)
+ if !ok {
+ sessionSettings = make(map[string]any)
+ }
+ sessionSettings["model"] = defaultModelID
+
+ if !isValidReasoningEffort(settings.SessionDefaultSettings.ReasoningEffort) {
+ sessionSettings["reasoningEffort"] = "none"
+ }
+
+ settingsMap["sessionDefaultSettings"] = sessionSettings
+
+ data, err := json.MarshalIndent(settingsMap, "", " ")
+ if err != nil {
+ return err
+ }
+ return writeWithBackup(settingsPath, data)
+}
+
+func (d *Droid) Models() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+
+ data, err := os.ReadFile(filepath.Join(home, ".factory", "settings.json"))
+ if err != nil {
+ return nil
+ }
+
+ var settings droidSettings
+ if err := json.Unmarshal(data, &settings); err != nil {
+ return nil
+ }
+
+ var result []string
+ for _, m := range settings.CustomModels {
+ if m.APIKey == "ollama" {
+ result = append(result, m.Model)
+ }
+ }
+ return result
+}
+
+var validReasoningEfforts = []string{"high", "medium", "low", "none"}
+
+func isValidReasoningEffort(effort string) bool {
+ return slices.Contains(validReasoningEfforts, effort)
+}
diff --git a/cmd/config/droid_test.go b/cmd/config/droid_test.go
new file mode 100644
index 00000000000..f13c3e93603
--- /dev/null
+++ b/cmd/config/droid_test.go
@@ -0,0 +1,1351 @@
+package config
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestDroidIntegration(t *testing.T) {
+ d := &Droid{}
+
+ t.Run("String", func(t *testing.T) {
+ if got := d.String(); got != "Droid" {
+ t.Errorf("String() = %q, want %q", got, "Droid")
+ }
+ })
+
+ t.Run("implements Runner", func(t *testing.T) {
+ var _ Runner = d
+ })
+
+ t.Run("implements Editor", func(t *testing.T) {
+ var _ Editor = d
+ })
+}
+
+func TestDroidEdit(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ cleanup := func() {
+ os.RemoveAll(settingsDir)
+ }
+
+ readSettings := func() map[string]any {
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+ return settings
+ }
+
+ getCustomModels := func(settings map[string]any) []map[string]any {
+ models, ok := settings["customModels"].([]any)
+ if !ok {
+ return nil
+ }
+ var result []map[string]any
+ for _, m := range models {
+ if entry, ok := m.(map[string]any); ok {
+ result = append(result, entry)
+ }
+ }
+ return result
+ }
+
+ t.Run("fresh install creates models with sequential indices", func(t *testing.T) {
+ cleanup()
+ if err := d.Edit([]string{"model-a", "model-b"}); err != nil {
+ t.Fatal(err)
+ }
+
+ settings := readSettings()
+ models := getCustomModels(settings)
+
+ if len(models) != 2 {
+ t.Fatalf("expected 2 models, got %d", len(models))
+ }
+
+ // Check first model
+ if models[0]["model"] != "model-a" {
+ t.Errorf("expected model-a, got %s", models[0]["model"])
+ }
+ if models[0]["id"] != "custom:model-a-0" {
+ t.Errorf("expected custom:model-a-0, got %s", models[0]["id"])
+ }
+ if models[0]["index"] != float64(0) {
+ t.Errorf("expected index 0, got %v", models[0]["index"])
+ }
+
+ // Check second model
+ if models[1]["model"] != "model-b" {
+ t.Errorf("expected model-b, got %s", models[1]["model"])
+ }
+ if models[1]["id"] != "custom:model-b-1" {
+ t.Errorf("expected custom:model-b-1, got %s", models[1]["id"])
+ }
+ if models[1]["index"] != float64(1) {
+ t.Errorf("expected index 1, got %v", models[1]["index"])
+ }
+ })
+
+ t.Run("sets sessionDefaultSettings.model to first model ID", func(t *testing.T) {
+ cleanup()
+ if err := d.Edit([]string{"model-a", "model-b"}); err != nil {
+ t.Fatal(err)
+ }
+
+ settings := readSettings()
+ session, ok := settings["sessionDefaultSettings"].(map[string]any)
+ if !ok {
+ t.Fatal("sessionDefaultSettings not found")
+ }
+ if session["model"] != "custom:model-a-0" {
+ t.Errorf("expected custom:model-a-0, got %s", session["model"])
+ }
+ })
+
+ t.Run("re-indexes when models removed", func(t *testing.T) {
+ cleanup()
+ // Add three models
+ d.Edit([]string{"model-a", "model-b", "model-c"})
+
+ // Remove middle model
+ d.Edit([]string{"model-a", "model-c"})
+
+ settings := readSettings()
+ models := getCustomModels(settings)
+
+ if len(models) != 2 {
+ t.Fatalf("expected 2 models, got %d", len(models))
+ }
+
+ // Check indices are sequential 0, 1
+ if models[0]["index"] != float64(0) {
+ t.Errorf("expected index 0, got %v", models[0]["index"])
+ }
+ if models[1]["index"] != float64(1) {
+ t.Errorf("expected index 1, got %v", models[1]["index"])
+ }
+
+ // Check IDs match new indices
+ if models[0]["id"] != "custom:model-a-0" {
+ t.Errorf("expected custom:model-a-0, got %s", models[0]["id"])
+ }
+ if models[1]["id"] != "custom:model-c-1" {
+ t.Errorf("expected custom:model-c-1, got %s", models[1]["id"])
+ }
+ })
+
+ t.Run("preserves non-Ollama custom models", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(settingsDir, 0o755)
+ // Pre-existing non-Ollama model
+ os.WriteFile(settingsPath, []byte(`{
+ "customModels": [
+ {"model": "gpt-4", "displayName": "GPT-4", "provider": "openai"}
+ ]
+ }`), 0o644)
+
+ d.Edit([]string{"model-a"})
+
+ settings := readSettings()
+ models := getCustomModels(settings)
+
+ if len(models) != 2 {
+ t.Fatalf("expected 2 models (1 Ollama + 1 non-Ollama), got %d", len(models))
+ }
+
+ // Ollama model should be first
+ if models[0]["model"] != "model-a" {
+ t.Errorf("expected Ollama model first, got %s", models[0]["model"])
+ }
+
+ // Non-Ollama model should be preserved at end
+ if models[1]["model"] != "gpt-4" {
+ t.Errorf("expected gpt-4 preserved, got %s", models[1]["model"])
+ }
+ })
+
+ t.Run("preserves other settings", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(`{
+ "theme": "dark",
+ "enableHooks": true,
+ "sessionDefaultSettings": {"autonomyMode": "auto-high"}
+ }`), 0o644)
+
+ d.Edit([]string{"model-a"})
+
+ settings := readSettings()
+
+ if settings["theme"] != "dark" {
+ t.Error("theme was not preserved")
+ }
+ if settings["enableHooks"] != true {
+ t.Error("enableHooks was not preserved")
+ }
+
+ session := settings["sessionDefaultSettings"].(map[string]any)
+ if session["autonomyMode"] != "auto-high" {
+ t.Error("autonomyMode was not preserved")
+ }
+ })
+
+ t.Run("required fields present", func(t *testing.T) {
+ cleanup()
+ d.Edit([]string{"test-model"})
+
+ settings := readSettings()
+ models := getCustomModels(settings)
+
+ if len(models) != 1 {
+ t.Fatal("expected 1 model")
+ }
+
+ model := models[0]
+ requiredFields := []string{"model", "displayName", "baseUrl", "apiKey", "provider", "maxOutputTokens", "id", "index"}
+ for _, field := range requiredFields {
+ if model[field] == nil {
+ t.Errorf("missing required field: %s", field)
+ }
+ }
+
+ if model["baseUrl"] != "http://127.0.0.1:11434/v1" {
+ t.Errorf("unexpected baseUrl: %s", model["baseUrl"])
+ }
+ if model["apiKey"] != "ollama" {
+ t.Errorf("unexpected apiKey: %s", model["apiKey"])
+ }
+ if model["provider"] != "generic-chat-completion-api" {
+ t.Errorf("unexpected provider: %s", model["provider"])
+ }
+ })
+
+ t.Run("fixes invalid reasoningEffort", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(settingsDir, 0o755)
+ // Pre-existing settings with invalid reasoningEffort
+ os.WriteFile(settingsPath, []byte(`{
+ "sessionDefaultSettings": {"reasoningEffort": "off"}
+ }`), 0o644)
+
+ d.Edit([]string{"model-a"})
+
+ settings := readSettings()
+ session := settings["sessionDefaultSettings"].(map[string]any)
+
+ if session["reasoningEffort"] != "none" {
+ t.Errorf("expected reasoningEffort to be fixed to 'none', got %s", session["reasoningEffort"])
+ }
+ })
+
+ t.Run("preserves valid reasoningEffort", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(`{
+ "sessionDefaultSettings": {"reasoningEffort": "high"}
+ }`), 0o644)
+
+ d.Edit([]string{"model-a"})
+
+ settings := readSettings()
+ session := settings["sessionDefaultSettings"].(map[string]any)
+
+ if session["reasoningEffort"] != "high" {
+ t.Errorf("expected reasoningEffort to remain 'high', got %s", session["reasoningEffort"])
+ }
+ })
+}
+
+// Edge case tests for droid.go
+
+func TestDroidEdit_CorruptedJSON(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(`{corrupted json content`), 0o644)
+
+ // Corrupted JSON should return an error so user knows something is wrong
+ err := d.Edit([]string{"model-a"})
+ if err == nil {
+ t.Fatal("expected error for corrupted JSON, got nil")
+ }
+
+ // Original corrupted file should be preserved (not overwritten)
+ data, _ := os.ReadFile(settingsPath)
+ if string(data) != `{corrupted json content` {
+ t.Errorf("corrupted file was modified: got %s", string(data))
+ }
+}
+
+func TestDroidEdit_WrongTypeCustomModels(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ // customModels is a string instead of array
+ os.WriteFile(settingsPath, []byte(`{"customModels": "not an array"}`), 0o644)
+
+ // Should not panic - wrong type should be handled gracefully
+ err := d.Edit([]string{"model-a"})
+ if err != nil {
+ t.Fatalf("Edit failed with wrong type customModels: %v", err)
+ }
+
+ // Verify models were added correctly
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ customModels, ok := settings["customModels"].([]any)
+ if !ok {
+ t.Fatalf("customModels should be array after setup, got %T", settings["customModels"])
+ }
+ if len(customModels) != 1 {
+ t.Errorf("expected 1 model, got %d", len(customModels))
+ }
+}
+
+func TestDroidEdit_EmptyModels(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ originalContent := `{"customModels": [{"model": "existing"}]}`
+ os.WriteFile(settingsPath, []byte(originalContent), 0o644)
+
+ // Empty models should be no-op
+ err := d.Edit([]string{})
+ if err != nil {
+ t.Fatalf("Edit with empty models failed: %v", err)
+ }
+
+ // Original content should be preserved (file not modified)
+ data, _ := os.ReadFile(settingsPath)
+ if string(data) != originalContent {
+ t.Errorf("empty models should not modify file, but content changed")
+ }
+}
+
+func TestDroidEdit_DuplicateModels(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ // Add same model twice
+ err := d.Edit([]string{"model-a", "model-a"})
+ if err != nil {
+ t.Fatalf("Edit with duplicates failed: %v", err)
+ }
+
+ settings, err := readJSONFile(settingsPath)
+ if err != nil {
+ t.Fatalf("readJSONFile failed: %v", err)
+ }
+
+ customModels, _ := settings["customModels"].([]any)
+ // Document current behavior: duplicates are kept as separate entries
+ if len(customModels) != 2 {
+ t.Logf("Note: duplicates result in %d entries (documenting behavior)", len(customModels))
+ }
+}
+
+func TestDroidEdit_MalformedModelEntry(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ // Model entry is a string instead of a map
+ os.WriteFile(settingsPath, []byte(`{"customModels": ["not a map", 123]}`), 0o644)
+
+ err := d.Edit([]string{"model-a"})
+ if err != nil {
+ t.Fatalf("Edit with malformed entries failed: %v", err)
+ }
+
+ // Malformed entries (non-object) are dropped - only valid model objects are preserved
+ settings, _ := readJSONFile(settingsPath)
+ customModels, _ := settings["customModels"].([]any)
+
+ // Should have: 1 new Ollama model only (malformed entries dropped)
+ if len(customModels) != 1 {
+ t.Errorf("expected 1 entry (malformed entries dropped), got %d", len(customModels))
+ }
+}
+
+func TestDroidEdit_WrongTypeSessionSettings(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ // sessionDefaultSettings is a string instead of map
+ os.WriteFile(settingsPath, []byte(`{"sessionDefaultSettings": "not a map"}`), 0o644)
+
+ err := d.Edit([]string{"model-a"})
+ if err != nil {
+ t.Fatalf("Edit with wrong type sessionDefaultSettings failed: %v", err)
+ }
+
+ // Should create proper sessionDefaultSettings
+ settings, _ := readJSONFile(settingsPath)
+ session, ok := settings["sessionDefaultSettings"].(map[string]any)
+ if !ok {
+ t.Fatalf("sessionDefaultSettings should be map after setup, got %T", settings["sessionDefaultSettings"])
+ }
+ if session["model"] == nil {
+ t.Error("expected model to be set in sessionDefaultSettings")
+ }
+}
+
+// testDroidSettingsFixture is a representative settings.json fixture for testing.
+// It covers: simple fields, arrays, nested objects, and customModels.
+const testDroidSettingsFixture = `{
+ "commandAllowlist": ["ls", "pwd", "git status"],
+ "diffMode": "github",
+ "enableHooks": true,
+ "hooks": {
+ "claudeHooksImported": true,
+ "importedClaudeHooks": ["uv run ruff check", "echo test"]
+ },
+ "ideExtensionPromptedAt": {
+ "cursor": 1763081579486,
+ "vscode": 1762992990179
+ },
+ "customModels": [
+ {
+ "model": "existing-ollama-model",
+ "displayName": "existing-ollama-model",
+ "baseUrl": "http://127.0.0.1:11434/v1",
+ "apiKey": "ollama",
+ "provider": "generic-chat-completion-api",
+ "maxOutputTokens": 64000,
+ "supportsImages": false,
+ "id": "custom:existing-ollama-model-0",
+ "index": 0
+ },
+ {
+ "model": "gpt-4",
+ "displayName": "GPT-4",
+ "baseUrl": "https://api.openai.com/v1",
+ "apiKey": "sk-xxx",
+ "provider": "openai",
+ "maxOutputTokens": 4096,
+ "supportsImages": true,
+ "id": "openai-gpt4",
+ "index": 1,
+ "customField": "should be preserved"
+ }
+ ],
+ "sessionDefaultSettings": {
+ "autonomyMode": "auto-medium",
+ "model": "custom:existing-ollama-model-0",
+ "reasoningEffort": "high"
+ },
+ "todoDisplayMode": "pinned"
+}`
+
+func TestDroidEdit_RoundTrip(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(testDroidSettingsFixture), 0o644)
+
+ // Edit with new models
+ if err := d.Edit([]string{"llama3", "mistral"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Read back and verify
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ // Verify unknown top-level fields preserved
+ if settings["diffMode"] != "github" {
+ t.Error("diffMode not preserved")
+ }
+ if settings["enableHooks"] != true {
+ t.Error("enableHooks not preserved")
+ }
+ if settings["todoDisplayMode"] != "pinned" {
+ t.Error("todoDisplayMode not preserved")
+ }
+
+ // Verify arrays preserved
+ allowlist, ok := settings["commandAllowlist"].([]any)
+ if !ok || len(allowlist) != 3 {
+ t.Error("commandAllowlist not preserved")
+ }
+
+ // Verify nested objects preserved
+ hooks, ok := settings["hooks"].(map[string]any)
+ if !ok {
+ t.Fatal("hooks not preserved")
+ }
+ if hooks["claudeHooksImported"] != true {
+ t.Error("hooks.claudeHooksImported not preserved")
+ }
+ importedHooks, ok := hooks["importedClaudeHooks"].([]any)
+ if !ok || len(importedHooks) != 2 {
+ t.Error("hooks.importedClaudeHooks not preserved")
+ }
+
+ // Verify deeply nested numeric values preserved
+ idePrompted, ok := settings["ideExtensionPromptedAt"].(map[string]any)
+ if !ok {
+ t.Fatal("ideExtensionPromptedAt not preserved")
+ }
+ if idePrompted["cursor"] != float64(1763081579486) {
+ t.Error("ideExtensionPromptedAt.cursor not preserved")
+ }
+
+ // Verify sessionDefaultSettings unknown fields preserved
+ session, ok := settings["sessionDefaultSettings"].(map[string]any)
+ if !ok {
+ t.Fatal("sessionDefaultSettings not preserved")
+ }
+ if session["autonomyMode"] != "auto-medium" {
+ t.Error("sessionDefaultSettings.autonomyMode not preserved")
+ }
+ if session["reasoningEffort"] != "high" {
+ t.Error("sessionDefaultSettings.reasoningEffort not preserved (was valid)")
+ }
+ // model should be updated
+ if session["model"] != "custom:llama3-0" {
+ t.Errorf("sessionDefaultSettings.model not updated, got %s", session["model"])
+ }
+
+ // Verify customModels: old ollama replaced, non-ollama preserved with extra fields
+ models, ok := settings["customModels"].([]any)
+ if !ok {
+ t.Fatal("customModels not preserved")
+ }
+ if len(models) != 3 { // 2 new ollama + 1 non-ollama
+ t.Fatalf("expected 3 models, got %d", len(models))
+ }
+
+ // First two should be new Ollama models
+ m0 := models[0].(map[string]any)
+ if m0["model"] != "llama3" || m0["apiKey"] != "ollama" {
+ t.Error("first model should be llama3")
+ }
+ m1 := models[1].(map[string]any)
+ if m1["model"] != "mistral" || m1["apiKey"] != "ollama" {
+ t.Error("second model should be mistral")
+ }
+
+ // Third should be preserved non-Ollama with extra field
+ m2 := models[2].(map[string]any)
+ if m2["model"] != "gpt-4" {
+ t.Error("non-Ollama model not preserved")
+ }
+ if m2["customField"] != "should be preserved" {
+ t.Error("non-Ollama model's extra field not preserved")
+ }
+}
+
+func TestDroidEdit_PreservesUnknownFields(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ readSettings := func() map[string]any {
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+ return settings
+ }
+
+ t.Run("preserves all JSON value types", func(t *testing.T) {
+ os.RemoveAll(settingsDir)
+ os.MkdirAll(settingsDir, 0o755)
+
+ original := `{
+ "stringField": "value",
+ "numberField": 42,
+ "floatField": 3.14,
+ "boolField": true,
+ "nullField": null,
+ "arrayField": [1, "two", true],
+ "objectField": {"nested": "value"},
+ "customModels": [],
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ settings := readSettings()
+
+ if settings["stringField"] != "value" {
+ t.Error("stringField not preserved")
+ }
+ if settings["numberField"] != float64(42) {
+ t.Error("numberField not preserved")
+ }
+ if settings["floatField"] != 3.14 {
+ t.Error("floatField not preserved")
+ }
+ if settings["boolField"] != true {
+ t.Error("boolField not preserved")
+ }
+ if settings["nullField"] != nil {
+ t.Error("nullField not preserved")
+ }
+ arr, ok := settings["arrayField"].([]any)
+ if !ok || len(arr) != 3 {
+ t.Error("arrayField not preserved")
+ }
+ obj, ok := settings["objectField"].(map[string]any)
+ if !ok || obj["nested"] != "value" {
+ t.Error("objectField not preserved")
+ }
+ })
+
+ t.Run("preserves extra fields in non-Ollama models", func(t *testing.T) {
+ os.RemoveAll(settingsDir)
+ os.MkdirAll(settingsDir, 0o755)
+
+ original := `{
+ "customModels": [{
+ "model": "gpt-4",
+ "apiKey": "sk-xxx",
+ "extraField": "preserved",
+ "nestedExtra": {"foo": "bar"}
+ }]
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"llama3"}); err != nil {
+ t.Fatal(err)
+ }
+
+ settings := readSettings()
+ models := settings["customModels"].([]any)
+ gpt4 := models[1].(map[string]any) // non-Ollama is second
+
+ if gpt4["extraField"] != "preserved" {
+ t.Error("extraField not preserved")
+ }
+ nested := gpt4["nestedExtra"].(map[string]any)
+ if nested["foo"] != "bar" {
+ t.Error("nestedExtra not preserved")
+ }
+ })
+}
+
+func TestIsValidReasoningEffort(t *testing.T) {
+ tests := []struct {
+ effort string
+ valid bool
+ }{
+ {"high", true},
+ {"medium", true},
+ {"low", true},
+ {"none", true},
+ {"off", false},
+ {"", false},
+ {"HIGH", false}, // case sensitive
+ {"max", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.effort, func(t *testing.T) {
+ got := isValidReasoningEffort(tt.effort)
+ if got != tt.valid {
+ t.Errorf("isValidReasoningEffort(%q) = %v, want %v", tt.effort, got, tt.valid)
+ }
+ })
+ }
+}
+
+func TestDroidEdit_Idempotent(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(testDroidSettingsFixture), 0o644)
+
+ // Edit twice with same models
+ d.Edit([]string{"llama3", "mistral"})
+ firstData, _ := os.ReadFile(settingsPath)
+
+ d.Edit([]string{"llama3", "mistral"})
+ secondData, _ := os.ReadFile(settingsPath)
+
+ // Results should be identical
+ if string(firstData) != string(secondData) {
+ t.Error("repeated edits with same models produced different results")
+ }
+}
+
+func TestDroidEdit_MultipleConsecutiveEdits(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(testDroidSettingsFixture), 0o644)
+
+ // Multiple edits shouldn't accumulate garbage or lose data
+ for i := range 10 {
+ models := []string{"model-a", "model-b"}
+ if i%2 == 0 {
+ models = []string{"model-x", "model-y", "model-z"}
+ }
+ if err := d.Edit(models); err != nil {
+ t.Fatalf("edit %d failed: %v", i, err)
+ }
+ }
+
+ // Verify file is still valid JSON and preserves original fields
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ if err := json.Unmarshal(data, &settings); err != nil {
+ t.Fatalf("file is not valid JSON after multiple edits: %v", err)
+ }
+
+ // Original fields should still be there
+ if settings["diffMode"] != "github" {
+ t.Error("diffMode lost after multiple edits")
+ }
+ if settings["enableHooks"] != true {
+ t.Error("enableHooks lost after multiple edits")
+ }
+
+ // Non-Ollama model should still be preserved
+ models := settings["customModels"].([]any)
+ foundOther := false
+ for _, m := range models {
+ if entry, ok := m.(map[string]any); ok {
+ if entry["model"] == "gpt-4" {
+ foundOther = true
+ if entry["customField"] != "should be preserved" {
+ t.Error("other customField lost after multiple edits")
+ }
+ }
+ }
+ }
+ if !foundOther {
+ t.Error("other model lost after multiple edits")
+ }
+}
+
+func TestDroidEdit_UnicodeAndSpecialCharacters(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // Settings with unicode and special characters
+ original := `{
+ "userName": "日本語テスト",
+ "emoji": "🚀🎉💻",
+ "specialChars": "quotes: \"test\" and 'test', backslash: \\, newline: \n, tab: \t",
+ "unicodeEscape": "\u0048\u0065\u006c\u006c\u006f",
+ "customModels": [],
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ if settings["userName"] != "日本語テスト" {
+ t.Error("Japanese characters not preserved")
+ }
+ if settings["emoji"] != "🚀🎉💻" {
+ t.Error("emoji not preserved")
+ }
+ // Note: JSON encoding will normalize escape sequences
+ if settings["unicodeEscape"] != "Hello" {
+ t.Error("unicode escape sequence not preserved")
+ }
+}
+
+func TestDroidEdit_LargeNumbers(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // Large numbers and timestamps (common in settings files)
+ original := `{
+ "timestamp": 1763081579486,
+ "largeInt": 9007199254740991,
+ "negativeNum": -12345,
+ "floatNum": 3.141592653589793,
+ "scientificNotation": 1.23e10,
+ "customModels": [],
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ if settings["timestamp"] != float64(1763081579486) {
+ t.Errorf("timestamp not preserved: got %v", settings["timestamp"])
+ }
+ if settings["largeInt"] != float64(9007199254740991) {
+ t.Errorf("largeInt not preserved: got %v", settings["largeInt"])
+ }
+ if settings["negativeNum"] != float64(-12345) {
+ t.Error("negativeNum not preserved")
+ }
+ if settings["floatNum"] != 3.141592653589793 {
+ t.Error("floatNum not preserved")
+ }
+}
+
+func TestDroidEdit_EmptyAndNullValues(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ original := `{
+ "emptyString": "",
+ "nullValue": null,
+ "emptyArray": [],
+ "emptyObject": {},
+ "falseBool": false,
+ "zeroNumber": 0,
+ "customModels": [],
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ if settings["emptyString"] != "" {
+ t.Error("emptyString not preserved")
+ }
+ if settings["nullValue"] != nil {
+ t.Error("nullValue not preserved as null")
+ }
+ if arr, ok := settings["emptyArray"].([]any); !ok || len(arr) != 0 {
+ t.Error("emptyArray not preserved")
+ }
+ if obj, ok := settings["emptyObject"].(map[string]any); !ok || len(obj) != 0 {
+ t.Error("emptyObject not preserved")
+ }
+ if settings["falseBool"] != false {
+ t.Error("falseBool not preserved (false vs missing)")
+ }
+ if settings["zeroNumber"] != float64(0) {
+ t.Error("zeroNumber not preserved")
+ }
+}
+
+func TestDroidEdit_DeeplyNestedStructures(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ original := `{
+ "level1": {
+ "level2": {
+ "level3": {
+ "level4": {
+ "deepValue": "found me",
+ "deepArray": [1, 2, {"nested": true}]
+ }
+ }
+ }
+ },
+ "customModels": [],
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ // Navigate to deeply nested value
+ l1 := settings["level1"].(map[string]any)
+ l2 := l1["level2"].(map[string]any)
+ l3 := l2["level3"].(map[string]any)
+ l4 := l3["level4"].(map[string]any)
+
+ if l4["deepValue"] != "found me" {
+ t.Error("deeply nested value not preserved")
+ }
+
+ deepArray := l4["deepArray"].([]any)
+ if len(deepArray) != 3 {
+ t.Error("deeply nested array not preserved")
+ }
+ nestedInArray := deepArray[2].(map[string]any)
+ if nestedInArray["nested"] != true {
+ t.Error("object nested in array not preserved")
+ }
+}
+
+func TestDroidEdit_ModelNamesWithSpecialCharacters(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ // Test model names with colons, slashes, special chars
+ specialModels := []string{
+ "qwen3:480b-cloud",
+ "llama3.2:70b",
+ "model/with/slashes",
+ "model-with-dashes",
+ "model_with_underscores",
+ }
+
+ if err := d.Edit(specialModels); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ models := settings["customModels"].([]any)
+ if len(models) != len(specialModels) {
+ t.Fatalf("expected %d models, got %d", len(specialModels), len(models))
+ }
+
+ for i, expected := range specialModels {
+ m := models[i].(map[string]any)
+ if m["model"] != expected {
+ t.Errorf("model %d: expected %s, got %s", i, expected, m["model"])
+ }
+ }
+}
+
+func TestDroidEdit_MissingCustomModelsKey(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // No customModels key at all
+ original := `{
+ "diffMode": "github",
+ "sessionDefaultSettings": {"autonomyMode": "auto-high"}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ // Original fields preserved
+ if settings["diffMode"] != "github" {
+ t.Error("diffMode not preserved")
+ }
+
+ // customModels created
+ models, ok := settings["customModels"].([]any)
+ if !ok || len(models) != 1 {
+ t.Error("customModels not created properly")
+ }
+}
+
+func TestDroidEdit_NullCustomModels(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ original := `{
+ "customModels": null,
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ models, ok := settings["customModels"].([]any)
+ if !ok || len(models) != 1 {
+ t.Error("null customModels not handled properly")
+ }
+}
+
+func TestDroidEdit_MinifiedJSON(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // Minified JSON (no whitespace)
+ original := `{"diffMode":"github","enableHooks":true,"hooks":{"imported":["cmd1","cmd2"]},"customModels":[],"sessionDefaultSettings":{}}`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ if err := json.Unmarshal(data, &settings); err != nil {
+ t.Fatal("output is not valid JSON")
+ }
+
+ if settings["diffMode"] != "github" {
+ t.Error("diffMode not preserved from minified JSON")
+ }
+ if settings["enableHooks"] != true {
+ t.Error("enableHooks not preserved from minified JSON")
+ }
+}
+
+func TestDroidEdit_CreatesDirectoryIfMissing(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+
+ // Directory doesn't exist
+ if _, err := os.Stat(settingsDir); !os.IsNotExist(err) {
+ t.Fatal("directory should not exist before test")
+ }
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Directory should be created
+ if _, err := os.Stat(settingsDir); os.IsNotExist(err) {
+ t.Fatal("directory was not created")
+ }
+
+ // File should exist and be valid
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+ data, err := os.ReadFile(settingsPath)
+ if err != nil {
+ t.Fatal("settings file not created")
+ }
+
+ var settings map[string]any
+ if err := json.Unmarshal(data, &settings); err != nil {
+ t.Fatal("created file is not valid JSON")
+ }
+}
+
+func TestDroidEdit_PreservesFileAfterError(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // Valid original content
+ original := `{"diffMode": "github", "customModels": [], "sessionDefaultSettings": {}}`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ // Empty models list is a no-op, should not modify file
+ d.Edit([]string{})
+
+ data, _ := os.ReadFile(settingsPath)
+ if string(data) != original {
+ t.Error("file was modified when it should not have been")
+ }
+}
+
+func TestDroidEdit_BackupCreated(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+ backupDir := filepath.Join(os.TempDir(), "ollama-backups")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // Use a unique marker to identify our backup
+ uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
+ original := fmt.Sprintf(`{"diffMode": "%s", "customModels": [], "sessionDefaultSettings": {}}`, uniqueMarker)
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Find backup containing our unique marker
+ backups, _ := filepath.Glob(filepath.Join(backupDir, "settings.json.*"))
+ foundBackup := false
+ for _, backup := range backups {
+ data, err := os.ReadFile(backup)
+ if err != nil {
+ continue
+ }
+ if string(data) == original {
+ foundBackup = true
+ break
+ }
+ }
+
+ if !foundBackup {
+ t.Error("backup with original content not found")
+ }
+
+ // Main file should be modified
+ newData, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(newData, &settings)
+
+ models := settings["customModels"].([]any)
+ if len(models) != 1 {
+ t.Error("main file was not updated")
+ }
+}
+
+func TestDroidEdit_LargeNumberOfModels(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+ os.WriteFile(settingsPath, []byte(`{"customModels": [], "sessionDefaultSettings": {}}`), 0o644)
+
+ // Add many models
+ var models []string
+ for i := range 100 {
+ models = append(models, fmt.Sprintf("model-%d", i))
+ }
+
+ if err := d.Edit(models); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ customModels := settings["customModels"].([]any)
+ if len(customModels) != 100 {
+ t.Errorf("expected 100 models, got %d", len(customModels))
+ }
+
+ // Verify indices are correct
+ for i, m := range customModels {
+ entry := m.(map[string]any)
+ if entry["index"] != float64(i) {
+ t.Errorf("model %d has wrong index: %v", i, entry["index"])
+ }
+ }
+}
+
+func TestDroidEdit_LocalModelDefaultMaxOutput(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ if err := d.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ models := settings["customModels"].([]any)
+ entry := models[0].(map[string]any)
+ if entry["maxOutputTokens"] != float64(64000) {
+ t.Errorf("local model maxOutputTokens = %v, want 64000", entry["maxOutputTokens"])
+ }
+}
+
+func TestDroidEdit_CloudModelLimitsUsed(t *testing.T) {
+ // Verify that every cloud model in cloudModelLimits has a valid output
+ // value that would be used for maxOutputTokens when isCloudModel returns true.
+ // :cloud suffix stripping must also work since that's how users specify them.
+ for name, expected := range cloudModelLimits {
+ t.Run(name, func(t *testing.T) {
+ l, ok := lookupCloudModelLimit(name)
+ if !ok {
+ t.Fatalf("lookupCloudModelLimit(%q) returned false", name)
+ }
+ if l.Output != expected.Output {
+ t.Errorf("output = %d, want %d", l.Output, expected.Output)
+ }
+ // Also verify :cloud suffix lookup
+ cloudName := name + ":cloud"
+ l2, ok := lookupCloudModelLimit(cloudName)
+ if !ok {
+ t.Fatalf("lookupCloudModelLimit(%q) returned false", cloudName)
+ }
+ if l2.Output != expected.Output {
+ t.Errorf(":cloud output = %d, want %d", l2.Output, expected.Output)
+ }
+ })
+ }
+}
+
+func TestDroidEdit_ArraysWithMixedTypes(t *testing.T) {
+ d := &Droid{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ settingsDir := filepath.Join(tmpDir, ".factory")
+ settingsPath := filepath.Join(settingsDir, "settings.json")
+
+ os.MkdirAll(settingsDir, 0o755)
+
+ // Arrays with mixed types (valid JSON)
+ original := `{
+ "mixedArray": [1, "two", true, null, {"nested": "obj"}, [1,2,3]],
+ "customModels": [],
+ "sessionDefaultSettings": {}
+ }`
+ os.WriteFile(settingsPath, []byte(original), 0o644)
+
+ if err := d.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(settingsPath)
+ var settings map[string]any
+ json.Unmarshal(data, &settings)
+
+ arr := settings["mixedArray"].([]any)
+ if len(arr) != 6 {
+ t.Error("mixedArray length not preserved")
+ }
+ if arr[0] != float64(1) {
+ t.Error("number in mixed array not preserved")
+ }
+ if arr[1] != "two" {
+ t.Error("string in mixed array not preserved")
+ }
+ if arr[2] != true {
+ t.Error("bool in mixed array not preserved")
+ }
+ if arr[3] != nil {
+ t.Error("null in mixed array not preserved")
+ }
+ if nested, ok := arr[4].(map[string]any); !ok || nested["nested"] != "obj" {
+ t.Error("object in mixed array not preserved")
+ }
+ if innerArr, ok := arr[5].([]any); !ok || len(innerArr) != 3 {
+ t.Error("array in mixed array not preserved")
+ }
+}
diff --git a/cmd/config/files.go b/cmd/config/files.go
new file mode 100644
index 00000000000..545e25c4df5
--- /dev/null
+++ b/cmd/config/files.go
@@ -0,0 +1,99 @@
+package config
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "time"
+)
+
+func readJSONFile(path string) (map[string]any, error) {
+ data, err := os.ReadFile(path)
+ if err != nil {
+ return nil, err
+ }
+ var result map[string]any
+ if err := json.Unmarshal(data, &result); err != nil {
+ return nil, err
+ }
+ return result, nil
+}
+
+func copyFile(src, dst string) error {
+ info, err := os.Stat(src)
+ if err != nil {
+ return err
+ }
+ data, err := os.ReadFile(src)
+ if err != nil {
+ return err
+ }
+ return os.WriteFile(dst, data, info.Mode().Perm())
+}
+
+func backupDir() string {
+ return filepath.Join(os.TempDir(), "ollama-backups")
+}
+
+func backupToTmp(srcPath string) (string, error) {
+ dir := backupDir()
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return "", err
+ }
+
+ backupPath := filepath.Join(dir, fmt.Sprintf("%s.%d", filepath.Base(srcPath), time.Now().Unix()))
+ if err := copyFile(srcPath, backupPath); err != nil {
+ return "", err
+ }
+ return backupPath, nil
+}
+
+// writeWithBackup writes data to path via temp file + rename, backing up any existing file first
+func writeWithBackup(path string, data []byte) error {
+ var backupPath string
+ // backup must be created before any writes to the target file
+ if existingContent, err := os.ReadFile(path); err == nil {
+ if !bytes.Equal(existingContent, data) {
+ backupPath, err = backupToTmp(path)
+ if err != nil {
+ return fmt.Errorf("backup failed: %w", err)
+ }
+ }
+ } else if !os.IsNotExist(err) {
+ return fmt.Errorf("read existing file: %w", err)
+ }
+
+ dir := filepath.Dir(path)
+ tmp, err := os.CreateTemp(dir, ".tmp-*")
+ if err != nil {
+ return fmt.Errorf("create temp failed: %w", err)
+ }
+ tmpPath := tmp.Name()
+
+ if _, err := tmp.Write(data); err != nil {
+ _ = tmp.Close()
+ _ = os.Remove(tmpPath)
+ return fmt.Errorf("write failed: %w", err)
+ }
+ if err := tmp.Sync(); err != nil {
+ _ = tmp.Close()
+ _ = os.Remove(tmpPath)
+ return fmt.Errorf("sync failed: %w", err)
+ }
+ if err := tmp.Close(); err != nil {
+ _ = os.Remove(tmpPath)
+ return fmt.Errorf("close failed: %w", err)
+ }
+
+ if err := os.Rename(tmpPath, path); err != nil {
+ _ = os.Remove(tmpPath)
+ if backupPath != "" {
+ _ = copyFile(backupPath, path)
+ }
+ return fmt.Errorf("rename failed: %w", err)
+ }
+
+ return nil
+}
diff --git a/cmd/config/files_test.go b/cmd/config/files_test.go
new file mode 100644
index 00000000000..e0aaea2b5de
--- /dev/null
+++ b/cmd/config/files_test.go
@@ -0,0 +1,502 @@
+package config
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "runtime"
+ "testing"
+)
+
+func mustMarshal(t *testing.T, v any) []byte {
+ t.Helper()
+ data, err := json.MarshalIndent(v, "", " ")
+ if err != nil {
+ t.Fatal(err)
+ }
+ return data
+}
+
+func TestWriteWithBackup(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ t.Run("creates file", func(t *testing.T) {
+ path := filepath.Join(tmpDir, "new.json")
+ data := mustMarshal(t, map[string]string{"key": "value"})
+
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatal(err)
+ }
+
+ content, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var result map[string]string
+ if err := json.Unmarshal(content, &result); err != nil {
+ t.Fatal(err)
+ }
+ if result["key"] != "value" {
+ t.Errorf("expected value, got %s", result["key"])
+ }
+ })
+
+ t.Run("creates backup in /tmp/ollama-backups", func(t *testing.T) {
+ path := filepath.Join(tmpDir, "backup.json")
+
+ os.WriteFile(path, []byte(`{"original": true}`), 0o644)
+
+ data := mustMarshal(t, map[string]bool{"updated": true})
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatal(err)
+ }
+
+ entries, err := os.ReadDir(backupDir())
+ if err != nil {
+ t.Fatal("backup directory not created")
+ }
+
+ var foundBackup bool
+ for _, entry := range entries {
+ if filepath.Ext(entry.Name()) != ".json" {
+ name := entry.Name()
+ if len(name) > len("backup.json.") && name[:len("backup.json.")] == "backup.json." {
+ backupPath := filepath.Join(backupDir(), name)
+ backup, err := os.ReadFile(backupPath)
+ if err == nil {
+ var backupData map[string]bool
+ json.Unmarshal(backup, &backupData)
+ if backupData["original"] {
+ foundBackup = true
+ os.Remove(backupPath)
+ break
+ }
+ }
+ }
+ }
+ }
+
+ if !foundBackup {
+ t.Error("backup file not created in /tmp/ollama-backups")
+ }
+
+ current, _ := os.ReadFile(path)
+ var currentData map[string]bool
+ json.Unmarshal(current, ¤tData)
+ if !currentData["updated"] {
+ t.Error("file doesn't contain updated data")
+ }
+ })
+
+ t.Run("no backup for new file", func(t *testing.T) {
+ path := filepath.Join(tmpDir, "nobak.json")
+
+ data := mustMarshal(t, map[string]string{"new": "file"})
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatal(err)
+ }
+
+ entries, _ := os.ReadDir(backupDir())
+ for _, entry := range entries {
+ if len(entry.Name()) > len("nobak.json.") && entry.Name()[:len("nobak.json.")] == "nobak.json." {
+ t.Error("backup should not exist for new file")
+ }
+ }
+ })
+
+ t.Run("no backup when content unchanged", func(t *testing.T) {
+ path := filepath.Join(tmpDir, "unchanged.json")
+
+ data := mustMarshal(t, map[string]string{"key": "value"})
+
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatal(err)
+ }
+
+ entries1, _ := os.ReadDir(backupDir())
+ countBefore := 0
+ for _, e := range entries1 {
+ if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
+ countBefore++
+ }
+ }
+
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatal(err)
+ }
+
+ entries2, _ := os.ReadDir(backupDir())
+ countAfter := 0
+ for _, e := range entries2 {
+ if len(e.Name()) > len("unchanged.json.") && e.Name()[:len("unchanged.json.")] == "unchanged.json." {
+ countAfter++
+ }
+ }
+
+ if countAfter != countBefore {
+ t.Errorf("backup was created when content unchanged (before=%d, after=%d)", countBefore, countAfter)
+ }
+ })
+
+ t.Run("backup filename contains unix timestamp", func(t *testing.T) {
+ path := filepath.Join(tmpDir, "timestamped.json")
+
+ os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
+ data := mustMarshal(t, map[string]int{"v": 2})
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatal(err)
+ }
+
+ entries, _ := os.ReadDir(backupDir())
+ var found bool
+ for _, entry := range entries {
+ name := entry.Name()
+ if len(name) > len("timestamped.json.") && name[:len("timestamped.json.")] == "timestamped.json." {
+ timestamp := name[len("timestamped.json."):]
+ for _, c := range timestamp {
+ if c < '0' || c > '9' {
+ t.Errorf("backup filename timestamp contains non-numeric character: %s", name)
+ }
+ }
+ found = true
+ os.Remove(filepath.Join(backupDir(), name))
+ break
+ }
+ }
+ if !found {
+ t.Error("backup file with timestamp not found")
+ }
+ })
+}
+
+// Edge case tests for files.go
+
+// TestWriteWithBackup_FailsIfBackupFails documents critical behavior: if backup fails, we must not proceed.
+// User could lose their config with no way to recover.
+func TestWriteWithBackup_FailsIfBackupFails(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("permission tests unreliable on Windows")
+ }
+
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "config.json")
+
+ // Create original file
+ originalContent := []byte(`{"original": true}`)
+ os.WriteFile(path, originalContent, 0o644)
+
+ // Make backup directory read-only to force backup failure
+ backupDir := backupDir()
+ os.MkdirAll(backupDir, 0o755)
+ os.Chmod(backupDir, 0o444) // Read-only
+ defer os.Chmod(backupDir, 0o755)
+
+ newContent := []byte(`{"updated": true}`)
+ err := writeWithBackup(path, newContent)
+
+ // Should fail because backup couldn't be created
+ if err == nil {
+ t.Error("expected error when backup fails, got nil")
+ }
+
+ // Original file should be preserved
+ current, _ := os.ReadFile(path)
+ if string(current) != string(originalContent) {
+ t.Errorf("original file was modified despite backup failure: got %s", string(current))
+ }
+}
+
+// TestWriteWithBackup_PermissionDenied verifies clear error when target file has wrong permissions.
+// Common issue when config owned by root or wrong perms.
+func TestWriteWithBackup_PermissionDenied(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("permission tests unreliable on Windows")
+ }
+
+ tmpDir := t.TempDir()
+
+ // Create a read-only directory
+ readOnlyDir := filepath.Join(tmpDir, "readonly")
+ os.MkdirAll(readOnlyDir, 0o755)
+ os.Chmod(readOnlyDir, 0o444)
+ defer os.Chmod(readOnlyDir, 0o755)
+
+ path := filepath.Join(readOnlyDir, "config.json")
+ err := writeWithBackup(path, []byte(`{"test": true}`))
+
+ if err == nil {
+ t.Error("expected permission error, got nil")
+ }
+}
+
+// TestWriteWithBackup_DirectoryDoesNotExist verifies behavior when target directory doesn't exist.
+// writeWithBackup doesn't create directories - caller is responsible.
+func TestWriteWithBackup_DirectoryDoesNotExist(t *testing.T) {
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "nonexistent", "subdir", "config.json")
+
+ err := writeWithBackup(path, []byte(`{"test": true}`))
+
+ // Should fail because directory doesn't exist
+ if err == nil {
+ t.Error("expected error for nonexistent directory, got nil")
+ }
+}
+
+// TestWriteWithBackup_SymlinkTarget documents behavior when target is a symlink.
+// Documents what happens if user symlinks their config file.
+func TestWriteWithBackup_SymlinkTarget(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("symlink tests may require admin on Windows")
+ }
+
+ tmpDir := t.TempDir()
+ realFile := filepath.Join(tmpDir, "real.json")
+ symlink := filepath.Join(tmpDir, "link.json")
+
+ // Create real file and symlink
+ os.WriteFile(realFile, []byte(`{"v": 1}`), 0o644)
+ os.Symlink(realFile, symlink)
+
+ // Write through symlink
+ err := writeWithBackup(symlink, []byte(`{"v": 2}`))
+ if err != nil {
+ t.Fatalf("writeWithBackup through symlink failed: %v", err)
+ }
+
+ // The real file should be updated (symlink followed for temp file creation)
+ content, _ := os.ReadFile(symlink)
+ if string(content) != `{"v": 2}` {
+ t.Errorf("symlink target not updated correctly: got %s", string(content))
+ }
+}
+
+// TestBackupToTmp_SpecialCharsInFilename verifies backup works with special characters.
+// User may have config files with unusual names.
+func TestBackupToTmp_SpecialCharsInFilename(t *testing.T) {
+ tmpDir := t.TempDir()
+
+ // File with spaces and special chars
+ path := filepath.Join(tmpDir, "my config (backup).json")
+ os.WriteFile(path, []byte(`{"test": true}`), 0o644)
+
+ backupPath, err := backupToTmp(path)
+ if err != nil {
+ t.Fatalf("backupToTmp with special chars failed: %v", err)
+ }
+
+ // Verify backup exists and has correct content
+ content, err := os.ReadFile(backupPath)
+ if err != nil {
+ t.Fatalf("could not read backup: %v", err)
+ }
+ if string(content) != `{"test": true}` {
+ t.Errorf("backup content mismatch: got %s", string(content))
+ }
+
+ os.Remove(backupPath)
+}
+
+// TestCopyFile_PreservesPermissions verifies that copyFile preserves file permissions.
+func TestCopyFile_PreservesPermissions(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("permission preservation tests unreliable on Windows")
+ }
+
+ tmpDir := t.TempDir()
+ src := filepath.Join(tmpDir, "src.json")
+ dst := filepath.Join(tmpDir, "dst.json")
+
+ // Create source with specific permissions
+ os.WriteFile(src, []byte(`{"test": true}`), 0o600)
+
+ err := copyFile(src, dst)
+ if err != nil {
+ t.Fatalf("copyFile failed: %v", err)
+ }
+
+ srcInfo, _ := os.Stat(src)
+ dstInfo, _ := os.Stat(dst)
+
+ if srcInfo.Mode().Perm() != dstInfo.Mode().Perm() {
+ t.Errorf("permissions not preserved: src=%v, dst=%v", srcInfo.Mode().Perm(), dstInfo.Mode().Perm())
+ }
+}
+
+// TestCopyFile_SourceNotFound verifies clear error when source doesn't exist.
+func TestCopyFile_SourceNotFound(t *testing.T) {
+ tmpDir := t.TempDir()
+ src := filepath.Join(tmpDir, "nonexistent.json")
+ dst := filepath.Join(tmpDir, "dst.json")
+
+ err := copyFile(src, dst)
+ if err == nil {
+ t.Error("expected error for nonexistent source, got nil")
+ }
+}
+
+// TestWriteWithBackup_TargetIsDirectory verifies error when path points to a directory.
+func TestWriteWithBackup_TargetIsDirectory(t *testing.T) {
+ tmpDir := t.TempDir()
+ dirPath := filepath.Join(tmpDir, "actualdir")
+ os.MkdirAll(dirPath, 0o755)
+
+ err := writeWithBackup(dirPath, []byte(`{"test": true}`))
+ if err == nil {
+ t.Error("expected error when target is a directory, got nil")
+ }
+}
+
+// TestWriteWithBackup_EmptyData verifies writing zero bytes works correctly.
+func TestWriteWithBackup_EmptyData(t *testing.T) {
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "empty.json")
+
+ err := writeWithBackup(path, []byte{})
+ if err != nil {
+ t.Fatalf("writeWithBackup with empty data failed: %v", err)
+ }
+
+ content, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatalf("could not read file: %v", err)
+ }
+ if len(content) != 0 {
+ t.Errorf("expected empty file, got %d bytes", len(content))
+ }
+}
+
+// TestWriteWithBackup_FileUnreadableButDirWritable verifies behavior when existing file
+// cannot be read (for backup comparison) but directory is writable.
+func TestWriteWithBackup_FileUnreadableButDirWritable(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("permission tests unreliable on Windows")
+ }
+
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "unreadable.json")
+
+ // Create file and make it unreadable
+ os.WriteFile(path, []byte(`{"original": true}`), 0o644)
+ os.Chmod(path, 0o000)
+ defer os.Chmod(path, 0o644)
+
+ // Should fail because we can't read the file to compare/backup
+ err := writeWithBackup(path, []byte(`{"updated": true}`))
+ if err == nil {
+ t.Error("expected error when file is unreadable, got nil")
+ }
+}
+
+// TestWriteWithBackup_RapidSuccessiveWrites verifies backup works with multiple writes
+// within the same second (timestamp collision scenario).
+func TestWriteWithBackup_RapidSuccessiveWrites(t *testing.T) {
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "rapid.json")
+
+ // Create initial file
+ os.WriteFile(path, []byte(`{"v": 0}`), 0o644)
+
+ // Rapid successive writes
+ for i := 1; i <= 3; i++ {
+ data := []byte(fmt.Sprintf(`{"v": %d}`, i))
+ if err := writeWithBackup(path, data); err != nil {
+ t.Fatalf("write %d failed: %v", i, err)
+ }
+ }
+
+ // Verify final content
+ content, _ := os.ReadFile(path)
+ if string(content) != `{"v": 3}` {
+ t.Errorf("expected final content {\"v\": 3}, got %s", string(content))
+ }
+
+ // Verify at least one backup exists
+ entries, _ := os.ReadDir(backupDir())
+ var backupCount int
+ for _, e := range entries {
+ if len(e.Name()) > len("rapid.json.") && e.Name()[:len("rapid.json.")] == "rapid.json." {
+ backupCount++
+ }
+ }
+ if backupCount == 0 {
+ t.Error("expected at least one backup file from rapid writes")
+ }
+}
+
+// TestWriteWithBackup_BackupDirIsFile verifies error when backup directory path is a file.
+func TestWriteWithBackup_BackupDirIsFile(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("test modifies system temp directory")
+ }
+
+ // Create a file at the backup directory path
+ backupPath := backupDir()
+ // Clean up any existing directory first
+ os.RemoveAll(backupPath)
+ // Create a file instead of directory
+ os.WriteFile(backupPath, []byte("not a directory"), 0o644)
+ defer func() {
+ os.Remove(backupPath)
+ os.MkdirAll(backupPath, 0o755)
+ }()
+
+ tmpDir := t.TempDir()
+ path := filepath.Join(tmpDir, "test.json")
+ os.WriteFile(path, []byte(`{"original": true}`), 0o644)
+
+ err := writeWithBackup(path, []byte(`{"updated": true}`))
+ if err == nil {
+ t.Error("expected error when backup dir is a file, got nil")
+ }
+}
+
+// TestWriteWithBackup_NoOrphanTempFiles verifies temp files are cleaned up on failure.
+func TestWriteWithBackup_NoOrphanTempFiles(t *testing.T) {
+ if runtime.GOOS == "windows" {
+ t.Skip("permission tests unreliable on Windows")
+ }
+
+ tmpDir := t.TempDir()
+
+ // Count existing temp files
+ countTempFiles := func() int {
+ entries, _ := os.ReadDir(tmpDir)
+ count := 0
+ for _, e := range entries {
+ if len(e.Name()) > 4 && e.Name()[:4] == ".tmp" {
+ count++
+ }
+ }
+ return count
+ }
+
+ before := countTempFiles()
+
+ // Create a file, then make directory read-only to cause rename failure
+ path := filepath.Join(tmpDir, "orphan.json")
+ os.WriteFile(path, []byte(`{"v": 1}`), 0o644)
+
+ // Make a subdirectory and try to write there after making parent read-only
+ subDir := filepath.Join(tmpDir, "subdir")
+ os.MkdirAll(subDir, 0o755)
+ subPath := filepath.Join(subDir, "config.json")
+ os.WriteFile(subPath, []byte(`{"v": 1}`), 0o644)
+
+ // Make subdir read-only after creating temp file would succeed but rename would fail
+ // This is tricky to test - the temp file is created in the same dir, so if we can't
+ // rename, we also couldn't create. Let's just verify normal failure cleanup works.
+
+ // Force a failure by making the target a directory
+ badPath := filepath.Join(tmpDir, "isdir")
+ os.MkdirAll(badPath, 0o755)
+
+ _ = writeWithBackup(badPath, []byte(`{"test": true}`))
+
+ after := countTempFiles()
+ if after > before {
+ t.Errorf("orphan temp files left behind: before=%d, after=%d", before, after)
+ }
+}
diff --git a/cmd/config/integrations.go b/cmd/config/integrations.go
new file mode 100644
index 00000000000..1480de5f30a
--- /dev/null
+++ b/cmd/config/integrations.go
@@ -0,0 +1,1418 @@
+package config
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "runtime"
+ "slices"
+ "strings"
+ "time"
+
+ "github.com/ollama/ollama/api"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
+ "github.com/ollama/ollama/progress"
+ "github.com/spf13/cobra"
+)
+
+// Runners execute the launching of a model with the integration - claude, codex
+// Editors can edit config files (supports multi-model selection) - opencode, droid
+// They are composable interfaces where in some cases an editor is also a runner - opencode, droid
+// Runner can run an integration with a model.
+
+type Runner interface {
+ Run(model string, args []string) error
+ // String returns the human-readable name of the integration
+ String() string
+}
+
+// Editor can edit config files (supports multi-model selection)
+type Editor interface {
+ // Paths returns the paths to the config files for the integration
+ Paths() []string
+ // Edit updates the config files for the integration with the given models
+ Edit(models []string) error
+ // Models returns the models currently configured for the integration
+ // TODO(parthsareen): add error return to Models()
+ Models() []string
+}
+
+// AliasConfigurer can configure model aliases (e.g., for subagent routing).
+// Integrations like Claude and Codex use this to route model requests to local models.
+type AliasConfigurer interface {
+ // ConfigureAliases prompts the user to configure aliases and returns the updated map.
+ ConfigureAliases(ctx context.Context, primaryModel string, existing map[string]string, force bool) (map[string]string, bool, error)
+ // SetAliases syncs the configured aliases to the server
+ SetAliases(ctx context.Context, aliases map[string]string) error
+}
+
+// integrations is the registry of available integrations.
+var integrations = map[string]Runner{
+ "claude": &Claude{},
+ "clawdbot": &Openclaw{},
+ "cline": &Cline{},
+ "codex": &Codex{},
+ "moltbot": &Openclaw{},
+ "droid": &Droid{},
+ "opencode": &OpenCode{},
+ "openclaw": &Openclaw{},
+ "pi": &Pi{},
+}
+
+// recommendedModels are shown when the user has no models or as suggestions.
+// Order matters: local models first, then cloud models.
+var recommendedModels = []ModelItem{
+ {Name: "minimax-m2.5:cloud", Description: "Fast, efficient coding and real-world productivity", Recommended: true},
+ {Name: "glm-5:cloud", Description: "Reasoning and code generation", Recommended: true},
+ {Name: "kimi-k2.5:cloud", Description: "Multimodal reasoning with subagents", Recommended: true},
+ {Name: "glm-4.7-flash", Description: "Reasoning and code generation locally", Recommended: true},
+ {Name: "qwen3:8b", Description: "Efficient all-purpose assistant", Recommended: true},
+}
+
+// cloudModelLimits maps cloud model base names to their token limits.
+// TODO(parthsareen): grab context/output limits from model info instead of hardcoding
+var cloudModelLimits = map[string]cloudModelLimit{
+ "minimax-m2.5": {Context: 204_800, Output: 128_000},
+ "cogito-2.1:671b": {Context: 163_840, Output: 65_536},
+ "deepseek-v3.1:671b": {Context: 163_840, Output: 163_840},
+ "deepseek-v3.2": {Context: 163_840, Output: 65_536},
+ "glm-4.6": {Context: 202_752, Output: 131_072},
+ "glm-4.7": {Context: 202_752, Output: 131_072},
+ "gpt-oss:120b": {Context: 131_072, Output: 131_072},
+ "gpt-oss:20b": {Context: 131_072, Output: 131_072},
+ "kimi-k2:1t": {Context: 262_144, Output: 262_144},
+ "kimi-k2.5": {Context: 262_144, Output: 262_144},
+ "kimi-k2-thinking": {Context: 262_144, Output: 262_144},
+ "nemotron-3-nano:30b": {Context: 1_048_576, Output: 131_072},
+ "qwen3-coder:480b": {Context: 262_144, Output: 65_536},
+ "qwen3-coder-next": {Context: 262_144, Output: 32_768},
+ "qwen3-next:80b": {Context: 262_144, Output: 32_768},
+}
+
+// recommendedVRAM maps local recommended models to their approximate VRAM requirement.
+var recommendedVRAM = map[string]string{
+ "glm-4.7-flash": "~25GB",
+ "qwen3:8b": "~11GB",
+}
+
+// integrationAliases are hidden from the interactive selector but work as CLI arguments.
+var integrationAliases = map[string]bool{
+ "clawdbot": true,
+ "moltbot": true,
+}
+
+// integrationInstallHints maps integration names to install URLs.
+var integrationInstallHints = map[string]string{
+ "claude": "https://code.claude.com/docs/en/quickstart",
+ "cline": "https://cline.bot/cli",
+ "openclaw": "https://docs.openclaw.ai",
+ "codex": "https://developers.openai.com/codex/cli/",
+ "droid": "https://docs.factory.ai/cli/getting-started/quickstart",
+ "opencode": "https://opencode.ai",
+ "pi": "https://github.com/badlogic/pi-mono",
+}
+
+// hyperlink wraps text in an OSC 8 terminal hyperlink so it is cmd+clickable.
+func hyperlink(url, text string) string {
+ return fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", url, text)
+}
+
+// IntegrationInfo contains display information about a registered integration.
+type IntegrationInfo struct {
+ Name string // registry key, e.g. "claude"
+ DisplayName string // human-readable, e.g. "Claude Code"
+ Description string // short description, e.g. "Anthropic's agentic coding tool"
+}
+
+// integrationDescriptions maps integration names to short descriptions.
+var integrationDescriptions = map[string]string{
+ "claude": "Anthropic's coding tool with subagents",
+ "cline": "Autonomous coding agent with parallel execution",
+ "codex": "OpenAI's open-source coding agent",
+ "openclaw": "Personal AI with 100+ skills",
+ "droid": "Factory's coding agent across terminal and IDEs",
+ "opencode": "Anomaly's open-source coding agent",
+ "pi": "Minimal AI agent toolkit with plugin support",
+}
+
+// integrationOrder defines a custom display order for integrations.
+// Integrations listed here are placed at the end in the given order;
+// all others appear first, sorted alphabetically.
+var integrationOrder = []string{"opencode", "droid", "pi", "cline"}
+
+// ListIntegrationInfos returns all non-alias registered integrations, sorted by name
+// with integrationOrder entries placed at the end.
+func ListIntegrationInfos() []IntegrationInfo {
+ var result []IntegrationInfo
+ for name, r := range integrations {
+ if integrationAliases[name] {
+ continue
+ }
+ result = append(result, IntegrationInfo{
+ Name: name,
+ DisplayName: r.String(),
+ Description: integrationDescriptions[name],
+ })
+ }
+
+ orderRank := make(map[string]int, len(integrationOrder))
+ for i, name := range integrationOrder {
+ orderRank[name] = i + 1 // 1-indexed so 0 means "not in the list"
+ }
+
+ slices.SortFunc(result, func(a, b IntegrationInfo) int {
+ aRank, bRank := orderRank[a.Name], orderRank[b.Name]
+ // Both have custom order: sort by their rank
+ if aRank > 0 && bRank > 0 {
+ return aRank - bRank
+ }
+ // Only one has custom order: it goes last
+ if aRank > 0 {
+ return 1
+ }
+ if bRank > 0 {
+ return -1
+ }
+ // Neither has custom order: alphabetical
+ return strings.Compare(a.Name, b.Name)
+ })
+ return result
+}
+
+// IntegrationInstallHint returns a user-friendly install hint for the given integration,
+// or an empty string if none is available. The URL is wrapped in an OSC 8 hyperlink
+// so it is cmd+clickable in supported terminals.
+func IntegrationInstallHint(name string) string {
+ url := integrationInstallHints[name]
+ if url == "" {
+ return ""
+ }
+ return "Install from " + hyperlink(url, url)
+}
+
+// IsIntegrationInstalled checks if an integration binary is installed.
+func IsIntegrationInstalled(name string) bool {
+ switch name {
+ case "claude":
+ c := &Claude{}
+ _, err := c.findPath()
+ return err == nil
+ case "openclaw":
+ if _, err := exec.LookPath("openclaw"); err == nil {
+ return true
+ }
+ if _, err := exec.LookPath("clawdbot"); err == nil {
+ return true
+ }
+ return false
+ case "codex":
+ _, err := exec.LookPath("codex")
+ return err == nil
+ case "droid":
+ _, err := exec.LookPath("droid")
+ return err == nil
+ case "cline":
+ _, err := exec.LookPath("cline")
+ return err == nil
+ case "opencode":
+ _, err := exec.LookPath("opencode")
+ return err == nil
+ case "pi":
+ _, err := exec.LookPath("pi")
+ return err == nil
+ default:
+ return true // Assume installed for unknown integrations
+ }
+}
+
+// IsEditorIntegration returns true if the named integration uses multi-model
+// selection (implements the Editor interface).
+func IsEditorIntegration(name string) bool {
+ r, ok := integrations[strings.ToLower(name)]
+ if !ok {
+ return false
+ }
+ _, isEditor := r.(Editor)
+ return isEditor
+}
+
+// SelectModel lets the user select a model to run.
+// ModelItem represents a model for selection.
+type ModelItem struct {
+ Name string
+ Description string
+ Recommended bool
+}
+
+// SingleSelector is a function type for single item selection.
+// current is the name of the previously selected item to highlight; empty means no pre-selection.
+type SingleSelector func(title string, items []ModelItem, current string) (string, error)
+
+// MultiSelector is a function type for multi item selection.
+type MultiSelector func(title string, items []ModelItem, preChecked []string) ([]string, error)
+
+// SelectModelWithSelector prompts the user to select a model using the provided selector.
+func SelectModelWithSelector(ctx context.Context, selector SingleSelector) (string, error) {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return "", err
+ }
+
+ models, err := client.List(ctx)
+ if err != nil {
+ return "", err
+ }
+
+ var existing []modelInfo
+ for _, m := range models.Models {
+ existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
+ }
+
+ cloudDisabled, _ := cloudStatusDisabled(ctx, client)
+ if cloudDisabled {
+ existing = filterCloudModels(existing)
+ }
+
+ lastModel := LastModel()
+ var preChecked []string
+ if lastModel != "" {
+ preChecked = []string{lastModel}
+ }
+
+ items, _, existingModels, cloudModels := buildModelList(existing, preChecked, lastModel)
+
+ if cloudDisabled {
+ items = filterCloudItems(items)
+ }
+
+ if len(items) == 0 {
+ return "", fmt.Errorf("no models available, run 'ollama pull
' first")
+ }
+
+ selected, err := selector("Select model to run:", items, "")
+ if err != nil {
+ return "", err
+ }
+
+ // If the selected model isn't installed, pull it first
+ if !existingModels[selected] {
+ if cloudModels[selected] {
+ // Cloud models only pull a small manifest; no confirmation needed
+ if err := pullModel(ctx, client, selected); err != nil {
+ return "", fmt.Errorf("failed to pull %s: %w", selected, err)
+ }
+ } else {
+ msg := fmt.Sprintf("Download %s?", selected)
+ if ok, err := confirmPrompt(msg); err != nil {
+ return "", err
+ } else if !ok {
+ return "", errCancelled
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+ if err := pullModel(ctx, client, selected); err != nil {
+ return "", fmt.Errorf("failed to pull %s: %w", selected, err)
+ }
+ }
+ }
+
+ // If it's a cloud model, ensure user is signed in
+ if cloudModels[selected] {
+ user, err := client.Whoami(ctx)
+ if err == nil && user != nil && user.Name != "" {
+ return selected, nil
+ }
+
+ var aErr api.AuthorizationError
+ if !errors.As(err, &aErr) || aErr.SigninURL == "" {
+ return "", err
+ }
+
+ yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", selected))
+ if err != nil || !yes {
+ return "", fmt.Errorf("%s requires sign in", selected)
+ }
+
+ fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
+
+ // Auto-open browser (best effort, fail silently)
+ switch runtime.GOOS {
+ case "darwin":
+ _ = exec.Command("open", aErr.SigninURL).Start()
+ case "linux":
+ _ = exec.Command("xdg-open", aErr.SigninURL).Start()
+ case "windows":
+ _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", aErr.SigninURL).Start()
+ }
+
+ spinnerFrames := []string{"|", "/", "-", "\\"}
+ frame := 0
+
+ fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
+
+ ticker := time.NewTicker(200 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ fmt.Fprintf(os.Stderr, "\r\033[K")
+ return "", ctx.Err()
+ case <-ticker.C:
+ frame++
+ fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
+
+ // poll every 10th frame (~2 seconds)
+ if frame%10 == 0 {
+ u, err := client.Whoami(ctx)
+ if err == nil && u != nil && u.Name != "" {
+ fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
+ return selected, nil
+ }
+ }
+ }
+ }
+ }
+
+ return selected, nil
+}
+
+func SelectModel(ctx context.Context) (string, error) {
+ return SelectModelWithSelector(ctx, DefaultSingleSelector)
+}
+
+// DefaultSingleSelector is the default single-select implementation.
+var DefaultSingleSelector SingleSelector
+
+// DefaultMultiSelector is the default multi-select implementation.
+var DefaultMultiSelector MultiSelector
+
+// DefaultSignIn provides a TUI-based sign-in flow.
+// When set, ensureAuth uses it instead of plain text prompts.
+// Returns the signed-in username or an error.
+var DefaultSignIn func(modelName, signInURL string) (string, error)
+
+func selectIntegration() (string, error) {
+ if DefaultSingleSelector == nil {
+ return "", fmt.Errorf("no selector configured")
+ }
+ if len(integrations) == 0 {
+ return "", fmt.Errorf("no integrations available")
+ }
+
+ var items []ModelItem
+ for name, r := range integrations {
+ if integrationAliases[name] {
+ continue
+ }
+ description := r.String()
+ if conn, err := loadIntegration(name); err == nil && len(conn.Models) > 0 {
+ description = fmt.Sprintf("%s (%s)", r.String(), conn.Models[0])
+ }
+ items = append(items, ModelItem{Name: name, Description: description})
+ }
+
+ orderRank := make(map[string]int, len(integrationOrder))
+ for i, name := range integrationOrder {
+ orderRank[name] = i + 1
+ }
+ slices.SortFunc(items, func(a, b ModelItem) int {
+ aRank, bRank := orderRank[a.Name], orderRank[b.Name]
+ if aRank > 0 && bRank > 0 {
+ return aRank - bRank
+ }
+ if aRank > 0 {
+ return 1
+ }
+ if bRank > 0 {
+ return -1
+ }
+ return strings.Compare(a.Name, b.Name)
+ })
+
+ return DefaultSingleSelector("Select integration:", items, "")
+}
+
+// selectModelsWithSelectors lets the user select models for an integration using provided selectors.
+func selectModelsWithSelectors(ctx context.Context, name, current string, single SingleSelector, multi MultiSelector) ([]string, error) {
+ r, ok := integrations[name]
+ if !ok {
+ return nil, fmt.Errorf("unknown integration: %s", name)
+ }
+
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return nil, err
+ }
+
+ models, err := client.List(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ var existing []modelInfo
+ for _, m := range models.Models {
+ existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
+ }
+
+ cloudDisabled, _ := cloudStatusDisabled(ctx, client)
+ if cloudDisabled {
+ existing = filterCloudModels(existing)
+ }
+
+ var preChecked []string
+ if saved, err := loadIntegration(name); err == nil {
+ preChecked = saved.Models
+ } else if editor, ok := r.(Editor); ok {
+ preChecked = editor.Models()
+ }
+
+ items, preChecked, existingModels, cloudModels := buildModelList(existing, preChecked, current)
+
+ if cloudDisabled {
+ items = filterCloudItems(items)
+ }
+
+ if len(items) == 0 {
+ return nil, fmt.Errorf("no models available")
+ }
+
+ var selected []string
+ if _, ok := r.(Editor); ok {
+ selected, err = multi(fmt.Sprintf("Select models for %s:", r), items, preChecked)
+ if err != nil {
+ return nil, err
+ }
+ } else {
+ prompt := fmt.Sprintf("Select model for %s:", r)
+ if _, ok := r.(AliasConfigurer); ok {
+ prompt = fmt.Sprintf("Select Primary model for %s:", r)
+ }
+ model, err := single(prompt, items, current)
+ if err != nil {
+ return nil, err
+ }
+ selected = []string{model}
+ }
+
+ var toPull []string
+ for _, m := range selected {
+ if !existingModels[m] {
+ toPull = append(toPull, m)
+ }
+ }
+ if len(toPull) > 0 {
+ msg := fmt.Sprintf("Download %s?", strings.Join(toPull, ", "))
+ if ok, err := confirmPrompt(msg); err != nil {
+ return nil, err
+ } else if !ok {
+ return nil, errCancelled
+ }
+ for _, m := range toPull {
+ fmt.Fprintf(os.Stderr, "\n")
+ if err := pullModel(ctx, client, m); err != nil {
+ return nil, fmt.Errorf("failed to pull %s: %w", m, err)
+ }
+ }
+ }
+
+ if err := ensureAuth(ctx, client, cloudModels, selected); err != nil {
+ return nil, err
+ }
+
+ return selected, nil
+}
+
+func pullIfNeeded(ctx context.Context, client *api.Client, existingModels map[string]bool, model string) error {
+ if existingModels[model] {
+ return nil
+ }
+ msg := fmt.Sprintf("Download %s?", model)
+ if ok, err := confirmPrompt(msg); err != nil {
+ return err
+ } else if !ok {
+ return errCancelled
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+ if err := pullModel(ctx, client, model); err != nil {
+ return fmt.Errorf("failed to pull %s: %w", model, err)
+ }
+ return nil
+}
+
+// TODO(parthsareen): pull this out to tui package
+// ShowOrPull checks if a model exists via client.Show and offers to pull it if not found.
+func ShowOrPull(ctx context.Context, client *api.Client, model string) error {
+ if _, err := client.Show(ctx, &api.ShowRequest{Model: model}); err == nil {
+ return nil
+ }
+ // Cloud models only pull a small manifest; skip the download confirmation
+ // TODO(parthsareen): consolidate with cloud config changes
+ if strings.HasSuffix(model, "cloud") {
+ return pullModel(ctx, client, model)
+ }
+ if ok, err := confirmPrompt(fmt.Sprintf("Download %s?", model)); err != nil {
+ return err
+ } else if !ok {
+ return errCancelled
+ }
+ fmt.Fprintf(os.Stderr, "\n")
+ return pullModel(ctx, client, model)
+}
+
+func listModels(ctx context.Context) ([]ModelItem, map[string]bool, map[string]bool, *api.Client, error) {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ models, err := client.List(ctx)
+ if err != nil {
+ return nil, nil, nil, nil, err
+ }
+
+ var existing []modelInfo
+ for _, m := range models.Models {
+ existing = append(existing, modelInfo{
+ Name: m.Name,
+ Remote: m.RemoteModel != "",
+ })
+ }
+
+ cloudDisabled, _ := cloudStatusDisabled(ctx, client)
+ if cloudDisabled {
+ existing = filterCloudModels(existing)
+ }
+
+ items, _, existingModels, cloudModels := buildModelList(existing, nil, "")
+
+ if cloudDisabled {
+ items = filterCloudItems(items)
+ }
+
+ if len(items) == 0 {
+ return nil, nil, nil, nil, fmt.Errorf("no models available, run 'ollama pull ' first")
+ }
+
+ return items, existingModels, cloudModels, client, nil
+}
+
+func OpenBrowser(url string) {
+ switch runtime.GOOS {
+ case "darwin":
+ _ = exec.Command("open", url).Start()
+ case "linux":
+ _ = exec.Command("xdg-open", url).Start()
+ case "windows":
+ _ = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
+ }
+}
+
+func ensureAuth(ctx context.Context, client *api.Client, cloudModels map[string]bool, selected []string) error {
+ var selectedCloudModels []string
+ for _, m := range selected {
+ if cloudModels[m] {
+ selectedCloudModels = append(selectedCloudModels, m)
+ }
+ }
+ if len(selectedCloudModels) == 0 {
+ return nil
+ }
+ if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
+ return errors.New(internalcloud.DisabledError("remote inference is unavailable"))
+ }
+
+ user, err := client.Whoami(ctx)
+ if err == nil && user != nil && user.Name != "" {
+ return nil
+ }
+
+ var aErr api.AuthorizationError
+ if !errors.As(err, &aErr) || aErr.SigninURL == "" {
+ return err
+ }
+
+ modelList := strings.Join(selectedCloudModels, ", ")
+
+ if DefaultSignIn != nil {
+ _, err := DefaultSignIn(modelList, aErr.SigninURL)
+ if err != nil {
+ return fmt.Errorf("%s requires sign in", modelList)
+ }
+ return nil
+ }
+
+ // Fallback: plain text sign-in flow
+ yes, err := confirmPrompt(fmt.Sprintf("sign in to use %s?", modelList))
+ if err != nil || !yes {
+ return fmt.Errorf("%s requires sign in", modelList)
+ }
+
+ fmt.Fprintf(os.Stderr, "\nTo sign in, navigate to:\n %s\n\n", aErr.SigninURL)
+
+ OpenBrowser(aErr.SigninURL)
+
+ spinnerFrames := []string{"|", "/", "-", "\\"}
+ frame := 0
+
+ fmt.Fprintf(os.Stderr, "\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[0])
+
+ ticker := time.NewTicker(200 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ fmt.Fprintf(os.Stderr, "\r\033[K")
+ return ctx.Err()
+ case <-ticker.C:
+ frame++
+ fmt.Fprintf(os.Stderr, "\r\033[90mwaiting for sign in to complete... %s\033[0m", spinnerFrames[frame%len(spinnerFrames)])
+
+ // poll every 10th frame (~2 seconds)
+ if frame%10 == 0 {
+ u, err := client.Whoami(ctx)
+ if err == nil && u != nil && u.Name != "" {
+ fmt.Fprintf(os.Stderr, "\r\033[K\033[A\r\033[K\033[1msigned in:\033[0m %s\n", u.Name)
+ return nil
+ }
+ }
+ }
+ }
+}
+
+// selectModels lets the user select models for an integration using default selectors.
+func selectModels(ctx context.Context, name, current string) ([]string, error) {
+ return selectModelsWithSelectors(ctx, name, current, DefaultSingleSelector, DefaultMultiSelector)
+}
+
+func runIntegration(name, modelName string, args []string) error {
+ r, ok := integrations[name]
+ if !ok {
+ return fmt.Errorf("unknown integration: %s", name)
+ }
+
+ fmt.Fprintf(os.Stderr, "\nLaunching %s with %s...\n", r, modelName)
+ return r.Run(modelName, args)
+}
+
+// syncAliases syncs aliases to server and saves locally for an AliasConfigurer.
+func syncAliases(ctx context.Context, client *api.Client, ac AliasConfigurer, name, model string, existing map[string]string) error {
+ aliases := make(map[string]string)
+ for k, v := range existing {
+ aliases[k] = v
+ }
+ aliases["primary"] = model
+
+ if isCloudModel(ctx, client, model) {
+ if aliases["fast"] == "" || !isCloudModel(ctx, client, aliases["fast"]) {
+ aliases["fast"] = model
+ }
+ } else {
+ delete(aliases, "fast")
+ }
+
+ if err := ac.SetAliases(ctx, aliases); err != nil {
+ return err
+ }
+ return saveAliases(name, aliases)
+}
+
+// LaunchIntegration launches the named integration using saved config or prompts for setup.
+func LaunchIntegration(name string) error {
+ r, ok := integrations[name]
+ if !ok {
+ return fmt.Errorf("unknown integration: %s", name)
+ }
+
+ // Try to use saved config
+ if ic, err := loadIntegration(name); err == nil && len(ic.Models) > 0 {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+ if err := ShowOrPull(context.Background(), client, ic.Models[0]); err != nil {
+ return err
+ }
+ return runIntegration(name, ic.Models[0], nil)
+ }
+
+ // No saved config - prompt user to run setup
+ return fmt.Errorf("%s is not configured. Run 'ollama launch %s' to set it up", r, name)
+}
+
+// LaunchIntegrationWithModel launches the named integration with the specified model.
+func LaunchIntegrationWithModel(name, modelName string) error {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+ if err := ShowOrPull(context.Background(), client, modelName); err != nil {
+ return err
+ }
+ return runIntegration(name, modelName, nil)
+}
+
+// SaveAndEditIntegration saves the models for an Editor integration and runs its Edit method
+// to write the integration's config files.
+func SaveAndEditIntegration(name string, models []string) error {
+ r, ok := integrations[strings.ToLower(name)]
+ if !ok {
+ return fmt.Errorf("unknown integration: %s", name)
+ }
+ if err := SaveIntegration(name, models); err != nil {
+ return fmt.Errorf("failed to save: %w", err)
+ }
+ if editor, isEditor := r.(Editor); isEditor {
+ if err := editor.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+ }
+ return nil
+}
+
+// resolveEditorModels filters out cloud-disabled models before editor launch.
+// If no models remain, it invokes picker to collect a valid replacement list.
+func resolveEditorModels(name string, models []string, picker func() ([]string, error)) ([]string, error) {
+ filtered := filterDisabledCloudModels(models)
+ if len(filtered) != len(models) {
+ if err := SaveIntegration(name, filtered); err != nil {
+ return nil, fmt.Errorf("failed to save: %w", err)
+ }
+ }
+ if len(filtered) > 0 {
+ return filtered, nil
+ }
+
+ selected, err := picker()
+ if err != nil {
+ return nil, err
+ }
+ if err := SaveIntegration(name, selected); err != nil {
+ return nil, fmt.Errorf("failed to save: %w", err)
+ }
+ return selected, nil
+}
+
+// ConfigureIntegrationWithSelectors allows the user to select/change the model for an integration using custom selectors.
+func ConfigureIntegrationWithSelectors(ctx context.Context, name string, single SingleSelector, multi MultiSelector) error {
+ r, ok := integrations[name]
+ if !ok {
+ return fmt.Errorf("unknown integration: %s", name)
+ }
+
+ models, err := selectModelsWithSelectors(ctx, name, "", single, multi)
+ if errors.Is(err, errCancelled) {
+ return errCancelled
+ }
+ if err != nil {
+ return err
+ }
+
+ if editor, isEditor := r.(Editor); isEditor {
+ paths := editor.Paths()
+ if len(paths) > 0 {
+ fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
+ for _, p := range paths {
+ fmt.Fprintf(os.Stderr, " %s\n", p)
+ }
+ fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
+
+ if ok, _ := confirmPrompt("Proceed?"); !ok {
+ return nil
+ }
+ }
+
+ if err := editor.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+ }
+
+ if err := SaveIntegration(name, models); err != nil {
+ return fmt.Errorf("failed to save: %w", err)
+ }
+
+ if len(models) == 1 {
+ fmt.Fprintf(os.Stderr, "Configured %s with %s\n", r, models[0])
+ } else {
+ fmt.Fprintf(os.Stderr, "Configured %s with %d models (default: %s)\n", r, len(models), models[0])
+ }
+
+ return nil
+}
+
+// ConfigureIntegration allows the user to select/change the model for an integration.
+func ConfigureIntegration(ctx context.Context, name string) error {
+ return ConfigureIntegrationWithSelectors(ctx, name, DefaultSingleSelector, DefaultMultiSelector)
+}
+
+// LaunchCmd returns the cobra command for launching integrations.
+// The runTUI callback is called when no arguments are provided (alias for main TUI).
+func LaunchCmd(checkServerHeartbeat func(cmd *cobra.Command, args []string) error, runTUI func(cmd *cobra.Command)) *cobra.Command {
+ var modelFlag string
+ var configFlag bool
+
+ cmd := &cobra.Command{
+ Use: "launch [INTEGRATION] [-- [EXTRA_ARGS...]]",
+ Short: "Launch the Ollama menu or an integration",
+ Long: `Launch the Ollama interactive menu, or directly launch a specific integration.
+
+Without arguments, this is equivalent to running 'ollama' directly.
+
+Supported integrations:
+ claude Claude Code
+ cline Cline
+ codex Codex
+ droid Droid
+ opencode OpenCode
+ openclaw OpenClaw (aliases: clawdbot, moltbot)
+ pi Pi
+
+Examples:
+ ollama launch
+ ollama launch claude
+ ollama launch claude --model
+ ollama launch droid --config (does not auto-launch)
+ ollama launch codex -- -p myprofile (pass extra args to integration)
+ ollama launch codex -- --sandbox workspace-write`,
+ Args: cobra.ArbitraryArgs,
+ PreRunE: checkServerHeartbeat,
+ RunE: func(cmd *cobra.Command, args []string) error {
+ // No args and no flags - show the full TUI (same as bare 'ollama')
+ if len(args) == 0 && modelFlag == "" && !configFlag {
+ runTUI(cmd)
+ return nil
+ }
+
+ // Extract integration name and args to pass through using -- separator
+ var name string
+ var passArgs []string
+ dashIdx := cmd.ArgsLenAtDash()
+
+ if dashIdx == -1 {
+ // No "--" separator: only allow 0 or 1 args (integration name)
+ if len(args) > 1 {
+ return fmt.Errorf("unexpected arguments: %v\nUse '--' to pass extra arguments to the integration", args[1:])
+ }
+ if len(args) == 1 {
+ name = args[0]
+ }
+ } else {
+ // "--" was used: args before it = integration name, args after = passthrough
+ if dashIdx > 1 {
+ return fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
+ }
+ if dashIdx == 1 {
+ name = args[0]
+ }
+ passArgs = args[dashIdx:]
+ }
+
+ if name == "" {
+ var err error
+ name, err = selectIntegration()
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ r, ok := integrations[strings.ToLower(name)]
+ if !ok {
+ return fmt.Errorf("unknown integration: %s", name)
+ }
+
+ if modelFlag != "" && IsCloudModelDisabled(cmd.Context(), modelFlag) {
+ modelFlag = ""
+ }
+
+ // Handle AliasConfigurer integrations (claude, codex)
+ if ac, ok := r.(AliasConfigurer); ok {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+
+ // Validate --model flag if provided
+ if modelFlag != "" {
+ if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil {
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ return err
+ }
+ }
+
+ var model string
+ var existingAliases map[string]string
+
+ // Load saved config
+ if cfg, err := loadIntegration(name); err == nil {
+ existingAliases = cfg.Aliases
+ if len(cfg.Models) > 0 {
+ model = cfg.Models[0]
+ // AliasConfigurer integrations use single model; sanitize if multiple
+ if len(cfg.Models) > 1 {
+ _ = SaveIntegration(name, []string{model})
+ }
+ }
+ }
+
+ // --model flag overrides saved model
+ if modelFlag != "" {
+ model = modelFlag
+ }
+
+ // Validate saved model still exists
+ if model != "" && modelFlag == "" {
+ if disabled, _ := cloudStatusDisabled(cmd.Context(), client); disabled && isCloudModelName(model) {
+ model = ""
+ } else if _, err := client.Show(cmd.Context(), &api.ShowRequest{Model: model}); err != nil {
+ fmt.Fprintf(os.Stderr, "%sConfigured model %q not found%s\n\n", ansiGray, model, ansiReset)
+ if err := ShowOrPull(cmd.Context(), client, model); err != nil {
+ model = ""
+ }
+ }
+ }
+
+ // Show picker so user can change model (skip when --model flag provided)
+ aliases, _, err := ac.ConfigureAliases(cmd.Context(), model, existingAliases, modelFlag == "")
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ model = aliases["primary"]
+ existingAliases = aliases
+
+ // Ensure cloud models are authenticated
+ if isCloudModel(cmd.Context(), client, model) {
+ if err := ensureAuth(cmd.Context(), client, map[string]bool{model: true}, []string{model}); err != nil {
+ return err
+ }
+ }
+
+ // Sync aliases and save
+ if err := syncAliases(cmd.Context(), client, ac, name, model, existingAliases); err != nil {
+ fmt.Fprintf(os.Stderr, "%sWarning: Could not sync aliases: %v%s\n", ansiGray, err, ansiReset)
+ }
+ if err := SaveIntegration(name, []string{model}); err != nil {
+ return fmt.Errorf("failed to save: %w", err)
+ }
+
+ // Launch (unless --config without confirmation)
+ if configFlag {
+ if launch, _ := confirmPrompt(fmt.Sprintf("Launch %s now?", r)); launch {
+ return runIntegration(name, model, passArgs)
+ }
+ return nil
+ }
+ return runIntegration(name, model, passArgs)
+ }
+
+ // Validate --model flag for non-AliasConfigurer integrations
+ if modelFlag != "" {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return err
+ }
+ if err := ShowOrPull(cmd.Context(), client, modelFlag); err != nil {
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ return err
+ }
+ }
+
+ var models []string
+ if modelFlag != "" {
+ models = []string{modelFlag}
+ if existing, err := loadIntegration(name); err == nil && len(existing.Models) > 0 {
+ for _, m := range existing.Models {
+ if m != modelFlag {
+ models = append(models, m)
+ }
+ }
+ }
+ models = filterDisabledCloudModels(models)
+ if len(models) == 0 {
+ var err error
+ models, err = selectModels(cmd.Context(), name, "")
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ }
+ } else {
+ current := ""
+ if saved, err := loadIntegration(name); err == nil && len(saved.Models) > 0 {
+ current = saved.Models[0]
+ }
+ var err error
+ models, err = selectModels(cmd.Context(), name, current)
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ if editor, isEditor := r.(Editor); isEditor {
+ paths := editor.Paths()
+ if len(paths) > 0 {
+ fmt.Fprintf(os.Stderr, "This will modify your %s configuration:\n", r)
+ for _, p := range paths {
+ fmt.Fprintf(os.Stderr, " %s\n", p)
+ }
+ fmt.Fprintf(os.Stderr, "Backups will be saved to %s/\n\n", backupDir())
+
+ if ok, _ := confirmPrompt("Proceed?"); !ok {
+ return nil
+ }
+ }
+ }
+
+ if err := SaveIntegration(name, models); err != nil {
+ return fmt.Errorf("failed to save: %w", err)
+ }
+
+ if editor, isEditor := r.(Editor); isEditor {
+ if err := editor.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+ }
+
+ if _, isEditor := r.(Editor); isEditor {
+ if len(models) == 1 {
+ fmt.Fprintf(os.Stderr, "Added %s to %s\n", models[0], r)
+ } else {
+ fmt.Fprintf(os.Stderr, "Added %d models to %s (default: %s)\n", len(models), r, models[0])
+ }
+ }
+
+ if configFlag {
+ if launch, _ := confirmPrompt(fmt.Sprintf("\nLaunch %s now?", r)); launch {
+ return runIntegration(name, models[0], passArgs)
+ }
+ fmt.Fprintf(os.Stderr, "Run 'ollama launch %s' to start with %s\n", strings.ToLower(name), models[0])
+ return nil
+ }
+
+ return runIntegration(name, models[0], passArgs)
+ },
+ }
+
+ cmd.Flags().StringVar(&modelFlag, "model", "", "Model to use")
+ cmd.Flags().BoolVar(&configFlag, "config", false, "Configure without launching")
+ return cmd
+}
+
+type modelInfo struct {
+ Name string
+ Remote bool
+ ToolCapable bool
+}
+
+// buildModelList merges existing models with recommendations, sorts them, and returns
+// the ordered items along with maps of existing and cloud model names.
+func buildModelList(existing []modelInfo, preChecked []string, current string) (items []ModelItem, orderedChecked []string, existingModels, cloudModels map[string]bool) {
+ existingModels = make(map[string]bool)
+ cloudModels = make(map[string]bool)
+ recommended := make(map[string]bool)
+ var hasLocalModel, hasCloudModel bool
+
+ recDesc := make(map[string]string)
+ for _, rec := range recommendedModels {
+ recommended[rec.Name] = true
+ recDesc[rec.Name] = rec.Description
+ }
+
+ for _, m := range existing {
+ existingModels[m.Name] = true
+ if m.Remote {
+ cloudModels[m.Name] = true
+ hasCloudModel = true
+ } else {
+ hasLocalModel = true
+ }
+ displayName := strings.TrimSuffix(m.Name, ":latest")
+ existingModels[displayName] = true
+ item := ModelItem{Name: displayName, Recommended: recommended[displayName], Description: recDesc[displayName]}
+ items = append(items, item)
+ }
+
+ for _, rec := range recommendedModels {
+ if existingModels[rec.Name] || existingModels[rec.Name+":latest"] {
+ continue
+ }
+ items = append(items, rec)
+ if isCloudModelName(rec.Name) {
+ cloudModels[rec.Name] = true
+ }
+ }
+
+ checked := make(map[string]bool, len(preChecked))
+ for _, n := range preChecked {
+ checked[n] = true
+ }
+
+ // Resolve current to full name (e.g., "llama3.2" -> "llama3.2:latest")
+ for _, item := range items {
+ if item.Name == current || strings.HasPrefix(item.Name, current+":") {
+ current = item.Name
+ break
+ }
+ }
+
+ if checked[current] {
+ preChecked = append([]string{current}, slices.DeleteFunc(preChecked, func(m string) bool { return m == current })...)
+ }
+
+ // Non-existing models get "install?" suffix and are pushed to the bottom.
+ // When user has no models, preserve recommended order.
+ notInstalled := make(map[string]bool)
+ for i := range items {
+ if !existingModels[items[i].Name] {
+ notInstalled[items[i].Name] = true
+ var parts []string
+ if items[i].Description != "" {
+ parts = append(parts, items[i].Description)
+ }
+ if vram := recommendedVRAM[items[i].Name]; vram != "" {
+ parts = append(parts, vram)
+ }
+ parts = append(parts, "(not downloaded)")
+ items[i].Description = strings.Join(parts, ", ")
+ }
+ }
+
+ // Build a recommended rank map to preserve ordering within tiers.
+ recRank := make(map[string]int)
+ for i, rec := range recommendedModels {
+ recRank[rec.Name] = i + 1 // 1-indexed; 0 means not recommended
+ }
+
+ onlyLocal := hasLocalModel && !hasCloudModel
+
+ if hasLocalModel || hasCloudModel {
+ slices.SortStableFunc(items, func(a, b ModelItem) int {
+ ac, bc := checked[a.Name], checked[b.Name]
+ aNew, bNew := notInstalled[a.Name], notInstalled[b.Name]
+ aRec, bRec := recRank[a.Name] > 0, recRank[b.Name] > 0
+ aCloud, bCloud := cloudModels[a.Name], cloudModels[b.Name]
+
+ // Checked/pre-selected always first
+ if ac != bc {
+ if ac {
+ return -1
+ }
+ return 1
+ }
+
+ // Recommended above non-recommended
+ if aRec != bRec {
+ if aRec {
+ return -1
+ }
+ return 1
+ }
+
+ // Both recommended
+ if aRec && bRec {
+ if aCloud != bCloud {
+ if onlyLocal {
+ // Local before cloud when only local installed
+ if aCloud {
+ return 1
+ }
+ return -1
+ }
+ // Cloud before local in mixed case
+ if aCloud {
+ return -1
+ }
+ return 1
+ }
+ return recRank[a.Name] - recRank[b.Name]
+ }
+
+ // Both non-recommended: installed before not-installed
+ if aNew != bNew {
+ if aNew {
+ return 1
+ }
+ return -1
+ }
+
+ return strings.Compare(strings.ToLower(a.Name), strings.ToLower(b.Name))
+ })
+ }
+
+ return items, preChecked, existingModels, cloudModels
+}
+
+// IsCloudModelDisabled reports whether the given model name looks like a cloud
+// model and cloud features are currently disabled on the server.
+func IsCloudModelDisabled(ctx context.Context, name string) bool {
+ if !isCloudModelName(name) {
+ return false
+ }
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return false
+ }
+ disabled, _ := cloudStatusDisabled(ctx, client)
+ return disabled
+}
+
+func isCloudModelName(name string) bool {
+ return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
+}
+
+func filterCloudModels(existing []modelInfo) []modelInfo {
+ filtered := existing[:0]
+ for _, m := range existing {
+ if !m.Remote {
+ filtered = append(filtered, m)
+ }
+ }
+ return filtered
+}
+
+// filterDisabledCloudModels removes cloud models from a list when cloud is disabled.
+func filterDisabledCloudModels(models []string) []string {
+ var filtered []string
+ for _, m := range models {
+ if !IsCloudModelDisabled(context.Background(), m) {
+ filtered = append(filtered, m)
+ }
+ }
+ return filtered
+}
+
+func filterCloudItems(items []ModelItem) []ModelItem {
+ filtered := items[:0]
+ for _, item := range items {
+ if !isCloudModelName(item.Name) {
+ filtered = append(filtered, item)
+ }
+ }
+ return filtered
+}
+
+func isCloudModel(ctx context.Context, client *api.Client, name string) bool {
+ if client == nil {
+ return false
+ }
+ resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
+ if err != nil {
+ return false
+ }
+ return resp.RemoteModel != ""
+}
+
+// GetModelItems returns a list of model items including recommendations for the TUI.
+// It includes all locally available models plus recommended models that aren't installed.
+func GetModelItems(ctx context.Context) ([]ModelItem, map[string]bool) {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return nil, nil
+ }
+
+ models, err := client.List(ctx)
+ if err != nil {
+ return nil, nil
+ }
+
+ var existing []modelInfo
+ for _, m := range models.Models {
+ existing = append(existing, modelInfo{Name: m.Name, Remote: m.RemoteModel != ""})
+ }
+
+ cloudDisabled, _ := cloudStatusDisabled(ctx, client)
+ if cloudDisabled {
+ existing = filterCloudModels(existing)
+ }
+
+ lastModel := LastModel()
+ var preChecked []string
+ if lastModel != "" {
+ preChecked = []string{lastModel}
+ }
+
+ items, _, existingModels, _ := buildModelList(existing, preChecked, lastModel)
+
+ if cloudDisabled {
+ items = filterCloudItems(items)
+ }
+
+ return items, existingModels
+}
+
+func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
+ status, err := client.CloudStatusExperimental(ctx)
+ if err != nil {
+ var statusErr api.StatusError
+ if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
+ return false, false
+ }
+ return false, false
+ }
+ return status.Cloud.Disabled, true
+}
+
+func pullModel(ctx context.Context, client *api.Client, model string) error {
+ p := progress.NewProgress(os.Stderr)
+ defer p.Stop()
+
+ bars := make(map[string]*progress.Bar)
+ var status string
+ var spinner *progress.Spinner
+
+ fn := func(resp api.ProgressResponse) error {
+ if resp.Digest != "" {
+ if resp.Completed == 0 {
+ return nil
+ }
+
+ if spinner != nil {
+ spinner.Stop()
+ }
+
+ bar, ok := bars[resp.Digest]
+ if !ok {
+ name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
+ name = strings.TrimSpace(name)
+ if isDigest {
+ name = name[:min(12, len(name))]
+ }
+ bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
+ bars[resp.Digest] = bar
+ p.Add(resp.Digest, bar)
+ }
+
+ bar.Set(resp.Completed)
+ } else if status != resp.Status {
+ if spinner != nil {
+ spinner.Stop()
+ }
+
+ status = resp.Status
+ spinner = progress.NewSpinner(status)
+ p.Add(status, spinner)
+ }
+
+ return nil
+ }
+
+ request := api.PullRequest{Name: model}
+ return client.Pull(ctx, &request, fn)
+}
diff --git a/cmd/config/integrations_test.go b/cmd/config/integrations_test.go
new file mode 100644
index 00000000000..914a8f6610e
--- /dev/null
+++ b/cmd/config/integrations_test.go
@@ -0,0 +1,1429 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "slices"
+ "strings"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+ "github.com/ollama/ollama/api"
+ "github.com/spf13/cobra"
+)
+
+type stubEditorRunner struct {
+ edited [][]string
+ ranModel string
+}
+
+func (s *stubEditorRunner) Run(model string, args []string) error {
+ s.ranModel = model
+ return nil
+}
+
+func (s *stubEditorRunner) String() string { return "StubEditor" }
+
+func (s *stubEditorRunner) Paths() []string { return nil }
+
+func (s *stubEditorRunner) Edit(models []string) error {
+ cloned := append([]string(nil), models...)
+ s.edited = append(s.edited, cloned)
+ return nil
+}
+
+func (s *stubEditorRunner) Models() []string { return nil }
+
+func TestIntegrationLookup(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantFound bool
+ wantName string
+ }{
+ {"claude lowercase", "claude", true, "Claude Code"},
+ {"claude uppercase", "CLAUDE", true, "Claude Code"},
+ {"claude mixed case", "Claude", true, "Claude Code"},
+ {"codex", "codex", true, "Codex"},
+ {"droid", "droid", true, "Droid"},
+ {"opencode", "opencode", true, "OpenCode"},
+ {"unknown integration", "unknown", false, ""},
+ {"empty string", "", false, ""},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ r, found := integrations[strings.ToLower(tt.input)]
+ if found != tt.wantFound {
+ t.Errorf("integrations[%q] found = %v, want %v", tt.input, found, tt.wantFound)
+ }
+ if found && r.String() != tt.wantName {
+ t.Errorf("integrations[%q].String() = %q, want %q", tt.input, r.String(), tt.wantName)
+ }
+ })
+ }
+}
+
+func TestIntegrationRegistry(t *testing.T) {
+ expectedIntegrations := []string{"claude", "codex", "droid", "opencode"}
+
+ for _, name := range expectedIntegrations {
+ t.Run(name, func(t *testing.T) {
+ r, ok := integrations[name]
+ if !ok {
+ t.Fatalf("integration %q not found in registry", name)
+ }
+ if r.String() == "" {
+ t.Error("integration.String() should not be empty")
+ }
+ })
+ }
+}
+
+func TestHasLocalModel(t *testing.T) {
+ tests := []struct {
+ name string
+ models []string
+ want bool
+ }{
+ {"empty list", []string{}, false},
+ {"single local model", []string{"llama3.2"}, true},
+ {"single cloud model", []string{"cloud-model"}, false},
+ {"mixed models", []string{"cloud-model", "llama3.2"}, true},
+ {"multiple local models", []string{"llama3.2", "qwen2.5"}, true},
+ {"multiple cloud models", []string{"cloud-a", "cloud-b"}, false},
+ {"local model first", []string{"llama3.2", "cloud-model"}, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := slices.ContainsFunc(tt.models, func(m string) bool {
+ return !strings.Contains(m, "cloud")
+ })
+ if got != tt.want {
+ t.Errorf("hasLocalModel(%v) = %v, want %v", tt.models, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestLaunchCmd(t *testing.T) {
+ // Mock checkServerHeartbeat that always succeeds
+ mockCheck := func(cmd *cobra.Command, args []string) error {
+ return nil
+ }
+ mockTUI := func(cmd *cobra.Command) {}
+ cmd := LaunchCmd(mockCheck, mockTUI)
+
+ t.Run("command structure", func(t *testing.T) {
+ if cmd.Use != "launch [INTEGRATION] [-- [EXTRA_ARGS...]]" {
+ t.Errorf("Use = %q, want %q", cmd.Use, "launch [INTEGRATION] [-- [EXTRA_ARGS...]]")
+ }
+ if cmd.Short == "" {
+ t.Error("Short description should not be empty")
+ }
+ if cmd.Long == "" {
+ t.Error("Long description should not be empty")
+ }
+ })
+
+ t.Run("flags exist", func(t *testing.T) {
+ modelFlag := cmd.Flags().Lookup("model")
+ if modelFlag == nil {
+ t.Error("--model flag should exist")
+ }
+
+ configFlag := cmd.Flags().Lookup("config")
+ if configFlag == nil {
+ t.Error("--config flag should exist")
+ }
+ })
+
+ t.Run("PreRunE is set", func(t *testing.T) {
+ if cmd.PreRunE == nil {
+ t.Error("PreRunE should be set to checkServerHeartbeat")
+ }
+ })
+}
+
+func TestLaunchCmd_TUICallback(t *testing.T) {
+ mockCheck := func(cmd *cobra.Command, args []string) error {
+ return nil
+ }
+
+ t.Run("no args calls TUI", func(t *testing.T) {
+ tuiCalled := false
+ mockTUI := func(cmd *cobra.Command) {
+ tuiCalled = true
+ }
+
+ cmd := LaunchCmd(mockCheck, mockTUI)
+ cmd.SetArgs([]string{})
+ _ = cmd.Execute()
+
+ if !tuiCalled {
+ t.Error("TUI callback should be called when no args provided")
+ }
+ })
+
+ t.Run("integration arg bypasses TUI", func(t *testing.T) {
+ srv := httptest.NewServer(http.NotFoundHandler())
+ defer srv.Close()
+ t.Setenv("OLLAMA_HOST", srv.URL)
+
+ tuiCalled := false
+ mockTUI := func(cmd *cobra.Command) {
+ tuiCalled = true
+ }
+
+ cmd := LaunchCmd(mockCheck, mockTUI)
+ cmd.SetArgs([]string{"claude"})
+ // Will error because claude isn't configured, but that's OK
+ _ = cmd.Execute()
+
+ if tuiCalled {
+ t.Error("TUI callback should NOT be called when integration arg provided")
+ }
+ })
+
+ t.Run("--model flag bypasses TUI", func(t *testing.T) {
+ tuiCalled := false
+ mockTUI := func(cmd *cobra.Command) {
+ tuiCalled = true
+ }
+
+ cmd := LaunchCmd(mockCheck, mockTUI)
+ cmd.SetArgs([]string{"--model", "test-model"})
+ // Will error because no integration specified, but that's OK
+ _ = cmd.Execute()
+
+ if tuiCalled {
+ t.Error("TUI callback should NOT be called when --model flag provided")
+ }
+ })
+
+ t.Run("--config flag bypasses TUI", func(t *testing.T) {
+ tuiCalled := false
+ mockTUI := func(cmd *cobra.Command) {
+ tuiCalled = true
+ }
+
+ cmd := LaunchCmd(mockCheck, mockTUI)
+ cmd.SetArgs([]string{"--config"})
+ // Will error because no integration specified, but that's OK
+ _ = cmd.Execute()
+
+ if tuiCalled {
+ t.Error("TUI callback should NOT be called when --config flag provided")
+ }
+ })
+}
+
+func TestRunIntegration_UnknownIntegration(t *testing.T) {
+ err := runIntegration("unknown-integration", "model", nil)
+ if err == nil {
+ t.Error("expected error for unknown integration, got nil")
+ }
+ if !strings.Contains(err.Error(), "unknown integration") {
+ t.Errorf("error should mention 'unknown integration', got: %v", err)
+ }
+}
+
+func TestHasLocalModel_DocumentsHeuristic(t *testing.T) {
+ tests := []struct {
+ name string
+ models []string
+ want bool
+ reason string
+ }{
+ {"empty list", []string{}, false, "empty list has no local models"},
+ {"contains-cloud-substring", []string{"deepseek-r1:cloud"}, false, "model with 'cloud' substring is considered cloud"},
+ {"cloud-in-name", []string{"my-cloud-model"}, false, "'cloud' anywhere in name = cloud model"},
+ {"cloudless", []string{"cloudless-model"}, false, "'cloudless' still contains 'cloud'"},
+ {"local-model", []string{"llama3.2"}, true, "no 'cloud' = local"},
+ {"mixed", []string{"cloud-model", "llama3.2"}, true, "one local model = hasLocalModel true"},
+ {"all-cloud", []string{"cloud-a", "cloud-b"}, false, "all contain 'cloud'"},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := slices.ContainsFunc(tt.models, func(m string) bool {
+ return !strings.Contains(m, "cloud")
+ })
+ if got != tt.want {
+ t.Errorf("hasLocalModel(%v) = %v, want %v (%s)", tt.models, got, tt.want, tt.reason)
+ }
+ })
+ }
+}
+
+func TestLaunchCmd_NilHeartbeat(t *testing.T) {
+ // This should not panic - cmd creation should work even with nil
+ cmd := LaunchCmd(nil, nil)
+ if cmd == nil {
+ t.Fatal("LaunchCmd returned nil")
+ }
+
+ // PreRunE should be nil when passed nil
+ if cmd.PreRunE != nil {
+ t.Log("Note: PreRunE is set even when nil is passed (acceptable)")
+ }
+}
+
+func TestAllIntegrations_HaveRequiredMethods(t *testing.T) {
+ for name, r := range integrations {
+ t.Run(name, func(t *testing.T) {
+ displayName := r.String()
+ if displayName == "" {
+ t.Error("String() should not return empty")
+ }
+ var _ func(string, []string) error = r.Run
+ })
+ }
+}
+
+func TestParseArgs(t *testing.T) {
+ // Tests reflect cobra's ArgsLenAtDash() semantics:
+ // - cobra strips "--" from args
+ // - ArgsLenAtDash() returns the index where "--" was, or -1
+ tests := []struct {
+ name string
+ args []string // args as cobra delivers them (no "--")
+ dashIdx int // what ArgsLenAtDash() returns
+ wantName string
+ wantArgs []string
+ wantErr bool
+ }{
+ {
+ name: "no extra args, no dash",
+ args: []string{"claude"},
+ dashIdx: -1,
+ wantName: "claude",
+ },
+ {
+ name: "with extra args after --",
+ args: []string{"codex", "-p", "myprofile"},
+ dashIdx: 1,
+ wantName: "codex",
+ wantArgs: []string{"-p", "myprofile"},
+ },
+ {
+ name: "extra args only after --",
+ args: []string{"codex", "--sandbox", "workspace-write"},
+ dashIdx: 1,
+ wantName: "codex",
+ wantArgs: []string{"--sandbox", "workspace-write"},
+ },
+ {
+ name: "-- at end with no args after",
+ args: []string{"claude"},
+ dashIdx: 1,
+ wantName: "claude",
+ },
+ {
+ name: "-- with no integration name",
+ args: []string{"--verbose"},
+ dashIdx: 0,
+ wantName: "",
+ wantArgs: []string{"--verbose"},
+ },
+ {
+ name: "multiple args before -- is error",
+ args: []string{"claude", "codex", "--verbose"},
+ dashIdx: 2,
+ wantErr: true,
+ },
+ {
+ name: "multiple args without -- is error",
+ args: []string{"claude", "codex"},
+ dashIdx: -1,
+ wantErr: true,
+ },
+ {
+ name: "no args, no dash",
+ args: []string{},
+ dashIdx: -1,
+ wantName: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Simulate the parsing logic from LaunchCmd using dashIdx
+ var name string
+ var parsedArgs []string
+ var err error
+
+ dashIdx := tt.dashIdx
+ args := tt.args
+
+ if dashIdx == -1 {
+ if len(args) > 1 {
+ err = fmt.Errorf("unexpected arguments: %v", args[1:])
+ } else if len(args) == 1 {
+ name = args[0]
+ }
+ } else {
+ if dashIdx > 1 {
+ err = fmt.Errorf("expected at most 1 integration name before '--', got %d", dashIdx)
+ } else {
+ if dashIdx == 1 {
+ name = args[0]
+ }
+ parsedArgs = args[dashIdx:]
+ }
+ }
+
+ if tt.wantErr {
+ if err == nil {
+ t.Fatal("expected error, got nil")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if name != tt.wantName {
+ t.Errorf("name = %q, want %q", name, tt.wantName)
+ }
+ if !slices.Equal(parsedArgs, tt.wantArgs) {
+ t.Errorf("args = %v, want %v", parsedArgs, tt.wantArgs)
+ }
+ })
+ }
+}
+
+func TestIsCloudModel(t *testing.T) {
+ // isCloudModel now only uses Show API, so nil client always returns false
+ t.Run("nil client returns false", func(t *testing.T) {
+ models := []string{"glm-5:cloud", "kimi-k2.5:cloud", "local-model"}
+ for _, model := range models {
+ if isCloudModel(context.Background(), nil, model) {
+ t.Errorf("isCloudModel(%q) with nil client should return false", model)
+ }
+ }
+ })
+}
+
+func names(items []ModelItem) []string {
+ var out []string
+ for _, item := range items {
+ out = append(out, item.Name)
+ }
+ return out
+}
+
+func TestBuildModelList_NoExistingModels(t *testing.T) {
+ items, _, _, _ := buildModelList(nil, nil, "")
+
+ want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
+ if diff := cmp.Diff(want, names(items)); diff != "" {
+ t.Errorf("with no existing models, items should be recommended in order (-want +got):\n%s", diff)
+ }
+
+ for _, item := range items {
+ if !strings.HasSuffix(item.Description, "(not downloaded)") {
+ t.Errorf("item %q should have description ending with '(not downloaded)', got %q", item.Name, item.Description)
+ }
+ }
+}
+
+func TestBuildModelList_OnlyLocalModels_CloudRecsAtBottom(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "qwen2.5:latest", Remote: false},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // Recommended pinned at top (local recs first, then cloud recs when only-local), then installed non-recs
+ want := []string{"glm-4.7-flash", "qwen3:8b", "minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "llama3.2", "qwen2.5"}
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("recs pinned at top, local recs before cloud recs (-want +got):\n%s", diff)
+ }
+}
+
+func TestBuildModelList_BothCloudAndLocal_RegularSort(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // All recs pinned at top (cloud before local in mixed case), then non-recs
+ want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("recs pinned at top, cloud recs first in mixed case (-want +got):\n%s", diff)
+ }
+}
+
+func TestBuildModelList_PreCheckedFirst(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
+ got := names(items)
+
+ if got[0] != "llama3.2" {
+ t.Errorf("pre-checked model should be first, got %v", got)
+ }
+}
+
+func TestBuildModelList_ExistingRecommendedMarked(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "glm-4.7-flash", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+
+ for _, item := range items {
+ switch item.Name {
+ case "glm-4.7-flash", "glm-5:cloud":
+ if strings.HasSuffix(item.Description, "(not downloaded)") {
+ t.Errorf("installed recommended %q should not have '(not downloaded)' suffix, got %q", item.Name, item.Description)
+ }
+ case "minimax-m2.5:cloud", "kimi-k2.5:cloud", "qwen3:8b":
+ if !strings.HasSuffix(item.Description, "(not downloaded)") {
+ t.Errorf("non-installed recommended %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
+ }
+ }
+ }
+}
+
+func TestBuildModelList_ExistingCloudModelsNotPushedToBottom(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "glm-4.7-flash", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // glm-4.7-flash and glm-5:cloud are installed so they sort normally;
+ // kimi-k2.5:cloud and qwen3:8b are not installed so they go to the bottom
+ // All recs: cloud first in mixed case, then local, in rec order within each
+ want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b"}
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("all recs, cloud first in mixed case (-want +got):\n%s", diff)
+ }
+}
+
+func TestBuildModelList_HasRecommendedCloudModel_OnlyNonInstalledAtBottom(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "kimi-k2.5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // kimi-k2.5:cloud is installed so it sorts normally;
+ // the rest of the recommendations are not installed so they go to the bottom
+ // All recs pinned at top (cloud first in mixed case), then non-recs
+ want := []string{"minimax-m2.5:cloud", "glm-5:cloud", "kimi-k2.5:cloud", "glm-4.7-flash", "qwen3:8b", "llama3.2"}
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("recs pinned at top, cloud first in mixed case (-want +got):\n%s", diff)
+ }
+
+ for _, item := range items {
+ if !slices.Contains([]string{"kimi-k2.5:cloud", "llama3.2"}, item.Name) {
+ if !strings.HasSuffix(item.Description, "(not downloaded)") {
+ t.Errorf("non-installed %q should have '(not downloaded)' suffix, got %q", item.Name, item.Description)
+ }
+ }
+ }
+}
+
+func TestBuildModelList_LatestTagStripped(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "glm-4.7-flash:latest", Remote: false},
+ {Name: "llama3.2:latest", Remote: false},
+ }
+
+ items, _, existingModels, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // :latest should be stripped from display names
+ for _, name := range got {
+ if strings.HasSuffix(name, ":latest") {
+ t.Errorf("name %q should not have :latest suffix", name)
+ }
+ }
+
+ // glm-4.7-flash should not be duplicated (existing :latest matches the recommendation)
+ count := 0
+ for _, name := range got {
+ if name == "glm-4.7-flash" {
+ count++
+ }
+ }
+ if count != 1 {
+ t.Errorf("glm-4.7-flash should appear exactly once, got %d in %v", count, got)
+ }
+
+ // Stripped name should be in existingModels so it won't be pulled
+ if !existingModels["glm-4.7-flash"] {
+ t.Error("glm-4.7-flash should be in existingModels")
+ }
+}
+
+func TestBuildModelList_ReturnsExistingAndCloudMaps(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ _, _, existingModels, cloudModels := buildModelList(existing, nil, "")
+
+ if !existingModels["llama3.2"] {
+ t.Error("llama3.2 should be in existingModels")
+ }
+ if !existingModels["glm-5:cloud"] {
+ t.Error("glm-5:cloud should be in existingModels")
+ }
+ if existingModels["glm-4.7-flash"] {
+ t.Error("glm-4.7-flash should not be in existingModels (it's a recommendation)")
+ }
+
+ if !cloudModels["glm-5:cloud"] {
+ t.Error("glm-5:cloud should be in cloudModels")
+ }
+ if !cloudModels["kimi-k2.5:cloud"] {
+ t.Error("kimi-k2.5:cloud should be in cloudModels (recommended cloud)")
+ }
+ if cloudModels["llama3.2"] {
+ t.Error("llama3.2 should not be in cloudModels")
+ }
+}
+
+func TestBuildModelList_RecommendedFieldSet(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "glm-4.7-flash", Remote: false},
+ {Name: "llama3.2:latest", Remote: false},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+
+ for _, item := range items {
+ switch item.Name {
+ case "glm-4.7-flash", "qwen3:8b", "glm-5:cloud", "kimi-k2.5:cloud":
+ if !item.Recommended {
+ t.Errorf("%q should have Recommended=true", item.Name)
+ }
+ case "llama3.2":
+ if item.Recommended {
+ t.Errorf("%q should have Recommended=false", item.Name)
+ }
+ }
+ }
+}
+
+func TestBuildModelList_MixedCase_CloudRecsFirst(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // Cloud recs should sort before local recs in mixed case
+ cloudIdx := slices.Index(got, "glm-5:cloud")
+ localIdx := slices.Index(got, "glm-4.7-flash")
+ if cloudIdx > localIdx {
+ t.Errorf("cloud recs should be before local recs in mixed case, got %v", got)
+ }
+}
+
+func TestBuildModelList_OnlyLocal_LocalRecsFirst(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // Local recs should sort before cloud recs in only-local case
+ localIdx := slices.Index(got, "glm-4.7-flash")
+ cloudIdx := slices.Index(got, "glm-5:cloud")
+ if localIdx > cloudIdx {
+ t.Errorf("local recs should be before cloud recs in only-local case, got %v", got)
+ }
+}
+
+func TestBuildModelList_RecsAboveNonRecs(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "custom-model", Remote: false},
+ }
+
+ items, _, _, _ := buildModelList(existing, nil, "")
+ got := names(items)
+
+ // All recommended models should appear before non-recommended installed models
+ lastRecIdx := -1
+ firstNonRecIdx := len(got)
+ for i, name := range got {
+ isRec := name == "glm-4.7-flash" || name == "qwen3:8b" || name == "minimax-m2.5:cloud" || name == "glm-5:cloud" || name == "kimi-k2.5:cloud"
+ if isRec && i > lastRecIdx {
+ lastRecIdx = i
+ }
+ if !isRec && i < firstNonRecIdx {
+ firstNonRecIdx = i
+ }
+ }
+ if lastRecIdx > firstNonRecIdx {
+ t.Errorf("all recs should be above non-recs, got %v", got)
+ }
+}
+
+func TestBuildModelList_CheckedBeforeRecs(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "llama3.2:latest", Remote: false},
+ {Name: "glm-5:cloud", Remote: true},
+ }
+
+ items, _, _, _ := buildModelList(existing, []string{"llama3.2"}, "")
+ got := names(items)
+
+ if got[0] != "llama3.2" {
+ t.Errorf("checked model should be first even before recs, got %v", got)
+ }
+}
+
+func TestEditorIntegration_SavedConfigSkipsSelection(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Save a config for opencode so it looks like a previous launch
+ if err := SaveIntegration("opencode", []string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Verify loadIntegration returns the saved models
+ saved, err := loadIntegration("opencode")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(saved.Models) == 0 {
+ t.Fatal("expected saved models")
+ }
+ if saved.Models[0] != "llama3.2" {
+ t.Errorf("expected llama3.2, got %s", saved.Models[0])
+ }
+}
+
+func TestResolveEditorLaunchModels_PicksWhenAllFiltered(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/status":
+ fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+ t.Setenv("OLLAMA_HOST", srv.URL)
+
+ pickerCalled := false
+ models, err := resolveEditorModels("opencode", []string{"glm-5:cloud"}, func() ([]string, error) {
+ pickerCalled = true
+ return []string{"llama3.2"}, nil
+ })
+ if err != nil {
+ t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
+ }
+ if !pickerCalled {
+ t.Fatal("expected model picker to be called when all models are filtered")
+ }
+ if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
+ t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
+ }
+
+ saved, err := loadIntegration("opencode")
+ if err != nil {
+ t.Fatalf("failed to reload integration config: %v", err)
+ }
+ if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
+ t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestResolveEditorLaunchModels_FiltersAndSkipsPickerWhenLocalRemains(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/status":
+ fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+ t.Setenv("OLLAMA_HOST", srv.URL)
+
+ pickerCalled := false
+ models, err := resolveEditorModels("droid", []string{"llama3.2", "glm-5:cloud"}, func() ([]string, error) {
+ pickerCalled = true
+ return []string{"qwen3:8b"}, nil
+ })
+ if err != nil {
+ t.Fatalf("resolveEditorLaunchModels returned error: %v", err)
+ }
+ if pickerCalled {
+ t.Fatal("picker should not be called when a local model remains")
+ }
+ if diff := cmp.Diff([]string{"llama3.2"}, models); diff != "" {
+ t.Fatalf("resolved models mismatch (-want +got):\n%s", diff)
+ }
+
+ saved, err := loadIntegration("droid")
+ if err != nil {
+ t.Fatalf("failed to reload integration config: %v", err)
+ }
+ if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
+ t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
+ }
+}
+
+func TestLaunchCmd_ModelFlagFiltersDisabledCloudFromSavedConfig(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ if err := SaveIntegration("stubeditor", []string{"glm-5:cloud"}); err != nil {
+ t.Fatalf("failed to seed saved config: %v", err)
+ }
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/status":
+ fmt.Fprintf(w, `{"cloud":{"disabled":true,"source":"config"}}`)
+ case "/api/show":
+ fmt.Fprintf(w, `{"model":"llama3.2"}`)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+ t.Setenv("OLLAMA_HOST", srv.URL)
+
+ stub := &stubEditorRunner{}
+ old, existed := integrations["stubeditor"]
+ integrations["stubeditor"] = stub
+ defer func() {
+ if existed {
+ integrations["stubeditor"] = old
+ } else {
+ delete(integrations, "stubeditor")
+ }
+ }()
+
+ cmd := LaunchCmd(func(cmd *cobra.Command, args []string) error { return nil }, func(cmd *cobra.Command) {})
+ cmd.SetArgs([]string{"stubeditor", "--model", "llama3.2"})
+ if err := cmd.Execute(); err != nil {
+ t.Fatalf("launch command failed: %v", err)
+ }
+
+ saved, err := loadIntegration("stubeditor")
+ if err != nil {
+ t.Fatalf("failed to reload integration config: %v", err)
+ }
+ if diff := cmp.Diff([]string{"llama3.2"}, saved.Models); diff != "" {
+ t.Fatalf("saved models mismatch (-want +got):\n%s", diff)
+ }
+ if diff := cmp.Diff([][]string{{"llama3.2"}}, stub.edited); diff != "" {
+ t.Fatalf("editor models mismatch (-want +got):\n%s", diff)
+ }
+ if stub.ranModel != "llama3.2" {
+ t.Fatalf("expected launch to run with llama3.2, got %q", stub.ranModel)
+ }
+}
+
+func TestAliasConfigurerInterface(t *testing.T) {
+ t.Run("claude implements AliasConfigurer", func(t *testing.T) {
+ claude := &Claude{}
+ if _, ok := interface{}(claude).(AliasConfigurer); !ok {
+ t.Error("Claude should implement AliasConfigurer")
+ }
+ })
+
+ t.Run("codex does not implement AliasConfigurer", func(t *testing.T) {
+ codex := &Codex{}
+ if _, ok := interface{}(codex).(AliasConfigurer); ok {
+ t.Error("Codex should not implement AliasConfigurer")
+ }
+ })
+}
+
+func TestShowOrPull_ModelExists(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `{"model":"test-model"}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ err := ShowOrPull(context.Background(), client, "test-model")
+ if err != nil {
+ t.Errorf("showOrPull should return nil when model exists, got: %v", err)
+ }
+}
+
+func TestShowOrPull_ModelNotFound_NoTerminal(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ fmt.Fprintf(w, `{"error":"model not found"}`)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ // confirmPrompt will fail in test (no terminal), so showOrPull should return an error
+ err := ShowOrPull(context.Background(), client, "missing-model")
+ if err == nil {
+ t.Error("showOrPull should return error when model not found and no terminal available")
+ }
+}
+
+func TestShowOrPull_ShowCalledWithCorrectModel(t *testing.T) {
+ var receivedModel string
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ var req api.ShowRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err == nil {
+ receivedModel = req.Model
+ }
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `{"model":"%s"}`, receivedModel)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ _ = ShowOrPull(context.Background(), client, "qwen3:8b")
+ if receivedModel != "qwen3:8b" {
+ t.Errorf("expected Show to be called with %q, got %q", "qwen3:8b", receivedModel)
+ }
+}
+
+func TestShowOrPull_ModelNotFound_ConfirmYes_Pulls(t *testing.T) {
+ // Set up hook so confirmPrompt doesn't need a terminal
+ oldHook := DefaultConfirmPrompt
+ DefaultConfirmPrompt = func(prompt string) (bool, error) {
+ if !strings.Contains(prompt, "missing-model") {
+ t.Errorf("expected prompt to contain model name, got %q", prompt)
+ }
+ return true, nil
+ }
+ defer func() { DefaultConfirmPrompt = oldHook }()
+
+ var pullCalled bool
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/show":
+ w.WriteHeader(http.StatusNotFound)
+ fmt.Fprintf(w, `{"error":"model not found"}`)
+ case "/api/pull":
+ pullCalled = true
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `{"status":"success"}`)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ err := ShowOrPull(context.Background(), client, "missing-model")
+ if err != nil {
+ t.Errorf("ShowOrPull should succeed after pull, got: %v", err)
+ }
+ if !pullCalled {
+ t.Error("expected pull to be called when user confirms download")
+ }
+}
+
+func TestShowOrPull_ModelNotFound_ConfirmNo_Cancelled(t *testing.T) {
+ oldHook := DefaultConfirmPrompt
+ DefaultConfirmPrompt = func(prompt string) (bool, error) {
+ return false, ErrCancelled
+ }
+ defer func() { DefaultConfirmPrompt = oldHook }()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/show":
+ w.WriteHeader(http.StatusNotFound)
+ fmt.Fprintf(w, `{"error":"model not found"}`)
+ case "/api/pull":
+ t.Error("pull should not be called when user declines")
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ err := ShowOrPull(context.Background(), client, "missing-model")
+ if err == nil {
+ t.Error("ShowOrPull should return error when user declines")
+ }
+}
+
+func TestShowOrPull_CloudModel_SkipsConfirmation(t *testing.T) {
+ // Confirm prompt should NOT be called for cloud models
+ oldHook := DefaultConfirmPrompt
+ DefaultConfirmPrompt = func(prompt string) (bool, error) {
+ t.Error("confirm prompt should not be called for cloud models")
+ return false, nil
+ }
+ defer func() { DefaultConfirmPrompt = oldHook }()
+
+ var pullCalled bool
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ switch r.URL.Path {
+ case "/api/show":
+ w.WriteHeader(http.StatusNotFound)
+ fmt.Fprintf(w, `{"error":"model not found"}`)
+ case "/api/pull":
+ pullCalled = true
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `{"status":"success"}`)
+ default:
+ w.WriteHeader(http.StatusNotFound)
+ }
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ err := ShowOrPull(context.Background(), client, "glm-5:cloud")
+ if err != nil {
+ t.Errorf("ShowOrPull should succeed for cloud model, got: %v", err)
+ }
+ if !pullCalled {
+ t.Error("expected pull to be called for cloud model without confirmation")
+ }
+}
+
+func TestConfirmPrompt_DelegatesToHook(t *testing.T) {
+ oldHook := DefaultConfirmPrompt
+ var hookCalled bool
+ DefaultConfirmPrompt = func(prompt string) (bool, error) {
+ hookCalled = true
+ if prompt != "test prompt?" {
+ t.Errorf("expected prompt %q, got %q", "test prompt?", prompt)
+ }
+ return true, nil
+ }
+ defer func() { DefaultConfirmPrompt = oldHook }()
+
+ ok, err := confirmPrompt("test prompt?")
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if !ok {
+ t.Error("expected true from hook")
+ }
+ if !hookCalled {
+ t.Error("expected DefaultConfirmPrompt hook to be called")
+ }
+}
+
+func TestEnsureAuth_NoCloudModels(t *testing.T) {
+ // ensureAuth should be a no-op when no cloud models are selected
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ t.Error("no API calls expected when no cloud models selected")
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ err := ensureAuth(context.Background(), client, map[string]bool{}, []string{"local-model"})
+ if err != nil {
+ t.Errorf("ensureAuth should return nil for non-cloud models, got: %v", err)
+ }
+}
+
+func TestEnsureAuth_CloudModelFilteredCorrectly(t *testing.T) {
+ // ensureAuth should only care about models in cloudModels map
+ var whoamiCalled bool
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/me" {
+ whoamiCalled = true
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintf(w, `{"name":"testuser"}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cloudModels := map[string]bool{"cloud-model:cloud": true}
+ selected := []string{"cloud-model:cloud", "local-model"}
+
+ err := ensureAuth(context.Background(), client, cloudModels, selected)
+ if err != nil {
+ t.Errorf("ensureAuth should succeed when user is authenticated, got: %v", err)
+ }
+ if !whoamiCalled {
+ t.Error("expected whoami to be called for cloud model")
+ }
+}
+
+func TestEnsureAuth_SkipsWhenNoCloudSelected(t *testing.T) {
+ var whoamiCalled bool
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/me" {
+ whoamiCalled = true
+ }
+ w.WriteHeader(http.StatusOK)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ // cloudModels has entries but none are in selected
+ cloudModels := map[string]bool{"cloud-model:cloud": true}
+ selected := []string{"local-model"}
+
+ err := ensureAuth(context.Background(), client, cloudModels, selected)
+ if err != nil {
+ t.Errorf("expected nil error, got: %v", err)
+ }
+ if whoamiCalled {
+ t.Error("whoami should not be called when no cloud models are selected")
+ }
+}
+
+func TestHyperlink(t *testing.T) {
+ tests := []struct {
+ name string
+ url string
+ text string
+ wantURL string
+ wantText string
+ }{
+ {
+ name: "basic link",
+ url: "https://example.com",
+ text: "click here",
+ wantURL: "https://example.com",
+ wantText: "click here",
+ },
+ {
+ name: "url with path",
+ url: "https://example.com/docs/install",
+ text: "install docs",
+ wantURL: "https://example.com/docs/install",
+ wantText: "install docs",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := hyperlink(tt.url, tt.text)
+
+ // Should contain OSC 8 escape sequences
+ if !strings.Contains(got, "\033]8;;") {
+ t.Error("should contain OSC 8 open sequence")
+ }
+ if !strings.Contains(got, tt.wantURL) {
+ t.Errorf("should contain URL %q", tt.wantURL)
+ }
+ if !strings.Contains(got, tt.wantText) {
+ t.Errorf("should contain text %q", tt.wantText)
+ }
+
+ // Should have closing OSC 8 sequence
+ wantSuffix := "\033]8;;\033\\"
+ if !strings.HasSuffix(got, wantSuffix) {
+ t.Error("should end with OSC 8 close sequence")
+ }
+ })
+ }
+}
+
+func TestIntegrationInstallHint(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ wantEmpty bool
+ wantURL string
+ }{
+ {
+ name: "claude has hint",
+ input: "claude",
+ wantURL: "https://code.claude.com/docs/en/quickstart",
+ },
+ {
+ name: "codex has hint",
+ input: "codex",
+ wantURL: "https://developers.openai.com/codex/cli/",
+ },
+ {
+ name: "openclaw has hint",
+ input: "openclaw",
+ wantURL: "https://docs.openclaw.ai",
+ },
+ {
+ name: "unknown has no hint",
+ input: "unknown",
+ wantEmpty: true,
+ },
+ {
+ name: "empty name has no hint",
+ input: "",
+ wantEmpty: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := IntegrationInstallHint(tt.input)
+ if tt.wantEmpty {
+ if got != "" {
+ t.Errorf("expected empty hint, got %q", got)
+ }
+ return
+ }
+ if !strings.Contains(got, "Install from") {
+ t.Errorf("hint should start with 'Install from', got %q", got)
+ }
+ if !strings.Contains(got, tt.wantURL) {
+ t.Errorf("hint should contain URL %q, got %q", tt.wantURL, got)
+ }
+ // Should be a clickable hyperlink
+ if !strings.Contains(got, "\033]8;;") {
+ t.Error("hint URL should be wrapped in OSC 8 hyperlink")
+ }
+ })
+ }
+}
+
+func TestListIntegrationInfos(t *testing.T) {
+ infos := ListIntegrationInfos()
+
+ t.Run("excludes aliases", func(t *testing.T) {
+ for _, info := range infos {
+ if integrationAliases[info.Name] {
+ t.Errorf("alias %q should not appear in ListIntegrationInfos", info.Name)
+ }
+ }
+ })
+
+ t.Run("sorted with custom order at end", func(t *testing.T) {
+ // integrationOrder entries (cline, opencode) should appear last, in that order.
+ // All other entries should be sorted alphabetically before them.
+ orderRank := make(map[string]int)
+ for i, name := range integrationOrder {
+ orderRank[name] = i + 1
+ }
+ for i := 1; i < len(infos); i++ {
+ aRank, bRank := orderRank[infos[i-1].Name], orderRank[infos[i].Name]
+ switch {
+ case aRank == 0 && bRank == 0:
+ if infos[i-1].Name >= infos[i].Name {
+ t.Errorf("non-ordered items not sorted: %q >= %q", infos[i-1].Name, infos[i].Name)
+ }
+ case aRank > 0 && bRank == 0:
+ t.Errorf("ordered item %q should come after non-ordered %q", infos[i-1].Name, infos[i].Name)
+ case aRank > 0 && bRank > 0:
+ if aRank >= bRank {
+ t.Errorf("ordered items wrong: %q (rank %d) before %q (rank %d)", infos[i-1].Name, aRank, infos[i].Name, bRank)
+ }
+ }
+ }
+ })
+
+ t.Run("all fields populated", func(t *testing.T) {
+ for _, info := range infos {
+ if info.Name == "" {
+ t.Error("Name should not be empty")
+ }
+ if info.DisplayName == "" {
+ t.Errorf("DisplayName for %q should not be empty", info.Name)
+ }
+ }
+ })
+
+ t.Run("includes known integrations", func(t *testing.T) {
+ known := map[string]bool{"claude": false, "codex": false, "opencode": false}
+ for _, info := range infos {
+ if _, ok := known[info.Name]; ok {
+ known[info.Name] = true
+ }
+ }
+ for name, found := range known {
+ if !found {
+ t.Errorf("expected %q in ListIntegrationInfos", name)
+ }
+ }
+ })
+}
+
+func TestBuildModelList_Descriptions(t *testing.T) {
+ t.Run("installed recommended has base description", func(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "qwen3:8b", Remote: false},
+ }
+ items, _, _, _ := buildModelList(existing, nil, "")
+
+ for _, item := range items {
+ if item.Name == "qwen3:8b" {
+ if strings.HasSuffix(item.Description, "install?") {
+ t.Errorf("installed model should not have 'install?' suffix, got %q", item.Description)
+ }
+ if item.Description == "" {
+ t.Error("installed recommended model should have a description")
+ }
+ return
+ }
+ }
+ t.Error("qwen3:8b not found in items")
+ })
+
+ t.Run("not-installed local rec has VRAM in description", func(t *testing.T) {
+ items, _, _, _ := buildModelList(nil, nil, "")
+
+ for _, item := range items {
+ if item.Name == "qwen3:8b" {
+ if !strings.Contains(item.Description, "~11GB") {
+ t.Errorf("not-installed qwen3:8b should show VRAM hint, got %q", item.Description)
+ }
+ return
+ }
+ }
+ t.Error("qwen3:8b not found in items")
+ })
+
+ t.Run("installed local rec omits VRAM", func(t *testing.T) {
+ existing := []modelInfo{
+ {Name: "qwen3:8b", Remote: false},
+ }
+ items, _, _, _ := buildModelList(existing, nil, "")
+
+ for _, item := range items {
+ if item.Name == "qwen3:8b" {
+ if strings.Contains(item.Description, "~11GB") {
+ t.Errorf("installed qwen3:8b should not show VRAM hint, got %q", item.Description)
+ }
+ return
+ }
+ }
+ t.Error("qwen3:8b not found in items")
+ })
+}
+
+func TestLaunchIntegration_UnknownIntegration(t *testing.T) {
+ err := LaunchIntegration("nonexistent-integration")
+ if err == nil {
+ t.Fatal("expected error for unknown integration")
+ }
+ if !strings.Contains(err.Error(), "unknown integration") {
+ t.Errorf("error should mention 'unknown integration', got: %v", err)
+ }
+}
+
+func TestLaunchIntegration_NotConfigured(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Claude is a known integration but not configured in temp dir
+ err := LaunchIntegration("claude")
+ if err == nil {
+ t.Fatal("expected error when integration is not configured")
+ }
+ if !strings.Contains(err.Error(), "not configured") {
+ t.Errorf("error should mention 'not configured', got: %v", err)
+ }
+}
+
+func TestIsEditorIntegration(t *testing.T) {
+ tests := []struct {
+ name string
+ want bool
+ }{
+ {"droid", true},
+ {"opencode", true},
+ {"openclaw", true},
+ {"claude", false},
+ {"codex", false},
+ {"nonexistent", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := IsEditorIntegration(tt.name); got != tt.want {
+ t.Errorf("IsEditorIntegration(%q) = %v, want %v", tt.name, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestIntegrationModels(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("returns nil when not configured", func(t *testing.T) {
+ if got := IntegrationModels("droid"); got != nil {
+ t.Errorf("expected nil, got %v", got)
+ }
+ })
+
+ t.Run("returns all saved models", func(t *testing.T) {
+ if err := SaveIntegration("droid", []string{"llama3.2", "qwen3:8b"}); err != nil {
+ t.Fatal(err)
+ }
+ got := IntegrationModels("droid")
+ want := []string{"llama3.2", "qwen3:8b"}
+ if diff := cmp.Diff(want, got); diff != "" {
+ t.Errorf("IntegrationModels mismatch (-want +got):\n%s", diff)
+ }
+ })
+}
+
+func TestSaveAndEditIntegration_UnknownIntegration(t *testing.T) {
+ err := SaveAndEditIntegration("nonexistent", []string{"model"})
+ if err == nil {
+ t.Fatal("expected error for unknown integration")
+ }
+ if !strings.Contains(err.Error(), "unknown integration") {
+ t.Errorf("error should mention 'unknown integration', got: %v", err)
+ }
+}
diff --git a/cmd/config/openclaw.go b/cmd/config/openclaw.go
new file mode 100644
index 00000000000..c64c2630e9f
--- /dev/null
+++ b/cmd/config/openclaw.go
@@ -0,0 +1,264 @@
+package config
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "strings"
+
+ "github.com/ollama/ollama/envconfig"
+)
+
+type Openclaw struct{}
+
+func (c *Openclaw) String() string { return "OpenClaw" }
+
+func (c *Openclaw) Run(model string, args []string) error {
+ bin := "openclaw"
+ if _, err := exec.LookPath(bin); err != nil {
+ bin = "clawdbot"
+ if _, err := exec.LookPath(bin); err != nil {
+ return fmt.Errorf("openclaw is not installed, install from https://docs.openclaw.ai")
+ }
+ }
+
+ models := []string{model}
+ if config, err := loadIntegration("openclaw"); err == nil && len(config.Models) > 0 {
+ models = config.Models
+ } else if config, err := loadIntegration("clawdbot"); err == nil && len(config.Models) > 0 {
+ models = config.Models
+ }
+ var err error
+ models, err = resolveEditorModels("openclaw", models, func() ([]string, error) {
+ return selectModels(context.Background(), "openclaw", "")
+ })
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ if err := c.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+
+ if !c.onboarded() {
+ // Onboarding not completed: run it (model already set via Edit)
+ // Use "ollama" as gateway token for simple local access
+ cmd := exec.Command(bin, "onboard",
+ "--auth-choice", "skip",
+ "--gateway-token", "ollama",
+ )
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ return cmd.Run()
+ }
+
+ // Onboarding completed: run gateway
+ cmd := exec.Command(bin, append([]string{"gateway"}, args...)...)
+ cmd.Stdin = os.Stdin
+
+ // Capture output to detect "already running" message
+ var outputBuf bytes.Buffer
+ cmd.Stdout = io.MultiWriter(os.Stdout, &outputBuf)
+ cmd.Stderr = io.MultiWriter(os.Stderr, &outputBuf)
+
+ err = cmd.Run()
+ if err != nil && strings.Contains(outputBuf.String(), "Gateway already running") {
+ fmt.Fprintf(os.Stderr, "%sOpenClaw has been configured with Ollama. Gateway is already running.%s\n", ansiGreen, ansiReset)
+ return nil
+ }
+ return err
+}
+
+// onboarded checks if OpenClaw onboarding wizard was completed
+// by looking for the wizard.lastRunAt marker in the config
+func (c *Openclaw) onboarded() bool {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return false
+ }
+
+ configPath := filepath.Join(home, ".openclaw", "openclaw.json")
+ legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
+
+ config := make(map[string]any)
+ if data, err := os.ReadFile(configPath); err == nil {
+ _ = json.Unmarshal(data, &config)
+ } else if data, err := os.ReadFile(legacyPath); err == nil {
+ _ = json.Unmarshal(data, &config)
+ } else {
+ return false
+ }
+
+ // Check for wizard.lastRunAt marker (set when onboarding completes)
+ wizard, _ := config["wizard"].(map[string]any)
+ if wizard == nil {
+ return false
+ }
+ lastRunAt, _ := wizard["lastRunAt"].(string)
+ return lastRunAt != ""
+}
+
+func (c *Openclaw) Paths() []string {
+ home, _ := os.UserHomeDir()
+ p := filepath.Join(home, ".openclaw", "openclaw.json")
+ if _, err := os.Stat(p); err == nil {
+ return []string{p}
+ }
+ legacy := filepath.Join(home, ".clawdbot", "clawdbot.json")
+ if _, err := os.Stat(legacy); err == nil {
+ return []string{legacy}
+ }
+ return nil
+}
+
+func (c *Openclaw) Edit(models []string) error {
+ if len(models) == 0 {
+ return nil
+ }
+
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ configPath := filepath.Join(home, ".openclaw", "openclaw.json")
+ legacyPath := filepath.Join(home, ".clawdbot", "clawdbot.json")
+ if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
+ return err
+ }
+
+ // Read into map[string]any to preserve unknown fields
+ config := make(map[string]any)
+ if data, err := os.ReadFile(configPath); err == nil {
+ _ = json.Unmarshal(data, &config)
+ } else if data, err := os.ReadFile(legacyPath); err == nil {
+ _ = json.Unmarshal(data, &config)
+ }
+
+ // Navigate/create: models.providers.ollama (preserving other providers)
+ modelsSection, _ := config["models"].(map[string]any)
+ if modelsSection == nil {
+ modelsSection = make(map[string]any)
+ }
+ providers, _ := modelsSection["providers"].(map[string]any)
+ if providers == nil {
+ providers = make(map[string]any)
+ }
+ ollama, _ := providers["ollama"].(map[string]any)
+ if ollama == nil {
+ ollama = make(map[string]any)
+ }
+
+ ollama["baseUrl"] = envconfig.Host().String() + "/v1"
+ // needed to register provider
+ ollama["apiKey"] = "ollama-local"
+ // TODO(parthsareen): potentially move to responses
+ ollama["api"] = "openai-completions"
+
+ // Build map of existing models to preserve user customizations
+ existingModels, _ := ollama["models"].([]any)
+ existingByID := make(map[string]map[string]any)
+ for _, m := range existingModels {
+ if entry, ok := m.(map[string]any); ok {
+ if id, ok := entry["id"].(string); ok {
+ existingByID[id] = entry
+ }
+ }
+ }
+
+ var newModels []any
+ for _, model := range models {
+ entry := map[string]any{
+ "id": model,
+ "name": model,
+ "reasoning": false,
+ "input": []any{"text"},
+ "cost": map[string]any{
+ "input": 0,
+ "output": 0,
+ "cacheRead": 0,
+ "cacheWrite": 0,
+ },
+ // TODO(parthsareen): get these values from API
+ "contextWindow": 131072,
+ "maxTokens": 16384,
+ }
+ // Merge existing fields (user customizations)
+ if existing, ok := existingByID[model]; ok {
+ for k, v := range existing {
+ if _, isNew := entry[k]; !isNew {
+ entry[k] = v
+ }
+ }
+ }
+ newModels = append(newModels, entry)
+ }
+ ollama["models"] = newModels
+
+ providers["ollama"] = ollama
+ modelsSection["providers"] = providers
+ config["models"] = modelsSection
+
+ // Update agents.defaults.model.primary (preserving other agent settings)
+ agents, _ := config["agents"].(map[string]any)
+ if agents == nil {
+ agents = make(map[string]any)
+ }
+ defaults, _ := agents["defaults"].(map[string]any)
+ if defaults == nil {
+ defaults = make(map[string]any)
+ }
+ modelConfig, _ := defaults["model"].(map[string]any)
+ if modelConfig == nil {
+ modelConfig = make(map[string]any)
+ }
+ modelConfig["primary"] = "ollama/" + models[0]
+ defaults["model"] = modelConfig
+ agents["defaults"] = defaults
+ config["agents"] = agents
+
+ data, err := json.MarshalIndent(config, "", " ")
+ if err != nil {
+ return err
+ }
+ return writeWithBackup(configPath, data)
+}
+
+func (c *Openclaw) Models() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+
+ config, err := readJSONFile(filepath.Join(home, ".openclaw", "openclaw.json"))
+ if err != nil {
+ config, err = readJSONFile(filepath.Join(home, ".clawdbot", "clawdbot.json"))
+ if err != nil {
+ return nil
+ }
+ }
+
+ modelsSection, _ := config["models"].(map[string]any)
+ providers, _ := modelsSection["providers"].(map[string]any)
+ ollama, _ := providers["ollama"].(map[string]any)
+ modelList, _ := ollama["models"].([]any)
+
+ var result []string
+ for _, m := range modelList {
+ if entry, ok := m.(map[string]any); ok {
+ if id, ok := entry["id"].(string); ok {
+ result = append(result, id)
+ }
+ }
+ }
+ return result
+}
diff --git a/cmd/config/openclaw_test.go b/cmd/config/openclaw_test.go
new file mode 100644
index 00000000000..439f51a3538
--- /dev/null
+++ b/cmd/config/openclaw_test.go
@@ -0,0 +1,878 @@
+package config
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestOpenclawIntegration(t *testing.T) {
+ c := &Openclaw{}
+
+ t.Run("String", func(t *testing.T) {
+ if got := c.String(); got != "OpenClaw" {
+ t.Errorf("String() = %q, want %q", got, "OpenClaw")
+ }
+ })
+
+ t.Run("implements Runner", func(t *testing.T) {
+ var _ Runner = c
+ })
+
+ t.Run("implements Editor", func(t *testing.T) {
+ var _ Editor = c
+ })
+}
+
+func TestOpenclawEdit(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+
+ cleanup := func() { os.RemoveAll(configDir) }
+
+ t.Run("fresh install", func(t *testing.T) {
+ cleanup()
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenclawModelExists(t, configPath, "llama3.2")
+ assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
+ })
+
+ t.Run("multiple models - first is primary", func(t *testing.T) {
+ cleanup()
+ if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenclawModelExists(t, configPath, "llama3.2")
+ assertOpenclawModelExists(t, configPath, "mistral")
+ assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2")
+ })
+
+ t.Run("preserve other providers", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":{"providers":{"anthropic":{"apiKey":"xxx"}}}}`), 0o644)
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ models := cfg["models"].(map[string]any)
+ providers := models["providers"].(map[string]any)
+ if providers["anthropic"] == nil {
+ t.Error("anthropic provider was removed")
+ }
+ })
+
+ t.Run("preserve top-level keys", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"theme":"dark","mcp":{"servers":{}}}`), 0o644)
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ if cfg["theme"] != "dark" {
+ t.Error("theme was removed")
+ }
+ if cfg["mcp"] == nil {
+ t.Error("mcp was removed")
+ }
+ })
+
+ t.Run("preserve user customizations on models", func(t *testing.T) {
+ cleanup()
+ c.Edit([]string{"llama3.2"})
+
+ // User adds custom field
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ models := cfg["models"].(map[string]any)
+ providers := models["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelList := ollama["models"].([]any)
+ entry := modelList[0].(map[string]any)
+ entry["customField"] = "user-value"
+ configData, _ := json.MarshalIndent(cfg, "", " ")
+ os.WriteFile(configPath, configData, 0o644)
+
+ // Re-run Edit
+ c.Edit([]string{"llama3.2"})
+
+ data, _ = os.ReadFile(configPath)
+ json.Unmarshal(data, &cfg)
+ models = cfg["models"].(map[string]any)
+ providers = models["providers"].(map[string]any)
+ ollama = providers["ollama"].(map[string]any)
+ modelList = ollama["models"].([]any)
+ entry = modelList[0].(map[string]any)
+ if entry["customField"] != "user-value" {
+ t.Error("custom field was lost")
+ }
+ })
+
+ t.Run("edit replaces models list", func(t *testing.T) {
+ cleanup()
+ c.Edit([]string{"llama3.2", "mistral"})
+ c.Edit([]string{"llama3.2"})
+
+ assertOpenclawModelExists(t, configPath, "llama3.2")
+ assertOpenclawModelNotExists(t, configPath, "mistral")
+ })
+
+ t.Run("empty models is no-op", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ original := `{"existing":"data"}`
+ os.WriteFile(configPath, []byte(original), 0o644)
+
+ c.Edit([]string{})
+
+ data, _ := os.ReadFile(configPath)
+ if string(data) != original {
+ t.Error("empty models should not modify file")
+ }
+ })
+
+ t.Run("corrupted JSON treated as empty", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
+
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Error("result should be valid JSON")
+ }
+ })
+
+ t.Run("wrong type models section", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":"not a map"}`), 0o644)
+
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenclawModelExists(t, configPath, "llama3.2")
+ })
+}
+
+func TestOpenclawModels(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ t.Run("no config returns nil", func(t *testing.T) {
+ if models := c.Models(); len(models) > 0 {
+ t.Errorf("expected nil/empty, got %v", models)
+ }
+ })
+
+ t.Run("returns all ollama models", func(t *testing.T) {
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{
+ "models":{"providers":{"ollama":{"models":[
+ {"id":"llama3.2"},
+ {"id":"mistral"}
+ ]}}}
+ }`), 0o644)
+
+ models := c.Models()
+ if len(models) != 2 {
+ t.Errorf("expected 2 models, got %v", models)
+ }
+ })
+}
+
+// Helper functions
+func assertOpenclawModelExists(t *testing.T, path, model string) {
+ t.Helper()
+ data, _ := os.ReadFile(path)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ models := cfg["models"].(map[string]any)
+ providers := models["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelList := ollama["models"].([]any)
+ for _, m := range modelList {
+ if entry, ok := m.(map[string]any); ok {
+ if entry["id"] == model {
+ return
+ }
+ }
+ }
+ t.Errorf("model %s not found", model)
+}
+
+func assertOpenclawModelNotExists(t *testing.T, path, model string) {
+ t.Helper()
+ data, _ := os.ReadFile(path)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ models, _ := cfg["models"].(map[string]any)
+ providers, _ := models["providers"].(map[string]any)
+ ollama, _ := providers["ollama"].(map[string]any)
+ modelList, _ := ollama["models"].([]any)
+ for _, m := range modelList {
+ if entry, ok := m.(map[string]any); ok {
+ if entry["id"] == model {
+ t.Errorf("model %s should not exist", model)
+ }
+ }
+ }
+}
+
+func assertOpenclawPrimaryModel(t *testing.T, path, expected string) {
+ t.Helper()
+ data, _ := os.ReadFile(path)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ agents := cfg["agents"].(map[string]any)
+ defaults := agents["defaults"].(map[string]any)
+ model := defaults["model"].(map[string]any)
+ if model["primary"] != expected {
+ t.Errorf("primary model = %v, want %v", model["primary"], expected)
+ }
+}
+
+func TestOpenclawPaths(t *testing.T) {
+ c := &Openclaw{}
+
+ t.Run("returns path when config exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{}`), 0o644)
+
+ paths := c.Paths()
+ if len(paths) != 1 {
+ t.Errorf("expected 1 path, got %d", len(paths))
+ }
+ })
+
+ t.Run("returns nil when config missing", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ if paths := c.Paths(); paths != nil {
+ t.Errorf("expected nil, got %v", paths)
+ }
+ })
+}
+
+func TestOpenclawModelsEdgeCases(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+ cleanup := func() { os.RemoveAll(configDir) }
+
+ t.Run("corrupted JSON returns nil", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{corrupted`), 0o644)
+ if models := c.Models(); models != nil {
+ t.Errorf("expected nil, got %v", models)
+ }
+ })
+
+ t.Run("wrong type at models level", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":"string"}`), 0o644)
+ if models := c.Models(); models != nil {
+ t.Errorf("expected nil, got %v", models)
+ }
+ })
+
+ t.Run("wrong type at providers level", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":{"providers":"string"}}`), 0o644)
+ if models := c.Models(); models != nil {
+ t.Errorf("expected nil, got %v", models)
+ }
+ })
+
+ t.Run("wrong type at ollama level", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":"string"}}}`), 0o644)
+ if models := c.Models(); models != nil {
+ t.Errorf("expected nil, got %v", models)
+ }
+ })
+
+ t.Run("model entry missing id", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"name":"test"}]}}}}`), 0o644)
+ if len(c.Models()) != 0 {
+ t.Error("expected empty for missing id")
+ }
+ })
+
+ t.Run("model id is not string", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"models":{"providers":{"ollama":{"models":[{"id":123}]}}}}`), 0o644)
+ if len(c.Models()) != 0 {
+ t.Error("expected empty for non-string id")
+ }
+ })
+}
+
+func TestOpenclawEditSchemaFields(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
+
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ models := cfg["models"].(map[string]any)
+ providers := models["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelList := ollama["models"].([]any)
+ entry := modelList[0].(map[string]any)
+
+ // Verify required schema fields
+ if entry["reasoning"] != false {
+ t.Error("reasoning should be false")
+ }
+ if entry["input"] == nil {
+ t.Error("input should be set")
+ }
+ if entry["contextWindow"] == nil {
+ t.Error("contextWindow should be set")
+ }
+ if entry["maxTokens"] == nil {
+ t.Error("maxTokens should be set")
+ }
+ cost := entry["cost"].(map[string]any)
+ if cost["cacheRead"] == nil {
+ t.Error("cost.cacheRead should be set")
+ }
+ if cost["cacheWrite"] == nil {
+ t.Error("cost.cacheWrite should be set")
+ }
+}
+
+func TestOpenclawEditModelNames(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
+ cleanup := func() { os.RemoveAll(filepath.Join(tmpDir, ".openclaw")) }
+
+ t.Run("model with colon tag", func(t *testing.T) {
+ cleanup()
+ if err := c.Edit([]string{"llama3.2:70b"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenclawModelExists(t, configPath, "llama3.2:70b")
+ assertOpenclawPrimaryModel(t, configPath, "ollama/llama3.2:70b")
+ })
+
+ t.Run("model with slash", func(t *testing.T) {
+ cleanup()
+ if err := c.Edit([]string{"library/model:tag"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenclawModelExists(t, configPath, "library/model:tag")
+ assertOpenclawPrimaryModel(t, configPath, "ollama/library/model:tag")
+ })
+
+ t.Run("model with hyphen", func(t *testing.T) {
+ cleanup()
+ if err := c.Edit([]string{"test-model"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenclawModelExists(t, configPath, "test-model")
+ })
+}
+
+func TestOpenclawEditAgentsPreservation(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+ cleanup := func() { os.RemoveAll(configDir) }
+
+ t.Run("preserve other agent defaults", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"agents":{"defaults":{"model":{"primary":"old"},"temperature":0.7}}}`), 0o644)
+
+ c.Edit([]string{"llama3.2"})
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ agents := cfg["agents"].(map[string]any)
+ defaults := agents["defaults"].(map[string]any)
+ if defaults["temperature"] != 0.7 {
+ t.Error("temperature setting was lost")
+ }
+ })
+
+ t.Run("preserve other agents besides defaults", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"agents":{"defaults":{},"custom-agent":{"foo":"bar"}}}`), 0o644)
+
+ c.Edit([]string{"llama3.2"})
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ agents := cfg["agents"].(map[string]any)
+ if agents["custom-agent"] == nil {
+ t.Error("custom-agent was lost")
+ }
+ })
+}
+
+const testOpenclawFixture = `{
+ "theme": "dark",
+ "mcp": {"servers": {"custom": {"enabled": true}}},
+ "models": {
+ "providers": {
+ "anthropic": {"apiKey": "xxx"},
+ "ollama": {
+ "baseUrl": "http://127.0.0.1:11434/v1",
+ "models": [{"id": "old-model", "customField": "preserved"}]
+ }
+ }
+ },
+ "agents": {
+ "defaults": {"model": {"primary": "old"}, "temperature": 0.7},
+ "custom-agent": {"foo": "bar"}
+ }
+}`
+
+func TestOpenclawEdit_RoundTrip(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
+
+ if err := c.Edit([]string{"llama3.2", "mistral"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+
+ // Verify top-level preserved
+ if cfg["theme"] != "dark" {
+ t.Error("theme not preserved")
+ }
+ mcp := cfg["mcp"].(map[string]any)
+ servers := mcp["servers"].(map[string]any)
+ if servers["custom"] == nil {
+ t.Error("mcp.servers.custom not preserved")
+ }
+
+ // Verify other providers preserved
+ models := cfg["models"].(map[string]any)
+ providers := models["providers"].(map[string]any)
+ if providers["anthropic"] == nil {
+ t.Error("anthropic provider not preserved")
+ }
+
+ // Verify agents preserved
+ agents := cfg["agents"].(map[string]any)
+ if agents["custom-agent"] == nil {
+ t.Error("custom-agent not preserved")
+ }
+ defaults := agents["defaults"].(map[string]any)
+ if defaults["temperature"] != 0.7 {
+ t.Error("temperature not preserved")
+ }
+}
+
+func TestOpenclawEdit_Idempotent(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
+
+ c.Edit([]string{"llama3.2", "mistral"})
+ firstData, _ := os.ReadFile(configPath)
+
+ c.Edit([]string{"llama3.2", "mistral"})
+ secondData, _ := os.ReadFile(configPath)
+
+ if string(firstData) != string(secondData) {
+ t.Error("repeated edits with same models produced different results")
+ }
+}
+
+func TestOpenclawEdit_MultipleConsecutiveEdits(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(testOpenclawFixture), 0o644)
+
+ for i := range 10 {
+ models := []string{"model-a", "model-b"}
+ if i%2 == 0 {
+ models = []string{"model-x", "model-y", "model-z"}
+ }
+ if err := c.Edit(models); err != nil {
+ t.Fatalf("edit %d failed: %v", i, err)
+ }
+ }
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Fatalf("file is not valid JSON after multiple edits: %v", err)
+ }
+
+ if cfg["theme"] != "dark" {
+ t.Error("theme lost after multiple edits")
+ }
+}
+
+func TestOpenclawEdit_BackupCreated(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ configPath := filepath.Join(configDir, "openclaw.json")
+ backupDir := filepath.Join(os.TempDir(), "ollama-backups")
+
+ os.MkdirAll(configDir, 0o755)
+ uniqueMarker := fmt.Sprintf("test-marker-%d", os.Getpid())
+ original := fmt.Sprintf(`{"theme": "%s"}`, uniqueMarker)
+ os.WriteFile(configPath, []byte(original), 0o644)
+
+ if err := c.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ backups, _ := filepath.Glob(filepath.Join(backupDir, "openclaw.json.*"))
+ foundBackup := false
+ for _, backup := range backups {
+ data, _ := os.ReadFile(backup)
+ if string(data) == original {
+ foundBackup = true
+ break
+ }
+ }
+
+ if !foundBackup {
+ t.Error("backup with original content not found")
+ }
+}
+
+func TestOpenclawClawdbotAlias(t *testing.T) {
+ for _, alias := range []string{"clawdbot", "moltbot"} {
+ t.Run(alias+" alias resolves to Openclaw runner", func(t *testing.T) {
+ r, ok := integrations[alias]
+ if !ok {
+ t.Fatalf("%s not found in integrations", alias)
+ }
+ if _, ok := r.(*Openclaw); !ok {
+ t.Errorf("%s integration is %T, want *Openclaw", alias, r)
+ }
+ })
+
+ t.Run(alias+" is hidden from selector", func(t *testing.T) {
+ if !integrationAliases[alias] {
+ t.Errorf("%s should be in integrationAliases", alias)
+ }
+ })
+ }
+}
+
+func TestOpenclawLegacyPaths(t *testing.T) {
+ c := &Openclaw{}
+
+ t.Run("falls back to legacy clawdbot path", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
+
+ paths := c.Paths()
+ if len(paths) != 1 {
+ t.Fatalf("expected 1 path, got %d", len(paths))
+ }
+ if paths[0] != filepath.Join(legacyDir, "clawdbot.json") {
+ t.Errorf("expected legacy path, got %s", paths[0])
+ }
+ })
+
+ t.Run("prefers new path over legacy", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ newDir := filepath.Join(tmpDir, ".openclaw")
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(newDir, 0o755)
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{}`), 0o644)
+
+ paths := c.Paths()
+ if len(paths) != 1 {
+ t.Fatalf("expected 1 path, got %d", len(paths))
+ }
+ if paths[0] != filepath.Join(newDir, "openclaw.json") {
+ t.Errorf("expected new path, got %s", paths[0])
+ }
+ })
+
+ t.Run("Models reads from legacy path", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
+ "models":{"providers":{"ollama":{"models":[{"id":"llama3.2"}]}}}
+ }`), 0o644)
+
+ models := c.Models()
+ if len(models) != 1 || models[0] != "llama3.2" {
+ t.Errorf("expected [llama3.2], got %v", models)
+ }
+ })
+
+ t.Run("Models prefers new path over legacy", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ newDir := filepath.Join(tmpDir, ".openclaw")
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(newDir, 0o755)
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{
+ "models":{"providers":{"ollama":{"models":[{"id":"new-model"}]}}}
+ }`), 0o644)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{
+ "models":{"providers":{"ollama":{"models":[{"id":"legacy-model"}]}}}
+ }`), 0o644)
+
+ models := c.Models()
+ if len(models) != 1 || models[0] != "new-model" {
+ t.Errorf("expected [new-model], got %v", models)
+ }
+ })
+
+ t.Run("Edit reads new path over legacy when both exist", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ newDir := filepath.Join(tmpDir, ".openclaw")
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(newDir, 0o755)
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{"theme":"new"}`), 0o644)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"legacy"}`), 0o644)
+
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(filepath.Join(newDir, "openclaw.json"))
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ if cfg["theme"] != "new" {
+ t.Errorf("expected theme from new config, got %v", cfg["theme"])
+ }
+ })
+
+ t.Run("Edit migrates from legacy config", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"theme":"dark"}`), 0o644)
+
+ if err := c.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Should write to new path
+ newPath := filepath.Join(tmpDir, ".openclaw", "openclaw.json")
+ data, err := os.ReadFile(newPath)
+ if err != nil {
+ t.Fatal("expected new config file to be created")
+ }
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ if cfg["theme"] != "dark" {
+ t.Error("legacy theme setting was not migrated")
+ }
+ })
+}
+
+func TestOpenclawEdit_CreatesDirectoryIfMissing(t *testing.T) {
+ c := &Openclaw{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+
+ if _, err := os.Stat(configDir); !os.IsNotExist(err) {
+ t.Fatal("directory should not exist before test")
+ }
+
+ if err := c.Edit([]string{"model-a"}); err != nil {
+ t.Fatal(err)
+ }
+
+ if _, err := os.Stat(configDir); os.IsNotExist(err) {
+ t.Fatal("directory was not created")
+ }
+}
+
+func TestOpenclawOnboarded(t *testing.T) {
+ c := &Openclaw{}
+
+ t.Run("returns false when no config exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ if c.onboarded() {
+ t.Error("expected false when no config exists")
+ }
+ })
+
+ t.Run("returns false when config exists but no wizard section", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"theme":"dark"}`), 0o644)
+
+ if c.onboarded() {
+ t.Error("expected false when no wizard section")
+ }
+ })
+
+ t.Run("returns false when wizard section exists but no lastRunAt", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{}}`), 0o644)
+
+ if c.onboarded() {
+ t.Error("expected false when wizard.lastRunAt is missing")
+ }
+ })
+
+ t.Run("returns false when wizard.lastRunAt is empty string", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":""}}`), 0o644)
+
+ if c.onboarded() {
+ t.Error("expected false when wizard.lastRunAt is empty")
+ }
+ })
+
+ t.Run("returns true when wizard.lastRunAt is set", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
+
+ if !c.onboarded() {
+ t.Error("expected true when wizard.lastRunAt is set")
+ }
+ })
+
+ t.Run("checks legacy clawdbot path", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(legacyDir, 0o755)
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
+
+ if !c.onboarded() {
+ t.Error("expected true when legacy config has wizard.lastRunAt")
+ }
+ })
+
+ t.Run("prefers new path over legacy", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ newDir := filepath.Join(tmpDir, ".openclaw")
+ legacyDir := filepath.Join(tmpDir, ".clawdbot")
+ os.MkdirAll(newDir, 0o755)
+ os.MkdirAll(legacyDir, 0o755)
+ // New path has no wizard marker
+ os.WriteFile(filepath.Join(newDir, "openclaw.json"), []byte(`{}`), 0o644)
+ // Legacy has wizard marker
+ os.WriteFile(filepath.Join(legacyDir, "clawdbot.json"), []byte(`{"wizard":{"lastRunAt":"2024-01-01T00:00:00Z"}}`), 0o644)
+
+ if c.onboarded() {
+ t.Error("expected false - should prefer new path which has no wizard marker")
+ }
+ })
+
+ t.Run("handles corrupted JSON gracefully", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{corrupted`), 0o644)
+
+ if c.onboarded() {
+ t.Error("expected false for corrupted JSON")
+ }
+ })
+
+ t.Run("handles wrong type for wizard section", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+ configDir := filepath.Join(tmpDir, ".openclaw")
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(filepath.Join(configDir, "openclaw.json"), []byte(`{"wizard":"not a map"}`), 0o644)
+
+ if c.onboarded() {
+ t.Error("expected false when wizard is wrong type")
+ }
+ })
+}
diff --git a/cmd/config/opencode.go b/cmd/config/opencode.go
new file mode 100644
index 00000000000..b4715beb94a
--- /dev/null
+++ b/cmd/config/opencode.go
@@ -0,0 +1,279 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "maps"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/envconfig"
+)
+
+// OpenCode implements Runner and Editor for OpenCode integration
+type OpenCode struct{}
+
+// cloudModelLimit holds context and output token limits for a cloud model.
+type cloudModelLimit struct {
+ Context int
+ Output int
+}
+
+// lookupCloudModelLimit returns the token limits for a cloud model.
+// It tries the exact name first, then strips the ":cloud" suffix.
+func lookupCloudModelLimit(name string) (cloudModelLimit, bool) {
+ if l, ok := cloudModelLimits[name]; ok {
+ return l, true
+ }
+ base := strings.TrimSuffix(name, ":cloud")
+ if base != name {
+ if l, ok := cloudModelLimits[base]; ok {
+ return l, true
+ }
+ }
+ return cloudModelLimit{}, false
+}
+
+func (o *OpenCode) String() string { return "OpenCode" }
+
+func (o *OpenCode) Run(model string, args []string) error {
+ if _, err := exec.LookPath("opencode"); err != nil {
+ return fmt.Errorf("opencode is not installed, install from https://opencode.ai")
+ }
+
+ // Call Edit() to ensure config is up-to-date before launch
+ models := []string{model}
+ if config, err := loadIntegration("opencode"); err == nil && len(config.Models) > 0 {
+ models = config.Models
+ }
+ var err error
+ models, err = resolveEditorModels("opencode", models, func() ([]string, error) {
+ return selectModels(context.Background(), "opencode", "")
+ })
+ if errors.Is(err, errCancelled) {
+ return nil
+ }
+ if err != nil {
+ return err
+ }
+ if err := o.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+
+ cmd := exec.Command("opencode", args...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ return cmd.Run()
+}
+
+func (o *OpenCode) Paths() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+
+ var paths []string
+ p := filepath.Join(home, ".config", "opencode", "opencode.json")
+ if _, err := os.Stat(p); err == nil {
+ paths = append(paths, p)
+ }
+ sp := filepath.Join(home, ".local", "state", "opencode", "model.json")
+ if _, err := os.Stat(sp); err == nil {
+ paths = append(paths, sp)
+ }
+ return paths
+}
+
+func (o *OpenCode) Edit(modelList []string) error {
+ if len(modelList) == 0 {
+ return nil
+ }
+
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ configPath := filepath.Join(home, ".config", "opencode", "opencode.json")
+ if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
+ return err
+ }
+
+ config := make(map[string]any)
+ if data, err := os.ReadFile(configPath); err == nil {
+ _ = json.Unmarshal(data, &config) // Ignore parse errors; treat missing/corrupt files as empty
+ }
+
+ config["$schema"] = "https://opencode.ai/config.json"
+
+ provider, ok := config["provider"].(map[string]any)
+ if !ok {
+ provider = make(map[string]any)
+ }
+
+ ollama, ok := provider["ollama"].(map[string]any)
+ if !ok {
+ ollama = map[string]any{
+ "npm": "@ai-sdk/openai-compatible",
+ "name": "Ollama (local)",
+ "options": map[string]any{
+ "baseURL": envconfig.Host().String() + "/v1",
+ },
+ }
+ }
+
+ models, ok := ollama["models"].(map[string]any)
+ if !ok {
+ models = make(map[string]any)
+ }
+
+ selectedSet := make(map[string]bool)
+ for _, m := range modelList {
+ selectedSet[m] = true
+ }
+
+ for name, cfg := range models {
+ if cfgMap, ok := cfg.(map[string]any); ok {
+ if isOllamaModel(cfgMap) && !selectedSet[name] {
+ delete(models, name)
+ }
+ }
+ }
+
+ client, _ := api.ClientFromEnvironment()
+
+ for _, model := range modelList {
+ if existing, ok := models[model].(map[string]any); ok {
+ // migrate existing models without _launch marker
+ if isOllamaModel(existing) {
+ existing["_launch"] = true
+ if name, ok := existing["name"].(string); ok {
+ existing["name"] = strings.TrimSuffix(name, " [Ollama]")
+ }
+ }
+ if isCloudModel(context.Background(), client, model) {
+ if l, ok := lookupCloudModelLimit(model); ok {
+ existing["limit"] = map[string]any{
+ "context": l.Context,
+ "output": l.Output,
+ }
+ }
+ }
+ continue
+ }
+ entry := map[string]any{
+ "name": model,
+ "_launch": true,
+ }
+ if isCloudModel(context.Background(), client, model) {
+ if l, ok := lookupCloudModelLimit(model); ok {
+ entry["limit"] = map[string]any{
+ "context": l.Context,
+ "output": l.Output,
+ }
+ }
+ }
+ models[model] = entry
+ }
+
+ ollama["models"] = models
+ provider["ollama"] = ollama
+ config["provider"] = provider
+
+ configData, err := json.MarshalIndent(config, "", " ")
+ if err != nil {
+ return err
+ }
+ if err := writeWithBackup(configPath, configData); err != nil {
+ return err
+ }
+
+ statePath := filepath.Join(home, ".local", "state", "opencode", "model.json")
+ if err := os.MkdirAll(filepath.Dir(statePath), 0o755); err != nil {
+ return err
+ }
+
+ state := map[string]any{
+ "recent": []any{},
+ "favorite": []any{},
+ "variant": map[string]any{},
+ }
+ if data, err := os.ReadFile(statePath); err == nil {
+ _ = json.Unmarshal(data, &state) // Ignore parse errors; use defaults
+ }
+
+ recent, _ := state["recent"].([]any)
+
+ modelSet := make(map[string]bool)
+ for _, m := range modelList {
+ modelSet[m] = true
+ }
+
+ // Filter out existing Ollama models we're about to re-add
+ newRecent := slices.DeleteFunc(slices.Clone(recent), func(entry any) bool {
+ e, ok := entry.(map[string]any)
+ if !ok || e["providerID"] != "ollama" {
+ return false
+ }
+ modelID, _ := e["modelID"].(string)
+ return modelSet[modelID]
+ })
+
+ // Prepend models in reverse order so first model ends up first
+ for _, model := range slices.Backward(modelList) {
+ newRecent = slices.Insert(newRecent, 0, any(map[string]any{
+ "providerID": "ollama",
+ "modelID": model,
+ }))
+ }
+
+ const maxRecentModels = 10
+ newRecent = newRecent[:min(len(newRecent), maxRecentModels)]
+
+ state["recent"] = newRecent
+
+ stateData, err := json.MarshalIndent(state, "", " ")
+ if err != nil {
+ return err
+ }
+ return writeWithBackup(statePath, stateData)
+}
+
+func (o *OpenCode) Models() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+ config, err := readJSONFile(filepath.Join(home, ".config", "opencode", "opencode.json"))
+ if err != nil {
+ return nil
+ }
+ provider, _ := config["provider"].(map[string]any)
+ ollama, _ := provider["ollama"].(map[string]any)
+ models, _ := ollama["models"].(map[string]any)
+ if len(models) == 0 {
+ return nil
+ }
+ keys := slices.Collect(maps.Keys(models))
+ slices.Sort(keys)
+ return keys
+}
+
+// isOllamaModel reports whether a model config entry is managed by us
+func isOllamaModel(cfg map[string]any) bool {
+ if v, ok := cfg["_launch"].(bool); ok && v {
+ return true
+ }
+ // previously used [Ollama] as a suffix for the model managed by ollama launch
+ if name, ok := cfg["name"].(string); ok {
+ return strings.HasSuffix(name, "[Ollama]")
+ }
+ return false
+}
diff --git a/cmd/config/opencode_test.go b/cmd/config/opencode_test.go
new file mode 100644
index 00000000000..8de17458860
--- /dev/null
+++ b/cmd/config/opencode_test.go
@@ -0,0 +1,668 @@
+package config
+
+import (
+ "encoding/json"
+ "fmt"
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestOpenCodeIntegration(t *testing.T) {
+ o := &OpenCode{}
+
+ t.Run("String", func(t *testing.T) {
+ if got := o.String(); got != "OpenCode" {
+ t.Errorf("String() = %q, want %q", got, "OpenCode")
+ }
+ })
+
+ t.Run("implements Runner", func(t *testing.T) {
+ var _ Runner = o
+ })
+
+ t.Run("implements Editor", func(t *testing.T) {
+ var _ Editor = o
+ })
+}
+
+func TestOpenCodeEdit(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+ stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
+ statePath := filepath.Join(stateDir, "model.json")
+
+ cleanup := func() {
+ os.RemoveAll(configDir)
+ os.RemoveAll(stateDir)
+ }
+
+ t.Run("fresh install", func(t *testing.T) {
+ cleanup()
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
+ })
+
+ t.Run("preserve other providers", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"provider":{"anthropic":{"apiKey":"xxx"}}}`), 0o644)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ provider := cfg["provider"].(map[string]any)
+ if provider["anthropic"] == nil {
+ t.Error("anthropic provider was removed")
+ }
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ })
+
+ t.Run("preserve other models", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"mistral":{"name":"Mistral"}}}}}`), 0o644)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenCodeModelExists(t, configPath, "mistral")
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ })
+
+ t.Run("update existing model", func(t *testing.T) {
+ cleanup()
+ o.Edit([]string{"llama3.2"})
+ o.Edit([]string{"llama3.2"})
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ })
+
+ t.Run("preserve top-level keys", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"theme":"dark","keybindings":{}}`), 0o644)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ if cfg["theme"] != "dark" {
+ t.Error("theme was removed")
+ }
+ if cfg["keybindings"] == nil {
+ t.Error("keybindings was removed")
+ }
+ })
+
+ t.Run("model state - insert at index 0", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(stateDir, 0o755)
+ os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
+ assertOpenCodeRecentModel(t, statePath, 1, "anthropic", "claude")
+ })
+
+ t.Run("model state - preserve favorites and variants", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(stateDir, 0o755)
+ os.WriteFile(statePath, []byte(`{"recent":[],"favorite":[{"providerID":"x","modelID":"y"}],"variant":{"a":"b"}}`), 0o644)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(statePath)
+ var state map[string]any
+ json.Unmarshal(data, &state)
+ if len(state["favorite"].([]any)) != 1 {
+ t.Error("favorite was modified")
+ }
+ if state["variant"].(map[string]any)["a"] != "b" {
+ t.Error("variant was modified")
+ }
+ })
+
+ t.Run("model state - deduplicate on re-add", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(stateDir, 0o755)
+ os.WriteFile(statePath, []byte(`{"recent":[{"providerID":"ollama","modelID":"llama3.2"},{"providerID":"anthropic","modelID":"claude"}],"favorite":[],"variant":{}}`), 0o644)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+ data, _ := os.ReadFile(statePath)
+ var state map[string]any
+ json.Unmarshal(data, &state)
+ recent := state["recent"].([]any)
+ if len(recent) != 2 {
+ t.Errorf("expected 2 recent entries, got %d", len(recent))
+ }
+ assertOpenCodeRecentModel(t, statePath, 0, "ollama", "llama3.2")
+ })
+
+ t.Run("remove model", func(t *testing.T) {
+ cleanup()
+ // First add two models
+ o.Edit([]string{"llama3.2", "mistral"})
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ assertOpenCodeModelExists(t, configPath, "mistral")
+
+ // Then remove one by only selecting the other
+ o.Edit([]string{"llama3.2"})
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ assertOpenCodeModelNotExists(t, configPath, "mistral")
+ })
+
+ t.Run("preserve user customizations on managed models", func(t *testing.T) {
+ cleanup()
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ // Add custom fields to the model entry (simulating user edits)
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ provider := cfg["provider"].(map[string]any)
+ ollama := provider["ollama"].(map[string]any)
+ models := ollama["models"].(map[string]any)
+ entry := models["llama3.2"].(map[string]any)
+ entry["_myPref"] = "custom-value"
+ entry["_myNum"] = 42
+ configData, _ := json.MarshalIndent(cfg, "", " ")
+ os.WriteFile(configPath, configData, 0o644)
+
+ // Re-run Edit — should preserve custom fields
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ = os.ReadFile(configPath)
+ json.Unmarshal(data, &cfg)
+ provider = cfg["provider"].(map[string]any)
+ ollama = provider["ollama"].(map[string]any)
+ models = ollama["models"].(map[string]any)
+ entry = models["llama3.2"].(map[string]any)
+
+ if entry["_myPref"] != "custom-value" {
+ t.Errorf("_myPref was lost: got %v", entry["_myPref"])
+ }
+ if entry["_myNum"] != float64(42) {
+ t.Errorf("_myNum was lost: got %v", entry["_myNum"])
+ }
+ if v, ok := entry["_launch"].(bool); !ok || !v {
+ t.Errorf("_launch marker missing or false: got %v", entry["_launch"])
+ }
+ })
+
+ t.Run("migrate legacy [Ollama] suffix entries", func(t *testing.T) {
+ cleanup()
+ // Write a config with a legacy entry (has [Ollama] suffix but no _launch marker)
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"llama3.2":{"name":"llama3.2 [Ollama]"}}}}}`), 0o644)
+
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ provider := cfg["provider"].(map[string]any)
+ ollama := provider["ollama"].(map[string]any)
+ models := ollama["models"].(map[string]any)
+ entry := models["llama3.2"].(map[string]any)
+
+ // _launch marker should be added
+ if v, ok := entry["_launch"].(bool); !ok || !v {
+ t.Errorf("_launch marker not added during migration: got %v", entry["_launch"])
+ }
+ // [Ollama] suffix should be stripped
+ if name, ok := entry["name"].(string); !ok || name != "llama3.2" {
+ t.Errorf("name suffix not stripped: got %q", entry["name"])
+ }
+ })
+
+ t.Run("remove model preserves non-ollama models", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+ // Add a non-Ollama model manually
+ os.WriteFile(configPath, []byte(`{"provider":{"ollama":{"models":{"external":{"name":"External Model"}}}}}`), 0o644)
+
+ o.Edit([]string{"llama3.2"})
+ assertOpenCodeModelExists(t, configPath, "llama3.2")
+ assertOpenCodeModelExists(t, configPath, "external") // Should be preserved
+ })
+}
+
+func assertOpenCodeModelExists(t *testing.T, path, model string) {
+ t.Helper()
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Fatal(err)
+ }
+ provider, ok := cfg["provider"].(map[string]any)
+ if !ok {
+ t.Fatal("provider not found")
+ }
+ ollama, ok := provider["ollama"].(map[string]any)
+ if !ok {
+ t.Fatal("ollama provider not found")
+ }
+ models, ok := ollama["models"].(map[string]any)
+ if !ok {
+ t.Fatal("models not found")
+ }
+ if models[model] == nil {
+ t.Errorf("model %s not found", model)
+ }
+}
+
+func assertOpenCodeModelNotExists(t *testing.T, path, model string) {
+ t.Helper()
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Fatal(err)
+ }
+ provider, ok := cfg["provider"].(map[string]any)
+ if !ok {
+ return // No provider means no model
+ }
+ ollama, ok := provider["ollama"].(map[string]any)
+ if !ok {
+ return // No ollama means no model
+ }
+ models, ok := ollama["models"].(map[string]any)
+ if !ok {
+ return // No models means no model
+ }
+ if models[model] != nil {
+ t.Errorf("model %s should not exist but was found", model)
+ }
+}
+
+func assertOpenCodeRecentModel(t *testing.T, path string, index int, providerID, modelID string) {
+ t.Helper()
+ data, err := os.ReadFile(path)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var state map[string]any
+ if err := json.Unmarshal(data, &state); err != nil {
+ t.Fatal(err)
+ }
+ recent, ok := state["recent"].([]any)
+ if !ok {
+ t.Fatal("recent not found")
+ }
+ if index >= len(recent) {
+ t.Fatalf("index %d out of range (len=%d)", index, len(recent))
+ }
+ entry, ok := recent[index].(map[string]any)
+ if !ok {
+ t.Fatal("entry is not a map")
+ }
+ if entry["providerID"] != providerID {
+ t.Errorf("expected providerID %s, got %s", providerID, entry["providerID"])
+ }
+ if entry["modelID"] != modelID {
+ t.Errorf("expected modelID %s, got %s", modelID, entry["modelID"])
+ }
+}
+
+// Edge case tests for opencode.go
+
+func TestOpenCodeEdit_CorruptedConfigJSON(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{corrupted json content`), 0o644)
+
+ // Should not panic - corrupted JSON should be treated as empty
+ err := o.Edit([]string{"llama3.2"})
+ if err != nil {
+ t.Fatalf("Edit failed with corrupted config: %v", err)
+ }
+
+ // Verify valid JSON was created
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Errorf("resulting config is not valid JSON: %v", err)
+ }
+}
+
+func TestOpenCodeEdit_CorruptedStateJSON(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
+ statePath := filepath.Join(stateDir, "model.json")
+
+ os.MkdirAll(stateDir, 0o755)
+ os.WriteFile(statePath, []byte(`{corrupted state`), 0o644)
+
+ err := o.Edit([]string{"llama3.2"})
+ if err != nil {
+ t.Fatalf("Edit failed with corrupted state: %v", err)
+ }
+
+ // Verify valid state was created
+ data, _ := os.ReadFile(statePath)
+ var state map[string]any
+ if err := json.Unmarshal(data, &state); err != nil {
+ t.Errorf("resulting state is not valid JSON: %v", err)
+ }
+}
+
+func TestOpenCodeEdit_WrongTypeProvider(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{"provider": "not a map"}`), 0o644)
+
+ err := o.Edit([]string{"llama3.2"})
+ if err != nil {
+ t.Fatalf("Edit with wrong type provider failed: %v", err)
+ }
+
+ // Verify provider is now correct type
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+
+ provider, ok := cfg["provider"].(map[string]any)
+ if !ok {
+ t.Fatalf("provider should be map after setup, got %T", cfg["provider"])
+ }
+ if provider["ollama"] == nil {
+ t.Error("ollama provider should be created")
+ }
+}
+
+func TestOpenCodeEdit_WrongTypeRecent(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ stateDir := filepath.Join(tmpDir, ".local", "state", "opencode")
+ statePath := filepath.Join(stateDir, "model.json")
+
+ os.MkdirAll(stateDir, 0o755)
+ os.WriteFile(statePath, []byte(`{"recent": "not an array", "favorite": [], "variant": {}}`), 0o644)
+
+ err := o.Edit([]string{"llama3.2"})
+ if err != nil {
+ t.Fatalf("Edit with wrong type recent failed: %v", err)
+ }
+
+ // The function should handle this gracefully
+ data, _ := os.ReadFile(statePath)
+ var state map[string]any
+ json.Unmarshal(data, &state)
+
+ // recent should be properly set after setup
+ recent, ok := state["recent"].([]any)
+ if !ok {
+ t.Logf("Note: recent type after setup is %T (documenting behavior)", state["recent"])
+ } else if len(recent) == 0 {
+ t.Logf("Note: recent is empty (documenting behavior)")
+ }
+}
+
+func TestOpenCodeEdit_EmptyModels(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+
+ os.MkdirAll(configDir, 0o755)
+ originalContent := `{"provider":{"ollama":{"models":{"existing":{}}}}}`
+ os.WriteFile(configPath, []byte(originalContent), 0o644)
+
+ // Empty models should be no-op
+ err := o.Edit([]string{})
+ if err != nil {
+ t.Fatalf("Edit with empty models failed: %v", err)
+ }
+
+ // Original content should be preserved (file not modified)
+ data, _ := os.ReadFile(configPath)
+ if string(data) != originalContent {
+ t.Errorf("empty models should not modify file, but content changed")
+ }
+}
+
+func TestOpenCodeEdit_SpecialCharsInModelName(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ // Model name with special characters (though unusual)
+ specialModel := `model-with-"quotes"`
+
+ err := o.Edit([]string{specialModel})
+ if err != nil {
+ t.Fatalf("Edit with special chars failed: %v", err)
+ }
+
+ // Verify it was stored correctly
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+ data, _ := os.ReadFile(configPath)
+
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Fatalf("resulting config is invalid JSON: %v", err)
+ }
+
+ // Model should be accessible
+ provider, _ := cfg["provider"].(map[string]any)
+ ollama, _ := provider["ollama"].(map[string]any)
+ models, _ := ollama["models"].(map[string]any)
+
+ if models[specialModel] == nil {
+ t.Errorf("model with special chars not found in config")
+ }
+}
+
+func readOpenCodeModel(t *testing.T, configPath, model string) map[string]any {
+ t.Helper()
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ provider := cfg["provider"].(map[string]any)
+ ollama := provider["ollama"].(map[string]any)
+ models := ollama["models"].(map[string]any)
+ entry, ok := models[model].(map[string]any)
+ if !ok {
+ t.Fatalf("model %s not found in config", model)
+ }
+ return entry
+}
+
+func TestOpenCodeEdit_LocalModelNoLimit(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configPath := filepath.Join(tmpDir, ".config", "opencode", "opencode.json")
+
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ entry := readOpenCodeModel(t, configPath, "llama3.2")
+ if entry["limit"] != nil {
+ t.Errorf("local model should not have limit set, got %v", entry["limit"])
+ }
+}
+
+func TestOpenCodeEdit_PreservesUserLimit(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+
+ // Set up a model with a user-configured limit
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(`{
+ "provider": {
+ "ollama": {
+ "models": {
+ "llama3.2": {
+ "name": "llama3.2",
+ "_launch": true,
+ "limit": {"context": 8192, "output": 4096}
+ }
+ }
+ }
+ }
+ }`), 0o644)
+
+ // Re-edit should preserve the user's limit (not delete it)
+ if err := o.Edit([]string{"llama3.2"}); err != nil {
+ t.Fatal(err)
+ }
+
+ entry := readOpenCodeModel(t, configPath, "llama3.2")
+ limit, ok := entry["limit"].(map[string]any)
+ if !ok {
+ t.Fatal("user-configured limit was removed")
+ }
+ if limit["context"] != float64(8192) {
+ t.Errorf("context limit changed: got %v, want 8192", limit["context"])
+ }
+ if limit["output"] != float64(4096) {
+ t.Errorf("output limit changed: got %v, want 4096", limit["output"])
+ }
+}
+
+func TestOpenCodeEdit_CloudModelLimitStructure(t *testing.T) {
+ // Verify that when a cloud model entry has limits set (as Edit would do),
+ // the structure matches what opencode expects and re-edit preserves them.
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".config", "opencode")
+ configPath := filepath.Join(configDir, "opencode.json")
+
+ expected := cloudModelLimits["glm-4.7"]
+
+ // Simulate a cloud model that already has the limit set by a previous Edit
+ os.MkdirAll(configDir, 0o755)
+ os.WriteFile(configPath, []byte(fmt.Sprintf(`{
+ "provider": {
+ "ollama": {
+ "models": {
+ "glm-4.7:cloud": {
+ "name": "glm-4.7:cloud",
+ "_launch": true,
+ "limit": {"context": %d, "output": %d}
+ }
+ }
+ }
+ }
+ }`, expected.Context, expected.Output)), 0o644)
+
+ // Re-edit should preserve the cloud model limit
+ if err := o.Edit([]string{"glm-4.7:cloud"}); err != nil {
+ t.Fatal(err)
+ }
+
+ entry := readOpenCodeModel(t, configPath, "glm-4.7:cloud")
+ limit, ok := entry["limit"].(map[string]any)
+ if !ok {
+ t.Fatal("cloud model limit was removed on re-edit")
+ }
+ if limit["context"] != float64(expected.Context) {
+ t.Errorf("context = %v, want %d", limit["context"], expected.Context)
+ }
+ if limit["output"] != float64(expected.Output) {
+ t.Errorf("output = %v, want %d", limit["output"], expected.Output)
+ }
+}
+
+func TestLookupCloudModelLimit(t *testing.T) {
+ tests := []struct {
+ name string
+ wantOK bool
+ wantContext int
+ wantOutput int
+ }{
+ {"glm-4.7", true, 202_752, 131_072},
+ {"glm-4.7:cloud", true, 202_752, 131_072},
+ {"kimi-k2.5", true, 262_144, 262_144},
+ {"kimi-k2.5:cloud", true, 262_144, 262_144},
+ {"deepseek-v3.2", true, 163_840, 65_536},
+ {"deepseek-v3.2:cloud", true, 163_840, 65_536},
+ {"qwen3-coder:480b", true, 262_144, 65_536},
+ {"qwen3-coder-next:cloud", true, 262_144, 32_768},
+ {"llama3.2", false, 0, 0},
+ {"unknown-model:cloud", false, 0, 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ l, ok := lookupCloudModelLimit(tt.name)
+ if ok != tt.wantOK {
+ t.Errorf("lookupCloudModelLimit(%q) ok = %v, want %v", tt.name, ok, tt.wantOK)
+ }
+ if ok {
+ if l.Context != tt.wantContext {
+ t.Errorf("context = %d, want %d", l.Context, tt.wantContext)
+ }
+ if l.Output != tt.wantOutput {
+ t.Errorf("output = %d, want %d", l.Output, tt.wantOutput)
+ }
+ }
+ })
+ }
+}
+
+func TestOpenCodeModels_NoConfig(t *testing.T) {
+ o := &OpenCode{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ models := o.Models()
+ if len(models) > 0 {
+ t.Errorf("expected nil/empty for missing config, got %v", models)
+ }
+}
diff --git a/cmd/config/pi.go b/cmd/config/pi.go
new file mode 100644
index 00000000000..9dd84ee8779
--- /dev/null
+++ b/cmd/config/pi.go
@@ -0,0 +1,237 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "slices"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/envconfig"
+ "github.com/ollama/ollama/types/model"
+)
+
+// Pi implements Runner and Editor for Pi (Pi Coding Agent) integration
+type Pi struct{}
+
+func (p *Pi) String() string { return "Pi" }
+
+func (p *Pi) Run(model string, args []string) error {
+ if _, err := exec.LookPath("pi"); err != nil {
+ return fmt.Errorf("pi is not installed, install with: npm install -g @mariozechner/pi-coding-agent")
+ }
+
+ // Call Edit() to ensure config is up-to-date before launch
+ models := []string{model}
+ if config, err := loadIntegration("pi"); err == nil && len(config.Models) > 0 {
+ models = config.Models
+ }
+ if err := p.Edit(models); err != nil {
+ return fmt.Errorf("setup failed: %w", err)
+ }
+
+ cmd := exec.Command("pi", args...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+ return cmd.Run()
+}
+
+func (p *Pi) Paths() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+
+ var paths []string
+ modelsPath := filepath.Join(home, ".pi", "agent", "models.json")
+ if _, err := os.Stat(modelsPath); err == nil {
+ paths = append(paths, modelsPath)
+ }
+ settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
+ if _, err := os.Stat(settingsPath); err == nil {
+ paths = append(paths, settingsPath)
+ }
+ return paths
+}
+
+func (p *Pi) Edit(models []string) error {
+ if len(models) == 0 {
+ return nil
+ }
+
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return err
+ }
+
+ configPath := filepath.Join(home, ".pi", "agent", "models.json")
+ if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
+ return err
+ }
+
+ config := make(map[string]any)
+ if data, err := os.ReadFile(configPath); err == nil {
+ _ = json.Unmarshal(data, &config)
+ }
+
+ providers, ok := config["providers"].(map[string]any)
+ if !ok {
+ providers = make(map[string]any)
+ }
+
+ ollama, ok := providers["ollama"].(map[string]any)
+ if !ok {
+ ollama = map[string]any{
+ "baseUrl": envconfig.Host().String() + "/v1",
+ "api": "openai-completions",
+ "apiKey": "ollama",
+ }
+ }
+
+ existingModels, ok := ollama["models"].([]any)
+ if !ok {
+ existingModels = make([]any, 0)
+ }
+
+ // Build set of selected models to track which need to be added
+ selectedSet := make(map[string]bool, len(models))
+ for _, m := range models {
+ selectedSet[m] = true
+ }
+
+ // Build new models list:
+ // 1. Keep user-managed models (no _launch marker) - untouched
+ // 2. Keep ollama-managed models (_launch marker) that are still selected
+ // 3. Add new ollama-managed models
+ var newModels []any
+ for _, m := range existingModels {
+ if modelObj, ok := m.(map[string]any); ok {
+ if id, ok := modelObj["id"].(string); ok {
+ // User-managed model (no _launch marker) - always preserve
+ if !isPiOllamaModel(modelObj) {
+ newModels = append(newModels, m)
+ } else if selectedSet[id] {
+ // Ollama-managed and still selected - keep it
+ newModels = append(newModels, m)
+ selectedSet[id] = false
+ }
+ }
+ }
+ }
+
+ // Add newly selected models that weren't already in the list
+ client := api.NewClient(envconfig.Host(), http.DefaultClient)
+ ctx := context.Background()
+ for _, model := range models {
+ if selectedSet[model] {
+ newModels = append(newModels, createConfig(ctx, client, model))
+ }
+ }
+
+ ollama["models"] = newModels
+ providers["ollama"] = ollama
+ config["providers"] = providers
+
+ configData, err := json.MarshalIndent(config, "", " ")
+ if err != nil {
+ return err
+ }
+ if err := writeWithBackup(configPath, configData); err != nil {
+ return err
+ }
+
+ // Update settings.json with default provider and model
+ settingsPath := filepath.Join(home, ".pi", "agent", "settings.json")
+ settings := make(map[string]any)
+ if data, err := os.ReadFile(settingsPath); err == nil {
+ _ = json.Unmarshal(data, &settings)
+ }
+
+ settings["defaultProvider"] = "ollama"
+ settings["defaultModel"] = models[0]
+
+ settingsData, err := json.MarshalIndent(settings, "", " ")
+ if err != nil {
+ return err
+ }
+ return writeWithBackup(settingsPath, settingsData)
+}
+
+func (p *Pi) Models() []string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return nil
+ }
+
+ configPath := filepath.Join(home, ".pi", "agent", "models.json")
+ config, err := readJSONFile(configPath)
+ if err != nil {
+ return nil
+ }
+
+ providers, _ := config["providers"].(map[string]any)
+ ollama, _ := providers["ollama"].(map[string]any)
+ models, _ := ollama["models"].([]any)
+
+ var result []string
+ for _, m := range models {
+ if modelObj, ok := m.(map[string]any); ok {
+ if id, ok := modelObj["id"].(string); ok {
+ result = append(result, id)
+ }
+ }
+ }
+ slices.Sort(result)
+ return result
+}
+
+// isPiOllamaModel reports whether a model config entry is managed by ollama launch
+func isPiOllamaModel(cfg map[string]any) bool {
+ if v, ok := cfg["_launch"].(bool); ok && v {
+ return true
+ }
+ return false
+}
+
+// createConfig builds Pi model config with capability detection
+func createConfig(ctx context.Context, client *api.Client, modelID string) map[string]any {
+ cfg := map[string]any{
+ "id": modelID,
+ "_launch": true,
+ }
+
+ resp, err := client.Show(ctx, &api.ShowRequest{Model: modelID})
+ if err != nil {
+ return cfg
+ }
+
+ // Set input types based on vision capability
+ if slices.Contains(resp.Capabilities, model.CapabilityVision) {
+ cfg["input"] = []string{"text", "image"}
+ } else {
+ cfg["input"] = []string{"text"}
+ }
+
+ // Set reasoning based on thinking capability
+ if slices.Contains(resp.Capabilities, model.CapabilityThinking) {
+ cfg["reasoning"] = true
+ }
+
+ // Extract context window from ModelInfo
+ for key, val := range resp.ModelInfo {
+ if strings.HasSuffix(key, ".context_length") {
+ if ctxLen, ok := val.(float64); ok && ctxLen > 0 {
+ cfg["contextWindow"] = int(ctxLen)
+ }
+ break
+ }
+ }
+
+ return cfg
+}
diff --git a/cmd/config/pi_test.go b/cmd/config/pi_test.go
new file mode 100644
index 00000000000..18c62be9498
--- /dev/null
+++ b/cmd/config/pi_test.go
@@ -0,0 +1,830 @@
+package config
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/types/model"
+)
+
+func TestPiIntegration(t *testing.T) {
+ pi := &Pi{}
+
+ t.Run("String", func(t *testing.T) {
+ if got := pi.String(); got != "Pi" {
+ t.Errorf("String() = %q, want %q", got, "Pi")
+ }
+ })
+
+ t.Run("implements Runner", func(t *testing.T) {
+ var _ Runner = pi
+ })
+
+ t.Run("implements Editor", func(t *testing.T) {
+ var _ Editor = pi
+ })
+}
+
+func TestPiPaths(t *testing.T) {
+ pi := &Pi{}
+
+ t.Run("returns empty when no config exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ paths := pi.Paths()
+ if len(paths) != 0 {
+ t.Errorf("Paths() = %v, want empty", paths)
+ }
+ })
+
+ t.Run("returns path when config exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".pi", "agent")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ configPath := filepath.Join(configDir, "models.json")
+ if err := os.WriteFile(configPath, []byte("{}"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ paths := pi.Paths()
+ if len(paths) != 1 || paths[0] != configPath {
+ t.Errorf("Paths() = %v, want [%s]", paths, configPath)
+ }
+ })
+}
+
+func TestPiEdit(t *testing.T) {
+ // Mock Ollama server for createConfig calls during Edit
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":[],"model_info":{}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+ t.Setenv("OLLAMA_HOST", srv.URL)
+
+ pi := &Pi{}
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".pi", "agent")
+ configPath := filepath.Join(configDir, "models.json")
+
+ cleanup := func() {
+ os.RemoveAll(configDir)
+ }
+
+ readConfig := func() map[string]any {
+ data, _ := os.ReadFile(configPath)
+ var cfg map[string]any
+ json.Unmarshal(data, &cfg)
+ return cfg
+ }
+
+ t.Run("returns nil for empty models", func(t *testing.T) {
+ if err := pi.Edit([]string{}); err != nil {
+ t.Errorf("Edit([]) error = %v, want nil", err)
+ }
+ })
+
+ t.Run("creates config with models", func(t *testing.T) {
+ cleanup()
+
+ models := []string{"llama3.2", "qwen3:8b"}
+ if err := pi.Edit(models); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ cfg := readConfig()
+
+ providers, ok := cfg["providers"].(map[string]any)
+ if !ok {
+ t.Error("Config missing providers")
+ }
+
+ ollama, ok := providers["ollama"].(map[string]any)
+ if !ok {
+ t.Error("Providers missing ollama")
+ }
+
+ modelsArray, ok := ollama["models"].([]any)
+ if !ok || len(modelsArray) != 2 {
+ t.Errorf("Expected 2 models, got %v", modelsArray)
+ }
+
+ if ollama["baseUrl"] == nil {
+ t.Error("Missing baseUrl")
+ }
+ if ollama["api"] != "openai-completions" {
+ t.Errorf("Expected api=openai-completions, got %v", ollama["api"])
+ }
+ if ollama["apiKey"] != "ollama" {
+ t.Errorf("Expected apiKey=ollama, got %v", ollama["apiKey"])
+ }
+ })
+
+ t.Run("updates existing config preserving ollama provider settings", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ existingConfig := `{
+ "providers": {
+ "ollama": {
+ "baseUrl": "http://custom:8080/v1",
+ "api": "custom-api",
+ "apiKey": "custom-key",
+ "models": [
+ {"id": "old-model", "_launch": true}
+ ]
+ }
+ }
+ }`
+ if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := []string{"new-model"}
+ if err := pi.Edit(models); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ cfg := readConfig()
+ providers := cfg["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+
+ if ollama["baseUrl"] != "http://custom:8080/v1" {
+ t.Errorf("Custom baseUrl not preserved, got %v", ollama["baseUrl"])
+ }
+ if ollama["api"] != "custom-api" {
+ t.Errorf("Custom api not preserved, got %v", ollama["api"])
+ }
+ if ollama["apiKey"] != "custom-key" {
+ t.Errorf("Custom apiKey not preserved, got %v", ollama["apiKey"])
+ }
+
+ modelsArray := ollama["models"].([]any)
+ if len(modelsArray) != 1 {
+ t.Errorf("Expected 1 model after update, got %d", len(modelsArray))
+ } else {
+ modelEntry := modelsArray[0].(map[string]any)
+ if modelEntry["id"] != "new-model" {
+ t.Errorf("Expected new-model, got %v", modelEntry["id"])
+ }
+ // Verify _launch marker is present
+ if modelEntry["_launch"] != true {
+ t.Errorf("Expected _launch marker to be true")
+ }
+ }
+ })
+
+ t.Run("replaces old models with new ones", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ // Old models must have _launch marker to be managed by us
+ existingConfig := `{
+ "providers": {
+ "ollama": {
+ "baseUrl": "http://localhost:11434/v1",
+ "api": "openai-completions",
+ "apiKey": "ollama",
+ "models": [
+ {"id": "old-model-1", "_launch": true},
+ {"id": "old-model-2", "_launch": true}
+ ]
+ }
+ }
+ }`
+ if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ newModels := []string{"new-model-1", "new-model-2"}
+ if err := pi.Edit(newModels); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ cfg := readConfig()
+ providers := cfg["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelsArray := ollama["models"].([]any)
+
+ if len(modelsArray) != 2 {
+ t.Errorf("Expected 2 models, got %d", len(modelsArray))
+ }
+
+ modelIDs := make(map[string]bool)
+ for _, m := range modelsArray {
+ modelObj := m.(map[string]any)
+ id := modelObj["id"].(string)
+ modelIDs[id] = true
+ }
+
+ if !modelIDs["new-model-1"] || !modelIDs["new-model-2"] {
+ t.Errorf("Expected new models, got %v", modelIDs)
+ }
+ if modelIDs["old-model-1"] || modelIDs["old-model-2"] {
+ t.Errorf("Old models should have been removed, got %v", modelIDs)
+ }
+ })
+
+ t.Run("handles partial overlap in model list", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ // Models must have _launch marker to be managed
+ existingConfig := `{
+ "providers": {
+ "ollama": {
+ "baseUrl": "http://localhost:11434/v1",
+ "api": "openai-completions",
+ "apiKey": "ollama",
+ "models": [
+ {"id": "keep-model", "_launch": true},
+ {"id": "remove-model", "_launch": true}
+ ]
+ }
+ }
+ }`
+ if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ newModels := []string{"keep-model", "add-model"}
+ if err := pi.Edit(newModels); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ cfg := readConfig()
+ providers := cfg["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelsArray := ollama["models"].([]any)
+
+ if len(modelsArray) != 2 {
+ t.Errorf("Expected 2 models, got %d", len(modelsArray))
+ }
+
+ modelIDs := make(map[string]bool)
+ for _, m := range modelsArray {
+ modelObj := m.(map[string]any)
+ id := modelObj["id"].(string)
+ modelIDs[id] = true
+ }
+
+ if !modelIDs["keep-model"] || !modelIDs["add-model"] {
+ t.Errorf("Expected keep-model and add-model, got %v", modelIDs)
+ }
+ if modelIDs["remove-model"] {
+ t.Errorf("remove-model should have been removed")
+ }
+ })
+
+ t.Run("handles corrupt config gracefully", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := []string{"test-model"}
+ if err := pi.Edit(models); err != nil {
+ t.Fatalf("Edit() should not fail with corrupt config, got %v", err)
+ }
+
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ t.Fatalf("Failed to read config: %v", err)
+ }
+
+ var cfg map[string]any
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ t.Fatalf("Config should be valid after Edit, got parse error: %v", err)
+ }
+
+ providers := cfg["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelsArray := ollama["models"].([]any)
+
+ if len(modelsArray) != 1 {
+ t.Errorf("Expected 1 model, got %d", len(modelsArray))
+ }
+ })
+
+ // CRITICAL SAFETY TEST: verifies we don't stomp on user configs
+ t.Run("preserves user-managed models without _launch marker", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ // User has manually configured models in ollama provider (no _launch marker)
+ existingConfig := `{
+ "providers": {
+ "ollama": {
+ "baseUrl": "http://localhost:11434/v1",
+ "api": "openai-completions",
+ "apiKey": "ollama",
+ "models": [
+ {"id": "user-model-1"},
+ {"id": "user-model-2", "customField": "preserved"},
+ {"id": "ollama-managed", "_launch": true}
+ ]
+ }
+ }
+ }`
+ if err := os.WriteFile(configPath, []byte(existingConfig), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ // Add a new ollama-managed model
+ newModels := []string{"new-ollama-model"}
+ if err := pi.Edit(newModels); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ cfg := readConfig()
+ providers := cfg["providers"].(map[string]any)
+ ollama := providers["ollama"].(map[string]any)
+ modelsArray := ollama["models"].([]any)
+
+ // Should have: new-ollama-model (managed) + 2 user models (preserved)
+ if len(modelsArray) != 3 {
+ t.Errorf("Expected 3 models (1 new managed + 2 preserved user models), got %d", len(modelsArray))
+ }
+
+ modelIDs := make(map[string]map[string]any)
+ for _, m := range modelsArray {
+ modelObj := m.(map[string]any)
+ id := modelObj["id"].(string)
+ modelIDs[id] = modelObj
+ }
+
+ // Verify new model has _launch marker
+ if m, ok := modelIDs["new-ollama-model"]; !ok {
+ t.Errorf("new-ollama-model should be present")
+ } else if m["_launch"] != true {
+ t.Errorf("new-ollama-model should have _launch marker")
+ }
+
+ // Verify user models are preserved
+ if _, ok := modelIDs["user-model-1"]; !ok {
+ t.Errorf("user-model-1 should be preserved")
+ }
+ if _, ok := modelIDs["user-model-2"]; !ok {
+ t.Errorf("user-model-2 should be preserved")
+ } else if modelIDs["user-model-2"]["customField"] != "preserved" {
+ t.Errorf("user-model-2 customField should be preserved")
+ }
+
+ // Verify old ollama-managed model is removed (not in new list)
+ if _, ok := modelIDs["ollama-managed"]; ok {
+ t.Errorf("ollama-managed should be removed (old ollama model not in new selection)")
+ }
+ })
+
+ t.Run("updates settings.json with default provider and model", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ // Create existing settings with other fields
+ settingsPath := filepath.Join(configDir, "settings.json")
+ existingSettings := `{
+ "theme": "dark",
+ "customSetting": "value",
+ "defaultProvider": "anthropic",
+ "defaultModel": "claude-3"
+ }`
+ if err := os.WriteFile(settingsPath, []byte(existingSettings), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := []string{"llama3.2"}
+ if err := pi.Edit(models); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ data, err := os.ReadFile(settingsPath)
+ if err != nil {
+ t.Fatalf("Failed to read settings: %v", err)
+ }
+
+ var settings map[string]any
+ if err := json.Unmarshal(data, &settings); err != nil {
+ t.Fatalf("Failed to parse settings: %v", err)
+ }
+
+ // Verify defaultProvider is set to ollama
+ if settings["defaultProvider"] != "ollama" {
+ t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
+ }
+
+ // Verify defaultModel is set to first model
+ if settings["defaultModel"] != "llama3.2" {
+ t.Errorf("defaultModel = %v, want llama3.2", settings["defaultModel"])
+ }
+
+ // Verify other fields are preserved
+ if settings["theme"] != "dark" {
+ t.Errorf("theme = %v, want dark (preserved)", settings["theme"])
+ }
+ if settings["customSetting"] != "value" {
+ t.Errorf("customSetting = %v, want value (preserved)", settings["customSetting"])
+ }
+ })
+
+ t.Run("creates settings.json if it does not exist", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ models := []string{"qwen3:8b"}
+ if err := pi.Edit(models); err != nil {
+ t.Fatalf("Edit() error = %v", err)
+ }
+
+ settingsPath := filepath.Join(configDir, "settings.json")
+ data, err := os.ReadFile(settingsPath)
+ if err != nil {
+ t.Fatalf("settings.json should be created: %v", err)
+ }
+
+ var settings map[string]any
+ if err := json.Unmarshal(data, &settings); err != nil {
+ t.Fatalf("Failed to parse settings: %v", err)
+ }
+
+ if settings["defaultProvider"] != "ollama" {
+ t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
+ }
+ if settings["defaultModel"] != "qwen3:8b" {
+ t.Errorf("defaultModel = %v, want qwen3:8b", settings["defaultModel"])
+ }
+ })
+
+ t.Run("handles corrupt settings.json gracefully", func(t *testing.T) {
+ cleanup()
+ os.MkdirAll(configDir, 0o755)
+
+ // Create corrupt settings
+ settingsPath := filepath.Join(configDir, "settings.json")
+ if err := os.WriteFile(settingsPath, []byte("{invalid"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := []string{"test-model"}
+ if err := pi.Edit(models); err != nil {
+ t.Fatalf("Edit() should not fail with corrupt settings, got %v", err)
+ }
+
+ data, err := os.ReadFile(settingsPath)
+ if err != nil {
+ t.Fatalf("Failed to read settings: %v", err)
+ }
+
+ var settings map[string]any
+ if err := json.Unmarshal(data, &settings); err != nil {
+ t.Fatalf("settings.json should be valid after Edit, got parse error: %v", err)
+ }
+
+ if settings["defaultProvider"] != "ollama" {
+ t.Errorf("defaultProvider = %v, want ollama", settings["defaultProvider"])
+ }
+ if settings["defaultModel"] != "test-model" {
+ t.Errorf("defaultModel = %v, want test-model", settings["defaultModel"])
+ }
+ })
+}
+
+func TestPiModels(t *testing.T) {
+ pi := &Pi{}
+
+ t.Run("returns nil when no config exists", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ models := pi.Models()
+ if models != nil {
+ t.Errorf("Models() = %v, want nil", models)
+ }
+ })
+
+ t.Run("returns models from config", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".pi", "agent")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ config := `{
+ "providers": {
+ "ollama": {
+ "models": [
+ {"id": "llama3.2"},
+ {"id": "qwen3:8b"}
+ ]
+ }
+ }
+ }`
+ configPath := filepath.Join(configDir, "models.json")
+ if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := pi.Models()
+ if len(models) != 2 {
+ t.Errorf("Models() returned %d models, want 2", len(models))
+ }
+ if models[0] != "llama3.2" || models[1] != "qwen3:8b" {
+ t.Errorf("Models() = %v, want [llama3.2 qwen3:8b] (sorted)", models)
+ }
+ })
+
+ t.Run("returns sorted models", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".pi", "agent")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ config := `{
+ "providers": {
+ "ollama": {
+ "models": [
+ {"id": "z-model"},
+ {"id": "a-model"},
+ {"id": "m-model"}
+ ]
+ }
+ }
+ }`
+ configPath := filepath.Join(configDir, "models.json")
+ if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := pi.Models()
+ if models[0] != "a-model" || models[1] != "m-model" || models[2] != "z-model" {
+ t.Errorf("Models() = %v, want [a-model m-model z-model] (sorted)", models)
+ }
+ })
+
+ t.Run("returns nil when models array is missing", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".pi", "agent")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ config := `{
+ "providers": {
+ "ollama": {}
+ }
+ }`
+ configPath := filepath.Join(configDir, "models.json")
+ if err := os.WriteFile(configPath, []byte(config), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := pi.Models()
+ if models != nil {
+ t.Errorf("Models() = %v, want nil when models array is missing", models)
+ }
+ })
+
+ t.Run("handles corrupt config gracefully", func(t *testing.T) {
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configDir := filepath.Join(tmpDir, ".pi", "agent")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ configPath := filepath.Join(configDir, "models.json")
+ if err := os.WriteFile(configPath, []byte("{invalid json}"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ models := pi.Models()
+ if models != nil {
+ t.Errorf("Models() = %v, want nil for corrupt config", models)
+ }
+ })
+}
+
+func TestIsPiOllamaModel(t *testing.T) {
+ tests := []struct {
+ name string
+ cfg map[string]any
+ want bool
+ }{
+ {"with _launch true", map[string]any{"id": "m", "_launch": true}, true},
+ {"with _launch false", map[string]any{"id": "m", "_launch": false}, false},
+ {"without _launch", map[string]any{"id": "m"}, false},
+ {"with _launch non-bool", map[string]any{"id": "m", "_launch": "yes"}, false},
+ {"empty map", map[string]any{}, false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := isPiOllamaModel(tt.cfg); got != tt.want {
+ t.Errorf("isPiOllamaModel(%v) = %v, want %v", tt.cfg, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestCreateConfig(t *testing.T) {
+ t.Run("sets vision input when model has vision capability", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":["vision"],"model_info":{}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "llava:7b")
+
+ if cfg["id"] != "llava:7b" {
+ t.Errorf("id = %v, want llava:7b", cfg["id"])
+ }
+ if cfg["_launch"] != true {
+ t.Error("expected _launch = true")
+ }
+ input, ok := cfg["input"].([]string)
+ if !ok || len(input) != 2 || input[0] != "text" || input[1] != "image" {
+ t.Errorf("input = %v, want [text image]", cfg["input"])
+ }
+ })
+
+ t.Run("sets text-only input when model lacks vision", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":["completion"],"model_info":{}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "llama3.2")
+
+ input, ok := cfg["input"].([]string)
+ if !ok || len(input) != 1 || input[0] != "text" {
+ t.Errorf("input = %v, want [text]", cfg["input"])
+ }
+ if _, ok := cfg["reasoning"]; ok {
+ t.Error("reasoning should not be set for non-thinking model")
+ }
+ })
+
+ t.Run("sets reasoning when model has thinking capability", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":["thinking"],"model_info":{}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "qwq")
+
+ if cfg["reasoning"] != true {
+ t.Error("expected reasoning = true for thinking model")
+ }
+ })
+
+ t.Run("extracts context window from model info", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":131072}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "llama3.2")
+
+ if cfg["contextWindow"] != 131072 {
+ t.Errorf("contextWindow = %v, want 131072", cfg["contextWindow"])
+ }
+ })
+
+ t.Run("handles all capabilities together", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":["vision","thinking"],"model_info":{"qwen3.context_length":32768}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "qwen3-vision")
+
+ input := cfg["input"].([]string)
+ if len(input) != 2 || input[0] != "text" || input[1] != "image" {
+ t.Errorf("input = %v, want [text image]", input)
+ }
+ if cfg["reasoning"] != true {
+ t.Error("expected reasoning = true")
+ }
+ if cfg["contextWindow"] != 32768 {
+ t.Errorf("contextWindow = %v, want 32768", cfg["contextWindow"])
+ }
+ })
+
+ t.Run("returns minimal config when show fails", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusNotFound)
+ fmt.Fprintf(w, `{"error":"model not found"}`)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "missing-model")
+
+ if cfg["id"] != "missing-model" {
+ t.Errorf("id = %v, want missing-model", cfg["id"])
+ }
+ if cfg["_launch"] != true {
+ t.Error("expected _launch = true")
+ }
+ // Should not have capability fields
+ if _, ok := cfg["input"]; ok {
+ t.Error("input should not be set when show fails")
+ }
+ if _, ok := cfg["reasoning"]; ok {
+ t.Error("reasoning should not be set when show fails")
+ }
+ if _, ok := cfg["contextWindow"]; ok {
+ t.Error("contextWindow should not be set when show fails")
+ }
+ })
+
+ t.Run("skips zero context length", func(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if r.URL.Path == "/api/show" {
+ fmt.Fprintf(w, `{"capabilities":[],"model_info":{"llama.context_length":0}}`)
+ return
+ }
+ w.WriteHeader(http.StatusNotFound)
+ }))
+ defer srv.Close()
+
+ u, _ := url.Parse(srv.URL)
+ client := api.NewClient(u, srv.Client())
+
+ cfg := createConfig(context.Background(), client, "test-model")
+
+ if _, ok := cfg["contextWindow"]; ok {
+ t.Error("contextWindow should not be set for zero value")
+ }
+ })
+}
+
+// Ensure Capability constants used in createConfig match expected values
+func TestPiCapabilityConstants(t *testing.T) {
+ if model.CapabilityVision != "vision" {
+ t.Errorf("CapabilityVision = %q, want %q", model.CapabilityVision, "vision")
+ }
+ if model.CapabilityThinking != "thinking" {
+ t.Errorf("CapabilityThinking = %q, want %q", model.CapabilityThinking, "thinking")
+ }
+}
diff --git a/cmd/config/selector.go b/cmd/config/selector.go
new file mode 100644
index 00000000000..bcd0b749f92
--- /dev/null
+++ b/cmd/config/selector.go
@@ -0,0 +1,58 @@
+package config
+
+import (
+ "errors"
+ "fmt"
+ "os"
+
+ "golang.org/x/term"
+)
+
+// ANSI escape sequences for terminal formatting.
+const (
+ ansiBold = "\033[1m"
+ ansiReset = "\033[0m"
+ ansiGray = "\033[37m"
+ ansiGreen = "\033[32m"
+)
+
+// ErrCancelled is returned when the user cancels a selection.
+var ErrCancelled = errors.New("cancelled")
+
+// errCancelled is kept as an alias for backward compatibility within the package.
+var errCancelled = ErrCancelled
+
+// DefaultConfirmPrompt provides a TUI-based confirmation prompt.
+// When set, confirmPrompt delegates to it instead of using raw terminal I/O.
+var DefaultConfirmPrompt func(prompt string) (bool, error)
+
+func confirmPrompt(prompt string) (bool, error) {
+ if DefaultConfirmPrompt != nil {
+ return DefaultConfirmPrompt(prompt)
+ }
+
+ fd := int(os.Stdin.Fd())
+ oldState, err := term.MakeRaw(fd)
+ if err != nil {
+ return false, err
+ }
+ defer term.Restore(fd, oldState)
+
+ fmt.Fprintf(os.Stderr, "%s (\033[1my\033[0m/n) ", prompt)
+
+ buf := make([]byte, 1)
+ for {
+ if _, err := os.Stdin.Read(buf); err != nil {
+ return false, err
+ }
+
+ switch buf[0] {
+ case 'Y', 'y', 13:
+ fmt.Fprintf(os.Stderr, "yes\r\n")
+ return true, nil
+ case 'N', 'n', 27, 3:
+ fmt.Fprintf(os.Stderr, "no\r\n")
+ return false, nil
+ }
+ }
+}
diff --git a/cmd/config/selector_test.go b/cmd/config/selector_test.go
new file mode 100644
index 00000000000..3e84d1b5db3
--- /dev/null
+++ b/cmd/config/selector_test.go
@@ -0,0 +1,19 @@
+package config
+
+import (
+ "testing"
+)
+
+func TestErrCancelled(t *testing.T) {
+ t.Run("NotNil", func(t *testing.T) {
+ if errCancelled == nil {
+ t.Error("errCancelled should not be nil")
+ }
+ })
+
+ t.Run("Message", func(t *testing.T) {
+ if errCancelled.Error() != "cancelled" {
+ t.Errorf("expected 'cancelled', got %q", errCancelled.Error())
+ }
+ })
+}
diff --git a/cmd/editor_unix.go b/cmd/editor_unix.go
new file mode 100644
index 00000000000..0a7848c83b6
--- /dev/null
+++ b/cmd/editor_unix.go
@@ -0,0 +1,5 @@
+//go:build !windows
+
+package cmd
+
+const defaultEditor = "vi"
diff --git a/cmd/editor_windows.go b/cmd/editor_windows.go
new file mode 100644
index 00000000000..ed428859d3b
--- /dev/null
+++ b/cmd/editor_windows.go
@@ -0,0 +1,5 @@
+//go:build windows
+
+package cmd
+
+const defaultEditor = "edit"
diff --git a/cmd/interactive.go b/cmd/interactive.go
index aad3eccfca2..1f91f9eca30 100644
--- a/cmd/interactive.go
+++ b/cmd/interactive.go
@@ -7,6 +7,7 @@ import (
"io"
"net/http"
"os"
+ "os/exec"
"path/filepath"
"regexp"
"slices"
@@ -79,6 +80,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " Ctrl + w Delete the word before the cursor")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, " Ctrl + l Clear the screen")
+ fmt.Fprintln(os.Stderr, " Ctrl + g Open default editor to compose a prompt")
fmt.Fprintln(os.Stderr, " Ctrl + c Stop the model from responding")
fmt.Fprintln(os.Stderr, " Ctrl + d Exit ollama (/bye)")
fmt.Fprintln(os.Stderr, "")
@@ -147,6 +149,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
scanner.Prompt.UseAlt = false
sb.Reset()
+ continue
+ case errors.Is(err, readline.ErrEditPrompt):
+ sb.Reset()
+ content, err := editInExternalEditor(line)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "error: %v\n", err)
+ continue
+ }
+ if strings.TrimSpace(content) == "" {
+ continue
+ }
+ scanner.Prefill = content
continue
case err != nil:
return err
@@ -159,6 +173,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
sb.WriteString(before)
if !ok {
fmt.Fprintln(&sb)
+ scanner.Prompt.UseAlt = true
continue
}
@@ -597,6 +612,57 @@ func extractFileData(input string) (string, []api.ImageData, error) {
return strings.TrimSpace(input), imgs, nil
}
+func editInExternalEditor(content string) (string, error) {
+ editor := envconfig.Editor()
+ if editor == "" {
+ editor = os.Getenv("VISUAL")
+ }
+ if editor == "" {
+ editor = os.Getenv("EDITOR")
+ }
+ if editor == "" {
+ editor = defaultEditor
+ }
+
+ // Check that the editor binary exists
+ name := strings.Fields(editor)[0]
+ if _, err := exec.LookPath(name); err != nil {
+ return "", fmt.Errorf("editor %q not found, set OLLAMA_EDITOR to the path of your preferred editor", name)
+ }
+
+ tmpFile, err := os.CreateTemp("", "ollama-prompt-*.txt")
+ if err != nil {
+ return "", fmt.Errorf("creating temp file: %w", err)
+ }
+ defer os.Remove(tmpFile.Name())
+
+ if content != "" {
+ if _, err := tmpFile.WriteString(content); err != nil {
+ tmpFile.Close()
+ return "", fmt.Errorf("writing to temp file: %w", err)
+ }
+ }
+ tmpFile.Close()
+
+ args := strings.Fields(editor)
+ args = append(args, tmpFile.Name())
+ cmd := exec.Command(args[0], args[1:]...)
+ cmd.Stdin = os.Stdin
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Run(); err != nil {
+ return "", fmt.Errorf("editor exited with error: %w", err)
+ }
+
+ data, err := os.ReadFile(tmpFile.Name())
+ if err != nil {
+ return "", fmt.Errorf("reading temp file: %w", err)
+ }
+
+ return strings.TrimRight(string(data), "\n"), nil
+}
+
func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
diff --git a/cmd/start_darwin.go b/cmd/start_darwin.go
index 05a4551e1ea..008adf15ece 100644
--- a/cmd/start_darwin.go
+++ b/cmd/start_darwin.go
@@ -10,19 +10,21 @@ import (
"github.com/ollama/ollama/api"
)
+var errNotRunning = errors.New("could not connect to ollama server, run 'ollama serve' to start it")
+
func startApp(ctx context.Context, client *api.Client) error {
exe, err := os.Executable()
if err != nil {
- return err
+ return errNotRunning
}
link, err := os.Readlink(exe)
if err != nil {
- return err
+ return errNotRunning
}
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
m := r.FindStringSubmatch(link)
if len(m) != 1 {
- return errors.New("could not find ollama app")
+ return errNotRunning
}
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
return err
diff --git a/cmd/tui/confirm.go b/cmd/tui/confirm.go
new file mode 100644
index 00000000000..b8f92b12480
--- /dev/null
+++ b/cmd/tui/confirm.go
@@ -0,0 +1,109 @@
+package tui
+
+import (
+ "fmt"
+
+ tea "github.com/charmbracelet/bubbletea"
+ "github.com/charmbracelet/lipgloss"
+)
+
+var (
+ confirmActiveStyle = lipgloss.NewStyle().
+ Bold(true).
+ Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
+
+ confirmInactiveStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
+)
+
+type confirmModel struct {
+ prompt string
+ yes bool
+ confirmed bool
+ cancelled bool
+ width int
+}
+
+func (m confirmModel) Init() tea.Cmd {
+ return nil
+}
+
+func (m confirmModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ wasSet := m.width > 0
+ m.width = msg.Width
+ if wasSet {
+ return m, tea.EnterAltScreen
+ }
+ return m, nil
+
+ case tea.KeyMsg:
+ switch msg.String() {
+ case "ctrl+c", "esc", "n":
+ m.cancelled = true
+ return m, tea.Quit
+ case "y":
+ m.yes = true
+ m.confirmed = true
+ return m, tea.Quit
+ case "enter":
+ m.confirmed = true
+ return m, tea.Quit
+ case "left", "h":
+ m.yes = true
+ case "right", "l":
+ m.yes = false
+ case "tab":
+ m.yes = !m.yes
+ }
+ }
+
+ return m, nil
+}
+
+func (m confirmModel) View() string {
+ if m.confirmed || m.cancelled {
+ return ""
+ }
+
+ var yesBtn, noBtn string
+ if m.yes {
+ yesBtn = confirmActiveStyle.Render(" Yes ")
+ noBtn = confirmInactiveStyle.Render(" No ")
+ } else {
+ yesBtn = confirmInactiveStyle.Render(" Yes ")
+ noBtn = confirmActiveStyle.Render(" No ")
+ }
+
+ s := selectorTitleStyle.Render(m.prompt) + "\n\n"
+ s += " " + yesBtn + " " + noBtn + "\n\n"
+ s += selectorHelpStyle.Render("←/→ navigate • enter confirm • esc cancel")
+
+ if m.width > 0 {
+ return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
+ }
+ return s
+}
+
+// RunConfirm shows a bubbletea yes/no confirmation prompt.
+// Returns true if the user confirmed, false if cancelled.
+func RunConfirm(prompt string) (bool, error) {
+ m := confirmModel{
+ prompt: prompt,
+ yes: true, // default to yes
+ }
+
+ p := tea.NewProgram(m)
+ finalModel, err := p.Run()
+ if err != nil {
+ return false, fmt.Errorf("error running confirm: %w", err)
+ }
+
+ fm := finalModel.(confirmModel)
+ if fm.cancelled {
+ return false, ErrCancelled
+ }
+
+ return fm.yes, nil
+}
diff --git a/cmd/tui/confirm_test.go b/cmd/tui/confirm_test.go
new file mode 100644
index 00000000000..4279d18ebda
--- /dev/null
+++ b/cmd/tui/confirm_test.go
@@ -0,0 +1,208 @@
+package tui
+
+import (
+ "strings"
+ "testing"
+
+ tea "github.com/charmbracelet/bubbletea"
+)
+
+func TestConfirmModel_DefaultsToYes(t *testing.T) {
+ m := confirmModel{prompt: "Download test?", yes: true}
+ if !m.yes {
+ t.Error("should default to yes")
+ }
+}
+
+func TestConfirmModel_View_ContainsPrompt(t *testing.T) {
+ m := confirmModel{prompt: "Download qwen3:8b?", yes: true}
+ got := m.View()
+ if !strings.Contains(got, "Download qwen3:8b?") {
+ t.Error("should contain the prompt text")
+ }
+}
+
+func TestConfirmModel_View_ContainsButtons(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+ got := m.View()
+ if !strings.Contains(got, "Yes") {
+ t.Error("should contain Yes button")
+ }
+ if !strings.Contains(got, "No") {
+ t.Error("should contain No button")
+ }
+}
+
+func TestConfirmModel_View_ContainsHelp(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+ got := m.View()
+ if !strings.Contains(got, "enter confirm") {
+ t.Error("should contain help text")
+ }
+}
+
+func TestConfirmModel_View_ClearsAfterConfirm(t *testing.T) {
+ m := confirmModel{prompt: "Download?", confirmed: true}
+ if m.View() != "" {
+ t.Error("View should return empty string after confirmation")
+ }
+}
+
+func TestConfirmModel_View_ClearsAfterCancel(t *testing.T) {
+ m := confirmModel{prompt: "Download?", cancelled: true}
+ if m.View() != "" {
+ t.Error("View should return empty string after cancellation")
+ }
+}
+
+func TestConfirmModel_EnterConfirmsYes(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
+ fm := updated.(confirmModel)
+ if !fm.confirmed {
+ t.Error("enter should set confirmed=true")
+ }
+ if !fm.yes {
+ t.Error("enter with yes selected should keep yes=true")
+ }
+ if cmd == nil {
+ t.Error("enter should return tea.Quit")
+ }
+}
+
+func TestConfirmModel_EnterConfirmsNo(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: false}
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
+ fm := updated.(confirmModel)
+ if !fm.confirmed {
+ t.Error("enter should set confirmed=true")
+ }
+ if fm.yes {
+ t.Error("enter with no selected should keep yes=false")
+ }
+ if cmd == nil {
+ t.Error("enter should return tea.Quit")
+ }
+}
+
+func TestConfirmModel_EscCancels(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
+ fm := updated.(confirmModel)
+ if !fm.cancelled {
+ t.Error("esc should set cancelled=true")
+ }
+ if cmd == nil {
+ t.Error("esc should return tea.Quit")
+ }
+}
+
+func TestConfirmModel_CtrlCCancels(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
+ fm := updated.(confirmModel)
+ if !fm.cancelled {
+ t.Error("ctrl+c should set cancelled=true")
+ }
+ if cmd == nil {
+ t.Error("ctrl+c should return tea.Quit")
+ }
+}
+
+func TestConfirmModel_NCancels(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'n'}})
+ fm := updated.(confirmModel)
+ if !fm.cancelled {
+ t.Error("'n' should set cancelled=true")
+ }
+ if cmd == nil {
+ t.Error("'n' should return tea.Quit")
+ }
+}
+
+func TestConfirmModel_YConfirmsYes(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: false}
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'y'}})
+ fm := updated.(confirmModel)
+ if !fm.confirmed {
+ t.Error("'y' should set confirmed=true")
+ }
+ if !fm.yes {
+ t.Error("'y' should set yes=true")
+ }
+ if cmd == nil {
+ t.Error("'y' should return tea.Quit")
+ }
+}
+
+func TestConfirmModel_ArrowKeysNavigate(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+
+ // Right moves to No
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'l'}})
+ fm := updated.(confirmModel)
+ if fm.yes {
+ t.Error("right/l should move to No")
+ }
+ if fm.confirmed || fm.cancelled {
+ t.Error("navigation should not confirm or cancel")
+ }
+
+ // Left moves back to Yes
+ updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'h'}})
+ fm = updated.(confirmModel)
+ if !fm.yes {
+ t.Error("left/h should move to Yes")
+ }
+}
+
+func TestConfirmModel_TabToggles(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true}
+
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
+ fm := updated.(confirmModel)
+ if fm.yes {
+ t.Error("tab should toggle from Yes to No")
+ }
+
+ updated, _ = fm.Update(tea.KeyMsg{Type: tea.KeyTab})
+ fm = updated.(confirmModel)
+ if !fm.yes {
+ t.Error("tab should toggle from No to Yes")
+ }
+}
+
+func TestConfirmModel_WindowSizeUpdatesWidth(t *testing.T) {
+ m := confirmModel{prompt: "Download?"}
+ updated, _ := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
+ fm := updated.(confirmModel)
+ if fm.width != 100 {
+ t.Errorf("expected width 100, got %d", fm.width)
+ }
+}
+
+func TestConfirmModel_ResizeEntersAltScreen(t *testing.T) {
+ m := confirmModel{prompt: "Download?", width: 80}
+ _, cmd := m.Update(tea.WindowSizeMsg{Width: 100, Height: 40})
+ if cmd == nil {
+ t.Error("resize (width already set) should return a command")
+ }
+}
+
+func TestConfirmModel_InitialWindowSizeNoAltScreen(t *testing.T) {
+ m := confirmModel{prompt: "Download?"}
+ _, cmd := m.Update(tea.WindowSizeMsg{Width: 80, Height: 40})
+ if cmd != nil {
+ t.Error("initial WindowSizeMsg should not return a command")
+ }
+}
+
+func TestConfirmModel_ViewMaxWidth(t *testing.T) {
+ m := confirmModel{prompt: "Download?", yes: true, width: 40}
+ got := m.View()
+ // Just ensure it doesn't panic and returns content
+ if got == "" {
+ t.Error("View with width set should still return content")
+ }
+}
diff --git a/cmd/tui/selector.go b/cmd/tui/selector.go
new file mode 100644
index 00000000000..7bf8180be2d
--- /dev/null
+++ b/cmd/tui/selector.go
@@ -0,0 +1,824 @@
+package tui
+
+import (
+ "errors"
+ "fmt"
+ "strings"
+
+ tea "github.com/charmbracelet/bubbletea"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/ollama/ollama/cmd/config"
+)
+
+var (
+ selectorTitleStyle = lipgloss.NewStyle().
+ Bold(true)
+
+ selectorItemStyle = lipgloss.NewStyle().
+ PaddingLeft(4)
+
+ selectorSelectedItemStyle = lipgloss.NewStyle().
+ PaddingLeft(2).
+ Bold(true).
+ Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
+
+ selectorDescStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
+
+ selectorDescLineStyle = selectorDescStyle.
+ PaddingLeft(6)
+
+ selectorFilterStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
+ Italic(true)
+
+ selectorInputStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "235", Dark: "252"})
+
+ selectorDefaultTagStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
+ Italic(true)
+
+ selectorHelpStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "244", Dark: "244"})
+
+ selectorMoreStyle = lipgloss.NewStyle().
+ PaddingLeft(6).
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
+ Italic(true)
+
+ sectionHeaderStyle = lipgloss.NewStyle().
+ PaddingLeft(2).
+ Bold(true).
+ Foreground(lipgloss.AdaptiveColor{Light: "240", Dark: "249"})
+)
+
+const maxSelectorItems = 10
+
+// ErrCancelled is returned when the user cancels the selection.
+var ErrCancelled = errors.New("cancelled")
+
+type SelectItem struct {
+ Name string
+ Description string
+ Recommended bool
+}
+
+// ConvertItems converts config.ModelItem slice to SelectItem slice.
+func ConvertItems(items []config.ModelItem) []SelectItem {
+ out := make([]SelectItem, len(items))
+ for i, item := range items {
+ out[i] = SelectItem{Name: item.Name, Description: item.Description, Recommended: item.Recommended}
+ }
+ return out
+}
+
+// ReorderItems returns a copy with recommended items first, then non-recommended,
+// preserving relative order within each group. This ensures the data order matches
+// the visual section layout (Recommended / More).
+func ReorderItems(items []SelectItem) []SelectItem {
+ var rec, other []SelectItem
+ for _, item := range items {
+ if item.Recommended {
+ rec = append(rec, item)
+ } else {
+ other = append(other, item)
+ }
+ }
+ return append(rec, other...)
+}
+
+// selectorModel is the bubbletea model for single selection.
+type selectorModel struct {
+ title string
+ items []SelectItem
+ filter string
+ cursor int
+ scrollOffset int
+ selected string
+ cancelled bool
+ helpText string
+ width int
+}
+
+func (m selectorModel) filteredItems() []SelectItem {
+ if m.filter == "" {
+ return m.items
+ }
+ filterLower := strings.ToLower(m.filter)
+ var result []SelectItem
+ for _, item := range m.items {
+ if strings.Contains(strings.ToLower(item.Name), filterLower) {
+ result = append(result, item)
+ }
+ }
+ return result
+}
+
+func (m selectorModel) Init() tea.Cmd {
+ return nil
+}
+
+// otherStart returns the index of the first non-recommended item in the filtered list.
+// When filtering, all items scroll together so this returns 0.
+func (m selectorModel) otherStart() int {
+ if m.filter != "" {
+ return 0
+ }
+ filtered := m.filteredItems()
+ for i, item := range filtered {
+ if !item.Recommended {
+ return i
+ }
+ }
+ return len(filtered)
+}
+
+// updateNavigation handles navigation keys (up/down/pgup/pgdown/filter/backspace).
+// It does NOT handle Enter, Esc, or CtrlC. This is used by both the standalone
+// selector and the TUI modal (which intercepts Enter/Esc for its own logic).
+func (m *selectorModel) updateNavigation(msg tea.KeyMsg) {
+ filtered := m.filteredItems()
+ otherStart := m.otherStart()
+
+ switch msg.Type {
+ case tea.KeyUp:
+ if m.cursor > 0 {
+ m.cursor--
+ m.updateScroll(otherStart)
+ }
+
+ case tea.KeyDown:
+ if m.cursor < len(filtered)-1 {
+ m.cursor++
+ m.updateScroll(otherStart)
+ }
+
+ case tea.KeyPgUp:
+ m.cursor -= maxSelectorItems
+ if m.cursor < 0 {
+ m.cursor = 0
+ }
+ m.updateScroll(otherStart)
+
+ case tea.KeyPgDown:
+ m.cursor += maxSelectorItems
+ if m.cursor >= len(filtered) {
+ m.cursor = len(filtered) - 1
+ }
+ m.updateScroll(otherStart)
+
+ case tea.KeyBackspace:
+ if len(m.filter) > 0 {
+ m.filter = m.filter[:len(m.filter)-1]
+ m.cursor = 0
+ m.scrollOffset = 0
+ }
+
+ case tea.KeyRunes:
+ m.filter += string(msg.Runes)
+ m.cursor = 0
+ m.scrollOffset = 0
+ }
+}
+
+// updateScroll adjusts scrollOffset based on cursor position.
+// When not filtering, scrollOffset is relative to the "More" (non-recommended) section.
+// When filtering, it's relative to the full filtered list.
+func (m *selectorModel) updateScroll(otherStart int) {
+ if m.filter != "" {
+ if m.cursor < m.scrollOffset {
+ m.scrollOffset = m.cursor
+ }
+ if m.cursor >= m.scrollOffset+maxSelectorItems {
+ m.scrollOffset = m.cursor - maxSelectorItems + 1
+ }
+ return
+ }
+
+ // Cursor is in recommended section — reset "More" scroll to top
+ if m.cursor < otherStart {
+ m.scrollOffset = 0
+ return
+ }
+
+ // Cursor is in "More" section — scroll relative to others
+ posInOthers := m.cursor - otherStart
+ maxOthers := maxSelectorItems - otherStart
+ if maxOthers < 3 {
+ maxOthers = 3
+ }
+ if posInOthers < m.scrollOffset {
+ m.scrollOffset = posInOthers
+ }
+ if posInOthers >= m.scrollOffset+maxOthers {
+ m.scrollOffset = posInOthers - maxOthers + 1
+ }
+}
+
+func (m selectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ wasSet := m.width > 0
+ m.width = msg.Width
+ if wasSet {
+ return m, tea.EnterAltScreen
+ }
+ return m, nil
+
+ case tea.KeyMsg:
+ switch msg.Type {
+ case tea.KeyCtrlC, tea.KeyEsc:
+ m.cancelled = true
+ return m, tea.Quit
+
+ case tea.KeyEnter:
+ filtered := m.filteredItems()
+ if len(filtered) > 0 && m.cursor < len(filtered) {
+ m.selected = filtered[m.cursor].Name
+ }
+ return m, tea.Quit
+
+ default:
+ m.updateNavigation(msg)
+ }
+ }
+
+ return m, nil
+}
+
+func (m selectorModel) renderItem(s *strings.Builder, item SelectItem, idx int) {
+ if idx == m.cursor {
+ s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
+ } else {
+ s.WriteString(selectorItemStyle.Render(item.Name))
+ }
+ s.WriteString("\n")
+ if item.Description != "" {
+ s.WriteString(selectorDescLineStyle.Render(item.Description))
+ s.WriteString("\n")
+ }
+}
+
+// renderContent renders the selector content (title, items, help text) without
+// checking the cancelled/selected state. This is used by both View() (standalone mode)
+// and by the TUI modal which embeds a selectorModel.
+func (m selectorModel) renderContent() string {
+ var s strings.Builder
+
+ s.WriteString(selectorTitleStyle.Render(m.title))
+ s.WriteString(" ")
+ if m.filter == "" {
+ s.WriteString(selectorFilterStyle.Render("Type to filter..."))
+ } else {
+ s.WriteString(selectorInputStyle.Render(m.filter))
+ }
+ s.WriteString("\n\n")
+
+ filtered := m.filteredItems()
+
+ if len(filtered) == 0 {
+ s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
+ s.WriteString("\n")
+ } else if m.filter != "" {
+ s.WriteString(sectionHeaderStyle.Render("Top Results"))
+ s.WriteString("\n")
+
+ displayCount := min(len(filtered), maxSelectorItems)
+ for i := range displayCount {
+ idx := m.scrollOffset + i
+ if idx >= len(filtered) {
+ break
+ }
+ m.renderItem(&s, filtered[idx], idx)
+ }
+
+ if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
+ s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
+ s.WriteString("\n")
+ }
+ } else {
+ // Split into pinned recommended and scrollable others
+ var recItems, otherItems []int
+ for i, item := range filtered {
+ if item.Recommended {
+ recItems = append(recItems, i)
+ } else {
+ otherItems = append(otherItems, i)
+ }
+ }
+
+ // Always render all recommended items (pinned)
+ if len(recItems) > 0 {
+ s.WriteString(sectionHeaderStyle.Render("Recommended"))
+ s.WriteString("\n")
+ for _, idx := range recItems {
+ m.renderItem(&s, filtered[idx], idx)
+ }
+ }
+
+ if len(otherItems) > 0 {
+ s.WriteString("\n")
+ s.WriteString(sectionHeaderStyle.Render("More"))
+ s.WriteString("\n")
+
+ maxOthers := maxSelectorItems - len(recItems)
+ if maxOthers < 3 {
+ maxOthers = 3
+ }
+ displayCount := min(len(otherItems), maxOthers)
+
+ for i := range displayCount {
+ idx := m.scrollOffset + i
+ if idx >= len(otherItems) {
+ break
+ }
+ m.renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
+ }
+
+ if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
+ s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
+ s.WriteString("\n")
+ }
+ }
+ }
+
+ s.WriteString("\n")
+ help := "↑/↓ navigate • enter select • esc cancel"
+ if m.helpText != "" {
+ help = m.helpText
+ }
+ s.WriteString(selectorHelpStyle.Render(help))
+
+ return s.String()
+}
+
+func (m selectorModel) View() string {
+ if m.cancelled || m.selected != "" {
+ return ""
+ }
+
+ s := m.renderContent()
+ if m.width > 0 {
+ return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
+ }
+ return s
+}
+
+// cursorForCurrent returns the item index matching current, or 0 if not found.
+func cursorForCurrent(items []SelectItem, current string) int {
+ if current != "" {
+ for i, item := range items {
+ if item.Name == current || strings.HasPrefix(item.Name, current+":") || strings.HasPrefix(current, item.Name+":") {
+ return i
+ }
+ }
+ }
+ return 0
+}
+
+func SelectSingle(title string, items []SelectItem, current string) (string, error) {
+ if len(items) == 0 {
+ return "", fmt.Errorf("no items to select from")
+ }
+
+ m := selectorModel{
+ title: title,
+ items: items,
+ cursor: cursorForCurrent(items, current),
+ }
+
+ p := tea.NewProgram(m)
+ finalModel, err := p.Run()
+ if err != nil {
+ return "", fmt.Errorf("error running selector: %w", err)
+ }
+
+ fm := finalModel.(selectorModel)
+ if fm.cancelled {
+ return "", ErrCancelled
+ }
+
+ return fm.selected, nil
+}
+
+// multiSelectorModel is the bubbletea model for multi selection.
+type multiSelectorModel struct {
+ title string
+ items []SelectItem
+ itemIndex map[string]int
+ filter string
+ cursor int
+ scrollOffset int
+ checked map[int]bool
+ checkOrder []int
+ cancelled bool
+ confirmed bool
+ width int
+
+ // multi enables full multi-select editing mode. The zero value (false)
+ // shows a single-select picker where Enter adds the chosen model to
+ // the existing list. Tab toggles between modes.
+ multi bool
+ singleAdd string // model picked in single mode
+}
+
+func newMultiSelectorModel(title string, items []SelectItem, preChecked []string) multiSelectorModel {
+ m := multiSelectorModel{
+ title: title,
+ items: items,
+ itemIndex: make(map[string]int, len(items)),
+ checked: make(map[int]bool),
+ }
+
+ for i, item := range items {
+ m.itemIndex[item.Name] = i
+ }
+
+ // Reverse order so preChecked[0] (the current default) ends up last
+ // in checkOrder, matching the "last checked = default" convention.
+ for i := len(preChecked) - 1; i >= 0; i-- {
+ if idx, ok := m.itemIndex[preChecked[i]]; ok {
+ m.checked[idx] = true
+ m.checkOrder = append(m.checkOrder, idx)
+ }
+ }
+
+ // Position cursor on the current default model
+ if len(preChecked) > 0 {
+ if idx, ok := m.itemIndex[preChecked[0]]; ok {
+ m.cursor = idx
+ m.updateScroll(m.otherStart())
+ }
+ }
+
+ return m
+}
+
+func (m multiSelectorModel) filteredItems() []SelectItem {
+ if m.filter == "" {
+ return m.items
+ }
+ filterLower := strings.ToLower(m.filter)
+ var result []SelectItem
+ for _, item := range m.items {
+ if strings.Contains(strings.ToLower(item.Name), filterLower) {
+ result = append(result, item)
+ }
+ }
+ return result
+}
+
+// otherStart returns the index of the first non-recommended item in the filtered list.
+func (m multiSelectorModel) otherStart() int {
+ if m.filter != "" {
+ return 0
+ }
+ filtered := m.filteredItems()
+ for i, item := range filtered {
+ if !item.Recommended {
+ return i
+ }
+ }
+ return len(filtered)
+}
+
+// updateScroll adjusts scrollOffset for section-based scrolling (matches single-select).
+func (m *multiSelectorModel) updateScroll(otherStart int) {
+ if m.filter != "" {
+ if m.cursor < m.scrollOffset {
+ m.scrollOffset = m.cursor
+ }
+ if m.cursor >= m.scrollOffset+maxSelectorItems {
+ m.scrollOffset = m.cursor - maxSelectorItems + 1
+ }
+ return
+ }
+
+ if m.cursor < otherStart {
+ m.scrollOffset = 0
+ return
+ }
+
+ posInOthers := m.cursor - otherStart
+ maxOthers := maxSelectorItems - otherStart
+ if maxOthers < 3 {
+ maxOthers = 3
+ }
+ if posInOthers < m.scrollOffset {
+ m.scrollOffset = posInOthers
+ }
+ if posInOthers >= m.scrollOffset+maxOthers {
+ m.scrollOffset = posInOthers - maxOthers + 1
+ }
+}
+
+func (m *multiSelectorModel) toggleItem() {
+ filtered := m.filteredItems()
+ if len(filtered) == 0 || m.cursor >= len(filtered) {
+ return
+ }
+
+ item := filtered[m.cursor]
+ origIdx := m.itemIndex[item.Name]
+
+ if m.checked[origIdx] {
+ delete(m.checked, origIdx)
+ for i, idx := range m.checkOrder {
+ if idx == origIdx {
+ m.checkOrder = append(m.checkOrder[:i], m.checkOrder[i+1:]...)
+ break
+ }
+ }
+ } else {
+ m.checked[origIdx] = true
+ m.checkOrder = append(m.checkOrder, origIdx)
+ }
+}
+
+func (m multiSelectorModel) selectedCount() int {
+ return len(m.checkOrder)
+}
+
+func (m multiSelectorModel) Init() tea.Cmd {
+ return nil
+}
+
+func (m multiSelectorModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ wasSet := m.width > 0
+ m.width = msg.Width
+ if wasSet {
+ return m, tea.EnterAltScreen
+ }
+ return m, nil
+
+ case tea.KeyMsg:
+ filtered := m.filteredItems()
+
+ switch msg.Type {
+ case tea.KeyCtrlC, tea.KeyEsc:
+ m.cancelled = true
+ return m, tea.Quit
+
+ case tea.KeyTab:
+ m.multi = !m.multi
+
+ case tea.KeyEnter:
+ if !m.multi {
+ if len(filtered) > 0 && m.cursor < len(filtered) {
+ m.singleAdd = filtered[m.cursor].Name
+ m.confirmed = true
+ return m, tea.Quit
+ }
+ } else if len(m.checkOrder) > 0 {
+ m.confirmed = true
+ return m, tea.Quit
+ }
+
+ case tea.KeySpace:
+ if m.multi {
+ m.toggleItem()
+ }
+
+ case tea.KeyUp:
+ if m.cursor > 0 {
+ m.cursor--
+ m.updateScroll(m.otherStart())
+ }
+
+ case tea.KeyDown:
+ if m.cursor < len(filtered)-1 {
+ m.cursor++
+ m.updateScroll(m.otherStart())
+ }
+
+ case tea.KeyPgUp:
+ m.cursor -= maxSelectorItems
+ if m.cursor < 0 {
+ m.cursor = 0
+ }
+ m.updateScroll(m.otherStart())
+
+ case tea.KeyPgDown:
+ m.cursor += maxSelectorItems
+ if m.cursor >= len(filtered) {
+ m.cursor = len(filtered) - 1
+ }
+ m.updateScroll(m.otherStart())
+
+ case tea.KeyBackspace:
+ if len(m.filter) > 0 {
+ m.filter = m.filter[:len(m.filter)-1]
+ m.cursor = 0
+ m.scrollOffset = 0
+ }
+
+ case tea.KeyRunes:
+ // On some terminals (e.g. Windows PowerShell), space arrives as
+ // KeyRunes instead of KeySpace. Intercept it so toggle still works.
+ if len(msg.Runes) == 1 && msg.Runes[0] == ' ' {
+ if m.multi {
+ m.toggleItem()
+ }
+ } else {
+ m.filter += string(msg.Runes)
+ m.cursor = 0
+ m.scrollOffset = 0
+ }
+ }
+ }
+
+ return m, nil
+}
+
+func (m multiSelectorModel) renderSingleItem(s *strings.Builder, item SelectItem, idx int) {
+ if idx == m.cursor {
+ s.WriteString(selectorSelectedItemStyle.Render("▸ " + item.Name))
+ } else {
+ s.WriteString(selectorItemStyle.Render(item.Name))
+ }
+ s.WriteString("\n")
+ if item.Description != "" {
+ s.WriteString(selectorDescLineStyle.Render(item.Description))
+ s.WriteString("\n")
+ }
+}
+
+func (m multiSelectorModel) renderMultiItem(s *strings.Builder, item SelectItem, idx int) {
+ origIdx := m.itemIndex[item.Name]
+
+ var check string
+ if m.checked[origIdx] {
+ check = "[x] "
+ } else {
+ check = "[ ] "
+ }
+
+ suffix := ""
+ if len(m.checkOrder) > 0 && m.checkOrder[len(m.checkOrder)-1] == origIdx {
+ suffix = " " + selectorDefaultTagStyle.Render("(default)")
+ }
+
+ if idx == m.cursor {
+ s.WriteString(selectorSelectedItemStyle.Render("▸ " + check + item.Name))
+ } else {
+ s.WriteString(selectorItemStyle.Render(check + item.Name))
+ }
+ s.WriteString(suffix)
+ s.WriteString("\n")
+ if item.Description != "" {
+ s.WriteString(selectorDescLineStyle.Render(item.Description))
+ s.WriteString("\n")
+ }
+}
+
+func (m multiSelectorModel) View() string {
+ if m.cancelled || m.confirmed {
+ return ""
+ }
+
+ renderItem := m.renderSingleItem
+ if m.multi {
+ renderItem = m.renderMultiItem
+ }
+
+ var s strings.Builder
+
+ s.WriteString(selectorTitleStyle.Render(m.title))
+ s.WriteString(" ")
+ if m.filter == "" {
+ s.WriteString(selectorFilterStyle.Render("Type to filter..."))
+ } else {
+ s.WriteString(selectorInputStyle.Render(m.filter))
+ }
+ s.WriteString("\n\n")
+
+ filtered := m.filteredItems()
+
+ if len(filtered) == 0 {
+ s.WriteString(selectorItemStyle.Render(selectorDescStyle.Render("(no matches)")))
+ s.WriteString("\n")
+ } else if m.filter != "" {
+ // Filtering: flat scroll through all matches
+ displayCount := min(len(filtered), maxSelectorItems)
+ for i := range displayCount {
+ idx := m.scrollOffset + i
+ if idx >= len(filtered) {
+ break
+ }
+ renderItem(&s, filtered[idx], idx)
+ }
+
+ if remaining := len(filtered) - m.scrollOffset - displayCount; remaining > 0 {
+ s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
+ s.WriteString("\n")
+ }
+ } else {
+ // Split into pinned recommended and scrollable others (matches single-select layout)
+ var recItems, otherItems []int
+ for i, item := range filtered {
+ if item.Recommended {
+ recItems = append(recItems, i)
+ } else {
+ otherItems = append(otherItems, i)
+ }
+ }
+
+ // Always render all recommended items (pinned)
+ if len(recItems) > 0 {
+ s.WriteString(sectionHeaderStyle.Render("Recommended"))
+ s.WriteString("\n")
+ for _, idx := range recItems {
+ renderItem(&s, filtered[idx], idx)
+ }
+ }
+
+ if len(otherItems) > 0 {
+ s.WriteString("\n")
+ s.WriteString(sectionHeaderStyle.Render("More"))
+ s.WriteString("\n")
+
+ maxOthers := maxSelectorItems - len(recItems)
+ if maxOthers < 3 {
+ maxOthers = 3
+ }
+ displayCount := min(len(otherItems), maxOthers)
+
+ for i := range displayCount {
+ idx := m.scrollOffset + i
+ if idx >= len(otherItems) {
+ break
+ }
+ renderItem(&s, filtered[otherItems[idx]], otherItems[idx])
+ }
+
+ if remaining := len(otherItems) - m.scrollOffset - displayCount; remaining > 0 {
+ s.WriteString(selectorMoreStyle.Render(fmt.Sprintf("... and %d more", remaining)))
+ s.WriteString("\n")
+ }
+ }
+ }
+
+ s.WriteString("\n")
+
+ if !m.multi {
+ s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • enter select • tab add multiple • esc cancel"))
+ } else {
+ count := m.selectedCount()
+ if count == 0 {
+ s.WriteString(selectorDescStyle.Render(" Select at least one model."))
+ } else {
+ s.WriteString(selectorDescStyle.Render(fmt.Sprintf(" %d selected - press enter to continue", count)))
+ }
+ s.WriteString("\n\n")
+ s.WriteString(selectorHelpStyle.Render("↑/↓ navigate • space toggle • tab select single • enter confirm • esc cancel"))
+ }
+
+ result := s.String()
+ if m.width > 0 {
+ return lipgloss.NewStyle().MaxWidth(m.width).Render(result)
+ }
+ return result
+}
+
+func SelectMultiple(title string, items []SelectItem, preChecked []string) ([]string, error) {
+ if len(items) == 0 {
+ return nil, fmt.Errorf("no items to select from")
+ }
+
+ m := newMultiSelectorModel(title, items, preChecked)
+
+ p := tea.NewProgram(m)
+ finalModel, err := p.Run()
+ if err != nil {
+ return nil, fmt.Errorf("error running selector: %w", err)
+ }
+
+ fm := finalModel.(multiSelectorModel)
+ if fm.cancelled || !fm.confirmed {
+ return nil, ErrCancelled
+ }
+
+ // Single-add mode: prepend the picked model, keep existing models deduped
+ if fm.singleAdd != "" {
+ result := []string{fm.singleAdd}
+ for _, name := range preChecked {
+ if name != fm.singleAdd {
+ result = append(result, name)
+ }
+ }
+ return result, nil
+ }
+
+ // Multi-edit mode: last checked is default (first in result)
+ last := fm.checkOrder[len(fm.checkOrder)-1]
+ result := []string{fm.items[last].Name}
+ for _, idx := range fm.checkOrder {
+ if idx != last {
+ result = append(result, fm.items[idx].Name)
+ }
+ }
+ return result, nil
+}
diff --git a/cmd/tui/selector_test.go b/cmd/tui/selector_test.go
new file mode 100644
index 00000000000..fa8ff4dc47f
--- /dev/null
+++ b/cmd/tui/selector_test.go
@@ -0,0 +1,811 @@
+package tui
+
+import (
+ "strings"
+ "testing"
+
+ tea "github.com/charmbracelet/bubbletea"
+)
+
+func items(names ...string) []SelectItem {
+ var out []SelectItem
+ for _, n := range names {
+ out = append(out, SelectItem{Name: n})
+ }
+ return out
+}
+
+func recItems(names ...string) []SelectItem {
+ var out []SelectItem
+ for _, n := range names {
+ out = append(out, SelectItem{Name: n, Recommended: true})
+ }
+ return out
+}
+
+func mixedItems() []SelectItem {
+ return []SelectItem{
+ {Name: "rec-a", Recommended: true},
+ {Name: "rec-b", Recommended: true},
+ {Name: "other-1"},
+ {Name: "other-2"},
+ {Name: "other-3"},
+ {Name: "other-4"},
+ {Name: "other-5"},
+ {Name: "other-6"},
+ {Name: "other-7"},
+ {Name: "other-8"},
+ {Name: "other-9"},
+ {Name: "other-10"},
+ }
+}
+
+func TestFilteredItems(t *testing.T) {
+ tests := []struct {
+ name string
+ items []SelectItem
+ filter string
+ want []string
+ }{
+ {
+ name: "no filter returns all",
+ items: items("alpha", "beta", "gamma"),
+ filter: "",
+ want: []string{"alpha", "beta", "gamma"},
+ },
+ {
+ name: "filter matches substring",
+ items: items("llama3.2", "qwen3:8b", "llama2"),
+ filter: "llama",
+ want: []string{"llama3.2", "llama2"},
+ },
+ {
+ name: "filter is case insensitive",
+ items: items("Qwen3:8b", "llama3.2"),
+ filter: "QWEN",
+ want: []string{"Qwen3:8b"},
+ },
+ {
+ name: "no matches",
+ items: items("alpha", "beta"),
+ filter: "zzz",
+ want: nil,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := selectorModel{items: tt.items, filter: tt.filter}
+ got := m.filteredItems()
+ var gotNames []string
+ for _, item := range got {
+ gotNames = append(gotNames, item.Name)
+ }
+ if len(gotNames) != len(tt.want) {
+ t.Fatalf("got %v, want %v", gotNames, tt.want)
+ }
+ for i := range tt.want {
+ if gotNames[i] != tt.want[i] {
+ t.Errorf("index %d: got %q, want %q", i, gotNames[i], tt.want[i])
+ }
+ }
+ })
+ }
+}
+
+func TestOtherStart(t *testing.T) {
+ tests := []struct {
+ name string
+ items []SelectItem
+ filter string
+ want int
+ }{
+ {
+ name: "all recommended",
+ items: recItems("a", "b", "c"),
+ want: 3,
+ },
+ {
+ name: "none recommended",
+ items: items("a", "b"),
+ want: 0,
+ },
+ {
+ name: "mixed",
+ items: []SelectItem{
+ {Name: "rec-a", Recommended: true},
+ {Name: "rec-b", Recommended: true},
+ {Name: "other-1"},
+ {Name: "other-2"},
+ },
+ want: 2,
+ },
+ {
+ name: "empty",
+ items: nil,
+ want: 0,
+ },
+ {
+ name: "filtering returns 0",
+ items: []SelectItem{
+ {Name: "rec-a", Recommended: true},
+ {Name: "other-1"},
+ },
+ filter: "rec",
+ want: 0,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := selectorModel{items: tt.items, filter: tt.filter}
+ if got := m.otherStart(); got != tt.want {
+ t.Errorf("otherStart() = %d, want %d", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestUpdateScroll(t *testing.T) {
+ tests := []struct {
+ name string
+ cursor int
+ offset int
+ otherStart int
+ filter string
+ wantOffset int
+ }{
+ {
+ name: "cursor in recommended resets scroll",
+ cursor: 1,
+ offset: 5,
+ otherStart: 3,
+ wantOffset: 0,
+ },
+ {
+ name: "cursor at start of others",
+ cursor: 2,
+ offset: 0,
+ otherStart: 2,
+ wantOffset: 0,
+ },
+ {
+ name: "cursor scrolls down in others",
+ cursor: 12,
+ offset: 0,
+ otherStart: 2,
+ wantOffset: 3, // posInOthers=10, maxOthers=8, 10-8+1=3
+ },
+ {
+ name: "cursor scrolls up in others",
+ cursor: 4,
+ offset: 5,
+ otherStart: 2,
+ wantOffset: 2, // posInOthers=2 < offset=5
+ },
+ {
+ name: "filter mode standard scroll down",
+ cursor: 12,
+ offset: 0,
+ filter: "x",
+ otherStart: 0,
+ wantOffset: 3, // 12 - 10 + 1 = 3
+ },
+ {
+ name: "filter mode standard scroll up",
+ cursor: 2,
+ offset: 5,
+ filter: "x",
+ otherStart: 0,
+ wantOffset: 2,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := selectorModel{
+ cursor: tt.cursor,
+ scrollOffset: tt.offset,
+ filter: tt.filter,
+ }
+ m.updateScroll(tt.otherStart)
+ if m.scrollOffset != tt.wantOffset {
+ t.Errorf("scrollOffset = %d, want %d", m.scrollOffset, tt.wantOffset)
+ }
+ })
+ }
+}
+
+func TestRenderContent_SectionHeaders(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: []SelectItem{
+ {Name: "rec-a", Recommended: true},
+ {Name: "other-1"},
+ },
+ }
+ content := m.renderContent()
+
+ if !strings.Contains(content, "Recommended") {
+ t.Error("should contain 'Recommended' header")
+ }
+ if !strings.Contains(content, "More") {
+ t.Error("should contain 'More' header")
+ }
+}
+
+func TestRenderContent_FilteredHeader(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: items("alpha", "beta", "alphabet"),
+ filter: "alpha",
+ }
+ content := m.renderContent()
+
+ if !strings.Contains(content, "Top Results") {
+ t.Error("filtered view should contain 'Top Results' header")
+ }
+ if strings.Contains(content, "Recommended") {
+ t.Error("filtered view should not contain 'Recommended' header")
+ }
+}
+
+func TestRenderContent_NoMatches(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: items("alpha"),
+ filter: "zzz",
+ }
+ content := m.renderContent()
+
+ if !strings.Contains(content, "(no matches)") {
+ t.Error("should show '(no matches)' when filter has no results")
+ }
+}
+
+func TestRenderContent_SelectedItemIndicator(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: items("alpha", "beta"),
+ cursor: 0,
+ }
+ content := m.renderContent()
+
+ if !strings.Contains(content, "▸") {
+ t.Error("selected item should have ▸ indicator")
+ }
+}
+
+func TestRenderContent_Description(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: []SelectItem{
+ {Name: "alpha", Description: "the first letter"},
+ },
+ }
+ content := m.renderContent()
+
+ if !strings.Contains(content, "the first letter") {
+ t.Error("should render item description")
+ }
+}
+
+func TestRenderContent_PinnedRecommended(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: mixedItems(),
+ // cursor deep in "More" section
+ cursor: 8,
+ scrollOffset: 3,
+ }
+ content := m.renderContent()
+
+ // Recommended items should always be visible (pinned)
+ if !strings.Contains(content, "rec-a") {
+ t.Error("recommended items should always be rendered (pinned)")
+ }
+ if !strings.Contains(content, "rec-b") {
+ t.Error("recommended items should always be rendered (pinned)")
+ }
+}
+
+func TestRenderContent_MoreOverflowIndicator(t *testing.T) {
+ m := selectorModel{
+ title: "Pick:",
+ items: mixedItems(), // 2 rec + 10 other = 12 total, maxSelectorItems=10
+ }
+ content := m.renderContent()
+
+ if !strings.Contains(content, "... and") {
+ t.Error("should show overflow indicator when more items than visible")
+ }
+}
+
+func TestUpdateNavigation_CursorBounds(t *testing.T) {
+ m := selectorModel{
+ items: items("a", "b", "c"),
+ cursor: 0,
+ }
+
+ // Up at top stays at 0
+ m.updateNavigation(keyMsg(KeyUp))
+ if m.cursor != 0 {
+ t.Errorf("cursor should stay at 0 when pressing up at top, got %d", m.cursor)
+ }
+
+ // Down moves to 1
+ m.updateNavigation(keyMsg(KeyDown))
+ if m.cursor != 1 {
+ t.Errorf("cursor should be 1 after down, got %d", m.cursor)
+ }
+
+ // Down to end
+ m.updateNavigation(keyMsg(KeyDown))
+ m.updateNavigation(keyMsg(KeyDown))
+ if m.cursor != 2 {
+ t.Errorf("cursor should be 2 at bottom, got %d", m.cursor)
+ }
+}
+
+func TestUpdateNavigation_FilterResetsState(t *testing.T) {
+ m := selectorModel{
+ items: items("alpha", "beta"),
+ cursor: 1,
+ scrollOffset: 5,
+ }
+
+ m.updateNavigation(runeMsg('x'))
+ if m.filter != "x" {
+ t.Errorf("filter should be 'x', got %q", m.filter)
+ }
+ if m.cursor != 0 {
+ t.Errorf("cursor should reset to 0 on filter, got %d", m.cursor)
+ }
+ if m.scrollOffset != 0 {
+ t.Errorf("scrollOffset should reset to 0 on filter, got %d", m.scrollOffset)
+ }
+}
+
+func TestUpdateNavigation_Backspace(t *testing.T) {
+ m := selectorModel{
+ items: items("alpha"),
+ filter: "abc",
+ cursor: 1,
+ }
+
+ m.updateNavigation(keyMsg(KeyBackspace))
+ if m.filter != "ab" {
+ t.Errorf("filter should be 'ab' after backspace, got %q", m.filter)
+ }
+ if m.cursor != 0 {
+ t.Errorf("cursor should reset to 0 on backspace, got %d", m.cursor)
+ }
+}
+
+// --- cursorForCurrent ---
+
+func TestCursorForCurrent(t *testing.T) {
+ testItems := []SelectItem{
+ {Name: "llama3.2", Recommended: true},
+ {Name: "qwen3:8b", Recommended: true},
+ {Name: "gemma3:latest"},
+ {Name: "deepseek-r1"},
+ {Name: "glm-5:cloud"},
+ }
+
+ tests := []struct {
+ name string
+ current string
+ want int
+ }{
+ {"empty current", "", 0},
+ {"exact match", "qwen3:8b", 1},
+ {"no match returns 0", "nonexistent", 0},
+ {"bare name matches with :latest suffix", "gemma3", 2},
+ {"full tag matches bare item", "llama3.2:latest", 0},
+ {"cloud model exact match", "glm-5:cloud", 4},
+ {"cloud model bare name", "glm-5", 4},
+ {"recommended item exact match", "llama3.2", 0},
+ {"recommended item with tag", "qwen3", 1},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := cursorForCurrent(testItems, tt.current); got != tt.want {
+ t.Errorf("cursorForCurrent(%q) = %d, want %d", tt.current, got, tt.want)
+ }
+ })
+ }
+}
+
+// --- ReorderItems ---
+
+func TestReorderItems(t *testing.T) {
+ input := []SelectItem{
+ {Name: "local-1"},
+ {Name: "rec-a", Recommended: true},
+ {Name: "local-2"},
+ {Name: "rec-b", Recommended: true},
+ }
+ got := ReorderItems(input)
+ want := []string{"rec-a", "rec-b", "local-1", "local-2"}
+ for i, item := range got {
+ if item.Name != want[i] {
+ t.Errorf("index %d: got %q, want %q", i, item.Name, want[i])
+ }
+ }
+}
+
+func TestReorderItems_AllRecommended(t *testing.T) {
+ input := recItems("a", "b", "c")
+ got := ReorderItems(input)
+ if len(got) != 3 {
+ t.Fatalf("expected 3 items, got %d", len(got))
+ }
+ for i, item := range got {
+ if item.Name != input[i].Name {
+ t.Errorf("order should be preserved, index %d: got %q, want %q", i, item.Name, input[i].Name)
+ }
+ }
+}
+
+func TestReorderItems_NoneRecommended(t *testing.T) {
+ input := items("x", "y")
+ got := ReorderItems(input)
+ if len(got) != 2 || got[0].Name != "x" || got[1].Name != "y" {
+ t.Errorf("order should be preserved, got %v", got)
+ }
+}
+
+// --- Multi-select otherStart ---
+
+func TestMultiOtherStart(t *testing.T) {
+ tests := []struct {
+ name string
+ items []SelectItem
+ filter string
+ want int
+ }{
+ {"all recommended", recItems("a", "b"), "", 2},
+ {"none recommended", items("a", "b"), "", 0},
+ {"mixed", mixedItems(), "", 2},
+ {"with filter returns 0", mixedItems(), "other", 0},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := newMultiSelectorModel("test", tt.items, nil)
+ m.filter = tt.filter
+ if got := m.otherStart(); got != tt.want {
+ t.Errorf("otherStart() = %d, want %d", got, tt.want)
+ }
+ })
+ }
+}
+
+// --- Multi-select updateScroll ---
+
+func TestMultiUpdateScroll(t *testing.T) {
+ tests := []struct {
+ name string
+ cursor int
+ offset int
+ otherStart int
+ wantOffset int
+ }{
+ {"cursor in recommended resets scroll", 1, 5, 3, 0},
+ {"cursor at start of others", 2, 0, 2, 0},
+ {"cursor scrolls down in others", 12, 0, 2, 3},
+ {"cursor scrolls up in others", 4, 5, 2, 2},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ m := newMultiSelectorModel("test", nil, nil)
+ m.cursor = tt.cursor
+ m.scrollOffset = tt.offset
+ m.updateScroll(tt.otherStart)
+ if m.scrollOffset != tt.wantOffset {
+ t.Errorf("scrollOffset = %d, want %d", m.scrollOffset, tt.wantOffset)
+ }
+ })
+ }
+}
+
+// --- Multi-select View section headers ---
+
+func TestMultiView_SectionHeaders(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", []SelectItem{
+ {Name: "rec-a", Recommended: true},
+ {Name: "other-1"},
+ }, nil)
+ content := m.View()
+
+ if !strings.Contains(content, "Recommended") {
+ t.Error("should contain 'Recommended' header")
+ }
+ if !strings.Contains(content, "More") {
+ t.Error("should contain 'More' header")
+ }
+}
+
+func TestMultiView_CursorIndicator(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
+ m.cursor = 0
+ content := m.View()
+
+ if !strings.Contains(content, "▸") {
+ t.Error("should show ▸ cursor indicator")
+ }
+}
+
+func TestMultiView_CheckedItemShowsX(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), []string{"a"})
+ m.multi = true
+ content := m.View()
+
+ if !strings.Contains(content, "[x]") {
+ t.Error("checked item should show [x]")
+ }
+ if !strings.Contains(content, "[ ]") {
+ t.Error("unchecked item should show [ ]")
+ }
+}
+
+func TestMultiView_DefaultTag(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b"})
+ m.multi = true
+ content := m.View()
+
+ if !strings.Contains(content, "(default)") {
+ t.Error("should have (default) tag")
+ }
+ // preChecked[0] ("a") should be the default (last in checkOrder)
+ aIdx := strings.Index(content, "a")
+ defaultIdx := strings.Index(content, "(default)")
+ if defaultIdx < aIdx {
+ t.Error("(default) tag should appear after 'a' (the current default)")
+ }
+}
+
+func TestMultiView_PinnedRecommended(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", mixedItems(), nil)
+ m.cursor = 8
+ m.scrollOffset = 3
+ content := m.View()
+
+ if !strings.Contains(content, "rec-a") {
+ t.Error("recommended items should always be visible (pinned)")
+ }
+ if !strings.Contains(content, "rec-b") {
+ t.Error("recommended items should always be visible (pinned)")
+ }
+}
+
+func TestMultiView_OverflowIndicator(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", mixedItems(), nil)
+ content := m.View()
+
+ if !strings.Contains(content, "... and") {
+ t.Error("should show overflow indicator when more items than visible")
+ }
+}
+
+// --- Multi-select space toggle (including KeyRunes fallback for Windows PowerShell) ---
+
+func TestMultiUpdate_SpaceTogglesItem(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
+ m.multi = true
+ m.cursor = 1
+
+ // Simulate space delivered as tea.KeySpace
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
+ m = updated.(multiSelectorModel)
+
+ if !m.checked[1] {
+ t.Error("space (KeySpace) should toggle the item at cursor")
+ }
+ if m.filter != "" {
+ t.Error("space should not modify filter")
+ }
+}
+
+func TestMultiUpdate_SpaceRuneTogglesItem(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
+ m.multi = true
+ m.cursor = 1
+
+ // Simulate space delivered as tea.KeyRunes (Windows PowerShell behavior)
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
+ m = updated.(multiSelectorModel)
+
+ if !m.checked[1] {
+ t.Error("space (KeyRunes) should toggle the item at cursor")
+ }
+ if m.filter != "" {
+ t.Error("space rune should not be added to filter")
+ }
+ if m.cursor != 1 {
+ t.Errorf("cursor should stay at 1, got %d", m.cursor)
+ }
+}
+
+// --- Single-add mode ---
+
+func TestMulti_StartsInSingleMode(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
+ if m.multi {
+ t.Error("should start in single mode (multi=false)")
+ }
+}
+
+func TestMulti_SingleModeNoCheckboxes(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
+ content := m.View()
+ if strings.Contains(content, "[x]") || strings.Contains(content, "[ ]") {
+ t.Error("single mode should not show checkboxes")
+ }
+ if !strings.Contains(content, "▸") {
+ t.Error("single mode should show cursor indicator")
+ }
+}
+
+func TestMulti_SingleModeEnterPicksItem(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b", "c"), nil)
+ m.cursor = 1
+
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyEnter})
+ m = updated.(multiSelectorModel)
+
+ if m.singleAdd != "b" {
+ t.Errorf("enter in single mode should pick cursor item, got %q", m.singleAdd)
+ }
+ if !m.confirmed {
+ t.Error("should set confirmed")
+ }
+}
+
+func TestMulti_SingleModeSpaceIsNoop(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
+ m.cursor = 0
+
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeySpace})
+ m = updated.(multiSelectorModel)
+
+ if len(m.checked) != 0 {
+ t.Error("space in single mode should not toggle items")
+ }
+}
+
+func TestMulti_SingleModeSpaceRuneIsNoop(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
+ m.cursor = 0
+
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{' '}})
+ m = updated.(multiSelectorModel)
+
+ if len(m.checked) != 0 {
+ t.Error("space rune in single mode should not toggle items")
+ }
+ if m.filter != "" {
+ t.Error("space rune in single mode should not add to filter")
+ }
+}
+
+func TestMulti_TabTogglesMode(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a", "b"), nil)
+
+ if m.multi {
+ t.Fatal("should start in single mode")
+ }
+
+ updated, _ := m.Update(tea.KeyMsg{Type: tea.KeyTab})
+ m = updated.(multiSelectorModel)
+ if !m.multi {
+ t.Error("tab should switch to multi mode")
+ }
+
+ updated, _ = m.Update(tea.KeyMsg{Type: tea.KeyTab})
+ m = updated.(multiSelectorModel)
+ if m.multi {
+ t.Error("tab should switch back to single mode")
+ }
+}
+
+func TestMulti_SingleModeHelpText(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a"), nil)
+ content := m.View()
+ if !strings.Contains(content, "tab add multiple") {
+ t.Error("single mode should show 'tab add multiple' in help")
+ }
+}
+
+func TestMulti_MultiModeHelpText(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("a"), nil)
+ m.multi = true
+ content := m.View()
+ if !strings.Contains(content, "tab select single") {
+ t.Error("multi mode should show 'tab select single' in help")
+ }
+}
+
+// --- preChecked initialization order ---
+
+func TestMulti_PreCheckedDefaultIsLast(t *testing.T) {
+ // preChecked[0] ("a") is the current default and should end up
+ // last in checkOrder so it gets the (default) tag.
+ m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"a", "b", "c"})
+
+ if len(m.checkOrder) != 3 {
+ t.Fatalf("expected 3 in checkOrder, got %d", len(m.checkOrder))
+ }
+ lastIdx := m.checkOrder[len(m.checkOrder)-1]
+ if m.items[lastIdx].Name != "a" {
+ t.Errorf("preChecked[0] should be last in checkOrder, got %q", m.items[lastIdx].Name)
+ }
+}
+
+func TestMulti_CursorOnDefaultModel(t *testing.T) {
+ // preChecked[0] ("b") is the default; cursor should start on it
+ m := newMultiSelectorModel("Pick:", items("a", "b", "c"), []string{"b", "c"})
+
+ if m.cursor != 1 {
+ t.Errorf("cursor should be on preChecked[0] ('b') at index 1, got %d", m.cursor)
+ }
+}
+
+// --- Multi-mode last-checked is default ---
+
+func TestMulti_LastCheckedIsDefault(t *testing.T) {
+ m := newMultiSelectorModel("Pick:", items("alpha", "beta", "gamma"), nil)
+ m.multi = true
+
+ // Check "alpha" then "gamma"
+ m.cursor = 0
+ m.toggleItem()
+ m.cursor = 2
+ m.toggleItem()
+
+ // Last checked ("gamma") should be at the end of checkOrder
+ lastIdx := m.checkOrder[len(m.checkOrder)-1]
+ if m.items[lastIdx].Name != "gamma" {
+ t.Errorf("last checked should be 'gamma', got %q", m.items[lastIdx].Name)
+ }
+
+ // The (default) tag renders based on checkOrder[len-1]
+ content := m.View()
+ if !strings.Contains(content, "(default)") {
+ t.Fatal("should show (default) tag")
+ }
+ // "alpha" line should NOT have the default tag
+ for _, line := range strings.Split(content, "\n") {
+ if strings.Contains(line, "alpha") && strings.Contains(line, "(default)") {
+ t.Error("'alpha' (first checked) should not have (default) tag")
+ }
+ }
+}
+
+// Key message helpers for testing
+
+type keyType = int
+
+const (
+ KeyUp keyType = iota
+ KeyDown keyType = iota
+ KeyBackspace keyType = iota
+)
+
+func keyMsg(k keyType) tea.KeyMsg {
+ switch k {
+ case KeyUp:
+ return tea.KeyMsg{Type: tea.KeyUp}
+ case KeyDown:
+ return tea.KeyMsg{Type: tea.KeyDown}
+ case KeyBackspace:
+ return tea.KeyMsg{Type: tea.KeyBackspace}
+ default:
+ return tea.KeyMsg{}
+ }
+}
+
+func runeMsg(r rune) tea.KeyMsg {
+ return tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{r}}
+}
diff --git a/cmd/tui/signin.go b/cmd/tui/signin.go
new file mode 100644
index 00000000000..118dbdf1c4b
--- /dev/null
+++ b/cmd/tui/signin.go
@@ -0,0 +1,128 @@
+package tui
+
+import (
+ "fmt"
+ "strings"
+ "time"
+
+ tea "github.com/charmbracelet/bubbletea"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/ollama/ollama/cmd/config"
+)
+
+type signInModel struct {
+ modelName string
+ signInURL string
+ spinner int
+ width int
+ userName string
+ cancelled bool
+}
+
+func (m signInModel) Init() tea.Cmd {
+ return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
+ return signInTickMsg{}
+ })
+}
+
+func (m signInModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ switch msg := msg.(type) {
+ case tea.WindowSizeMsg:
+ wasSet := m.width > 0
+ m.width = msg.Width
+ if wasSet {
+ return m, tea.EnterAltScreen
+ }
+ return m, nil
+
+ case tea.KeyMsg:
+ switch msg.Type {
+ case tea.KeyCtrlC, tea.KeyEsc:
+ m.cancelled = true
+ return m, tea.Quit
+ }
+
+ case signInTickMsg:
+ m.spinner++
+ if m.spinner%5 == 0 {
+ return m, tea.Batch(
+ tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
+ return signInTickMsg{}
+ }),
+ checkSignIn,
+ )
+ }
+ return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
+ return signInTickMsg{}
+ })
+
+ case signInCheckMsg:
+ if msg.signedIn {
+ m.userName = msg.userName
+ return m, tea.Quit
+ }
+ }
+
+ return m, nil
+}
+
+func (m signInModel) View() string {
+ if m.userName != "" {
+ return ""
+ }
+ return renderSignIn(m.modelName, m.signInURL, m.spinner, m.width)
+}
+
+func renderSignIn(modelName, signInURL string, spinner, width int) string {
+ spinnerFrames := []string{"⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"}
+ frame := spinnerFrames[spinner%len(spinnerFrames)]
+
+ urlColor := lipgloss.NewStyle().
+ Foreground(lipgloss.Color("117"))
+ urlWrap := lipgloss.NewStyle().PaddingLeft(2)
+ if width > 4 {
+ urlWrap = urlWrap.Width(width - 4)
+ }
+
+ var s strings.Builder
+
+ fmt.Fprintf(&s, "To use %s, please sign in.\n\n", selectorSelectedItemStyle.Render(modelName))
+
+ // Wrap in OSC 8 hyperlink so the entire URL is clickable even when wrapped.
+ // Padding is outside the hyperlink so spaces don't get underlined.
+ link := fmt.Sprintf("\033]8;;%s\033\\%s\033]8;;\033\\", signInURL, urlColor.Render(signInURL))
+ s.WriteString("Navigate to:\n")
+ s.WriteString(urlWrap.Render(link))
+ s.WriteString("\n\n")
+
+ s.WriteString(lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).Render(
+ frame + " Waiting for sign in to complete..."))
+ s.WriteString("\n\n")
+
+ s.WriteString(selectorHelpStyle.Render("esc cancel"))
+
+ return lipgloss.NewStyle().PaddingLeft(2).Render(s.String())
+}
+
+// RunSignIn shows a bubbletea sign-in dialog and polls until the user signs in or cancels.
+func RunSignIn(modelName, signInURL string) (string, error) {
+ config.OpenBrowser(signInURL)
+
+ m := signInModel{
+ modelName: modelName,
+ signInURL: signInURL,
+ }
+
+ p := tea.NewProgram(m)
+ finalModel, err := p.Run()
+ if err != nil {
+ return "", fmt.Errorf("error running sign-in: %w", err)
+ }
+
+ fm := finalModel.(signInModel)
+ if fm.cancelled {
+ return "", ErrCancelled
+ }
+
+ return fm.userName, nil
+}
diff --git a/cmd/tui/signin_test.go b/cmd/tui/signin_test.go
new file mode 100644
index 00000000000..0af9ddc6ea8
--- /dev/null
+++ b/cmd/tui/signin_test.go
@@ -0,0 +1,175 @@
+package tui
+
+import (
+ "strings"
+ "testing"
+
+ tea "github.com/charmbracelet/bubbletea"
+)
+
+func TestRenderSignIn_ContainsModelName(t *testing.T) {
+ got := renderSignIn("glm-4.7:cloud", "https://example.com/signin", 0, 80)
+ if !strings.Contains(got, "glm-4.7:cloud") {
+ t.Error("should contain model name")
+ }
+ if !strings.Contains(got, "please sign in") {
+ t.Error("should contain sign-in prompt")
+ }
+}
+
+func TestRenderSignIn_ContainsURL(t *testing.T) {
+ url := "https://ollama.com/connect?key=abc123"
+ got := renderSignIn("test:cloud", url, 0, 120)
+ if !strings.Contains(got, url) {
+ t.Errorf("should contain URL %q", url)
+ }
+}
+
+func TestRenderSignIn_OSC8Hyperlink(t *testing.T) {
+ url := "https://ollama.com/connect?key=abc123"
+ got := renderSignIn("test:cloud", url, 0, 120)
+
+ // Should contain OSC 8 open sequence with the URL
+ osc8Open := "\033]8;;" + url + "\033\\"
+ if !strings.Contains(got, osc8Open) {
+ t.Error("should contain OSC 8 open sequence with URL")
+ }
+
+ // Should contain OSC 8 close sequence
+ osc8Close := "\033]8;;\033\\"
+ if !strings.Contains(got, osc8Close) {
+ t.Error("should contain OSC 8 close sequence")
+ }
+}
+
+func TestRenderSignIn_ContainsSpinner(t *testing.T) {
+ got := renderSignIn("test:cloud", "https://example.com", 0, 80)
+ if !strings.Contains(got, "Waiting for sign in to complete") {
+ t.Error("should contain waiting message")
+ }
+ if !strings.Contains(got, "⠋") {
+ t.Error("should contain first spinner frame at spinner=0")
+ }
+}
+
+func TestRenderSignIn_SpinnerAdvances(t *testing.T) {
+ got0 := renderSignIn("test:cloud", "https://example.com", 0, 80)
+ got1 := renderSignIn("test:cloud", "https://example.com", 1, 80)
+ if got0 == got1 {
+ t.Error("different spinner values should produce different output")
+ }
+}
+
+func TestRenderSignIn_ContainsEscHelp(t *testing.T) {
+ got := renderSignIn("test:cloud", "https://example.com", 0, 80)
+ if !strings.Contains(got, "esc cancel") {
+ t.Error("should contain esc cancel help text")
+ }
+}
+
+func TestSignInModel_EscCancels(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ }
+
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyEsc})
+ fm := updated.(signInModel)
+ if !fm.cancelled {
+ t.Error("esc should set cancelled=true")
+ }
+ if cmd == nil {
+ t.Error("esc should return tea.Quit")
+ }
+}
+
+func TestSignInModel_CtrlCCancels(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ }
+
+ updated, cmd := m.Update(tea.KeyMsg{Type: tea.KeyCtrlC})
+ fm := updated.(signInModel)
+ if !fm.cancelled {
+ t.Error("ctrl+c should set cancelled=true")
+ }
+ if cmd == nil {
+ t.Error("ctrl+c should return tea.Quit")
+ }
+}
+
+func TestSignInModel_SignedInQuitsClean(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ }
+
+ updated, cmd := m.Update(signInCheckMsg{signedIn: true, userName: "alice"})
+ fm := updated.(signInModel)
+ if fm.userName != "alice" {
+ t.Errorf("expected userName 'alice', got %q", fm.userName)
+ }
+ if cmd == nil {
+ t.Error("successful sign-in should return tea.Quit")
+ }
+}
+
+func TestSignInModel_SignedInViewClears(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ userName: "alice",
+ }
+
+ got := m.View()
+ if got != "" {
+ t.Errorf("View should return empty string after sign-in, got %q", got)
+ }
+}
+
+func TestSignInModel_NotSignedInContinues(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ }
+
+ updated, _ := m.Update(signInCheckMsg{signedIn: false})
+ fm := updated.(signInModel)
+ if fm.userName != "" {
+ t.Error("should not set userName when not signed in")
+ }
+ if fm.cancelled {
+ t.Error("should not cancel when check returns not signed in")
+ }
+}
+
+func TestSignInModel_WindowSizeUpdatesWidth(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ }
+
+ updated, _ := m.Update(tea.WindowSizeMsg{Width: 120, Height: 40})
+ fm := updated.(signInModel)
+ if fm.width != 120 {
+ t.Errorf("expected width 120, got %d", fm.width)
+ }
+}
+
+func TestSignInModel_TickAdvancesSpinner(t *testing.T) {
+ m := signInModel{
+ modelName: "test:cloud",
+ signInURL: "https://example.com",
+ spinner: 0,
+ }
+
+ updated, cmd := m.Update(signInTickMsg{})
+ fm := updated.(signInModel)
+ if fm.spinner != 1 {
+ t.Errorf("expected spinner=1, got %d", fm.spinner)
+ }
+ if cmd == nil {
+ t.Error("tick should return a command")
+ }
+}
diff --git a/cmd/tui/tui.go b/cmd/tui/tui.go
new file mode 100644
index 00000000000..b9f1ef7b100
--- /dev/null
+++ b/cmd/tui/tui.go
@@ -0,0 +1,734 @@
+package tui
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ tea "github.com/charmbracelet/bubbletea"
+ "github.com/charmbracelet/lipgloss"
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/cmd/config"
+ "github.com/ollama/ollama/version"
+)
+
+var (
+ versionStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
+
+ menuItemStyle = lipgloss.NewStyle().
+ PaddingLeft(2)
+
+ menuSelectedItemStyle = lipgloss.NewStyle().
+ Bold(true).
+ Background(lipgloss.AdaptiveColor{Light: "254", Dark: "236"})
+
+ menuDescStyle = selectorDescStyle.
+ PaddingLeft(4)
+
+ greyedStyle = menuItemStyle.
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
+
+ greyedSelectedStyle = menuSelectedItemStyle.
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"})
+
+ modelStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "243", Dark: "250"})
+
+ notInstalledStyle = lipgloss.NewStyle().
+ Foreground(lipgloss.AdaptiveColor{Light: "242", Dark: "246"}).
+ Italic(true)
+)
+
+type menuItem struct {
+ title string
+ description string
+ integration string // integration name for loading model config, empty if not an integration
+ isRunModel bool
+ isOthers bool
+}
+
+var mainMenuItems = []menuItem{
+ {
+ title: "Run a model",
+ description: "Start an interactive chat with a model",
+ isRunModel: true,
+ },
+ {
+ title: "Launch Claude Code",
+ description: "Agentic coding across large codebases",
+ integration: "claude",
+ },
+ {
+ title: "Launch Codex",
+ description: "OpenAI's open-source coding agent",
+ integration: "codex",
+ },
+ {
+ title: "Launch OpenClaw",
+ description: "Personal AI with 100+ skills",
+ integration: "openclaw",
+ },
+}
+
+var othersMenuItem = menuItem{
+ title: "More...",
+ description: "Show additional integrations",
+ isOthers: true,
+}
+
+// getOtherIntegrations dynamically builds the "Others" list from the integration
+// registry, excluding any integrations already present in the pinned mainMenuItems.
+func getOtherIntegrations() []menuItem {
+ pinned := map[string]bool{
+ "run": true, // not an integration but in the pinned list
+ }
+ for _, item := range mainMenuItems {
+ if item.integration != "" {
+ pinned[item.integration] = true
+ }
+ }
+
+ var others []menuItem
+ for _, info := range config.ListIntegrationInfos() {
+ if pinned[info.Name] {
+ continue
+ }
+ desc := info.Description
+ if desc == "" {
+ desc = "Open " + info.DisplayName + " integration"
+ }
+ others = append(others, menuItem{
+ title: "Launch " + info.DisplayName,
+ description: desc,
+ integration: info.Name,
+ })
+ }
+ return others
+}
+
+type model struct {
+ items []menuItem
+ cursor int
+ quitting bool
+ selected bool
+ changeModel bool
+ changeModels []string // multi-select result for Editor integrations
+ showOthers bool
+ availableModels map[string]bool
+ err error
+
+ showingModal bool
+ modalSelector selectorModel
+ modalItems []SelectItem
+
+ showingMultiModal bool
+ multiModalSelector multiSelectorModel
+
+ showingSignIn bool
+ signInURL string
+ signInModel string
+ signInSpinner int
+ signInFromModal bool // true if sign-in was triggered from modal (not main menu)
+
+ width int // terminal width from WindowSizeMsg
+ statusMsg string // temporary status message shown near help text
+}
+
+type signInTickMsg struct{}
+
+type signInCheckMsg struct {
+ signedIn bool
+ userName string
+}
+
+type clearStatusMsg struct{}
+
+func (m *model) modelExists(name string) bool {
+ if m.availableModels == nil || name == "" {
+ return false
+ }
+ if m.availableModels[name] {
+ return true
+ }
+ // Check for prefix match (e.g., "llama2" matches "llama2:latest")
+ for modelName := range m.availableModels {
+ if strings.HasPrefix(modelName, name+":") {
+ return true
+ }
+ }
+ return false
+}
+
+func (m *model) buildModalItems() []SelectItem {
+ modelItems, _ := config.GetModelItems(context.Background())
+ return ReorderItems(ConvertItems(modelItems))
+}
+
+func (m *model) openModelModal(currentModel string) {
+ m.modalItems = m.buildModalItems()
+ cursor := 0
+ if currentModel != "" {
+ for i, item := range m.modalItems {
+ if item.Name == currentModel || strings.HasPrefix(item.Name, currentModel+":") || strings.HasPrefix(currentModel, item.Name+":") {
+ cursor = i
+ break
+ }
+ }
+ }
+ m.modalSelector = selectorModel{
+ title: "Select model:",
+ items: m.modalItems,
+ cursor: cursor,
+ helpText: "↑/↓ navigate • enter select • ← back",
+ }
+ m.modalSelector.updateScroll(m.modalSelector.otherStart())
+ m.showingModal = true
+}
+
+func (m *model) openMultiModelModal(integration string) {
+ items := m.buildModalItems()
+ var preChecked []string
+ if models := config.IntegrationModels(integration); len(models) > 0 {
+ preChecked = models
+ }
+ m.multiModalSelector = newMultiSelectorModel("Select models:", items, preChecked)
+ // Set cursor to the first pre-checked (last used) model
+ if len(preChecked) > 0 {
+ for i, item := range items {
+ if item.Name == preChecked[0] {
+ m.multiModalSelector.cursor = i
+ m.multiModalSelector.updateScroll(m.multiModalSelector.otherStart())
+ break
+ }
+ }
+ }
+ m.showingMultiModal = true
+}
+
+func isCloudModel(name string) bool {
+ return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
+}
+
+func cloudStatusDisabled(client *api.Client) bool {
+ status, err := client.CloudStatusExperimental(context.Background())
+ if err != nil {
+ return false
+ }
+ return status.Cloud.Disabled
+}
+
+func cloudModelDisabled(name string) bool {
+ if !isCloudModel(name) {
+ return false
+ }
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return false
+ }
+ return cloudStatusDisabled(client)
+}
+
+// checkCloudSignIn checks if a cloud model needs sign-in.
+// Returns a command to start sign-in if needed, or nil if already signed in.
+func (m *model) checkCloudSignIn(modelName string, fromModal bool) tea.Cmd {
+ if modelName == "" || !isCloudModel(modelName) {
+ return nil
+ }
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return nil
+ }
+ if cloudStatusDisabled(client) {
+ return nil
+ }
+ user, err := client.Whoami(context.Background())
+ if err == nil && user != nil && user.Name != "" {
+ return nil
+ }
+ var aErr api.AuthorizationError
+ if errors.As(err, &aErr) && aErr.SigninURL != "" {
+ return m.startSignIn(modelName, aErr.SigninURL, fromModal)
+ }
+ return nil
+}
+
+// startSignIn initiates the sign-in flow for a cloud model.
+// fromModal indicates if this was triggered from the model picker modal.
+func (m *model) startSignIn(modelName, signInURL string, fromModal bool) tea.Cmd {
+ m.showingModal = false
+ m.showingSignIn = true
+ m.signInURL = signInURL
+ m.signInModel = modelName
+ m.signInSpinner = 0
+ m.signInFromModal = fromModal
+
+ config.OpenBrowser(signInURL)
+
+ return tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
+ return signInTickMsg{}
+ })
+}
+
+func checkSignIn() tea.Msg {
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return signInCheckMsg{signedIn: false}
+ }
+ user, err := client.Whoami(context.Background())
+ if err == nil && user != nil && user.Name != "" {
+ return signInCheckMsg{signedIn: true, userName: user.Name}
+ }
+ return signInCheckMsg{signedIn: false}
+}
+
+func (m *model) loadAvailableModels() {
+ m.availableModels = make(map[string]bool)
+ client, err := api.ClientFromEnvironment()
+ if err != nil {
+ return
+ }
+ models, err := client.List(context.Background())
+ if err != nil {
+ return
+ }
+ cloudDisabled := cloudStatusDisabled(client)
+ for _, mdl := range models.Models {
+ if cloudDisabled && mdl.RemoteModel != "" {
+ continue
+ }
+ m.availableModels[mdl.Name] = true
+ }
+}
+
+func (m *model) buildItems() {
+ others := getOtherIntegrations()
+ m.items = make([]menuItem, 0, len(mainMenuItems)+1+len(others))
+ m.items = append(m.items, mainMenuItems...)
+
+ if m.showOthers {
+ m.items = append(m.items, others...)
+ } else {
+ m.items = append(m.items, othersMenuItem)
+ }
+}
+
+func isOthersIntegration(name string) bool {
+ for _, item := range getOtherIntegrations() {
+ if item.integration == name {
+ return true
+ }
+ }
+ return false
+}
+
+func initialModel() model {
+ m := model{
+ cursor: 0,
+ }
+ m.loadAvailableModels()
+
+ lastSelection := config.LastSelection()
+ if isOthersIntegration(lastSelection) {
+ m.showOthers = true
+ }
+
+ m.buildItems()
+
+ if lastSelection != "" {
+ for i, item := range m.items {
+ if lastSelection == "run" && item.isRunModel {
+ m.cursor = i
+ break
+ } else if item.integration == lastSelection {
+ m.cursor = i
+ break
+ }
+ }
+ }
+
+ return m
+}
+
+func (m model) Init() tea.Cmd {
+ return nil
+}
+
+func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
+ if wmsg, ok := msg.(tea.WindowSizeMsg); ok {
+ wasSet := m.width > 0
+ m.width = wmsg.Width
+ if wasSet {
+ return m, tea.EnterAltScreen
+ }
+ return m, nil
+ }
+
+ if _, ok := msg.(clearStatusMsg); ok {
+ m.statusMsg = ""
+ return m, nil
+ }
+
+ if m.showingSignIn {
+ switch msg := msg.(type) {
+ case tea.KeyMsg:
+ switch msg.Type {
+ case tea.KeyCtrlC, tea.KeyEsc:
+ m.showingSignIn = false
+ if m.signInFromModal {
+ m.showingModal = true
+ }
+ return m, nil
+ }
+
+ case signInTickMsg:
+ m.signInSpinner++
+ // Check sign-in status every 5th tick (~1 second)
+ if m.signInSpinner%5 == 0 {
+ return m, tea.Batch(
+ tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
+ return signInTickMsg{}
+ }),
+ checkSignIn,
+ )
+ }
+ return m, tea.Tick(200*time.Millisecond, func(t time.Time) tea.Msg {
+ return signInTickMsg{}
+ })
+
+ case signInCheckMsg:
+ if msg.signedIn {
+ if m.signInFromModal {
+ m.modalSelector.selected = m.signInModel
+ m.changeModel = true
+ } else {
+ m.selected = true
+ }
+ m.quitting = true
+ return m, tea.Quit
+ }
+ }
+ return m, nil
+ }
+
+ if m.showingMultiModal {
+ switch msg := msg.(type) {
+ case tea.KeyMsg:
+ if msg.Type == tea.KeyLeft {
+ m.showingMultiModal = false
+ return m, nil
+ }
+ updated, cmd := m.multiModalSelector.Update(msg)
+ m.multiModalSelector = updated.(multiSelectorModel)
+
+ if m.multiModalSelector.cancelled {
+ m.showingMultiModal = false
+ return m, nil
+ }
+ if m.multiModalSelector.confirmed {
+ var selected []string
+ if m.multiModalSelector.singleAdd != "" {
+ // Single-add mode: prepend picked model, keep existing deduped
+ selected = []string{m.multiModalSelector.singleAdd}
+ for _, name := range config.IntegrationModels(m.items[m.cursor].integration) {
+ if name != m.multiModalSelector.singleAdd {
+ selected = append(selected, name)
+ }
+ }
+ } else {
+ // Last checked is default (first in result)
+ co := m.multiModalSelector.checkOrder
+ last := co[len(co)-1]
+ selected = []string{m.multiModalSelector.items[last].Name}
+ for _, idx := range co {
+ if idx != last {
+ selected = append(selected, m.multiModalSelector.items[idx].Name)
+ }
+ }
+ }
+ if len(selected) > 0 {
+ m.changeModels = selected
+ m.changeModel = true
+ m.quitting = true
+ return m, tea.Quit
+ }
+ m.multiModalSelector.confirmed = false
+ return m, nil
+ }
+ return m, cmd
+ }
+ return m, nil
+ }
+
+ if m.showingModal {
+ switch msg := msg.(type) {
+ case tea.KeyMsg:
+ switch msg.Type {
+ case tea.KeyCtrlC, tea.KeyEsc, tea.KeyLeft:
+ m.showingModal = false
+ return m, nil
+
+ case tea.KeyEnter:
+ filtered := m.modalSelector.filteredItems()
+ if len(filtered) > 0 && m.modalSelector.cursor < len(filtered) {
+ m.modalSelector.selected = filtered[m.modalSelector.cursor].Name
+ }
+ if m.modalSelector.selected != "" {
+ if cmd := m.checkCloudSignIn(m.modalSelector.selected, true); cmd != nil {
+ return m, cmd
+ }
+ m.changeModel = true
+ m.quitting = true
+ return m, tea.Quit
+ }
+ return m, nil
+
+ default:
+ // Delegate navigation (up/down/pgup/pgdown/filter/backspace) to selectorModel
+ m.modalSelector.updateNavigation(msg)
+ }
+ }
+ return m, nil
+ }
+
+ switch msg := msg.(type) {
+ case tea.KeyMsg:
+ switch msg.String() {
+ case "ctrl+c", "q", "esc":
+ m.quitting = true
+ return m, tea.Quit
+
+ case "up", "k":
+ if m.cursor > 0 {
+ m.cursor--
+ }
+ // Auto-collapse "Others" when cursor moves back into pinned items
+ if m.showOthers && m.cursor < len(mainMenuItems) {
+ m.showOthers = false
+ m.buildItems()
+ }
+
+ case "down", "j":
+ if m.cursor < len(m.items)-1 {
+ m.cursor++
+ }
+ // Auto-expand "Others..." when cursor lands on it
+ if m.cursor < len(m.items) && m.items[m.cursor].isOthers && !m.showOthers {
+ m.showOthers = true
+ m.buildItems()
+ // cursor now points at the first "other" integration
+ }
+
+ case "enter", " ":
+ item := m.items[m.cursor]
+
+ if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
+ return m, nil
+ }
+
+ var configuredModel string
+ if item.isRunModel {
+ configuredModel = config.LastModel()
+ } else if item.integration != "" {
+ configuredModel = config.IntegrationModel(item.integration)
+ }
+ if cmd := m.checkCloudSignIn(configuredModel, false); cmd != nil {
+ return m, cmd
+ }
+
+ if configuredModel != "" && isCloudModel(configuredModel) && cloudModelDisabled(configuredModel) {
+ if item.integration != "" && config.IsEditorIntegration(item.integration) {
+ m.openMultiModelModal(item.integration)
+ } else {
+ m.openModelModal(configuredModel)
+ }
+ return m, nil
+ }
+
+ m.selected = true
+ m.quitting = true
+ return m, tea.Quit
+
+ case "right", "l":
+ item := m.items[m.cursor]
+ if item.integration != "" || item.isRunModel {
+ if item.integration != "" && !config.IsIntegrationInstalled(item.integration) {
+ return m, nil
+ }
+ if item.integration != "" && config.IsEditorIntegration(item.integration) {
+ m.openMultiModelModal(item.integration)
+ } else {
+ var currentModel string
+ if item.isRunModel {
+ currentModel = config.LastModel()
+ } else if item.integration != "" {
+ currentModel = config.IntegrationModel(item.integration)
+ }
+ m.openModelModal(currentModel)
+ }
+ }
+ }
+ }
+
+ return m, nil
+}
+
+func (m model) View() string {
+ if m.quitting {
+ return ""
+ }
+
+ if m.showingSignIn {
+ return m.renderSignInDialog()
+ }
+
+ if m.showingMultiModal {
+ return m.multiModalSelector.View()
+ }
+
+ if m.showingModal {
+ return m.renderModal()
+ }
+
+ s := selectorTitleStyle.Render("Ollama "+versionStyle.Render(version.Version)) + "\n\n"
+
+ for i, item := range m.items {
+ cursor := ""
+ style := menuItemStyle
+ isInstalled := true
+
+ if item.integration != "" {
+ isInstalled = config.IsIntegrationInstalled(item.integration)
+ }
+
+ if m.cursor == i {
+ cursor = "▸ "
+ if isInstalled {
+ style = menuSelectedItemStyle
+ } else {
+ style = greyedSelectedStyle
+ }
+ } else if !isInstalled && item.integration != "" {
+ style = greyedStyle
+ }
+
+ title := item.title
+ var modelSuffix string
+ if item.integration != "" {
+ if !isInstalled {
+ title += " " + notInstalledStyle.Render("(not installed)")
+ } else if m.cursor == i {
+ if mdl := config.IntegrationModel(item.integration); mdl != "" && m.modelExists(mdl) {
+ modelSuffix = " " + modelStyle.Render("("+mdl+")")
+ }
+ }
+ } else if item.isRunModel && m.cursor == i {
+ if mdl := config.LastModel(); mdl != "" && m.modelExists(mdl) {
+ modelSuffix = " " + modelStyle.Render("("+mdl+")")
+ }
+ }
+
+ s += style.Render(cursor+title) + modelSuffix + "\n"
+
+ desc := item.description
+ if !isInstalled && item.integration != "" && m.cursor == i {
+ if hint := config.IntegrationInstallHint(item.integration); hint != "" {
+ desc = hint
+ } else {
+ desc = "not installed"
+ }
+ }
+ s += menuDescStyle.Render(desc) + "\n\n"
+ }
+
+ if m.statusMsg != "" {
+ s += "\n" + lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "124", Dark: "210"}).Render(m.statusMsg) + "\n"
+ }
+
+ s += "\n" + selectorHelpStyle.Render("↑/↓ navigate • enter launch • → change model • esc quit")
+
+ if m.width > 0 {
+ return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
+ }
+ return s
+}
+
+func (m model) renderModal() string {
+ modalStyle := lipgloss.NewStyle().
+ PaddingBottom(1).
+ PaddingRight(2)
+
+ s := modalStyle.Render(m.modalSelector.renderContent())
+ if m.width > 0 {
+ return lipgloss.NewStyle().MaxWidth(m.width).Render(s)
+ }
+ return s
+}
+
+func (m model) renderSignInDialog() string {
+ return renderSignIn(m.signInModel, m.signInURL, m.signInSpinner, m.width)
+}
+
+type Selection int
+
+const (
+ SelectionNone Selection = iota
+ SelectionRunModel
+ SelectionChangeRunModel
+ SelectionIntegration // Generic integration selection
+ SelectionChangeIntegration // Generic change model for integration
+)
+
+type Result struct {
+ Selection Selection
+ Integration string // integration name if applicable
+ Model string // model name if selected from single-select modal
+ Models []string // models selected from multi-select modal (Editor integrations)
+}
+
+func Run() (Result, error) {
+ m := initialModel()
+ p := tea.NewProgram(m)
+
+ finalModel, err := p.Run()
+ if err != nil {
+ return Result{Selection: SelectionNone}, fmt.Errorf("error running TUI: %w", err)
+ }
+
+ fm := finalModel.(model)
+ if fm.err != nil {
+ return Result{Selection: SelectionNone}, fm.err
+ }
+
+ if !fm.selected && !fm.changeModel {
+ return Result{Selection: SelectionNone}, nil
+ }
+
+ item := fm.items[fm.cursor]
+
+ if fm.changeModel {
+ if item.isRunModel {
+ return Result{
+ Selection: SelectionChangeRunModel,
+ Model: fm.modalSelector.selected,
+ }, nil
+ }
+ return Result{
+ Selection: SelectionChangeIntegration,
+ Integration: item.integration,
+ Model: fm.modalSelector.selected,
+ Models: fm.changeModels,
+ }, nil
+ }
+
+ if item.isRunModel {
+ return Result{Selection: SelectionRunModel}, nil
+ }
+
+ return Result{
+ Selection: SelectionIntegration,
+ Integration: item.integration,
+ }, nil
+}
diff --git a/convert/convert.go b/convert/convert.go
index b2e6f5e3700..1f318be9082 100644
--- a/convert/convert.go
+++ b/convert/convert.go
@@ -313,6 +313,12 @@ func LoadModelMetadata(fsys fs.FS) (ModelKV, *Tokenizer, error) {
conv = &deepseek2Model{}
case "Glm4MoeLiteForCausalLM":
conv = &glm4MoeLiteModel{}
+ case "GlmOcrForConditionalGeneration":
+ conv = &glmOcrModel{}
+ case "Lfm2ForCausalLM":
+ conv = &lfm2Model{}
+ case "Qwen3NextForCausalLM":
+ conv = &qwen3NextModel{}
default:
return nil, nil, fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}
diff --git a/convert/convert_glm4moelite.go b/convert/convert_glm4moelite.go
index a74a2fee63b..492266e6c63 100644
--- a/convert/convert_glm4moelite.go
+++ b/convert/convert_glm4moelite.go
@@ -6,6 +6,10 @@ import (
"log/slog"
"regexp"
"strconv"
+ "strings"
+
+ "github.com/pdevine/tensor"
+ "github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
@@ -69,6 +73,9 @@ func (p *glm4MoeLiteModel) KV(t *Tokenizer) KV {
kv["glm4moelite.rope.dimension_count"] = p.QKRopeHeadDim
kv["glm4moelite.rope.freq_base"] = cmp.Or(p.RopeTheta, float32(1000000.0))
+ kv["glm4moelite.attention.key_length_mla"] = p.KVLoraRank + p.QKRopeHeadDim
+ kv["glm4moelite.attention.value_length_mla"] = p.KVLoraRank
+
kv["tokenizer.ggml.pre"] = "glm4"
return kv
@@ -100,6 +107,67 @@ func (p *glm4MoeLiteModel) Replacements() []string {
}
}
+// repackKVB extracts K or V from the combined KV_B tensor for MLA absorption.
+// K output row-major: [n_head, kv_lora_rank, qk_nope] -> GGML ne[]={qk_nope, kv_lora_rank, n_head}
+// V output row-major: [n_head, v_head, kv_lora_rank] -> GGML ne[]={kv_lora_rank, v_head, n_head}
+func (p *glm4MoeLiteModel) repackKVB(extractK bool, kvFirst bool, numHeads int) Repacker {
+ qkNope := int(p.QKNopeHeadDim)
+ vHeadDim := int(p.VHeadDim)
+ kvLoraRank := int(p.KVLoraRank)
+ kvPerHead := qkNope + vHeadDim
+
+ return func(_ string, data []float32, shape []uint64) ([]float32, error) {
+ dims := make([]int, len(shape))
+ for i := range shape {
+ dims[i] = int(shape[i])
+ }
+
+ var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
+ var err error
+
+ // Normalize to [n_head * (qk_nope + v_head), kv_lora_rank] layout
+ if kvFirst {
+ tt, err = tensor.Transpose(tt, 1, 0)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+ }
+
+ // Reshape to [n_head, qk_nope + v_head, kv_lora_rank]
+ if err := tt.Reshape(numHeads, kvPerHead, kvLoraRank); err != nil {
+ return nil, err
+ }
+
+ if extractK {
+ // Slice K: [n_head, qk_nope, kv_lora_rank]
+ tt, err = tt.Slice(nil, tensor.S(0, qkNope), nil)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+ // Transpose to [n_head, kv_lora_rank, qk_nope]
+ tt, err = tensor.Transpose(tt, 0, 2, 1)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+ } else {
+ // Slice V: [n_head, v_head, kv_lora_rank] - already correct layout
+ tt, err = tt.Slice(nil, tensor.S(qkNope, kvPerHead), nil)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+ }
+
+ if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
+ return nil, err
+ }
+ return native.VectorF32(tt.(*tensor.Dense))
+ }
+}
+
func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
@@ -139,6 +207,52 @@ func (p *glm4MoeLiteModel) Tensors(s []Tensor) (out []*ggml.Tensor) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
+
+ // Split attn_kv_b into separate attn_k_b and attn_v_b for MLA absorption
+ if strings.HasSuffix(t.Name(), ".attn_kv_b.weight") {
+ qkNope := int(p.QKNopeHeadDim)
+ vHeadDim := int(p.VHeadDim)
+ kvLoraRank := int(p.KVLoraRank)
+ kvPerHead := qkNope + vHeadDim
+ numHeads := int(p.NumAttentionHeads)
+ kvFirst := true
+ if len(t.Shape()) == 2 {
+ switch {
+ case int(t.Shape()[0]) == kvLoraRank:
+ if kvPerHead > 0 && int(t.Shape()[1])%kvPerHead == 0 {
+ numHeads = int(t.Shape()[1]) / kvPerHead
+ }
+ kvFirst = true
+ case int(t.Shape()[1]) == kvLoraRank:
+ if kvPerHead > 0 && int(t.Shape()[0])%kvPerHead == 0 {
+ numHeads = int(t.Shape()[0]) / kvPerHead
+ }
+ kvFirst = false
+ default:
+ slog.Warn("glm4moelite: unexpected attn_kv_b layout", "name", t.Name(), "shape", t.Shape())
+ }
+ }
+
+ kTensor := t.Clone()
+ kTensor.SetRepacker(p.repackKVB(true, kvFirst, numHeads))
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(t.Name(), "attn_kv_b", "attn_k_b", 1),
+ Kind: t.Kind(),
+ Shape: []uint64{uint64(numHeads), uint64(kvLoraRank), uint64(qkNope)},
+ WriterTo: kTensor,
+ })
+
+ vTensor := t.Clone()
+ vTensor.SetRepacker(p.repackKVB(false, kvFirst, numHeads))
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(t.Name(), "attn_kv_b", "attn_v_b", 1),
+ Kind: t.Kind(),
+ Shape: []uint64{uint64(numHeads), uint64(vHeadDim), uint64(kvLoraRank)},
+ WriterTo: vTensor,
+ })
+ continue
+ }
+
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
diff --git a/convert/convert_glmocr.go b/convert/convert_glmocr.go
new file mode 100644
index 00000000000..c8524fdbf26
--- /dev/null
+++ b/convert/convert_glmocr.go
@@ -0,0 +1,455 @@
+package convert
+
+import (
+ "cmp"
+ "encoding/json"
+ "io/fs"
+ "log/slog"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "github.com/ollama/ollama/fs/ggml"
+ "github.com/pdevine/tensor"
+ "github.com/pdevine/tensor/native"
+)
+
+// normalToNeoXRepacker creates a repacker that permutes Q/K weights from interleaved (LLaMA)
+// to NeoX ordering for compatibility with GGML's M-RoPE kernel.
+//
+// For weights: reshape [out, in] -> [n_heads, head_dim, in], permute rotary dims, reshape back
+// For biases: reshape [out] -> [n_heads, head_dim], permute rotary dims, reshape back
+func normalToNeoXRepacker(nHeads, headDim int, partialRotaryFactor float32) func(string, []float32, []uint64) ([]float32, error) {
+ return func(_ string, data []float32, shape []uint64) ([]float32, error) {
+ rotaryDim := int(float32(headDim) * partialRotaryFactor)
+ if rotaryDim%2 != 0 {
+ rotaryDim = (rotaryDim / 2) * 2 // Round down to even
+ }
+
+ // Handle 1D (bias) or 2D (weight) tensors
+ is1D := len(shape) == 1
+ var inFeatures int
+ if is1D {
+ inFeatures = 1
+ } else {
+ inFeatures = int(shape[1])
+ }
+ outFeatures := int(shape[0])
+ nEffectiveHeads := outFeatures / headDim
+
+ if nEffectiveHeads != nHeads {
+ slog.Warn("normalToNeoX: unexpected head count", "effective", nEffectiveHeads, "expected", nHeads)
+ }
+
+ // Reshape to [n_heads, head_dim, in_features]
+ reshaped := make([]float32, len(data))
+ copy(reshaped, data)
+
+ // Permute the rotary dimensions: even indices first, then odd
+ // For each head, reorder [0,1,2,3,4,5...] to [0,2,4...,1,3,5...]
+ result := make([]float32, len(data))
+ halfRotary := rotaryDim / 2
+
+ for h := range nEffectiveHeads {
+ for f := range inFeatures {
+ for i := range halfRotary {
+ // Even dim (0, 2, 4, ...) -> position i
+ srcIdx := h*headDim*inFeatures + (2*i)*inFeatures + f
+ dstIdx := h*headDim*inFeatures + i*inFeatures + f
+ result[dstIdx] = reshaped[srcIdx]
+
+ // Odd dim (1, 3, 5, ...) -> position halfRotary + i
+ srcIdx = h*headDim*inFeatures + (2*i+1)*inFeatures + f
+ dstIdx = h*headDim*inFeatures + (halfRotary+i)*inFeatures + f
+ result[dstIdx] = reshaped[srcIdx]
+ }
+
+ // Non-rotary part: copy as-is
+ for i := rotaryDim; i < headDim; i++ {
+ srcIdx := h*headDim*inFeatures + i*inFeatures + f
+ result[srcIdx] = reshaped[srcIdx]
+ }
+ }
+ }
+
+ return result, nil
+ }
+}
+
+type glmOcrModel struct {
+ ModelParameters
+
+ TextConfig struct {
+ HiddenSize uint32 `json:"hidden_size"`
+ IntermediateSize uint32 `json:"intermediate_size"`
+ NumHiddenLayers uint32 `json:"num_hidden_layers"`
+ NumAttentionHeads uint32 `json:"num_attention_heads"`
+ NumKeyValueHeads uint32 `json:"num_key_value_heads"`
+ HeadDim uint32 `json:"head_dim"`
+ MaxPositionEmbed uint32 `json:"max_position_embeddings"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ PartialRotaryFactor float32 `json:"partial_rotary_factor"`
+ RopeParameters struct {
+ RopeType string `json:"rope_type"`
+ MRopeSection []int32 `json:"mrope_section"`
+ RopeTheta float32 `json:"rope_theta"`
+ PartialRotaryFactor float32 `json:"partial_rotary_factor"`
+ } `json:"rope_parameters"`
+ } `json:"text_config"`
+
+ VisionConfig struct {
+ HiddenSize uint32 `json:"hidden_size"`
+ IntermediateSize uint32 `json:"intermediate_size"`
+ Depth uint32 `json:"depth"`
+ NumHeads uint32 `json:"num_heads"`
+ ImageSize uint32 `json:"image_size"`
+ PatchSize uint32 `json:"patch_size"`
+ OutHiddenSize uint32 `json:"out_hidden_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ SpatialMergeSize uint32 `json:"spatial_merge_size"`
+ TemporalPatchSize uint32 `json:"temporal_patch_size"`
+ } `json:"vision_config"`
+
+ ImageStartTokenID uint32 `json:"image_start_token_id"`
+ ImageEndTokenID uint32 `json:"image_end_token_id"`
+ VideoStartTokenID uint32 `json:"video_start_token_id"`
+ VideoEndTokenID uint32 `json:"video_end_token_id"`
+ ImageTokenID uint32 `json:"image_token_id"`
+ VideoTokenID uint32 `json:"video_token_id"`
+
+ // Preprocessor config (preprocessor_config.json)
+ Preprocessor struct {
+ Size struct {
+ ShortestEdge uint32 `json:"shortest_edge"`
+ LongestEdge uint32 `json:"longest_edge"`
+ } `json:"size"`
+ PatchSize uint32 `json:"patch_size"`
+ TemporalPatchSize uint32 `json:"temporal_patch_size"`
+ MergeSize uint32 `json:"merge_size"`
+ ImageMean []float32 `json:"image_mean"`
+ ImageStd []float32 `json:"image_std"`
+ } `json:"-"`
+}
+
+var _ ModelConverter = (*glmOcrModel)(nil)
+
+func (m *glmOcrModel) parseMore(fsys fs.FS) error {
+ bts, err := fs.ReadFile(fsys, "preprocessor_config.json")
+ if err != nil {
+ return err
+ }
+
+ return json.Unmarshal(bts, &m.Preprocessor)
+}
+
+func (m *glmOcrModel) KV(t *Tokenizer) KV {
+ kv := m.ModelParameters.KV(t)
+ kv["general.architecture"] = "glmocr"
+
+ // Text model parameters
+ kv["glmocr.block_count"] = cmp.Or(m.TextConfig.NumHiddenLayers, 16)
+ kv["glmocr.embedding_length"] = cmp.Or(m.TextConfig.HiddenSize, 1536)
+ kv["glmocr.attention.head_count"] = cmp.Or(m.TextConfig.NumAttentionHeads, 16)
+ kv["glmocr.attention.head_count_kv"] = cmp.Or(m.TextConfig.NumKeyValueHeads, 8)
+ headDim := cmp.Or(m.TextConfig.HeadDim, m.TextConfig.HiddenSize/m.TextConfig.NumAttentionHeads)
+ kv["glmocr.attention.key_length"] = headDim
+ kv["glmocr.attention.value_length"] = headDim
+ kv["glmocr.feed_forward_length"] = cmp.Or(m.TextConfig.IntermediateSize, 4608)
+ kv["glmocr.attention.layer_norm_rms_epsilon"] = cmp.Or(m.TextConfig.RMSNormEps, 1e-5)
+ kv["glmocr.context_length"] = cmp.Or(m.TextConfig.MaxPositionEmbed, 131072)
+ kv["glmocr.rope.freq_base"] = cmp.Or(m.TextConfig.RopeParameters.RopeTheta, float32(10000))
+ kv["glmocr.rope.partial_rotary_factor"] = cmp.Or(m.TextConfig.RopeParameters.PartialRotaryFactor, m.TextConfig.PartialRotaryFactor, float32(1.0))
+ if len(m.TextConfig.RopeParameters.MRopeSection) > 0 {
+ kv["glmocr.rope.mrope_section"] = m.TextConfig.RopeParameters.MRopeSection
+ }
+
+ // Vision model parameters
+ kv["glmocr.vision.block_count"] = cmp.Or(m.VisionConfig.Depth, 24)
+ kv["glmocr.vision.embedding_length"] = cmp.Or(m.VisionConfig.HiddenSize, 1024)
+ kv["glmocr.vision.attention.head_count"] = cmp.Or(m.VisionConfig.NumHeads, 16)
+ kv["glmocr.vision.image_size"] = cmp.Or(m.VisionConfig.ImageSize, 336)
+ kv["glmocr.vision.patch_size"] = cmp.Or(m.VisionConfig.PatchSize, m.Preprocessor.PatchSize, 14)
+ kv["glmocr.vision.spatial_merge_size"] = cmp.Or(m.VisionConfig.SpatialMergeSize, m.Preprocessor.MergeSize, 2)
+ kv["glmocr.vision.temporal_patch_size"] = cmp.Or(m.VisionConfig.TemporalPatchSize, m.Preprocessor.TemporalPatchSize, 2)
+ kv["glmocr.vision.out_hidden_size"] = cmp.Or(m.VisionConfig.OutHiddenSize, 1536)
+ kv["glmocr.vision.intermediate_size"] = cmp.Or(m.VisionConfig.IntermediateSize, 4096)
+ kv["glmocr.vision.attention.layer_norm_rms_epsilon"] = cmp.Or(m.VisionConfig.RMSNormEps, 1e-5)
+
+ // Preprocessor-derived image settings (min/max pixels and normalization)
+ // Note: fs.Config.keyValue() auto-prepends architecture prefix, so use full key
+ if m.Preprocessor.Size.ShortestEdge > 0 {
+ kv["glmocr.vision.min_pixels"] = m.Preprocessor.Size.ShortestEdge
+ }
+ if m.Preprocessor.Size.LongestEdge > 0 {
+ kv["glmocr.vision.max_pixels"] = m.Preprocessor.Size.LongestEdge
+ }
+ if len(m.Preprocessor.ImageMean) == 3 {
+ kv["glmocr.vision.image_mean"] = m.Preprocessor.ImageMean
+ }
+ if len(m.Preprocessor.ImageStd) == 3 {
+ kv["glmocr.vision.image_std"] = m.Preprocessor.ImageStd
+ }
+
+ // Special tokens
+ kv["glmocr.image_token_id"] = m.ImageTokenID
+ kv["glmocr.image_start_token_id"] = m.ImageStartTokenID
+ kv["glmocr.image_end_token_id"] = m.ImageEndTokenID
+ kv["glmocr.video_token_id"] = m.VideoTokenID
+ kv["glmocr.video_start_token_id"] = m.VideoStartTokenID
+ kv["glmocr.video_end_token_id"] = m.VideoEndTokenID
+
+ return kv
+}
+
+func (m *glmOcrModel) Tensors(ts []Tensor) []*ggml.Tensor {
+ var out []*ggml.Tensor
+
+ // Skip layers >= num_hidden_layers (Multi-Token Prediction layers not needed for basic inference)
+ numLayers := int(cmp.Or(m.TextConfig.NumHiddenLayers, 16))
+ skipLayer := func(name string) bool {
+ // Tensor names are already replaced to "blk.N.xxx" format
+ re := regexp.MustCompile(`^blk\.(\d+)`)
+ matches := re.FindStringSubmatch(name)
+ if matches == nil {
+ return false
+ }
+ blkNum, err := strconv.Atoi(matches[1])
+ if err != nil {
+ return false
+ }
+ return blkNum >= numLayers
+ }
+
+ for _, t := range ts {
+ name := t.Name()
+
+ // Skip next-n prediction layers (layers >= num_hidden_layers)
+ if skipLayer(name) {
+ continue
+ }
+
+ // Split ffn_gate_up into separate gate and up projections
+ if strings.Contains(name, "ffn_gate_up") {
+ for t := range splitDim(t, 0,
+ split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_gate")},
+ split{Replacer: strings.NewReplacer("ffn_gate_up", "ffn_up")},
+ ) {
+ out = append(out, t)
+ }
+ continue
+ }
+
+ if strings.HasSuffix(name, "patch_embd.weight") {
+ shape := t.Shape()
+ if len(shape) == 5 && shape[2] == 2 {
+ newShape := []uint64{shape[0], shape[1], shape[3], shape[4]}
+
+ t0 := t.Clone()
+ t0.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
+ dims := make([]int, len(shape))
+ for i := range shape {
+ dims[i] = int(shape[i])
+ }
+ var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
+ tt, err := tt.Slice(nil, nil, tensor.S(0, 1), nil, nil)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+ newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
+ if err := tt.Reshape(newDims...); err != nil {
+ return nil, err
+ }
+ if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
+ return nil, err
+ }
+ return native.VectorF32(tt.(*tensor.Dense))
+ })
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
+ Kind: t.Kind(),
+ Shape: newShape,
+ WriterTo: t0,
+ })
+
+ t1 := t.Clone()
+ t1.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
+ dims := make([]int, len(shape))
+ for i := range shape {
+ dims[i] = int(shape[i])
+ }
+ var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
+ tt, err := tt.Slice(nil, nil, tensor.S(1, 2), nil, nil)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+ newDims := []int{int(shape[0]), int(shape[1]), int(shape[3]), int(shape[4])}
+ if err := tt.Reshape(newDims...); err != nil {
+ return nil, err
+ }
+ if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
+ return nil, err
+ }
+ return native.VectorF32(tt.(*tensor.Dense))
+ })
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.weight", "patch_embd_1.weight", 1),
+ Kind: t.Kind(),
+ Shape: newShape,
+ WriterTo: t1,
+ })
+
+ continue
+ }
+
+ if len(shape) == 4 {
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.weight", "patch_embd_0.weight", 1),
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: t,
+ })
+ continue
+ }
+
+ slog.Warn("glmocr: patch_embed weight has unexpected shape - not splitting", "shape", shape)
+ // Fall through to default handling
+ }
+
+ // Handle pre-split patch embedding weights
+ // Pattern 1: v.patch_embd.0.weight, v.patch_embd.1.weight -> patch_embd_0.weight, patch_embd_1.weight
+ // Pattern 2: v.patch_embd.weight.0, v.patch_embd.weight.1 -> patch_embd_0.weight, patch_embd_1.weight
+ if strings.Contains(name, "patch_embd.0.") {
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.0.", "patch_embd_0.", 1),
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: t,
+ })
+ continue
+ }
+ if strings.Contains(name, "patch_embd.1.") {
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.1.", "patch_embd_1.", 1),
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: t,
+ })
+ continue
+ }
+ // Handle .weight.0 and .weight.1 suffix patterns
+ if strings.HasSuffix(name, "patch_embd.weight.0") {
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.weight.0", "patch_embd_0.weight", 1),
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: t,
+ })
+ continue
+ }
+ if strings.HasSuffix(name, "patch_embd.weight.1") {
+ out = append(out, &ggml.Tensor{
+ Name: strings.Replace(name, "patch_embd.weight.1", "patch_embd_1.weight", 1),
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: t,
+ })
+ continue
+ }
+
+ // Permute Q/K weights for M-RoPE compatibility (interleaved -> NeoX ordering)
+ // GGML's M-RoPE kernel uses NeoX-style rotation, but GLM-OCR uses interleaved (LLaMA-style)
+ // We permute at conversion time so the weights work correctly with GGML's kernel
+ // This aligns Q/K rotary dimensions with GGML's NeoX-style rotation
+ if len(m.TextConfig.RopeParameters.MRopeSection) > 0 &&
+ strings.Contains(name, "blk.") && (strings.Contains(name, "attn_q.") || strings.Contains(name, "attn_k.")) {
+ // Get config values for permutation
+ nHeads := int(cmp.Or(m.TextConfig.NumAttentionHeads, 16))
+ nKVHeads := int(cmp.Or(m.TextConfig.NumKeyValueHeads, 8))
+ hiddenSize := int(cmp.Or(m.TextConfig.HiddenSize, 1536))
+ headDim := int(cmp.Or(m.TextConfig.HeadDim, uint32(hiddenSize/nHeads)))
+ partialRotaryFactor := cmp.Or(m.TextConfig.PartialRotaryFactor, m.TextConfig.RopeParameters.PartialRotaryFactor, float32(1.0))
+
+ // Use appropriate head count: nHeads for Q, nKVHeads for K
+ effectiveHeads := nHeads
+ if strings.Contains(name, "attn_k.") {
+ effectiveHeads = nKVHeads
+ }
+
+ permutedT := t.Clone()
+ permutedT.SetRepacker(normalToNeoXRepacker(effectiveHeads, headDim, partialRotaryFactor))
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: permutedT,
+ })
+ continue
+ }
+
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: t.Shape(),
+ WriterTo: t,
+ })
+ }
+
+ return out
+}
+
+func (m *glmOcrModel) Replacements() []string {
+ return []string{
+ // Vision encoder
+ "model.visual.patch_embed.proj_1", "v.patch_embd_1", // Second temporal split
+ "model.visual.patch_embed.proj", "v.patch_embd",
+ "model.visual.blocks", "v.blk",
+ "model.visual.post_layernorm", "v.post_ln",
+ "model.visual.downsample", "mm.patch_merger",
+
+ // Vision attention
+ "attn.qkv", "attn_qkv",
+ "attn.proj", "attn_out",
+ "attn.q_norm", "attn_q_norm",
+ "attn.k_norm", "attn_k_norm",
+
+ // Vision norms
+ "norm1", "ln1",
+ "norm2", "ln2",
+
+ // Vision MLP
+ "mlp.gate_proj", "ffn_gate",
+ "mlp.up_proj", "ffn_up",
+ "mlp.down_proj", "ffn_down",
+
+ // Merger (multimodal projector)
+ "model.visual.merger.proj", "mm.model.fc",
+ "model.visual.merger.post_projection_norm", "mm.post_norm",
+ "model.visual.merger.gate_proj", "mm.gate",
+ "model.visual.merger.up_proj", "mm.up",
+ "model.visual.merger.down_proj", "mm.down",
+
+ // Language model
+ "model.language_model.embed_tokens", "token_embd",
+ "model.language_model.layers", "blk",
+ "model.language_model.norm", "output_norm",
+ "lm_head", "output",
+
+ // Language model attention
+ "self_attn.q_proj", "attn_q",
+ "self_attn.k_proj", "attn_k",
+ "self_attn.v_proj", "attn_v",
+ "self_attn.o_proj", "attn_out",
+
+ // Language model norms
+ "input_layernorm", "attn_norm",
+ "post_attention_layernorm", "ffn_norm",
+ "post_self_attn_layernorm", "post_attn_norm",
+ "post_mlp_layernorm", "post_ffn_norm",
+
+ // Language model MLP (remove mlp. prefix so ffn_* names work)
+ "mlp.gate_up_proj", "ffn_gate_up",
+ "mlp.down_proj", "ffn_down",
+ }
+}
diff --git a/convert/convert_lfm2.go b/convert/convert_lfm2.go
new file mode 100644
index 00000000000..fdae1074c34
--- /dev/null
+++ b/convert/convert_lfm2.go
@@ -0,0 +1,100 @@
+package convert
+
+import (
+ "slices"
+ "strings"
+
+ "github.com/ollama/ollama/fs/ggml"
+)
+
+type lfm2Model struct {
+ ModelParameters
+ HiddenSize uint32 `json:"hidden_size"`
+ NumHiddenLayers uint32 `json:"num_hidden_layers"`
+ MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
+ IntermediateSize uint32 `json:"intermediate_size"`
+ NumAttentionHeads uint32 `json:"num_attention_heads"`
+ NumKeyValueHeads uint32 `json:"num_key_value_heads"`
+ RopeTheta float32 `json:"rope_theta"`
+ NormEps float32 `json:"norm_eps"`
+ ConvLCache uint32 `json:"conv_L_cache"`
+ LayerTypes []string `json:"layer_types"`
+ TieEmbedding bool `json:"tie_embedding"`
+}
+
+var _ ModelConverter = (*lfm2Model)(nil)
+
+func (p *lfm2Model) KV(t *Tokenizer) KV {
+ kv := p.ModelParameters.KV(t)
+ kv["general.architecture"] = "lfm2"
+ kv["lfm2.vocab_size"] = p.VocabSize
+ kv["lfm2.block_count"] = p.NumHiddenLayers
+ kv["lfm2.embedding_length"] = p.HiddenSize
+ kv["lfm2.feed_forward_length"] = p.IntermediateSize
+ kv["lfm2.context_length"] = p.MaxPositionEmbeddings
+
+ // Build per-layer KV head count array based on layer_types
+ // (0 = shortconv layer, non-zero = attention layer with that many KV heads)
+ kvHeadCounts := make([]uint32, p.NumHiddenLayers)
+ for i := range p.NumHiddenLayers {
+ if int(i) < len(p.LayerTypes) && p.LayerTypes[i] == "full_attention" {
+ kvHeadCounts[i] = p.NumKeyValueHeads
+ }
+ }
+
+ kv["lfm2.attention.head_count"] = p.NumAttentionHeads
+ kv["lfm2.attention.head_count_kv"] = kvHeadCounts
+ kv["lfm2.attention.key_length"] = p.HiddenSize / p.NumAttentionHeads
+ kv["lfm2.attention.value_length"] = p.HiddenSize / p.NumAttentionHeads
+ kv["lfm2.attention.layer_norm_rms_epsilon"] = p.NormEps
+ kv["lfm2.rope.freq_base"] = p.RopeTheta
+ kv["lfm2.shortconv.l_cache"] = p.ConvLCache
+
+ return kv
+}
+
+func (p *lfm2Model) Tensors(ts []Tensor) []*ggml.Tensor {
+ var out []*ggml.Tensor
+
+ for _, t := range ts {
+ shape := t.Shape()
+
+ // Squeeze conv weights: [D, 1, K] -> [D, K]
+ if strings.HasSuffix(t.Name(), "shortconv.conv.weight") {
+ if len(shape) == 3 && shape[1] == 1 {
+ shape = []uint64{shape[0], shape[2]}
+ }
+ }
+
+ out = append(out, &ggml.Tensor{
+ Name: t.Name(),
+ Kind: t.Kind(),
+ Shape: slices.Clone(shape),
+ WriterTo: t,
+ })
+ }
+
+ return out
+}
+
+func (p *lfm2Model) Replacements() []string {
+ return []string{
+ "model.embed_tokens", "token_embd",
+ "model.embedding_norm", "output_norm",
+ "model.layers", "blk",
+ "operator_norm", "attn_norm",
+ "self_attn.q_proj", "attn_q",
+ "self_attn.k_proj", "attn_k",
+ "self_attn.v_proj", "attn_v",
+ "self_attn.out_proj", "attn_output",
+ "self_attn.q_layernorm", "attn_q_norm",
+ "self_attn.k_layernorm", "attn_k_norm",
+ "conv.conv", "shortconv.conv",
+ "conv.in_proj", "shortconv.in_proj",
+ "conv.out_proj", "shortconv.out_proj",
+ "feed_forward.w1", "ffn_gate",
+ "feed_forward.w2", "ffn_down",
+ "feed_forward.w3", "ffn_up",
+ "ffn_norm", "ffn_norm",
+ }
+}
diff --git a/convert/convert_qwen3next.go b/convert/convert_qwen3next.go
new file mode 100644
index 00000000000..84db35e14df
--- /dev/null
+++ b/convert/convert_qwen3next.go
@@ -0,0 +1,512 @@
+package convert
+
+import (
+ "fmt"
+ "io/fs"
+ "math"
+ "slices"
+ "strings"
+
+ "github.com/pdevine/tensor"
+ "github.com/pdevine/tensor/native"
+
+ "github.com/ollama/ollama/fs/ggml"
+)
+
+type qwen3NextModel struct {
+ ModelParameters
+ MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
+ HiddenSize uint32 `json:"hidden_size"`
+ NumHiddenLayers uint32 `json:"num_hidden_layers"`
+ IntermediateSize uint32 `json:"intermediate_size"`
+ NumAttentionHeads uint32 `json:"num_attention_heads"`
+ NumKeyValueHeads uint32 `json:"num_key_value_heads"`
+ HeadDim uint32 `json:"head_dim"`
+ RopeTheta float32 `json:"rope_theta"`
+ RMSNormEPS float32 `json:"rms_norm_eps"`
+
+ // MoE config
+ NumExperts uint32 `json:"num_experts"`
+ NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
+ NormTopkProb bool `json:"norm_topk_prob"`
+ MoEIntermediateSize uint32 `json:"moe_intermediate_size"`
+ SharedExpertIntermSize uint32 `json:"shared_expert_intermediate_size"`
+
+ // Hybrid attention config
+ FullAttentionInterval uint32 `json:"full_attention_interval"`
+
+ // Linear attention (Gated Delta Net) config
+ LinearConvKernelDim uint32 `json:"linear_conv_kernel_dim"`
+ LinearKeyHeadDim uint32 `json:"linear_key_head_dim"`
+ LinearNumKeyHeads uint32 `json:"linear_num_key_heads"`
+ LinearNumValueHeads uint32 `json:"linear_num_value_heads"`
+ LinearValueHeadDim uint32 `json:"linear_value_head_dim"`
+
+ // RoPE config
+ PartialRotaryFactor float32 `json:"partial_rotary_factor"`
+ RopeScaling struct {
+ Type string `json:"type"`
+ Factor ropeFactor `json:"factor"`
+ } `json:"rope_scaling"`
+}
+
+var _ ModelConverter = (*qwen3NextModel)(nil)
+
+func (q *qwen3NextModel) parseMore(_ fs.FS) error {
+ if q.NumHiddenLayers == 0 {
+ return fmt.Errorf("qwen3next: num_hidden_layers must be set")
+ }
+ if q.NumAttentionHeads == 0 {
+ return fmt.Errorf("qwen3next: num_attention_heads must be set")
+ }
+ if q.NumKeyValueHeads == 0 {
+ return fmt.Errorf("qwen3next: num_key_value_heads must be set")
+ }
+ if q.HeadDim == 0 {
+ return fmt.Errorf("qwen3next: head_dim must be set")
+ }
+ if q.RopeTheta == 0 {
+ return fmt.Errorf("qwen3next: rope_theta must be set")
+ }
+ if q.PartialRotaryFactor <= 0 || q.PartialRotaryFactor > 1 {
+ return fmt.Errorf("qwen3next: partial_rotary_factor must be in (0,1], got %v", q.PartialRotaryFactor)
+ }
+ if q.LinearNumKeyHeads == 0 || q.LinearNumValueHeads == 0 || q.LinearKeyHeadDim == 0 || q.LinearValueHeadDim == 0 {
+ return fmt.Errorf("qwen3next: linear attention config must be set (linear_num_key_heads, linear_num_value_heads, linear_key_head_dim, linear_value_head_dim)")
+ }
+ if q.FullAttentionInterval == 0 {
+ return fmt.Errorf("qwen3next: full_attention_interval must be set")
+ }
+ if q.FullAttentionInterval > q.NumHiddenLayers {
+ return fmt.Errorf("qwen3next: full_attention_interval (%d) exceeds num_hidden_layers (%d)", q.FullAttentionInterval, q.NumHiddenLayers)
+ }
+
+ hasFull := false
+ for i := range q.NumHiddenLayers {
+ if (i+1)%q.FullAttentionInterval == 0 {
+ hasFull = true
+ break
+ }
+ }
+ if !hasFull {
+ return fmt.Errorf("qwen3next: head_count_kv would be all zeros (full_attention_interval=%d, num_hidden_layers=%d)", q.FullAttentionInterval, q.NumHiddenLayers)
+ }
+
+ return nil
+}
+
+func (q *qwen3NextModel) KV(t *Tokenizer) KV {
+ kv := q.ModelParameters.KV(t)
+ kv["general.architecture"] = "qwen3next"
+ kv["tokenizer.ggml.pre"] = "qwen2"
+ kv["block_count"] = q.NumHiddenLayers
+ kv["context_length"] = q.MaxPositionEmbeddings
+ kv["embedding_length"] = q.HiddenSize
+ kv["feed_forward_length"] = q.IntermediateSize
+ kv["attention.head_count"] = q.NumAttentionHeads
+ headDim := q.HeadDim
+ if headDim == 0 && q.NumAttentionHeads > 0 {
+ headDim = q.HiddenSize / q.NumAttentionHeads
+ }
+ kv["attention.key_length"] = headDim
+ kv["attention.value_length"] = headDim
+ kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
+ kv["rope.freq_base"] = q.RopeTheta
+
+ // RoPE dimension count (partial rotary)
+ // partial_rotary_factor = 0.25 means only 25% of head_dim uses RoPE
+ partialRotary := q.PartialRotaryFactor
+ if partialRotary > 0 && partialRotary <= 1 {
+ kv["rope.dimension_count"] = uint32(float32(headDim) * partialRotary)
+ }
+
+ // MoE config
+ if q.NumExperts > 0 {
+ kv["expert_count"] = q.NumExperts
+ kv["expert_used_count"] = q.NumExpertsPerToken
+ kv["norm_top_k_prob"] = q.NormTopkProb
+ if q.MoEIntermediateSize > 0 {
+ kv["expert_feed_forward_length"] = q.MoEIntermediateSize
+ }
+ if q.SharedExpertIntermSize > 0 {
+ kv["expert_shared_feed_forward_length"] = q.SharedExpertIntermSize
+ }
+ }
+
+ // SSM/Linear attention config
+ // d_inner = linear_value_head_dim * linear_num_value_heads
+ dInner := q.LinearValueHeadDim * q.LinearNumValueHeads
+ kv["ssm.inner_size"] = dInner
+ kv["ssm.state_size"] = q.LinearKeyHeadDim // head_k_dim
+ kv["ssm.group_count"] = q.LinearNumKeyHeads // num_k_heads
+ kv["ssm.time_step_rank"] = q.LinearNumValueHeads // num_v_heads
+ kv["ssm.conv_kernel"] = q.LinearConvKernelDim
+ interval := q.FullAttentionInterval
+ kv["full_attention_interval"] = interval
+
+ // Build per-layer KV head count array to identify layer types
+ // 0 = recurrent (linear attention), non-zero = full attention
+ kvHeadCounts := make([]uint32, q.NumHiddenLayers)
+ for i := range q.NumHiddenLayers {
+ // Full attention every full_attention_interval layers (starting at interval-1)
+ if interval > 0 && (i+1)%interval == 0 {
+ kvHeadCounts[i] = q.NumKeyValueHeads
+ }
+ // else stays 0 (recurrent layer)
+ }
+ kv["attention.head_count_kv"] = kvHeadCounts
+
+ // RoPE scaling
+ if q.RopeScaling.Type != "" {
+ kv["rope.scaling.type"] = q.RopeScaling.Type
+ kv["rope.scaling.factor"] = q.RopeScaling.Factor
+ }
+
+ return kv
+}
+
+func (q *qwen3NextModel) Tensors(ts []Tensor) []*ggml.Tensor {
+ var out []*ggml.Tensor
+
+ // Create merges for expert tensors - stack individual experts into batched tensors
+ merges := make([]merge, q.NumHiddenLayers*3)
+ for i := range q.NumHiddenLayers {
+ merges[i*3+0] = merge{
+ fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
+ fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
+ }
+ merges[i*3+1] = merge{
+ fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
+ fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
+ }
+ merges[i*3+2] = merge{
+ fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
+ fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
+ }
+ }
+
+ // Merge expert tensors
+ merged, remaining := mergeTensors(ts, merges...)
+ out = append(out, merged...)
+
+ // Process remaining tensors
+ for _, t := range remaining {
+ name := t.Name()
+ shape := t.Shape()
+
+ // Split linear_attn.in_proj_qkvz (ssm_in) into attn_qkv + attn_gate when possible
+ if strings.HasSuffix(name, ".ssm_in.weight") {
+ if qkv, gate, ok := q.splitQKVZTensor(t); ok {
+ out = append(out, qkv, gate)
+ continue
+ }
+ panic(fmt.Sprintf("qwen3next: failed to split %s into attn_qkv/attn_gate (shape=%v)", name, shape))
+ }
+
+ switch {
+ // Add 1 to norm weights (except ssm_norm which is linear_attn.norm)
+ // This matches the Python converter behavior for qwen3next
+ case strings.HasSuffix(name, "_norm.weight") && !strings.HasSuffix(name, ".ssm_norm.weight"):
+ t.SetRepacker(q.addOne)
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: slices.Clone(shape),
+ WriterTo: t,
+ })
+
+ // Handle linear attention A_log -> ssm_a (negate and exp)
+ // Note: name has already been transformed by Replacements at this point
+ case strings.HasSuffix(name, ".ssm_a"):
+ t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
+ // Compute -exp(A_log)
+ result := make([]float32, len(data))
+ for i, v := range data {
+ // -exp(v)
+ result[i] = -float32(math.Exp(float64(v)))
+ }
+ return result, nil
+ })
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: slices.Clone(shape),
+ WriterTo: t,
+ })
+
+ // Squeeze conv1d weights: [1, D, K] or [D, 1, K] -> [D, K]
+ case strings.HasSuffix(name, ".ssm_conv1d.weight"):
+ newShape := slices.Clone(shape)
+ if len(shape) == 3 {
+ if shape[0] == 1 {
+ // [1, D, K] -> [D, K]
+ newShape = []uint64{shape[1], shape[2]}
+ } else if shape[1] == 1 {
+ // [D, 1, K] -> [D, K]
+ newShape = []uint64{shape[0], shape[2]}
+ }
+ }
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: newShape,
+ WriterTo: t,
+ })
+ // Squeeze shared expert gate: [D, 1] or [1, D] -> [D]
+ case strings.HasSuffix(name, ".ffn_gate_inp_shexp.weight"):
+ newShape := slices.Clone(shape)
+ if len(shape) == 2 {
+ if shape[0] == 1 && shape[1] > 1 {
+ newShape = []uint64{shape[1]}
+ } else if shape[1] == 1 && shape[0] > 1 {
+ newShape = []uint64{shape[0]}
+ }
+ }
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: newShape,
+ WriterTo: t,
+ })
+
+ default:
+ out = append(out, &ggml.Tensor{
+ Name: name,
+ Kind: t.Kind(),
+ Shape: slices.Clone(shape),
+ WriterTo: t,
+ })
+ }
+ }
+
+ return out
+}
+
+type qkvzSplitSpec struct {
+ hidden int
+ headKDim int
+ headVDim int
+ numKHeads int
+ numVHeads int
+ qkvzDim int
+ qkvOut int
+ gateOut int
+}
+
+func (q *qwen3NextModel) qkvzSpec(shape []uint64) (qkvzSplitSpec, bool) {
+ if len(shape) != 2 {
+ return qkvzSplitSpec{}, false
+ }
+
+ numKHeads := int(q.LinearNumKeyHeads)
+ numVHeads := int(q.LinearNumValueHeads)
+ headKDim := int(q.LinearKeyHeadDim)
+ headVDim := int(q.LinearValueHeadDim)
+ if numKHeads == 0 || numVHeads == 0 || headKDim == 0 || headVDim == 0 {
+ return qkvzSplitSpec{}, false
+ }
+ if numVHeads%numKHeads != 0 {
+ return qkvzSplitSpec{}, false
+ }
+
+ hidden := int(shape[1])
+ vPerHead := headVDim * (numVHeads / numKHeads)
+ qkvzDim := 2*headKDim + 2*vPerHead
+ expectedOut := qkvzDim * numKHeads
+ if int(shape[0]) != expectedOut {
+ return qkvzSplitSpec{}, false
+ }
+
+ return qkvzSplitSpec{
+ hidden: hidden,
+ headKDim: headKDim,
+ headVDim: headVDim,
+ numKHeads: numKHeads,
+ numVHeads: numVHeads,
+ qkvzDim: qkvzDim,
+ qkvOut: 2*headKDim*numKHeads + headVDim*numVHeads,
+ gateOut: headVDim * numVHeads,
+ }, true
+}
+
+func (q *qwen3NextModel) splitQKVZTensor(t Tensor) (*ggml.Tensor, *ggml.Tensor, bool) {
+ spec, ok := q.qkvzSpec(t.Shape())
+ if !ok {
+ return nil, nil, false
+ }
+
+ qkvTensor := t.Clone()
+ qkvTensor.SetRepacker(q.repackQKVZ(spec, false))
+
+ gateTensor := t.Clone()
+ gateTensor.SetRepacker(q.repackQKVZ(spec, true))
+
+ qkvName := strings.Replace(t.Name(), "ssm_in", "attn_qkv", 1)
+ gateName := strings.Replace(t.Name(), "ssm_in", "attn_gate", 1)
+
+ return &ggml.Tensor{
+ Name: qkvName,
+ Kind: t.Kind(),
+ Shape: []uint64{uint64(spec.qkvOut), uint64(spec.hidden)},
+ WriterTo: qkvTensor,
+ }, &ggml.Tensor{
+ Name: gateName,
+ Kind: t.Kind(),
+ Shape: []uint64{uint64(spec.gateOut), uint64(spec.hidden)},
+ WriterTo: gateTensor,
+ }, true
+}
+
+func (q *qwen3NextModel) repackQKVZ(spec qkvzSplitSpec, extractGate bool) Repacker {
+ vPerHead := spec.headVDim * (spec.numVHeads / spec.numKHeads)
+
+ return func(_ string, data []float32, shape []uint64) ([]float32, error) {
+ dims := make([]int, len(shape))
+ for i := range shape {
+ dims[i] = int(shape[i])
+ }
+
+ var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
+ var err error
+
+ // Convert to [hidden, out_features] layout for slicing
+ tt, err = tensor.Transpose(tt, 1, 0)
+ if err != nil {
+ return nil, err
+ }
+ tt = tensor.Materialize(tt)
+
+ if err := tt.Reshape(spec.hidden, spec.numKHeads, spec.qkvzDim); err != nil {
+ return nil, err
+ }
+
+ offset := 0
+ qSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
+ if err != nil {
+ return nil, err
+ }
+ offset += spec.headKDim
+ kSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+spec.headKDim))
+ if err != nil {
+ return nil, err
+ }
+ offset += spec.headKDim
+ vSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
+ if err != nil {
+ return nil, err
+ }
+ offset += vPerHead
+ zSlice, err := tt.Slice(nil, nil, tensor.S(offset, offset+vPerHead))
+ if err != nil {
+ return nil, err
+ }
+
+ qMat := tensor.Materialize(qSlice).(*tensor.Dense)
+ kMat := tensor.Materialize(kSlice).(*tensor.Dense)
+ vMat := tensor.Materialize(vSlice).(*tensor.Dense)
+ zMat := tensor.Materialize(zSlice).(*tensor.Dense)
+
+ if err := qMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
+ return nil, err
+ }
+ if err := kMat.Reshape(spec.hidden, spec.numKHeads*spec.headKDim); err != nil {
+ return nil, err
+ }
+ if err := vMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
+ return nil, err
+ }
+ if err := zMat.Reshape(spec.hidden, spec.numKHeads*vPerHead); err != nil {
+ return nil, err
+ }
+
+ var out tensor.Tensor
+ if extractGate {
+ out = zMat
+ } else {
+ out, err = tensor.Concat(1, qMat, kMat, vMat)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ out = tensor.Materialize(out)
+ out, err = tensor.Transpose(out, 1, 0)
+ if err != nil {
+ return nil, err
+ }
+ out = tensor.Materialize(out)
+
+ if err := out.Reshape(out.Shape().TotalSize()); err != nil {
+ return nil, err
+ }
+
+ return native.VectorF32(out.(*tensor.Dense))
+ }
+}
+
+// addOne adds 1.0 to all elements in the tensor (for norm weights)
+func (*qwen3NextModel) addOne(_ string, data []float32, shape []uint64) ([]float32, error) {
+ n := tensor.New(tensor.WithShape(int(shape[0])), tensor.WithBacking(data))
+ ones := tensor.Ones(tensor.Float32, int(shape[0]))
+
+ n, err := n.Add(ones)
+ if err != nil {
+ return nil, err
+ }
+
+ ts, err := native.SelectF32(n, 0)
+ if err != nil {
+ return nil, err
+ }
+
+ var f32s []float32
+ for _, t := range ts {
+ f32s = append(f32s, t...)
+ }
+
+ return f32s, nil
+}
+
+func (q *qwen3NextModel) Replacements() []string {
+ return []string{
+ // Embeddings and output
+ "lm_head", "output",
+ "model.embed_tokens", "token_embd",
+ "model.norm", "output_norm",
+ "model.layers", "blk",
+
+ // Layer norms
+ "input_layernorm", "attn_norm",
+ "post_attention_layernorm", "post_attention_norm",
+
+ // Full attention (self_attn)
+ "self_attn.q_proj", "attn_q",
+ "self_attn.q_norm", "attn_q_norm",
+ "self_attn.k_proj", "attn_k",
+ "self_attn.k_norm", "attn_k_norm",
+ "self_attn.v_proj", "attn_v",
+ "self_attn.o_proj", "attn_output",
+
+ // Linear attention (Gated Delta Net)
+ "linear_attn.in_proj_qkvz", "ssm_in",
+ "linear_attn.in_proj_ba", "ssm_ba",
+ "linear_attn.conv1d", "ssm_conv1d",
+ "linear_attn.dt_bias", "ssm_dt",
+ "linear_attn.dt_proj", "ssm_dt",
+ "linear_attn.A_log", "ssm_a",
+ "linear_attn.norm", "ssm_norm",
+ "linear_attn.out_proj", "ssm_out",
+
+ // MoE (experts are stacked via mergeTensors, not replaced here)
+ "mlp.gate.weight", "ffn_gate_inp.weight",
+ "mlp.shared_expert.down_proj", "ffn_down_shexp",
+ "mlp.shared_expert.gate_proj", "ffn_gate_shexp",
+ "mlp.shared_expert.up_proj", "ffn_up_shexp",
+ "mlp.shared_expert_gate", "ffn_gate_inp_shexp",
+
+ // Dense FFN (if any layers use it)
+ "mlp.down_proj", "ffn_down",
+ "mlp.gate_proj", "ffn_gate",
+ "mlp.up_proj", "ffn_up",
+ }
+}
diff --git a/convert/reader.go b/convert/reader.go
index 75764f018b9..0cff12a2296 100644
--- a/convert/reader.go
+++ b/convert/reader.go
@@ -40,6 +40,8 @@ const (
func (t tensorBase) Kind() uint32 {
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
strings.HasSuffix(t.name, ".bias") ||
+ strings.HasSuffix(t.name, ".shortconv.conv.weight") ||
+ strings.HasSuffix(t.name, ".ssm_conv1d.weight") || // SSM conv kernel must be F32 for Metal
t.name == "token_types.weight" ||
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
diff --git a/convert/reader_safetensors.go b/convert/reader_safetensors.go
index f7d9754f03c..f7dae0646b7 100644
--- a/convert/reader_safetensors.go
+++ b/convert/reader_safetensors.go
@@ -99,6 +99,8 @@ func (st safetensor) Kind() uint32 {
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
+ !strings.HasPrefix(st.name, "mm.") &&
+ !strings.Contains(st.name, "ffn_gate_inp_shexp.weight") &&
kind != tensorKindFP32 {
kind = tensorKindBF16
}
diff --git a/docs/api/anthropic-compatibility.mdx b/docs/api/anthropic-compatibility.mdx
index f12a0beb121..81ec04d4764 100644
--- a/docs/api/anthropic-compatibility.mdx
+++ b/docs/api/anthropic-compatibility.mdx
@@ -4,16 +4,6 @@ title: Anthropic compatibility
Ollama provides compatibility with the [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) to help connect existing applications to Ollama, including tools like Claude Code.
-## Recommended models
-
-For coding use cases, models like `glm-4.7:cloud`, `minimax-m2.1:cloud`, and `qwen3-coder` are recommended.
-
-Pull a model before use:
-```shell
-ollama pull qwen3-coder
-ollama pull glm-4.7:cloud
-```
-
## Usage
### Environment variables
@@ -22,8 +12,8 @@ To use Ollama with tools that expect the Anthropic API (like Claude Code), set t
```shell
export ANTHROPIC_AUTH_TOKEN=ollama # required but ignored
+export ANTHROPIC_API_KEY="" # required but ignored
export ANTHROPIC_BASE_URL=http://localhost:11434
-export ANTHROPIC_API_KEY=ollama # required but ignored
```
### Simple `/v1/messages` example
@@ -245,10 +235,41 @@ curl -X POST http://localhost:11434/v1/messages \
## Using with Claude Code
-[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend:
+[Claude Code](https://code.claude.com/docs/en/overview) can be configured to use Ollama as its backend.
+
+### Recommended models
+
+For coding use cases, models like `glm-4.7`, `minimax-m2.1`, and `qwen3-coder` are recommended.
+
+Download a model before use:
```shell
-ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY=ollama claude --model qwen3-coder
+ollama pull qwen3-coder
+```
+> Note: Qwen 3 coder is a 30B parameter model requiring at least 24GB of VRAM to run smoothly. More is required for longer context lengths.
+
+```shell
+ollama pull glm-4.7:cloud
+```
+
+### Quick setup
+
+```shell
+ollama launch claude
+```
+
+This will prompt you to select a model, configure Claude Code automatically, and launch it. To configure without launching:
+
+```shell
+ollama launch claude --config
+```
+
+### Manual setup
+
+Set the environment variables and run Claude Code:
+
+```shell
+ANTHROPIC_AUTH_TOKEN=ollama ANTHROPIC_BASE_URL=http://localhost:11434 ANTHROPIC_API_KEY="" claude --model qwen3-coder
```
Or set the environment variables in your shell profile:
@@ -256,19 +277,13 @@ Or set the environment variables in your shell profile:
```shell
export ANTHROPIC_AUTH_TOKEN=ollama
export ANTHROPIC_BASE_URL=http://localhost:11434
-export ANTHROPIC_API_KEY=ollama
+export ANTHROPIC_API_KEY=""
```
Then run Claude Code with any Ollama model:
```shell
-# Local models
claude --model qwen3-coder
-claude --model gpt-oss:20b
-
-# Cloud models
-claude --model glm-4.7:cloud
-claude --model minimax-m2.1:cloud
```
## Endpoints
diff --git a/docs/cli.mdx b/docs/cli.mdx
index 97810e64a70..ecceee41d1d 100644
--- a/docs/cli.mdx
+++ b/docs/cli.mdx
@@ -8,6 +8,47 @@ title: CLI Reference
ollama run gemma3
```
+### Launch integrations
+
+```
+ollama launch
+```
+
+Configure and launch external applications to use Ollama models. This provides an interactive way to set up and start integrations with supported apps.
+
+#### Supported integrations
+
+- **OpenCode** - Open-source coding assistant
+- **Claude Code** - Anthropic's agentic coding tool
+- **Codex** - OpenAI's coding assistant
+- **Droid** - Factory's AI coding agent
+
+#### Examples
+
+Launch an integration interactively:
+
+```
+ollama launch
+```
+
+Launch a specific integration:
+
+```
+ollama launch claude
+```
+
+Launch with a specific model:
+
+```
+ollama launch claude --model qwen3-coder
+```
+
+Configure without launching:
+
+```
+ollama launch droid --config
+```
+
#### Multiline input
For multiline input, you can wrap text with `"""`:
diff --git a/docs/cloud.mdx b/docs/cloud.mdx
index 4f4c3722b9b..a1dacd20476 100644
--- a/docs/cloud.mdx
+++ b/docs/cloud.mdx
@@ -3,8 +3,6 @@ title: Cloud
sidebarTitle: Cloud
---
-Ollama's cloud is currently in preview.
-
## Cloud Models
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
@@ -228,3 +226,7 @@ curl https://ollama.com/api/chat \
+
+## Local only
+
+Ollama can run in local-only mode by [disabling Ollama's cloud](./faq#how-do-i-disable-ollama-cloud) features.
\ No newline at end of file
diff --git a/docs/context-length.mdx b/docs/context-length.mdx
index 43bcf0d3178..06ae21a39de 100644
--- a/docs/context-length.mdx
+++ b/docs/context-length.mdx
@@ -5,10 +5,13 @@ title: Context length
Context length is the maximum number of tokens that the model has access to in memory.
- The default context length in Ollama is 4096 tokens.
+ Ollama defaults to the following context lengths based on VRAM:
+ - < 24 GiB VRAM: 4k context
+ - 24-48 GiB VRAM: 32k context
+ - >= 48 GiB VRAM: 256k context
-Tasks which require large context like web search, agents, and coding tools should be set to at least 32000 tokens.
+Tasks which require large context like web search, agents, and coding tools should be set to at least 64000 tokens.
## Setting context length
@@ -24,7 +27,7 @@ Change the slider in the Ollama app under settings to your desired context lengt
### CLI
If editing the context length for Ollama is not possible, the context length can also be updated when serving Ollama.
```
-OLLAMA_CONTEXT_LENGTH=32000 ollama serve
+OLLAMA_CONTEXT_LENGTH=64000 ollama serve
```
### Check allocated context length and model offloading
diff --git a/docs/docs.json b/docs/docs.json
index 921c9e34e06..3f8b5c1a891 100644
--- a/docs/docs.json
+++ b/docs/docs.json
@@ -71,6 +71,10 @@
{
"source": "/api",
"destination": "/api/introduction"
+ },
+ {
+ "source": "/integrations/clawdbot",
+ "destination": "/integrations/openclaw"
}
],
"navigation": {
@@ -101,19 +105,55 @@
{
"group": "Integrations",
"pages": [
- "/integrations/claude-code",
- "/integrations/vscode",
- "/integrations/jetbrains",
- "/integrations/codex",
- "/integrations/cline",
- "/integrations/droid",
- "/integrations/goose",
- "/integrations/zed",
- "/integrations/roo-code",
- "/integrations/n8n",
- "/integrations/xcode",
- "/integrations/onyx",
- "/integrations/marimo"
+ "/integrations/index",
+ {
+ "group": "Assistants",
+ "expanded": true,
+ "pages": [
+ "/integrations/openclaw"
+ ]
+ },
+ {
+ "group": "Coding",
+ "expanded": true,
+ "pages": [
+ "/integrations/claude-code",
+ "/integrations/codex",
+ "/integrations/opencode",
+ "/integrations/droid",
+ "/integrations/goose",
+ "/integrations/pi"
+ ]
+ },
+ {
+ "group": "IDEs & Editors",
+ "pages": [
+ "/integrations/cline",
+ "/integrations/jetbrains",
+ "/integrations/roo-code",
+ "/integrations/vscode",
+ "/integrations/xcode",
+ "/integrations/zed"
+ ]
+ },
+ {
+ "group": "Chat & RAG",
+ "pages": [
+ "/integrations/onyx"
+ ]
+ },
+ {
+ "group": "Automation",
+ "pages": [
+ "/integrations/n8n"
+ ]
+ },
+ {
+ "group": "Notebooks",
+ "pages": [
+ "/integrations/marimo"
+ ]
+ }
]
},
{
diff --git a/docs/faq.mdx b/docs/faq.mdx
index a751c0c3ff4..47f06529d5f 100644
--- a/docs/faq.mdx
+++ b/docs/faq.mdx
@@ -14,11 +14,11 @@ curl -fsSL https://ollama.com/install.sh | sh
## How can I view the logs?
-Review the [Troubleshooting](./troubleshooting) docs for more about using logs.
+Review the [Troubleshooting](./troubleshooting.mdx) docs for more about using logs.
## Is my GPU compatible with Ollama?
-Please refer to the [GPU docs](./gpu).
+Please refer to the [GPU docs](./gpu.mdx).
## How can I specify the context window size?
@@ -66,7 +66,7 @@ llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
-The `Processor` column will show which memory the model was loaded in to:
+The `Processor` column will show which memory the model was loaded into:
- `100% GPU` means the model was loaded entirely into the GPU
- `100% CPU` means the model was loaded entirely in system memory
@@ -158,7 +158,27 @@ docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-
## Does Ollama send my prompts and answers back to ollama.com?
-No. Ollama runs locally, and conversation data does not leave your machine.
+Ollama runs locally. We don't see your prompts or data when you run locally. When using cloud-hosted models, we process your prompts and responses to provide the service but do not store or log that content and never train on it. We collect basic account info and limited usage metadata to provide the service that does not include prompt or response content. We don't sell your data. You can delete your account anytime.
+
+## How do I disable Ollama's cloud features?
+
+Ollama can run in local only mode by disabling Ollama's cloud features. By turning off Ollama's cloud features, you will lose the ability to use Ollama's cloud models and web search.
+
+Set `disable_ollama_cloud` in `~/.ollama/server.json`:
+
+```json
+{
+ "disable_ollama_cloud": true
+}
+```
+
+You can also set the environment variable:
+
+```shell
+OLLAMA_NO_CLOUD=1
+```
+
+Restart Ollama after changing configuration. Once disabled, Ollama's logs will show `Ollama cloud disabled: true`.
## How can I expose Ollama on my network?
@@ -183,7 +203,7 @@ server {
## How can I use Ollama with ngrok?
-Ollama can be accessed using a range of tools for tunneling tools. For example with Ngrok:
+Ollama can be accessed using a range of tunneling apps. For example with Ngrok:
```shell
ngrok http 11434 --host-header="localhost:11434"
@@ -240,7 +260,7 @@ GPU acceleration is not available for Docker Desktop in macOS due to the lack of
This can impact both installing Ollama, as well as downloading models.
-Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernel (WSL)` adapter, right click and select `Properties`.
+Open `Control Panel > Networking and Internet > View network status and tasks` and click on `Change adapter settings` on the left panel. Find the `vEthernet (WSL)` adapter, right click and select `Properties`.
Click on `Configure` and open the `Advanced` tab. Search through each of the properties until you find `Large Send Offload Version 2 (IPv4)` and `Large Send Offload Version 2 (IPv6)`. _Disable_ both of these
properties.
@@ -299,7 +319,7 @@ The `keep_alive` API parameter with the `/api/generate` and `/api/chat` API endp
## How do I manage the maximum number of requests the Ollama server can queue?
-If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queue by setting `OLLAMA_MAX_QUEUE`.
+If too many requests are sent to the server, it will respond with a 503 error indicating the server is overloaded. You can adjust how many requests may be queued by setting `OLLAMA_MAX_QUEUE`.
## How does Ollama handle concurrent requests?
@@ -312,10 +332,10 @@ Parallel request processing for a given model results in increasing the context
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 \* the number of GPUs or 3 for CPU inference.
-- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
+- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time, default 1. Required RAM will scale by `OLLAMA_NUM_PARALLEL` * `OLLAMA_CONTEXT_LENGTH`.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
-Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
+Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPU's VRAM.
## How does Ollama load models on multiple GPUs?
@@ -382,7 +402,7 @@ ollama signin
Replace <username> with your actual Windows user name.
-## How can I stop Ollama from starting when I login to my computer
+## How can I stop Ollama from starting when I login to my computer?
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
@@ -390,4 +410,4 @@ Ollama for Windows and macOS register as a login item during installation. You
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
-- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
+- Open `Settings` and search for "Login Items", find the `Ollama` entry under `Allow in the Background`, then click the slider to disable.
diff --git a/docs/gpu.mdx b/docs/gpu.mdx
index 9cb2d3abc53..60e2f5e3c14 100644
--- a/docs/gpu.mdx
+++ b/docs/gpu.mdx
@@ -10,6 +10,7 @@ Check your compute compatibility to see if your card is supported:
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
+| 12.1 | NVIDIA | `GB10 (DGX Spark)` |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
@@ -163,4 +164,4 @@ To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
-by setting `GGML_VK_VISIBLE_DEVICES=-1`
\ No newline at end of file
+by setting `GGML_VK_VISIBLE_DEVICES=-1`
diff --git a/docs/import.mdx b/docs/import.mdx
index 26a6365640c..6870124ac8f 100644
--- a/docs/import.mdx
+++ b/docs/import.mdx
@@ -138,22 +138,12 @@ success
### Supported Quantizations
-- `q4_0`
-- `q4_1`
-- `q5_0`
-- `q5_1`
- `q8_0`
#### K-means Quantizations
-- `q3_K_S`
-- `q3_K_M`
-- `q3_K_L`
- `q4_K_S`
- `q4_K_M`
-- `q5_K_S`
-- `q5_K_M`
-- `q6_K`
## Sharing your model on ollama.com
diff --git a/docs/index.mdx b/docs/index.mdx
index 669d30cfbc9..ac1c744ea11 100644
--- a/docs/index.mdx
+++ b/docs/index.mdx
@@ -9,7 +9,7 @@ sidebarTitle: Welcome
- Get up and running with your first model
+ Get up and running with your first model or integrate Ollama with your favorite tools
-```
-
-3. Run Claude Code with a cloud model:
-
-```shell
-claude --model glm-4.7:cloud
-```
+**Note:** Claude Code requires a large context window. We recommend at least 64k tokens. See the [context length documentation](/context-length) for how to adjust context length in Ollama.
## Recommended Models
-### Cloud models
-- `glm-4.7:cloud` - High-performance cloud model
-- `minimax-m2.1:cloud` - Fast cloud model
-- `qwen3-coder:480b` - Large coding model
+- `qwen3-coder`
+- `glm-4.7`
+- `gpt-oss:20b`
+- `gpt-oss:120b`
+
+Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
-### Local models
-- `qwen3-coder` - Excellent for coding tasks
-- `gpt-oss:20b` - Strong general-purpose model
-- `gpt-oss:120b` - Larger general-purpose model for more complex tasks
\ No newline at end of file
diff --git a/docs/integrations/codex.mdx b/docs/integrations/codex.mdx
index f9df1b85897..7a79d39ab59 100644
--- a/docs/integrations/codex.mdx
+++ b/docs/integrations/codex.mdx
@@ -13,7 +13,21 @@ npm install -g @openai/codex
## Usage with Ollama
-Codex requires a larger context window. It is recommended to use a context window of at least 32K tokens.
+Codex requires a larger context window. It is recommended to use a context window of at least 64k tokens.
+
+### Quick setup
+
+```
+ollama launch codex
+```
+
+To configure without launching:
+
+```shell
+ollama launch codex --config
+```
+
+### Manual setup
To use `codex` with Ollama, use the `--oss` flag:
diff --git a/docs/integrations/droid.mdx b/docs/integrations/droid.mdx
index b1ba37710a6..24955451024 100644
--- a/docs/integrations/droid.mdx
+++ b/docs/integrations/droid.mdx
@@ -11,10 +11,24 @@ Install the [Droid CLI](https://factory.ai/):
curl -fsSL https://app.factory.ai/cli | sh
```
-Droid requires a larger context window. It is recommended to use a context window of at least 32K tokens. See [Context length](/context-length) for more information.
+Droid requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.
## Usage with Ollama
+### Quick setup
+
+```bash
+ollama launch droid
+```
+
+To configure without launching:
+
+```shell
+ollama launch droid --config
+```
+
+### Manual setup
+
Add a local configuration block to `~/.factory/config.json`:
```json
@@ -73,4 +87,4 @@ Add the cloud configuration block to `~/.factory/config.json`:
}
```
-Run `droid` in a new terminal to load the new settings.
\ No newline at end of file
+Run `droid` in a new terminal to load the new settings.
diff --git a/docs/integrations/index.mdx b/docs/integrations/index.mdx
new file mode 100644
index 00000000000..5ae2fe6707f
--- /dev/null
+++ b/docs/integrations/index.mdx
@@ -0,0 +1,51 @@
+---
+title: Overview
+---
+
+Ollama integrates with a wide range of tools.
+
+## Coding Agents
+
+Coding assistants that can read, modify, and execute code in your projects.
+
+- [Claude Code](/integrations/claude-code)
+- [Codex](/integrations/codex)
+- [OpenCode](/integrations/opencode)
+- [Droid](/integrations/droid)
+- [Goose](/integrations/goose)
+- [Pi](/integrations/pi)
+
+## Assistants
+
+AI assistants that help with everyday tasks.
+
+- [OpenClaw](/integrations/openclaw)
+
+## IDEs & Editors
+
+Native integrations for popular development environments.
+
+- [VS Code](/integrations/vscode)
+- [Cline](/integrations/cline)
+- [Roo Code](/integrations/roo-code)
+- [JetBrains](/integrations/jetbrains)
+- [Xcode](/integrations/xcode)
+- [Zed](/integrations/zed)
+
+## Chat & RAG
+
+Chat interfaces and retrieval-augmented generation platforms.
+
+- [Onyx](/integrations/onyx)
+
+## Automation
+
+Workflow automation platforms with AI integration.
+
+- [n8n](/integrations/n8n)
+
+## Notebooks
+
+Interactive computing environments with AI capabilities.
+
+- [marimo](/integrations/marimo)
diff --git a/docs/integrations/openclaw.mdx b/docs/integrations/openclaw.mdx
new file mode 100644
index 00000000000..1a4a79905c3
--- /dev/null
+++ b/docs/integrations/openclaw.mdx
@@ -0,0 +1,50 @@
+---
+title: OpenClaw
+---
+
+OpenClaw is a personal AI assistant that runs on your own devices. It bridges messaging services (WhatsApp, Telegram, Slack, Discord, iMessage, and more) to AI coding agents through a centralized gateway.
+
+## Install
+
+Install [OpenClaw](https://openclaw.ai/)
+
+```bash
+npm install -g openclaw@latest
+```
+
+Then run the onboarding wizard:
+
+```bash
+openclaw onboard --install-daemon
+```
+
+OpenClaw requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.
+
+## Usage with Ollama
+
+### Quick setup
+
+```bash
+ollama launch openclaw
+```
+
+Previously known as Clawdbot. `ollama launch clawdbot` still works as an alias.
+
+This configures OpenClaw to use Ollama and starts the gateway.
+If the gateway is already running, no changes need to be made as the gateway will auto-reload the changes.
+
+
+To configure without launching:
+
+```shell
+ollama launch openclaw --config
+```
+
+## Recommended Models
+
+- `qwen3-coder`
+- `glm-4.7`
+- `gpt-oss:20b`
+- `gpt-oss:120b`
+
+Cloud models are also available at [ollama.com/search?c=cloud](https://ollama.com/search?c=cloud).
diff --git a/docs/integrations/opencode.mdx b/docs/integrations/opencode.mdx
new file mode 100644
index 00000000000..6f5707688c6
--- /dev/null
+++ b/docs/integrations/opencode.mdx
@@ -0,0 +1,106 @@
+---
+title: OpenCode
+---
+
+OpenCode is an open-source AI coding assistant that runs in your terminal.
+
+## Install
+
+Install the [OpenCode CLI](https://opencode.ai):
+
+```bash
+curl -fsSL https://opencode.ai/install | bash
+```
+
+OpenCode requires a larger context window. It is recommended to use a context window of at least 64k tokens. See [Context length](/context-length) for more information.
+
+## Usage with Ollama
+
+### Quick setup
+
+```bash
+ollama launch opencode
+```
+
+To configure without launching:
+
+```shell
+ollama launch opencode --config
+```
+
+### Manual setup
+
+Add a configuration block to `~/.config/opencode/opencode.json`:
+
+```json
+{
+ "$schema": "https://opencode.ai/config.json",
+ "provider": {
+ "ollama": {
+ "npm": "@ai-sdk/openai-compatible",
+ "name": "Ollama",
+ "options": {
+ "baseURL": "http://localhost:11434/v1"
+ },
+ "models": {
+ "qwen3-coder": {
+ "name": "qwen3-coder"
+ }
+ }
+ }
+ }
+}
+```
+
+## Cloud Models
+
+`glm-4.7:cloud` is the recommended model for use with OpenCode.
+
+Add the cloud configuration to `~/.config/opencode/opencode.json`:
+
+```json
+{
+ "$schema": "https://opencode.ai/config.json",
+ "provider": {
+ "ollama": {
+ "npm": "@ai-sdk/openai-compatible",
+ "name": "Ollama",
+ "options": {
+ "baseURL": "http://localhost:11434/v1"
+ },
+ "models": {
+ "glm-4.7:cloud": {
+ "name": "glm-4.7:cloud"
+ }
+ }
+ }
+ }
+}
+```
+
+## Connecting to ollama.com
+
+1. Create an [API key](https://ollama.com/settings/keys) from ollama.com and export it as `OLLAMA_API_KEY`.
+2. Update `~/.config/opencode/opencode.json` to point to ollama.com:
+
+```json
+{
+ "$schema": "https://opencode.ai/config.json",
+ "provider": {
+ "ollama": {
+ "npm": "@ai-sdk/openai-compatible",
+ "name": "Ollama Cloud",
+ "options": {
+ "baseURL": "https://ollama.com/v1"
+ },
+ "models": {
+ "glm-4.7:cloud": {
+ "name": "glm-4.7:cloud"
+ }
+ }
+ }
+ }
+}
+```
+
+Run `opencode` in a new terminal to load the new settings.
diff --git a/docs/integrations/pi.mdx b/docs/integrations/pi.mdx
new file mode 100644
index 00000000000..fd2dadbed37
--- /dev/null
+++ b/docs/integrations/pi.mdx
@@ -0,0 +1,57 @@
+---
+title: Pi
+---
+
+Pi is a minimal AI agent toolkit with plugin support.
+
+## Install
+
+Install [Pi](https://github.com/badlogic/pi-mono):
+
+```bash
+npm install -g @mariozechner/pi-coding-agent
+```
+
+## Usage with Ollama
+
+### Quick setup
+
+```bash
+ollama launch pi
+```
+
+To configure without launching:
+
+```shell
+ollama launch pi --config
+```
+
+### Manual setup
+
+Add a configuration block to `~/.pi/agent/models.json`:
+
+```json
+{
+ "providers": {
+ "ollama": {
+ "baseUrl": "http://localhost:11434/v1",
+ "api": "openai-completions",
+ "apiKey": "ollama",
+ "models": [
+ {
+ "id": "qwen3-coder"
+ }
+ ]
+ }
+ }
+}
+```
+
+Update `~/.pi/agent/settings.json` to set the default provider:
+
+```json
+{
+ "defaultProvider": "ollama",
+ "defaultModel": "qwen3-coder"
+}
+```
diff --git a/docs/openapi.yaml b/docs/openapi.yaml
index 4817bcb41f5..d225a4769ff 100644
--- a/docs/openapi.yaml
+++ b/docs/openapi.yaml
@@ -596,6 +596,15 @@ components:
name:
type: string
description: Model name
+ model:
+ type: string
+ description: Model name
+ remote_model:
+ type: string
+ description: Name of the upstream model, if the model is remote
+ remote_host:
+ type: string
+ description: URL of the upstream Ollama host, if the model is remote
modified_at:
type: string
description: Last modified timestamp in ISO 8601 format
@@ -636,6 +645,9 @@ components:
Ps:
type: object
properties:
+ name:
+ type: string
+ description: Name of the running model
model:
type: string
description: Name of the running model
@@ -1137,6 +1149,7 @@ paths:
example:
models:
- name: "gemma3"
+ model: "gemma3"
modified_at: "2025-10-03T23:34:03.409490317-07:00"
size: 3338801804
digest: "a2af6cc3eb7fa8be8504abaf9b04e88f17a119ec3f04a3addf55f92841195f5a"
@@ -1168,7 +1181,8 @@ paths:
$ref: "#/components/schemas/PsResponse"
example:
models:
- - model: "gemma3"
+ - name: "gemma3"
+ model: "gemma3"
size: 6591830464
digest: "a2af6cc3eb7fa8be8504abaf9b04e88f17a119ec3f04a3addf55f92841195f5a"
details:
diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx
index 5ef9fa825d8..62f2f99b4c2 100644
--- a/docs/quickstart.mdx
+++ b/docs/quickstart.mdx
@@ -2,7 +2,7 @@
title: Quickstart
---
-This quickstart will walk your through running your first model with Ollama. To get started, download Ollama on macOS, Windows or Linux.
+Ollama is available on macOS, Windows, and Linux.
-## Run a model
-
-
-
- Open a terminal and run the command:
-
- ```
- ollama run gemma3
- ```
-
-
-
- ```
- ollama pull gemma3
- ```
-
- Lastly, chat with the model:
-
- ```shell
- curl http://localhost:11434/api/chat -d '{
- "model": "gemma3",
- "messages": [{
- "role": "user",
- "content": "Hello there!"
- }],
- "stream": false
- }'
- ```
-
-
-
- Start by downloading a model:
-
- ```
- ollama pull gemma3
- ```
-
- Then install Ollama's Python library:
-
- ```
- pip install ollama
- ```
-
- Lastly, chat with the model:
-
- ```python
- from ollama import chat
- from ollama import ChatResponse
-
- response: ChatResponse = chat(model='gemma3', messages=[
- {
- 'role': 'user',
- 'content': 'Why is the sky blue?',
- },
- ])
- print(response['message']['content'])
- # or access fields directly from the response object
- print(response.message.content)
- ```
-
-
-
- Start by downloading a model:
-
- ```
- ollama pull gemma3
- ```
-
- Then install the Ollama JavaScript library:
- ```
- npm i ollama
- ```
-
- Lastly, chat with the model:
-
- ```shell
- import ollama from 'ollama'
-
- const response = await ollama.chat({
- model: 'gemma3',
- messages: [{ role: 'user', content: 'Why is the sky blue?' }],
- })
- console.log(response.message.content)
- ```
-
-
-
-
-See a full list of available models [here](https://ollama.com/models).
+## Get Started
+
+Run `ollama` in your terminal to open the interactive menu:
+
+```sh
+ollama
+```
+
+Navigate with `↑/↓`, press `enter` to launch, `→` to change model, and `esc` to quit.
+
+The menu provides quick access to:
+- **Run a model** - Start an interactive chat
+- **Launch tools** - Claude Code, Codex, OpenClaw, and more
+- **Additional integrations** - Available under "More..."
+
+## Assistants
+
+Launch [OpenClaw](/integrations/openclaw), a personal AI with 100+ skills:
+
+```sh
+ollama launch openclaw
+```
+
+## Coding
+
+Launch [Claude Code](/integrations/claude-code) and other coding tools with Ollama models:
+
+```sh
+ollama launch claude
+```
+
+```sh
+ollama launch codex
+```
+
+```sh
+ollama launch opencode
+```
+
+See [integrations](/integrations) for all supported tools.
+
+## API
+
+Use the [API](/api) to integrate Ollama into your applications:
+
+```sh
+curl http://localhost:11434/api/chat -d '{
+ "model": "gemma3",
+ "messages": [{ "role": "user", "content": "Hello!" }]
+}'
+```
+
+See the [API documentation](/api) for Python, JavaScript, and other integrations.
diff --git a/envconfig/config.go b/envconfig/config.go
index aa5ce302793..40346f0ca51 100644
--- a/envconfig/config.go
+++ b/envconfig/config.go
@@ -1,6 +1,8 @@
package envconfig
import (
+ "encoding/json"
+ "errors"
"fmt"
"log/slog"
"math"
@@ -11,6 +13,7 @@ import (
"runtime"
"strconv"
"strings"
+ "sync"
"time"
)
@@ -203,11 +206,13 @@ var (
// Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE")
// ContextLength sets the default context length
- ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
+ ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 0)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
+ // NoCloudEnv checks the OLLAMA_NO_CLOUD environment variable.
+ NoCloudEnv = Bool("OLLAMA_NO_CLOUD")
)
func String(s string) func() string {
@@ -218,6 +223,7 @@ func String(s string) func() string {
var (
LLMLibrary = String("OLLAMA_LLM_LIBRARY")
+ Editor = String("OLLAMA_EDITOR")
CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES")
HipVisibleDevices = String("HIP_VISIBLE_DEVICES")
@@ -286,6 +292,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners(), "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueue(), "Maximum number of queued requests"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", Models(), "The path to the models directory"},
+ "OLLAMA_NO_CLOUD": {"OLLAMA_NO_CLOUD", NoCloud(), "Disable Ollama cloud features (remote inference and web search)"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
@@ -293,7 +300,8 @@ func AsMap() map[string]EnvVar {
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_RPC_SERVERS": {"OLLAMA_RPC_SERVERS", RPCServers(), "A comma separated list of RPC server to disribute models to"},
- "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
+ "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4k/32k/256k based on VRAM)"},
+ "OLLAMA_EDITOR": {"OLLAMA_EDITOR", Editor(), "Path to editor for interactive prompt editing (Ctrl+G)"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
"OLLAMA_REMOTES": {"OLLAMA_REMOTES", Remotes(), "Allowed hosts for remote models (default \"ollama.com\")"},
@@ -335,3 +343,91 @@ func Values() map[string]string {
func Var(key string) string {
return strings.Trim(strings.TrimSpace(os.Getenv(key)), "\"'")
}
+
+// serverConfigData holds the parsed fields from ~/.ollama/server.json.
+type serverConfigData struct {
+ DisableOllamaCloud bool `json:"disable_ollama_cloud,omitempty"`
+}
+
+var (
+ serverCfgMu sync.RWMutex
+ serverCfgLoaded bool
+ serverCfg serverConfigData
+)
+
+func loadServerConfig() {
+ serverCfgMu.RLock()
+ if serverCfgLoaded {
+ serverCfgMu.RUnlock()
+ return
+ }
+ serverCfgMu.RUnlock()
+
+ cfg := serverConfigData{}
+ home, err := os.UserHomeDir()
+ if err == nil {
+ path := filepath.Join(home, ".ollama", "server.json")
+ data, err := os.ReadFile(path)
+ if err != nil {
+ if !errors.Is(err, os.ErrNotExist) {
+ slog.Debug("envconfig: could not read server config", "error", err)
+ }
+ } else if err := json.Unmarshal(data, &cfg); err != nil {
+ slog.Debug("envconfig: could not parse server config", "error", err)
+ }
+ }
+
+ serverCfgMu.Lock()
+ defer serverCfgMu.Unlock()
+ if serverCfgLoaded {
+ return
+ }
+ serverCfg = cfg
+ serverCfgLoaded = true
+}
+
+func cachedServerConfig() serverConfigData {
+ serverCfgMu.RLock()
+ defer serverCfgMu.RUnlock()
+ return serverCfg
+}
+
+// ReloadServerConfig refreshes the cached ~/.ollama/server.json settings.
+func ReloadServerConfig() {
+ serverCfgMu.Lock()
+ serverCfgLoaded = false
+ serverCfg = serverConfigData{}
+ serverCfgMu.Unlock()
+
+ loadServerConfig()
+}
+
+// NoCloud returns true if Ollama cloud features are disabled,
+// checking both the OLLAMA_NO_CLOUD environment variable and
+// the disable_ollama_cloud field in ~/.ollama/server.json.
+func NoCloud() bool {
+ if NoCloudEnv() {
+ return true
+ }
+ loadServerConfig()
+ return cachedServerConfig().DisableOllamaCloud
+}
+
+// NoCloudSource returns the source of the cloud-disabled decision.
+// Returns "none", "env", "config", or "both".
+func NoCloudSource() string {
+ envDisabled := NoCloudEnv()
+ loadServerConfig()
+ configDisabled := cachedServerConfig().DisableOllamaCloud
+
+ switch {
+ case envDisabled && configDisabled:
+ return "both"
+ case envDisabled:
+ return "env"
+ case configDisabled:
+ return "config"
+ default:
+ return "none"
+ }
+}
diff --git a/envconfig/config_test.go b/envconfig/config_test.go
index ddd86a1147b..9242e66f5d1 100644
--- a/envconfig/config_test.go
+++ b/envconfig/config_test.go
@@ -3,6 +3,8 @@ package envconfig
import (
"log/slog"
"math"
+ "os"
+ "path/filepath"
"testing"
"time"
@@ -282,7 +284,7 @@ func TestVar(t *testing.T) {
func TestContextLength(t *testing.T) {
cases := map[string]uint{
- "": 4096,
+ "": 0,
"2048": 2048,
}
@@ -326,3 +328,81 @@ func TestLogLevel(t *testing.T) {
})
}
}
+
+func TestNoCloud(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ configContent string
+ wantDisabled bool
+ wantSource string
+ }{
+ {
+ name: "neither env nor config",
+ wantDisabled: false,
+ wantSource: "none",
+ },
+ {
+ name: "env only",
+ envValue: "1",
+ wantDisabled: true,
+ wantSource: "env",
+ },
+ {
+ name: "config only",
+ configContent: `{"disable_ollama_cloud": true}`,
+ wantDisabled: true,
+ wantSource: "config",
+ },
+ {
+ name: "both env and config",
+ envValue: "1",
+ configContent: `{"disable_ollama_cloud": true}`,
+ wantDisabled: true,
+ wantSource: "both",
+ },
+ {
+ name: "config false",
+ configContent: `{"disable_ollama_cloud": false}`,
+ wantDisabled: false,
+ wantSource: "none",
+ },
+ {
+ name: "invalid config ignored",
+ configContent: `{invalid json`,
+ wantDisabled: false,
+ wantSource: "none",
+ },
+ {
+ name: "no config file",
+ wantDisabled: false,
+ wantSource: "none",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ home := t.TempDir()
+ if tt.configContent != "" {
+ configDir := filepath.Join(home, ".ollama")
+ if err := os.MkdirAll(configDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(filepath.Join(configDir, "server.json"), []byte(tt.configContent), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ setTestHome(t, home)
+ t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
+
+ if got := NoCloud(); got != tt.wantDisabled {
+ t.Errorf("NoCloud() = %v, want %v", got, tt.wantDisabled)
+ }
+
+ if got := NoCloudSource(); got != tt.wantSource {
+ t.Errorf("NoCloudSource() = %q, want %q", got, tt.wantSource)
+ }
+ })
+ }
+}
diff --git a/envconfig/test_home_test.go b/envconfig/test_home_test.go
new file mode 100644
index 00000000000..993f1c0aaad
--- /dev/null
+++ b/envconfig/test_home_test.go
@@ -0,0 +1,10 @@
+package envconfig
+
+import "testing"
+
+func setTestHome(t *testing.T, home string) {
+ t.Helper()
+ t.Setenv("HOME", home)
+ t.Setenv("USERPROFILE", home)
+ ReloadServerConfig()
+}
diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go
index 0035dca0615..6ef63d76c6f 100644
--- a/fs/ggml/ggml.go
+++ b/fs/ggml/ggml.go
@@ -292,8 +292,11 @@ func (kv KV) OllamaEngineRequired() bool {
"olmo3",
"qwen25vl",
"qwen3", "qwen3moe",
+ "qwen3next",
"qwen3vl", "qwen3vlmoe",
"glm4moelite",
+ "glmocr",
+ "lfm2",
}, kv.Architecture())
}
@@ -1029,10 +1032,13 @@ func (f MetaGGML) FlashAttention() bool {
"bert",
"gemma3",
"glm4moelite",
+ "glmocr",
"gptoss", "gpt-oss",
+ "lfm2",
"mistral3",
"olmo3",
"qwen3", "qwen3moe",
+ "qwen3next",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))
}
diff --git a/go.mod b/go.mod
index 0f7bca5f29f..1a7a6f8a229 100644
--- a/go.mod
+++ b/go.mod
@@ -13,7 +13,7 @@ require (
github.com/mattn/go-sqlite3 v1.14.24
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
- github.com/stretchr/testify v1.9.0
+ github.com/stretchr/testify v1.10.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.17.0
golang.org/x/sys v0.37.0
@@ -21,13 +21,19 @@ require (
require (
github.com/agnivade/levenshtein v1.1.1
+ github.com/charmbracelet/bubbletea v1.3.10
+ github.com/charmbracelet/lipgloss v1.1.0
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
- github.com/mattn/go-runewidth v0.0.14
+ github.com/klauspost/compress v1.18.3
+ github.com/mattn/go-runewidth v0.0.16
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
+ github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/tkrajina/typescriptify-golang-structs v0.2.0
+ github.com/tree-sitter/go-tree-sitter v0.25.0
+ github.com/tree-sitter/tree-sitter-cpp v0.23.4
github.com/wk8/go-ordered-map/v2 v2.1.8
golang.org/x/image v0.22.0
golang.org/x/mod v0.30.0
@@ -37,22 +43,35 @@ require (
require (
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
+ github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
+ github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
+ github.com/charmbracelet/x/ansi v0.10.1 // indirect
+ github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect
+ github.com/charmbracelet/x/term v0.2.1 // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.11.0 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v24.3.25+incompatible // indirect
github.com/kr/text v0.2.0 // indirect
+ github.com/lucasb-eyer/go-colorful v1.2.0 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
+ github.com/mattn/go-localereader v0.0.1 // indirect
+ github.com/mattn/go-pointer v0.0.1 // indirect
+ github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
+ github.com/muesli/cancelreader v0.2.2 // indirect
+ github.com/muesli/termenv v0.16.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
- github.com/rivo/uniseg v0.2.0 // indirect
+ github.com/rivo/uniseg v0.4.7 // indirect
github.com/tkrajina/go-reflector v0.5.5 // indirect
+ github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
diff --git a/go.sum b/go.sum
index 83014fc5b05..13dd3563839 100644
--- a/go.sum
+++ b/go.sum
@@ -14,6 +14,8 @@ github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6IC
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
+github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
+github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk=
github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
@@ -24,6 +26,18 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU=
+github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
+github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
+github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc h1:4pZI35227imm7yK2bGPcfpFEmuY1gc2YSTShr4iJBfs=
+github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc/go.mod h1:X4/0JoqgTIPSFcRA/P6INZzIuyqdFY5rm8tb41s9okk=
+github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
+github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
+github.com/charmbracelet/x/ansi v0.10.1 h1:rL3Koar5XvX0pHGfovN03f5cxLbCF2YvLeyz7D2jVDQ=
+github.com/charmbracelet/x/ansi v0.10.1/go.mod h1:3RQDQ6lDnROptfpWuUVIUG64bD2g2BgntdxH0Ya5TeE=
+github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd h1:vy0GVL4jeHEwG5YOXDmi86oYw2yuYUGqz6a8sLwg0X8=
+github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd/go.mod h1:xe0nKWGd3eJgtqZRaN9RjMtK7xUYchjzPr7q6kcvCCs=
+github.com/charmbracelet/x/term v0.2.1 h1:AQeHeLZ1OqSXhrAWpYUtZyX1T3zVxfpZuEQMIQaGIAQ=
+github.com/charmbracelet/x/term v0.2.1/go.mod h1:oQ4enTYFV7QN4m0i9mzHrViD7TQKvNEEkHUMCmsxdUg=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
@@ -59,6 +73,8 @@ github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.m
github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk=
github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0/go.mod h1:hliV/p42l8fGbc6Y9bQ70uLwIvmJyVE5k4iMKlh8wCQ=
github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c=
+github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
+github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -106,7 +122,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
-github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI=
@@ -134,8 +149,9 @@ github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
-github.com/klauspost/compress v1.13.1 h1:wXr2uRxZTJXHLly6qhJabee5JqIhTRoLBhDOA74hDEQ=
github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg=
+github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw=
+github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
@@ -148,13 +164,19 @@ github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728 h1:QwWKgMY28TAXaDl+
github.com/ledongthuc/pdf v0.0.0-20250511090121-5959a4027728/go.mod h1:1fEHWurg7pvf5SG6XNE5Q8UZmOwex51Mkx3SLhrW5B4=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
+github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY=
+github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
+github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
+github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
+github.com/mattn/go-pointer v0.0.1 h1:n+XhsuGeVO6MEAp7xyEukFINEa+Quek5psIR/ylA6o0=
+github.com/mattn/go-pointer v0.0.1/go.mod h1:2zXcozF6qYGgmsG+SeTZz3oAbFLdD3OWqnUbNvJZAlc=
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
-github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
-github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
+github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc=
+github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -162,6 +184,12 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
+github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
+github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
+github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
+github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
+github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
+github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
@@ -174,14 +202,17 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
+github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
-github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
+github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
+github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
@@ -204,12 +235,39 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
-github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
+github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
+github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tkrajina/go-reflector v0.5.5 h1:gwoQFNye30Kk7NrExj8zm3zFtrGPqOkzFMLuQZg1DtQ=
github.com/tkrajina/go-reflector v0.5.5/go.mod h1:ECbqLgccecY5kPmPmXg1MrHW585yMcDkVl6IvJe64T4=
github.com/tkrajina/typescriptify-golang-structs v0.2.0 h1:ZedWk82egydDspGTryAatbX0/1NZDQbdiZLoCbOk4f8=
github.com/tkrajina/typescriptify-golang-structs v0.2.0/go.mod h1:sjU00nti/PMEOZb07KljFlR+lJ+RotsC0GBQMv9EKls=
+github.com/tree-sitter/go-tree-sitter v0.25.0 h1:sx6kcg8raRFCvc9BnXglke6axya12krCJF5xJ2sftRU=
+github.com/tree-sitter/go-tree-sitter v0.25.0/go.mod h1:r77ig7BikoZhHrrsjAnv8RqGti5rtSyvDHPzgTPsUuU=
+github.com/tree-sitter/tree-sitter-c v0.23.4 h1:nBPH3FV07DzAD7p0GfNvXM+Y7pNIoPenQWBpvM++t4c=
+github.com/tree-sitter/tree-sitter-c v0.23.4/go.mod h1:MkI5dOiIpeN94LNjeCp8ljXN/953JCwAby4bClMr6bw=
+github.com/tree-sitter/tree-sitter-cpp v0.23.4 h1:LaWZsiqQKvR65yHgKmnaqA+uz6tlDJTJFCyFIeZU/8w=
+github.com/tree-sitter/tree-sitter-cpp v0.23.4/go.mod h1:doqNW64BriC7WBCQ1klf0KmJpdEvfxyXtoEybnBo6v8=
+github.com/tree-sitter/tree-sitter-embedded-template v0.23.2 h1:nFkkH6Sbe56EXLmZBqHHcamTpmz3TId97I16EnGy4rg=
+github.com/tree-sitter/tree-sitter-embedded-template v0.23.2/go.mod h1:HNPOhN0qF3hWluYLdxWs5WbzP/iE4aaRVPMsdxuzIaQ=
+github.com/tree-sitter/tree-sitter-go v0.23.4 h1:yt5KMGnTHS+86pJmLIAZMWxukr8W7Ae1STPvQUuNROA=
+github.com/tree-sitter/tree-sitter-go v0.23.4/go.mod h1:Jrx8QqYN0v7npv1fJRH1AznddllYiCMUChtVjxPK040=
+github.com/tree-sitter/tree-sitter-html v0.23.2 h1:1UYDV+Yd05GGRhVnTcbP58GkKLSHHZwVaN+lBZV11Lc=
+github.com/tree-sitter/tree-sitter-html v0.23.2/go.mod h1:gpUv/dG3Xl/eebqgeYeFMt+JLOY9cgFinb/Nw08a9og=
+github.com/tree-sitter/tree-sitter-java v0.23.5 h1:J9YeMGMwXYlKSP3K4Us8CitC6hjtMjqpeOf2GGo6tig=
+github.com/tree-sitter/tree-sitter-java v0.23.5/go.mod h1:NRKlI8+EznxA7t1Yt3xtraPk1Wzqh3GAIC46wxvc320=
+github.com/tree-sitter/tree-sitter-javascript v0.23.1 h1:1fWupaRC0ArlHJ/QJzsfQ3Ibyopw7ZfQK4xXc40Zveo=
+github.com/tree-sitter/tree-sitter-javascript v0.23.1/go.mod h1:lmGD1EJdCA+v0S1u2fFgepMg/opzSg/4pgFym2FPGAs=
+github.com/tree-sitter/tree-sitter-json v0.24.8 h1:tV5rMkihgtiOe14a9LHfDY5kzTl5GNUYe6carZBn0fQ=
+github.com/tree-sitter/tree-sitter-json v0.24.8/go.mod h1:F351KK0KGvCaYbZ5zxwx/gWWvZhIDl0eMtn+1r+gQbo=
+github.com/tree-sitter/tree-sitter-php v0.23.11 h1:iHewsLNDmznh8kgGyfWfujsZxIz1YGbSd2ZTEM0ZiP8=
+github.com/tree-sitter/tree-sitter-php v0.23.11/go.mod h1:T/kbfi+UcCywQfUNAJnGTN/fMSUjnwPXA8k4yoIks74=
+github.com/tree-sitter/tree-sitter-python v0.23.6 h1:qHnWFR5WhtMQpxBZRwiaU5Hk/29vGju6CVtmvu5Haas=
+github.com/tree-sitter/tree-sitter-python v0.23.6/go.mod h1:cpdthSy/Yoa28aJFBscFHlGiU+cnSiSh1kuDVtI8YeM=
+github.com/tree-sitter/tree-sitter-ruby v0.23.1 h1:T/NKHUA+iVbHM440hFx+lzVOzS4dV6z8Qw8ai+72bYo=
+github.com/tree-sitter/tree-sitter-ruby v0.23.1/go.mod h1:kUS4kCCQloFcdX6sdpr8p6r2rogbM6ZjTox5ZOQy8cA=
+github.com/tree-sitter/tree-sitter-rust v0.23.2 h1:6AtoooCW5GqNrRpfnvl0iUhxTAZEovEmLKDbyHlfw90=
+github.com/tree-sitter/tree-sitter-rust v0.23.2/go.mod h1:hfeGWic9BAfgTrc7Xf6FaOAguCFJRo3RBbs7QJ6D7MI=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
@@ -218,6 +276,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/
github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw=
github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM=
github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg=
+github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
+github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@@ -304,6 +364,8 @@ golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
diff --git a/integration/basic_test.go b/integration/basic_test.go
index 414061479bc..351a1e3888a 100644
--- a/integration/basic_test.go
+++ b/integration/basic_test.go
@@ -144,3 +144,47 @@ func TestUnicodeModelDir(t *testing.T) {
}
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
+
+// TestNumPredict verifies that when num_predict is set, the model generates
+// exactly that many tokens. It uses logprobs to count the actual tokens output.
+func TestNumPredict(t *testing.T) {
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
+ defer cancel()
+
+ client, _, cleanup := InitServerConnection(ctx, t)
+ defer cleanup()
+
+ if err := PullIfMissing(ctx, client, "qwen3:0.6b"); err != nil {
+ t.Fatalf("failed to pull model: %v", err)
+ }
+
+ req := api.GenerateRequest{
+ Model: "qwen3:0.6b",
+ Prompt: "Write a long story.",
+ Stream: &stream,
+ Logprobs: true,
+ Options: map[string]any{
+ "num_predict": 10,
+ "temperature": 0,
+ "seed": 123,
+ },
+ }
+
+ logprobCount := 0
+ var finalResponse api.GenerateResponse
+ err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
+ logprobCount += len(resp.Logprobs)
+ if resp.Done {
+ finalResponse = resp
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatalf("generate failed: %v", err)
+ }
+
+ if logprobCount != 10 {
+ t.Errorf("expected 10 tokens (logprobs), got %d (EvalCount=%d, DoneReason=%s)",
+ logprobCount, finalResponse.EvalCount, finalResponse.DoneReason)
+ }
+}
diff --git a/integration/imagegen_test.go b/integration/imagegen_test.go
index 0b7e6e0e064..6f2b65eccc2 100644
--- a/integration/imagegen_test.go
+++ b/integration/imagegen_test.go
@@ -3,18 +3,14 @@
package integration
import (
- "bytes"
"context"
"encoding/base64"
- "encoding/json"
"fmt"
- "net/http"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
- imagegenapi "github.com/ollama/ollama/x/imagegen/api"
)
func TestImageGeneration(t *testing.T) {
@@ -41,7 +37,7 @@ func TestImageGeneration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
defer cancel()
- client, testEndpoint, cleanup := InitServerConnection(ctx, t)
+ client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Pull both models
@@ -54,7 +50,7 @@ func TestImageGeneration(t *testing.T) {
// Generate the image
t.Logf("Generating image with prompt: %s", tc.prompt)
- imageBase64, err := generateImage(ctx, testEndpoint, tc.imageGenModel, tc.prompt)
+ imageBase64, err := generateImage(ctx, client, tc.imageGenModel, tc.prompt)
if err != nil {
if strings.Contains(err.Error(), "image generation not available") {
t.Skip("Target system does not support image generation")
@@ -127,48 +123,26 @@ func TestImageGeneration(t *testing.T) {
}
}
-// generateImage calls the OpenAI-compatible image generation API and returns the base64 image data
-func generateImage(ctx context.Context, endpoint, model, prompt string) (string, error) {
- reqBody := imagegenapi.ImageGenerationRequest{
- Model: model,
- Prompt: prompt,
- N: 1,
- Size: "512x512",
- ResponseFormat: "b64_json",
- }
-
- jsonBody, err := json.Marshal(reqBody)
- if err != nil {
- return "", fmt.Errorf("failed to marshal request: %w", err)
- }
-
- url := fmt.Sprintf("http://%s/v1/images/generations", endpoint)
- req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(jsonBody))
- if err != nil {
- return "", fmt.Errorf("failed to create request: %w", err)
- }
- req.Header.Set("Content-Type", "application/json")
-
- resp, err := http.DefaultClient.Do(req)
+// generateImage calls the Ollama API to generate an image and returns the base64 image data
+func generateImage(ctx context.Context, client *api.Client, model, prompt string) (string, error) {
+ var imageBase64 string
+
+ err := client.Generate(ctx, &api.GenerateRequest{
+ Model: model,
+ Prompt: prompt,
+ }, func(resp api.GenerateResponse) error {
+ if resp.Image != "" {
+ imageBase64 = resp.Image
+ }
+ return nil
+ })
if err != nil {
- return "", fmt.Errorf("failed to send request: %w", err)
- }
- defer resp.Body.Close()
-
- if resp.StatusCode != http.StatusOK {
- var buf bytes.Buffer
- buf.ReadFrom(resp.Body)
- return "", fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, buf.String())
- }
-
- var genResp imagegenapi.ImageGenerationResponse
- if err := json.NewDecoder(resp.Body).Decode(&genResp); err != nil {
- return "", fmt.Errorf("failed to decode response: %w", err)
+ return "", fmt.Errorf("failed to generate image: %w", err)
}
- if len(genResp.Data) == 0 {
+ if imageBase64 == "" {
return "", fmt.Errorf("no image data in response")
}
- return genResp.Data[0].B64JSON, nil
+ return imageBase64, nil
}
diff --git a/integration/utils_test.go b/integration/utils_test.go
index 2dd39ecb0b8..6112edf2743 100644
--- a/integration/utils_test.go
+++ b/integration/utils_test.go
@@ -38,6 +38,7 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
+ "lfm2.5-thinking",
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
@@ -143,6 +144,7 @@ var (
"granite3.3",
"hermes3",
"internlm2",
+ "lfm2.5-thinking",
"llama-guard3",
"llama-pro",
"llama2-chinese",
@@ -263,6 +265,7 @@ var (
"snowflake-arctic-embed2",
}
libraryToolsModels = []string{
+ "lfm2.5-thinking",
"qwen3-vl",
"gpt-oss:20b",
"gpt-oss:120b",
diff --git a/internal/cloud/policy.go b/internal/cloud/policy.go
new file mode 100644
index 00000000000..c540bff67f7
--- /dev/null
+++ b/internal/cloud/policy.go
@@ -0,0 +1,25 @@
+package cloud
+
+import (
+ "github.com/ollama/ollama/envconfig"
+)
+
+const DisabledMessagePrefix = "ollama cloud is disabled"
+
+// Status returns whether cloud is disabled and the source of the decision.
+// Source is one of: "none", "env", "config", "both".
+func Status() (disabled bool, source string) {
+ return envconfig.NoCloud(), envconfig.NoCloudSource()
+}
+
+func Disabled() bool {
+ return envconfig.NoCloud()
+}
+
+func DisabledError(operation string) string {
+ if operation == "" {
+ return DisabledMessagePrefix
+ }
+
+ return DisabledMessagePrefix + ": " + operation
+}
diff --git a/internal/cloud/policy_test.go b/internal/cloud/policy_test.go
new file mode 100644
index 00000000000..28c36eba3e7
--- /dev/null
+++ b/internal/cloud/policy_test.go
@@ -0,0 +1,85 @@
+package cloud
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+)
+
+func TestStatus(t *testing.T) {
+ tests := []struct {
+ name string
+ envValue string
+ configContent string
+ disabled bool
+ source string
+ }{
+ {
+ name: "none",
+ disabled: false,
+ source: "none",
+ },
+ {
+ name: "env only",
+ envValue: "1",
+ disabled: true,
+ source: "env",
+ },
+ {
+ name: "config only",
+ configContent: `{"disable_ollama_cloud": true}`,
+ disabled: true,
+ source: "config",
+ },
+ {
+ name: "both",
+ envValue: "1",
+ configContent: `{"disable_ollama_cloud": true}`,
+ disabled: true,
+ source: "both",
+ },
+ {
+ name: "invalid config ignored",
+ configContent: `{invalid json`,
+ disabled: false,
+ source: "none",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ home := t.TempDir()
+ if tt.configContent != "" {
+ configPath := filepath.Join(home, ".ollama", "server.json")
+ if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(configPath, []byte(tt.configContent), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ setTestHome(t, home)
+ t.Setenv("OLLAMA_NO_CLOUD", tt.envValue)
+
+ disabled, source := Status()
+ if disabled != tt.disabled {
+ t.Fatalf("disabled: expected %v, got %v", tt.disabled, disabled)
+ }
+ if source != tt.source {
+ t.Fatalf("source: expected %q, got %q", tt.source, source)
+ }
+ })
+ }
+}
+
+func TestDisabledError(t *testing.T) {
+ if got := DisabledError(""); got != DisabledMessagePrefix {
+ t.Fatalf("expected %q, got %q", DisabledMessagePrefix, got)
+ }
+
+ want := DisabledMessagePrefix + ": remote inference is unavailable"
+ if got := DisabledError("remote inference is unavailable"); got != want {
+ t.Fatalf("expected %q, got %q", want, got)
+ }
+}
diff --git a/internal/cloud/test_home_test.go b/internal/cloud/test_home_test.go
new file mode 100644
index 00000000000..5da8c3a69a2
--- /dev/null
+++ b/internal/cloud/test_home_test.go
@@ -0,0 +1,14 @@
+package cloud
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/envconfig"
+)
+
+func setTestHome(t *testing.T, home string) {
+ t.Helper()
+ t.Setenv("HOME", home)
+ t.Setenv("USERPROFILE", home)
+ envconfig.ReloadServerConfig()
+}
diff --git a/kvcache/cache.go b/kvcache/cache.go
index 405c797332e..5c6fc250bfa 100644
--- a/kvcache/cache.go
+++ b/kvcache/cache.go
@@ -75,3 +75,10 @@ type Cache interface {
// removed by calling Remove(seq, 0, math.MaxInt32)
Remove(seq int, beginIndex, endIndex int32) error
}
+
+// CheckpointCache optionally supports restoring recurrent state to a prior
+// position to avoid full prompt reprocessing when a prefix mismatch occurs.
+// The returned position is the number of tokens that can be kept (prefix length).
+type CheckpointCache interface {
+ PrepareRestore(seq int, targetPos int32) (int32, bool)
+}
diff --git a/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch
new file mode 100644
index 00000000000..abd7df93014
--- /dev/null
+++ b/llama/patches/0032-ggml-enable-MLA-flash-attention-for-GLM-4.7-flash.patch
@@ -0,0 +1,309 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: nobody <>
+Date: Sat, 24 Jan 2026 02:31:01 +0000
+Subject: [PATCH] ggml: enable MLA flash attention for GLM-4.7-flash
+
+Add support for gqa_ratio 4 in MLA flash attention kernels. GLM-4.7-flash
+uses head size 576 with gqa_ratio 4, which was previously only supported
+for gqa_ratio 16 (DeepSeek).
+
+Metal changes:
+- Enable head size 576 for flash attention
+- Increase simdgroups to 8 for large heads (>=512)
+- Add case 8 kernel dispatch for 8 simdgroups
+
+CUDA changes:
+- Add gqa_ratio 4 support for head 576/512
+- Add tile configs for (576, 512, 4) and (576, 512, 8)
+- Add MMA config cases for ncols 4
+- Add template instances for ncols2=4
+- Fix nbatch_fa values in nvidia_fp32 config (32->64)
+---
+ ggml/src/ggml-cuda/fattn-mma-f16.cuh | 40 +++++++++++++++----
+ ggml/src/ggml-cuda/fattn-tile.cuh | 16 ++++++++
+ ggml/src/ggml-cuda/fattn.cu | 12 ++++--
+ ...ttn-mma-f16-instance-ncols1_16-ncols2_4.cu | 1 +
+ ...attn-mma-f16-instance-ncols1_2-ncols2_4.cu | 1 +
+ ...attn-mma-f16-instance-ncols1_4-ncols2_4.cu | 1 +
+ ...attn-mma-f16-instance-ncols1_8-ncols2_4.cu | 1 +
+ ggml/src/ggml-metal/ggml-metal-device.m | 8 +---
+ ggml/src/ggml-metal/ggml-metal-ops.cpp | 2 +-
+ ggml/src/ggml-metal/ggml-metal.metal | 1 +
+ 10 files changed, 64 insertions(+), 19 deletions(-)
+
+diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+index 7bd1044c1..3dea2205e 100644
+--- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh
++++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
+
+- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
+
+- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
+@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
+ }
+
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
+- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
++ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
+@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int cols_per_warp = T_B_KQ::I;
+ constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
++ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
+@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+ }
+ }
+ } else {
+- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
+ #pragma unroll
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
+ load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
+@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
+ T_A_KQ K_A;
+ load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
+
+- // Wide version of KQ_C is column-major => swap A and B.
+- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
++ if constexpr (cols_per_warp == 8) {
++ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
++ } else {
++ // Wide version of KQ_C is column-major
++#if defined(AMD_WMMA_AVAILABLE)
++ // RDNA matrix C is column-major.
++ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
++#else
++ // swap A and B for CUDA.
++ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
++#endif // defined(AMD_WMMA_AVAILABLE)
++ }
+ }
+ }
+ }
+@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
+
+ constexpr int cols_per_warp = T_B_KQ::I;
+ constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
+- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
++ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
+@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
+ NO_DEVICE_CODE;
+ return;
+ }
++#ifdef VOLTA_MMA_AVAILABLE
++ if (ncols1*ncols2 < 32) {
++ NO_DEVICE_CODE;
++ return;
++ }
++#endif // VOLTA_MMA_AVAILABLE
++
+ #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ if (ncols1*ncols2 > 32) {
+ NO_DEVICE_CODE;
+@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
++
++// For GLM 4.7 Flash
++extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
++extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
++extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
+diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh
+index 7c4d6fe67..371be7442 100644
+--- a/ggml/src/ggml-cuda/fattn-tile.cuh
++++ b/ggml/src/ggml-cuda/fattn-tile.cuh
+@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+
+ return 0;
+@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
+
+ return 0;
+@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
+
+@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
++ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
+
+@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
+ launch_fattn_tile_switch_ncols1(ctx, dst);
+ return;
+ }
++ if (use_gqa_opt && gqa_ratio % 8 == 0) {
++ launch_fattn_tile_switch_ncols1(ctx, dst);
++ return;
++ }
++ if (use_gqa_opt && gqa_ratio % 4 == 0) {
++ launch_fattn_tile_switch_ncols1(ctx, dst);
++ return;
++ }
+ }
+
+ if constexpr (DV <= 256) {
+diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu
+index 015540666..1693479cb 100644
+--- a/ggml/src/ggml-cuda/fattn.cu
++++ b/ggml/src/ggml-cuda/fattn.cu
+@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
+ break;
+ case 576: {
+- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
++ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ GGML_ASSERT(V->ne[0] == 512);
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+- GGML_ASSERT(gqa_ratio % 16 == 0);
+- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
++ GGML_ASSERT(gqa_ratio % 4 == 0);
++ if (gqa_ratio % 16 == 0) {
++ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
++ } else {
++ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
++ }
+ } break;
+ default:
+ GGML_ABORT("fatal error");
+@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
+ if (V->ne[0] != 512) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
++ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+index 2074e954a..517993cb0 100644
+--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
+ DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
+ DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
+ DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
++DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
+diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+index 24c64cf00..97b19c67a 100644
+--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
+ DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
+ DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
+ DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
++DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
+diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+index 1ada657f1..989626dfa 100644
+--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
+ DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
+ DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
+ DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
++DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+index 86d4ffae2..173de7aac 100644
+--- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
++++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
+ DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
+ DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
+ DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
++DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
+index f24270bb1..7b5ee968c 100644
+--- a/ggml/src/ggml-metal/ggml-metal-device.m
++++ b/ggml/src/ggml-metal/ggml-metal-device.m
+@@ -1071,12 +1071,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
+ op->src[0]->ne[0] != 112 &&
+ op->src[0]->ne[0] != 128 &&
+ op->src[0]->ne[0] != 192 &&
+- op->src[0]->ne[0] != 256) {
+- return false;
+- }
+- if (op->src[0]->ne[0] == 576) {
+- // DeepSeek sizes
+- // TODO: disabled for now, until optmized
++ op->src[0]->ne[0] != 256 &&
++ op->src[0]->ne[0] != 576) {
+ return false;
+ }
+ if (op->src[1]->type != op->src[2]->type) {
+diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
+index e99c1763f..80864f303 100644
+--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
++++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
+@@ -2456,7 +2456,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
+
+ // simdgroups per threadgroup (a.k.a. warps)
+ //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
+- int32_t nsg = 4;
++ int32_t nsg = ne00 >= 512 ? 8 : 4;
+
+ const size_t smem = FATTN_SMEM(nsg);
+
+diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
+index c98d269d1..d33c16079 100644
+--- a/ggml/src/ggml-metal/ggml-metal.metal
++++ b/ggml/src/ggml-metal/ggml-metal.metal
+@@ -6166,6 +6166,7 @@ kernel void kernel_flash_attn_ext(
+ //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ //case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break;
++ case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ }
+ #undef FWD_TMPL
+ #undef FWD_ARGS
diff --git a/llama/patches/0033-ggml-metal-solve_tri.patch b/llama/patches/0033-ggml-metal-solve_tri.patch
new file mode 100644
index 00000000000..7bc65fda791
--- /dev/null
+++ b/llama/patches/0033-ggml-metal-solve_tri.patch
@@ -0,0 +1,276 @@
+From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
+From: Jeffrey Morgan
+Date: Tue, 3 Feb 2026 12:00:00 -0800
+Subject: [PATCH] ggml: metal solve_tri
+
+---
+ ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++++++
+ ggml/src/ggml-metal/ggml-metal-device.h | 1 +
+ ggml/src/ggml-metal/ggml-metal-device.m | 11 ++++
+ ggml/src/ggml-metal/ggml-metal-impl.h | 21 ++++++++
+ ggml/src/ggml-metal/ggml-metal-ops.cpp | 63 +++++++++++++++++++++++
+ ggml/src/ggml-metal/ggml-metal-ops.h | 1 +
+ ggml/src/ggml-metal/ggml-metal.metal | 60 +++++++++++++++++++++
+ 7 files changed, 177 insertions(+)
+
+diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp
+index 680904d13..83385c9ef 100644
+--- a/ggml/src/ggml-metal/ggml-metal-device.cpp
++++ b/ggml/src/ggml-metal/ggml-metal-device.cpp
+@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
+ return res;
+ }
+
++ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
++ assert(op->op == GGML_OP_SOLVE_TRI);
++
++ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
++ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
++
++ char base[256];
++ char name[256];
++
++ snprintf(base, 256, "kernel_solve_tri_f32");
++ snprintf(name, 256, "%s", base);
++
++ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
++ if (!res.pipeline) {
++ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
++ }
++
++ return res;
++}
++
+ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_GROUP_NORM);
+
+diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h
+index 0a8b9211a..8a9d17460 100644
+--- a/ggml/src/ggml-metal/ggml-metal-device.h
++++ b/ggml/src/ggml-metal/ggml-metal-device.h
+@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
++struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
+ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
+diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m
+index 7b5ee968c..4e5acfbe5 100644
+--- a/ggml/src/ggml-metal/ggml-metal-device.m
++++ b/ggml/src/ggml-metal/ggml-metal-device.m
+@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
+ return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
+ case GGML_OP_L2_NORM:
+ return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
++ case GGML_OP_SOLVE_TRI:
++ return ggml_is_contiguous(op->src[0]) &&
++ ggml_is_contiguous(op->src[1]) &&
++ op->src[0]->type == GGML_TYPE_F32 &&
++ op->src[1]->type == GGML_TYPE_F32 &&
++ op->type == GGML_TYPE_F32;
++ case GGML_OP_COUNT_EQUAL:
++ return has_simdgroup_reduction &&
++ op->src[0]->type == GGML_TYPE_I32 &&
++ op->src[1]->type == GGML_TYPE_I32 &&
++ op->type == GGML_TYPE_I64;
+ case GGML_OP_ARGMAX:
+ return has_simdgroup_reduction;
+ case GGML_OP_NORM:
+diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
+index 8944b07e9..cfdea9c07 100644
+--- a/ggml/src/ggml-metal/ggml-metal-impl.h
++++ b/ggml/src/ggml-metal/ggml-metal-impl.h
+@@ -500,6 +500,27 @@ typedef struct {
+ float eps;
+ } ggml_metal_kargs_l2_norm;
+
++typedef struct {
++ int32_t ne00;
++ int32_t ne01;
++ int32_t ne02;
++ int32_t ne03;
++ uint64_t nb00;
++ uint64_t nb01;
++ uint64_t nb02;
++ uint64_t nb03;
++ int32_t ne10;
++ int32_t ne11;
++ uint64_t nb10;
++ uint64_t nb11;
++ uint64_t nb12;
++ uint64_t nb13;
++ uint64_t nb0;
++ uint64_t nb1;
++ uint64_t nb2;
++ uint64_t nb3;
++} ggml_metal_kargs_solve_tri;
++
+ typedef struct {
+ int64_t ne00;
+ int64_t ne01;
+diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp
+index 80864f303..4ac135603 100644
+--- a/ggml/src/ggml-metal/ggml-metal-ops.cpp
++++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp
+@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
+ {
+ n_fuse = ggml_metal_op_l2_norm(ctx, idx);
+ } break;
++ case GGML_OP_SOLVE_TRI:
++ {
++ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
++ } break;
+ case GGML_OP_GROUP_NORM:
+ {
+ n_fuse = ggml_metal_op_group_norm(ctx, idx);
+@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
+ return 1;
+ }
+
++int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
++ ggml_tensor * op = ctx->node(idx);
++
++ ggml_metal_library_t lib = ctx->lib;
++ ggml_metal_encoder_t enc = ctx->enc;
++
++ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
++ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
++ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
++ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
++ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
++ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
++
++ ggml_metal_kargs_solve_tri args = {
++ /*.ne00 =*/ ne00,
++ /*.ne01 =*/ ne01,
++ /*.ne02 =*/ ne02,
++ /*.ne03 =*/ ne03,
++ /*.nb00 =*/ nb00,
++ /*.nb01 =*/ nb01,
++ /*.nb02 =*/ nb02,
++ /*.nb03 =*/ nb03,
++ /*.ne10 =*/ ne10,
++ /*.ne11 =*/ ne11,
++ /*.nb10 =*/ nb10,
++ /*.nb11 =*/ nb11,
++ /*.nb12 =*/ nb12,
++ /*.nb13 =*/ nb13,
++ /*.nb0 =*/ nb0,
++ /*.nb1 =*/ nb1,
++ /*.nb2 =*/ nb2,
++ /*.nb3 =*/ nb3,
++ };
++
++ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
++
++ const int64_t ncols = ne10;
++ const int64_t n_batches = (int64_t)ne02 * ne03;
++ const int64_t nr = n_batches * ncols;
++
++ int nth = 64;
++ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
++ if (nth < 1) {
++ nth = 1;
++ }
++
++ const int64_t n_tg = (nr + nth - 1) / nth;
++
++ ggml_metal_encoder_set_pipeline(enc, pipeline);
++ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
++ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
++ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
++ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
++
++ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
++
++ return 1;
++}
++
+ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h
+index 902b54452..a475183d3 100644
+--- a/ggml/src/ggml-metal/ggml-metal-ops.h
++++ b/ggml/src/ggml-metal/ggml-metal-ops.h
+@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
+ int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
+ int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
+ int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
++int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
+ int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
+ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
+ int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
+diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
+index d33c16079..c37447a10 100644
+--- a/ggml/src/ggml-metal/ggml-metal.metal
++++ b/ggml/src/ggml-metal/ggml-metal.metal
+@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
+ }
+ }
+
++kernel void kernel_solve_tri_f32(
++ constant ggml_metal_kargs_solve_tri & args,
++ device const char * src0,
++ device const char * src1,
++ device char * dst,
++ uint tgpig[[threadgroup_position_in_grid]],
++ ushort tpitg[[thread_position_in_threadgroup]],
++ ushort ntg[[threads_per_threadgroup]]) {
++ const uint64_t ncols = (uint64_t) args.ne10;
++ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
++ const uint64_t nr = n_batches * ncols;
++
++ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
++ if (gid >= nr) {
++ return;
++ }
++
++ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
++ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
++ const uint64_t i02 = rem / ncols;
++ const uint64_t i01 = rem - i02 * ncols;
++
++ const uint64_t sa0 = args.nb00 / sizeof(float);
++ const uint64_t sa1 = args.nb01 / sizeof(float);
++ const uint64_t sa2 = args.nb02 / sizeof(float);
++ const uint64_t sa3 = args.nb03 / sizeof(float);
++
++ const uint64_t sb0 = args.nb10 / sizeof(float);
++ const uint64_t sb1 = args.nb11 / sizeof(float);
++ const uint64_t sb2 = args.nb12 / sizeof(float);
++ const uint64_t sb3 = args.nb13 / sizeof(float);
++
++ const uint64_t sx0 = args.nb0 / sizeof(float);
++ const uint64_t sx1 = args.nb1 / sizeof(float);
++ const uint64_t sx2 = args.nb2 / sizeof(float);
++ const uint64_t sx3 = args.nb3 / sizeof(float);
++
++ device const float * A = (device const float *) src0;
++ device const float * B = (device const float *) src1;
++ device float * X = (device float *) dst;
++
++ const uint64_t A_base = i02 * sa2 + i03 * sa3;
++ const uint64_t B_base = i02 * sb2 + i03 * sb3;
++ const uint64_t X_base = i02 * sx2 + i03 * sx3;
++
++ const uint64_t n = (uint64_t) args.ne11;
++
++ for (uint64_t i00 = 0; i00 < n; ++i00) {
++ float sum = 0.0f;
++ for (uint64_t t = 0; t < i00; ++t) {
++ sum += A[A_base + i00 * sa1 + t * sa0] *
++ X[X_base + t * sx1 + i01 * sx0];
++ }
++
++ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
++ X[X_base + i00 * sx1 + i01 * sx0] =
++ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
++ }
++}
++
+ kernel void kernel_group_norm_f32(
+ constant ggml_metal_kargs_group_norm & args,
+ device const float * src0,
diff --git a/llm/server.go b/llm/server.go
index 846ef8ffa5a..6ecc17b646b 100644
--- a/llm/server.go
+++ b/llm/server.go
@@ -34,6 +34,7 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/tokenizer"
)
type filteredEnv []string
@@ -80,6 +81,7 @@ type LlamaServer interface {
GetPort() int
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
HasExited() bool
+ ContextLength() int
}
// llmServer is an instance of a runner hosting a single model
@@ -116,7 +118,7 @@ type llamaServer struct {
type ollamaServer struct {
llmServer
- textProcessor model.TextProcessor // textProcessor handles text encoding/decoding
+ tokenizer tokenizer.Tokenizer // tokenizer handles text encoding/decoding
}
// LoadModel will load a model from disk. The model must be in the GGML format.
@@ -185,11 +187,11 @@ func LoadModel(model string, extraModels []string, maxArraySize int, reliefSplit
// NewLlamaServer will run a server for the given GPUs
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, extraModelPaths []string, f *ggml.MetaGGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
var llamaModel *llama.Model
- var textProcessor model.TextProcessor
+ var tok tokenizer.Tokenizer
var err error
if envconfig.NewEngine() || f.KV().OllamaEngineRequired() {
if len(projectors) == 0 {
- textProcessor, err = model.NewTextProcessor(modelPath)
+ tok, err = model.NewTextProcessor(modelPath)
} else {
err = errors.New("split vision models aren't supported")
}
@@ -228,7 +230,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}
- if textProcessor == nil {
+ if tok == nil {
llamaModel, err = llama.LoadModelFromFile(modelPath, extraModelPaths, llama.ModelParams{VocabOnly: true})
if err != nil {
return nil, err
@@ -284,7 +286,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
kvct := strings.ToLower(envconfig.KvCacheType())
- if textProcessor == nil {
+ if tok == nil {
flashAttention := ml.FlashAttentionAuto
if faUserSet {
if fa {
@@ -334,7 +336,7 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
gpuLibs := ml.LibraryPaths(gpus)
status := NewStatusWriter(os.Stderr)
cmd, port, err := StartRunner(
- textProcessor != nil,
+ tok != nil,
modelPath,
extraModelPaths,
gpuLibs,
@@ -385,8 +387,8 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st
}
}()
- if textProcessor != nil {
- return &ollamaServer{llmServer: s, textProcessor: textProcessor}, nil
+ if tok != nil {
+ return &ollamaServer{llmServer: s, tokenizer: tok}, nil
} else {
return &llamaServer{llmServer: s, ggml: f}, nil
}
@@ -1302,7 +1304,8 @@ func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation Lo
resp, err := http.DefaultClient.Do(r)
if err != nil {
- return nil, fmt.Errorf("do load request: %w", err)
+ slog.Error("do load request", "error", err)
+ return nil, errors.New("model failed to load, this may be due to resource limitations or an internal error, check ollama server logs for details")
}
defer resp.Body.Close()
@@ -1874,7 +1877,7 @@ func (s *llamaServer) Tokenize(ctx context.Context, content string) ([]int, erro
}
func (s *ollamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
- tokens, err := s.textProcessor.Encode(content, false)
+ tokens, err := s.tokenizer.Encode(content, false)
if err != nil {
return nil, err
}
@@ -1909,7 +1912,7 @@ func (s *ollamaServer) Detokenize(ctx context.Context, tokens []int) (string, er
toks[i] = int32(t)
}
- content, err := s.textProcessor.Decode(toks)
+ content, err := s.tokenizer.Decode(toks)
if err != nil {
return "", err
}
@@ -2003,6 +2006,10 @@ func (s *llmServer) VRAMByGPU(id ml.DeviceID) uint64 {
return 0
}
+func (s *llmServer) ContextLength() int {
+ return s.options.NumCtx
+}
+
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
devices, err := ml.GetDevicesFromRunner(ctx, s)
if err != nil {
diff --git a/server/layer.go b/manifest/layer.go
similarity index 88%
rename from server/layer.go
rename to manifest/layer.go
index 4baabe35cf1..82d44953aac 100644
--- a/server/layer.go
+++ b/manifest/layer.go
@@ -1,4 +1,4 @@
-package server
+package manifest
import (
"crypto/sha256"
@@ -14,7 +14,7 @@ type Layer struct {
Size int64 `json:"size"`
From string `json:"from,omitempty"`
Name string `json:"name,omitempty"` // tensor name, e.g., "text_encoder/model.embed_tokens.weight"
- status string
+ Status string `json:"-"`
}
const (
@@ -22,7 +22,7 @@ const (
)
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
- blobs, err := GetBlobsPath("")
+ blobs, err := BlobsPath("")
if err != nil {
return Layer{}, err
}
@@ -45,7 +45,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
- blob, err := GetBlobsPath(digest)
+ blob, err := BlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -65,7 +65,7 @@ func NewLayer(r io.Reader, mediatype string) (Layer, error) {
MediaType: mediatype,
Digest: digest,
Size: n,
- status: fmt.Sprintf("%s %s", status, digest),
+ Status: fmt.Sprintf("%s %s", status, digest),
}, nil
}
@@ -74,7 +74,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
- blob, err := GetBlobsPath(digest)
+ blob, err := BlobsPath(digest)
if err != nil {
return Layer{}, err
}
@@ -89,7 +89,7 @@ func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
Digest: digest,
Size: fi.Size(),
From: from,
- status: fmt.Sprintf("using existing layer %s", digest),
+ Status: fmt.Sprintf("using existing layer %s", digest),
}, nil
}
@@ -98,7 +98,7 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return nil, errors.New("opening layer with empty digest")
}
- blob, err := GetBlobsPath(l.Digest)
+ blob, err := BlobsPath(l.Digest)
if err != nil {
return nil, err
}
@@ -126,7 +126,7 @@ func (l *Layer) Remove() error {
}
}
- blob, err := GetBlobsPath(l.Digest)
+ blob, err := BlobsPath(l.Digest)
if err != nil {
return err
}
diff --git a/server/manifest.go b/manifest/manifest.go
similarity index 81%
rename from server/manifest.go
rename to manifest/manifest.go
index da596f6582f..c0277e9a572 100644
--- a/server/manifest.go
+++ b/manifest/manifest.go
@@ -1,10 +1,9 @@
-package server
+package manifest
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
- "errors"
"fmt"
"io"
"log/slog"
@@ -33,12 +32,38 @@ func (m *Manifest) Size() (size int64) {
return
}
+func (m *Manifest) Digest() string {
+ return m.digest
+}
+
+func (m *Manifest) FileInfo() os.FileInfo {
+ return m.fi
+}
+
+// ReadConfigJSON reads and unmarshals a config layer as JSON.
+func (m *Manifest) ReadConfigJSON(configPath string, v any) error {
+ for _, layer := range m.Layers {
+ if layer.MediaType == "application/vnd.ollama.image.json" && layer.Name == configPath {
+ blobPath, err := BlobsPath(layer.Digest)
+ if err != nil {
+ return err
+ }
+ data, err := os.ReadFile(blobPath)
+ if err != nil {
+ return err
+ }
+ return json.Unmarshal(data, v)
+ }
+ }
+ return fmt.Errorf("config %q not found in manifest", configPath)
+}
+
func (m *Manifest) Remove() error {
if err := os.Remove(m.filepath); err != nil {
return err
}
- manifests, err := GetManifestPath()
+ manifests, err := Path()
if err != nil {
return err
}
@@ -70,11 +95,11 @@ func (m *Manifest) RemoveLayers() error {
if _, used := inUse[layer.Digest]; used {
continue
}
- blob, err := GetBlobsPath(layer.Digest)
+ blob, err := BlobsPath(layer.Digest)
if err != nil {
return err
}
- if err := os.Remove(blob); errors.Is(err, os.ErrNotExist) {
+ if err := os.Remove(blob); os.IsNotExist(err) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
@@ -89,7 +114,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, model.Unqualified(n)
}
- manifests, err := GetManifestPath()
+ manifests, err := Path()
if err != nil {
return nil, err
}
@@ -121,7 +146,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
}
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
- manifests, err := GetManifestPath()
+ manifests, err := Path()
if err != nil {
return err
}
@@ -148,7 +173,7 @@ func WriteManifest(name model.Name, config Layer, layers []Layer) error {
}
func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) {
- manifests, err := GetManifestPath()
+ manifests, err := Path()
if err != nil {
return nil, err
}
diff --git a/server/manifest_test.go b/manifest/manifest_test.go
similarity index 99%
rename from server/manifest_test.go
rename to manifest/manifest_test.go
index d94deefb443..9eb83789ec2 100644
--- a/server/manifest_test.go
+++ b/manifest/manifest_test.go
@@ -1,4 +1,4 @@
-package server
+package manifest
import (
"encoding/json"
diff --git a/manifest/paths.go b/manifest/paths.go
new file mode 100644
index 00000000000..4451c81aa12
--- /dev/null
+++ b/manifest/paths.go
@@ -0,0 +1,95 @@
+package manifest
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "path/filepath"
+ "regexp"
+ "strings"
+
+ "github.com/ollama/ollama/envconfig"
+ "github.com/ollama/ollama/types/model"
+)
+
+var ErrInvalidDigestFormat = errors.New("invalid digest format")
+
+func Path() (string, error) {
+ path := filepath.Join(envconfig.Models(), "manifests")
+ if err := os.MkdirAll(path, 0o755); err != nil {
+ return "", fmt.Errorf("%w: ensure path elements are traversable", err)
+ }
+
+ return path, nil
+}
+
+// PathForName returns the path to the manifest file for a specific model name.
+func PathForName(n model.Name) (string, error) {
+ if !n.IsValid() {
+ return "", os.ErrNotExist
+ }
+
+ manifests, err := Path()
+ if err != nil {
+ return "", err
+ }
+
+ return filepath.Join(manifests, n.Filepath()), nil
+}
+
+func BlobsPath(digest string) (string, error) {
+ // only accept actual sha256 digests
+ pattern := "^sha256[:-][0-9a-fA-F]{64}$"
+ re := regexp.MustCompile(pattern)
+
+ if digest != "" && !re.MatchString(digest) {
+ return "", ErrInvalidDigestFormat
+ }
+
+ digest = strings.ReplaceAll(digest, ":", "-")
+ path := filepath.Join(envconfig.Models(), "blobs", digest)
+ dirPath := filepath.Dir(path)
+ if digest == "" {
+ dirPath = path
+ }
+
+ if err := os.MkdirAll(dirPath, 0o755); err != nil {
+ return "", fmt.Errorf("%w: ensure path elements are traversable", err)
+ }
+
+ return path, nil
+}
+
+// PruneDirectory removes empty directories recursively.
+func PruneDirectory(path string) error {
+ info, err := os.Lstat(path)
+ if err != nil {
+ return err
+ }
+
+ if info.IsDir() && info.Mode()&os.ModeSymlink == 0 {
+ entries, err := os.ReadDir(path)
+ if err != nil {
+ return err
+ }
+
+ for _, entry := range entries {
+ if err := PruneDirectory(filepath.Join(path, entry.Name())); err != nil {
+ return err
+ }
+ }
+
+ entries, err = os.ReadDir(path)
+ if err != nil {
+ return err
+ }
+
+ if len(entries) > 0 {
+ return nil
+ }
+
+ return os.Remove(path)
+ }
+
+ return nil
+}
diff --git a/middleware/anthropic.go b/middleware/anthropic.go
index ff55b6ebfcf..85c95e60c07 100644
--- a/middleware/anthropic.go
+++ b/middleware/anthropic.go
@@ -2,15 +2,22 @@ package middleware
import (
"bytes"
+ "context"
"encoding/json"
"fmt"
"io"
+ "log/slog"
"net/http"
+ "strings"
+ "time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/anthropic"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/envconfig"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
+ "github.com/ollama/ollama/logutil"
)
// AnthropicWriter wraps the response writer to transform Ollama responses to Anthropic format
@@ -18,7 +25,6 @@ type AnthropicWriter struct {
BaseWriter
stream bool
id string
- model string
converter *anthropic.StreamConverter
}
@@ -31,7 +37,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
- err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.ResponseWriter.Status(), errData.Error))
+ err := json.NewEncoder(w.ResponseWriter).Encode(anthropic.NewError(w.Status(), errData.Error))
if err != nil {
return 0, err
}
@@ -40,18 +46,7 @@ func (w *AnthropicWriter) writeError(data []byte) (int, error) {
}
func (w *AnthropicWriter) writeEvent(eventType string, data any) error {
- d, err := json.Marshal(data)
- if err != nil {
- return err
- }
- _, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("event: %s\ndata: %s\n\n", eventType, d)))
- if err != nil {
- return err
- }
- if f, ok := w.ResponseWriter.(http.Flusher); ok {
- f.Flush()
- }
- return nil
+ return writeSSE(w.ResponseWriter, eventType, data)
}
func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
@@ -65,6 +60,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
events := w.converter.Process(chatResponse)
+ logutil.Trace("anthropic middleware: stream chunk", "resp", anthropic.TraceChatResponse(chatResponse), "events", len(events))
for _, event := range events {
if err := w.writeEvent(event.Event, event.Data); err != nil {
return 0, err
@@ -75,6 +71,7 @@ func (w *AnthropicWriter) writeResponse(data []byte) (int, error) {
w.ResponseWriter.Header().Set("Content-Type", "application/json")
response := anthropic.ToMessagesResponse(w.id, chatResponse)
+ logutil.Trace("anthropic middleware: converted response", "resp", anthropic.TraceMessagesResponse(response))
return len(data), json.NewEncoder(w.ResponseWriter).Encode(response)
}
@@ -87,9 +84,743 @@ func (w *AnthropicWriter) Write(data []byte) (int, error) {
return w.writeResponse(data)
}
+// WebSearchAnthropicWriter intercepts responses containing web_search tool calls,
+// executes the search, re-invokes the model with results, and assembles the
+// Anthropic-format response (server_tool_use + web_search_tool_result + text).
+type WebSearchAnthropicWriter struct {
+ BaseWriter
+ newLoopContext func() (context.Context, context.CancelFunc)
+ inner *AnthropicWriter
+ req anthropic.MessagesRequest // original Anthropic request
+ chatReq *api.ChatRequest // converted Ollama request (for followup calls)
+ stream bool
+
+ estimatedInputTokens int
+
+ terminalSent bool
+
+ observedPromptEvalCount int
+ observedEvalCount int
+
+ loopInFlight bool
+ loopBaseInputTok int
+ loopBaseOutputTok int
+ loopResultCh chan webSearchLoopResult
+
+ streamMessageStarted bool
+ streamHasOpenBlock bool
+ streamOpenBlockIndex int
+ streamNextIndex int
+}
+
+const maxWebSearchLoops = 3
+
+type webSearchLoopResult struct {
+ response anthropic.MessagesResponse
+ loopErr *webSearchLoopError
+}
+
+type webSearchLoopError struct {
+ code string
+ query string
+ usage anthropic.Usage
+ err error
+}
+
+func (e *webSearchLoopError) Error() string {
+ if e.err == nil {
+ return e.code
+ }
+ return fmt.Sprintf("%s: %v", e.code, e.err)
+}
+
+func (w *WebSearchAnthropicWriter) Write(data []byte) (int, error) {
+ if w.terminalSent {
+ return len(data), nil
+ }
+
+ code := w.Status()
+ if code != http.StatusOK {
+ return w.inner.writeError(data)
+ }
+
+ var chatResponse api.ChatResponse
+ if err := json.Unmarshal(data, &chatResponse); err != nil {
+ return 0, err
+ }
+ w.recordObservedUsage(chatResponse.Metrics)
+
+ if w.stream && w.loopInFlight {
+ if !chatResponse.Done {
+ return len(data), nil
+ }
+ if err := w.writeLoopResult(); err != nil {
+ return len(data), err
+ }
+ return len(data), nil
+ }
+
+ webSearchCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(chatResponse.Message.ToolCalls)
+ logutil.Trace("anthropic middleware: upstream chunk",
+ "resp", anthropic.TraceChatResponse(chatResponse),
+ "web_search", hasWebSearch,
+ "other_tools", hasOtherTools,
+ )
+ if hasWebSearch && hasOtherTools {
+ // Prefer web_search if both server and client tools are present in one chunk.
+ slog.Debug("preferring web_search tool call over client tool calls in mixed tool response")
+ }
+
+ if !hasWebSearch {
+ if w.stream {
+ if err := w.writePassthroughStreamChunk(chatResponse); err != nil {
+ return 0, err
+ }
+ return len(data), nil
+ }
+ return w.inner.writeResponse(data)
+ }
+
+ if w.stream {
+ // Let the original generation continue to completion while web search runs in parallel.
+ logutil.Trace("anthropic middleware: starting async web_search loop",
+ "tool_call", anthropic.TraceToolCall(webSearchCall),
+ "resp", anthropic.TraceChatResponse(chatResponse),
+ )
+ w.startLoopWorker(chatResponse, webSearchCall)
+ if chatResponse.Done {
+ if err := w.writeLoopResult(); err != nil {
+ return len(data), err
+ }
+ }
+ return len(data), nil
+ }
+
+ loopCtx, cancel := w.startLoopContext()
+ defer cancel()
+
+ initialUsage := anthropic.Usage{
+ InputTokens: max(w.observedPromptEvalCount, chatResponse.Metrics.PromptEvalCount),
+ OutputTokens: max(w.observedEvalCount, chatResponse.Metrics.EvalCount),
+ }
+ logutil.Trace("anthropic middleware: starting sync web_search loop",
+ "tool_call", anthropic.TraceToolCall(webSearchCall),
+ "resp", anthropic.TraceChatResponse(chatResponse),
+ "usage", initialUsage,
+ )
+ response, loopErr := w.runWebSearchLoop(loopCtx, chatResponse, webSearchCall, initialUsage)
+ if loopErr != nil {
+ return len(data), w.sendError(loopErr.code, loopErr.query, loopErr.usage)
+ }
+
+ if err := w.writeTerminalResponse(response); err != nil {
+ return 0, err
+ }
+
+ return len(data), nil
+}
+
+func (w *WebSearchAnthropicWriter) runWebSearchLoop(ctx context.Context, initialResponse api.ChatResponse, initialToolCall api.ToolCall, initialUsage anthropic.Usage) (anthropic.MessagesResponse, *webSearchLoopError) {
+ followUpMessages := make([]api.Message, 0, len(w.chatReq.Messages)+maxWebSearchLoops*2)
+ followUpMessages = append(followUpMessages, w.chatReq.Messages...)
+
+ followUpTools := append(api.Tools(nil), w.chatReq.Tools...)
+ usage := initialUsage
+ logutil.TraceContext(ctx, "anthropic middleware: web_search loop init",
+ "model", w.req.Model,
+ "tool_call", anthropic.TraceToolCall(initialToolCall),
+ "messages", len(followUpMessages),
+ "tools", len(followUpTools),
+ "max_loops", maxWebSearchLoops,
+ )
+
+ currentResponse := initialResponse
+ currentToolCall := initialToolCall
+
+ var serverContent []anthropic.ContentBlock
+
+ if !isCloudModelName(w.req.Model) {
+ logutil.TraceContext(ctx, "anthropic middleware: web_search execution blocked", "reason", "non_cloud_model")
+ return anthropic.MessagesResponse{}, &webSearchLoopError{
+ code: "web_search_not_supported_for_local_models",
+ query: extractQueryFromToolCall(&initialToolCall),
+ usage: usage,
+ }
+ }
+
+ for loop := 1; loop <= maxWebSearchLoops; loop++ {
+ query := extractQueryFromToolCall(¤tToolCall)
+ logutil.TraceContext(ctx, "anthropic middleware: web_search loop iteration",
+ "loop", loop,
+ "query", anthropic.TraceTruncateString(query),
+ "messages", len(followUpMessages),
+ )
+ if query == "" {
+ return anthropic.MessagesResponse{}, &webSearchLoopError{
+ code: "invalid_request",
+ query: "",
+ usage: usage,
+ }
+ }
+
+ const defaultMaxResults = 5
+ searchResp, err := anthropic.WebSearch(ctx, query, defaultMaxResults)
+ if err != nil {
+ logutil.TraceContext(ctx, "anthropic middleware: web_search request failed",
+ "loop", loop,
+ "query", query,
+ "error", err,
+ )
+ return anthropic.MessagesResponse{}, &webSearchLoopError{
+ code: "unavailable",
+ query: query,
+ usage: usage,
+ err: err,
+ }
+ }
+ logutil.TraceContext(ctx, "anthropic middleware: web_search results",
+ "loop", loop,
+ "results", len(searchResp.Results),
+ )
+
+ toolUseID := loopServerToolUseID(w.inner.id, loop)
+ searchResults := anthropic.ConvertOllamaToAnthropicResults(searchResp)
+ serverContent = append(serverContent,
+ anthropic.ContentBlock{
+ Type: "server_tool_use",
+ ID: toolUseID,
+ Name: "web_search",
+ Input: map[string]any{"query": query},
+ },
+ anthropic.ContentBlock{
+ Type: "web_search_tool_result",
+ ToolUseID: toolUseID,
+ Content: searchResults,
+ },
+ )
+
+ assistantMsg := buildWebSearchAssistantMessage(currentResponse, currentToolCall)
+ toolResultMsg := api.Message{
+ Role: "tool",
+ Content: formatWebSearchResultsForToolMessage(searchResp.Results),
+ ToolCallID: currentToolCall.ID,
+ }
+ followUpMessages = append(followUpMessages, assistantMsg, toolResultMsg)
+
+ followUpResponse, err := w.callFollowUpChat(ctx, followUpMessages, followUpTools)
+ if err != nil {
+ logutil.TraceContext(ctx, "anthropic middleware: followup /api/chat failed",
+ "loop", loop,
+ "query", query,
+ "error", err,
+ )
+ return anthropic.MessagesResponse{}, &webSearchLoopError{
+ code: "api_error",
+ query: query,
+ usage: usage,
+ err: err,
+ }
+ }
+ logutil.TraceContext(ctx, "anthropic middleware: followup response",
+ "loop", loop,
+ "resp", anthropic.TraceChatResponse(followUpResponse),
+ )
+
+ usage.InputTokens += followUpResponse.Metrics.PromptEvalCount
+ usage.OutputTokens += followUpResponse.Metrics.EvalCount
+
+ nextToolCall, hasWebSearch, hasOtherTools := findWebSearchToolCall(followUpResponse.Message.ToolCalls)
+ if hasWebSearch && hasOtherTools {
+ // Prefer web_search if both server and client tools are present in one chunk.
+ slog.Debug("preferring web_search tool call over client tool calls in mixed followup response")
+ }
+
+ if !hasWebSearch {
+ finalResponse := w.combineServerAndFinalContent(serverContent, followUpResponse, usage)
+ logutil.TraceContext(ctx, "anthropic middleware: web_search loop complete",
+ "loop", loop,
+ "resp", anthropic.TraceMessagesResponse(finalResponse),
+ )
+ return finalResponse, nil
+ }
+
+ currentResponse = followUpResponse
+ currentToolCall = nextToolCall
+ }
+
+ maxLoopQuery := extractQueryFromToolCall(¤tToolCall)
+ maxLoopToolUseID := loopServerToolUseID(w.inner.id, maxWebSearchLoops+1)
+ serverContent = append(serverContent,
+ anthropic.ContentBlock{
+ Type: "server_tool_use",
+ ID: maxLoopToolUseID,
+ Name: "web_search",
+ Input: map[string]any{"query": maxLoopQuery},
+ },
+ anthropic.ContentBlock{
+ Type: "web_search_tool_result",
+ ToolUseID: maxLoopToolUseID,
+ Content: anthropic.WebSearchToolResultError{
+ Type: "web_search_tool_result_error",
+ ErrorCode: "max_uses_exceeded",
+ },
+ },
+ )
+
+ maxResponse := anthropic.MessagesResponse{
+ ID: w.inner.id,
+ Type: "message",
+ Role: "assistant",
+ Model: w.req.Model,
+ Content: serverContent,
+ StopReason: "end_turn",
+ Usage: usage,
+ }
+ logutil.TraceContext(ctx, "anthropic middleware: web_search loop max reached",
+ "resp", anthropic.TraceMessagesResponse(maxResponse),
+ )
+ return maxResponse, nil
+}
+
+func (w *WebSearchAnthropicWriter) startLoopWorker(initialResponse api.ChatResponse, initialToolCall api.ToolCall) {
+ if w.loopInFlight {
+ return
+ }
+
+ initialUsage := anthropic.Usage{
+ InputTokens: max(w.observedPromptEvalCount, initialResponse.Metrics.PromptEvalCount),
+ OutputTokens: max(w.observedEvalCount, initialResponse.Metrics.EvalCount),
+ }
+ w.loopBaseInputTok = initialUsage.InputTokens
+ w.loopBaseOutputTok = initialUsage.OutputTokens
+ w.loopResultCh = make(chan webSearchLoopResult, 1)
+ w.loopInFlight = true
+ logutil.Trace("anthropic middleware: loop worker started",
+ "usage", initialUsage,
+ "tool_call", anthropic.TraceToolCall(initialToolCall),
+ )
+
+ go func() {
+ ctx, cancel := w.startLoopContext()
+ defer cancel()
+
+ response, loopErr := w.runWebSearchLoop(ctx, initialResponse, initialToolCall, initialUsage)
+ w.loopResultCh <- webSearchLoopResult{
+ response: response,
+ loopErr: loopErr,
+ }
+ }()
+}
+
+func (w *WebSearchAnthropicWriter) writeLoopResult() error {
+ if w.loopResultCh == nil {
+ return w.sendError("api_error", "", w.currentObservedUsage())
+ }
+
+ result := <-w.loopResultCh
+ w.loopResultCh = nil
+ w.loopInFlight = false
+ if result.loopErr != nil {
+ logutil.Trace("anthropic middleware: loop worker returned error",
+ "code", result.loopErr.code,
+ "query", result.loopErr.query,
+ "usage", result.loopErr.usage,
+ "error", result.loopErr.err,
+ )
+ usage := result.loopErr.usage
+ w.applyObservedUsageDeltaToUsage(&usage)
+ return w.sendError(result.loopErr.code, result.loopErr.query, usage)
+ }
+ logutil.Trace("anthropic middleware: loop worker done", "resp", anthropic.TraceMessagesResponse(result.response))
+
+ w.applyObservedUsageDelta(&result.response)
+ return w.writeTerminalResponse(result.response)
+}
+
+func (w *WebSearchAnthropicWriter) applyObservedUsageDelta(response *anthropic.MessagesResponse) {
+ w.applyObservedUsageDeltaToUsage(&response.Usage)
+}
+
+func (w *WebSearchAnthropicWriter) recordObservedUsage(metrics api.Metrics) {
+ if metrics.PromptEvalCount > w.observedPromptEvalCount {
+ w.observedPromptEvalCount = metrics.PromptEvalCount
+ }
+ if metrics.EvalCount > w.observedEvalCount {
+ w.observedEvalCount = metrics.EvalCount
+ }
+}
+
+func (w *WebSearchAnthropicWriter) applyObservedUsageDeltaToUsage(usage *anthropic.Usage) {
+ if deltaIn := w.observedPromptEvalCount - w.loopBaseInputTok; deltaIn > 0 {
+ usage.InputTokens += deltaIn
+ }
+ if deltaOut := w.observedEvalCount - w.loopBaseOutputTok; deltaOut > 0 {
+ usage.OutputTokens += deltaOut
+ }
+}
+
+func (w *WebSearchAnthropicWriter) currentObservedUsage() anthropic.Usage {
+ return anthropic.Usage{
+ InputTokens: w.observedPromptEvalCount,
+ OutputTokens: w.observedEvalCount,
+ }
+}
+
+func (w *WebSearchAnthropicWriter) startLoopContext() (context.Context, context.CancelFunc) {
+ if w.newLoopContext != nil {
+ return w.newLoopContext()
+ }
+ return context.WithTimeout(context.Background(), 5*time.Minute)
+}
+
+func (w *WebSearchAnthropicWriter) combineServerAndFinalContent(serverContent []anthropic.ContentBlock, finalResponse api.ChatResponse, usage anthropic.Usage) anthropic.MessagesResponse {
+ converted := anthropic.ToMessagesResponse(w.inner.id, finalResponse)
+
+ content := make([]anthropic.ContentBlock, 0, len(serverContent)+len(converted.Content))
+ content = append(content, serverContent...)
+ content = append(content, converted.Content...)
+
+ return anthropic.MessagesResponse{
+ ID: w.inner.id,
+ Type: "message",
+ Role: "assistant",
+ Model: w.req.Model,
+ Content: content,
+ StopReason: converted.StopReason,
+ StopSequence: converted.StopSequence,
+ Usage: usage,
+ }
+}
+
+func buildWebSearchAssistantMessage(response api.ChatResponse, webSearchCall api.ToolCall) api.Message {
+ assistantMsg := api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{webSearchCall},
+ }
+ if response.Message.Content != "" {
+ assistantMsg.Content = response.Message.Content
+ }
+ if response.Message.Thinking != "" {
+ assistantMsg.Thinking = response.Message.Thinking
+ }
+ return assistantMsg
+}
+
+func formatWebSearchResultsForToolMessage(results []anthropic.OllamaWebSearchResult) string {
+ var resultText strings.Builder
+ for _, r := range results {
+ fmt.Fprintf(&resultText, "Title: %s\nURL: %s\n", r.Title, r.URL)
+ if r.Content != "" {
+ fmt.Fprintf(&resultText, "Content: %s\n", r.Content)
+ }
+ resultText.WriteString("\n")
+ }
+ return resultText.String()
+}
+
+func findWebSearchToolCall(toolCalls []api.ToolCall) (api.ToolCall, bool, bool) {
+ var webSearchCall api.ToolCall
+ hasWebSearch := false
+ hasOtherTools := false
+
+ for _, toolCall := range toolCalls {
+ if toolCall.Function.Name == "web_search" {
+ if !hasWebSearch {
+ webSearchCall = toolCall
+ hasWebSearch = true
+ }
+ continue
+ }
+ hasOtherTools = true
+ }
+
+ return webSearchCall, hasWebSearch, hasOtherTools
+}
+
+func loopServerToolUseID(messageID string, loop int) string {
+ base := serverToolUseID(messageID)
+ if loop <= 1 {
+ return base
+ }
+ return fmt.Sprintf("%s_%d", base, loop)
+}
+
+func (w *WebSearchAnthropicWriter) callFollowUpChat(ctx context.Context, messages []api.Message, tools api.Tools) (api.ChatResponse, error) {
+ streaming := false
+ followUp := api.ChatRequest{
+ Model: w.chatReq.Model,
+ Messages: messages,
+ Stream: &streaming,
+ Tools: tools,
+ Options: w.chatReq.Options,
+ }
+
+ body, err := json.Marshal(followUp)
+ if err != nil {
+ return api.ChatResponse{}, err
+ }
+
+ chatURL := envconfig.Host().String() + "/api/chat"
+ logutil.TraceContext(ctx, "anthropic middleware: followup request",
+ "url", chatURL,
+ "req", anthropic.TraceChatRequest(&followUp),
+ )
+ httpReq, err := http.NewRequestWithContext(ctx, "POST", chatURL, bytes.NewReader(body))
+ if err != nil {
+ return api.ChatResponse{}, err
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := http.DefaultClient.Do(httpReq)
+ if err != nil {
+ return api.ChatResponse{}, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ respBody, _ := io.ReadAll(resp.Body)
+ logutil.TraceContext(ctx, "anthropic middleware: followup non-200 response",
+ "status", resp.StatusCode,
+ "response", strings.TrimSpace(string(respBody)),
+ )
+ return api.ChatResponse{}, fmt.Errorf("followup /api/chat returned status %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
+ }
+
+ var chatResp api.ChatResponse
+ if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
+ return api.ChatResponse{}, err
+ }
+ logutil.TraceContext(ctx, "anthropic middleware: followup decoded", "resp", anthropic.TraceChatResponse(chatResp))
+
+ return chatResp, nil
+}
+
+func (w *WebSearchAnthropicWriter) writePassthroughStreamChunk(chatResponse api.ChatResponse) error {
+ events := w.inner.converter.Process(chatResponse)
+ for _, event := range events {
+ switch e := event.Data.(type) {
+ case anthropic.MessageStartEvent:
+ w.streamMessageStarted = true
+ case anthropic.ContentBlockStartEvent:
+ w.streamHasOpenBlock = true
+ w.streamOpenBlockIndex = e.Index
+ if e.Index+1 > w.streamNextIndex {
+ w.streamNextIndex = e.Index + 1
+ }
+ case anthropic.ContentBlockStopEvent:
+ if w.streamHasOpenBlock && w.streamOpenBlockIndex == e.Index {
+ w.streamHasOpenBlock = false
+ }
+ if e.Index+1 > w.streamNextIndex {
+ w.streamNextIndex = e.Index + 1
+ }
+ case anthropic.MessageStopEvent:
+ w.terminalSent = true
+ }
+
+ if err := writeSSE(w.ResponseWriter, event.Event, event.Data); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (w *WebSearchAnthropicWriter) ensureStreamMessageStart(usage anthropic.Usage) error {
+ if w.streamMessageStarted {
+ return nil
+ }
+
+ inputTokens := usage.InputTokens
+ if inputTokens == 0 {
+ inputTokens = w.estimatedInputTokens
+ }
+
+ if err := writeSSE(w.ResponseWriter, "message_start", anthropic.MessageStartEvent{
+ Type: "message_start",
+ Message: anthropic.MessagesResponse{
+ ID: w.inner.id,
+ Type: "message",
+ Role: "assistant",
+ Model: w.req.Model,
+ Content: []anthropic.ContentBlock{},
+ Usage: anthropic.Usage{
+ InputTokens: inputTokens,
+ },
+ },
+ }); err != nil {
+ return err
+ }
+
+ w.streamMessageStarted = true
+ return nil
+}
+
+func (w *WebSearchAnthropicWriter) closeOpenStreamBlock() error {
+ if !w.streamHasOpenBlock {
+ return nil
+ }
+
+ if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
+ Type: "content_block_stop",
+ Index: w.streamOpenBlockIndex,
+ }); err != nil {
+ return err
+ }
+
+ if w.streamOpenBlockIndex+1 > w.streamNextIndex {
+ w.streamNextIndex = w.streamOpenBlockIndex + 1
+ }
+ w.streamHasOpenBlock = false
+ return nil
+}
+
+func (w *WebSearchAnthropicWriter) writeStreamContentBlocks(content []anthropic.ContentBlock) error {
+ for _, block := range content {
+ index := w.streamNextIndex
+ if block.Type == "text" {
+ emptyText := ""
+ if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
+ Type: "content_block_start",
+ Index: index,
+ ContentBlock: anthropic.ContentBlock{
+ Type: "text",
+ Text: &emptyText,
+ },
+ }); err != nil {
+ return err
+ }
+
+ text := ""
+ if block.Text != nil {
+ text = *block.Text
+ }
+ if err := writeSSE(w.ResponseWriter, "content_block_delta", anthropic.ContentBlockDeltaEvent{
+ Type: "content_block_delta",
+ Index: index,
+ Delta: anthropic.Delta{
+ Type: "text_delta",
+ Text: text,
+ },
+ }); err != nil {
+ return err
+ }
+ } else {
+ if err := writeSSE(w.ResponseWriter, "content_block_start", anthropic.ContentBlockStartEvent{
+ Type: "content_block_start",
+ Index: index,
+ ContentBlock: block,
+ }); err != nil {
+ return err
+ }
+ }
+
+ if err := writeSSE(w.ResponseWriter, "content_block_stop", anthropic.ContentBlockStopEvent{
+ Type: "content_block_stop",
+ Index: index,
+ }); err != nil {
+ return err
+ }
+
+ w.streamNextIndex++
+ }
+
+ return nil
+}
+
+func (w *WebSearchAnthropicWriter) writeTerminalResponse(response anthropic.MessagesResponse) error {
+ if w.terminalSent {
+ return nil
+ }
+
+ if !w.stream {
+ w.ResponseWriter.Header().Set("Content-Type", "application/json")
+ if err := json.NewEncoder(w.ResponseWriter).Encode(response); err != nil {
+ return err
+ }
+ w.terminalSent = true
+ return nil
+ }
+
+ if err := w.ensureStreamMessageStart(response.Usage); err != nil {
+ return err
+ }
+ if err := w.closeOpenStreamBlock(); err != nil {
+ return err
+ }
+ if err := w.writeStreamContentBlocks(response.Content); err != nil {
+ return err
+ }
+
+ if err := writeSSE(w.ResponseWriter, "message_delta", anthropic.MessageDeltaEvent{
+ Type: "message_delta",
+ Delta: anthropic.MessageDelta{
+ StopReason: response.StopReason,
+ },
+ Usage: anthropic.DeltaUsage{
+ InputTokens: response.Usage.InputTokens,
+ OutputTokens: response.Usage.OutputTokens,
+ },
+ }); err != nil {
+ return err
+ }
+
+ if err := writeSSE(w.ResponseWriter, "message_stop", anthropic.MessageStopEvent{
+ Type: "message_stop",
+ }); err != nil {
+ return err
+ }
+
+ w.terminalSent = true
+ return nil
+}
+
+// streamResponse emits a complete MessagesResponse as SSE events.
+func (w *WebSearchAnthropicWriter) streamResponse(response anthropic.MessagesResponse) error {
+ return w.writeTerminalResponse(response)
+}
+
+func (w *WebSearchAnthropicWriter) webSearchErrorResponse(errorCode, query string, usage anthropic.Usage) anthropic.MessagesResponse {
+ toolUseID := serverToolUseID(w.inner.id)
+
+ return anthropic.MessagesResponse{
+ ID: w.inner.id,
+ Type: "message",
+ Role: "assistant",
+ Model: w.req.Model,
+ Content: []anthropic.ContentBlock{
+ {
+ Type: "server_tool_use",
+ ID: toolUseID,
+ Name: "web_search",
+ Input: map[string]any{"query": query},
+ },
+ {
+ Type: "web_search_tool_result",
+ ToolUseID: toolUseID,
+ Content: anthropic.WebSearchToolResultError{
+ Type: "web_search_tool_result_error",
+ ErrorCode: errorCode,
+ },
+ },
+ },
+ StopReason: "end_turn",
+ Usage: usage,
+ }
+}
+
+// sendError sends a web search error response.
+func (w *WebSearchAnthropicWriter) sendError(errorCode, query string, usage anthropic.Usage) error {
+ response := w.webSearchErrorResponse(errorCode, query, usage)
+ logutil.Trace("anthropic middleware: web_search error", "code", errorCode, "query", query, "usage", usage)
+ return w.writeTerminalResponse(response)
+}
+
// AnthropicMessagesMiddleware handles Anthropic Messages API requests
func AnthropicMessagesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
+ requestCtx := c.Request.Context()
+
var req anthropic.MessagesRequest
err := c.ShouldBindJSON(&req)
if err != nil {
@@ -131,12 +862,14 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
messageID := anthropic.GenerateMessageID()
- w := &AnthropicWriter{
+ // Estimate input tokens for streaming (actual count not available until generation completes)
+ estimatedTokens := anthropic.EstimateInputTokens(req)
+
+ innerWriter := &AnthropicWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: messageID,
- model: req.Model,
- converter: anthropic.NewStreamConverter(messageID, req.Model),
+ converter: anthropic.NewStreamConverter(messageID, req.Model, estimatedTokens),
}
if req.Stream {
@@ -145,8 +878,78 @@ func AnthropicMessagesMiddleware() gin.HandlerFunc {
c.Writer.Header().Set("Connection", "keep-alive")
}
- c.Writer = w
+ if hasWebSearchTool(req.Tools) {
+ // Guard against runtime cloud-disable policy (OLLAMA_NO_CLOUD/server.json)
+ // for cloud models. Local models may still receive web_search tool definitions;
+ // execution is validated when the model actually emits a web_search tool call.
+ if isCloudModelName(req.Model) {
+ if disabled, _ := internalcloud.Status(); disabled {
+ c.AbortWithStatusJSON(http.StatusForbidden, anthropic.NewError(http.StatusForbidden, internalcloud.DisabledError("web search is unavailable")))
+ return
+ }
+ }
+
+ c.Writer = &WebSearchAnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+ newLoopContext: func() (context.Context, context.CancelFunc) {
+ return context.WithTimeout(requestCtx, 5*time.Minute)
+ },
+ inner: innerWriter,
+ req: req,
+ chatReq: chatReq,
+ stream: req.Stream,
+ estimatedInputTokens: estimatedTokens,
+ }
+ } else {
+ c.Writer = innerWriter
+ }
c.Next()
}
}
+
+// hasWebSearchTool checks if the request tools include a web_search tool
+func hasWebSearchTool(tools []anthropic.Tool) bool {
+ for _, tool := range tools {
+ if strings.HasPrefix(tool.Type, "web_search") {
+ return true
+ }
+ }
+ return false
+}
+
+func isCloudModelName(name string) bool {
+ return strings.HasSuffix(name, ":cloud") || strings.HasSuffix(name, "-cloud")
+}
+
+// extractQueryFromToolCall extracts the search query from a web_search tool call
+func extractQueryFromToolCall(tc *api.ToolCall) string {
+ q, ok := tc.Function.Arguments.Get("query")
+ if !ok {
+ return ""
+ }
+ if s, ok := q.(string); ok {
+ return s
+ }
+ return ""
+}
+
+// writeSSE writes a Server-Sent Event
+func writeSSE(w http.ResponseWriter, eventType string, data any) error {
+ d, err := json.Marshal(data)
+ if err != nil {
+ return err
+ }
+ if _, err := fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, d); err != nil {
+ return err
+ }
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+ return nil
+}
+
+// serverToolUseID derives a server tool use ID from a message ID
+func serverToolUseID(messageID string) string {
+ return "srvtoolu_" + strings.TrimPrefix(messageID, "msg_")
+}
diff --git a/middleware/anthropic_test.go b/middleware/anthropic_test.go
index a913fd3c49a..dacdd5ce654 100644
--- a/middleware/anthropic_test.go
+++ b/middleware/anthropic_test.go
@@ -605,3 +605,2375 @@ func TestAnthropicMessagesMiddleware_SetsRelaxThinkingFlag(t *testing.T) {
t.Error("expected relax_thinking flag to be set in context")
}
}
+
+// Web Search Tests
+
+func TestHasWebSearchTool(t *testing.T) {
+ tests := []struct {
+ name string
+ tools []anthropic.Tool
+ expected bool
+ }{
+ {
+ name: "no tools",
+ tools: nil,
+ expected: false,
+ },
+ {
+ name: "regular tool only",
+ tools: []anthropic.Tool{
+ {Type: "custom", Name: "get_weather"},
+ },
+ expected: false,
+ },
+ {
+ name: "web search tool",
+ tools: []anthropic.Tool{
+ {Type: "web_search_20250305", Name: "web_search"},
+ },
+ expected: true,
+ },
+ {
+ name: "mixed tools",
+ tools: []anthropic.Tool{
+ {Type: "custom", Name: "get_weather"},
+ {Type: "web_search_20250305", Name: "web_search"},
+ },
+ expected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := hasWebSearchTool(tt.tools)
+ if result != tt.expected {
+ t.Errorf("expected %v, got %v", tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestExtractQueryFromToolCall(t *testing.T) {
+ tests := []struct {
+ name string
+ tc *api.ToolCall
+ expected string
+ }{
+ {
+ name: "valid query",
+ tc: &api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "test search"),
+ },
+ },
+ expected: "test search",
+ },
+ {
+ name: "empty arguments",
+ tc: &api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ },
+ },
+ expected: "",
+ },
+ {
+ name: "no query key",
+ tc: &api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("other", "value"),
+ },
+ },
+ expected: "",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := extractQueryFromToolCall(tt.tc)
+ if result != tt.expected {
+ t.Errorf("expected %q, got %q", tt.expected, result)
+ }
+ })
+ }
+}
+
+// makeArgs is a test helper that creates ToolCallFunctionArguments
+func makeArgs(key string, value any) api.ToolCallFunctionArguments {
+ args := api.NewToolCallFunctionArguments()
+ args.Set(key, value)
+ return args
+}
+
+// --- Web Search Integration Tests ---
+
+// TestWebSearchServerToolUseID tests the ID derivation logic.
+func TestWebSearchServerToolUseID(t *testing.T) {
+ tests := []struct {
+ msgID string
+ expected string
+ }{
+ {"msg_abc123", "srvtoolu_abc123"},
+ {"msg_", "srvtoolu_"},
+ {"nomsgprefix", "srvtoolu_nomsgprefix"},
+ }
+ for _, tt := range tests {
+ got := serverToolUseID(tt.msgID)
+ if got != tt.expected {
+ t.Errorf("serverToolUseID(%q) = %q, want %q", tt.msgID, got, tt.expected)
+ }
+ }
+}
+
+// TestWebSearchNoWebSearchTool verifies that when there is no web_search tool,
+// requests pass through to the normal AnthropicWriter without interception.
+func TestWebSearchNoWebSearchTool(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ Content: "Normal response",
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"test-model","max_tokens":100,"messages":[{"role":"user","content":"Hello"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ if result.Type != "message" {
+ t.Errorf("expected type 'message', got %q", result.Type)
+ }
+ if len(result.Content) != 1 || result.Content[0].Type != "text" {
+ t.Fatalf("expected single text block, got %d blocks", len(result.Content))
+ }
+ if *result.Content[0].Text != "Normal response" {
+ t.Errorf("expected text 'Normal response', got %q", *result.Content[0].Text)
+ }
+}
+
+// TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming verifies that when
+// the web_search tool is present but the model does not call it, the response
+// passes through normally (non-streaming case).
+func TestWebSearchToolPresent_ModelDoesNotCallIt_NonStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ Content: "I can answer that without searching.",
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 8},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"What is 2+2?"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ if result.Type != "message" {
+ t.Errorf("expected type 'message', got %q", result.Type)
+ }
+ if len(result.Content) != 1 || result.Content[0].Type != "text" {
+ t.Fatalf("expected single text block, got %+v", result.Content)
+ }
+ if *result.Content[0].Text != "I can answer that without searching." {
+ t.Errorf("unexpected text: %q", *result.Content[0].Text)
+ }
+ if result.StopReason != "end_turn" {
+ t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
+ }
+}
+
+// TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming verifies the streaming
+// pass-through case when the model does not invoke web_search.
+func TestWebSearchToolPresent_ModelDoesNotCallIt_Streaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ // Simulate streaming: two partial chunks then a final chunk
+ chunks := []api.ChatResponse{
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Hello "},
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "world"},
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: ""},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ },
+ }
+ c.Writer.WriteHeader(http.StatusOK)
+ for _, chunk := range chunks {
+ data, _ := json.Marshal(chunk)
+ _, _ = c.Writer.Write(data)
+ }
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "stream":true,
+ "messages":[{"role":"user","content":"Hi"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ // Parse SSE events
+ events := parseSSEEvents(t, resp.Body.String())
+
+ // Should have standard streaming event flow
+ if len(events) == 0 {
+ t.Fatal("expected SSE events, got none")
+ }
+
+ // First event should be message_start
+ if events[0].event != "message_start" {
+ t.Errorf("first event should be message_start, got %q", events[0].event)
+ }
+
+ // Should have content_block_start for text
+ hasTextStart := false
+ hasTextDelta := false
+ hasMessageStop := false
+ for _, e := range events {
+ if e.event == "content_block_start" {
+ var cbs anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(e.data), &cbs); err == nil {
+ if cbs.ContentBlock.Type == "text" {
+ hasTextStart = true
+ }
+ }
+ }
+ if e.event == "content_block_delta" {
+ var cbd anthropic.ContentBlockDeltaEvent
+ if err := json.Unmarshal([]byte(e.data), &cbd); err == nil {
+ if cbd.Delta.Type == "text_delta" {
+ hasTextDelta = true
+ }
+ }
+ }
+ if e.event == "message_stop" {
+ hasMessageStop = true
+ }
+ }
+ if !hasTextStart {
+ t.Error("expected content_block_start with text type")
+ }
+ if !hasTextDelta {
+ t.Error("expected content_block_delta with text_delta")
+ }
+ if !hasMessageStop {
+ t.Error("expected message_stop event")
+ }
+}
+
+// TestWebSearchToolPresent_ModelCallsIt_NonStreaming tests the full web search flow
+// in non-streaming mode. It mocks the followup /api/chat call using a local HTTP server.
+func TestWebSearchToolPresent_ModelCallsIt_NonStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ // Create a mock Ollama server that responds to the followup /api/chat call
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ Content: "Based on my search, the answer is 42.",
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 50, EvalCount: 20},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+
+ // Set OLLAMA_HOST to our mock server so the followup call goes there
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ // Also mock the web search API
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Test Result", URL: "https://example.com/result", Content: "Some content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+
+ // Point DoWebSearch at our mock search server
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_001",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "meaning of life"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"What is the meaning of life?"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v\nbody: %s", err, resp.Body.String())
+ }
+
+ if result.Type != "message" {
+ t.Errorf("expected type 'message', got %q", result.Type)
+ }
+ if result.Role != "assistant" {
+ t.Errorf("expected role 'assistant', got %q", result.Role)
+ }
+
+ // Should have 3 blocks: server_tool_use + web_search_tool_result + text
+ if len(result.Content) != 3 {
+ t.Fatalf("expected 3 content blocks, got %d: %+v", len(result.Content), result.Content)
+ }
+
+ if result.Content[0].Type != "server_tool_use" {
+ t.Errorf("expected first block type 'server_tool_use', got %q", result.Content[0].Type)
+ }
+ if result.Content[0].Name != "web_search" {
+ t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
+ }
+
+ if result.Content[1].Type != "web_search_tool_result" {
+ t.Errorf("expected second block type 'web_search_tool_result', got %q", result.Content[1].Type)
+ }
+ if result.Content[1].ToolUseID != result.Content[0].ID {
+ t.Errorf("tool_use_id mismatch: %q != %q", result.Content[1].ToolUseID, result.Content[0].ID)
+ }
+
+ if result.Content[2].Type != "text" {
+ t.Errorf("expected third block type 'text', got %q", result.Content[2].Type)
+ }
+ if result.Content[2].Text == nil || *result.Content[2].Text == "" {
+ t.Error("expected non-empty text in third block")
+ }
+
+ if result.StopReason != "end_turn" {
+ t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
+ }
+}
+
+// TestWebSearchToolPresent_ModelCallsIt_Streaming tests the streaming SSE output
+// when the model calls web_search with mocked search and followup endpoints.
+func TestWebSearchToolPresent_ModelCallsIt_Streaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ // Mock followup /api/chat server
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Here are the latest news."},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 40, EvalCount: 15},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ // Mock web search API
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "News Result", URL: "https://example.com/news", Content: "Breaking news"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ // Simulate buffered streaming: non-final chunk then final with tool call
+ chunks := []api.ChatResponse{
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant"},
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_002",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "latest news"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
+ },
+ }
+ c.Writer.WriteHeader(http.StatusOK)
+ for _, chunk := range chunks {
+ data, _ := json.Marshal(chunk)
+ _, _ = c.Writer.Write(data)
+ }
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "stream":true,
+ "messages":[{"role":"user","content":"What is the latest news?"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ events := parseSSEEvents(t, resp.Body.String())
+
+ // Success path: 10 events (3 blocks: server_tool_use, web_search_tool_result, text with delta)
+ expectedEventTypes := []string{
+ "message_start",
+ "content_block_start", // server_tool_use
+ "content_block_stop",
+ "content_block_start", // web_search_tool_result
+ "content_block_stop",
+ "content_block_start", // text (empty)
+ "content_block_delta", // text_delta with actual content
+ "content_block_stop",
+ "message_delta",
+ "message_stop",
+ }
+
+ if len(events) != len(expectedEventTypes) {
+ t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
+ }
+
+ for i, expected := range expectedEventTypes {
+ if events[i].event != expected {
+ t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
+ }
+ }
+
+ // Verify text delta has the followup model's content
+ var textDelta anthropic.ContentBlockDeltaEvent
+ if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil {
+ t.Fatalf("failed to parse text delta: %v", err)
+ }
+ if textDelta.Delta.Type != "text_delta" {
+ t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type)
+ }
+ if textDelta.Delta.Text != "Here are the latest news." {
+ t.Errorf("expected followup text, got %q", textDelta.Delta.Text)
+ }
+}
+
+// TestWebSearchStreamResponse tests the streamResponse method directly by constructing
+// a WebSearchAnthropicWriter and calling streamResponse with a known response.
+func TestWebSearchStreamResponse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ text := "Here is the answer."
+
+ response := anthropic.MessagesResponse{
+ ID: "msg_test123",
+ Type: "message",
+ Role: "assistant",
+ Model: "test-model",
+ Content: []anthropic.ContentBlock{
+ {
+ Type: "server_tool_use",
+ ID: "srvtoolu_test123",
+ Name: "web_search",
+ Input: map[string]any{"query": "test query"},
+ },
+ {
+ Type: "web_search_tool_result",
+ ToolUseID: "srvtoolu_test123",
+ Content: []anthropic.WebSearchResult{
+ {Type: "web_search_result", URL: "https://example.com", Title: "Example"},
+ },
+ },
+ {
+ Type: "text",
+ Text: &text,
+ },
+ },
+ StopReason: "end_turn",
+ Usage: anthropic.Usage{InputTokens: 20, OutputTokens: 10},
+ }
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+
+ innerWriter := &AnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ stream: true,
+ id: "msg_test123",
+ }
+ wsWriter := &WebSearchAnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ inner: innerWriter,
+ stream: true,
+ req: anthropic.MessagesRequest{Model: "test-model"},
+ }
+
+ if err := wsWriter.streamResponse(response); err != nil {
+ t.Fatalf("streamResponse error: %v", err)
+ }
+
+ events := parseSSEEvents(t, rec.Body.String())
+
+ // Verify full event sequence
+ expectedEventTypes := []string{
+ "message_start",
+ "content_block_start", // server_tool_use (index 0)
+ "content_block_stop", // index 0
+ "content_block_start", // web_search_tool_result (index 1)
+ "content_block_stop", // index 1
+ "content_block_start", // text (index 2)
+ "content_block_delta", // text_delta
+ "content_block_stop", // index 2
+ "message_delta",
+ "message_stop",
+ }
+
+ if len(events) != len(expectedEventTypes) {
+ t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
+ }
+
+ for i, expected := range expectedEventTypes {
+ if events[i].event != expected {
+ t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
+ }
+ }
+
+ // Verify message_start content
+ var msgStart anthropic.MessageStartEvent
+ if err := json.Unmarshal([]byte(events[0].data), &msgStart); err != nil {
+ t.Fatalf("failed to parse message_start: %v", err)
+ }
+ if msgStart.Message.ID != "msg_test123" {
+ t.Errorf("expected message ID 'msg_test123', got %q", msgStart.Message.ID)
+ }
+ if msgStart.Message.Role != "assistant" {
+ t.Errorf("expected role 'assistant', got %q", msgStart.Message.Role)
+ }
+ if len(msgStart.Message.Content) != 0 {
+ t.Errorf("expected empty content in message_start, got %d blocks", len(msgStart.Message.Content))
+ }
+
+ // Verify content_block_start for server_tool_use (event index 1)
+ var toolStart anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil {
+ t.Fatalf("failed to parse server_tool_use start: %v", err)
+ }
+ if toolStart.Index != 0 {
+ t.Errorf("expected index 0, got %d", toolStart.Index)
+ }
+ if toolStart.ContentBlock.Type != "server_tool_use" {
+ t.Errorf("expected type 'server_tool_use', got %q", toolStart.ContentBlock.Type)
+ }
+ if toolStart.ContentBlock.ID != "srvtoolu_test123" {
+ t.Errorf("expected ID 'srvtoolu_test123', got %q", toolStart.ContentBlock.ID)
+ }
+
+ // Verify content_block_start for web_search_tool_result (event index 3)
+ var searchStart anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(events[3].data), &searchStart); err != nil {
+ t.Fatalf("failed to parse web_search_tool_result start: %v", err)
+ }
+ if searchStart.Index != 1 {
+ t.Errorf("expected index 1, got %d", searchStart.Index)
+ }
+ if searchStart.ContentBlock.Type != "web_search_tool_result" {
+ t.Errorf("expected type 'web_search_tool_result', got %q", searchStart.ContentBlock.Type)
+ }
+
+ // Verify text block: content_block_start (event index 5)
+ var textStart anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(events[5].data), &textStart); err != nil {
+ t.Fatalf("failed to parse text start: %v", err)
+ }
+ if textStart.Index != 2 {
+ t.Errorf("expected index 2, got %d", textStart.Index)
+ }
+ if textStart.ContentBlock.Type != "text" {
+ t.Errorf("expected type 'text', got %q", textStart.ContentBlock.Type)
+ }
+ // Text in start should be empty
+ if textStart.ContentBlock.Text == nil || *textStart.ContentBlock.Text != "" {
+ t.Errorf("expected empty text in content_block_start, got %v", textStart.ContentBlock.Text)
+ }
+
+ // Verify text delta (event index 6)
+ var textDelta anthropic.ContentBlockDeltaEvent
+ if err := json.Unmarshal([]byte(events[6].data), &textDelta); err != nil {
+ t.Fatalf("failed to parse text delta: %v", err)
+ }
+ if textDelta.Index != 2 {
+ t.Errorf("expected index 2, got %d", textDelta.Index)
+ }
+ if textDelta.Delta.Type != "text_delta" {
+ t.Errorf("expected delta type 'text_delta', got %q", textDelta.Delta.Type)
+ }
+ if textDelta.Delta.Text != "Here is the answer." {
+ t.Errorf("expected delta text 'Here is the answer.', got %q", textDelta.Delta.Text)
+ }
+
+ // Verify message_delta (event index 8)
+ var msgDelta anthropic.MessageDeltaEvent
+ if err := json.Unmarshal([]byte(events[8].data), &msgDelta); err != nil {
+ t.Fatalf("failed to parse message_delta: %v", err)
+ }
+ if msgDelta.Delta.StopReason != "end_turn" {
+ t.Errorf("expected stop_reason 'end_turn', got %q", msgDelta.Delta.StopReason)
+ }
+ if msgDelta.Usage.InputTokens != 20 {
+ t.Errorf("expected input_tokens 20, got %d", msgDelta.Usage.InputTokens)
+ }
+ if msgDelta.Usage.OutputTokens != 10 {
+ t.Errorf("expected output_tokens 10, got %d", msgDelta.Usage.OutputTokens)
+ }
+}
+
+// TestWebSearchSendError_NonStreaming tests sendError produces correct response shape.
+func TestWebSearchSendError_NonStreaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+
+ innerWriter := &AnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ stream: false,
+ id: "msg_err001",
+ }
+ wsWriter := &WebSearchAnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ inner: innerWriter,
+ stream: false,
+ req: anthropic.MessagesRequest{Model: "test-model"},
+ }
+
+ errorUsage := anthropic.Usage{InputTokens: 7, OutputTokens: 2}
+ if err := wsWriter.sendError("unavailable", "test query", errorUsage); err != nil {
+ t.Fatalf("sendError error: %v", err)
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v\nbody: %s", err, rec.Body.String())
+ }
+
+ if result.Type != "message" {
+ t.Errorf("expected type 'message', got %q", result.Type)
+ }
+ if result.ID != "msg_err001" {
+ t.Errorf("expected ID 'msg_err001', got %q", result.ID)
+ }
+
+ // Should have exactly 2 blocks: server_tool_use + web_search_tool_result
+ if len(result.Content) != 2 {
+ t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
+ }
+
+ // Block 0: server_tool_use
+ if result.Content[0].Type != "server_tool_use" {
+ t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type)
+ }
+ expectedToolID := "srvtoolu_err001"
+ if result.Content[0].ID != expectedToolID {
+ t.Errorf("expected ID %q, got %q", expectedToolID, result.Content[0].ID)
+ }
+ if result.Content[0].Name != "web_search" {
+ t.Errorf("expected name 'web_search', got %q", result.Content[0].Name)
+ }
+ // Verify input contains the query
+ inputMap, ok := result.Content[0].Input.(map[string]any)
+ if !ok {
+ t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
+ }
+ if inputMap["query"] != "test query" {
+ t.Errorf("expected query 'test query', got %v", inputMap["query"])
+ }
+
+ // Block 1: web_search_tool_result with error
+ if result.Content[1].Type != "web_search_tool_result" {
+ t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type)
+ }
+ if result.Content[1].ToolUseID != expectedToolID {
+ t.Errorf("expected tool_use_id %q, got %q", expectedToolID, result.Content[1].ToolUseID)
+ }
+
+ // The Content field should be a WebSearchToolResultError
+ contentJSON, _ := json.Marshal(result.Content[1].Content)
+ var errContent anthropic.WebSearchToolResultError
+ if err := json.Unmarshal(contentJSON, &errContent); err != nil {
+ t.Fatalf("failed to parse error content: %v\nraw: %s", err, string(contentJSON))
+ }
+ if errContent.Type != "web_search_tool_result_error" {
+ t.Errorf("expected error type 'web_search_tool_result_error', got %q", errContent.Type)
+ }
+ if errContent.ErrorCode != "unavailable" {
+ t.Errorf("expected error_code 'unavailable', got %q", errContent.ErrorCode)
+ }
+
+ if result.StopReason != "end_turn" {
+ t.Errorf("expected stop_reason 'end_turn', got %q", result.StopReason)
+ }
+ if result.Usage != errorUsage {
+ t.Errorf("expected usage %+v, got %+v", errorUsage, result.Usage)
+ }
+}
+
+// TestWebSearchSendError_Streaming tests sendError in streaming mode produces proper SSE.
+func TestWebSearchSendError_Streaming(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+
+ innerWriter := &AnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ stream: true,
+ id: "msg_err002",
+ }
+ wsWriter := &WebSearchAnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ inner: innerWriter,
+ stream: true,
+ req: anthropic.MessagesRequest{Model: "test-model"},
+ }
+
+ errorUsage := anthropic.Usage{InputTokens: 9, OutputTokens: 4}
+ if err := wsWriter.sendError("invalid_request", "bad query", errorUsage); err != nil {
+ t.Fatalf("sendError error: %v", err)
+ }
+
+ events := parseSSEEvents(t, rec.Body.String())
+
+ // Error response has 2 blocks: server_tool_use + web_search_tool_result
+ // Expected events: message_start,
+ // content_block_start(server_tool_use), content_block_stop,
+ // content_block_start(web_search_tool_result), content_block_stop,
+ // message_delta, message_stop
+ expectedEventTypes := []string{
+ "message_start",
+ "content_block_start",
+ "content_block_stop",
+ "content_block_start",
+ "content_block_stop",
+ "message_delta",
+ "message_stop",
+ }
+
+ if len(events) != len(expectedEventTypes) {
+ t.Fatalf("expected %d events, got %d.\nEvents: %v", len(expectedEventTypes), len(events), eventNames(events))
+ }
+
+ for i, expected := range expectedEventTypes {
+ if events[i].event != expected {
+ t.Errorf("event[%d]: expected %q, got %q", i, expected, events[i].event)
+ }
+ }
+
+ // Verify the server_tool_use block
+ var toolStart anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(events[1].data), &toolStart); err != nil {
+ t.Fatalf("failed to parse server_tool_use start: %v", err)
+ }
+ if toolStart.ContentBlock.Type != "server_tool_use" {
+ t.Errorf("expected 'server_tool_use', got %q", toolStart.ContentBlock.Type)
+ }
+
+ // Verify the web_search_tool_result block
+ var resultStart anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(events[3].data), &resultStart); err != nil {
+ t.Fatalf("failed to parse web_search_tool_result start: %v", err)
+ }
+ if resultStart.ContentBlock.Type != "web_search_tool_result" {
+ t.Errorf("expected 'web_search_tool_result', got %q", resultStart.ContentBlock.Type)
+ }
+
+ var msgDelta anthropic.MessageDeltaEvent
+ if err := json.Unmarshal([]byte(events[5].data), &msgDelta); err != nil {
+ t.Fatalf("failed to parse message_delta: %v", err)
+ }
+ if msgDelta.Usage.InputTokens != errorUsage.InputTokens || msgDelta.Usage.OutputTokens != errorUsage.OutputTokens {
+ t.Fatalf("expected usage %+v in message_delta, got %+v", errorUsage, msgDelta.Usage)
+ }
+}
+
+// TestWebSearchSendError_EmptyQuery tests sendError with an empty query.
+func TestWebSearchSendError_EmptyQuery(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+
+ innerWriter := &AnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ stream: false,
+ id: "msg_empty001",
+ }
+ wsWriter := &WebSearchAnthropicWriter{
+ BaseWriter: BaseWriter{ResponseWriter: ginCtx.Writer},
+ inner: innerWriter,
+ stream: false,
+ req: anthropic.MessagesRequest{Model: "test-model"},
+ }
+
+ if err := wsWriter.sendError("invalid_request", "", anthropic.Usage{}); err != nil {
+ t.Fatalf("sendError error: %v", err)
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(rec.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ if len(result.Content) != 2 {
+ t.Fatalf("expected 2 content blocks, got %d", len(result.Content))
+ }
+
+ // Verify the input has empty query
+ inputMap, ok := result.Content[0].Input.(map[string]any)
+ if !ok {
+ t.Fatalf("expected Input to be map, got %T", result.Content[0].Input)
+ }
+ if inputMap["query"] != "" {
+ t.Errorf("expected empty query, got %v", inputMap["query"])
+ }
+}
+
+// --- SSE parsing helpers ---
+
+type sseEvent struct {
+ event string
+ data string
+}
+
+// parseSSEEvents parses Server-Sent Events from a string.
+func parseSSEEvents(t *testing.T, body string) []sseEvent {
+ t.Helper()
+ var events []sseEvent
+ var currentEvent string
+ var currentData strings.Builder
+
+ for _, line := range strings.Split(body, "\n") {
+ if strings.HasPrefix(line, "event: ") {
+ currentEvent = strings.TrimPrefix(line, "event: ")
+ } else if strings.HasPrefix(line, "data: ") {
+ currentData.WriteString(strings.TrimPrefix(line, "data: "))
+ } else if line == "" && currentEvent != "" {
+ events = append(events, sseEvent{event: currentEvent, data: currentData.String()})
+ currentEvent = ""
+ currentData.Reset()
+ }
+ }
+ return events
+}
+
+// eventNames returns a list of event type names for debugging.
+func eventNames(events []sseEvent) []string {
+ names := make([]string, len(events))
+ for i, e := range events {
+ names[i] = e.event
+ }
+ return names
+}
+
+// TestWebSearchCloudModelGating tests web_search behavior across model types.
+func TestWebSearchCloudModelGating(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ t.Run("local model allowed when web_search is not called", func(t *testing.T) {
+ handlerCalled := false
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ handlerCalled = true
+ resp := api.ChatResponse{
+ Model: "llama3.2",
+ Message: api.Message{Role: "assistant", Content: "hello"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ if !handlerCalled {
+ t.Error("handler should be called for local model when web_search is not called")
+ }
+ })
+
+ t.Run("local model emits web_search and gets structured error", func(t *testing.T) {
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "llama3.2",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_local_ws",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "hello"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 8, EvalCount: 2},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+ if len(result.Content) != 2 {
+ t.Fatalf("expected 2 content blocks for local model web_search error, got %d", len(result.Content))
+ }
+ contentJSON, _ := json.Marshal(result.Content[1].Content)
+ var errContent anthropic.WebSearchToolResultError
+ if err := json.Unmarshal(contentJSON, &errContent); err != nil {
+ t.Fatalf("failed to parse web_search error content: %v", err)
+ }
+ if errContent.ErrorCode != "web_search_not_supported_for_local_models" {
+ t.Fatalf("expected web_search_not_supported_for_local_models, got %q", errContent.ErrorCode)
+ }
+ })
+
+ t.Run("model ending in cloud without cloud suffix treated as local", func(t *testing.T) {
+ handlerCalled := false
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ handlerCalled = true
+ resp := api.ChatResponse{
+ Model: "notreallycloud",
+ Message: api.Message{Role: "assistant", Content: "hello"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"notreallycloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if !handlerCalled {
+ t.Error("handler should be called for non-cloud model when web_search is not called")
+ }
+ if resp.Code != http.StatusOK {
+ t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ })
+
+ t.Run("cloud model with size tag allowed", func(t *testing.T) {
+ handlerCalled := false
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ handlerCalled = true
+ resp := api.ChatResponse{
+ Model: "gpt-oss:120b",
+ Message: api.Message{Role: "assistant", Content: "hello"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"gpt-oss:120b-cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if !handlerCalled {
+ t.Error("handler should be called for cloud model")
+ }
+ if resp.Code != http.StatusOK {
+ t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ })
+
+ t.Run("cloud model allowed", func(t *testing.T) {
+ handlerCalled := false
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ handlerCalled = true
+ resp := api.ChatResponse{
+ Model: "kimi-k2.5",
+ Message: api.Message{Role: "assistant", Content: "hello"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if !handlerCalled {
+ t.Error("handler should be called for cloud model")
+ }
+ if resp.Code != http.StatusOK {
+ t.Errorf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ })
+
+ t.Run("cloud disabled blocks web search for cloud model", func(t *testing.T) {
+ t.Setenv("OLLAMA_NO_CLOUD", "1")
+
+ handlerCalled := false
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ handlerCalled = true
+ })
+
+ body := `{"model":"kimi-k2.5:cloud","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusForbidden {
+ t.Fatalf("expected 403, got %d: %s", resp.Code, resp.Body.String())
+ }
+ if handlerCalled {
+ t.Fatal("handler should not be called when cloud is disabled")
+ }
+
+ var errResp anthropic.ErrorResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+ t.Fatalf("failed to parse error response: %v", err)
+ }
+ if !strings.Contains(errResp.Error.Message, "ollama cloud is disabled") {
+ t.Fatalf("expected cloud disabled error, got: %q", errResp.Error.Message)
+ }
+ })
+
+ t.Run("cloud disabled does not block local model if web_search is not called", func(t *testing.T) {
+ t.Setenv("OLLAMA_NO_CLOUD", "1")
+
+ handlerCalled := false
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ handlerCalled = true
+ resp := api.ChatResponse{
+ Model: "llama3.2",
+ Message: api.Message{Role: "assistant", Content: "hello"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 5},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{"model":"llama3.2","max_tokens":100,"messages":[{"role":"user","content":"hello"}],"tools":[{"type":"web_search_20250305","name":"web_search"}]}`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ if !handlerCalled {
+ t.Fatal("handler should be called for local model when web_search is not called")
+ }
+ })
+}
+
+func TestWebSearchDoesNotRequireAuthorizationHeaderForMockEndpoint(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ var authHeader string
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ authHeader = r.Header.Get("Authorization")
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "done"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 2},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_auth",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "auth test"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 4, EvalCount: 1},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"test auth"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ if authHeader != "" {
+ t.Fatalf("expected no Authorization header for mock web search endpoint, got %q", authHeader)
+ }
+}
+
+// TestWebSearchSearchAPIError tests that a failing search API returns a proper error response.
+func TestWebSearchSearchAPIError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ // Mock search server that returns 500
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "internal error", http.StatusInternalServerError)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_err",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "test"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"test"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ // Error response: server_tool_use + web_search_tool_result with error
+ if len(result.Content) != 2 {
+ t.Fatalf("expected 2 content blocks for error, got %d", len(result.Content))
+ }
+ if result.Content[0].Type != "server_tool_use" {
+ t.Errorf("expected 'server_tool_use', got %q", result.Content[0].Type)
+ }
+ if result.Content[1].Type != "web_search_tool_result" {
+ t.Errorf("expected 'web_search_tool_result', got %q", result.Content[1].Type)
+ }
+ if result.Usage.InputTokens != 10 || result.Usage.OutputTokens != 2 {
+ t.Fatalf("expected usage input=10 output=2, got %+v", result.Usage)
+ }
+}
+
+func TestWebSearchStreamingImmediateTakeover(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "After search."},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 10},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ chunks := []api.ChatResponse{
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Preface "},
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_stream_1",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "latest updates"),
+ },
+ },
+ },
+ },
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "ignored chunk"},
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 4},
+ },
+ }
+ c.Writer.WriteHeader(http.StatusOK)
+ for _, chunk := range chunks {
+ data, _ := json.Marshal(chunk)
+ _, _ = c.Writer.Write(data)
+ }
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "stream":true,
+ "messages":[{"role":"user","content":"Find updates"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ events := parseSSEEvents(t, resp.Body.String())
+ if countEventsByName(events, "message_start") != 1 {
+ t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start"))
+ }
+ if countEventsByName(events, "message_stop") != 1 {
+ t.Fatalf("expected exactly one message_stop, got %d", countEventsByName(events, "message_stop"))
+ }
+
+ textDeltas := collectTextDeltas(t, events)
+ if !containsString(textDeltas, "Preface ") {
+ t.Fatalf("expected passthrough text delta, got %v", textDeltas)
+ }
+ if !containsString(textDeltas, "After search.") {
+ t.Fatalf("expected post-search text delta, got %v", textDeltas)
+ }
+ if containsString(textDeltas, "ignored chunk") {
+ t.Fatalf("unexpected text from chunks after takeover: %v", textDeltas)
+ }
+}
+
+func TestWebSearchStreamingUsageUsesObservedChunkMetrics(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "After search."},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 7},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ chunks := []api.ChatResponse{
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Preface "},
+ Done: false,
+ Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4},
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_stream_usage",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "latest updates"),
+ },
+ },
+ },
+ },
+ Done: false,
+ Metrics: api.Metrics{PromptEvalCount: 0, EvalCount: 0},
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 12, EvalCount: 4},
+ },
+ }
+ c.Writer.WriteHeader(http.StatusOK)
+ for _, chunk := range chunks {
+ data, _ := json.Marshal(chunk)
+ _, _ = c.Writer.Write(data)
+ }
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "stream":true,
+ "messages":[{"role":"user","content":"Find updates"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ events := parseSSEEvents(t, resp.Body.String())
+ var messageDelta anthropic.MessageDeltaEvent
+ found := false
+ for _, event := range events {
+ if event.event != "message_delta" {
+ continue
+ }
+ if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil {
+ t.Fatalf("failed to unmarshal message_delta: %v", err)
+ }
+ found = true
+ break
+ }
+ if !found {
+ t.Fatal("expected message_delta event")
+ }
+ if messageDelta.Usage.InputTokens != 32 {
+ t.Fatalf("expected aggregated input tokens 32 (12 passthrough + 20 followup), got %d", messageDelta.Usage.InputTokens)
+ }
+ if messageDelta.Usage.OutputTokens != 11 {
+ t.Fatalf("expected aggregated output tokens 11 (4 passthrough + 7 followup), got %d", messageDelta.Usage.OutputTokens)
+ }
+}
+
+func TestWebSearchMixedToolCallsPreferWebSearch(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Search answer."},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 11, EvalCount: 6},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_other",
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: makeArgs("location", "SF"),
+ },
+ },
+ {
+ ID: "call_ws_mixed",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "latest weather"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 2},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"Weather?"}],
+ "tools":[
+ {"type":"web_search_20250305","name":"web_search"},
+ {"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
+ ]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ if len(result.Content) < 3 {
+ t.Fatalf("expected at least 3 blocks, got %d", len(result.Content))
+ }
+ if result.Content[0].Type != "server_tool_use" {
+ t.Fatalf("expected server_tool_use first, got %q", result.Content[0].Type)
+ }
+ if result.Content[1].Type != "web_search_tool_result" {
+ t.Fatalf("expected web_search_tool_result second, got %q", result.Content[1].Type)
+ }
+
+ for _, block := range result.Content {
+ if block.Type == "tool_use" && block.Name == "get_weather" {
+ t.Fatalf("did not expect get_weather tool_use in mixed web_search-preferred path: %+v", result.Content)
+ }
+ }
+}
+
+func TestWebSearchFollowupClientToolStopReasonToolUse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_weather_final",
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: makeArgs("location", "New York"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 25, EvalCount: 7},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_tool_use",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "forecast"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 15, EvalCount: 3},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"Do I need an umbrella?"}],
+ "tools":[
+ {"type":"web_search_20250305","name":"web_search"},
+ {"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
+ ]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ if result.StopReason != "tool_use" {
+ t.Fatalf("expected stop_reason tool_use, got %q", result.StopReason)
+ }
+ if len(result.Content) < 3 {
+ t.Fatalf("expected server blocks + tool_use, got %d blocks", len(result.Content))
+ }
+ last := result.Content[len(result.Content)-1]
+ if last.Type != "tool_use" {
+ t.Fatalf("expected final block tool_use, got %q", last.Type)
+ }
+ if last.Name != "get_weather" {
+ t.Fatalf("expected final tool name get_weather, got %q", last.Name)
+ }
+ if result.Usage.InputTokens != 40 || result.Usage.OutputTokens != 10 {
+ t.Fatalf("unexpected aggregated usage: %+v", result.Usage)
+ }
+}
+
+func TestWebSearchMultiIterationLoop(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupCall := 0
+ followupDecodeErr := false
+ missingWebSearchTool := false
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var followupReq api.ChatRequest
+ if err := json.NewDecoder(r.Body).Decode(&followupReq); err != nil {
+ followupDecodeErr = true
+ http.Error(w, "bad request", http.StatusBadRequest)
+ return
+ }
+ hasWebSearchTool := false
+ for _, tool := range followupReq.Tools {
+ if tool.Function.Name == "web_search" {
+ hasWebSearchTool = true
+ break
+ }
+ }
+ if !hasWebSearchTool {
+ missingWebSearchTool = true
+ }
+
+ followupCall++
+ switch followupCall {
+ case 1:
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_2",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "loop query 2"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 20, EvalCount: 2},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ case 2:
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Final answer after 2 searches."},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 30, EvalCount: 3},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ default:
+ t.Fatalf("unexpected extra followup call: %d", followupCall)
+ }
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_1",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "loop query 1"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 1},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"do multiple searches"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ if followupCall != 2 {
+ t.Fatalf("expected 2 followup calls, got %d", followupCall)
+ }
+ if followupDecodeErr {
+ t.Fatal("failed to decode followup request body")
+ }
+ if missingWebSearchTool {
+ t.Fatal("expected followup requests to retain web_search tool definition")
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ serverToolUses := 0
+ webResults := 0
+ for _, block := range result.Content {
+ if block.Type == "server_tool_use" {
+ serverToolUses++
+ }
+ if block.Type == "web_search_tool_result" {
+ webResults++
+ }
+ }
+ if serverToolUses != 2 || webResults != 2 {
+ t.Fatalf("expected two search iterations, got server_tool_use=%d web_search_tool_result=%d", serverToolUses, webResults)
+ }
+
+ if result.Usage.InputTokens != 60 || result.Usage.OutputTokens != 6 {
+ t.Fatalf("unexpected aggregated usage: %+v", result.Usage)
+ }
+}
+
+func TestWebSearchLoopMaxLimit(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupCall := 0
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ followupCall++
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_loop_limit",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "loop query next"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 7, EvalCount: 2},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_initial",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "loop query 1"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 5, EvalCount: 1},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"keep searching"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+ if followupCall != 3 {
+ t.Fatalf("expected 3 followup calls before max loop error, got %d", followupCall)
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+
+ last := result.Content[len(result.Content)-1]
+ if last.Type != "web_search_tool_result" {
+ t.Fatalf("expected last block web_search_tool_result, got %q", last.Type)
+ }
+ contentJSON, _ := json.Marshal(last.Content)
+ var errContent anthropic.WebSearchToolResultError
+ if err := json.Unmarshal(contentJSON, &errContent); err != nil {
+ t.Fatalf("failed to parse web search error content: %v", err)
+ }
+ if errContent.ErrorCode != "max_uses_exceeded" {
+ t.Fatalf("expected max_uses_exceeded error, got %q", errContent.ErrorCode)
+ }
+ if result.StopReason != "end_turn" {
+ t.Fatalf("expected end_turn, got %q", result.StopReason)
+ }
+}
+
+func TestWebSearchStreamingFinalStopReasonToolUse(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_weather_stream",
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: makeArgs("location", "Seattle"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 14, EvalCount: 5},
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ chunks := []api.ChatResponse{
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant", Content: "Let me check. "},
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_stream_tool_use",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "weather seattle"),
+ },
+ },
+ },
+ },
+ Done: false,
+ },
+ {
+ Model: "test-model",
+ Message: api.Message{Role: "assistant"},
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 10, EvalCount: 3},
+ },
+ }
+ c.Writer.WriteHeader(http.StatusOK)
+ for _, chunk := range chunks {
+ data, _ := json.Marshal(chunk)
+ _, _ = c.Writer.Write(data)
+ }
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "stream":true,
+ "messages":[{"role":"user","content":"Should I take a jacket?"}],
+ "tools":[
+ {"type":"web_search_20250305","name":"web_search"},
+ {"type":"custom","name":"get_weather","input_schema":{"type":"object"}}
+ ]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ events := parseSSEEvents(t, resp.Body.String())
+ if countEventsByName(events, "message_start") != 1 {
+ t.Fatalf("expected exactly one message_start, got %d", countEventsByName(events, "message_start"))
+ }
+
+ var messageDelta anthropic.MessageDeltaEvent
+ foundMessageDelta := false
+ foundToolUse := false
+ for _, event := range events {
+ if event.event == "message_delta" {
+ foundMessageDelta = true
+ if err := json.Unmarshal([]byte(event.data), &messageDelta); err != nil {
+ t.Fatalf("failed to unmarshal message_delta: %v", err)
+ }
+ }
+ if event.event == "content_block_start" {
+ var start anthropic.ContentBlockStartEvent
+ if err := json.Unmarshal([]byte(event.data), &start); err != nil {
+ t.Fatalf("failed to unmarshal content_block_start: %v", err)
+ }
+ if start.ContentBlock.Type == "tool_use" && start.ContentBlock.Name == "get_weather" {
+ foundToolUse = true
+ }
+ }
+ }
+
+ if !foundMessageDelta {
+ t.Fatal("expected message_delta event")
+ }
+ if messageDelta.Delta.StopReason != "tool_use" {
+ t.Fatalf("expected stop_reason tool_use, got %q", messageDelta.Delta.StopReason)
+ }
+ if !foundToolUse {
+ t.Fatal("expected tool_use content block for get_weather")
+ }
+}
+
+func TestWebSearchFollowupNon200ReturnsApiError(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ enableCloudForTest(t)
+
+ followupServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "boom", http.StatusInternalServerError)
+ }))
+ defer followupServer.Close()
+ t.Setenv("OLLAMA_HOST", followupServer.URL)
+
+ searchServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ resp := anthropic.OllamaWebSearchResponse{
+ Results: []anthropic.OllamaWebSearchResult{
+ {Title: "Result", URL: "https://example.com", Content: "content"},
+ },
+ }
+ _ = json.NewEncoder(w).Encode(resp)
+ }))
+ defer searchServer.Close()
+ originalEndpoint := anthropic.WebSearchEndpoint
+ anthropic.WebSearchEndpoint = searchServer.URL
+ defer func() { anthropic.WebSearchEndpoint = originalEndpoint }()
+
+ router := gin.New()
+ router.Use(AnthropicMessagesMiddleware())
+ router.POST("/v1/messages", func(c *gin.Context) {
+ resp := api.ChatResponse{
+ Model: "test-model",
+ Message: api.Message{
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ ID: "call_ws_non200",
+ Function: api.ToolCallFunction{
+ Name: "web_search",
+ Arguments: makeArgs("query", "test"),
+ },
+ },
+ },
+ },
+ Done: true,
+ DoneReason: "stop",
+ Metrics: api.Metrics{PromptEvalCount: 9, EvalCount: 1},
+ }
+ data, _ := json.Marshal(resp)
+ c.Writer.WriteHeader(http.StatusOK)
+ _, _ = c.Writer.Write(data)
+ })
+
+ body := `{
+ "model":"test-model:cloud",
+ "max_tokens":100,
+ "messages":[{"role":"user","content":"test"}],
+ "tools":[{"type":"web_search_20250305","name":"web_search"}]
+ }`
+ req, _ := http.NewRequest(http.MethodPost, "/v1/messages", strings.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ var result anthropic.MessagesResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
+ t.Fatalf("unmarshal error: %v", err)
+ }
+ if len(result.Content) != 2 {
+ t.Fatalf("expected 2 blocks in error response, got %d", len(result.Content))
+ }
+
+ contentJSON, _ := json.Marshal(result.Content[1].Content)
+ var errContent anthropic.WebSearchToolResultError
+ if err := json.Unmarshal(contentJSON, &errContent); err != nil {
+ t.Fatalf("failed to parse error content: %v", err)
+ }
+ if errContent.ErrorCode != "api_error" {
+ t.Fatalf("expected api_error, got %q", errContent.ErrorCode)
+ }
+ if result.Usage.InputTokens != 9 || result.Usage.OutputTokens != 1 {
+ t.Fatalf("expected usage input=9 output=1, got %+v", result.Usage)
+ }
+}
+
+func countEventsByName(events []sseEvent, eventName string) int {
+ count := 0
+ for _, event := range events {
+ if event.event == eventName {
+ count++
+ }
+ }
+ return count
+}
+
+func collectTextDeltas(t *testing.T, events []sseEvent) []string {
+ t.Helper()
+
+ var deltas []string
+ for _, event := range events {
+ if event.event != "content_block_delta" {
+ continue
+ }
+
+ var delta anthropic.ContentBlockDeltaEvent
+ if err := json.Unmarshal([]byte(event.data), &delta); err != nil {
+ t.Fatalf("failed to unmarshal content_block_delta: %v", err)
+ }
+ if delta.Delta.Type == "text_delta" {
+ deltas = append(deltas, delta.Delta.Text)
+ }
+ }
+
+ return deltas
+}
+
+func containsString(values []string, target string) bool {
+ for _, value := range values {
+ if value == target {
+ return true
+ }
+ }
+ return false
+}
diff --git a/middleware/openai.go b/middleware/openai.go
index beaa9ee9769..dc40fa35194 100644
--- a/middleware/openai.go
+++ b/middleware/openai.go
@@ -11,6 +11,7 @@ import (
"time"
"github.com/gin-gonic/gin"
+ "github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -496,6 +497,17 @@ func (w *ResponsesWriter) Write(data []byte) (int, error) {
func ResponsesMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
+ if c.GetHeader("Content-Encoding") == "zstd" {
+ reader, err := zstd.NewReader(c.Request.Body, zstd.WithDecoderMaxMemory(8<<20))
+ if err != nil {
+ c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "failed to decompress zstd body"))
+ return
+ }
+ defer reader.Close()
+ c.Request.Body = io.NopCloser(reader)
+ c.Request.Header.Del("Content-Encoding")
+ }
+
var req openai.ResponsesRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
@@ -609,3 +621,49 @@ func ImageGenerationsMiddleware() gin.HandlerFunc {
c.Next()
}
}
+
+func ImageEditsMiddleware() gin.HandlerFunc {
+ return func(c *gin.Context) {
+ var req openai.ImageEditRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
+ return
+ }
+
+ if req.Prompt == "" {
+ c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "prompt is required"))
+ return
+ }
+
+ if req.Model == "" {
+ c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "model is required"))
+ return
+ }
+
+ if req.Image == "" {
+ c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "image is required"))
+ return
+ }
+
+ genReq, err := openai.FromImageEditRequest(req)
+ if err != nil {
+ c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
+ return
+ }
+
+ var b bytes.Buffer
+ if err := json.NewEncoder(&b).Encode(genReq); err != nil {
+ c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
+ return
+ }
+
+ c.Request.Body = io.NopCloser(&b)
+
+ w := &ImageWriter{
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
+ }
+
+ c.Writer = w
+ c.Next()
+ }
+}
diff --git a/middleware/openai_test.go b/middleware/openai_test.go
index cc7c3c215ea..79b595a726a 100644
--- a/middleware/openai_test.go
+++ b/middleware/openai_test.go
@@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
+ "github.com/klauspost/compress/zstd"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
@@ -1112,3 +1113,228 @@ func TestImageWriterResponse(t *testing.T) {
t.Errorf("expected image data 'dGVzdC1pbWFnZS1kYXRh', got %s", imageResp.Data[0].B64JSON)
}
}
+
+func TestImageEditsMiddleware(t *testing.T) {
+ type testCase struct {
+ name string
+ body string
+ req api.GenerateRequest
+ err openai.ErrorResponse
+ }
+
+ var capturedRequest *api.GenerateRequest
+
+ // Base64-encoded test image (1x1 pixel PNG)
+ testImage := "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII="
+ decodedImage, _ := base64.StdEncoding.DecodeString("iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=")
+
+ testCases := []testCase{
+ {
+ name: "image edit basic",
+ body: `{
+ "model": "test-model",
+ "prompt": "make it blue",
+ "image": "` + testImage + `"
+ }`,
+ req: api.GenerateRequest{
+ Model: "test-model",
+ Prompt: "make it blue",
+ Images: []api.ImageData{decodedImage},
+ },
+ },
+ {
+ name: "image edit with size",
+ body: `{
+ "model": "test-model",
+ "prompt": "make it blue",
+ "image": "` + testImage + `",
+ "size": "512x768"
+ }`,
+ req: api.GenerateRequest{
+ Model: "test-model",
+ Prompt: "make it blue",
+ Images: []api.ImageData{decodedImage},
+ Width: 512,
+ Height: 768,
+ },
+ },
+ {
+ name: "image edit missing prompt",
+ body: `{
+ "model": "test-model",
+ "image": "` + testImage + `"
+ }`,
+ err: openai.ErrorResponse{
+ Error: openai.Error{
+ Message: "prompt is required",
+ Type: "invalid_request_error",
+ },
+ },
+ },
+ {
+ name: "image edit missing model",
+ body: `{
+ "prompt": "make it blue",
+ "image": "` + testImage + `"
+ }`,
+ err: openai.ErrorResponse{
+ Error: openai.Error{
+ Message: "model is required",
+ Type: "invalid_request_error",
+ },
+ },
+ },
+ {
+ name: "image edit missing image",
+ body: `{
+ "model": "test-model",
+ "prompt": "make it blue"
+ }`,
+ err: openai.ErrorResponse{
+ Error: openai.Error{
+ Message: "image is required",
+ Type: "invalid_request_error",
+ },
+ },
+ },
+ }
+
+ endpoint := func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ }
+
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ router.Use(ImageEditsMiddleware(), captureRequestMiddleware(&capturedRequest))
+ router.Handle(http.MethodPost, "/api/generate", endpoint)
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
+ req.Header.Set("Content-Type", "application/json")
+
+ defer func() { capturedRequest = nil }()
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if tc.err.Error.Message != "" {
+ var errResp openai.ErrorResponse
+ if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
+ t.Fatal(err)
+ }
+ if diff := cmp.Diff(tc.err, errResp); diff != "" {
+ t.Fatalf("errors did not match:\n%s", diff)
+ }
+ return
+ }
+
+ if resp.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
+ }
+
+ if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
+ t.Fatalf("requests did not match:\n%s", diff)
+ }
+ })
+ }
+}
+
+func zstdCompress(t *testing.T, data []byte) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ w, err := zstd.NewWriter(&buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := w.Write(data); err != nil {
+ t.Fatal(err)
+ }
+ if err := w.Close(); err != nil {
+ t.Fatal(err)
+ }
+ return buf.Bytes()
+}
+
+func TestResponsesMiddlewareZstd(t *testing.T) {
+ tests := []struct {
+ name string
+ body string
+ useZstd bool
+ oversized bool
+ wantCode int
+ wantModel string
+ wantMessage string
+ }{
+ {
+ name: "plain JSON",
+ body: `{"model": "test-model", "input": "Hello"}`,
+ wantCode: http.StatusOK,
+ wantModel: "test-model",
+ wantMessage: "Hello",
+ },
+ {
+ name: "zstd compressed",
+ body: `{"model": "test-model", "input": "Hello"}`,
+ useZstd: true,
+ wantCode: http.StatusOK,
+ wantModel: "test-model",
+ wantMessage: "Hello",
+ },
+ {
+ name: "zstd over max decompressed size",
+ oversized: true,
+ useZstd: true,
+ wantCode: http.StatusBadRequest,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ var capturedRequest *api.ChatRequest
+
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ router.Use(ResponsesMiddleware(), captureRequestMiddleware(&capturedRequest))
+ router.Handle(http.MethodPost, "/v1/responses", func(c *gin.Context) {
+ c.Status(http.StatusOK)
+ })
+
+ var bodyReader io.Reader
+ if tt.oversized {
+ bodyReader = bytes.NewReader(zstdCompress(t, bytes.Repeat([]byte("A"), 9<<20)))
+ } else if tt.useZstd {
+ bodyReader = bytes.NewReader(zstdCompress(t, []byte(tt.body)))
+ } else {
+ bodyReader = strings.NewReader(tt.body)
+ }
+
+ req, _ := http.NewRequest(http.MethodPost, "/v1/responses", bodyReader)
+ req.Header.Set("Content-Type", "application/json")
+ if tt.useZstd || tt.oversized {
+ req.Header.Set("Content-Encoding", "zstd")
+ }
+
+ resp := httptest.NewRecorder()
+ router.ServeHTTP(resp, req)
+
+ if resp.Code != tt.wantCode {
+ t.Fatalf("expected status %d, got %d: %s", tt.wantCode, resp.Code, resp.Body.String())
+ }
+
+ if tt.wantCode != http.StatusOK {
+ return
+ }
+
+ if capturedRequest == nil {
+ t.Fatal("expected captured request, got nil")
+ }
+ if capturedRequest.Model != tt.wantModel {
+ t.Fatalf("expected model %q, got %q", tt.wantModel, capturedRequest.Model)
+ }
+ if len(capturedRequest.Messages) != 1 || capturedRequest.Messages[0].Content != tt.wantMessage {
+ t.Fatalf("expected single user message %q, got %+v", tt.wantMessage, capturedRequest.Messages)
+ }
+ })
+ }
+}
diff --git a/middleware/test_home_test.go b/middleware/test_home_test.go
new file mode 100644
index 00000000000..6c013c147aa
--- /dev/null
+++ b/middleware/test_home_test.go
@@ -0,0 +1,22 @@
+package middleware
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/envconfig"
+)
+
+func setTestHome(t *testing.T, home string) {
+ t.Helper()
+ t.Setenv("HOME", home)
+ t.Setenv("USERPROFILE", home)
+ envconfig.ReloadServerConfig()
+}
+
+// enableCloudForTest sets HOME to a clean temp dir and clears OLLAMA_NO_CLOUD
+// so that cloud features are enabled for the duration of the test.
+func enableCloudForTest(t *testing.T) {
+ t.Helper()
+ t.Setenv("OLLAMA_NO_CLOUD", "")
+ setTestHome(t, t.TempDir())
+}
diff --git a/ml/backend.go b/ml/backend.go
index e6d0ae59971..a4e12451a0b 100644
--- a/ml/backend.go
+++ b/ml/backend.go
@@ -165,6 +165,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor
+ SSMConv(ctx Context, kernel Tensor) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
@@ -172,10 +173,12 @@ type Tensor interface {
Cos(ctx Context) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context, up ...Tensor) Tensor
+ GELU_ERF(ctx Context) Tensor
QuickGELU(ctx Context, up ...Tensor) Tensor
SILU(ctx Context, up ...Tensor) Tensor
RELU(ctx Context, up ...Tensor) Tensor
Sigmoid(ctx Context) Tensor
+ SigmoidOut(ctx Context) Tensor
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
@@ -208,6 +211,32 @@ type Tensor interface {
Stddev(ctx Context) Tensor
Sqr(ctx Context) Tensor
Sqrt(ctx Context) Tensor
+ Exp(ctx Context) Tensor
+ Neg(ctx Context) Tensor
+
+ // Clamp clamps values to [min, max] range
+ Clamp(ctx Context, min, max float32) Tensor
+
+ // Softplus computes ln(1 + exp(x))
+ Softplus(ctx Context) Tensor
+
+ // CumSum computes cumulative sum along dimension 0
+ CumSum(ctx Context) Tensor
+
+ // Diag creates a diagonal matrix from a 1D tensor
+ Diag(ctx Context) Tensor
+
+ // Tri converts a matrix to triangular form (0=upper+diag, 1=upper, 2=lower+diag, 3=lower)
+ Tri(ctx Context, triType int) Tensor
+
+ // Fill fills a tensor with a constant value (in-place)
+ Fill(ctx Context, value float32) Tensor
+
+ // Repeat4D repeats tensor to match target shape
+ Repeat4D(ctx Context, dim0, dim1, dim2, dim3 int) Tensor
+
+ // SolveTri solves a triangular system Ax = B
+ SolveTri(ctx Context, b Tensor, lower, left, unitDiag bool) Tensor
Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
}
diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go
index 6e10821fffa..ef1d1447326 100644
--- a/ml/backend/ggml/ggml.go
+++ b/ml/backend/ggml/ggml.go
@@ -517,7 +517,7 @@ func New(modelPath string, extraModelPaths []string, params ml.BackendParams) (m
}
}
- maxGraphNodes := max(1024, len(meta.Tensors.Items())*8)
+ maxGraphNodes := max(1024, len(meta.Tensors.Items())*32)
sched := C.ggml_backend_sched_new_ext(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
@@ -1608,6 +1608,13 @@ func (t *Tensor) Sigmoid(ctx ml.Context) ml.Tensor {
}
}
+func (t *Tensor) SigmoidOut(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_sigmoid(ctx.(*Context).ctx, t.t),
+ }
+}
+
func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
@@ -1721,6 +1728,13 @@ func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
}
}
+func (t *Tensor) GELU_ERF(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_gelu_erf_inplace(ctx.(*Context).ctx, t.t),
+ }
+}
+
func (t *Tensor) QuickGELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
var tt *C.struct_ggml_tensor
if len(t2) > 0 {
@@ -1781,6 +1795,13 @@ func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2,
return tt
}
+func (t *Tensor) SSMConv(ctx ml.Context, kernel ml.Tensor) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_ssm_conv(ctx.(*Context).ctx, t.t, kernel.(*Tensor).t),
+ }
+}
+
func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1905,6 +1926,76 @@ func (t *Tensor) Sqrt(ctx ml.Context) ml.Tensor {
}
}
+func (t *Tensor) Exp(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_exp(ctx.(*Context).ctx, t.t),
+ }
+}
+
+func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_neg(ctx.(*Context).ctx, t.t),
+ }
+}
+
+func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_clamp(ctx.(*Context).ctx, t.t, C.float(min), C.float(max)),
+ }
+}
+
+func (t *Tensor) Softplus(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_softplus(ctx.(*Context).ctx, t.t),
+ }
+}
+
+func (t *Tensor) CumSum(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_cumsum(ctx.(*Context).ctx, t.t),
+ }
+}
+
+func (t *Tensor) Diag(ctx ml.Context) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_diag(ctx.(*Context).ctx, t.t),
+ }
+}
+
+func (t *Tensor) Tri(ctx ml.Context, triType int) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_tri(ctx.(*Context).ctx, t.t, C.enum_ggml_tri_type(triType)),
+ }
+}
+
+func (t *Tensor) Fill(ctx ml.Context, value float32) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_fill_inplace(ctx.(*Context).ctx, t.t, C.float(value)),
+ }
+}
+
+func (t *Tensor) Repeat4D(ctx ml.Context, dim0, dim1, dim2, dim3 int) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_repeat_4d(ctx.(*Context).ctx, t.t, C.int64_t(dim0), C.int64_t(dim1), C.int64_t(dim2), C.int64_t(dim3)),
+ }
+}
+
+func (t *Tensor) SolveTri(ctx ml.Context, b ml.Tensor, lower, left, unitDiag bool) ml.Tensor {
+ return &Tensor{
+ b: t.b,
+ t: C.ggml_solve_tri(ctx.(*Context).ctx, t.t, b.(*Tensor).t, C._Bool(lower), C._Bool(left), C._Bool(unitDiag)),
+ }
+}
+
func (t *Tensor) Interpolate(ctx ml.Context, dims [4]int, samplingMode ml.SamplingMode) ml.Tensor {
var mode C.uint32_t
switch samplingMode {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
index 7bd1044c19f..3dea2205e55 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
@@ -66,7 +66,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -80,7 +81,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 96, 64, 128, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
@@ -89,7 +91,8 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
}
static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
- GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 4, 64, 4, 32, 288, 256, 64, 1, false);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, true);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
@@ -397,7 +400,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -467,7 +470,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
}
}
} else {
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +481,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
T_A_KQ K_A;
load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
- // Wide version of KQ_C is column-major => swap A and B.
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+ if constexpr (cols_per_warp == 8) {
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+ } else {
+ // Wide version of KQ_C is column-major
+#if defined(AMD_WMMA_AVAILABLE)
+ // RDNA matrix C is column-major.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+#else
+ // swap A and B for CUDA.
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+#endif // defined(AMD_WMMA_AVAILABLE)
+ }
}
}
}
@@ -841,7 +853,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -1353,6 +1365,13 @@ static __global__ void flash_attn_ext_f16(
NO_DEVICE_CODE;
return;
}
+#ifdef VOLTA_MMA_AVAILABLE
+ if (ncols1*ncols2 < 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // VOLTA_MMA_AVAILABLE
+
#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
if (ncols1*ncols2 > 32) {
NO_DEVICE_CODE;
@@ -1585,3 +1604,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh
index 7c4d6fe67fe..371be74421c 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
@@ -1187,6 +1195,14 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
launch_fattn_tile_switch_ncols1(ctx, dst);
return;
}
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1(ctx, dst);
+ return;
+ }
}
if constexpr (DV <= 256) {
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
index 0155406665c..1693479cb54 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
@@ -111,7 +111,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
break;
case 576: {
- // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
+ // For Deepseek/GLM4, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
GGML_ASSERT(V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
- GGML_ASSERT(gqa_ratio % 16 == 0);
- ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ GGML_ASSERT(gqa_ratio % 4 == 0);
+ if (gqa_ratio % 16 == 0) {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst);
+ }
} break;
default:
GGML_ABORT("fatal error");
@@ -251,7 +255,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
- if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+ if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
index 2074e954a32..517993cb068 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
index 24c64cf000f..97b19c67ade 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
index 1ada657f194..989626dfa5e 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
index 86d4ffae27c..173de7aac7d 100644
--- a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
index 680904d132d..83385c9ef60 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
@@ -1370,6 +1370,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
return res;
}
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
+ assert(op->op == GGML_OP_SOLVE_TRI);
+
+ GGML_ASSERT(ggml_is_contiguous(op->src[0]));
+ GGML_ASSERT(ggml_is_contiguous(op->src[1]));
+
+ char base[256];
+ char name[256];
+
+ snprintf(base, 256, "kernel_solve_tri_f32");
+ snprintf(name, 256, "%s", base);
+
+ ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+ if (!res.pipeline) {
+ res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
+ }
+
+ return res;
+}
+
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_GROUP_NORM);
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
index 0a8b9211a76..8a9d1746018 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
@@ -133,6 +133,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op);
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
index f24270bb1c5..4e5acfbe5fd 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
@@ -1023,6 +1023,17 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
+ case GGML_OP_SOLVE_TRI:
+ return ggml_is_contiguous(op->src[0]) &&
+ ggml_is_contiguous(op->src[1]) &&
+ op->src[0]->type == GGML_TYPE_F32 &&
+ op->src[1]->type == GGML_TYPE_F32 &&
+ op->type == GGML_TYPE_F32;
+ case GGML_OP_COUNT_EQUAL:
+ return has_simdgroup_reduction &&
+ op->src[0]->type == GGML_TYPE_I32 &&
+ op->src[1]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
@@ -1071,12 +1082,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
- op->src[0]->ne[0] != 256) {
- return false;
- }
- if (op->src[0]->ne[0] == 576) {
- // DeepSeek sizes
- // TODO: disabled for now, until optmized
+ op->src[0]->ne[0] != 256 &&
+ op->src[0]->ne[0] != 576) {
return false;
}
if (op->src[1]->type != op->src[2]->type) {
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
index 13c6715ba24..9404c93cebe 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
@@ -2385,6 +2385,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
typedef struct {
int64_t ne00;
int64_t ne01;
@@ -5813,6 +5834,66 @@ kernel void kernel_l2_norm_f32(
}
}
+kernel void kernel_solve_tri_f32(
+ constant ggml_metal_kargs_solve_tri & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const uint64_t ncols = (uint64_t) args.ne10;
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
+ const uint64_t nr = n_batches * ncols;
+
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
+ if (gid >= nr) {
+ return;
+ }
+
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
+ const uint64_t i02 = rem / ncols;
+ const uint64_t i01 = rem - i02 * ncols;
+
+ const uint64_t sa0 = args.nb00 / sizeof(float);
+ const uint64_t sa1 = args.nb01 / sizeof(float);
+ const uint64_t sa2 = args.nb02 / sizeof(float);
+ const uint64_t sa3 = args.nb03 / sizeof(float);
+
+ const uint64_t sb0 = args.nb10 / sizeof(float);
+ const uint64_t sb1 = args.nb11 / sizeof(float);
+ const uint64_t sb2 = args.nb12 / sizeof(float);
+ const uint64_t sb3 = args.nb13 / sizeof(float);
+
+ const uint64_t sx0 = args.nb0 / sizeof(float);
+ const uint64_t sx1 = args.nb1 / sizeof(float);
+ const uint64_t sx2 = args.nb2 / sizeof(float);
+ const uint64_t sx3 = args.nb3 / sizeof(float);
+
+ device const float * A = (device const float *) src0;
+ device const float * B = (device const float *) src1;
+ device float * X = (device float *) dst;
+
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
+
+ const uint64_t n = (uint64_t) args.ne11;
+
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
+ float sum = 0.0f;
+ for (uint64_t t = 0; t < i00; ++t) {
+ sum += A[A_base + i00 * sa1 + t * sa0] *
+ X[X_base + t * sx1 + i01 * sx0];
+ }
+
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
+ X[X_base + i00 * sx1 + i01 * sx0] =
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
+ }
+}
+
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -8967,6 +9048,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
index 8944b07e907..cfdea9c0721 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
@@ -500,6 +500,27 @@ typedef struct {
float eps;
} ggml_metal_kargs_l2_norm;
+typedef struct {
+ int32_t ne00;
+ int32_t ne01;
+ int32_t ne02;
+ int32_t ne03;
+ uint64_t nb00;
+ uint64_t nb01;
+ uint64_t nb02;
+ uint64_t nb03;
+ int32_t ne10;
+ int32_t ne11;
+ uint64_t nb10;
+ uint64_t nb11;
+ uint64_t nb12;
+ uint64_t nb13;
+ uint64_t nb0;
+ uint64_t nb1;
+ uint64_t nb2;
+ uint64_t nb3;
+} ggml_metal_kargs_solve_tri;
+
typedef struct {
int64_t ne00;
int64_t ne01;
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
index e99c1763f63..4ac135603cd 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
@@ -357,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_l2_norm(ctx, idx);
} break;
+ case GGML_OP_SOLVE_TRI:
+ {
+ n_fuse = ggml_metal_op_solve_tri(ctx, idx);
+ } break;
case GGML_OP_GROUP_NORM:
{
n_fuse = ggml_metal_op_group_norm(ctx, idx);
@@ -2456,7 +2460,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
// simdgroups per threadgroup (a.k.a. warps)
//nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4;
- int32_t nsg = 4;
+ int32_t nsg = ne00 >= 512 ? 8 : 4;
const size_t smem = FATTN_SMEM(nsg);
@@ -2931,6 +2935,65 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
return 1;
}
+int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
+ ggml_tensor * op = ctx->node(idx);
+
+ ggml_metal_library_t lib = ctx->lib;
+ ggml_metal_encoder_t enc = ctx->enc;
+
+ GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
+ GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
+ GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
+
+ ggml_metal_kargs_solve_tri args = {
+ /*.ne00 =*/ ne00,
+ /*.ne01 =*/ ne01,
+ /*.ne02 =*/ ne02,
+ /*.ne03 =*/ ne03,
+ /*.nb00 =*/ nb00,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb03 =*/ nb03,
+ /*.ne10 =*/ ne10,
+ /*.ne11 =*/ ne11,
+ /*.nb10 =*/ nb10,
+ /*.nb11 =*/ nb11,
+ /*.nb12 =*/ nb12,
+ /*.nb13 =*/ nb13,
+ /*.nb0 =*/ nb0,
+ /*.nb1 =*/ nb1,
+ /*.nb2 =*/ nb2,
+ /*.nb3 =*/ nb3,
+ };
+
+ auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op);
+
+ const int64_t ncols = ne10;
+ const int64_t n_batches = (int64_t)ne02 * ne03;
+ const int64_t nr = n_batches * ncols;
+
+ int nth = 64;
+ nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+ if (nth < 1) {
+ nth = 1;
+ }
+
+ const int64_t n_tg = (nr + nth - 1) / nth;
+
+ ggml_metal_encoder_set_pipeline(enc, pipeline);
+ ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
+
+ ggml_metal_encoder_dispatch_threadgroups(enc, n_tg, 1, 1, nth, 1, 1);
+
+ return 1;
+}
+
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
index 902b5445232..a475183d367 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
@@ -68,6 +68,7 @@ int ggml_metal_op_add_id (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_flash_attn_ext (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_bin (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx);
+int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx);
diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
index c98d269d133..c37447a1045 100644
--- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
@@ -3012,6 +3012,66 @@ kernel void kernel_l2_norm_f32(
}
}
+kernel void kernel_solve_tri_f32(
+ constant ggml_metal_kargs_solve_tri & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ const uint64_t ncols = (uint64_t) args.ne10;
+ const uint64_t n_batches = (uint64_t) args.ne02 * (uint64_t) args.ne03;
+ const uint64_t nr = n_batches * ncols;
+
+ const uint64_t gid = (uint64_t) tgpig * (uint64_t) ntg + (uint64_t) tpitg;
+ if (gid >= nr) {
+ return;
+ }
+
+ const uint64_t i03 = gid / ((uint64_t) args.ne02 * ncols);
+ const uint64_t rem = gid - i03 * (uint64_t) args.ne02 * ncols;
+ const uint64_t i02 = rem / ncols;
+ const uint64_t i01 = rem - i02 * ncols;
+
+ const uint64_t sa0 = args.nb00 / sizeof(float);
+ const uint64_t sa1 = args.nb01 / sizeof(float);
+ const uint64_t sa2 = args.nb02 / sizeof(float);
+ const uint64_t sa3 = args.nb03 / sizeof(float);
+
+ const uint64_t sb0 = args.nb10 / sizeof(float);
+ const uint64_t sb1 = args.nb11 / sizeof(float);
+ const uint64_t sb2 = args.nb12 / sizeof(float);
+ const uint64_t sb3 = args.nb13 / sizeof(float);
+
+ const uint64_t sx0 = args.nb0 / sizeof(float);
+ const uint64_t sx1 = args.nb1 / sizeof(float);
+ const uint64_t sx2 = args.nb2 / sizeof(float);
+ const uint64_t sx3 = args.nb3 / sizeof(float);
+
+ device const float * A = (device const float *) src0;
+ device const float * B = (device const float *) src1;
+ device float * X = (device float *) dst;
+
+ const uint64_t A_base = i02 * sa2 + i03 * sa3;
+ const uint64_t B_base = i02 * sb2 + i03 * sb3;
+ const uint64_t X_base = i02 * sx2 + i03 * sx3;
+
+ const uint64_t n = (uint64_t) args.ne11;
+
+ for (uint64_t i00 = 0; i00 < n; ++i00) {
+ float sum = 0.0f;
+ for (uint64_t t = 0; t < i00; ++t) {
+ sum += A[A_base + i00 * sa1 + t * sa0] *
+ X[X_base + t * sx1 + i01 * sx0];
+ }
+
+ const float diag = A[A_base + i00 * sa1 + i00 * sa0];
+ X[X_base + i00 * sx1 + i01 * sx0] =
+ (B[B_base + i00 * sb1 + i01 * sb0] - sum) / diag;
+ }
+}
+
kernel void kernel_group_norm_f32(
constant ggml_metal_kargs_group_norm & args,
device const float * src0,
@@ -6166,6 +6226,7 @@ kernel void kernel_flash_attn_ext(
//case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break;
//case 2: kernel_flash_attn_ext_impl(FWD_ARGS); break;
case 4: kernel_flash_attn_ext_impl(FWD_ARGS); break;
+ case 8: kernel_flash_attn_ext_impl(FWD_ARGS); break;
}
#undef FWD_TMPL
#undef FWD_ARGS
diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go
deleted file mode 100644
index 765331bf813..00000000000
--- a/model/bytepairencoding.go
+++ /dev/null
@@ -1,272 +0,0 @@
-package model
-
-import (
- "cmp"
- "iter"
- "slices"
- "strings"
-
- "github.com/dlclark/regexp2"
- heap "github.com/emirpasic/gods/v2/trees/binaryheap"
- "github.com/ollama/ollama/logutil"
-)
-
-type BytePairEncoding struct {
- vocab *Vocabulary
- regexps []*regexp2.Regexp
-}
-
-var _ TextProcessor = (*BytePairEncoding)(nil)
-
-func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
- if len(pretokenizers) == 0 {
- // set default byte-level pretokenizer if none provided, e.g.
- // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
- pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
- }
-
- return BytePairEncoding{
- vocab: vocab,
- regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
- for _, p := range pretokenizers {
- if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
- return
- }
- }
- }),
- }
-}
-
-func (bpe BytePairEncoding) Vocabulary() *Vocabulary {
- return bpe.vocab
-}
-
-func (bpe BytePairEncoding) Is(id int32, special Special) bool {
- return bpe.vocab.Is(id, special)
-}
-
-func (bpe *BytePairEncoding) split(s string) iter.Seq[string] {
- parts := []string{s}
- for _, re := range bpe.regexps {
- parts = slices.Collect(func(yield func(string) bool) {
- for _, part := range parts {
- r := []rune(part)
- var offset int
- for m, _ := re.FindRunesMatch(r); m != nil; m, _ = re.FindNextMatch(m) {
- if offset-m.Index != 0 {
- if !yield(string(r[:m.Index])) {
- return
- }
- }
-
- if !yield(m.String()) {
- return
- }
-
- offset = m.Index + m.Length
- }
-
- if offset < len(r) {
- if !yield(string(r[offset:])) {
- return
- }
- }
- }
- })
- }
-
- return slices.Values(parts)
-}
-
-// fragment is a string fragment and their corresponding token IDs
-type fragment struct {
- value string
- ids []int32
-}
-
-// pair is a pair of runes and its rank
-type pair struct {
- a, b int
- rank int
- value string
-}
-
-type merge struct {
- p, n int
- runes []rune
-}
-
-func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
- fragments := []fragment{{value: s}}
- for _, special := range bpe.vocab.SpecialVocabulary() {
- // TODO: process special tokens concurrently
- id := bpe.vocab.Encode(special)
- for i := 0; i < len(fragments); i++ {
- frag := fragments[i]
- if len(frag.ids) > 0 {
- continue
- }
-
- var middle []fragment
- switch i := strings.Index(frag.value, special); {
- case i < 0:
- middle = append(middle, frag)
- case i > 0:
- middle = append(middle, fragment{value: frag.value[:i]})
- fallthrough
- default:
- middle = append(middle, fragment{value: special, ids: []int32{id}})
- if rest := frag.value[i+len(special):]; rest != "" {
- middle = append(middle, fragment{value: rest})
- }
- }
-
- fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
- }
- }
-
- var ids []int32
- for _, frag := range fragments {
- if len(frag.ids) > 0 {
- ids = append(ids, frag.ids...)
- continue
- }
-
- for split := range bpe.split(frag.value) {
- // TODO: process splits concurrently
- var sb strings.Builder
- for _, b := range []byte(split) {
- r := rune(b)
- switch {
- case r == 0x00ad:
- r = 0x0143
- case r <= 0x0020:
- r = r + 0x0100
- case r >= 0x007f && r <= 0x00a0:
- r = r + 0x00a2
- }
-
- sb.WriteRune(r)
- }
-
- // short circuit if the fragment is in the vocabulary
- if id := bpe.vocab.Encode(sb.String()); id >= 0 {
- ids = append(ids, id)
- continue
- }
-
- runes := []rune(sb.String())
- merges := make([]merge, len(runes))
- for r := range runes {
- merges[r] = merge{
- p: r - 1,
- n: r + 1,
- runes: []rune{runes[r]},
- }
- }
-
- pairwise := func(a, b int) *pair {
- if a < 0 || b >= len(runes) {
- return nil
- }
-
- left, right := string(merges[a].runes), string(merges[b].runes)
- rank := bpe.vocab.Merge(left, right)
- if rank < 0 {
- return nil
- }
-
- return &pair{
- a: a,
- b: b,
- rank: rank,
- value: left + right,
- }
- }
-
- pairs := heap.NewWith(func(i, j *pair) int {
- return cmp.Compare(i.rank, j.rank)
- })
-
- for i := range len(runes) - 1 {
- if pair := pairwise(i, i+1); pair != nil {
- pairs.Push(pair)
- }
- }
-
- for !pairs.Empty() {
- pair, _ := pairs.Pop()
-
- left, right := merges[pair.a], merges[pair.b]
- if len(left.runes) == 0 || len(right.runes) == 0 ||
- string(left.runes)+string(right.runes) != pair.value {
- continue
- }
-
- if id := bpe.vocab.Encode(pair.value); id < 0 {
- continue
- }
-
- merges[pair.a].runes = append(left.runes, right.runes...)
- merges[pair.b].runes = nil
-
- merges[pair.a].n = right.n
- if right.n < len(merges) {
- merges[right.n].p = pair.a
- }
-
- if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
- pairs.Push(pair)
- }
-
- if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
- pairs.Push(pair)
- }
- }
-
- for _, merge := range merges {
- if len(merge.runes) > 0 {
- // TODO: handle the edge case where the rune isn't in the vocabulary
- if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 {
- ids = append(ids, id)
- }
- }
- }
- }
- }
-
- if addSpecial {
- ids = bpe.vocab.addSpecials(ids)
- }
-
- logutil.Trace("encoded", "string", s, "ids", ids)
- return ids, nil
-}
-
-func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
- var sb strings.Builder
- for _, id := range ids {
- for _, r := range bpe.vocab.Decode(id) {
- switch {
- case r == 0x0100:
- // this produces 0x00 aka NULL
- continue
- case r == 0x0143:
- r = 0x00ad
- case r > 0x0100 && r <= 0x0120:
- r = r - 0x0100
- case r > 0x0120 && r <= 0x0142:
- r = r - 0x00a2
- }
-
- // NOTE: not using WriteRune here because it writes the UTF-8
- // encoding of the rune which is _not_ what we want
- if err := sb.WriteByte(byte(r)); err != nil {
- return "", err
- }
- }
- }
-
- logutil.Trace("decoded", "string", sb.String(), "from", ids)
- return sb.String(), nil
-}
diff --git a/model/model.go b/model/model.go
index d45e0311175..ab3a068dafd 100644
--- a/model/model.go
+++ b/model/model.go
@@ -23,6 +23,7 @@ import (
_ "github.com/ollama/ollama/ml/backend"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
var (
@@ -39,6 +40,13 @@ type Model interface {
Config() config
}
+// Validator is an optional interface that models can implement to perform
+// validation after tensors have been loaded. If validation fails, model
+// loading will fail with the returned error.
+type Validator interface {
+ Validate() error
+}
+
// MultimodalProcessor must be implemented by multimodal models.
type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
@@ -116,10 +124,17 @@ func New(modelPath string, extraModelPaths []string, params ml.BackendParams) (M
base := Base{b: b, config: m.Config()}
v := reflect.ValueOf(m)
v.Elem().Set(populateFields(base, v.Elem()))
+
+ if validator, ok := m.(Validator); ok {
+ if err := validator.Validate(); err != nil {
+ return nil, err
+ }
+ }
+
return m, nil
}
-func NewTextProcessor(s string) (TextProcessor, error) {
+func NewTextProcessor(s string) (tokenizer.Tokenizer, error) {
r, err := os.Open(s)
if err != nil {
return nil, err
@@ -136,7 +151,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
- tp, ok := m.(TextProcessor)
+ tp, ok := m.(tokenizer.Tokenizer)
if !ok {
return nil, ErrUnsupportedTokenizer
}
diff --git a/model/model_test.go b/model/model_test.go
index f6d75b2302f..ed2868ff3e5 100644
--- a/model/model_test.go
+++ b/model/model_test.go
@@ -56,6 +56,18 @@ type fakeTensor struct {
Name string
}
+// Stub methods to satisfy ml.Tensor interface
+func (f *fakeTensor) Exp(ctx ml.Context) ml.Tensor { return f }
+func (f *fakeTensor) Neg(ctx ml.Context) ml.Tensor { return f }
+func (f *fakeTensor) Clamp(ctx ml.Context, _, _ float32) ml.Tensor { return f }
+func (f *fakeTensor) Softplus(ctx ml.Context) ml.Tensor { return f }
+func (f *fakeTensor) CumSum(ctx ml.Context) ml.Tensor { return f }
+func (f *fakeTensor) Diag(ctx ml.Context) ml.Tensor { return f }
+func (f *fakeTensor) Tri(ctx ml.Context, _ int) ml.Tensor { return f }
+func (f *fakeTensor) Fill(ctx ml.Context, _ float32) ml.Tensor { return f }
+func (f *fakeTensor) Repeat4D(ctx ml.Context, _, _, _, _ int) ml.Tensor { return f }
+func (f *fakeTensor) SolveTri(ctx ml.Context, _ ml.Tensor, _, _, _ bool) ml.Tensor { return f }
+
func (m *fakeBackend) Get(name string) ml.Tensor {
if slices.Contains(m.names, name) {
return &fakeTensor{Name: name}
diff --git a/model/models/bert/embed.go b/model/models/bert/embed.go
index 79cb3a3c7d7..3bce8bc0bc3 100644
--- a/model/models/bert/embed.go
+++ b/model/models/bert/embed.go
@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -129,7 +130,7 @@ func (o Options) headDim() int {
}
func New(c fs.Config) (model.Model, error) {
- vocab := &model.Vocabulary{
+ vocab := &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -153,17 +154,17 @@ func New(c fs.Config) (model.Model, error) {
},
}
- var processor model.TextProcessor
+ var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
- processor = model.NewWordPiece(vocab, true)
+ t = tokenizer.NewWordPiece(vocab, true)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
- TextProcessor: processor,
- Layers: make([]EncoderLayer, c.Uint("block_count")),
+ Tokenizer: t,
+ Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
diff --git a/model/models/deepseek2/model.go b/model/models/deepseek2/model.go
index 576076aab55..0bf9e7da3ee 100644
--- a/model/models/deepseek2/model.go
+++ b/model/models/deepseek2/model.go
@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -222,7 +223,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -277,8 +278,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/deepseekocr/model.go b/model/models/deepseekocr/model.go
index 4fc069b6995..9bfb5596a7f 100644
--- a/model/models/deepseekocr/model.go
+++ b/model/models/deepseekocr/model.go
@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
Sam *samModel `gguf:"s"`
Vision *visionModel `gguf:"v"`
@@ -134,8 +135,8 @@ func init() {
}
m := Model{
- TextProcessor: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go
index 7b0aa2f01ab..56ac6992211 100644
--- a/model/models/gemma2/model.go
+++ b/model/models/gemma2/model.go
@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -27,7 +28,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
- model.SentencePiece
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -43,8 +44,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
- SentencePiece: model.NewSentencePiece(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewSentencePiece(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
diff --git a/model/models/gemma3/embed.go b/model/models/gemma3/embed.go
index 9251111cfd1..6ad7f82cbb9 100644
--- a/model/models/gemma3/embed.go
+++ b/model/models/gemma3/embed.go
@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
- model.SentencePiece
+ tokenizer.Tokenizer
*TextModel
poolingType pooling.Type
@@ -31,8 +32,8 @@ func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, erro
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
- SentencePiece: model.NewSentencePiece(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewSentencePiece(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go
index e595f186305..4f5f0e40478 100644
--- a/model/models/gemma3/model.go
+++ b/model/models/gemma3/model.go
@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -54,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
- vocabulary := model.Vocabulary{
+ vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -70,19 +71,19 @@ func New(c fs.Config) (model.Model, error) {
),
}
- var processor model.TextProcessor
+ var t tokenizer.Tokenizer
switch c.String("tokenizer.ggml.model") {
case "gpt2":
- processor = model.NewBytePairEncoding(&vocabulary)
+ t = tokenizer.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
- processor = model.NewSentencePiece(&vocabulary)
+ t = tokenizer.NewSentencePiece(&vocabulary)
}
m := Model{
- TextProcessor: processor,
+ Tokenizer: t,
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
diff --git a/model/models/gemma3n/model.go b/model/models/gemma3n/model.go
index e59e3193f24..758745c5e64 100644
--- a/model/models/gemma3n/model.go
+++ b/model/models/gemma3n/model.go
@@ -6,11 +6,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.SentencePiece
+ tokenizer.Tokenizer
*TextModel
}
@@ -23,8 +24,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
- SentencePiece: model.NewSentencePiece(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewSentencePiece(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
diff --git a/model/models/glm4moelite/model.go b/model/models/glm4moelite/model.go
index 2e51f7d56a2..d5d50ad8230 100644
--- a/model/models/glm4moelite/model.go
+++ b/model/models/glm4moelite/model.go
@@ -1,6 +1,7 @@
package glm4moelite
import (
+ "errors"
"math"
"github.com/ollama/ollama/fs"
@@ -9,8 +10,11 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
+var ErrOldModelFormat = errors.New("this model uses a weight format that is no longer supported; please re-download it")
+
type Options struct {
numExpertsUsed int
numExperts int
@@ -47,7 +51,9 @@ type Attention struct {
KVA *nn.Linear `gguf:"attn_kv_a_mqa"`
KVANorm *nn.RMSNorm `gguf:"attn_kv_a_norm"`
- KVB *nn.Linear `gguf:"attn_kv_b"`
+
+ KB *nn.Linear `gguf:"attn_k_b"`
+ VB *nn.Linear `gguf:"attn_v_b"`
Output *nn.Linear `gguf:"attn_out,alt:attn_output"`
}
@@ -78,15 +84,16 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
- kPass = attn.KVB.Forward(ctx, kPass)
- kv := kPass.Reshape(ctx, kPass.Dim(0)/opts.numKVHeads, opts.numKVHeads, seqLength)
- kvChunks := kv.ChunkSections(ctx, 0, opts.kqNopeHeadDim, opts.vHeadDim)
+ // MLA absorption: absorb K projection into query
+ qPass := queryChunks[0].Permute(ctx, 0, 2, 1, 3)
+ qPassAbsorb := attn.KB.Forward(ctx, qPass).Permute(ctx, 0, 2, 1, 3)
+ query = qRot.Concat(ctx, qPassAbsorb, 0)
- kRot = kRot.Repeat(ctx, 1, queryChunks[0].Dim(1))
- query = qRot.Concat(ctx, queryChunks[0], 0)
- key := kRot.Concat(ctx, kvChunks[0], 0)
- attention := nn.Attention(ctx, query, key, kvChunks[1], opts.kqScale, cache)
+ kPass = kPass.Reshape(ctx, opts.kvLoraRank, 1, seqLength)
+ key := kRot.Concat(ctx, kPass, 0)
+
+ attention := nn.AttentionWithVMLA(ctx, query, key, kPass, nil, attn.VB.Weight, opts.kqScale, cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), seqLength)
return attn.Output.Forward(ctx, attention)
@@ -192,7 +199,7 @@ func (t *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -217,7 +224,6 @@ func New(c fs.Config) (model.Model, error) {
keyLength := int(c.Uint("attention.key_length"))
valueLength := int(c.Uint("attention.value_length"))
-
kqScale := 1.0 / math.Sqrt(float64(keyLength))
var pre []string
@@ -231,12 +237,12 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
- AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
@@ -279,6 +285,15 @@ func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
}
+func (m *Model) Validate() error {
+ for _, layer := range m.Layers {
+ if layer.Attention != nil && (layer.Attention.KB == nil || layer.Attention.VB == nil) {
+ return ErrOldModelFormat
+ }
+ }
+ return nil
+}
+
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
diff --git a/model/models/glm4moelite/model_test.go b/model/models/glm4moelite/model_test.go
new file mode 100644
index 00000000000..fbc3b460d99
--- /dev/null
+++ b/model/models/glm4moelite/model_test.go
@@ -0,0 +1,73 @@
+package glm4moelite
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/ml/nn"
+)
+
+func TestValidate(t *testing.T) {
+ tests := []struct {
+ name string
+ model *Model
+ wantErr bool
+ }{
+ {
+ name: "valid model with KB and VB",
+ model: &Model{
+ Layers: []Layer{
+ {Attention: &Attention{KB: &nn.Linear{}, VB: &nn.Linear{}}},
+ },
+ },
+ wantErr: false,
+ },
+ {
+ name: "missing KB",
+ model: &Model{
+ Layers: []Layer{
+ {Attention: &Attention{VB: &nn.Linear{}}},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "missing VB",
+ model: &Model{
+ Layers: []Layer{
+ {Attention: &Attention{KB: &nn.Linear{}}},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "missing both KB and VB",
+ model: &Model{
+ Layers: []Layer{
+ {Attention: &Attention{}},
+ },
+ },
+ wantErr: true,
+ },
+ {
+ name: "nil Attention is ok",
+ model: &Model{
+ Layers: []Layer{
+ {Attention: nil},
+ },
+ },
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := tt.model.Validate()
+ if (err != nil) != tt.wantErr {
+ t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
+ }
+ if tt.wantErr && err != ErrOldModelFormat {
+ t.Errorf("Validate() error = %v, want %v", err, ErrOldModelFormat)
+ }
+ })
+ }
+}
diff --git a/model/models/glmocr/imageprocessor.go b/model/models/glmocr/imageprocessor.go
new file mode 100644
index 00000000000..a26f42c4c73
--- /dev/null
+++ b/model/models/glmocr/imageprocessor.go
@@ -0,0 +1,174 @@
+package glmocr
+
+import (
+ "image"
+ "log/slog"
+ "math"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/model/imageproc"
+)
+
+type ImageProcessor struct {
+ imageSize int
+ patchSize int
+ temporalPatchSize int
+ spatialMergeSize int
+ minPixels int
+ maxPixels int
+ factor int
+ imageMean [3]float32
+ imageStd [3]float32
+}
+
+func newImageProcessor(c fs.Config) ImageProcessor {
+ patchSize := int(c.Uint("vision.patch_size", 14))
+ spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
+ temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
+
+ // Read normalization values from config if available, otherwise use CLIP defaults
+ imageMean := c.Floats("vision.image_mean", imageproc.ClipDefaultMean[:])
+ imageStd := c.Floats("vision.image_std", imageproc.ClipDefaultSTD[:])
+
+ // Default max_pixels: 2048 * patchSize^2 * mergeSize^2 * temporal = ~3.2M pixels
+ // This limits to ~16k patches (4k output tokens) to keep memory stable without flash attention
+ defaultMaxPixels := 2048 * patchSize * patchSize * spatialMergeSize * spatialMergeSize * temporalPatchSize
+
+ return ImageProcessor{
+ imageSize: int(c.Uint("vision.image_size", 336)),
+ patchSize: patchSize,
+ temporalPatchSize: temporalPatchSize,
+ spatialMergeSize: spatialMergeSize,
+ minPixels: int(c.Uint("vision.min_pixels", uint32(8*patchSize*patchSize*spatialMergeSize*spatialMergeSize*temporalPatchSize))),
+ maxPixels: int(c.Uint("vision.max_pixels", uint32(defaultMaxPixels))),
+ factor: patchSize * spatialMergeSize,
+ imageMean: [3]float32{imageMean[0], imageMean[1], imageMean[2]},
+ imageStd: [3]float32{imageStd[0], imageStd[1], imageStd[2]},
+ }
+}
+
+func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
+ factor := p.factor
+ temporalFactor := p.temporalPatchSize
+ numFrames := temporalFactor // single image
+
+ if height < factor || width < factor {
+ // Scale up small images
+ scale := float64(factor) / float64(min(height, width))
+ height = int(math.Ceil(float64(height) * scale))
+ width = int(math.Ceil(float64(width) * scale))
+ }
+
+ if temporalFactor <= 0 {
+ slog.Warn("temporal_patch_size must be > 0, defaulting to 1")
+ temporalFactor = 1
+ }
+ if numFrames < temporalFactor {
+ slog.Warn("num_frames must be >= temporal_patch_size, adjusting num_frames", "num_frames", numFrames, "temporal_patch_size", temporalFactor)
+ numFrames = temporalFactor
+ }
+ if aspectRatio := float64(max(height, width)) / float64(min(height, width)); aspectRatio > 200 {
+ slog.Warn("aspect ratio exceeds 200, image quality may be affected", "aspect_ratio", aspectRatio)
+ }
+
+ round := func(x float64) int { return int(math.RoundToEven(x)) }
+
+ hBar := round(float64(height)/float64(factor)) * factor
+ wBar := round(float64(width)/float64(factor)) * factor
+ tBar := round(float64(numFrames)/float64(temporalFactor)) * temporalFactor
+
+ if tBar*hBar*wBar > p.maxPixels {
+ beta := math.Sqrt(float64(numFrames*height*width) / float64(p.maxPixels))
+ hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
+ wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
+ } else if tBar*hBar*wBar < p.minPixels {
+ beta := math.Sqrt(float64(p.minPixels) / float64(numFrames*height*width))
+ hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
+ wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
+ }
+
+ return hBar, wBar
+}
+
+func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
+ img = imageproc.Composite(img)
+
+ origWidth := img.Bounds().Dx()
+ origHeight := img.Bounds().Dy()
+
+ // Calculate smart resize dimensions
+ resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
+
+ // Resize image
+ resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeCatmullrom)
+
+ // Normalize pixels - output format is [C, H, W] with rescale and channelFirst
+ // We keep [C, H, W] for patch extraction
+ normalizedPixels := imageproc.Normalize(resizedImg, p.imageMean, p.imageStd, true, true)
+
+ // Calculate grid dimensions (after Conv2D patching)
+ grid := &Grid{
+ Height: resizedHeight / p.patchSize,
+ Width: resizedWidth / p.patchSize,
+ Temporal: 1, // Single image
+ ImageHeight: resizedHeight,
+ ImageWidth: resizedWidth,
+ }
+
+ patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return patches, grid, nil
+}
+
+func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
+ channels := 3
+ patchSize := p.patchSize
+ mergeSize := p.spatialMergeSize
+ temporalPatchSize := p.temporalPatchSize
+
+ numPatches := grid.Temporal * grid.Height * grid.Width
+ patchDim := channels * temporalPatchSize * patchSize * patchSize
+ result := make([]float32, numPatches*patchDim)
+ patchIndex := 0
+
+ // Single temporal frame handling (copies to all frames)
+ for range grid.Temporal {
+ for h := 0; h < grid.Height; h += mergeSize {
+ for w := 0; w < grid.Width; w += mergeSize {
+ for mh := range mergeSize {
+ for mw := range mergeSize {
+ baseOffset := patchIndex * patchDim
+ for c := range channels {
+ channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
+ for py := range patchSize {
+ for px := range patchSize {
+ y := (h+mh)*patchSize + py
+ x := (w+mw)*patchSize + px
+ srcIdx := c*height*width + y*width + x
+ dstIdx := channelOffset + (py * patchSize) + px
+ result[dstIdx] = pixels[srcIdx]
+ }
+ }
+
+ if temporalPatchSize > 1 {
+ frameSize := patchSize * patchSize
+ for tp := 1; tp < temporalPatchSize; tp++ {
+ currentFrameOffset := channelOffset + (tp * frameSize)
+ copy(result[currentFrameOffset:currentFrameOffset+frameSize],
+ result[channelOffset:channelOffset+frameSize])
+ }
+ }
+ }
+
+ patchIndex++
+ }
+ }
+ }
+ }
+ }
+
+ return result, nil
+}
diff --git a/model/models/glmocr/model.go b/model/models/glmocr/model.go
new file mode 100644
index 00000000000..895a766c343
--- /dev/null
+++ b/model/models/glmocr/model.go
@@ -0,0 +1,236 @@
+package glmocr
+
+import (
+ "bytes"
+ "errors"
+ "image"
+ "slices"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
+)
+
+type Model struct {
+ model.Base
+ tokenizer.Tokenizer
+
+ *TextModel
+ *VisionModel `gguf:"v"`
+ VisionDownsample *VisionDownsample `gguf:"mm.patch_merger"`
+ PatchMerger *PatchMerger `gguf:"mm"`
+
+ ImageProcessor
+
+ imageTokenID int32
+ imageStartTokenID int32
+ imageEndTokenID int32
+}
+
+var _ model.MultimodalProcessor = (*Model)(nil)
+
+func New(c fs.Config) (model.Model, error) {
+ eosTokenID := int32(c.Uint("tokenizer.ggml.eos_token_id"))
+ eosTokenIDs := c.Ints("tokenizer.ggml.eos_token_ids")
+ allEOS := append([]int32{eosTokenID}, eosTokenIDs...)
+
+ m := &Model{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ Merges: c.Strings("tokenizer.ggml.merges"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
+ BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+ EOS: allEOS,
+ },
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
+ ),
+ TextModel: newTextModel(c),
+ VisionModel: newVisionModel(c),
+ ImageProcessor: newImageProcessor(c),
+ imageTokenID: int32(c.Uint("image_token_id", 59280)),
+ imageStartTokenID: int32(c.Uint("image_start_token_id", 59256)),
+ imageEndTokenID: int32(c.Uint("image_end_token_id", 59257)),
+ }
+
+ m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
+
+ return m, nil
+}
+
+func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
+ if len(m.VisionModel.Blocks) == 0 {
+ return nil, model.ErrNoVisionModel
+ }
+
+ img, _, err := image.Decode(bytes.NewReader(multimodalData))
+ if err != nil {
+ return nil, err
+ }
+
+ f32s, grid, err := m.ImageProcessor.ProcessImage(img)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create pixel values tensor from flattened patches
+ // Shape: [patchDim, numPatches]
+ patchDim := m.VisionModel.numChannels * m.temporalPatchSize * m.patchSize * m.patchSize
+ numPatches := grid.Temporal * grid.Height * grid.Width
+ pixelValues := ctx.Input().FromFloats(f32s, patchDim, numPatches)
+
+ // Forward through vision encoder
+ visionOutputs := m.VisionModel.Forward(ctx, pixelValues, grid)
+
+ // Forward through downsample (patch merger)
+ if m.VisionDownsample == nil || m.VisionDownsample.Weight == nil {
+ return nil, errors.New("glmocr: missing vision downsample weights")
+ }
+ visionOutputs = m.VisionDownsample.Forward(ctx, visionOutputs, grid, m.VisionModel.VisionModelOptions)
+
+ // Forward through patch merger (FC + LayerNorm + GELU + SwiGLU FFN)
+ if m.PatchMerger == nil {
+ return nil, errors.New("glmocr: missing patch merger weights")
+ }
+ visionOutputs = m.PatchMerger.Forward(ctx, visionOutputs, m.VisionModel.VisionModelOptions)
+
+ return []input.Multimodal{{Tensor: visionOutputs, Data: grid}}, nil
+}
+
+func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
+ var result []*input.Input
+
+ // Reset position cache
+ m.TextModel.positionCache = m.TextModel.positionCache[:0]
+ m.TextModel.ropeDelta = 0
+
+ pos := int32(0)
+ for _, inp := range inputs {
+ if inp.Multimodal == nil {
+ result = append(result, inp)
+ m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
+ pos++
+ continue
+ }
+
+ // Get grid info for position calculation
+ grid := inp.Multimodal[0].Data.(*Grid)
+ mergedH := grid.Height / m.VisionModel.spatialMergeSize
+ mergedW := grid.Width / m.VisionModel.spatialMergeSize
+
+ // Add image start token
+ result = append(result, &input.Input{Token: m.imageStartTokenID})
+ m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
+ pos++
+
+ // Add image tokens with multimodal data
+ // All image tokens share the same base position for temporal dimension
+ tokensPerGrid := inp.Multimodal[0].Tensor.Dim(1)
+ basePos := pos
+ sameBatch := tokensPerGrid - 1
+ if sameBatch < 0 {
+ sameBatch = 0
+ }
+ result = append(result, &input.Input{
+ Token: m.imageTokenID,
+ Multimodal: inp.Multimodal,
+ MultimodalHash: inp.MultimodalHash,
+ SameBatch: sameBatch,
+ })
+ m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
+
+ // Add placeholder tokens for remaining positions
+ // All image tokens use the same base position (temporal stays constant)
+ for range tokensPerGrid - 1 {
+ result = append(result, &input.Input{Token: m.imageTokenID})
+ m.TextModel.positionCache = append(m.TextModel.positionCache, basePos)
+ }
+
+ // Advance position by max(mergedH, mergedW) after image tokens
+ pos = basePos + int32(max(mergedH, mergedW))
+
+ // Add image end token
+ result = append(result, &input.Input{Token: m.imageEndTokenID})
+ m.TextModel.positionCache = append(m.TextModel.positionCache, pos)
+ pos++
+ }
+
+ // Compute rope delta for continuation after the prefill segment:
+ // delta = (max_position_id + 1) - sequence_length
+ if len(m.TextModel.positionCache) > 0 {
+ last := m.TextModel.positionCache[len(m.TextModel.positionCache)-1]
+ m.TextModel.ropeDelta = last + 1 - int32(len(m.TextModel.positionCache))
+ }
+
+ return result, nil
+}
+
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ // Initial token embedding
+ hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
+ ctx.Forward(hiddenStates)
+
+ // Build position slices for M-RoPE
+ positionSlice := func() [][]int32 {
+ s := [][]int32{
+ make([]int32, len(batch.Positions)), // temporal
+ make([]int32, len(batch.Positions)), // height
+ make([]int32, len(batch.Positions)), // width
+ make([]int32, len(batch.Positions)), // unused (zeros)
+ }
+ for i, position := range batch.Positions {
+ // Translate through position cache or continue sequence
+ if position < int32(len(m.TextModel.positionCache)) {
+ position = m.TextModel.positionCache[position]
+ } else if len(m.TextModel.positionCache) > 0 {
+ // Continue sequence after cached positions using ropeDelta
+ position = position + m.TextModel.ropeDelta
+ }
+
+ s[0][i] = position
+ s[1][i] = position
+ s[2][i] = position
+ }
+ return s
+ }()
+
+ // Inject vision embeddings and adjust positions for image tokens
+ for _, mi := range batch.Multimodal {
+ img := mi.Multimodal[0].Tensor
+ ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
+
+ if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
+ w := grid.Width / m.VisionModel.spatialMergeSize
+ for i := range img.Dim(1) {
+ positionSlice[1][mi.Index+i] += int32(i / w)
+ positionSlice[2][mi.Index+i] += int32(i % w)
+ }
+ }
+ }
+
+ positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
+
+ // Process through transformer layers
+ for i, layer := range m.TextModel.Layers {
+ m.Cache.SetLayer(i)
+
+ var lastLayerOutputs ml.Tensor
+ if i == len(m.TextModel.Layers)-1 {
+ lastLayerOutputs = batch.Outputs
+ }
+
+ hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, m.Cache, m.TextModel.TextModelOptions)
+ }
+
+ hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.TextModel.eps)
+ return m.Output.Forward(ctx, hiddenStates), nil
+}
+
+func init() {
+ model.Register("glmocr", New)
+}
diff --git a/model/models/glmocr/model_text.go b/model/models/glmocr/model_text.go
new file mode 100644
index 00000000000..ec9cd730181
--- /dev/null
+++ b/model/models/glmocr/model_text.go
@@ -0,0 +1,190 @@
+package glmocr
+
+import (
+ "math"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/rope"
+)
+
+type TextModelOptions struct {
+ hiddenSize int
+ numHeads int
+ numKVHeads int
+ headDim int
+ rotaryDim int
+ intermediateSize int
+ eps float32
+ ropeBase float32
+ mropeSections []int
+}
+
+func (o *TextModelOptions) applyMRoPE(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
+ // With 4 sections for [temporal, height, width, unused]
+ return nn.RoPE(ctx, states, positions, o.rotaryDim, o.ropeBase, 1.0, rope.WithMRoPE(o.mropeSections))
+}
+
+type TextSelfAttention struct {
+ Query *nn.Linear `gguf:"attn_q"`
+ Key *nn.Linear `gguf:"attn_k"`
+ Value *nn.Linear `gguf:"attn_v"`
+ Output *nn.Linear `gguf:"attn_out"`
+}
+
+func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
+ batchSize := hiddenStates.Dim(1)
+
+ // Separate Q, K, V projections
+ q := sa.Query.Forward(ctx, hiddenStates)
+ k := sa.Key.Forward(ctx, hiddenStates)
+ v := sa.Value.Forward(ctx, hiddenStates)
+
+ // Reshape for GQA
+ q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
+ k = k.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
+ v = v.Reshape(ctx, opts.headDim, opts.numKVHeads, batchSize)
+
+ // Apply M-RoPE (multi-resolution rotary position embeddings)
+ q = opts.applyMRoPE(ctx, q, positions)
+ k = opts.applyMRoPE(ctx, k, positions)
+
+ // Scaled dot-product attention with KV cache
+ scaleFactor := 1.0 / math.Sqrt(float64(opts.headDim))
+ kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
+ // Reshape attention output: [headDim, numHeads, batchSize] -> [numHeads*headDim, batchSize]
+ // Note: numHeads * headDim = 16 * 128 = 2048, which is the attention hidden size
+ kqv = kqv.Reshape(ctx, opts.numHeads*opts.headDim, batchSize)
+
+ return sa.Output.Forward(ctx, kqv)
+}
+
+type TextMLP struct {
+ Gate *nn.Linear `gguf:"ffn_gate"`
+ Up *nn.Linear `gguf:"ffn_up"`
+ Down *nn.Linear `gguf:"ffn_down"`
+}
+
+func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextModelOptions) ml.Tensor {
+ // SwiGLU: down(silu(gate(x)) * up(x))
+ gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ return mlp.Down.Forward(ctx, gate)
+}
+
+type TextDecoderLayer struct {
+ // Input layernorm (before attention)
+ AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
+ SelfAttention *TextSelfAttention
+ // Post self-attention layernorm (after attention, before residual add)
+ PostAttnNorm *nn.RMSNorm `gguf:"post_attn_norm"`
+
+ // FFN input layernorm (after first residual, before MLP)
+ FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
+ MLP *TextMLP
+ // Post MLP layernorm (after MLP, before residual add)
+ PostFFNNorm *nn.RMSNorm `gguf:"post_ffn_norm"`
+}
+
+func (l *TextDecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *TextModelOptions) ml.Tensor {
+ // Attention block
+ residual := hiddenStates
+ hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
+ hiddenStates = l.SelfAttention.Forward(ctx, hiddenStates, positions, cache, opts)
+ hiddenStates = l.PostAttnNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ // Prune to output positions in final layer
+ if outputs != nil {
+ hiddenStates = hiddenStates.Rows(ctx, outputs)
+ residual = residual.Rows(ctx, outputs)
+ }
+
+ hiddenStates = hiddenStates.Add(ctx, residual)
+
+ // MLP block
+ residual = hiddenStates
+ hiddenStates = l.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
+ hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
+ hiddenStates = l.PostFFNNorm.Forward(ctx, hiddenStates, opts.eps)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+
+ return hiddenStates
+}
+
+type TextModel struct {
+ TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+ Layers []TextDecoderLayer `gguf:"blk"`
+ OutputNorm *nn.RMSNorm `gguf:"output_norm"`
+ Output *nn.Linear `gguf:"output,alt:token_embd"`
+
+ *TextModelOptions
+
+ // positionCache stores the M-RoPE position for each token in the sequence.
+ // This is needed because image tokens share the same base position but have
+ // different height/width offsets, and the end token position depends on the
+ // image grid dimensions.
+ positionCache []int32
+ ropeDelta int32
+}
+
+func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+ // Clear position cache when KV cache shifts
+ m.positionCache = nil
+ m.ropeDelta = 0
+ return m.applyMRoPE(ctx, key, shift), nil
+}
+
+func newTextModel(c fs.Config) *TextModel {
+ hiddenSize := int(c.Uint("embedding_length", 1536))
+ numHeads := int(c.Uint("attention.head_count", 16))
+ numKVHeads := int(c.Uint("attention.head_count_kv", 8))
+ intermediateSize := int(c.Uint("feed_forward_length", 4608))
+ eps := c.Float("attention.layer_norm_rms_epsilon", 1e-5)
+ ropeBase := c.Float("rope.freq_base", 10000)
+
+ headDim := int(c.Uint("attention.key_length", uint32(hiddenSize/numHeads)))
+ ropeDim := int(c.Uint("rope.dimension_count", uint32(headDim)))
+ if ropeDim <= 0 {
+ ropeDim = headDim
+ }
+
+ mropeSections := c.Ints("rope.mrope_section")
+ var sectionInts []int
+
+ if len(mropeSections) > 0 {
+ sectionInts = make([]int, len(mropeSections))
+ for i, section := range mropeSections {
+ sectionInts[i] = int(section)
+ }
+ } else {
+ // Default to GLM-OCR's HF ratio (2:3:3) scaled to rotaryDim/2.
+ // For rotaryDim=64 this yields [8, 12, 12].
+ total := ropeDim / 2
+ if total <= 0 {
+ total = 32
+ }
+ s0 := total * 2 / 8
+ s1 := total * 3 / 8
+ s2 := total - s0 - s1
+ sectionInts = []int{s0, s1, s2}
+ }
+
+ // GGML rope_multi: sector = (dim_pair) % sum(sections), mapping each pair to its position dim
+ rotaryDim := ropeDim
+
+ return &TextModel{
+ Layers: make([]TextDecoderLayer, c.Uint("block_count", 16)),
+ TextModelOptions: &TextModelOptions{
+ hiddenSize: hiddenSize,
+ numHeads: numHeads,
+ numKVHeads: numKVHeads,
+ headDim: headDim,
+ rotaryDim: rotaryDim,
+ intermediateSize: intermediateSize,
+ eps: eps,
+ ropeBase: ropeBase,
+ mropeSections: sectionInts,
+ },
+ }
+}
diff --git a/model/models/glmocr/model_vision.go b/model/models/glmocr/model_vision.go
new file mode 100644
index 00000000000..6f8d1931150
--- /dev/null
+++ b/model/models/glmocr/model_vision.go
@@ -0,0 +1,355 @@
+package glmocr
+
+import (
+ "log/slog"
+ "math"
+ "slices"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/rope"
+)
+
+type Grid struct {
+ Height int // Number of patches in height direction
+ Width int // Number of patches in width direction
+ Temporal int
+ ImageHeight int // Full image height in pixels
+ ImageWidth int // Full image width in pixels
+}
+
+type VisionModelOptions struct {
+ hiddenSize int
+ numHeads int
+ headDim int
+ numChannels int
+ patchSize int
+ temporalPatchSize int
+ imageSize int
+ spatialMergeSize int
+ outHiddenSize int
+ intermediateSize int
+ eps float32
+}
+
+type VisionPatchEmbed struct {
+ Proj *nn.Conv2D `gguf:"patch_embd_0"`
+ Proj1 *nn.Conv2D `gguf:"patch_embd_1"`
+ Bias ml.Tensor `gguf:"patch_embd.bias"`
+}
+
+func (pe *VisionPatchEmbed) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
+ _ = grid // patches are already in merge-block order
+
+ // pixelValues shape: [patchDim, numPatches]
+ numPatches := pixelValues.Shape()[1]
+
+ // Reshape to [patchSize*patchSize, temporalPatchSize, numChannels, numPatches]
+ pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
+ // Permute to [temporalPatchSize, patchSize*patchSize, numChannels, numPatches]
+ pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+
+ // Slice temporal frames for Conv2D (simulate Conv3D)
+ in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
+ in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
+
+ s0, s1 := opts.patchSize, opts.patchSize
+ p0, p1 := 0, 0
+ d0, d1 := 1, 1
+ hiddenStates := pe.Proj.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
+
+ if pe.Proj1 != nil && opts.temporalPatchSize > 1 {
+ in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
+ in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
+ out1 := pe.Proj1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
+ hiddenStates = hiddenStates.Add(ctx, out1)
+ }
+
+ // Flatten to [hidden_size, num_patches]
+ hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, numPatches)
+
+ // Add patch bias - reshape from [hidden_size] to [hidden_size, 1] for broadcasting
+ if pe.Bias != nil {
+ hiddenStates = hiddenStates.Add(ctx, pe.Bias.Reshape(ctx, opts.hiddenSize, 1))
+ }
+
+ return hiddenStates
+}
+
+type VisionSelfAttention struct {
+ QKV *nn.Linear `gguf:"attn_qkv"`
+ QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
+ KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
+ Output *nn.Linear `gguf:"attn_out"`
+}
+
+func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+ batchSize := hiddenStates.Dim(1)
+
+ // Combined QKV projection: [3*hidden_size, batch_size]
+ qkv := sa.QKV.Forward(ctx, hiddenStates)
+
+ // Split using ChunkSections along dim 0 (handles byte offsets correctly)
+ // ChunkSections returns views - must make contiguous before further operations
+ chunks := qkv.ChunkSections(ctx, 0, opts.hiddenSize, opts.hiddenSize, opts.hiddenSize)
+ q := chunks[0].Contiguous(ctx)
+ k := chunks[1].Contiguous(ctx)
+ v := chunks[2].Contiguous(ctx)
+
+ // Reshape for multi-head attention: [hiddenSize, N] -> [headDim, numHeads, N]
+ q = q.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
+ k = k.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
+ v = v.Reshape(ctx, opts.headDim, opts.numHeads, batchSize)
+
+ // Apply Q-norm and K-norm after head reshape
+ // Weights are [headDim]=64, tensor is [headDim, numHeads, N]
+ q = sa.QNorm.Forward(ctx, q, opts.eps)
+ k = sa.KNorm.Forward(ctx, k, opts.eps)
+
+ // Apply rotary position embeddings with vision-style 2D positions.
+ // ggml's vision RoPE uses two position dimensions (H/W) with half-rotation pairs.
+ // We provide H/W sections and leave the remaining sections empty.
+ ropeFreqBase := float32(10000.0)
+ section := opts.headDim / 4
+ if section <= 0 {
+ section = 1
+ }
+ sections := []int{section, section, 0, 0}
+ q = nn.RoPE(ctx, q, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
+ k = nn.RoPE(ctx, k, positions, opts.headDim/2, ropeFreqBase, 1.0, rope.WithVision(sections))
+
+ // Scale factor for scaled dot-product attention
+ scale := 1.0 / math.Sqrt(float64(opts.headDim))
+
+ // Try flash attention first (ScaledDotProductAttention), fall back to manual
+ if sdpa, ok := q.(ml.ScaledDotProductAttention); ok {
+ attention := sdpa.ScaledDotProductAttention(ctx, k, v, nil, nil, nil, scale, false)
+ attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
+ return sa.Output.Forward(ctx, attention)
+ }
+
+ slog.Warn("glmocr: vision attention falling back to manual attention",
+ "batchSize", batchSize, "numHeads", opts.numHeads,
+ "hint", "set OLLAMA_FLASH_ATTENTION=1 to enable flash attention")
+
+ // Manual attention fallback
+ // q, k, v are [headDim, numHeads, batchSize] - GGML treats as 4D with implicit dim 3 = 1
+ q = q.Permute(ctx, 0, 2, 1, 3)
+ k = k.Permute(ctx, 0, 2, 1, 3)
+ v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
+
+ // Attention scores
+ kq := k.MulmatFullPrec(ctx, q)
+ kq = kq.Scale(ctx, scale)
+ kq = kq.Softmax(ctx)
+
+ // Attention output: v @ kq (note: v first)
+ kqv := v.Mulmat(ctx, kq)
+ attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
+ attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
+
+ return sa.Output.Forward(ctx, attention)
+}
+
+type VisionMLP struct {
+ Gate *nn.Linear `gguf:"ffn_gate"`
+ Up *nn.Linear `gguf:"ffn_up"`
+ Down *nn.Linear `gguf:"ffn_down"`
+}
+
+func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
+ // SwiGLU: down(silu(gate(x)) * up(x))
+ gate := mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ return mlp.Down.Forward(ctx, gate)
+}
+
+type VisionBlock struct {
+ Norm1 *nn.RMSNorm `gguf:"ln1"`
+ SelfAttention *VisionSelfAttention
+ Norm2 *nn.RMSNorm `gguf:"ln2"`
+ MLP *VisionMLP
+}
+
+func (b *VisionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+ // Pre-norm architecture
+ residual := hiddenStates
+ hiddenStates = b.Norm1.Forward(ctx, hiddenStates, opts.eps)
+ hiddenStates = b.SelfAttention.Forward(ctx, hiddenStates, positions, opts)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+
+ residual = hiddenStates
+ hiddenStates = b.Norm2.Forward(ctx, hiddenStates, opts.eps)
+ hiddenStates = b.MLP.Forward(ctx, hiddenStates)
+ hiddenStates = hiddenStates.Add(ctx, residual)
+
+ return hiddenStates
+}
+
+type VisionDownsample struct {
+ *nn.Conv2D
+}
+
+func (d *VisionDownsample) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts *VisionModelOptions) ml.Tensor {
+ // Apply spatial downsampling via Conv2D
+ // Input: [hidden_size, num_patches] where patches are in merge-block order
+
+ if d.Conv2D == nil || d.Weight == nil {
+ slog.Error("VisionDownsample weights not loaded - model may be corrupted or incompatible")
+ return hiddenStates // Return input unchanged as fallback
+ }
+
+ merge := opts.spatialMergeSize
+ numOutputTokens := (grid.Height / merge) * (grid.Width / merge)
+
+ // Step 1: Reshape to [hidden_size, merge, merge, num_output_tokens]
+ hiddenStates = hiddenStates.Reshape(ctx, opts.hiddenSize, merge, merge, numOutputTokens)
+
+ // Step 2: Permute to [merge, merge, hidden_size, num_output_tokens]
+ // ggml semantics: result.ne[perm[i]] = input.ne[i]
+ // So permute(2,0,1,3) on [1024,2,2,N] gives: ne[2]=1024, ne[0]=2, ne[1]=2, ne[3]=N -> [2,2,1024,N]
+ hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
+
+ // Step 3: Apply Conv2D without bias (bias added after reshape)
+ // Note: ggml_conv_2d takes (kernel, input) - kernel must be receiver in ollama
+ s0, s1 := merge, merge
+ p0, p1 := 0, 0
+ d0, d1 := 1, 1
+ hiddenStates = d.Weight.Conv2D(ctx, hiddenStates, s0, s1, p0, p1, d0, d1)
+
+ // Step 4: Reshape to [out_hidden_size, num_output_tokens]
+ hiddenStates = hiddenStates.Reshape(ctx, opts.outHiddenSize, numOutputTokens)
+
+ // Step 5: Add bias after reshape
+ // Reshape bias from [out_hidden_size] to [out_hidden_size, 1] for proper broadcasting
+ if d.Bias != nil {
+ hiddenStates = hiddenStates.Add(ctx, d.Bias.Reshape(ctx, opts.outHiddenSize, 1))
+ }
+
+ return hiddenStates
+}
+
+type PatchMerger struct {
+ // GGUF tags align with mm.* keys used by the model
+ Proj *nn.Linear `gguf:"model.fc"` // mm.model.fc.weight
+ PostLN *nn.LayerNorm `gguf:"post_norm"` // mm.post_norm.weight/bias
+ GateProj *nn.Linear `gguf:"gate"` // mm.gate.weight
+ UpProj *nn.Linear `gguf:"up"` // mm.up.weight
+ DownProj *nn.Linear `gguf:"down"` // mm.down.weight
+}
+
+func (m *PatchMerger) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
+ // Linear projection
+ hiddenStates = m.Proj.Forward(ctx, hiddenStates)
+
+ // Post-projection layer norm + GELU ERF
+ hiddenStates = m.PostLN.Forward(ctx, hiddenStates, opts.eps)
+ hiddenStates = hiddenStates.GELU_ERF(ctx)
+ // Force a copy to avoid in-place mutation issues with GELU_ERF
+ hiddenStates = hiddenStates.Contiguous(ctx)
+
+ // SwiGLU MLP: down(silu(gate(x)) * up(x))
+ gateOut := m.GateProj.Forward(ctx, hiddenStates)
+ upOut := m.UpProj.Forward(ctx, hiddenStates)
+ gate := gateOut.SILU(ctx, upOut)
+ return m.DownProj.Forward(ctx, gate)
+}
+
+type VisionModel struct {
+ PatchEmbed *VisionPatchEmbed
+ Blocks []VisionBlock `gguf:"blk"`
+ PostLN *nn.RMSNorm `gguf:"post_ln"`
+ // Note: Downsample is applied at the model level so mm.patch_merger stays separate
+
+ *VisionModelOptions
+}
+
+func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
+ // Extract patch embeddings from flattened patches
+ hiddenStates := m.PatchEmbed.Forward(ctx, pixelValues, grid, m.VisionModelOptions)
+
+ // Create position IDs for RoPE (spatial grid)
+ // Patches are already in merge-block order from preprocessing
+ positions := m.createPositions(ctx, grid)
+
+ // Process through vision blocks
+ for _, block := range m.Blocks {
+ hiddenStates = block.Forward(ctx, hiddenStates, positions, m.VisionModelOptions)
+ }
+
+ // Post-layernorm
+ hiddenStates = m.PostLN.Forward(ctx, hiddenStates, m.eps)
+
+ // Note: Downsample is now applied separately in Model.EncodeMultimodal
+ // so mm.patch_merger remains a distinct module
+
+ return hiddenStates
+}
+
+func (m *VisionModel) createPositions(ctx ml.Context, grid *Grid) ml.Tensor {
+ // Create spatial position IDs for vision RoPE
+ // Position layout: [height, width, height, width] - 4 sections for mrope
+ // Patches are in MERGE-BLOCK order after VisionPatchEmbed interleaving
+ // This follows the GLM-OCR rot_pos_emb layout
+ numPatches := grid.Height * grid.Width
+ mergeRatio := m.spatialMergeSize
+
+ // Build position arrays in merge-block order
+ // Each merge_ratio x merge_ratio block of patches is grouped together
+ hpos := make([]int32, numPatches)
+ wpos := make([]int32, numPatches)
+ ptr := 0
+ for y := 0; y < grid.Height; y += mergeRatio {
+ for x := 0; x < grid.Width; x += mergeRatio {
+ for dy := range mergeRatio {
+ for dx := range mergeRatio {
+ hpos[ptr] = int32(y + dy)
+ wpos[ptr] = int32(x + dx)
+ ptr++
+ }
+ }
+ }
+ }
+
+ // Build position arrays for 4 sections (mrope). ggml vision RoPE uses only H/W;
+ // keep remaining sections zeroed to match its conventions.
+ zeros := make([]int32, numPatches)
+ s := [][]int32{
+ hpos, // Section 0: height
+ wpos, // Section 1: width
+ zeros, // Section 2: unused
+ zeros, // Section 3: unused
+ }
+
+ return ctx.Input().FromInts(slices.Concat(s...), numPatches*4)
+}
+
+func newVisionModel(c fs.Config) *VisionModel {
+ hiddenSize := int(c.Uint("vision.embedding_length", 1024))
+ numHeads := int(c.Uint("vision.attention.head_count", 16))
+ numChannels := int(c.Uint("vision.num_channels", 3))
+ patchSize := int(c.Uint("vision.patch_size", 14))
+ temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
+ imageSize := int(c.Uint("vision.image_size", 336))
+ spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
+ outHiddenSize := int(c.Uint("vision.out_hidden_size", 1536))
+ intermediateSize := int(c.Uint("vision.intermediate_size", 4096))
+ eps := c.Float("vision.attention.layer_norm_rms_epsilon", 1e-5)
+
+ return &VisionModel{
+ Blocks: make([]VisionBlock, c.Uint("vision.block_count", 24)),
+ VisionModelOptions: &VisionModelOptions{
+ hiddenSize: hiddenSize,
+ numHeads: numHeads,
+ headDim: hiddenSize / numHeads,
+ numChannels: numChannels,
+ patchSize: patchSize,
+ temporalPatchSize: temporalPatchSize,
+ imageSize: imageSize,
+ spatialMergeSize: spatialMergeSize,
+ outHiddenSize: outHiddenSize,
+ intermediateSize: intermediateSize,
+ eps: eps,
+ },
+ }
+}
diff --git a/model/models/gptoss/model.go b/model/models/gptoss/model.go
index 9d1520bf346..2a3610a5617 100644
--- a/model/models/gptoss/model.go
+++ b/model/models/gptoss/model.go
@@ -12,11 +12,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Transformer struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TransformerBlocks []TransformerBlock `gguf:"blk"`
@@ -196,8 +197,8 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Optio
func New(c fs.Config) (model.Model, error) {
m := Transformer{
TransformerBlocks: make([]TransformerBlock, c.Uint("block_count")),
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/lfm2/cache.go b/model/models/lfm2/cache.go
new file mode 100644
index 00000000000..7e9d35f5f67
--- /dev/null
+++ b/model/models/lfm2/cache.go
@@ -0,0 +1,410 @@
+package lfm2
+
+import (
+ "slices"
+
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
+)
+
+var _ kvcache.Cache = (*HybridCache)(nil)
+
+// HybridCache stores:
+// - a standard causal KV cache for attention layers
+// - a per-sequence recurrent conv state for shortconv layers
+//
+// Conv state shape (per layer, per sequence): [dConv, hiddenSize] where dConv = L_cache - 1.
+// Stored internally as a tensor of shape [dConv * hiddenSize, maxSlots].
+type HybridCache struct {
+ kv *kvcache.Causal
+
+ backend ml.Backend
+ dtype ml.DType
+ maxSequences int
+
+ hiddenSize int
+ dConv int
+
+ // slot mapping for recurrent state
+ slotForSeq map[int]int
+ refCount []int
+ freeSlots []int
+
+ // per-layer conv state buffers (allocated lazily)
+ convCtxs map[int]ml.Context
+ convStates map[int]ml.Tensor // [dConv*hiddenSize, maxSlots]
+
+ // current forward batch (derived in StartForward)
+ curSeqs []int
+ curSlots []int
+ curSlotsInput ml.Tensor
+ curSeqTokens int
+
+ // track if EnsureWritable has been called for this forward pass
+ writableEnsured bool
+ // track any error from EnsureWritable to propagate later
+ writableError error
+}
+
+func NewHybridCache(shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error), hiddenSize, dConv int) *HybridCache {
+ return &HybridCache{
+ kv: kvcache.NewCausalCache(shift),
+ hiddenSize: hiddenSize,
+ dConv: dConv,
+ slotForSeq: make(map[int]int),
+ convCtxs: make(map[int]ml.Context),
+ convStates: make(map[int]ml.Tensor),
+ }
+}
+
+func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
+ c.backend = backend
+ c.dtype = dtype
+ c.maxSequences = maxSequences
+
+ // initialize slot allocator
+ c.refCount = make([]int, maxSequences)
+ c.freeSlots = c.freeSlots[:0]
+ for i := maxSequences - 1; i >= 0; i-- {
+ c.freeSlots = append(c.freeSlots, i)
+ }
+
+ c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
+}
+
+func (c *HybridCache) Close() {
+ for _, ctx := range c.convCtxs {
+ ctx.Close()
+ }
+ c.kv.Close()
+}
+
+func (c *HybridCache) SetConfig(config ml.CacheConfig) {
+ c.kv.SetConfig(config)
+}
+
+func (c *HybridCache) SetLayer(layer int) {
+ c.kv.SetLayer(layer)
+}
+
+func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
+ return c.kv.Get(ctx)
+}
+
+func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
+ c.kv.Put(ctx, key, value)
+}
+
+func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
+ if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
+ return err
+ }
+
+ // Derive equal-length sequence layout for shortconv.
+ // LFM2 shortconv assumes tokens form a [seq_tokens, seqs] grid.
+ seqCounts := make(map[int]int)
+ c.curSeqs = c.curSeqs[:0]
+ for _, s := range batch.Sequences {
+ if _, ok := seqCounts[s]; !ok {
+ c.curSeqs = append(c.curSeqs, s)
+ }
+ seqCounts[s]++
+ }
+
+ if len(c.curSeqs) == 0 {
+ return nil
+ }
+
+ nTokens := len(batch.Sequences)
+ nSeqs := len(c.curSeqs)
+ want := nTokens / nSeqs
+ for _, s := range c.curSeqs {
+ if seqCounts[s] != want {
+ return kvcache.ErrNotSupported
+ }
+ }
+
+ c.curSeqTokens = want
+
+ // When reserving memory for estimation, use fake slot assignments
+ // without modifying permanent state (slotForSeq, refCount)
+ if reserve {
+ c.curSlots = c.curSlots[:0]
+ slots := make([]int32, nSeqs)
+ for i := range nSeqs {
+ c.curSlots = append(c.curSlots, i)
+ slots[i] = int32(i)
+ }
+ c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
+ return nil
+ }
+
+ // Ensure slots exist for sequences in this batch
+ c.curSlots = c.curSlots[:0]
+ var newSlots []int // track newly allocated slots that need zeroing
+ for _, s := range c.curSeqs {
+ slot, ok := c.slotForSeq[s]
+ if !ok {
+ var err error
+ slot, err = c.allocSlot()
+ if err != nil {
+ return err
+ }
+ c.slotForSeq[s] = slot
+ c.refCount[slot] = 1
+ newSlots = append(newSlots, slot)
+ }
+ c.curSlots = append(c.curSlots, slot)
+ }
+
+ // Zero conv state for newly allocated slots to clear stale data from previous sequences
+ if len(newSlots) > 0 {
+ c.zeroConvSlots(ctx, newSlots)
+ }
+
+ // Create a tensor for the current slots
+ slots := make([]int32, len(c.curSlots))
+ for i, v := range c.curSlots {
+ slots[i] = int32(v)
+ }
+ c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
+
+ // Reset writable state for new forward pass
+ c.writableEnsured = false
+ c.writableError = nil
+
+ return nil
+}
+
+func (c *HybridCache) allocSlot() (int, error) {
+ if len(c.freeSlots) == 0 {
+ return 0, kvcache.ErrKvCacheFull
+ }
+ slot := c.freeSlots[len(c.freeSlots)-1]
+ c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
+ return slot, nil
+}
+
+func (c *HybridCache) freeSlot(slot int) {
+ // Bounds check before freeing
+ if slot >= 0 && slot < c.maxSequences {
+ c.freeSlots = append(c.freeSlots, slot)
+ }
+}
+
+// zeroConvSlots zeros the conv state for the given slots across all layers.
+// This must be called when recycling slots to prevent stale state from affecting new sequences.
+func (c *HybridCache) zeroConvSlots(ctx ml.Context, slots []int) {
+ if len(slots) == 0 || len(c.convStates) == 0 {
+ return
+ }
+
+ // Use input context for creating tensors
+ inputCtx := ctx.Input()
+
+ // Create slot indices tensor
+ slotIndices := make([]int32, len(slots))
+ for i, s := range slots {
+ slotIndices[i] = int32(s)
+ }
+ slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
+
+ // Create zero tensor for the slots (SetRows requires F32 source)
+ zeros := inputCtx.Zeros(ml.DTypeF32, c.dConv*c.hiddenSize, len(slots))
+
+ // Zero each layer's conv state for these slots
+ for _, buf := range c.convStates {
+ ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
+ }
+}
+
+// EnsureWritable ensures that sequences in the current batch have private (non-shared) conv slots.
+// Returns an error if slot allocation fails.
+func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
+ for i, seq := range c.curSeqs {
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ continue
+ }
+
+ // Bounds check
+ if slot < 0 || slot >= len(c.refCount) {
+ continue
+ }
+
+ if c.refCount[slot] <= 1 {
+ continue
+ }
+
+ newSlot, err := c.allocSlot()
+ if err != nil {
+ return err
+ }
+ c.refCount[slot]--
+ c.refCount[newSlot] = 1
+ c.slotForSeq[seq] = newSlot
+ c.curSlots[i] = newSlot
+
+ // Copy existing conv state for all initialized layers
+ for _, buf := range c.convStates {
+ // buf: [dConv*hiddenSize, maxSlots]
+ src := buf.Rows(ctx, ctx.Input().FromInts([]int32{int32(slot)}, 1))
+ // SetRows requires F32 source
+ srcF32 := src.Cast(ctx, ml.DTypeF32)
+ ctx.Forward(buf.SetRows(ctx, srcF32, ctx.Input().FromInts([]int32{int32(newSlot)}, 1)))
+ }
+ }
+
+ // Rebuild current slots tensor
+ slots := make([]int32, len(c.curSlots))
+ for i, v := range c.curSlots {
+ slots[i] = int32(v)
+ }
+ c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
+
+ return nil
+}
+
+func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
+ // KV cache shares prefix metadata (no copy) which is correct for prefix reuse.
+ c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
+
+ // For shortconv state we implement copy-on-write: dst shares the same slot as src.
+ // On the first write to dst, EnsureWritable will create a private slot.
+ if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
+ // Bounds check before decrementing
+ if dstSlot >= 0 && dstSlot < len(c.refCount) {
+ c.refCount[dstSlot]--
+ if c.refCount[dstSlot] <= 0 {
+ c.refCount[dstSlot] = 0
+ c.freeSlot(dstSlot)
+ }
+ }
+ delete(c.slotForSeq, dstSeq)
+ }
+
+ srcSlot, ok := c.slotForSeq[srcSeq]
+ if !ok {
+ // src may not have a slot yet; dst will allocate on demand
+ return
+ }
+
+ // Bounds check before incrementing
+ if srcSlot >= 0 && srcSlot < len(c.refCount) {
+ c.slotForSeq[dstSeq] = srcSlot
+ c.refCount[srcSlot]++
+ }
+}
+
+func (c *HybridCache) CanResume(seq int, pos int32) bool {
+ return c.kv.CanResume(seq, pos)
+}
+
+func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
+ if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
+ return err
+ }
+
+ // For recurrent state, any removal invalidates the state because
+ // the state at position N depends on all previous positions.
+ // Drop the slot mapping so it resets on next use.
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ return nil
+ }
+
+ // Bounds check
+ if slot < 0 || slot >= len(c.refCount) {
+ delete(c.slotForSeq, seq)
+ return nil
+ }
+
+ c.refCount[slot]--
+ if c.refCount[slot] <= 0 {
+ c.refCount[slot] = 0
+ c.freeSlot(slot)
+ }
+ delete(c.slotForSeq, seq)
+
+ return nil
+}
+
+func (c *HybridCache) slotsTensor() ml.Tensor {
+ return c.curSlotsInput
+}
+
+func (c *HybridCache) seqTokens() int {
+ return c.curSeqTokens
+}
+
+func (c *HybridCache) numSeqs() int {
+ return len(c.curSeqs)
+}
+
+func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
+ if buf, ok := c.convStates[layer]; ok {
+ return buf
+ }
+
+ if _, ok := c.convCtxs[layer]; !ok {
+ c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
+ }
+
+ buf := c.convCtxs[layer].Zeros(c.dtype, c.dConv*c.hiddenSize, c.maxSequences)
+ c.convStates[layer] = buf
+ return buf
+}
+
+// ConvState returns the conv state for current batch sequences as shape [dConv, hiddenSize, nSeqs].
+// Returns an error if copy-on-write allocation fails.
+func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
+ if !c.writableEnsured {
+ needsWritable := false
+ for _, seq := range c.curSeqs {
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ continue
+ }
+ if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
+ needsWritable = true
+ break
+ }
+ }
+
+ if needsWritable {
+ if err := c.EnsureWritable(ctx); err != nil {
+ c.writableError = err
+ }
+ }
+ c.writableEnsured = true
+ }
+
+ if c.writableError != nil {
+ return nil, c.writableError
+ }
+
+ buf := c.convBuffer(ctx, layer)
+ cur := buf.Rows(ctx, c.slotsTensor())
+ return cur.Reshape(ctx, c.dConv, c.hiddenSize, c.numSeqs()), nil
+}
+
+// UpdateConvState writes a new conv state for current batch sequences.
+// newState must have shape [dConv, hiddenSize, nSeqs].
+func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
+ buf := c.convBuffer(ctx, layer)
+ src := newState.Reshape(ctx, c.dConv*c.hiddenSize, c.numSeqs())
+ // SetRows requires F32 source
+ srcF32 := src.Cast(ctx, ml.DTypeF32)
+ ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
+}
+
+// IsSupportedForBatch returns true if the current batch layout supports shortconv.
+func (c *HybridCache) IsSupportedForBatch() bool {
+ return c.curSeqTokens > 0 && len(c.curSeqs) > 0
+}
+
+// Seqs returns the ordered unique sequences for the current forward pass.
+func (c *HybridCache) Seqs() []int {
+ return slices.Clone(c.curSeqs)
+}
diff --git a/model/models/lfm2/cache_test.go b/model/models/lfm2/cache_test.go
new file mode 100644
index 00000000000..f4c493c20cf
--- /dev/null
+++ b/model/models/lfm2/cache_test.go
@@ -0,0 +1,444 @@
+package lfm2
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+)
+
+// TestHybridCache tests verify the slot management logic of HybridCache.
+// These tests focus on the recurrent state slot allocation, reference counting,
+// and copy-on-write semantics without requiring a full ML backend.
+
+// createSlotOnlyCache creates a HybridCache with only the slot management
+// fields initialized. Used to test slot logic in isolation.
+func createSlotOnlyCache(maxSequences int) *HybridCache {
+ return &HybridCache{
+ hiddenSize: 256,
+ dConv: 3,
+ maxSequences: maxSequences,
+ refCount: make([]int, maxSequences),
+ freeSlots: initFreeSlots(maxSequences),
+ slotForSeq: make(map[int]int),
+ convCtxs: make(map[int]ml.Context),
+ convStates: make(map[int]ml.Tensor),
+ }
+}
+
+func initFreeSlots(n int) []int {
+ slots := make([]int, 0, n)
+ for i := n - 1; i >= 0; i-- {
+ slots = append(slots, i)
+ }
+ return slots
+}
+
+func TestHybridCache_SlotAllocation(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Verify initial state
+ if len(cache.freeSlots) != 4 {
+ t.Errorf("expected 4 free slots, got %d", len(cache.freeSlots))
+ }
+
+ // Allocate all slots
+ for range 4 {
+ slot, err := cache.allocSlot()
+ if err != nil {
+ t.Fatalf("allocSlot failed: %v", err)
+ }
+ cache.refCount[slot] = 1
+ }
+
+ // Should be full now
+ if len(cache.freeSlots) != 0 {
+ t.Errorf("expected 0 free slots, got %d", len(cache.freeSlots))
+ }
+
+ // Trying to allocate another should fail
+ _, err := cache.allocSlot()
+ if err != kvcache.ErrKvCacheFull {
+ t.Errorf("expected ErrKvCacheFull, got %v", err)
+ }
+}
+
+func TestHybridCache_SlotReuse(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Allocate a slot
+ slot1, _ := cache.allocSlot()
+ cache.refCount[slot1] = 1
+
+ // Free it
+ cache.refCount[slot1] = 0
+ cache.freeSlot(slot1)
+
+ // Allocate again - should get the same slot back (LIFO)
+ slot2, _ := cache.allocSlot()
+ if slot2 != slot1 {
+ t.Errorf("expected slot %d to be reused, got %d", slot1, slot2)
+ }
+}
+
+func TestHybridCache_SlotRefCounting_ShareSlot(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Allocate slot for seq 1
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ // Simulate sharing slot with seq 2 (copy-on-write style)
+ cache.slotForSeq[2] = slot1
+ cache.refCount[slot1]++
+
+ // Should share the same slot
+ if cache.slotForSeq[2] != slot1 {
+ t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
+ }
+
+ // Ref count should be 2
+ if cache.refCount[slot1] != 2 {
+ t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
+ }
+}
+
+func TestHybridCache_SlotRefCounting_DecRef(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Allocate slot for seq 1
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ // Share with seq 2
+ cache.slotForSeq[2] = slot1
+ cache.refCount[slot1]++
+
+ // Unshare seq 2
+ cache.refCount[slot1]--
+ delete(cache.slotForSeq, 2)
+
+ // Ref count should be back to 1
+ if cache.refCount[slot1] != 1 {
+ t.Errorf("expected refCount 1 after unshare, got %d", cache.refCount[slot1])
+ }
+
+ // Seq 2 should no longer have a slot
+ if _, ok := cache.slotForSeq[2]; ok {
+ t.Error("seq 2 should not have a slot after unshare")
+ }
+}
+
+func TestHybridCache_SlotFreeWhenUnused(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ initialFreeSlots := len(cache.freeSlots)
+
+ // Allocate slot for seq 1
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ // Free the slot when refCount drops to 0
+ cache.refCount[slot1]--
+ if cache.refCount[slot1] <= 0 {
+ cache.refCount[slot1] = 0
+ cache.freeSlot(slot1)
+ }
+ delete(cache.slotForSeq, 1)
+
+ // Slot should be freed
+ if len(cache.freeSlots) != initialFreeSlots {
+ t.Errorf("expected %d free slots, got %d", initialFreeSlots, len(cache.freeSlots))
+ }
+
+ // Ref count should be 0
+ if cache.refCount[slot1] != 0 {
+ t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
+ }
+}
+
+func TestHybridCache_SlotOverwrite(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Allocate slots for seq 1 and seq 2
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ slot2, _ := cache.allocSlot()
+ cache.slotForSeq[2] = slot2
+ cache.refCount[slot2] = 1
+
+ initialFreeSlots := len(cache.freeSlots)
+
+ // Simulate overwriting seq 2's slot with slot1 (sharing)
+ // First free the old slot
+ cache.refCount[slot2]--
+ if cache.refCount[slot2] <= 0 {
+ cache.refCount[slot2] = 0
+ cache.freeSlot(slot2)
+ }
+ // Then share slot1
+ cache.slotForSeq[2] = slot1
+ cache.refCount[slot1]++
+
+ // Seq 2 should now share slot1
+ if cache.slotForSeq[2] != slot1 {
+ t.Errorf("expected seq 2 to share slot %d, got %d", slot1, cache.slotForSeq[2])
+ }
+
+ // Old slot2 should be freed
+ if len(cache.freeSlots) != initialFreeSlots+1 {
+ t.Errorf("expected %d free slots, got %d", initialFreeSlots+1, len(cache.freeSlots))
+ }
+}
+
+func TestHybridCache_BoundsChecking(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Test freeing invalid slot (should not panic)
+ cache.freeSlot(-1)
+ cache.freeSlot(100) // out of bounds
+
+ // freeSlot does bounds checking, so invalid slots should be ignored
+ if len(cache.freeSlots) != 4 {
+ t.Errorf("invalid slots should not affect free list, got %d slots", len(cache.freeSlots))
+ }
+}
+
+func TestHybridCache_MultipleSequences_RefCounting(t *testing.T) {
+ cache := createSlotOnlyCache(8)
+
+ // Allocate slot for seq 1
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ // Fork to seq 2, 3, 4 (all share slot1)
+ for _, seq := range []int{2, 3, 4} {
+ cache.slotForSeq[seq] = slot1
+ cache.refCount[slot1]++
+ }
+
+ // Ref count should be 4
+ if cache.refCount[slot1] != 4 {
+ t.Errorf("expected refCount 4, got %d", cache.refCount[slot1])
+ }
+
+ // Remove seq 2, 3
+ for _, seq := range []int{2, 3} {
+ delete(cache.slotForSeq, seq)
+ cache.refCount[slot1]--
+ }
+
+ if cache.refCount[slot1] != 2 {
+ t.Errorf("expected refCount 2, got %d", cache.refCount[slot1])
+ }
+
+ // Slot should still be allocated (not in free list)
+ found := false
+ for _, s := range cache.freeSlots {
+ if s == slot1 {
+ found = true
+ break
+ }
+ }
+ if found {
+ t.Error("slot1 should not be in free list yet")
+ }
+
+ // Remove remaining sequences
+ for _, seq := range []int{1, 4} {
+ delete(cache.slotForSeq, seq)
+ cache.refCount[slot1]--
+ }
+
+ if cache.refCount[slot1] != 0 {
+ t.Errorf("expected refCount 0, got %d", cache.refCount[slot1])
+ }
+}
+
+func TestHybridCache_ChainedSharing(t *testing.T) {
+ cache := createSlotOnlyCache(8)
+
+ // Create seq 1
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ // Share 1 -> 2
+ cache.slotForSeq[2] = slot1
+ cache.refCount[slot1]++
+
+ // Share 2 -> 3 (should still share slot1)
+ cache.slotForSeq[3] = cache.slotForSeq[2] // which is slot1
+ cache.refCount[slot1]++
+
+ // All should share slot1
+ if cache.slotForSeq[1] != slot1 || cache.slotForSeq[2] != slot1 || cache.slotForSeq[3] != slot1 {
+ t.Error("all sequences should share slot1")
+ }
+
+ if cache.refCount[slot1] != 3 {
+ t.Errorf("expected refCount 3, got %d", cache.refCount[slot1])
+ }
+}
+
+func TestHybridCache_CacheParameters(t *testing.T) {
+ cache := NewHybridCache(nil, 512, 5) // hiddenSize=512, dConv=5
+
+ if cache.hiddenSize != 512 {
+ t.Errorf("expected hiddenSize 512, got %d", cache.hiddenSize)
+ }
+ if cache.dConv != 5 {
+ t.Errorf("expected dConv 5, got %d", cache.dConv)
+ }
+}
+
+func TestHybridCache_NumSeqs(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Initially no sequences
+ if cache.numSeqs() != 0 {
+ t.Errorf("expected 0 seqs, got %d", cache.numSeqs())
+ }
+
+ // Manually set up current batch state
+ cache.curSeqs = []int{1, 2, 3}
+
+ if cache.numSeqs() != 3 {
+ t.Errorf("expected 3 seqs, got %d", cache.numSeqs())
+ }
+}
+
+func TestHybridCache_SeqTokens(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Initially 0
+ if cache.seqTokens() != 0 {
+ t.Errorf("expected 0 seqTokens, got %d", cache.seqTokens())
+ }
+
+ // Manually set up current batch state
+ cache.curSeqTokens = 16
+
+ if cache.seqTokens() != 16 {
+ t.Errorf("expected 16 seqTokens, got %d", cache.seqTokens())
+ }
+}
+
+// Test that Seqs returns a clone of curSeqs
+func TestHybridCache_Seqs_ReturnsClone(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ cache.curSeqs = []int{1, 2, 3}
+
+ seqs := cache.Seqs()
+
+ // Modify returned slice
+ seqs[0] = 999
+
+ // Original should be unchanged
+ if cache.curSeqs[0] != 1 {
+ t.Error("Seqs should return a clone, not the original slice")
+ }
+}
+
+func TestHybridCache_IsSupportedForBatch(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Initially not supported (no batch set up)
+ if cache.IsSupportedForBatch() {
+ t.Error("expected IsSupportedForBatch to be false initially")
+ }
+
+ // Set up a valid batch
+ cache.curSeqTokens = 1
+ cache.curSeqs = []int{1}
+
+ if !cache.IsSupportedForBatch() {
+ t.Error("expected IsSupportedForBatch to be true with valid batch")
+ }
+}
+
+func TestHybridCache_ZeroConvSlots_EmptyInputs(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // zeroConvSlots should handle empty slots without panicking
+ cache.zeroConvSlots(nil, nil)
+ cache.zeroConvSlots(nil, []int{})
+
+ // zeroConvSlots should handle empty convStates without panicking
+ cache.zeroConvSlots(nil, []int{0, 1, 2})
+}
+
+func TestHybridCache_SlotRecycling_TracksNewSlots(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Allocate slot for seq 1
+ slot1, _ := cache.allocSlot()
+ cache.slotForSeq[1] = slot1
+ cache.refCount[slot1] = 1
+
+ // Free the slot (simulating sequence removal)
+ cache.refCount[slot1]--
+ cache.freeSlot(slot1)
+ delete(cache.slotForSeq, 1)
+
+ // Verify slot is in free list
+ if len(cache.freeSlots) != 4 {
+ t.Errorf("expected 4 free slots after freeing, got %d", len(cache.freeSlots))
+ }
+
+ // Allocate for new seq 2 - should get recycled slot
+ slot2, _ := cache.allocSlot()
+ if slot2 != slot1 {
+ t.Errorf("expected recycled slot %d, got %d", slot1, slot2)
+ }
+
+ // This recycled slot would need zeroing in the real implementation
+ // The actual zeroing is tested via integration tests since it requires ML context
+}
+
+func TestHybridCache_NewSequence_GetsTrackedForZeroing(t *testing.T) {
+ cache := createSlotOnlyCache(4)
+
+ // Simulate the slot allocation flow from StartForward
+ // When a sequence doesn't have a slot, it gets allocated and tracked as "new"
+
+ newSlots := []int{}
+
+ // Seq 1 doesn't have a slot - allocate and track
+ seq := 1
+ if _, ok := cache.slotForSeq[seq]; !ok {
+ slot, err := cache.allocSlot()
+ if err != nil {
+ t.Fatalf("allocSlot failed: %v", err)
+ }
+ cache.slotForSeq[seq] = slot
+ cache.refCount[slot] = 1
+ newSlots = append(newSlots, slot)
+ }
+
+ // Verify newSlots contains the allocated slot
+ if len(newSlots) != 1 {
+ t.Errorf("expected 1 new slot, got %d", len(newSlots))
+ }
+
+ // Seq 1 already has a slot - should NOT be tracked as new
+ newSlots2 := []int{}
+ if _, ok := cache.slotForSeq[seq]; !ok {
+ slot, _ := cache.allocSlot()
+ cache.slotForSeq[seq] = slot
+ cache.refCount[slot] = 1
+ newSlots2 = append(newSlots2, slot)
+ }
+
+ // Verify no new slots for existing sequence
+ if len(newSlots2) != 0 {
+ t.Errorf("expected 0 new slots for existing sequence, got %d", len(newSlots2))
+ }
+}
diff --git a/model/models/lfm2/model.go b/model/models/lfm2/model.go
new file mode 100644
index 00000000000..51e40d3c35b
--- /dev/null
+++ b/model/models/lfm2/model.go
@@ -0,0 +1,254 @@
+package lfm2
+
+import (
+ "cmp"
+ "math"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/rope"
+ "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
+)
+
+type Options struct {
+ hiddenSize int
+ headDim, ropeDim int
+
+ eps, ropeBase, ropeScale float32
+
+ ropeType string
+ originalContextLength int
+
+ // per-layer head counts (LFM2 alternates attention and recurrent layers)
+ numHeadsByLayer []int
+ numKVHeadsByLayer []int
+}
+
+func (o Options) headDimValue() int {
+ // Head dim is shared across layers; fall back to first attention layer head count.
+ for _, h := range o.numHeadsByLayer {
+ if h > 0 {
+ return cmp.Or(o.headDim, o.hiddenSize/h)
+ }
+ }
+ return cmp.Or(o.headDim, o.hiddenSize)
+}
+
+func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
+ opts := []func(*rope.Options){rope.WithTypeNeoX()}
+ if o.ropeType == "yarn" {
+ attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
+ opts = append(opts,
+ rope.WithOriginalContextLength(o.originalContextLength),
+ rope.WithExtrapolationFactor(1.),
+ rope.WithAttentionFactor(attnFactor),
+ )
+ }
+
+ headCount := 1
+ for _, h := range o.numHeadsByLayer {
+ if h > 0 {
+ headCount = h
+ break
+ }
+ }
+ return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/headCount), o.ropeBase, 1./o.ropeScale, opts...)
+}
+
+type Model struct {
+ model.Base
+ tokenizer.Tokenizer
+
+ TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+ Layers []Layer `gguf:"blk"`
+ OutputNorm *nn.RMSNorm `gguf:"output_norm,alt:token_embd_norm"`
+ Output *nn.Linear `gguf:"output,alt:token_embd"`
+
+ Options
+}
+
+func New(c fs.Config) (model.Model, error) {
+ if c.Uint("expert_count") > 0 {
+ return nil, model.ErrUnsupportedModel
+ }
+
+ if c.String("tokenizer.ggml.model") != "gpt2" {
+ return nil, model.ErrUnsupportedTokenizer
+ }
+
+ vocabulary := tokenizer.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Scores: c.Floats("tokenizer.ggml.scores"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ Merges: c.Strings("tokenizer.ggml.merges"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+ EOS: append(
+ []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
+ c.Ints("tokenizer.ggml.eos_token_ids")...,
+ ),
+ }
+
+ var pretokenizers []string
+ switch c.String("tokenizer.ggml.pre") {
+ case "default":
+ // use default BPE pretokenizer
+ default:
+ // llama-bpe style (default for LFM2)
+ pretokenizers = []string{
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
+ }
+ }
+
+ m := Model{
+ Tokenizer: tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...),
+ Layers: make([]Layer, c.Uint("block_count")),
+ Options: Options{
+ hiddenSize: int(c.Uint("embedding_length")),
+ headDim: int(c.Uint("attention.key_length")),
+ ropeDim: int(c.Uint("rope.dimension_count")),
+ eps: c.Float("attention.layer_norm_rms_epsilon"),
+ ropeType: c.String("rope.scaling.type"),
+ ropeBase: c.Float("rope.freq_base"),
+ ropeScale: c.Float("rope.scaling.factor", 1),
+ originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
+ },
+ }
+
+ type headCounts interface {
+ HeadCount() []uint64
+ HeadCountKV() []uint64
+ }
+ hc, ok := c.(headCounts)
+ if !ok {
+ return nil, model.ErrUnsupportedModel
+ }
+
+ headCount := hc.HeadCount()
+ headCountKV := hc.HeadCountKV()
+
+ m.numHeadsByLayer = make([]int, len(m.Layers))
+ m.numKVHeadsByLayer = make([]int, len(m.Layers))
+ for i := range m.Layers {
+ m.numHeadsByLayer[i] = int(headCount[i])
+ m.numKVHeadsByLayer[i] = int(headCountKV[i])
+
+ if m.numKVHeadsByLayer[i] == 0 {
+ m.Layers[i].Operator = &ShortConv{}
+ } else {
+ m.Layers[i].Operator = &Attention{}
+ }
+ }
+
+ lCache := int(c.Uint("shortconv.l_cache"))
+ dConv := max(0, lCache-1)
+ m.Cache = NewHybridCache(m.Shift, m.hiddenSize, dConv)
+ return &m, nil
+}
+
+type Operator interface {
+ Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor
+}
+
+type Attention struct {
+ Query *nn.Linear `gguf:"attn_q"`
+ QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
+ Key *nn.Linear `gguf:"attn_k"`
+ KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
+ Value *nn.Linear `gguf:"attn_v"`
+ Output *nn.Linear `gguf:"attn_output,alt:attn_out"`
+}
+
+func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
+ batchSize := hiddenStates.Dim(1)
+ headDim := opts.headDimValue()
+ numHeads := opts.numHeadsByLayer[layer]
+ numKVHeads := opts.numKVHeadsByLayer[layer]
+
+ query := sa.Query.Forward(ctx, hiddenStates)
+ key := sa.Key.Forward(ctx, hiddenStates)
+ value := sa.Value.Forward(ctx, hiddenStates)
+
+ query = query.Reshape(ctx, headDim, numHeads, batchSize)
+ key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
+ value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
+
+ query = sa.QueryNorm.Forward(ctx, query, opts.eps)
+ key = sa.KeyNorm.Forward(ctx, key, opts.eps)
+
+ query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
+ key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
+
+ attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
+ attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
+ return sa.Output.Forward(ctx, attention)
+}
+
+type MLP struct {
+ Up *nn.Linear `gguf:"ffn_up"`
+ Down *nn.Linear `gguf:"ffn_down"`
+ Gate *nn.Linear `gguf:"ffn_gate"`
+}
+
+func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
+ hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
+ return mlp.Down.Forward(ctx, hiddenState)
+}
+
+type Layer struct {
+ AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
+ Operator Operator
+ MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
+ MLP *MLP
+}
+
+func (l *Layer) Forward(ctx ml.Context, layer int, hiddenState, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) ml.Tensor {
+ residual := hiddenState
+
+ hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
+ hiddenState = l.Operator.Forward(ctx, hiddenState, positions, cache, layer, opts)
+
+ if outputs != nil {
+ hiddenState = hiddenState.Rows(ctx, outputs)
+ residual = residual.Rows(ctx, outputs)
+ }
+
+ hiddenState = hiddenState.Add(ctx, residual)
+ residual = hiddenState
+
+ hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
+ hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
+ return hiddenState.Add(ctx, residual)
+}
+
+func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+ return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
+}
+
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
+
+ hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
+
+ for i, layer := range m.Layers {
+ m.Cache.SetLayer(i)
+
+ var outputs ml.Tensor
+ if i == len(m.Layers)-1 {
+ outputs = batch.Outputs
+ }
+
+ hiddenState = layer.Forward(ctx, i, hiddenState, positions, outputs, m.Cache.(*HybridCache), &m.Options)
+ }
+
+ hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
+ return m.Output.Forward(ctx, hiddenState), nil
+}
+
+func init() {
+ model.Register("lfm2", New)
+}
diff --git a/model/models/lfm2/shortconv.go b/model/models/lfm2/shortconv.go
new file mode 100644
index 00000000000..d1f6c15feef
--- /dev/null
+++ b/model/models/lfm2/shortconv.go
@@ -0,0 +1,50 @@
+package lfm2
+
+import (
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+)
+
+type shortConvKernel struct {
+ Weight ml.Tensor `gguf:"weight"`
+}
+
+// ShortConv implements the LFM2 short-convolution block (GGML_OP_SSM_CONV) with a recurrent
+// state stored in the HybridCache.
+type ShortConv struct {
+ Conv *shortConvKernel `gguf:"shortconv.conv"`
+ InProj *nn.Linear `gguf:"shortconv.in_proj"`
+ OutProj *nn.Linear `gguf:"shortconv.out_proj"`
+}
+
+func (sc *ShortConv) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ ml.Tensor, cache *HybridCache, layer int, opts *Options) ml.Tensor {
+ nSeqs := cache.numSeqs()
+ seqTokens := cache.seqTokens()
+ hiddenSize := hiddenStates.Dim(0)
+ if nSeqs <= 0 || seqTokens <= 0 || hiddenStates.Dim(1) != nSeqs*seqTokens {
+ panic("lfm2: unsupported batch layout for shortconv")
+ }
+
+ bcx := sc.InProj.Forward(ctx, hiddenStates).Reshape(ctx, 3*hiddenSize, seqTokens, nSeqs)
+
+ elementSize := bcx.Stride(0)
+ b := bcx.View(ctx, 0*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
+ c := bcx.View(ctx, 1*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
+ x := bcx.View(ctx, 2*hiddenSize*elementSize, hiddenSize, bcx.Stride(1), seqTokens, bcx.Stride(2), nSeqs)
+
+ bx := b.Mul(ctx, x).Permute(ctx, 1, 0, 2, 3)
+
+ state, err := cache.ConvState(ctx, layer)
+ if err != nil {
+ panic("lfm2: failed to get conv state: " + err.Error())
+ }
+ sx := state.Concat(ctx, bx, 0)
+
+ convOut := sx.SSMConv(ctx, sc.Conv.Weight)
+ y := c.Mul(ctx, convOut)
+
+ dConv := sx.Dim(0) - seqTokens
+ cache.UpdateConvState(ctx, layer, sx.Slice(ctx, 0, sx.Dim(0)-dConv, sx.Dim(0), 1))
+
+ return sc.OutProj.Forward(ctx, y.Reshape(ctx, hiddenSize, seqTokens*nSeqs))
+}
diff --git a/model/models/llama/model.go b/model/models/llama/model.go
index 5ff4894e47d..ad95c5c0029 100644
--- a/model/models/llama/model.go
+++ b/model/models/llama/model.go
@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -25,7 +26,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -41,8 +42,8 @@ func New(c fs.Config) (model.Model, error) {
return nil, model.ErrUnsupportedModel
}
- var processor model.TextProcessor
- vocabulary := model.Vocabulary{
+ var processor tokenizer.Tokenizer
+ vocabulary := tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
@@ -80,16 +81,16 @@ func New(c fs.Config) (model.Model, error) {
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
- processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...)
+ processor = tokenizer.NewBytePairEncoding(&vocabulary, pretokenizers...)
case "llama":
- processor = model.NewSentencePiece(&vocabulary)
+ processor = tokenizer.NewSentencePiece(&vocabulary)
default:
return nil, model.ErrUnsupportedTokenizer
}
m := Model{
- TextProcessor: processor,
- Layers: make([]Layer, c.Uint("block_count")),
+ Tokenizer: processor,
+ Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
diff --git a/model/models/llama4/model.go b/model/models/llama4/model.go
index 4a22bc4bb38..c8373b7ebdb 100644
--- a/model/models/llama4/model.go
+++ b/model/models/llama4/model.go
@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
ImageProcessor
*VisionModel `gguf:"v"`
@@ -33,8 +34,8 @@ func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
func New(c fs.Config) (model.Model, error) {
m := Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go
index 8230dde3942..8485d34ca1a 100644
--- a/model/models/mistral3/model.go
+++ b/model/models/mistral3/model.go
@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -28,12 +29,12 @@ type Model struct {
var _ model.MultimodalProcessor = (*Model)(nil)
// Implement TextProcessor interface
-var _ model.TextProcessor = (*Model)(nil)
+var _ tokenizer.Tokenizer = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go
index 58fd5adcfca..2d032467de7 100644
--- a/model/models/mllama/model.go
+++ b/model/models/mllama/model.go
@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
*VisionModel `gguf:"v"`
*TextModel
@@ -32,8 +33,8 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/models.go b/model/models/models.go
index d900f7cc39e..d4a8dc536a0 100644
--- a/model/models/models.go
+++ b/model/models/models.go
@@ -8,7 +8,9 @@ import (
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"
_ "github.com/ollama/ollama/model/models/glm4moelite"
+ _ "github.com/ollama/ollama/model/models/glmocr"
_ "github.com/ollama/ollama/model/models/gptoss"
+ _ "github.com/ollama/ollama/model/models/lfm2"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
@@ -18,5 +20,6 @@ import (
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"
+ _ "github.com/ollama/ollama/model/models/qwen3next"
_ "github.com/ollama/ollama/model/models/qwen3vl"
)
diff --git a/model/models/nomicbert/model.go b/model/models/nomicbert/model.go
index 096d046a061..1d60b178cc3 100644
--- a/model/models/nomicbert/model.go
+++ b/model/models/nomicbert/model.go
@@ -11,11 +11,12 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
@@ -178,29 +179,6 @@ func New(c fs.Config) (model.Model, error) {
numHeads := int(c.Uint("attention.head_count"))
headDim := hiddenSize / numHeads
- processor := model.NewWordPiece(
- &model.Vocabulary{
- Values: c.Strings("tokenizer.ggml.tokens"),
- Scores: c.Floats("tokenizer.ggml.scores"),
- Types: c.Ints("tokenizer.ggml.token_type"),
- AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
- BOS: []int32{
- int32(cmp.Or(
- c.Uint("tokenizer.ggml.cls_token_id"),
- c.Uint("tokenizer.ggml.bos_token_id"),
- )),
- },
- AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
- EOS: []int32{
- int32(cmp.Or(
- c.Uint("tokenizer.ggml.separator_token_id"),
- c.Uint("tokenizer.ggml.eos_token_id"),
- )),
- },
- },
- false,
- )
-
blockCount := int(c.Uint("block_count"))
moeEveryNLayers := int(c.Uint("moe_every_n_layers", 0))
layers := make([]EncoderLayer, blockCount)
@@ -219,8 +197,29 @@ func New(c fs.Config) (model.Model, error) {
}
return &Model{
- TextProcessor: processor,
- Layers: layers,
+ Tokenizer: tokenizer.NewWordPiece(
+ &tokenizer.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Scores: c.Floats("tokenizer.ggml.scores"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
+ BOS: []int32{
+ int32(cmp.Or(
+ c.Uint("tokenizer.ggml.cls_token_id"),
+ c.Uint("tokenizer.ggml.bos_token_id"),
+ )),
+ },
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
+ EOS: []int32{
+ int32(cmp.Or(
+ c.Uint("tokenizer.ggml.separator_token_id"),
+ c.Uint("tokenizer.ggml.eos_token_id"),
+ )),
+ },
+ },
+ false,
+ ),
+ Layers: layers,
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,
diff --git a/model/models/olmo3/model.go b/model/models/olmo3/model.go
index 523c00e688d..44a746cd2c1 100644
--- a/model/models/olmo3/model.go
+++ b/model/models/olmo3/model.go
@@ -11,6 +11,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
const (
@@ -33,7 +34,7 @@ type Options struct {
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -44,28 +45,24 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
- vocabulary := model.Vocabulary{
- Values: c.Strings("tokenizer.ggml.tokens"),
- Scores: c.Floats("tokenizer.ggml.scores"),
- Types: c.Ints("tokenizer.ggml.token_type"),
- Merges: c.Strings("tokenizer.ggml.merges"),
- AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
- BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
- AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
- EOS: append(
- []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
- c.Ints("tokenizer.ggml.eos_token_ids")...,
- ),
- }
-
- processor := model.NewBytePairEncoding(
- &vocabulary,
- "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
- )
-
m := Model{
- TextProcessor: processor,
- Layers: make([]Layer, c.Uint("block_count")),
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Scores: c.Floats("tokenizer.ggml.scores"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ Merges: c.Strings("tokenizer.ggml.merges"),
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
+ BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+ EOS: append(
+ []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
+ c.Ints("tokenizer.ggml.eos_token_ids")...,
+ ),
+ },
+ "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
+ ),
+ Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
diff --git a/model/models/qwen2/model.go b/model/models/qwen2/model.go
index 66f546ae617..17ed0d32722 100644
--- a/model/models/qwen2/model.go
+++ b/model/models/qwen2/model.go
@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -92,7 +93,7 @@ func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs m
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
@@ -139,8 +140,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go
index 81296a81bc8..68230707fc5 100644
--- a/model/models/qwen25vl/model.go
+++ b/model/models/qwen25vl/model.go
@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -27,8 +28,8 @@ var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/qwen3/embed.go b/model/models/qwen3/embed.go
index c03888d45c6..c10390ff7e3 100644
--- a/model/models/qwen3/embed.go
+++ b/model/models/qwen3/embed.go
@@ -7,11 +7,12 @@ import (
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type embedModel struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
*Model
poolingType pooling.Type
@@ -34,8 +35,8 @@ func newEmbed(c fs.Config) (model.Model, error) {
layers[i].MLP = &dense{}
}
m := embedModel{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go
index d7747364e52..2602be44ed0 100644
--- a/model/models/qwen3/model.go
+++ b/model/models/qwen3/model.go
@@ -12,6 +12,7 @@ import (
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Options struct {
@@ -159,7 +160,7 @@ func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tens
type Model struct {
model.Base
- model.BytePairEncoding
+ tokenizer.Tokenizer
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -218,8 +219,8 @@ func New(c fs.Config) (model.Model, error) {
}
m := Model{
- BytePairEncoding: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/models/qwen3next/attention.go b/model/models/qwen3next/attention.go
new file mode 100644
index 00000000000..ee4a06beaa4
--- /dev/null
+++ b/model/models/qwen3next/attention.go
@@ -0,0 +1,103 @@
+package qwen3next
+
+import (
+ "errors"
+ "math"
+
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+)
+
+// ErrUnsupportedBatchLayout is returned when the batch layout is incompatible
+// with the attention layer requirements.
+var ErrUnsupportedBatchLayout = errors.New("qwen3next: unsupported batch layout")
+
+// FullAttention implements gated attention with QK normalization and sigmoid-gated output.
+// Key differences from standard attention:
+// - Q projection outputs 2x size (Q + gate interleaved)
+// - Both Q and K have RMSNorm
+// - Output is gated: attn * sigmoid(gate)
+type FullAttention struct {
+ Query *nn.Linear `gguf:"attn_q"` // outputs [n_embd_head * 2, n_head]
+ QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
+ Key *nn.Linear `gguf:"attn_k"`
+ KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
+ Value *nn.Linear `gguf:"attn_v"`
+ Output *nn.Linear `gguf:"attn_output"`
+}
+
+func (sa *FullAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
+ // Use Dim() instead of Shape() for consistent behavior during graph construction
+ hiddenDim := hiddenStates.Dim(0)
+ batchSize := hiddenStates.Dim(1)
+ nSeqs := hiddenStates.Dim(2) // 0 if 2D tensor
+
+ if cache != nil && cache.IsSupportedForBatch() {
+ seqTokens := cache.seqTokens()
+ seqs := cache.numSeqs()
+ if seqTokens > 0 && seqs > 0 {
+ if nSeqs > 0 {
+ // 3D tensor: [hiddenDim, seqTokens, nSeqs]
+ if batchSize != seqTokens || nSeqs != seqs {
+ return nil, ErrUnsupportedBatchLayout
+ }
+ hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, seqTokens*seqs)
+ batchSize = seqTokens * seqs
+ } else if batchSize != seqTokens*seqs {
+ return nil, ErrUnsupportedBatchLayout
+ }
+ }
+ }
+ headDim := opts.headDim()
+ numHeads := opts.numHeads
+
+ // Q projection outputs query + gate interleaved
+ qFull := sa.Query.Forward(ctx, hiddenStates)
+
+ // Reshape to [headDim * 2, numHeads, batchSize]
+ qFull = qFull.Reshape(ctx, headDim*2, numHeads, batchSize)
+
+ // Split Q and gate along dimension 0
+ // Q: first headDim elements, gate: second headDim elements
+ query := qFull.Slice(ctx, 0, 0, headDim, 1)
+ gate := qFull.Slice(ctx, 0, headDim, headDim*2, 1)
+
+ // Make query contiguous for further operations
+ query = query.Contiguous(ctx, headDim, numHeads, batchSize)
+
+ // K and V projections
+ key := sa.Key.Forward(ctx, hiddenStates)
+ value := sa.Value.Forward(ctx, hiddenStates)
+
+ // Derive numKVHeads from tensor dimensions (per-layer value)
+ numKVHeads := key.Dim(0) / headDim
+
+ key = key.Reshape(ctx, headDim, numKVHeads, batchSize)
+ value = value.Reshape(ctx, headDim, numKVHeads, batchSize)
+
+ // Apply QK normalization
+ query = sa.QueryNorm.Forward(ctx, query, opts.eps)
+ key = sa.KeyNorm.Forward(ctx, key, opts.eps)
+
+ // Apply RoPE
+ query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
+ key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
+
+ // Standard attention computation
+ scale := opts.attentionScale
+ if scale == 0 {
+ scale = 1.0 / math.Sqrt(float64(headDim))
+ }
+ attention := nn.Attention(ctx, query, key, value, scale, cache)
+
+ // Flatten heads
+ attention = attention.Reshape(ctx, headDim*numHeads, batchSize)
+
+ // Apply sigmoid gate
+ // gate shape: [headDim, numHeads, batchSize] -> [headDim*numHeads, batchSize]
+ gate = gate.Contiguous(ctx, headDim*numHeads, batchSize)
+ gateSigmoid := gate.Sigmoid(ctx)
+ attention = attention.Mul(ctx, gateSigmoid)
+
+ return sa.Output.Forward(ctx, attention), nil
+}
diff --git a/model/models/qwen3next/cache.go b/model/models/qwen3next/cache.go
new file mode 100644
index 00000000000..86ee2b58d66
--- /dev/null
+++ b/model/models/qwen3next/cache.go
@@ -0,0 +1,596 @@
+package qwen3next
+
+import (
+ "math"
+ "slices"
+
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
+)
+
+var _ kvcache.Cache = (*HybridCache)(nil)
+
+// HybridCache stores:
+// - a standard causal KV cache for full attention layers
+// - per-sequence conv state for linear attention layers
+// - per-sequence delta state for linear attention layers
+//
+// Conv state shape (per layer, per sequence): [convKernelSize-1, convChannels]
+// Delta state shape (per layer, per sequence): [headVDim, headVDim * numVHeads]
+type HybridCache struct {
+ kv *kvcache.Causal
+
+ backend ml.Backend
+ dtype ml.DType
+ maxSequences int
+
+ // Conv state dimensions
+ convDim int // convKernelSize - 1
+ convChannels int // d_inner + 2 * num_k_heads * head_k_dim
+
+ // Delta state dimensions
+ deltaStateSize int // headVDim * headVDim * numVHeads
+
+ // slot mapping for recurrent state (copy-on-write)
+ slotForSeq map[int]int
+ refCount []int
+ freeSlots []int
+
+ // per-layer conv state buffers (allocated lazily)
+ convCtxs map[int]ml.Context
+ convStates map[int]ml.Tensor // [convDim*convChannels, maxSlots]
+
+ // per-layer delta state buffers (allocated lazily)
+ deltaCtxs map[int]ml.Context
+ deltaStates map[int]ml.Tensor // [deltaStateSize, maxSlots]
+
+ // recurrent checkpoints (per slot)
+ checkpointCount int
+ checkpointMinPos int32
+ checkpointInterval int32
+ checkpointCtxSize int
+ checkpoints map[int]*slotCheckpointStore
+ pendingRestore map[int]checkpointRestore
+ curCheckpointPos []int32
+ curCheckpointSlots map[int]int
+ reserveCheckpoints bool
+ checkpointConvCtxs map[int]ml.Context
+ checkpointDeltaCtxs map[int]ml.Context
+ checkpointReserved map[int]struct{}
+
+ // current forward batch (derived in StartForward)
+ curSeqs []int
+ curSlots []int
+ curSlotsInput ml.Tensor
+ curSeqTokens int
+
+ // track if EnsureWritable has been called for this forward pass
+ writableEnsured bool
+ writableError error
+}
+
+func NewHybridCache(
+ shift func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error),
+ convDim, convChannels, deltaStateSize int,
+) *HybridCache {
+ return &HybridCache{
+ kv: kvcache.NewCausalCache(shift),
+ convDim: convDim,
+ convChannels: convChannels,
+ deltaStateSize: deltaStateSize,
+ slotForSeq: make(map[int]int),
+ convCtxs: make(map[int]ml.Context),
+ convStates: make(map[int]ml.Tensor),
+ deltaCtxs: make(map[int]ml.Context),
+ deltaStates: make(map[int]ml.Tensor),
+ checkpointCount: checkpointCountDefault,
+ checkpointMinPos: checkpointMinPosDefault,
+ checkpointInterval: checkpointIntervalDefault,
+ checkpoints: make(map[int]*slotCheckpointStore),
+ pendingRestore: make(map[int]checkpointRestore),
+ curCheckpointSlots: make(map[int]int),
+ checkpointConvCtxs: make(map[int]ml.Context),
+ checkpointDeltaCtxs: make(map[int]ml.Context),
+ checkpointReserved: make(map[int]struct{}),
+ }
+}
+
+func (c *HybridCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
+ c.backend = backend
+ c.dtype = dtype
+ c.maxSequences = maxSequences
+ c.checkpoints = make(map[int]*slotCheckpointStore)
+ c.pendingRestore = make(map[int]checkpointRestore)
+ c.curCheckpointPos = c.curCheckpointPos[:0]
+ c.curCheckpointSlots = make(map[int]int)
+ c.checkpointReserved = make(map[int]struct{})
+ c.checkpointCtxSize = c.checkpointCount * c.maxSequences
+ if c.checkpointCtxSize < 8 {
+ c.checkpointCtxSize = 8
+ }
+
+ // initialize slot allocator
+ c.refCount = make([]int, maxSequences)
+ c.freeSlots = c.freeSlots[:0]
+ for i := maxSequences - 1; i >= 0; i-- {
+ c.freeSlots = append(c.freeSlots, i)
+ }
+
+ c.kv.Init(backend, dtype, maxSequences, capacity, maxBatch)
+}
+
+func (c *HybridCache) Close() {
+ for _, ctx := range c.convCtxs {
+ ctx.Close()
+ }
+ for _, ctx := range c.deltaCtxs {
+ ctx.Close()
+ }
+ for _, ctx := range c.checkpointConvCtxs {
+ ctx.Close()
+ }
+ for _, ctx := range c.checkpointDeltaCtxs {
+ ctx.Close()
+ }
+ c.kv.Close()
+}
+
+func (c *HybridCache) SetConfig(config ml.CacheConfig) {
+ c.kv.SetConfig(config)
+}
+
+func (c *HybridCache) SetLayer(layer int) {
+ c.kv.SetLayer(layer)
+}
+
+func (c *HybridCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
+ return c.kv.Get(ctx)
+}
+
+func (c *HybridCache) Put(ctx ml.Context, key, value ml.Tensor) {
+ c.kv.Put(ctx, key, value)
+}
+
+func (c *HybridCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
+ if err := c.kv.StartForward(ctx, batch, reserve); err != nil {
+ return err
+ }
+
+ // Derive equal-length sequence layout for recurrent layers
+ seqCounts := make(map[int]int)
+ c.curSeqs = c.curSeqs[:0]
+ for _, s := range batch.Sequences {
+ if _, ok := seqCounts[s]; !ok {
+ c.curSeqs = append(c.curSeqs, s)
+ }
+ seqCounts[s]++
+ }
+
+ if len(c.curSeqs) == 0 {
+ return nil
+ }
+
+ nTokens := len(batch.Sequences)
+ nSeqs := len(c.curSeqs)
+ want := nTokens / nSeqs
+ for _, s := range c.curSeqs {
+ if seqCounts[s] != want {
+ return kvcache.ErrNotSupported
+ }
+ }
+
+ c.curSeqTokens = want
+
+ // When reserving memory for estimation, use fake slot assignments
+ if reserve {
+ c.curSlots = c.curSlots[:0]
+ slots := make([]int32, nSeqs)
+ for i := range nSeqs {
+ c.curSlots = append(c.curSlots, i)
+ slots[i] = int32(i)
+ }
+ c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
+ c.reserveCheckpoints = true
+ c.planCheckpoints(batch)
+ return nil
+ }
+
+ // Ensure slots exist for sequences in this batch
+ c.curSlots = c.curSlots[:0]
+ var newSlots []int
+ for _, s := range c.curSeqs {
+ slot, ok := c.slotForSeq[s]
+ if !ok {
+ var err error
+ slot, err = c.allocSlot()
+ if err != nil {
+ return err
+ }
+ c.slotForSeq[s] = slot
+ c.refCount[slot] = 1
+ newSlots = append(newSlots, slot)
+ }
+ c.curSlots = append(c.curSlots, slot)
+ }
+
+ // Zero state for newly allocated slots
+ if len(newSlots) > 0 {
+ c.zeroSlots(ctx, newSlots)
+ }
+
+ // Create a tensor for the current slots
+ slots := make([]int32, len(c.curSlots))
+ for i, v := range c.curSlots {
+ slots[i] = int32(v)
+ }
+ c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
+
+ // Reset writable state for new forward pass
+ c.writableEnsured = false
+ c.writableError = nil
+ c.reserveCheckpoints = false
+ c.planCheckpoints(batch)
+
+ return nil
+}
+
+func (c *HybridCache) allocSlot() (int, error) {
+ if len(c.freeSlots) == 0 {
+ return 0, kvcache.ErrKvCacheFull
+ }
+ slot := c.freeSlots[len(c.freeSlots)-1]
+ c.freeSlots = c.freeSlots[:len(c.freeSlots)-1]
+ return slot, nil
+}
+
+func (c *HybridCache) freeSlot(slot int) {
+ if slot >= 0 && slot < c.maxSequences {
+ c.freeSlots = append(c.freeSlots, slot)
+ }
+}
+
+// zeroSlots zeros the recurrent state for the given slots across all layers.
+func (c *HybridCache) zeroSlots(ctx ml.Context, slots []int) {
+ if len(slots) == 0 {
+ return
+ }
+
+ inputCtx := ctx.Input()
+
+ slotIndices := make([]int32, len(slots))
+ for i, s := range slots {
+ slotIndices[i] = int32(s)
+ }
+ slotsTensor := inputCtx.FromInts(slotIndices, len(slotIndices))
+
+ // Zero conv states
+ if len(c.convStates) > 0 {
+ zeros := inputCtx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, len(slots))
+ for _, buf := range c.convStates {
+ ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
+ }
+ }
+
+ // Zero delta states
+ if len(c.deltaStates) > 0 {
+ zeros := inputCtx.Zeros(ml.DTypeF32, c.deltaStateSize, len(slots))
+ for _, buf := range c.deltaStates {
+ ctx.Forward(buf.SetRows(ctx, zeros, slotsTensor))
+ }
+ }
+}
+
+// EnsureWritable ensures sequences have private slots (copy-on-write).
+func (c *HybridCache) EnsureWritable(ctx ml.Context) error {
+ for i, seq := range c.curSeqs {
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ continue
+ }
+
+ if slot < 0 || slot >= len(c.refCount) {
+ continue
+ }
+
+ if c.refCount[slot] <= 1 {
+ continue
+ }
+
+ newSlot, err := c.allocSlot()
+ if err != nil {
+ return err
+ }
+ c.refCount[slot]--
+ c.refCount[newSlot] = 1
+ c.slotForSeq[seq] = newSlot
+ c.curSlots[i] = newSlot
+
+ c.copyRecurrentState(ctx, slot, newSlot)
+ c.copyCheckpoints(ctx, slot, newSlot)
+ }
+
+ // Rebuild current slots tensor
+ slots := make([]int32, len(c.curSlots))
+ for i, v := range c.curSlots {
+ slots[i] = int32(v)
+ }
+ c.curSlotsInput = ctx.Input().FromInts(slots, len(slots))
+
+ return nil
+}
+
+func (c *HybridCache) copyRecurrentState(ctx ml.Context, srcSlot, dstSlot int) {
+ src := ctx.Input().FromInts([]int32{int32(srcSlot)}, 1)
+ dst := ctx.Input().FromInts([]int32{int32(dstSlot)}, 1)
+
+ for _, buf := range c.convStates {
+ rows := buf.Rows(ctx, src)
+ rowsF32 := rows.Cast(ctx, ml.DTypeF32)
+ ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
+ }
+
+ for _, buf := range c.deltaStates {
+ rows := buf.Rows(ctx, src)
+ rowsF32 := rows.Cast(ctx, ml.DTypeF32)
+ ctx.Forward(buf.SetRows(ctx, rowsF32, dst))
+ }
+}
+
+func (c *HybridCache) CopyPrefix(srcSeq, dstSeq int, prefixLen int32) {
+ c.kv.CopyPrefix(srcSeq, dstSeq, prefixLen)
+
+ // Copy-on-write for recurrent state
+ if dstSlot, ok := c.slotForSeq[dstSeq]; ok {
+ if c.validSlot(dstSlot) {
+ c.refCount[dstSlot]--
+ if c.refCount[dstSlot] <= 0 {
+ c.refCount[dstSlot] = 0
+ c.freeSlot(dstSlot)
+ }
+ }
+ delete(c.slotForSeq, dstSeq)
+ }
+
+ srcSlot, ok := c.slotForSeq[srcSeq]
+ if !ok {
+ return
+ }
+
+ if c.validSlot(srcSlot) {
+ c.slotForSeq[dstSeq] = srcSlot
+ c.refCount[srcSlot]++
+ }
+}
+
+func (c *HybridCache) CanResume(seq int, pos int32) bool {
+ if !c.kv.CanResume(seq, pos) {
+ return false
+ }
+ if pos == 0 {
+ return true
+ }
+ return c.hasCheckpoint(seq, pos)
+}
+
+func (c *HybridCache) Remove(seq int, beginIndex, endIndex int32) error {
+ if beginIndex > 0 && endIndex != math.MaxInt32 {
+ return kvcache.ErrNotSupported
+ }
+
+ if beginIndex > 0 {
+ restore, ok := c.pendingRestore[seq]
+ if !ok || restore.pos+1 != beginIndex {
+ return kvcache.ErrNotSupported
+ }
+ if !c.restoreComplete(restore) {
+ return kvcache.ErrNotSupported
+ }
+ // If the recurrent slot is shared, detach it before applying a restore.
+ if slot, ok := c.slotForSeq[seq]; ok && c.validSlot(slot) && c.refCount[slot] > 1 {
+ newSlot, err := c.allocSlot()
+ if err != nil {
+ return err
+ }
+ ctx := c.backend.NewContext()
+ c.copyRecurrentState(ctx, slot, newSlot)
+ c.copyCheckpoints(ctx, slot, newSlot)
+ if len(c.convStates) > 0 || len(c.deltaStates) > 0 {
+ ctx.Compute()
+ }
+ ctx.Close()
+
+ c.refCount[slot]--
+ c.refCount[newSlot] = 1
+ c.slotForSeq[seq] = newSlot
+
+ restore.slot = newSlot
+ c.pendingRestore[seq] = restore
+ }
+ }
+
+ if err := c.kv.Remove(seq, beginIndex, endIndex); err != nil {
+ return err
+ }
+
+ if beginIndex > 0 {
+ restore := c.pendingRestore[seq]
+ delete(c.pendingRestore, seq)
+ return c.applyCheckpointRestore(restore)
+ }
+
+ // Removal invalidates recurrent state
+ slot, ok := c.slotForSeq[seq]
+ delete(c.pendingRestore, seq)
+ if !ok {
+ return nil
+ }
+
+ if !c.validSlot(slot) {
+ delete(c.slotForSeq, seq)
+ return nil
+ }
+
+ c.refCount[slot]--
+ if c.refCount[slot] <= 0 {
+ c.refCount[slot] = 0
+ c.clearCheckpoints(slot)
+ c.freeSlot(slot)
+ }
+ delete(c.slotForSeq, seq)
+
+ return nil
+}
+
+func (c *HybridCache) validSlot(slot int) bool {
+ return slot >= 0 && slot < len(c.refCount)
+}
+
+func (c *HybridCache) slotsTensor() ml.Tensor {
+ return c.curSlotsInput
+}
+
+// contiguousSlots returns the starting slot if current slots are contiguous and ordered.
+func (c *HybridCache) contiguousSlots() (int, bool) {
+ if len(c.curSlots) == 0 {
+ return 0, false
+ }
+ start := c.curSlots[0]
+ for i, s := range c.curSlots {
+ if s != start+i {
+ return 0, false
+ }
+ }
+ return start, true
+}
+
+func (c *HybridCache) seqTokens() int {
+ return c.curSeqTokens
+}
+
+func (c *HybridCache) numSeqs() int {
+ return len(c.curSeqs)
+}
+
+func (c *HybridCache) convBuffer(ctx ml.Context, layer int) ml.Tensor {
+ if buf, ok := c.convStates[layer]; ok {
+ return buf
+ }
+
+ if _, ok := c.convCtxs[layer]; !ok {
+ c.convCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
+ }
+
+ // Recurrent state must stay in F32 (ssm_conv kernels are F32-only).
+ buf := c.convCtxs[layer].Zeros(ml.DTypeF32, c.convDim*c.convChannels, c.maxSequences)
+ c.convStates[layer] = buf
+ return buf
+}
+
+func (c *HybridCache) deltaBuffer(ctx ml.Context, layer int) ml.Tensor {
+ if buf, ok := c.deltaStates[layer]; ok {
+ return buf
+ }
+
+ if _, ok := c.deltaCtxs[layer]; !ok {
+ c.deltaCtxs[layer] = c.backend.NewContextSize(1).Layer(layer)
+ }
+
+ // Recurrent delta state must stay in F32.
+ buf := c.deltaCtxs[layer].Zeros(ml.DTypeF32, c.deltaStateSize, c.maxSequences)
+ c.deltaStates[layer] = buf
+ return buf
+}
+
+func (c *HybridCache) ensureWritableOnce(ctx ml.Context) {
+ if !c.writableEnsured {
+ needsWritable := false
+ for _, seq := range c.curSeqs {
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ continue
+ }
+ if slot >= 0 && slot < len(c.refCount) && c.refCount[slot] > 1 {
+ needsWritable = true
+ break
+ }
+ }
+
+ if needsWritable {
+ if err := c.EnsureWritable(ctx); err != nil {
+ c.writableError = err
+ }
+ }
+ c.writableEnsured = true
+ }
+}
+
+// ConvState returns the conv state for current batch sequences as [convDim, convChannels, nSeqs].
+func (c *HybridCache) ConvState(ctx ml.Context, layer int) (ml.Tensor, error) {
+ c.ensureWritableOnce(ctx)
+
+ if c.writableError != nil {
+ return nil, c.writableError
+ }
+
+ buf := c.convBuffer(ctx, layer)
+ cur := buf.Rows(ctx, c.slotsTensor())
+ return cur.Reshape(ctx, c.convDim, c.convChannels, c.numSeqs()), nil
+}
+
+// UpdateConvState writes a new conv state for current batch sequences.
+func (c *HybridCache) UpdateConvState(ctx ml.Context, layer int, newState ml.Tensor) {
+ buf := c.convBuffer(ctx, layer)
+ src := newState.Reshape(ctx, c.convDim*c.convChannels, c.numSeqs())
+ srcF32 := src.Cast(ctx, ml.DTypeF32)
+ if start, ok := c.contiguousSlots(); ok {
+ // Fast path: contiguous slots allow a single view + copy
+ offset := start * buf.Stride(1)
+ view := buf.View(ctx, offset, c.convDim*c.convChannels, buf.Stride(1), c.numSeqs())
+ ctx.Forward(srcF32.Copy(ctx, view))
+ } else {
+ ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
+ }
+
+ c.captureConvCheckpoint(ctx, layer, srcF32)
+}
+
+// DeltaState returns the delta state for current batch sequences as [headVDim, headVDim*numVHeads, nSeqs].
+func (c *HybridCache) DeltaState(ctx ml.Context, layer int, headVDim, numVHeads int) (ml.Tensor, error) {
+ c.ensureWritableOnce(ctx)
+
+ if c.writableError != nil {
+ return nil, c.writableError
+ }
+
+ buf := c.deltaBuffer(ctx, layer)
+ cur := buf.Rows(ctx, c.slotsTensor())
+ return cur.Reshape(ctx, headVDim, headVDim*numVHeads, c.numSeqs()), nil
+}
+
+// UpdateDeltaState writes a new delta state for current batch sequences.
+func (c *HybridCache) UpdateDeltaState(ctx ml.Context, layer int, newState ml.Tensor) {
+ buf := c.deltaBuffer(ctx, layer)
+ src := newState.Reshape(ctx, c.deltaStateSize, c.numSeqs())
+ srcF32 := src.Cast(ctx, ml.DTypeF32)
+ if start, ok := c.contiguousSlots(); ok {
+ // Fast path: contiguous slots allow a single view + copy
+ offset := start * buf.Stride(1)
+ view := buf.View(ctx, offset, c.deltaStateSize, buf.Stride(1), c.numSeqs())
+ ctx.Forward(srcF32.Copy(ctx, view))
+ } else {
+ ctx.Forward(buf.SetRows(ctx, srcF32, c.slotsTensor()))
+ }
+
+ c.captureDeltaCheckpoint(ctx, layer, srcF32)
+}
+
+// IsSupportedForBatch returns true if the current batch layout supports recurrent layers.
+func (c *HybridCache) IsSupportedForBatch() bool {
+ return c.curSeqTokens > 0 && len(c.curSeqs) > 0
+}
+
+// Seqs returns the ordered unique sequences for the current forward pass.
+func (c *HybridCache) Seqs() []int {
+ return slices.Clone(c.curSeqs)
+}
diff --git a/model/models/qwen3next/checkpoints.go b/model/models/qwen3next/checkpoints.go
new file mode 100644
index 00000000000..913af1c054d
--- /dev/null
+++ b/model/models/qwen3next/checkpoints.go
@@ -0,0 +1,498 @@
+package qwen3next
+
+import (
+ "log/slog"
+ "math"
+
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/model/input"
+)
+
+const (
+ checkpointCountDefault = 32
+ checkpointMinPosDefault = int32(16)
+ checkpointIntervalDefault = int32(1280)
+)
+
+// TODO(jmorganca): Add byte-serialized host-RAM checkpoints to reduce GPU
+// memory usage while preserving prefix reuse for recurrent state.
+
+type checkpointEntry struct {
+ pos int32
+ conv map[int]ml.Tensor
+ delta map[int]ml.Tensor
+}
+
+type slotCheckpointStore struct {
+ entries []checkpointEntry
+ size int
+ next int
+ lastPos int32
+}
+
+type checkpointRestore struct {
+ slot int
+ idx int
+ pos int32
+}
+
+func newSlotCheckpointStore(n int) *slotCheckpointStore {
+ entries := make([]checkpointEntry, n)
+ for i := range entries {
+ entries[i].pos = -1
+ }
+ return &slotCheckpointStore{
+ entries: entries,
+ lastPos: -1,
+ }
+}
+
+func (s *slotCheckpointStore) reset() {
+ s.size = 0
+ s.next = 0
+ s.lastPos = -1
+ for i := range s.entries {
+ s.entries[i].pos = -1
+ }
+}
+
+func (s *slotCheckpointStore) record(pos int32) int {
+ if len(s.entries) == 0 {
+ return -1
+ }
+ idx := s.next
+ s.next = (s.next + 1) % len(s.entries)
+ if s.size < len(s.entries) {
+ s.size++
+ }
+ s.entries[idx].pos = pos
+ s.lastPos = pos
+ return idx
+}
+
+func (s *slotCheckpointStore) bestIndex(targetPos int32) (int, int32, bool) {
+ bestIdx := -1
+ bestPos := int32(-1)
+ for i := range s.entries {
+ pos := s.entries[i].pos
+ if pos < 0 || pos >= targetPos {
+ continue
+ }
+ if pos > bestPos {
+ bestPos = pos
+ bestIdx = i
+ }
+ }
+ if bestIdx < 0 {
+ return -1, -1, false
+ }
+ return bestIdx, bestPos, true
+}
+
+func (s *slotCheckpointStore) pruneAfter(pos int32) {
+ if len(s.entries) == 0 {
+ s.size = 0
+ s.next = 0
+ s.lastPos = -1
+ return
+ }
+
+ size := 0
+ next := -1
+ minPos := int32(math.MaxInt32)
+ minIdx := 0
+ for i := range s.entries {
+ if s.entries[i].pos > pos {
+ s.entries[i].pos = -1
+ }
+ if s.entries[i].pos >= 0 {
+ size++
+ if s.entries[i].pos < minPos {
+ minPos = s.entries[i].pos
+ minIdx = i
+ }
+ } else if next == -1 {
+ next = i
+ }
+ }
+
+ s.size = size
+ if size == 0 {
+ s.next = 0
+ s.lastPos = -1
+ return
+ }
+ if next != -1 {
+ s.next = next
+ } else {
+ // Full ring: overwrite the oldest checkpoint next.
+ s.next = minIdx
+ }
+ s.lastPos = pos
+}
+
+func (s *slotCheckpointStore) window() (size int, minPos, maxPos, lastPos int32) {
+ minPos = int32(math.MaxInt32)
+ maxPos = int32(-1)
+ for i := range s.entries {
+ pos := s.entries[i].pos
+ if pos < 0 {
+ continue
+ }
+ size++
+ if pos < minPos {
+ minPos = pos
+ }
+ if pos > maxPos {
+ maxPos = pos
+ }
+ }
+ if size == 0 {
+ minPos = -1
+ maxPos = -1
+ }
+ return size, minPos, maxPos, s.lastPos
+}
+
+func (c *HybridCache) planCheckpoints(batch input.Batch) {
+ if c.checkpointCount == 0 || len(c.curSeqs) == 0 {
+ c.curCheckpointPos = c.curCheckpointPos[:0]
+ for k := range c.curCheckpointSlots {
+ delete(c.curCheckpointSlots, k)
+ }
+ return
+ }
+
+ if cap(c.curCheckpointPos) < len(c.curSeqs) {
+ c.curCheckpointPos = make([]int32, len(c.curSeqs))
+ } else {
+ c.curCheckpointPos = c.curCheckpointPos[:len(c.curSeqs)]
+ }
+ for i := range c.curCheckpointPos {
+ c.curCheckpointPos[i] = -1
+ }
+ for k := range c.curCheckpointSlots {
+ delete(c.curCheckpointSlots, k)
+ }
+
+ posMax := make(map[int]int32, len(c.curSeqs))
+ for i, seq := range batch.Sequences {
+ pos := batch.Positions[i]
+ if cur, ok := posMax[seq]; !ok || pos > cur {
+ posMax[seq] = pos
+ }
+ }
+
+ for i, seq := range c.curSeqs {
+ pos, ok := posMax[seq]
+ if !ok {
+ continue
+ }
+ if pos < c.checkpointMinPos {
+ continue
+ }
+ slot := c.curSlots[i]
+ store := c.checkpointStore(slot)
+ lastPos := store.lastPos
+ if lastPos < 0 || pos-lastPos >= c.checkpointInterval {
+ c.curCheckpointPos[i] = pos
+ }
+ }
+}
+
+func (c *HybridCache) checkpointStore(slot int) *slotCheckpointStore {
+ store, ok := c.checkpoints[slot]
+ if ok {
+ return store
+ }
+ store = newSlotCheckpointStore(c.checkpointCount)
+ c.checkpoints[slot] = store
+ return store
+}
+
+func (c *HybridCache) checkpointIndexForSlot(slot int, pos int32) int {
+ if c.checkpointCount == 0 {
+ return -1
+ }
+ if idx, ok := c.curCheckpointSlots[slot]; ok {
+ return idx
+ }
+ store := c.checkpointStore(slot)
+ idx := store.record(pos)
+ if idx >= 0 {
+ c.curCheckpointSlots[slot] = idx
+ }
+ return idx
+}
+
+func (c *HybridCache) hasCheckpoint(seq int, pos int32) bool {
+ if pos <= 0 {
+ return false
+ }
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ return false
+ }
+ store, ok := c.checkpoints[slot]
+ if !ok {
+ return false
+ }
+ _, _, ok = store.bestIndex(pos)
+ return ok
+}
+
+func (c *HybridCache) PrepareRestore(seq int, targetPos int32) (int32, bool) {
+ if targetPos <= 0 {
+ return 0, false
+ }
+ slot, ok := c.slotForSeq[seq]
+ if !ok {
+ return 0, false
+ }
+ store, ok := c.checkpoints[slot]
+ if !ok {
+ slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", 0)
+ return 0, false
+ }
+ idx, pos, ok := store.bestIndex(targetPos)
+ if !ok {
+ size, minPos, maxPos, lastPos := store.window()
+ slog.Debug("qwen3next: checkpoint miss", "seq", seq, "slot", slot, "target", targetPos, "size", size,
+ "min", minPos, "max", maxPos, "last", lastPos)
+ return 0, false
+ }
+ c.pendingRestore[seq] = checkpointRestore{
+ slot: slot,
+ idx: idx,
+ pos: pos,
+ }
+ return pos + 1, true
+}
+
+func (c *HybridCache) applyCheckpointRestore(restore checkpointRestore) error {
+ entry, ok := c.restoreEntry(restore)
+ if !ok {
+ return kvcache.ErrNotSupported
+ }
+
+ ctx := c.backend.NewContext()
+ defer ctx.Close()
+
+ slotIdx := ctx.Input().FromInts([]int32{int32(restore.slot)}, 1)
+ for layer, src := range entry.conv {
+ buf := c.convBuffer(ctx, layer)
+ ctx.Forward(buf.SetRows(ctx, src, slotIdx))
+ }
+ for layer, src := range entry.delta {
+ buf := c.deltaBuffer(ctx, layer)
+ ctx.Forward(buf.SetRows(ctx, src, slotIdx))
+ }
+
+ if len(entry.conv) > 0 || len(entry.delta) > 0 {
+ ctx.Compute()
+ }
+ store := c.checkpoints[restore.slot]
+ store.pruneAfter(restore.pos)
+ return nil
+}
+
+func (c *HybridCache) restoreComplete(restore checkpointRestore) bool {
+ _, ok := c.restoreEntry(restore)
+ return ok
+}
+
+func (c *HybridCache) restoreEntry(restore checkpointRestore) (*checkpointEntry, bool) {
+ store, ok := c.checkpoints[restore.slot]
+ if !ok || restore.idx < 0 || restore.idx >= len(store.entries) {
+ return nil, false
+ }
+ entry := &store.entries[restore.idx]
+ if entry.pos < 0 {
+ return nil, false
+ }
+ if !c.entryComplete(entry) {
+ return nil, false
+ }
+ return entry, true
+}
+
+func (c *HybridCache) entryComplete(entry *checkpointEntry) bool {
+ for layer := range c.convStates {
+ if entry.conv == nil || entry.conv[layer] == nil {
+ return false
+ }
+ }
+ for layer := range c.deltaStates {
+ if entry.delta == nil || entry.delta[layer] == nil {
+ return false
+ }
+ }
+ return true
+}
+
+func (c *HybridCache) clearCheckpoints(slot int) {
+ if store, ok := c.checkpoints[slot]; ok {
+ store.reset()
+ }
+}
+
+func (c *HybridCache) copyCheckpoints(ctx ml.Context, srcSlot, dstSlot int) {
+ if c.checkpointCount == 0 {
+ return
+ }
+ srcStore, ok := c.checkpoints[srcSlot]
+ if !ok || srcStore.size == 0 {
+ return
+ }
+ dstStore := c.checkpointStore(dstSlot)
+ dstStore.size = srcStore.size
+ dstStore.next = srcStore.next
+ dstStore.lastPos = srcStore.lastPos
+
+ for i := range srcStore.entries {
+ srcEntry := &srcStore.entries[i]
+ dstEntry := &dstStore.entries[i]
+ dstEntry.pos = srcEntry.pos
+ if srcEntry.conv != nil {
+ if dstEntry.conv == nil {
+ dstEntry.conv = make(map[int]ml.Tensor)
+ }
+ for layer, src := range srcEntry.conv {
+ dst := c.ensureCheckpointConv(layer, dstEntry)
+ ctx.Forward(src.Copy(ctx, dst))
+ }
+ }
+ if srcEntry.delta != nil {
+ if dstEntry.delta == nil {
+ dstEntry.delta = make(map[int]ml.Tensor)
+ }
+ for layer, src := range srcEntry.delta {
+ dst := c.ensureCheckpointDelta(layer, dstEntry)
+ ctx.Forward(src.Copy(ctx, dst))
+ }
+ }
+ }
+}
+
+func (c *HybridCache) captureConvCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
+ if c.checkpointCount == 0 {
+ return
+ }
+ if c.reserveCheckpoints {
+ c.reserveCheckpointConv(layer)
+ return
+ }
+ if len(c.curCheckpointPos) == 0 {
+ return
+ }
+ for i, pos := range c.curCheckpointPos {
+ if pos < 0 {
+ continue
+ }
+ slot := c.curSlots[i]
+ idx := c.checkpointIndexForSlot(slot, pos)
+ if idx < 0 {
+ continue
+ }
+ entry := &c.checkpoints[slot].entries[idx]
+ dst := c.ensureCheckpointConv(layer, entry)
+ seqSlice := src.Slice(ctx, 1, i, i+1, 1)
+ ctx.Forward(seqSlice.Copy(ctx, dst))
+ }
+}
+
+func (c *HybridCache) captureDeltaCheckpoint(ctx ml.Context, layer int, src ml.Tensor) {
+ if c.checkpointCount == 0 {
+ return
+ }
+ if c.reserveCheckpoints {
+ c.reserveCheckpointDelta(layer)
+ return
+ }
+ if len(c.curCheckpointPos) == 0 {
+ return
+ }
+ for i, pos := range c.curCheckpointPos {
+ if pos < 0 {
+ continue
+ }
+ slot := c.curSlots[i]
+ idx := c.checkpointIndexForSlot(slot, pos)
+ if idx < 0 {
+ continue
+ }
+ entry := &c.checkpoints[slot].entries[idx]
+ dst := c.ensureCheckpointDelta(layer, entry)
+ seqSlice := src.Slice(ctx, 1, i, i+1, 1)
+ ctx.Forward(seqSlice.Copy(ctx, dst))
+ }
+}
+
+func (c *HybridCache) ensureCheckpointConv(layer int, entry *checkpointEntry) ml.Tensor {
+ if entry.conv == nil {
+ entry.conv = make(map[int]ml.Tensor)
+ }
+ if t, ok := entry.conv[layer]; ok {
+ return t
+ }
+ ctx, ok := c.checkpointConvCtxs[layer]
+ if !ok {
+ ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
+ c.checkpointConvCtxs[layer] = ctx
+ }
+ t := ctx.Zeros(ml.DTypeF32, c.convDim*c.convChannels, 1)
+ entry.conv[layer] = t
+ return t
+}
+
+func (c *HybridCache) ensureCheckpointDelta(layer int, entry *checkpointEntry) ml.Tensor {
+ if entry.delta == nil {
+ entry.delta = make(map[int]ml.Tensor)
+ }
+ if t, ok := entry.delta[layer]; ok {
+ return t
+ }
+ ctx, ok := c.checkpointDeltaCtxs[layer]
+ if !ok {
+ ctx = c.backend.NewContextSize(c.checkpointCtxSize).Layer(layer)
+ c.checkpointDeltaCtxs[layer] = ctx
+ }
+ t := ctx.Zeros(ml.DTypeF32, c.deltaStateSize, 1)
+ entry.delta[layer] = t
+ return t
+}
+
+func (c *HybridCache) reserveCheckpointConv(layer int) {
+ key := checkpointReserveKey(layer, 0)
+ if _, ok := c.checkpointReserved[key]; ok {
+ return
+ }
+ for slot := range c.maxSequences {
+ store := c.checkpointStore(slot)
+ for i := range store.entries {
+ entry := &store.entries[i]
+ _ = c.ensureCheckpointConv(layer, entry)
+ }
+ }
+ c.checkpointReserved[key] = struct{}{}
+}
+
+func (c *HybridCache) reserveCheckpointDelta(layer int) {
+ key := checkpointReserveKey(layer, 1)
+ if _, ok := c.checkpointReserved[key]; ok {
+ return
+ }
+ for slot := range c.maxSequences {
+ store := c.checkpointStore(slot)
+ for i := range store.entries {
+ entry := &store.entries[i]
+ _ = c.ensureCheckpointDelta(layer, entry)
+ }
+ }
+ c.checkpointReserved[key] = struct{}{}
+}
+
+func checkpointReserveKey(layer int, kind int) int {
+ return layer*2 + kind
+}
diff --git a/model/models/qwen3next/checkpoints_test.go b/model/models/qwen3next/checkpoints_test.go
new file mode 100644
index 00000000000..440a3a2cfc7
--- /dev/null
+++ b/model/models/qwen3next/checkpoints_test.go
@@ -0,0 +1,300 @@
+package qwen3next
+
+import (
+ "errors"
+ "math"
+ "os"
+ "testing"
+
+ "github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/kvcache"
+ "github.com/ollama/ollama/ml"
+)
+
+func newTestBackend(tb testing.TB) ml.Backend {
+ tb.Helper()
+
+ f, err := os.CreateTemp(tb.TempDir(), "*.gguf")
+ if err != nil {
+ tb.Fatal(err)
+ }
+ if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil {
+ _ = f.Close()
+ tb.Fatal(err)
+ }
+ if err := f.Close(); err != nil {
+ tb.Fatal(err)
+ }
+
+ b, err := ml.NewBackend(f.Name(), ml.BackendParams{AllocMemory: true})
+ if err != nil {
+ tb.Fatal(err)
+ }
+ tb.Cleanup(func() {
+ b.Close()
+ })
+
+ return b
+}
+
+func TestSlotCheckpointStoreBestIndex(t *testing.T) {
+ store := newSlotCheckpointStore(2)
+ store.record(10)
+ store.record(20)
+
+ _, pos, ok := store.bestIndex(15)
+ if !ok || pos != 10 {
+ t.Fatalf("expected best pos 10, got pos=%d ok=%v", pos, ok)
+ }
+
+ store.record(30) // overwrite oldest (10)
+
+ if _, _, ok := store.bestIndex(15); ok {
+ t.Fatalf("expected no checkpoint for targetPos=15 after overwrite")
+ }
+
+ _, pos, ok = store.bestIndex(40)
+ if !ok || pos != 30 {
+ t.Fatalf("expected best pos 30, got pos=%d ok=%v", pos, ok)
+ }
+}
+
+func TestHybridCachePrepareRestore(t *testing.T) {
+ cache := NewHybridCache(nil, 1, 1, 1)
+ cache.checkpointCount = 3
+ cache.checkpoints = make(map[int]*slotCheckpointStore)
+ cache.pendingRestore = make(map[int]checkpointRestore)
+
+ cache.slotForSeq[1] = 0
+ store := cache.checkpointStore(0)
+ store.record(5)
+ store.record(9)
+ store.record(15)
+
+ restorePos, ok := cache.PrepareRestore(1, 12)
+ if !ok {
+ t.Fatalf("expected restore ok")
+ }
+ if restorePos != 10 {
+ t.Fatalf("expected restorePos 10, got %d", restorePos)
+ }
+ rest, ok := cache.pendingRestore[1]
+ if !ok {
+ t.Fatalf("expected pending restore entry")
+ }
+ if rest.pos != 9 {
+ t.Fatalf("expected pending restore pos 9, got %d", rest.pos)
+ }
+}
+
+func TestSlotCheckpointStorePruneAfter(t *testing.T) {
+ store := newSlotCheckpointStore(3)
+ store.record(10)
+ store.record(20)
+ store.record(30)
+
+ store.pruneAfter(20)
+
+ if store.lastPos != 20 {
+ t.Fatalf("expected lastPos 20, got %d", store.lastPos)
+ }
+
+ _, pos, ok := store.bestIndex(25)
+ if !ok || pos != 20 {
+ t.Fatalf("expected best pos 20 after prune, got pos=%d ok=%v", pos, ok)
+ }
+
+ _, pos, ok = store.bestIndex(35)
+ if !ok || pos != 20 {
+ t.Fatalf("expected pruned best pos 20 for targetPos=35, got pos=%d ok=%v", pos, ok)
+ }
+}
+
+func TestHybridCacheRestoreDetachesSharedSlot(t *testing.T) {
+ backend := newTestBackend(t)
+
+ cache := NewHybridCache(nil, 1, 2, 2)
+ cache.Init(backend, ml.DTypeF16, 2, 8, 2)
+
+ cache.slotForSeq[1] = 0
+ cache.slotForSeq[2] = 0
+ cache.refCount[0] = 2
+ cache.refCount[1] = 0
+ cache.freeSlots = []int{1}
+
+ store := cache.checkpointStore(0)
+ idx := store.record(9)
+ cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
+
+ if err := cache.Remove(1, 10, math.MaxInt32); err != nil {
+ t.Fatalf("Remove failed: %v", err)
+ }
+
+ if cache.slotForSeq[1] == cache.slotForSeq[2] {
+ t.Fatalf("expected restore to detach shared slot, got same slot %d", cache.slotForSeq[1])
+ }
+ if cache.slotForSeq[1] != 1 {
+ t.Fatalf("expected seq 1 to move to slot 1, got %d", cache.slotForSeq[1])
+ }
+ if cache.slotForSeq[2] != 0 {
+ t.Fatalf("expected seq 2 to remain on slot 0, got %d", cache.slotForSeq[2])
+ }
+ if cache.refCount[0] != 1 || cache.refCount[1] != 1 {
+ t.Fatalf("unexpected refCounts: slot0=%d slot1=%d", cache.refCount[0], cache.refCount[1])
+ }
+ if _, ok := cache.pendingRestore[1]; ok {
+ t.Fatalf("expected pending restore to be cleared")
+ }
+}
+
+func TestHybridCacheRestoreRejectsIncompleteCheckpoint(t *testing.T) {
+ cache := NewHybridCache(nil, 1, 2, 2)
+ cache.checkpointCount = 3
+ cache.checkpoints = make(map[int]*slotCheckpointStore)
+ cache.pendingRestore = make(map[int]checkpointRestore)
+
+ cache.slotForSeq[1] = 0
+ cache.refCount = []int{1}
+ cache.freeSlots = nil
+
+ // Simulate that layer 0 has both conv and delta state (so entryComplete expects both)
+ cache.convStates[0] = nil // placeholder to indicate layer 0 exists
+ cache.deltaStates[0] = nil // placeholder to indicate layer 0 exists
+
+ store := cache.checkpointStore(0)
+ idx := store.record(9)
+ entry := &store.entries[idx]
+ // Only set conv checkpoint, not delta - making it incomplete
+ entry.conv = map[int]ml.Tensor{0: nil}
+ // entry.delta is not set, so checkpoint is incomplete
+
+ cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
+
+ err := cache.Remove(1, 10, math.MaxInt32)
+ if !errors.Is(err, kvcache.ErrNotSupported) {
+ t.Fatalf("expected ErrNotSupported for incomplete checkpoint, got %v", err)
+ }
+}
+
+func TestHybridCacheRestoreAcceptsCompleteCheckpoint(t *testing.T) {
+ cache := NewHybridCache(nil, 1, 2, 2)
+ cache.checkpointCount = 3
+ cache.checkpoints = make(map[int]*slotCheckpointStore)
+ cache.pendingRestore = make(map[int]checkpointRestore)
+
+ cache.slotForSeq[1] = 0
+ cache.refCount = []int{1}
+ cache.freeSlots = nil
+
+ // Don't set convStates/deltaStates - with no layers to check,
+ // entryComplete will return true as long as entry.pos >= 0
+
+ store := cache.checkpointStore(0)
+ idx := store.record(9)
+
+ cache.pendingRestore[1] = checkpointRestore{slot: 0, idx: idx, pos: 9}
+
+ // Test that restoreComplete returns true when no layers need checkpoints
+ restore := cache.pendingRestore[1]
+ if !cache.restoreComplete(restore) {
+ t.Fatalf("expected restoreComplete to return true for complete checkpoint")
+ }
+}
+
+func TestSlotCheckpointStoreRingBufferWrapAround(t *testing.T) {
+ // Test that ring buffer wrap-around reuses entries without clearing maps.
+ store := newSlotCheckpointStore(3)
+
+ // Fill the buffer
+ store.record(10)
+ store.record(20)
+ store.record(30)
+
+ // Create fake tensor data in the first entry's maps
+ store.entries[0].conv = make(map[int]ml.Tensor)
+ store.entries[0].conv[0] = nil // Simulated tensor reference
+ store.entries[0].delta = make(map[int]ml.Tensor)
+ store.entries[0].delta[0] = nil // Simulated tensor reference
+
+ // Record another entry, which should wrap around and overwrite entry 0
+ store.record(40)
+
+ // Verify the maps are still present (we reuse tensors)
+ if store.entries[0].conv == nil {
+ t.Fatalf("expected conv map to be preserved on reuse")
+ }
+ if store.entries[0].delta == nil {
+ t.Fatalf("expected delta map to be preserved on reuse")
+ }
+
+ // Verify the new position was recorded
+ if store.entries[0].pos != 40 {
+ t.Fatalf("expected entry 0 pos to be 40, got %d", store.entries[0].pos)
+ }
+}
+
+func TestSlotCheckpointStoreFullCapacity(t *testing.T) {
+ // Test behavior when buffer is exactly at capacity
+ store := newSlotCheckpointStore(2)
+
+ idx1 := store.record(10)
+ idx2 := store.record(20)
+
+ if idx1 != 0 || idx2 != 1 {
+ t.Fatalf("expected indices 0, 1, got %d, %d", idx1, idx2)
+ }
+
+ if store.size != 2 {
+ t.Fatalf("expected size 2, got %d", store.size)
+ }
+
+ // Verify both checkpoints are accessible
+ _, pos1, ok1 := store.bestIndex(15)
+ _, pos2, ok2 := store.bestIndex(25)
+
+ if !ok1 || pos1 != 10 {
+ t.Fatalf("expected best pos 10 for target 15, got pos=%d ok=%v", pos1, ok1)
+ }
+ if !ok2 || pos2 != 20 {
+ t.Fatalf("expected best pos 20 for target 25, got pos=%d ok=%v", pos2, ok2)
+ }
+}
+
+func TestSlotCheckpointStoreEmptyBuffer(t *testing.T) {
+ // Test behavior with zero-size buffer
+ store := newSlotCheckpointStore(0)
+
+ idx := store.record(10)
+ if idx != -1 {
+ t.Fatalf("expected record to return -1 for empty buffer, got %d", idx)
+ }
+
+ _, _, ok := store.bestIndex(15)
+ if ok {
+ t.Fatalf("expected no checkpoint for empty buffer")
+ }
+}
+
+func TestSlotCheckpointStorePruneAfterAll(t *testing.T) {
+ // Test pruning that removes all checkpoints
+ store := newSlotCheckpointStore(3)
+ store.record(10)
+ store.record(20)
+ store.record(30)
+
+ // Prune everything by setting threshold below all positions
+ store.pruneAfter(5)
+
+ if store.size != 0 {
+ t.Fatalf("expected size 0 after pruning all, got %d", store.size)
+ }
+ // When all checkpoints are pruned, lastPos is reset to -1
+ if store.lastPos != -1 {
+ t.Fatalf("expected lastPos -1 after pruning all, got %d", store.lastPos)
+ }
+
+ _, _, ok := store.bestIndex(100)
+ if ok {
+ t.Fatalf("expected no checkpoint after pruning all")
+ }
+}
diff --git a/model/models/qwen3next/deltanet.go b/model/models/qwen3next/deltanet.go
new file mode 100644
index 00000000000..e0a6f7b25fe
--- /dev/null
+++ b/model/models/qwen3next/deltanet.go
@@ -0,0 +1,472 @@
+package qwen3next
+
+import (
+ "errors"
+ "log/slog"
+ "math"
+
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+)
+
+const chunkSize = 64
+
+// TriType constants for triangular matrix operations
+const (
+ TriTypeUpperDiag = 0
+ TriTypeUpper = 1
+ TriTypeLowerDiag = 2
+ TriTypeLower = 3
+)
+
+// convKernel wraps the 1D convolution kernel tensor
+type convKernel struct {
+ Weight ml.Tensor `gguf:"weight"`
+}
+
+// Masks holds pre-computed mask tensors for chunked attention
+type Masks struct {
+ Causal ml.Tensor // Lower triangular [chunkSize, chunkSize]
+ Identity ml.Tensor // Diagonal [chunkSize, chunkSize]
+ Diag ml.Tensor // causal + identity
+}
+
+// GatedDeltaNet implements linear attention with SSM convolution and recurrent state.
+// It implements the Operator interface directly.
+type GatedDeltaNet struct {
+ // Optimized path: pre-split QKV and gate
+ SSMQKV *nn.Linear `gguf:"attn_qkv"` // -> Q, K, V (concatenated)
+ SSMQKVGate *nn.Linear `gguf:"attn_gate"` // -> Z gate
+ SSMBetaAlpha *nn.Linear `gguf:"ssm_ba"` // -> beta, alpha
+ SSMConv1D *convKernel `gguf:"ssm_conv1d"`
+ SSMDT ml.Tensor `gguf:"ssm_dt"` // alpha bias
+ SSMA ml.Tensor `gguf:"ssm_a"` // -A_log.exp()
+ SSMNorm *nn.RMSNorm `gguf:"ssm_norm"`
+ SSMOut *nn.Linear `gguf:"ssm_out"`
+
+ // Layer index for cache access (set during model construction)
+ Layer int
+}
+
+// createMasks builds the constant mask tensors (called once, reused for all chunks)
+func createMasks(ctx ml.Context) *Masks {
+ ones := ctx.Input().Zeros(ml.DTypeF32, chunkSize, chunkSize)
+ ones = ones.Fill(ctx, 1.0)
+ causalMask := ones.Tri(ctx, TriTypeLower)
+
+ onesVec := ctx.Input().Zeros(ml.DTypeF32, chunkSize)
+ onesVec = onesVec.Fill(ctx, 1.0)
+ identity := onesVec.Diag(ctx)
+
+ diagMask := causalMask.Add(ctx, identity)
+
+ return &Masks{
+ Causal: causalMask,
+ Identity: identity,
+ Diag: diagMask,
+ }
+}
+
+func (gdn *GatedDeltaNet) Forward(ctx ml.Context, hiddenStates, _ ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
+ layer := gdn.Layer
+ nSeqTokens := hiddenStates.Dim(1)
+ nSeqs := hiddenStates.Dim(2)
+ if cache != nil && cache.IsSupportedForBatch() {
+ seqTokens := cache.seqTokens()
+ seqs := cache.numSeqs()
+ if seqTokens > 0 && seqs > 0 {
+ if nSeqs > 1 {
+ if nSeqTokens != seqTokens || nSeqs != seqs {
+ return nil, ErrUnsupportedBatchLayout
+ }
+ } else {
+ if nSeqTokens != seqTokens*seqs {
+ return nil, ErrUnsupportedBatchLayout
+ }
+ hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), seqTokens, seqs)
+ nSeqTokens = seqTokens
+ nSeqs = seqs
+ }
+ }
+ }
+
+ headKDim := opts.ssmDState
+ numKHeads := opts.ssmNGroup
+ numVHeads := opts.ssmDtRank
+ headVDim := opts.ssmDInner / numVHeads
+ convKernelSize := opts.convKernelSize
+
+ mixedBA := gdn.SSMBetaAlpha.Forward(ctx, hiddenStates)
+ qkvDim := headKDim*numKHeads*2 + headVDim*numVHeads
+
+ if gdn.SSMQKV == nil || gdn.SSMQKVGate == nil {
+ return nil, errors.New("qwen3next: missing attn_qkv/attn_gate projections (legacy ssm_in is not supported)")
+ }
+ // Optimized path: pre-split QKV and gate
+ qkvMixed := gdn.SSMQKV.Forward(ctx, hiddenStates).Reshape(ctx, qkvDim, nSeqTokens, nSeqs)
+ z := gdn.SSMQKVGate.Forward(ctx, hiddenStates)
+
+ baNewDim := 2 * numVHeads / numKHeads
+ mixedBAReshaped := mixedBA.Reshape(ctx, baNewDim, numKHeads, nSeqTokens, nSeqs)
+
+ // Split beta and alpha
+ betaSize := numVHeads / numKHeads
+ alphaSize := numVHeads / numKHeads
+
+ b := mixedBAReshaped.Slice(ctx, 0, 0, betaSize, 1)
+ a := mixedBAReshaped.Slice(ctx, 0, betaSize, betaSize+alphaSize, 1)
+
+ // Reshape to merge head dimensions
+ beta := b.Contiguous(ctx, numVHeads, 1, nSeqTokens, nSeqs)
+ alpha := a.Contiguous(ctx, numVHeads, nSeqTokens, nSeqs)
+
+ // Compute gate: softplus(alpha + dt_bias) * -A
+ alphaBiased := alpha.Add(ctx, gdn.SSMDT)
+ alphaSoftplus := alphaBiased.Softplus(ctx)
+ gate := alphaSoftplus.Mul(ctx, gdn.SSMA)
+ qkvMixed = qkvMixed.Permute(ctx, 1, 0, 2, 3)
+
+ // Get conv state from cache
+ convStates, err := cache.ConvState(ctx, layer)
+ if err != nil {
+ // Log this - if it happens, short-term context will be lost
+ slog.Warn("qwen3next: failed to get conv state, using zeros", "layer", layer, "error", err)
+ convStates = ctx.Input().Zeros(ml.DTypeF32, convKernelSize-1, qkvDim, nSeqs)
+ }
+
+ // Reshape conv states
+ convStates = convStates.Reshape(ctx, convKernelSize-1, qkvDim, nSeqs)
+
+ // Concatenate with input for convolution
+ convInput := convStates.Concat(ctx, qkvMixed, 0)
+
+ // Save new conv state (last convKernelSize-1 tokens)
+ lastConvStates := convInput.Slice(ctx, 0, nSeqTokens, nSeqTokens+convKernelSize-1, 1)
+ cache.UpdateConvState(ctx, layer, lastConvStates)
+
+ // Apply SSM convolution (kernel must be F32 for Metal)
+ convOutput := convInput.SSMConv(ctx, gdn.SSMConv1D.Weight)
+ convOutput = convOutput.SILU(ctx)
+
+ // Reshape for extraction
+ convQKVMix := convOutput.Contiguous(ctx, qkvDim, nSeqTokens*nSeqs)
+
+ // Extract convolved Q, K, V
+ qConv := convQKVMix.Slice(ctx, 0, 0, headKDim*numKHeads, 1)
+ kConv := convQKVMix.Slice(ctx, 0, headKDim*numKHeads, 2*headKDim*numKHeads, 1)
+ vConv := convQKVMix.Slice(ctx, 0, 2*headKDim*numKHeads, qkvDim, 1)
+
+ // Reshape to 4D
+ qConv = qConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
+ kConv = kConv.Contiguous(ctx, headKDim, numKHeads, nSeqTokens, nSeqs)
+ vConv = vConv.Contiguous(ctx, headVDim, numVHeads, nSeqTokens, nSeqs)
+
+ // Get delta state from cache
+ state, err := cache.DeltaState(ctx, layer, headVDim, numVHeads)
+ if err != nil {
+ // Log this - if it happens frequently, context will degrade
+ slog.Warn("qwen3next: failed to get delta state, using zeros", "layer", layer, "error", err)
+ state = ctx.Input().Zeros(ml.DTypeF32, headVDim, headVDim*numVHeads, nSeqs)
+ }
+ state = state.Reshape(ctx, headVDim, headVDim*numVHeads, 1, nSeqs)
+
+ // Repeat interleave Q and K if numKHeads != numVHeads
+ if numKHeads != numVHeads {
+ repeatFactor := numVHeads / numKHeads
+
+ qReshaped := qConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
+ kReshaped := kConv.Reshape(ctx, headKDim, 1, numKHeads*nSeqTokens*nSeqs)
+
+ qRepeated := qReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
+ kRepeated := kReshaped.Repeat4D(ctx, headKDim, repeatFactor, numKHeads*nSeqTokens*nSeqs, 1)
+
+ qConv = qRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
+ kConv = kRepeated.Reshape(ctx, headKDim, numKHeads*repeatFactor, nSeqTokens, nSeqs)
+ }
+
+ // Choose computation mode based on sequence length
+ var attnOut ml.Tensor
+ if nSeqTokens == 1 {
+ attnOut = gdn.deltaNetAutoregressive(ctx, qConv, kConv, vConv, gate, beta, state, opts, layer, cache)
+ } else {
+ // Use pre-computed masks from opts (created once in Model.Forward)
+ attnOut = gdn.deltaNetChunked(ctx, qConv, kConv, vConv, gate, beta, state, opts.masks, opts, layer, cache)
+ }
+
+ // Apply gated normalization
+ attnOut2D := attnOut.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
+ z2D := z.Contiguous(ctx, headVDim, numVHeads*nSeqTokens*nSeqs)
+
+ // norm(attnOut, z) = RMSNorm(attnOut) * silu(z)
+ attnOutNorm := gdn.SSMNorm.Forward(ctx, attnOut2D, opts.eps)
+ zSilu := z2D.SILU(ctx)
+ attnOutGated := attnOutNorm.Mul(ctx, zSilu)
+
+ // Reshape for output projection
+ finalOutput := attnOutGated.Reshape(ctx, headVDim*numVHeads, nSeqTokens, nSeqs)
+
+ out := gdn.SSMOut.Forward(ctx, finalOutput)
+ return out.Reshape(ctx, out.Dim(0), nSeqTokens*nSeqs), nil
+}
+
+// deltaNetAutoregressive implements single-token state update.
+// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
+func (gdn *GatedDeltaNet) deltaNetAutoregressive(
+ ctx ml.Context,
+ q, k, v, gate, beta, state ml.Tensor,
+ opts *Options,
+ layer int,
+ cache *HybridCache,
+) ml.Tensor {
+ numVHeads := v.Dim(1)
+ headVDim := v.Dim(0)
+ nSeqs := q.Dim(3)
+
+ // L2 normalize Q and K
+ q = q.L2Norm(ctx, opts.eps)
+ k = k.L2Norm(ctx, opts.eps)
+
+ // Scale Q
+ scale := 1.0 / math.Sqrt(float64(headVDim))
+ q = q.Scale(ctx, scale)
+
+ // Sigmoid beta
+ beta = beta.Sigmoid(ctx)
+
+ // Reshape state: [headVDim, headVDim, numVHeads, nSeqs]
+ state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
+
+ // Reshape gate and beta for broadcasting
+ gT := gate.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
+ betaT := beta.Permute(ctx, 1, 0, 2, 3).Reshape(ctx, 1, 1, numVHeads, nSeqs)
+
+ // Apply exponential to gate
+ gT = gT.Exp(ctx)
+
+ // state = state * g_t
+ state = state.Mul(ctx, gT)
+
+ // kv_mem = (state * k_t.unsqueeze(-1)).sum(dim=-2)
+ kTUnsqueezed := k.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
+ kvMem := state.Mul(ctx, kTUnsqueezed)
+ // Sum over dim=-2 (second dimension after permute)
+ kvMem = kvMem.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ kvMem = kvMem.SumRows(ctx)
+ kvMem = kvMem.Permute(ctx, 1, 0, 2, 3)
+
+ // v_t with singleton dimension
+ vT := v.Reshape(ctx, headVDim, 1, numVHeads, nSeqs)
+
+ // delta = (v_t - kv_mem) * beta_t
+ vDiff := vT.Sub(ctx, kvMem)
+ delta := vDiff.Mul(ctx, betaT)
+
+ // state = state + k_t.unsqueeze(-1) * delta
+ kTUnsqueezedBroad := kTUnsqueezed.Repeat4D(ctx, headVDim, headVDim, numVHeads, nSeqs)
+ kTDelta := kTUnsqueezedBroad.Mul(ctx, delta)
+ state = state.Add(ctx, kTDelta)
+
+ // core_attn_out = (state * q_t.unsqueeze(-1)).sum(dim=-2)
+ qTUnsqueezed := q.Reshape(ctx, 1, headVDim, numVHeads, nSeqs)
+ stateQ := state.Mul(ctx, qTUnsqueezed)
+ stateQ = stateQ.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ coreAttnOut := stateQ.SumRows(ctx)
+ coreAttnOut = coreAttnOut.Permute(ctx, 1, 0, 2, 3)
+
+ // Update delta state in cache
+ cache.UpdateDeltaState(ctx, layer, state.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
+
+ return coreAttnOut.Reshape(ctx, headVDim, numVHeads, 1, nSeqs)
+}
+
+// deltaNetChunked implements chunked computation for prefill.
+// NOTE: Assumes headKDim == headVDim (state shape is [headVDim, headVDim, numVHeads, nSeqs]).
+func (gdn *GatedDeltaNet) deltaNetChunked(
+ ctx ml.Context,
+ q, k, v, gate, beta, state ml.Tensor,
+ masks *Masks,
+ opts *Options,
+ layer int,
+ cache *HybridCache,
+) ml.Tensor {
+ headKDim := q.Dim(0)
+ numVHeads := v.Dim(1)
+ headVDim := v.Dim(0)
+ nTokens := q.Dim(2)
+ nSeqs := q.Dim(3)
+
+ // L2 normalize Q and K
+ q = q.L2Norm(ctx, opts.eps)
+ k = k.L2Norm(ctx, opts.eps)
+
+ // Scale Q
+ scale := 1.0 / math.Sqrt(float64(headVDim))
+ q = q.Scale(ctx, scale)
+
+ // Sigmoid beta
+ beta = beta.Sigmoid(ctx)
+
+ // Permute tensors for chunked computation
+ q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
+ k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headKDim, nTokens, numVHeads, nSeqs)
+ v = v.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, nTokens, numVHeads, nSeqs)
+ gate = gate.Permute(ctx, 2, 0, 3, 1).Contiguous(ctx, nTokens, 1, numVHeads, nSeqs)
+
+ beta = beta.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
+ state = state.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs)
+
+ // Compute padding
+ pad := (chunkSize - nTokens%chunkSize) % chunkSize
+ nChunks := (nTokens + pad) / chunkSize
+
+ // Pad tensors
+ if pad > 0 {
+ q = q.Pad(ctx, 0, pad, 0, 0)
+ k = k.Pad(ctx, 0, pad, 0, 0)
+ v = v.Pad(ctx, 0, pad, 0, 0)
+ gate = gate.Pad(ctx, pad, 0, 0, 0)
+ beta = beta.Pad(ctx, 0, pad, 0, 0)
+ }
+
+ // Use pre-computed masks (passed in, not recreated)
+ causalMask := masks.Causal
+ identity := masks.Identity
+ diagMask := masks.Diag
+ identity4D := identity.Reshape(ctx, chunkSize, chunkSize, 1, 1)
+
+ // v_beta = v * beta, k_beta = k * beta
+ vBeta := v.Mul(ctx, beta)
+ kBeta := k.Mul(ctx, beta)
+
+ // Reshape for chunked computation
+ q = q.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
+ k = k.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
+ kBeta = kBeta.Reshape(ctx, headKDim, chunkSize, nChunks, numVHeads*nSeqs)
+ vBeta = vBeta.Reshape(ctx, headVDim, chunkSize, nChunks, numVHeads*nSeqs)
+
+ gate = gate.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
+
+ // g_cumsum = cumsum(gate)
+ gCumsum := gate.CumSum(ctx)
+
+ // Compute decay mask
+ gcsI := gCumsum.Reshape(ctx, chunkSize, 1, nChunks, numVHeads*nSeqs)
+ gcsJ := gCumsum.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
+ gcsBroadcast := gcsJ.Repeat4D(ctx, chunkSize, chunkSize, nChunks, numVHeads*nSeqs)
+ decayMask := gcsBroadcast.Sub(ctx, gcsI)
+
+ decayMask = decayMask.Mul(ctx, diagMask)
+ decayMask = decayMask.Exp(ctx)
+ decayMask = decayMask.Mul(ctx, diagMask)
+
+ // k @ k_beta^T
+ kMulKBeta := k.Mulmat(ctx, kBeta)
+
+ // k_decay = k @ k_beta^T * decay_mask
+ kDecay := kMulKBeta.Mul(ctx, decayMask)
+
+ // attn = -k_decay * causal_mask
+ attn := kDecay.Neg(ctx).Mul(ctx, causalMask)
+
+ // Triangular solve: (I - attn_lower)^-1 @ attn
+ attnLower := attn.Mul(ctx, causalMask)
+ lhs := attnLower.Neg(ctx).Add(ctx, identity4D)
+ linSolve := lhs.SolveTri(ctx, attn, true, true, false)
+ attn = linSolve.Mul(ctx, causalMask)
+ attn = attn.Add(ctx, identity4D)
+
+ // v = v_beta^T @ attn
+ vBetaT := vBeta.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ v = vBetaT.Mulmat(ctx, attn)
+
+ // Compute g_exp for state update
+ gCumsumT := gCumsum.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ gExp := gCumsumT.Exp(ctx)
+
+ // kbeta_gexp = k_beta * g_exp
+ kBetaGExp := kBeta.Mul(ctx, gExp)
+
+ // k_cumdecay = attn @ kbeta_gexp^T
+ kBetaGExpT := kBetaGExp.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+ kCumdecay := attn.Mulmat(ctx, kBetaGExpT)
+ kCumdecay = kCumdecay.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+
+ // Pre-compute attn_kq = (k @ q) * decay_mask * diag_mask
+ attnKQ := k.Mulmat(ctx, q)
+ attnKQ = attnKQ.Mul(ctx, decayMask)
+ attnKQ = attnKQ.Mul(ctx, diagMask)
+
+ // Pre-compute g_last and key_gdiff
+ // g_last = view of last element in g_cumsum along chunk_size dimension
+ // We need to get the last row of gCumsum: shape [chunkSize, 1, nChunks, H*n_seqs] -> [1, 1, nChunks, H*n_seqs]
+ gLast := gCumsum.Slice(ctx, 0, chunkSize-1, chunkSize, 1).Contiguous(ctx, 1, 1, nChunks, numVHeads*nSeqs)
+ gLastExp := gLast.Exp(ctx)
+
+ // g_diff = -(g_cumsum - g_last) = g_last - g_cumsum
+ gDiff := gCumsum.Neg(ctx).Add(ctx, gLast)
+ gDiffExp := gDiff.Exp(ctx)
+
+ // Reshapes g_diff_exp to [1, chunkSize, nChunks, ...]
+ gDiffExpReshaped := gDiffExp.Reshape(ctx, 1, chunkSize, nChunks, numVHeads*nSeqs)
+ keyGDiff := k.Mul(ctx, gDiffExpReshaped)
+ keyGDiffT := keyGDiff.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+
+ // Process chunks and update state
+ var coreAttnOut ml.Tensor
+ newState := state
+
+ for chunk := range nChunks {
+ qChunk := q.Slice(ctx, 2, chunk, chunk+1, 1)
+ vChunk := v.Slice(ctx, 2, chunk, chunk+1, 1)
+ gExpChunk := gExp.Slice(ctx, 2, chunk, chunk+1, 1)
+ kCumdecayChunk := kCumdecay.Slice(ctx, 2, chunk, chunk+1, 1)
+ attnChunk := attnKQ.Slice(ctx, 2, chunk, chunk+1, 1) // Pre-computed!
+
+ // state^T - permute is needed but Contiguous creates a copy
+ stateT := newState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx, headVDim, headVDim, 1, numVHeads*nSeqs)
+
+ // v_prime = k_cumdecay @ state
+ vPrime := stateT.Mulmat(ctx, kCumdecayChunk)
+
+ // v_new = v - v_prime
+ vNew := vChunk.Sub(ctx, vPrime)
+ vNewT := vNew.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
+
+ // attn_inter = (q * g_exp) @ state
+ qGExp := qChunk.Mul(ctx, gExpChunk)
+ attnInter := stateT.Mulmat(ctx, qGExp)
+
+ // core_attn_out = attn_inter + attn @ v_new
+ vAttn := vNewT.Mulmat(ctx, attnChunk)
+ coreAttnOutChunk := attnInter.Add(ctx, vAttn)
+
+ if coreAttnOut == nil {
+ coreAttnOut = coreAttnOutChunk
+ } else {
+ coreAttnOut = coreAttnOut.Concat(ctx, coreAttnOutChunk, 1)
+ }
+
+ // Update state for next chunk
+ gExpLastChunk := gLastExp.Slice(ctx, 2, chunk, chunk+1, 1)
+ kGDiffChunkT := keyGDiffT.Slice(ctx, 2, chunk, chunk+1, 1)
+ kgdMulVNew := vNewT.Mulmat(ctx, kGDiffChunkT)
+
+ // state = state * g_last + kgdmulvnew
+ gExpLastReshaped := gExpLastChunk.Contiguous(ctx).Reshape(ctx, 1, 1, numVHeads, nSeqs)
+ newState = newState.Mul(ctx, gExpLastReshaped)
+ newState = newState.Add(ctx, kgdMulVNew.Reshape(ctx, headVDim, headVDim, numVHeads, nSeqs))
+ }
+
+ // Final reshape
+ coreAttnOut = coreAttnOut.Contiguous(ctx, headVDim, chunkSize*nChunks, numVHeads, nSeqs)
+
+ // Slice to remove padding
+ if pad > 0 {
+ coreAttnOut = coreAttnOut.Slice(ctx, 1, 0, nTokens, 1)
+ }
+
+ // Update delta state in cache
+ cache.UpdateDeltaState(ctx, layer, newState.Reshape(ctx, headVDim, headVDim*numVHeads, nSeqs))
+
+ return coreAttnOut.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, headVDim, numVHeads, nTokens, nSeqs)
+}
diff --git a/model/models/qwen3next/model.go b/model/models/qwen3next/model.go
new file mode 100644
index 00000000000..f3515e6994e
--- /dev/null
+++ b/model/models/qwen3next/model.go
@@ -0,0 +1,384 @@
+package qwen3next
+
+import (
+ "cmp"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/fs"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/ml/nn"
+ "github.com/ollama/ollama/ml/nn/rope"
+ "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
+)
+
+// Options contains model configuration
+type Options struct {
+ hiddenSize int
+ numHeads int
+ numKVHeads int
+ keyLength int
+ valueLength int
+ ropeDim int
+
+ eps float32
+ ropeBase float32
+ ropeScale float32
+ ropeType string
+ originalContextLength int
+ attentionScale float64
+
+ // MoE config
+ numExperts int
+ numExpertsUsed int
+ normTopKProb bool
+
+ // Linear attention (Gated Delta Net) config
+ ssmDInner int // d_inner = head_v_dim * num_v_heads
+ ssmDState int // head_k_dim
+ ssmNGroup int // num_k_heads
+ ssmDtRank int // num_v_heads
+ convKernelSize int // SSM conv kernel size
+
+ // Per-layer type from GGUF metadata
+ isRecurrent []bool
+
+ // Pre-computed masks for chunked attention (created once per forward pass)
+ masks *Masks
+}
+
+func (o Options) headDim() int {
+ return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
+}
+
+func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
+ opts := []func(*rope.Options){rope.WithTypeNeoX()}
+ if o.ropeType == "yarn" {
+ attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
+ opts = append(opts,
+ rope.WithOriginalContextLength(o.originalContextLength),
+ rope.WithExtrapolationFactor(1.),
+ rope.WithAttentionFactor(attnFactor),
+ )
+ }
+ ropeDim := cmp.Or(o.ropeDim, o.headDim())
+ return nn.RoPE(ctx, states, positions, ropeDim, o.ropeBase, 1./o.ropeScale, opts...)
+}
+
+// Operator is the interface for attention-like operators
+type Operator interface {
+ Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error)
+}
+
+// MLP is the interface for feedforward networks
+type MLP interface {
+ Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor
+}
+
+// sparse implements MoE with shared experts
+type sparse struct {
+ Router *nn.Linear `gguf:"ffn_gate_inp"`
+ Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
+ Up *nn.LinearBatch `gguf:"ffn_up_exps"`
+ Down *nn.LinearBatch `gguf:"ffn_down_exps"`
+
+ // Shared experts
+ SharedGateInp *nn.Linear `gguf:"ffn_gate_inp_shexp"`
+ SharedGate *nn.Linear `gguf:"ffn_gate_shexp"`
+ SharedUp *nn.Linear `gguf:"ffn_up_shexp"`
+ SharedDown *nn.Linear `gguf:"ffn_down_shexp"`
+}
+
+func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
+ hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
+ if batchSize == 0 {
+ batchSize = 1
+ }
+ hiddenStates2D := hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
+
+ // Router logits
+ routerLogits := mlp.Router.Forward(ctx, hiddenStates2D)
+
+ // Softmax routing weights
+ routingWeights := routerLogits.Softmax(ctx)
+ selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
+ routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates2D.Dim(1)).Rows(ctx, selectedExperts)
+ if opts.normTopKProb {
+ routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates2D.Dim(1))
+ routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
+ routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates2D.Dim(1))
+ }
+
+ hiddenStates3D := hiddenStates2D.Reshape(ctx, hiddenStates2D.Dim(0), 1, hiddenStates2D.Dim(1))
+
+ // Expert computation with SILU activation
+ gateOut := mlp.Gate.Forward(ctx, hiddenStates3D, selectedExperts)
+ upOut := mlp.Up.Forward(ctx, hiddenStates3D, selectedExperts)
+ experts := gateOut.SILU(ctx, upOut)
+ experts = mlp.Down.Forward(ctx, experts, selectedExperts)
+ experts = experts.Mul(ctx, routingWeights)
+
+ // Sum over experts
+ moeOut := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
+ for i := 1; i < opts.numExpertsUsed; i++ {
+ moeOut = moeOut.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
+ }
+
+ // Add shared experts if present
+ if mlp.SharedUp != nil {
+ sharedGate := mlp.SharedGate.Forward(ctx, hiddenStates2D)
+ sharedUp := mlp.SharedUp.Forward(ctx, hiddenStates2D)
+ sharedOut := sharedGate.SILU(ctx, sharedUp)
+ sharedOut = mlp.SharedDown.Forward(ctx, sharedOut)
+
+ // Apply shared expert gating
+ if mlp.SharedGateInp != nil {
+ sharedGateVal := mlp.SharedGateInp.Forward(ctx, hiddenStates2D)
+ sharedGateVal = sharedGateVal.SigmoidOut(ctx)
+ // Broadcast gate to match dimensions
+ sharedGateVal = sharedGateVal.Repeat(ctx, 0, sharedOut.Dim(0))
+ sharedOut = sharedOut.Mul(ctx, sharedGateVal)
+ }
+
+ moeOut = moeOut.Add(ctx, sharedOut)
+ }
+
+ return moeOut
+}
+
+// dense implements standard feedforward
+type dense struct {
+ Gate *nn.Linear `gguf:"ffn_gate"`
+ Up *nn.Linear `gguf:"ffn_up"`
+ Down *nn.Linear `gguf:"ffn_down"`
+}
+
+func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
+ hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
+ return mlp.Down.Forward(ctx, hiddenStates)
+}
+
+// Layer represents a single transformer layer
+type Layer struct {
+ AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
+ AttentionPostNorm *nn.RMSNorm `gguf:"post_attention_norm"` // Post-attention norm before FFN
+ Operator Operator
+
+ FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
+ MLP MLP
+}
+
+func (l *Layer) Forward(ctx ml.Context, layer int, hiddenStates, positions, outputs ml.Tensor, cache *HybridCache, opts *Options) (ml.Tensor, error) {
+ residual := hiddenStates
+
+ // Pre-attention norm
+ hiddenStates = l.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ // Attention (full or linear)
+ var err error
+ hiddenStates, err = l.Operator.Forward(ctx, hiddenStates, positions, cache, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ // Output projection for last layer
+ if outputs != nil {
+ hiddenStates = hiddenStates.Rows(ctx, outputs)
+ residual = residual.Rows(ctx, outputs)
+ }
+
+ // First residual connection
+ hiddenStates = hiddenStates.Add(ctx, residual)
+
+ // Save for FFN residual
+ ffnResidual := hiddenStates
+
+ // Post-attention norm (before FFN)
+ hiddenStates = l.AttentionPostNorm.Forward(ctx, hiddenStates, opts.eps)
+
+ // FFN
+ hiddenStates = l.MLP.Forward(ctx, hiddenStates, opts)
+
+ // Second residual connection
+ return hiddenStates.Add(ctx, ffnResidual), nil
+}
+
+// Model is the main Qwen3-Next model
+type Model struct {
+ model.Base
+ tokenizer.Tokenizer
+
+ TokenEmbedding *nn.Embedding `gguf:"token_embd"`
+ OutputNorm *nn.RMSNorm `gguf:"output_norm"`
+ Output *nn.Linear `gguf:"output,alt:token_embd"`
+
+ Layers []Layer `gguf:"blk"`
+
+ *Options
+}
+
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+ positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
+
+ hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
+
+ cache := m.Cache.(*HybridCache)
+
+ // Create masks once per forward pass
+ m.Options.masks = createMasks(ctx)
+
+ for i, layer := range m.Layers {
+ cache.SetLayer(i)
+
+ var outputs ml.Tensor
+ if i == len(m.Layers)-1 {
+ outputs = batch.Outputs
+ }
+
+ var err error
+ hiddenStates, err = layer.Forward(ctx, i, hiddenStates, positions, outputs, cache, m.Options)
+ if err != nil {
+ return nil, err
+ }
+ }
+
+ hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
+ return m.Output.Forward(ctx, hiddenStates), nil
+}
+
+func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
+ return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
+}
+
+var _ model.Model = (*Model)(nil)
+
+func New(c fs.Config) (model.Model, error) {
+ numLayers := int(c.Uint("block_count"))
+ layers := make([]Layer, numLayers)
+
+ // Get per-layer head counts (for detecting layer type)
+ type headCounts interface {
+ HeadCount() []uint64
+ HeadCountKV() []uint64
+ }
+
+ var isRecurrent []bool
+ var headCountKV []uint64
+ if hc, ok := c.(headCounts); ok {
+ headCountKV = hc.HeadCountKV()
+ }
+
+ isRecurrent = make([]bool, numLayers)
+ hasZero := false
+ hasFull := false
+ for i := range numLayers {
+ // If KV head count is 0, it's a recurrent layer
+ if i < len(headCountKV) && headCountKV[i] == 0 {
+ isRecurrent[i] = true
+ hasZero = true
+ } else if i < len(headCountKV) && headCountKV[i] > 0 {
+ hasFull = true
+ }
+ }
+ if !hasZero || !hasFull {
+ return nil, fmt.Errorf("qwen3next: invalid attention.head_count_kv array; expected mix of zero and non-zero values")
+ }
+
+ // Determine if MoE
+ isMoE := c.Uint("expert_count") > 0
+
+ for i := range layers {
+ if isRecurrent[i] {
+ layers[i].Operator = &GatedDeltaNet{Layer: i}
+ } else {
+ layers[i].Operator = &FullAttention{}
+ }
+
+ if isMoE {
+ layers[i].MLP = &sparse{}
+ } else {
+ layers[i].MLP = &dense{}
+ }
+ }
+
+ opts := &Options{
+ hiddenSize: int(c.Uint("embedding_length")),
+ numHeads: int(c.Uint("attention.head_count")),
+ numKVHeads: func() int {
+ for _, v := range headCountKV {
+ if v > 0 {
+ return int(v)
+ }
+ }
+ return 0
+ }(),
+ keyLength: int(c.Uint("attention.key_length")),
+ valueLength: int(c.Uint("attention.value_length")),
+ ropeDim: int(c.Uint("rope.dimension_count")),
+ eps: c.Float("attention.layer_norm_rms_epsilon"),
+ ropeType: c.String("rope.scaling.type"),
+ ropeBase: c.Float("rope.freq_base"),
+ ropeScale: c.Float("rope.scaling.factor", 1),
+ originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
+ attentionScale: float64(c.Float("attention.scale")),
+ numExperts: int(c.Uint("expert_count")),
+ numExpertsUsed: int(c.Uint("expert_used_count")),
+ normTopKProb: c.Bool("norm_top_k_prob", true),
+ ssmDInner: int(c.Uint("ssm.inner_size")),
+ ssmDState: int(c.Uint("ssm.state_size")),
+ ssmNGroup: int(c.Uint("ssm.group_count")),
+ ssmDtRank: int(c.Uint("ssm.time_step_rank")),
+ convKernelSize: int(c.Uint("ssm.conv_kernel")),
+ isRecurrent: isRecurrent,
+ }
+ if opts.numKVHeads == 0 {
+ return nil, fmt.Errorf("qwen3next: attention.head_count_kv array must include at least one non-zero value")
+ }
+
+ // Calculate cache dimensions
+ convDim := max(0, opts.convKernelSize-1)
+ convChannels := opts.ssmDInner + 2*opts.ssmNGroup*opts.ssmDState
+ headVDim := 0
+ numVHeads := opts.ssmDtRank
+ if numVHeads > 0 {
+ headVDim = opts.ssmDInner / numVHeads
+ }
+ deltaStateSize := headVDim * headVDim * numVHeads
+
+ // Validate dimension assumption: headKDim == headVDim is required for state computations
+ headKDim := opts.ssmDState
+ if headKDim != headVDim && headKDim > 0 && headVDim > 0 {
+ return nil, fmt.Errorf("qwen3next: headKDim (%d) != headVDim (%d) not supported; state computations require equal dimensions", headKDim, headVDim)
+ }
+
+ m := Model{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
+ Values: c.Strings("tokenizer.ggml.tokens"),
+ Types: c.Ints("tokenizer.ggml.token_type"),
+ Merges: c.Strings("tokenizer.ggml.merges"),
+ // Qwen3 tokenizers typically set add_bos_token=false and bos_token=null.
+ // Default to false when the GGUF key is missing to avoid injecting a spurious BOS.
+ AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
+ BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
+ AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
+ EOS: append(
+ []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
+ c.Ints("tokenizer.ggml.eos_token_ids")...,
+ ),
+ },
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
+ ),
+ Layers: layers,
+ Options: opts,
+ }
+
+ m.Cache = NewHybridCache(m.Shift, convDim, convChannels, deltaStateSize)
+ return &m, nil
+}
+
+func init() {
+ model.Register("qwen3next", New)
+}
diff --git a/model/models/qwen3vl/model.go b/model/models/qwen3vl/model.go
index cb1ce8d2c6b..740c548ff40 100644
--- a/model/models/qwen3vl/model.go
+++ b/model/models/qwen3vl/model.go
@@ -10,11 +10,12 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
+ "github.com/ollama/ollama/tokenizer"
)
type Model struct {
model.Base
- model.TextProcessor
+ tokenizer.Tokenizer
*TextModel
*VisionModel `gguf:"v"`
@@ -172,8 +173,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
- TextProcessor: model.NewBytePairEncoding(
- &model.Vocabulary{
+ Tokenizer: tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
diff --git a/model/parsers/glmocr.go b/model/parsers/glmocr.go
new file mode 100644
index 00000000000..671ba939cd3
--- /dev/null
+++ b/model/parsers/glmocr.go
@@ -0,0 +1,17 @@
+package parsers
+
+import "github.com/ollama/ollama/api"
+
+// GlmOcrParser is the GLM46 parser with thinking disabled.
+type GlmOcrParser struct {
+ GLM46Parser
+}
+
+func (p *GlmOcrParser) HasThinkingSupport() bool {
+ return false
+}
+
+func (p *GlmOcrParser) Init(tools []api.Tool, _ *api.Message, _ *api.ThinkValue) []api.Tool {
+ p.tools = tools
+ return tools
+}
diff --git a/model/parsers/lfm2.go b/model/parsers/lfm2.go
new file mode 100644
index 00000000000..4aade692670
--- /dev/null
+++ b/model/parsers/lfm2.go
@@ -0,0 +1,498 @@
+package parsers
+
+import (
+ "encoding/json"
+ "errors"
+ "log/slog"
+ "strconv"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+)
+
+type LFM2ParserState int
+
+const (
+ LFM2CollectingThinking LFM2ParserState = iota
+ LFM2CollectingContent
+ LFM2CollectingToolCalls
+)
+
+const (
+ lfm2ThinkingOpenTag = ""
+ lfm2ThinkingCloseTag = ""
+ lfm2ToolCallStartTag = "<|tool_call_start|>"
+ lfm2ToolCallEndTag = "<|tool_call_end|>"
+)
+
+type LFM2Parser struct {
+ state LFM2ParserState
+ buffer strings.Builder
+ hasThinkingSupport bool
+ needsThinkingLeadingTrim bool // trim leading whitespace after tag
+ needsContentLeadingTrim bool // trim leading whitespace after tag
+}
+
+func (p *LFM2Parser) HasToolSupport() bool {
+ return true
+}
+
+func (p *LFM2Parser) HasThinkingSupport() bool {
+ return p.hasThinkingSupport
+}
+
+func (p *LFM2Parser) setInitialState(lastMessage *api.Message, thinkValue *api.ThinkValue) {
+ prefill := lastMessage != nil && lastMessage.Role == "assistant"
+
+ // Check both model capability AND request preference
+ thinkingEnabled := p.HasThinkingSupport() && (thinkValue != nil && thinkValue.Bool())
+
+ if !thinkingEnabled {
+ p.state = LFM2CollectingContent
+ return
+ }
+
+ if prefill && lastMessage.Content != "" {
+ p.state = LFM2CollectingContent
+ return
+ }
+
+ p.state = LFM2CollectingThinking
+ p.needsThinkingLeadingTrim = true
+}
+
+func (p *LFM2Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
+ p.setInitialState(lastMessage, thinkValue)
+ return tools
+}
+
+type lfm2Event interface {
+ isLFM2Event()
+}
+
+type lfm2EventThinkingContent struct {
+ content string
+}
+
+type lfm2EventContent struct {
+ content string
+}
+
+type lfm2EventToolCall struct {
+ toolCall api.ToolCall
+}
+
+func (lfm2EventThinkingContent) isLFM2Event() {}
+func (lfm2EventContent) isLFM2Event() {}
+func (lfm2EventToolCall) isLFM2Event() {}
+
+func (p *LFM2Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.buffer.WriteString(s)
+ events := p.parseEvents()
+
+ var toolCalls []api.ToolCall
+ var contentSb strings.Builder
+ var thinkingSb strings.Builder
+ for _, event := range events {
+ switch event := event.(type) {
+ case lfm2EventToolCall:
+ toolCalls = append(toolCalls, event.toolCall)
+ case lfm2EventThinkingContent:
+ thinkingSb.WriteString(event.content)
+ case lfm2EventContent:
+ contentSb.WriteString(event.content)
+ }
+ }
+
+ return contentSb.String(), thinkingSb.String(), toolCalls, nil
+}
+
+func (p *LFM2Parser) parseEvents() []lfm2Event {
+ var all []lfm2Event
+
+ keepLooping := true
+ for keepLooping {
+ var events []lfm2Event
+ events, keepLooping = p.eat()
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ return all
+}
+
+func (p *LFM2Parser) eat() ([]lfm2Event, bool) {
+ var events []lfm2Event
+ bufStr := p.buffer.String()
+ if bufStr == "" {
+ return events, false
+ }
+
+ switch p.state {
+ case LFM2CollectingThinking:
+ // Strip opening tag if present
+ if strings.HasPrefix(bufStr, lfm2ThinkingOpenTag) {
+ bufStr = bufStr[len(lfm2ThinkingOpenTag):]
+ p.needsThinkingLeadingTrim = true
+ p.buffer.Reset()
+ p.buffer.WriteString(bufStr)
+ }
+
+ // Trim leading whitespace after tag (may span multiple chunks)
+ if p.needsThinkingLeadingTrim {
+ if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
+ bufStr = trimmed
+ p.buffer.Reset()
+ p.buffer.WriteString(bufStr)
+ }
+ // Clear flag once we have non-whitespace content or buffer is empty
+ if len(bufStr) > 0 {
+ p.needsThinkingLeadingTrim = false
+ }
+ }
+
+ if strings.Contains(bufStr, lfm2ThinkingCloseTag) { // thinking[] -> content
+ split := strings.SplitN(bufStr, lfm2ThinkingCloseTag, 2)
+ thinking := split[0]
+ thinking = strings.TrimRightFunc(thinking, unicode.IsSpace)
+
+ remaining := split[1]
+ remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
+
+ p.buffer.Reset()
+ p.buffer.WriteString(remaining)
+ p.state = LFM2CollectingContent
+ p.needsThinkingLeadingTrim = false
+ // Set flag to trim any additional whitespace that may arrive in later chunks
+ p.needsContentLeadingTrim = len(remaining) == 0
+
+ if len(thinking) > 0 {
+ events = append(events, lfm2EventThinkingContent{content: thinking})
+ }
+ return events, true
+ } else if overlapLen := overlap(bufStr, lfm2ThinkingCloseTag); overlapLen > 0 { // partial
+ beforePartialTag := bufStr[:len(bufStr)-overlapLen]
+ trailingLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingLen
+
+ unambiguous := bufStr[:ambiguousStart]
+ ambiguous := bufStr[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, lfm2EventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ } else { // otherwise its thinking content
+ whitespaceLen := trailingWhitespaceLen(bufStr)
+ ambiguousStart := len(bufStr) - whitespaceLen
+
+ unambiguous := bufStr[:ambiguousStart]
+ ambiguous := bufStr[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, lfm2EventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case LFM2CollectingContent:
+ // Trim leading whitespace after tag (may span multiple chunks)
+ if p.needsContentLeadingTrim {
+ if trimmed := strings.TrimLeftFunc(bufStr, unicode.IsSpace); trimmed != bufStr {
+ bufStr = trimmed
+ p.buffer.Reset()
+ p.buffer.WriteString(bufStr)
+ }
+ // Clear flag once we have non-whitespace content
+ if len(bufStr) > 0 {
+ p.needsContentLeadingTrim = false
+ }
+ }
+
+ if strings.Contains(bufStr, lfm2ToolCallStartTag) { // content[<|tool_call_start|>] -> tool calls
+ split := strings.SplitN(bufStr, lfm2ToolCallStartTag, 2)
+ contentBefore := strings.TrimRightFunc(split[0], unicode.IsSpace)
+ remaining := split[1]
+
+ p.buffer.Reset()
+ p.buffer.WriteString(remaining)
+ p.state = LFM2CollectingToolCalls
+
+ if len(contentBefore) > 0 {
+ events = append(events, lfm2EventContent{content: contentBefore})
+ }
+ return events, true
+ } else { // otherwise its content
+ p.buffer.Reset()
+ if len(bufStr) > 0 {
+ events = append(events, lfm2EventContent{content: bufStr})
+ }
+ return events, false
+ }
+
+ case LFM2CollectingToolCalls:
+ // Look for complete tool call JSON between tags
+ if idx := strings.Index(bufStr, lfm2ToolCallEndTag); idx != -1 {
+ toolCallContent := bufStr[:idx]
+
+ if toolCalls, err := p.parseToolCallsContent(toolCallContent); err == nil && len(toolCalls) > 0 {
+ remaining := bufStr[idx+len(lfm2ToolCallEndTag):]
+
+ // Check if there's another tool call
+ if strings.HasPrefix(remaining, lfm2ToolCallStartTag) {
+ remaining = remaining[len(lfm2ToolCallStartTag):]
+ } else {
+ // No more tool calls, go back to content
+ remaining = strings.TrimLeftFunc(remaining, unicode.IsSpace)
+ p.state = LFM2CollectingContent
+ }
+
+ p.buffer.Reset()
+ p.buffer.WriteString(remaining)
+
+ for _, tc := range toolCalls {
+ events = append(events, lfm2EventToolCall{toolCall: tc})
+ }
+ return events, true
+ } else if err != nil {
+ slog.Warn("lfm2 tool call parsing failed", "error", err, "content", toolCallContent)
+ }
+ }
+
+ return events, false
+ }
+
+ return events, false
+}
+
+// parseToolCallsContent parses one or more tool calls from content
+// Supports JSON format and Python-style format including multiple calls: [func1(...),func2(...)]
+func (p *LFM2Parser) parseToolCallsContent(content string) ([]api.ToolCall, error) {
+ content = strings.TrimSpace(content)
+
+ // Try JSON format first: {"name": "func", "arguments": {...}}
+ var parsed struct {
+ Name string `json:"name"`
+ Arguments json.RawMessage `json:"arguments"`
+ }
+
+ if err := json.Unmarshal([]byte(content), &parsed); err == nil && parsed.Name != "" {
+ var args api.ToolCallFunctionArguments
+ if len(parsed.Arguments) > 0 {
+ if err := json.Unmarshal(parsed.Arguments, &args); err != nil {
+ return nil, err
+ }
+ } else {
+ args = api.NewToolCallFunctionArguments()
+ }
+
+ return []api.ToolCall{{
+ Function: api.ToolCallFunction{
+ Name: parsed.Name,
+ Arguments: args,
+ },
+ }}, nil
+ }
+
+ // Try Python-style format: [func(arg1='val1'),func2(arg2='val2')] or func(arg1='val1')
+ return p.parsePythonStyleToolCalls(content)
+}
+
+// parsePythonStyleToolCalls parses one or more Python-style tool calls
+// Examples: [bash(command='ls'),bash(command='pwd')] or bash(command='ls')
+func (p *LFM2Parser) parsePythonStyleToolCalls(content string) ([]api.ToolCall, error) {
+ content = strings.TrimSpace(content)
+
+ // Strip outer brackets if present: [func(...)] -> func(...)
+ if strings.HasPrefix(content, "[") && strings.HasSuffix(content, "]") {
+ content = content[1 : len(content)-1]
+ }
+
+ var toolCalls []api.ToolCall
+
+ // Parse multiple function calls separated by commas at the top level
+ for len(content) > 0 {
+ content = strings.TrimSpace(content)
+ if content == "" {
+ break
+ }
+
+ // Skip leading comma from previous iteration
+ if strings.HasPrefix(content, ",") {
+ content = strings.TrimSpace(content[1:])
+ if content == "" {
+ break
+ }
+ }
+
+ // Find function name
+ parenIdx := strings.Index(content, "(")
+ if parenIdx == -1 {
+ return nil, errors.New("invalid tool call: no opening parenthesis")
+ }
+
+ funcName := strings.TrimSpace(content[:parenIdx])
+ if funcName == "" {
+ return nil, errors.New("invalid tool call: empty function name")
+ }
+
+ // Find matching closing parenthesis
+ closeIdx := findMatchingParen(content, parenIdx)
+ if closeIdx == -1 {
+ return nil, errors.New("invalid tool call: no matching closing parenthesis")
+ }
+
+ argsStr := content[parenIdx+1 : closeIdx]
+ args := api.NewToolCallFunctionArguments()
+
+ if argsStr != "" {
+ if err := parsePythonArgs(argsStr, &args); err != nil {
+ return nil, err
+ }
+ }
+
+ toolCalls = append(toolCalls, api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: funcName,
+ Arguments: args,
+ },
+ })
+
+ // Move past this function call
+ content = content[closeIdx+1:]
+ }
+
+ if len(toolCalls) == 0 {
+ return nil, errors.New("no tool calls found")
+ }
+
+ return toolCalls, nil
+}
+
+// findMatchingParen finds the index of the closing parenthesis matching the one at openIdx
+// Returns -1 if not found. Handles nested parentheses and quoted strings.
+func findMatchingParen(s string, openIdx int) int {
+ depth := 1
+ i := openIdx + 1
+ for i < len(s) && depth > 0 {
+ switch s[i] {
+ case '(':
+ depth++
+ case ')':
+ depth--
+ if depth == 0 {
+ return i
+ }
+ case '\'', '"':
+ // Skip quoted string
+ quote := s[i]
+ i++
+ for i < len(s) && s[i] != quote {
+ if s[i] == '\\' && i+1 < len(s) {
+ i++ // skip escaped char
+ }
+ i++
+ }
+ }
+ i++
+ }
+ return -1
+}
+
+// parseToolCallContent parses a single tool call (for backward compatibility with tests)
+func (p *LFM2Parser) parseToolCallContent(content string) (api.ToolCall, error) {
+ calls, err := p.parseToolCallsContent(content)
+ if err != nil {
+ return api.ToolCall{}, err
+ }
+ if len(calls) == 0 {
+ return api.ToolCall{}, errors.New("no tool call found")
+ }
+ return calls[0], nil
+}
+
+// parsePythonArgs parses Python-style keyword arguments: key='value', key2="value2"
+func parsePythonArgs(argsStr string, args *api.ToolCallFunctionArguments) error {
+ // Simple state machine to parse key='value' pairs
+ // Handles: command='ls', flag="-la", count=42, enabled=true
+ var key string
+ i := 0
+
+ for i < len(argsStr) {
+ // Skip whitespace
+ for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
+ i++
+ }
+ if i >= len(argsStr) {
+ break
+ }
+
+ // Parse key
+ keyStart := i
+ for i < len(argsStr) && argsStr[i] != '=' && argsStr[i] != ',' {
+ i++
+ }
+ if i >= len(argsStr) || argsStr[i] != '=' {
+ return errors.New("invalid argument: expected '='")
+ }
+ key = strings.TrimSpace(argsStr[keyStart:i])
+ i++ // skip '='
+
+ // Skip whitespace after =
+ for i < len(argsStr) && (argsStr[i] == ' ' || argsStr[i] == '\t') {
+ i++
+ }
+
+ // Parse value
+ var value string
+ if i < len(argsStr) && (argsStr[i] == '\'' || argsStr[i] == '"') {
+ // Quoted string
+ quote := argsStr[i]
+ i++
+ valueStart := i
+ for i < len(argsStr) && argsStr[i] != quote {
+ if argsStr[i] == '\\' && i+1 < len(argsStr) {
+ i += 2 // skip escaped char
+ } else {
+ i++
+ }
+ }
+ value = argsStr[valueStart:i]
+ if i < len(argsStr) {
+ i++ // skip closing quote
+ }
+ args.Set(key, value)
+ } else {
+ // Unquoted value (number, bool, etc)
+ valueStart := i
+ for i < len(argsStr) && argsStr[i] != ',' {
+ i++
+ }
+ value = strings.TrimSpace(argsStr[valueStart:i])
+
+ // Try to parse as number or bool
+ if v, err := strconv.ParseInt(value, 10, 64); err == nil {
+ args.Set(key, v)
+ } else if v, err := strconv.ParseFloat(value, 64); err == nil {
+ args.Set(key, v)
+ } else if value == "true" {
+ args.Set(key, true)
+ } else if value == "false" {
+ args.Set(key, false)
+ } else {
+ args.Set(key, value)
+ }
+ }
+
+ // Skip comma and whitespace
+ for i < len(argsStr) && (argsStr[i] == ',' || argsStr[i] == ' ' || argsStr[i] == '\t' || argsStr[i] == '\n') {
+ i++
+ }
+ }
+
+ return nil
+}
diff --git a/model/parsers/lfm2_test.go b/model/parsers/lfm2_test.go
new file mode 100644
index 00000000000..3e139b8117a
--- /dev/null
+++ b/model/parsers/lfm2_test.go
@@ -0,0 +1,1088 @@
+package parsers
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestLFM2Parser(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expectedContent string
+ expectedThinking string
+ expectedCalls []api.ToolCall
+ hasThinking bool
+ }{
+ {
+ name: "simple_content",
+ input: "Hello, how are you?",
+ expectedContent: "Hello, how are you?",
+ hasThinking: false,
+ },
+ {
+ name: "thinking_content",
+ input: "I need to think about this...The answer is 42.",
+ expectedThinking: "I need to think about this...",
+ expectedContent: "The answer is 42.",
+ hasThinking: true,
+ },
+ {
+ name: "thinking_with_newlines",
+ input: "Let me think:\n- Point 1\n- Point 2\n\nHere's my answer.",
+ expectedThinking: "Let me think:\n- Point 1\n- Point 2",
+ expectedContent: "Here's my answer.",
+ hasThinking: true,
+ },
+ {
+ name: "tool_call_simple",
+ input: "I'll check the weather.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
+ expectedContent: "I'll check the weather.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "multiple_tool_calls",
+ input: "Getting weather for both cities.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|><|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"London\"}}<|tool_call_end|>",
+ expectedContent: "Getting weather for both cities.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "London",
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "complex_tool_arguments",
+ input: "Processing data.<|tool_call_start|>{\"name\":\"process_data\",\"arguments\":{\"items\":[\"item1\",\"item2\"],\"config\":{\"enabled\":true,\"threshold\":0.95}}}<|tool_call_end|>",
+ expectedContent: "Processing data.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "process_data",
+ Arguments: testArgs(map[string]any{
+ "items": []interface{}{"item1", "item2"},
+ "config": map[string]interface{}{"enabled": true, "threshold": 0.95},
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "thinking_with_tool_call",
+ input: "Let me check the weather...I'll get that for you.<|tool_call_start|>{\"name\":\"get_weather\",\"arguments\":{\"location\":\"Paris\"}}<|tool_call_end|>",
+ expectedThinking: "Let me check the weather...",
+ expectedContent: "I'll get that for you.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ hasThinking: true,
+ },
+ {
+ name: "empty_content",
+ input: "",
+ expectedContent: "",
+ hasThinking: false,
+ },
+ {
+ name: "only_thinking",
+ input: "Just thinking content",
+ expectedThinking: "Just thinking content",
+ expectedContent: "",
+ hasThinking: true,
+ },
+ {
+ name: "unicode_content",
+ input: "مرحبا بالعالم! 你好世界! 🌍",
+ expectedContent: "مرحبا بالعالم! 你好世界! 🌍",
+ hasThinking: false,
+ },
+ {
+ name: "newlines_and_whitespace",
+ input: "Line 1\n\nLine 3\t\tTabbed content",
+ expectedContent: "Line 1\n\nLine 3\t\tTabbed content",
+ hasThinking: false,
+ },
+ {
+ name: "thinking_with_unicode",
+ input: "我在思考这个问题...答案是42。",
+ expectedThinking: "我在思考这个问题...",
+ expectedContent: "答案是42。",
+ hasThinking: true,
+ },
+ {
+ name: "tool_call_with_unicode_args",
+ input: "Searching for information.<|tool_call_start|>{\"name\":\"search\",\"arguments\":{\"query\":\"北京天气\",\"language\":\"中文\"}}<|tool_call_end|>",
+ expectedContent: "Searching for information.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "search",
+ Arguments: testArgs(map[string]any{
+ "query": "北京天气",
+ "language": "中文",
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "thinking_with_special_chars",
+ input: "Let me calculate: 2+2=4 & 3*3=9...The results are correct!",
+ expectedThinking: "Let me calculate: 2+2=4 & 3*3=9...",
+ expectedContent: "The results are correct!",
+ hasThinking: true,
+ },
+ {
+ name: "empty_tool_call_args",
+ input: "Pinging server.<|tool_call_start|>{\"name\":\"ping\",\"arguments\":{}}<|tool_call_end|>",
+ expectedContent: "Pinging server.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "ping",
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ // Python-style tool call tests (from Liquid AI docs)
+ {
+ name: "python_style_tool_call",
+ input: "Let me check that.<|tool_call_start|>[get_candidate_status(candidate_id=\"12345\")]<|tool_call_end|>",
+ expectedContent: "Let me check that.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_candidate_status",
+ Arguments: testArgs(map[string]any{
+ "candidate_id": "12345",
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "python_style_multiple_calls",
+ input: "Running commands.<|tool_call_start|>[bash(command='ls'),bash(command='pwd')]<|tool_call_end|>",
+ expectedContent: "Running commands.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "ls",
+ }),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "pwd",
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "thinking_then_python_tool_call",
+ input: "I should check the status...Let me look that up.<|tool_call_start|>[get_status(id=\"123\")]<|tool_call_end|>",
+ expectedThinking: "I should check the status...",
+ expectedContent: "Let me look that up.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_status",
+ Arguments: testArgs(map[string]any{
+ "id": "123",
+ }),
+ },
+ },
+ },
+ hasThinking: true,
+ },
+ {
+ name: "python_style_no_args",
+ input: "Pinging.<|tool_call_start|>[ping()]<|tool_call_end|>",
+ expectedContent: "Pinging.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "ping",
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "python_style_mixed_types",
+ input: "Processing.<|tool_call_start|>[process(name=\"test\", count=42, enabled=true)]<|tool_call_end|>",
+ expectedContent: "Processing.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "process",
+ Arguments: testArgs(map[string]any{
+ "name": "test",
+ "count": int64(42),
+ "enabled": true,
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "tool_call_only_no_content",
+ input: "<|tool_call_start|>[check()]<|tool_call_end|>",
+ expectedContent: "",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "check",
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "thinking_directly_to_tool_call",
+ input: "Let me run this command...<|tool_call_start|>[bash(command='ls')]<|tool_call_end|>",
+ expectedThinking: "Let me run this command...",
+ expectedContent: "",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "ls",
+ }),
+ },
+ },
+ },
+ hasThinking: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
+ parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
+
+ content, thinking, calls, err := parser.Add(tt.input, true)
+ if err != nil {
+ t.Fatalf("Add() error = %v", err)
+ }
+
+ if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
+ t.Errorf("Content mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
+ t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(tt.expectedCalls, calls, argsComparer); diff != "" {
+ t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLFM2Parser_Streaming(t *testing.T) {
+ tests := []struct {
+ name string
+ chunks []string
+ expectedContent string
+ expectedThinking string
+ expectedCalls []api.ToolCall
+ hasThinking bool
+ }{
+ {
+ name: "streaming_simple_content",
+ chunks: []string{"Hello, ", "how are ", "you?"},
+ expectedContent: "Hello, how are you?",
+ hasThinking: false,
+ },
+ {
+ name: "streaming_thinking",
+ chunks: []string{"I need to ", "think about this", "...", "The answer is 42."},
+ expectedThinking: "I need to think about this...",
+ expectedContent: "The answer is 42.",
+ hasThinking: true,
+ },
+ {
+ name: "streaming_tool_call",
+ chunks: []string{"I'll check weather.", "<|tool_call_start|>", "{\"name\":\"get_weather\",", "\"arguments\":{\"location\":\"Paris\"}}", "<|tool_call_end|>"},
+ expectedContent: "I'll check weather.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ name: "streaming_thinking_with_partial_tag",
+ chunks: []string{"Thinking about this", "...", "think>", "Done thinking."},
+ expectedThinking: "Thinking about this...",
+ expectedContent: "Done thinking.",
+ hasThinking: true,
+ },
+ {
+ name: "streaming_unicode_content",
+ chunks: []string{"مرحبا ", "بالعالم! ", "你好", "世界!"},
+ expectedContent: "مرحبا بالعالم! 你好世界!",
+ hasThinking: false,
+ },
+ {
+ name: "streaming_tool_call_with_split_json",
+ chunks: []string{"Processing.", "<|tool_call_start|>{\"name\":\"calc\",\"arguments\":{\"x\":", "42,\"y\":", "24}}<|tool_call_end|>"},
+ expectedContent: "Processing.",
+ expectedCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "calc",
+ Arguments: testArgs(map[string]any{
+ "x": float64(42),
+ "y": float64(24),
+ }),
+ },
+ },
+ },
+ hasThinking: false,
+ },
+ {
+ // Test that leading whitespace after is trimmed even when in separate chunks
+ name: "streaming_thinking_whitespace_after_tag",
+ chunks: []string{"", "\n\n ", "Actual thinking content", "", "Response"},
+ expectedThinking: "Actual thinking content",
+ expectedContent: "Response",
+ hasThinking: true,
+ },
+ {
+ // Test whitespace between and content in streaming
+ name: "streaming_whitespace_after_close_tag",
+ chunks: []string{"Thinking", "\n\n\n", "Response content"},
+ expectedThinking: "Thinking",
+ expectedContent: "Response content",
+ hasThinking: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
+ parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
+
+ var allContent, allThinking string
+ var allCalls []api.ToolCall
+
+ for i, chunk := range tt.chunks {
+ done := i == len(tt.chunks)-1
+ content, thinking, calls, err := parser.Add(chunk, done)
+ if err != nil {
+ t.Fatalf("Add() error = %v", err)
+ }
+
+ allContent += content
+ allThinking += thinking
+ allCalls = append(allCalls, calls...)
+ }
+
+ if diff := cmp.Diff(tt.expectedContent, allContent); diff != "" {
+ t.Errorf("Content mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(tt.expectedThinking, allThinking); diff != "" {
+ t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(tt.expectedCalls, allCalls, argsComparer); diff != "" {
+ t.Errorf("Tool calls mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLFM2Parser_HasThinkingSupport(t *testing.T) {
+ tests := []struct {
+ name string
+ hasThinking bool
+ expectedSupport bool
+ }{
+ {
+ name: "thinking_enabled",
+ hasThinking: true,
+ expectedSupport: true,
+ },
+ {
+ name: "thinking_disabled",
+ hasThinking: false,
+ expectedSupport: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
+ if got := parser.HasThinkingSupport(); got != tt.expectedSupport {
+ t.Errorf("HasThinkingSupport() = %v, want %v", got, tt.expectedSupport)
+ }
+ })
+ }
+}
+
+func TestLFM2Parser_HasToolSupport(t *testing.T) {
+ parser := &LFM2Parser{}
+ if !parser.HasToolSupport() {
+ t.Error("HasToolSupport() should return true")
+ }
+}
+
+func TestLFM2Parser_Init(t *testing.T) {
+ parser := &LFM2Parser{hasThinkingSupport: true}
+ tools := []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "test_tool",
+ },
+ },
+ }
+
+ returnedTools := parser.Init(tools, nil, &api.ThinkValue{Value: true})
+
+ if diff := cmp.Diff(tools, returnedTools, toolsComparer); diff != "" {
+ t.Errorf("Init() returned tools mismatch (-want +got):\n%s", diff)
+ }
+
+ // Test initial state is set to thinking when enabled
+ if parser.state != LFM2CollectingThinking {
+ t.Errorf("Expected initial state to be LFM2CollectingThinking, got %v", parser.state)
+ }
+}
+
+func TestLFM2Parser_parseToolCallContent(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ expected api.ToolCall
+ expectError bool
+ }{
+ {
+ name: "valid_tool_call",
+ content: `{"name":"get_weather","arguments":{"location":"Paris"}}`,
+ expected: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ {
+ name: "complex_arguments",
+ content: `{"name":"process_data","arguments":{"items":["a","b"],"config":{"enabled":true}}}`,
+ expected: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "process_data",
+ Arguments: testArgs(map[string]any{
+ "items": []interface{}{"a", "b"},
+ "config": map[string]interface{}{"enabled": true},
+ }),
+ },
+ },
+ },
+ {
+ name: "empty_arguments",
+ content: `{"name":"ping","arguments":{}}`,
+ expected: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "ping",
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ },
+ },
+ {
+ name: "unicode_in_tool_name",
+ content: `{"name":"获取天气","arguments":{"城市":"北京"}}`,
+ expected: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "获取天气",
+ Arguments: testArgs(map[string]any{
+ "城市": "北京",
+ }),
+ },
+ },
+ },
+ {
+ name: "numeric_arguments",
+ content: `{"name":"calculate","arguments":{"x":3.14,"y":42,"enabled":true}}`,
+ expected: api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: "calculate",
+ Arguments: testArgs(map[string]any{
+ "x": 3.14,
+ "y": float64(42),
+ "enabled": true,
+ }),
+ },
+ },
+ },
+ {
+ name: "invalid_json",
+ content: `{invalid json}`,
+ expectError: true,
+ },
+ {
+ name: "missing_name",
+ content: `{"arguments":{"arg":"value"}}`,
+ expectError: true,
+ },
+ {
+ name: "empty_name",
+ content: `{"name":"","arguments":{"arg":"value"}}`,
+ expectError: true,
+ },
+ }
+
+ parser := &LFM2Parser{}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := parser.parseToolCallContent(tt.content)
+
+ if tt.expectError {
+ if err == nil {
+ t.Error("Expected error but got none")
+ }
+ return
+ }
+
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
+ t.Errorf("parseToolCallContent() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLFM2Parser_parseToolCallsContent(t *testing.T) {
+ tests := []struct {
+ name string
+ content string
+ expected []api.ToolCall
+ expectError bool
+ }{
+ {
+ name: "multiple_python_style_calls",
+ content: `[bash(command='curl google.com'),bash(command='curl example.com')]`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "curl google.com",
+ }),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "curl example.com",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "single_python_style_call",
+ content: `bash(command='ls -la')`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "ls -la",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "single_bracketed_call",
+ content: `[bash(command='pwd')]`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "pwd",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "multiple_different_functions",
+ content: `[get_weather(location='Paris'),search(query='news')]`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "search",
+ Arguments: testArgs(map[string]any{
+ "query": "news",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "nested_parentheses_in_arg",
+ content: `bash(command='echo "(hello)"')`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": `echo "(hello)"`,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "comma_inside_quotes",
+ content: `bash(command='echo "hello, world"')`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": `echo "hello, world"`,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "equals_inside_quotes",
+ content: `bash(command='export FOO=bar')`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": `export FOO=bar`,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "double_quotes_with_single_inside",
+ content: `bash(command="echo 'hello'")`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": `echo 'hello'`,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "multiple_args",
+ content: `bash(command='ls', flag='-la', count=42)`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": "ls",
+ "flag": "-la",
+ "count": int64(42),
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "no_args",
+ content: `ping()`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "ping",
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ },
+ },
+ },
+ {
+ name: "three_calls",
+ content: `[a(x='1'),b(y='2'),c(z='3')]`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "a",
+ Arguments: testArgs(map[string]any{"x": "1"}),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "b",
+ Arguments: testArgs(map[string]any{"y": "2"}),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "c",
+ Arguments: testArgs(map[string]any{"z": "3"}),
+ },
+ },
+ },
+ },
+ {
+ // Note: backslash escapes are preserved as-is, not processed
+ name: "escaped_quote_in_value",
+ content: `bash(command='echo \'hello\'')`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "bash",
+ Arguments: testArgs(map[string]any{
+ "command": `echo \'hello\'`,
+ }),
+ },
+ },
+ },
+ },
+ // Tests based on Liquid AI documentation examples
+ {
+ name: "docs_example_candidate_status",
+ content: `[get_candidate_status(candidate_id="12345")]`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_candidate_status",
+ Arguments: testArgs(map[string]any{
+ "candidate_id": "12345",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "boolean_true_arg",
+ content: `configure(enabled=true)`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "configure",
+ Arguments: testArgs(map[string]any{
+ "enabled": true,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "boolean_false_arg",
+ content: `configure(enabled=false)`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "configure",
+ Arguments: testArgs(map[string]any{
+ "enabled": false,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "float_arg",
+ content: `set_threshold(value=0.95)`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "set_threshold",
+ Arguments: testArgs(map[string]any{
+ "value": 0.95,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "negative_number_arg",
+ content: `adjust(offset=-10)`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "adjust",
+ Arguments: testArgs(map[string]any{
+ "offset": int64(-10),
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "mixed_arg_types",
+ content: `process(name="test", count=42, ratio=3.14, active=true)`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "process",
+ Arguments: testArgs(map[string]any{
+ "name": "test",
+ "count": int64(42),
+ "ratio": 3.14,
+ "active": true,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "newline_in_string_arg",
+ content: `write_file(content="line1\nline2\nline3")`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "write_file",
+ Arguments: testArgs(map[string]any{
+ "content": "line1\\nline2\\nline3",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "empty_string_arg",
+ content: `search(query="")`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "search",
+ Arguments: testArgs(map[string]any{
+ "query": "",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "underscore_function_name",
+ content: `get_user_profile(user_id="abc123")`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_user_profile",
+ Arguments: testArgs(map[string]any{
+ "user_id": "abc123",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "whitespace_around_args",
+ content: `func( arg1 = "value1" , arg2 = 42 )`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "func",
+ Arguments: testArgs(map[string]any{
+ "arg1": "value1",
+ "arg2": int64(42),
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "json_in_string_arg",
+ content: `send_data(payload='{"key": "value", "num": 123}')`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "send_data",
+ Arguments: testArgs(map[string]any{
+ "payload": `{"key": "value", "num": 123}`,
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "url_in_arg",
+ content: `fetch(url="https://example.com/api?foo=bar&baz=qux")`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "fetch",
+ Arguments: testArgs(map[string]any{
+ "url": "https://example.com/api?foo=bar&baz=qux",
+ }),
+ },
+ },
+ },
+ },
+ {
+ name: "path_with_spaces",
+ content: `read_file(path="/home/user/My Documents/file.txt")`,
+ expected: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "read_file",
+ Arguments: testArgs(map[string]any{
+ "path": "/home/user/My Documents/file.txt",
+ }),
+ },
+ },
+ },
+ },
+ }
+
+ parser := &LFM2Parser{}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, err := parser.parseToolCallsContent(tt.content)
+
+ if tt.expectError {
+ if err == nil {
+ t.Error("Expected error but got none")
+ }
+ return
+ }
+
+ if err != nil {
+ t.Fatalf("Unexpected error: %v", err)
+ }
+
+ if diff := cmp.Diff(tt.expected, result, argsComparer); diff != "" {
+ t.Errorf("parseToolCallsContent() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
+
+func TestLFM2Parser_EdgeCases(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expectedContent string
+ expectedThinking string
+ hasThinking bool
+ }{
+ {
+ name: "multiple_think_close_tags",
+ input: "First thoughtSecond thoughtFinal content",
+ expectedThinking: "First thought",
+ expectedContent: "Second thoughtFinal content",
+ hasThinking: true,
+ },
+ {
+ name: "empty_thinking_content",
+ input: "Just content",
+ expectedThinking: "",
+ expectedContent: "Just content",
+ hasThinking: true,
+ },
+ {
+ name: "thinking_disabled_with_think_tags",
+ input: "Some contentMore content",
+ expectedContent: "Some contentMore content",
+ hasThinking: false,
+ },
+ {
+ name: "whitespace_only_content",
+ input: " \n\t ",
+ expectedContent: " \n\t ",
+ hasThinking: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ parser := &LFM2Parser{hasThinkingSupport: tt.hasThinking}
+ parser.Init([]api.Tool{}, nil, &api.ThinkValue{Value: tt.hasThinking})
+
+ content, thinking, _, err := parser.Add(tt.input, true)
+ if err != nil {
+ t.Fatalf("Add() error = %v", err)
+ }
+
+ if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
+ t.Errorf("Content mismatch (-want +got):\n%s", diff)
+ }
+
+ if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
+ t.Errorf("Thinking mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/model/parsers/ministral.go b/model/parsers/ministral.go
index 2acf10c5f8f..5df9ff32987 100644
--- a/model/parsers/ministral.go
+++ b/model/parsers/ministral.go
@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"strings"
+ "unicode"
"github.com/ollama/ollama/api"
)
@@ -17,12 +18,34 @@ const (
ministralCollectingToolArgs
)
+// ministralEvent represents an event emitted during parsing
+type ministralEvent interface {
+ isMinistralEvent()
+}
+
+type ministralEventContent struct {
+ content string
+}
+
+type ministralEventThinking struct {
+ thinking string
+}
+
+type ministralEventToolCall struct {
+ name string
+ args string // raw JSON string
+}
+
+func (ministralEventContent) isMinistralEvent() {}
+func (ministralEventThinking) isMinistralEvent() {}
+func (ministralEventToolCall) isMinistralEvent() {}
+
type MinistralParser struct {
state ministralParserState
buffer strings.Builder
tools []api.Tool
hasThinkingSupport bool
- currentTool *api.Tool
+ pendingToolName string // stores tool name while collecting args
}
func (p *MinistralParser) HasToolSupport() bool {
@@ -63,74 +86,251 @@ func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
return nil, fmt.Errorf("tool '%s' not found", n)
}
-func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
- p.buffer.WriteString(s)
+const (
+ ministralToolCallsTag = "[TOOL_CALLS]"
+ ministralThinkTag = "[THINK]"
+ ministralThinkEndTag = "[/THINK]"
+ ministralArgsTag = "[ARGS]"
+)
+
+// eat consumes the parser's buffer, and returns a list of any unambiguous
+// events from the current parser state. The second return value indicates
+// whether to keep looping (true when state transitions, false when waiting
+// for more data).
+func (p *MinistralParser) eat() ([]ministralEvent, bool) {
+ var events []ministralEvent
switch p.state {
case ministralCollectingContent:
- if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
- before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
- if before != "" {
- return before, "", calls, nil
+ bufStr := p.buffer.String()
+
+ // Check for [TOOL_CALLS] tag
+ if strings.Contains(bufStr, ministralToolCallsTag) {
+ split := strings.SplitN(bufStr, ministralToolCallsTag, 2)
+ before := strings.TrimRightFunc(split[0], unicode.IsSpace)
+ if len(before) > 0 {
+ events = append(events, ministralEventContent{content: before})
}
+ after := split[1]
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
p.state = ministralCollectingToolName
- } else if strings.Contains(p.buffer.String(), "[THINK]") {
+ return events, true
+ }
+
+ // Check for [THINK] tag
+ if strings.Contains(bufStr, ministralThinkTag) {
+ split := strings.SplitN(bufStr, ministralThinkTag, 2)
+ before := strings.TrimRightFunc(split[0], unicode.IsSpace)
+ if len(before) > 0 {
+ events = append(events, ministralEventContent{content: before})
+ }
+ after := split[1]
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
p.state = ministralCollectingThinkingContent
- return "", "", calls, nil
- } else {
+ return events, true
+ }
+
+ // Check for partial tag overlap with [TOOL_CALLS] or [THINK]
+ overlapToolCalls := overlap(bufStr, ministralToolCallsTag)
+ overlapThink := overlap(bufStr, ministralThinkTag)
+ maxOverlap := max(overlapToolCalls, overlapThink)
+
+ if maxOverlap > 0 {
+ // Withhold the potential partial tag
+ beforePartialTag := bufStr[:len(bufStr)-maxOverlap]
+ trailingWS := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWS
+ unambiguous := bufStr[:ambiguousStart]
+ ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
- return s, "", calls, nil
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, ministralEventContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ // No tag found: emit content but withhold trailing whitespace
+ whitespaceLen := trailingWhitespaceLen(bufStr)
+ ambiguousStart := len(bufStr) - whitespaceLen
+ unambiguous := bufStr[:ambiguousStart]
+ ambiguous := bufStr[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, ministralEventContent{content: unambiguous})
}
+ return events, false
+
case ministralCollectingThinkingContent:
- if strings.Contains(p.buffer.String(), "[/THINK]") {
- thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
- p.state = ministralCollectingContent
- if after != "" {
- p.buffer.Reset()
- return after, thinkingContent, calls, nil
+ bufStr := p.buffer.String()
+
+ if strings.Contains(bufStr, ministralThinkEndTag) {
+ split := strings.SplitN(bufStr, ministralThinkEndTag, 2)
+ thinkingContent := split[0]
+ after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ if len(thinkingContent) > 0 {
+ events = append(events, ministralEventThinking{thinking: thinkingContent})
}
- return "", thinkingContent, calls, nil
- } else {
+ p.state = ministralCollectingContent
+ return events, true
+ }
+
+ // Check for partial overlap with [/THINK]
+ if overlapLen := overlap(bufStr, ministralThinkEndTag); overlapLen > 0 {
+ unambiguous := bufStr[:len(bufStr)-overlapLen]
+ ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
- return "", s, calls, nil
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, ministralEventThinking{thinking: unambiguous})
+ }
+ return events, false
}
+
+ // No tag found: emit all thinking content
+ p.buffer.Reset()
+ if len(bufStr) > 0 {
+ events = append(events, ministralEventThinking{thinking: bufStr})
+ }
+ return events, false
+
case ministralCollectingToolName:
- if strings.Contains(p.buffer.String(), "[ARGS]") {
- name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
+ bufStr := p.buffer.String()
- t, err := toolByName(p.tools, name)
- if err != nil {
- return "", "", calls, err
- }
- p.currentTool = t
+ if strings.Contains(bufStr, ministralArgsTag) {
+ split := strings.SplitN(bufStr, ministralArgsTag, 2)
+ toolName := split[0]
+ after := split[1]
+ p.pendingToolName = toolName
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
p.state = ministralCollectingToolArgs
- return "", "", calls, nil
+ return events, true
}
- return "", "", calls, nil
+ // Wait for more data
+ return events, false
+
case ministralCollectingToolArgs:
- if strings.Contains(p.buffer.String(), "}") {
- before, _ := splitAtTag(&p.buffer, "}", false)
- before += "}"
+ bufStr := p.buffer.String()
+ jsonEnd := findJSONEnd(bufStr)
- var args api.ToolCallFunctionArguments
- if err := json.Unmarshal([]byte(before), &args); err != nil {
- // todo - throw a better error
- return "", "", calls, err
- }
+ if jsonEnd != -1 {
+ jsonStr := bufStr[:jsonEnd+1]
+ remaining := bufStr[jsonEnd+1:]
+
+ events = append(events, ministralEventToolCall{
+ name: p.pendingToolName,
+ args: jsonStr,
+ })
+ p.pendingToolName = ""
+ p.buffer.Reset()
+ p.buffer.WriteString(remaining)
p.state = ministralCollectingContent
+ return events, true
+ }
+ // Wait for more data
+ return events, false
+
+ default:
+ panic("unexpected ministral event")
+ }
+}
+
+// parseEvents loops calling eat() until it returns false
+func (p *MinistralParser) parseEvents() []ministralEvent {
+ var all []ministralEvent
+ keepLooping := true
+ for keepLooping {
+ var events []ministralEvent
+ events, keepLooping = p.eat()
+ all = append(all, events...)
+ }
+ return all
+}
- call := api.ToolCall{
+func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.buffer.WriteString(s)
+
+ events := p.parseEvents()
+
+ var contentBuilder, thinkingBuilder strings.Builder
+ var toolCalls []api.ToolCall
+
+ for _, event := range events {
+ switch e := event.(type) {
+ case ministralEventContent:
+ contentBuilder.WriteString(e.content)
+ case ministralEventThinking:
+ thinkingBuilder.WriteString(e.thinking)
+ case ministralEventToolCall:
+ // Validate tool exists
+ tool, toolErr := toolByName(p.tools, e.name)
+ if toolErr != nil {
+ return contentBuilder.String(), thinkingBuilder.String(), toolCalls, toolErr
+ }
+ // Parse JSON arguments
+ var args api.ToolCallFunctionArguments
+ if jsonErr := json.Unmarshal([]byte(e.args), &args); jsonErr != nil {
+ return contentBuilder.String(), thinkingBuilder.String(), toolCalls, jsonErr
+ }
+ toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
- Name: p.currentTool.Function.Name,
+ Name: tool.Function.Name,
Arguments: args,
},
+ })
+ }
+ }
+
+ return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
+}
+
+// findJSONEnd finds the index of the closing brace that completes a JSON object.
+// It properly handles nested objects, arrays, and strings (including escaped characters).
+// Returns -1 if the JSON is not yet complete.
+func findJSONEnd(s string) int {
+ depth := 0
+ inString := false
+ escaped := false
+
+ for i, r := range s {
+ if inString {
+ switch {
+ case escaped:
+ // If the previous character was a backslash, skip this character
+ escaped = false
+ case r == '\\':
+ // Mark the next character as escaped
+ escaped = true
+ case r == '"':
+ // End of string literal
+ inString = false
+ }
+ continue
+ }
+
+ switch r {
+ case '"':
+ // Start of string literal
+ inString = true
+ case '{', '[':
+ // Increase nesting level for objects and arrays
+ depth++
+ case '}', ']':
+ // Decrease nesting level
+ depth--
+ if depth == 0 {
+ // Reached the end of the root JSON structure
+ return i
}
- calls = append(calls, call)
- return "", "", calls, nil
}
- return "", "", calls, nil
}
- return p.buffer.String(), thinking, calls, nil
+ return -1
}
diff --git a/model/parsers/ministral_test.go b/model/parsers/ministral_test.go
new file mode 100644
index 00000000000..a04590b0765
--- /dev/null
+++ b/model/parsers/ministral_test.go
@@ -0,0 +1,545 @@
+package parsers
+
+import (
+ "reflect"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestMinistralParserStreaming(t *testing.T) {
+ type step struct {
+ input string
+ wantEvents []ministralEvent
+ }
+
+ cases := []struct {
+ desc string
+ tools []api.Tool
+ steps []step
+ think bool // whether to enable thinking support
+ }{
+ // Content streaming
+ {
+ desc: "simple content",
+ steps: []step{
+ {input: "Hello, how can I help you?", wantEvents: []ministralEvent{
+ ministralEventContent{content: "Hello, how can I help you?"},
+ }},
+ },
+ },
+ {
+ desc: "streaming content word by word",
+ steps: []step{
+ {input: "Hello,", wantEvents: []ministralEvent{ministralEventContent{content: "Hello,"}}},
+ {input: " how", wantEvents: []ministralEvent{ministralEventContent{content: " how"}}},
+ {input: " can I help?", wantEvents: []ministralEvent{ministralEventContent{content: " can I help?"}}},
+ },
+ },
+
+ // Simple tool calls
+ {
+ desc: "simple tool call",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]get_weather[ARGS]{"location": "San Francisco"}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "get_weather", args: `{"location": "San Francisco"}`},
+ }},
+ },
+ },
+ {
+ desc: "tool call with nested object",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]create_entities[ARGS]{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
+ }},
+ },
+ },
+ {
+ desc: "tool call with deeply nested objects",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "update_config"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]update_config[ARGS]{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "update_config", args: `{"settings": {"user": {"profile": {"name": "John", "age": 30}}, "theme": "dark"}}`},
+ }},
+ },
+ },
+ {
+ desc: "tool call with array of objects",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "process_items"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]process_items[ARGS]{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "process_items", args: `{"items": [{"id": 1}, {"id": 2}, {"id": 3}]}`},
+ }},
+ },
+ },
+ {
+ desc: "tool call with escaped quotes in string",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "search"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]search[ARGS]{"query": "say \"hello\""}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "search", args: `{"query": "say \"hello\""}`},
+ }},
+ },
+ },
+ {
+ desc: "tool call with braces inside string",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "format"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]format[ARGS]{"template": "Hello {name}!"}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "format", args: `{"template": "Hello {name}!"}`},
+ }},
+ },
+ },
+ {
+ desc: "empty JSON object",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "no_args"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]no_args[ARGS]{}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "no_args", args: `{}`},
+ }},
+ },
+ },
+ {
+ desc: "JSON with newlines in string",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "write"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]write[ARGS]{"content": "line1\nline2\nline3"}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "write", args: `{"content": "line1\nline2\nline3"}`},
+ }},
+ },
+ },
+ {
+ desc: "backslash in string value",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "path"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]path[ARGS]{"dir": "C:\\Users\\test"}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "path", args: `{"dir": "C:\\Users\\test"}`},
+ }},
+ },
+ },
+
+ // Content after tool call
+ {
+ desc: "content after tool call",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ // NOTE: It's unclear if this is valid Ministral output, but the parser
+ // currently treats text after a tool call as regular content. This test
+ // documents that behavior so we notice if it changes.
+ {input: `[TOOL_CALLS]test[ARGS]{"a": 1}some content after`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "test", args: `{"a": 1}`},
+ ministralEventContent{content: "some content after"},
+ }},
+ },
+ },
+
+ // Multiple tool calls
+ {
+ desc: "multiple tool calls in sequence",
+ tools: []api.Tool{
+ {Function: api.ToolFunction{Name: "get_weather"}},
+ {Function: api.ToolFunction{Name: "get_time"}},
+ },
+ steps: []step{
+ {input: `[TOOL_CALLS]get_weather[ARGS]{"location": "NYC"}[TOOL_CALLS]get_time[ARGS]{"timezone": "EST"}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "get_weather", args: `{"location": "NYC"}`},
+ ministralEventToolCall{name: "get_time", args: `{"timezone": "EST"}`},
+ }},
+ },
+ },
+ {
+ desc: "multiple tool calls streamed separately",
+ tools: []api.Tool{
+ {Function: api.ToolFunction{Name: "tool_a"}},
+ {Function: api.ToolFunction{Name: "tool_b"}},
+ },
+ steps: []step{
+ {input: `[TOOL_CALLS]tool_a[ARGS]{"x": 1}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "tool_a", args: `{"x": 1}`},
+ }},
+ {input: `[TOOL_CALLS]tool_b[ARGS]{"y": 2}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "tool_b", args: `{"y": 2}`},
+ }},
+ },
+ },
+
+ // Streaming tool calls
+ {
+ desc: "streaming tool call with nested objects",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "create_entities"}}},
+ steps: []step{
+ {input: "[TOOL_CALLS]create_entities[ARGS]", wantEvents: []ministralEvent{}},
+ {input: `{"entities": [{"entityType": "Person",`, wantEvents: []ministralEvent{}},
+ {input: ` "name": "Jack",`, wantEvents: []ministralEvent{}},
+ {input: ` "observations": ["Works`, wantEvents: []ministralEvent{}},
+ {input: ` as a baker"]}`, wantEvents: []ministralEvent{}},
+ {input: `]}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "create_entities", args: `{"entities": [{"entityType": "Person", "name": "Jack", "observations": ["Works as a baker"]}]}`},
+ }},
+ },
+ },
+ {
+ desc: "streaming with incomplete JSON waits for completion",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ {input: "[TOOL_CALLS]test[ARGS]{", wantEvents: []ministralEvent{}},
+ {input: `"a": {`, wantEvents: []ministralEvent{}},
+ {input: `"b": 1`, wantEvents: []ministralEvent{}},
+ {input: `}`, wantEvents: []ministralEvent{}},
+ {input: `}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "test", args: `{"a": {"b": 1}}`},
+ }},
+ },
+ },
+
+ // Partial tag handling
+ {
+ desc: "partial tool tag fakeout",
+ steps: []step{
+ {input: "abc[TOOL", wantEvents: []ministralEvent{ministralEventContent{content: "abc"}}},
+ {input: " not a tag", wantEvents: []ministralEvent{ministralEventContent{content: "[TOOL not a tag"}}},
+ },
+ },
+ {
+ desc: "tool call tag split across chunks",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ {input: "[TOOL_", wantEvents: []ministralEvent{}},
+ {input: "CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "test", args: `{}`},
+ }},
+ },
+ },
+ {
+ desc: "content before tool call",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "get_weather"}}},
+ steps: []step{
+ {input: "hello [TOOL_CALLS]get_weather[ARGS]{}", wantEvents: []ministralEvent{
+ ministralEventContent{content: "hello"},
+ ministralEventToolCall{name: "get_weather", args: `{}`},
+ }},
+ },
+ },
+ {
+ desc: "whitespace between content and tool call is trimmed",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ {input: "content \n [TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
+ ministralEventContent{content: "content"},
+ ministralEventToolCall{name: "test", args: `{}`},
+ }},
+ },
+ },
+ {
+ desc: "tabs and newlines before tool call are trimmed",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ {input: "content\t\n\t[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
+ ministralEventContent{content: "content"},
+ ministralEventToolCall{name: "test", args: `{}`},
+ }},
+ },
+ },
+ {
+ desc: "non-breaking space before tool call is trimmed",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ // \u00a0 is non-breaking space, which unicode.IsSpace considers whitespace
+ {input: "content\u00a0[TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
+ ministralEventContent{content: "content"},
+ ministralEventToolCall{name: "test", args: `{}`},
+ }},
+ },
+ },
+ {
+ desc: "whitespace before THINK tag is trimmed",
+ steps: []step{
+ {input: "content \n [THINK]thinking[/THINK]after", wantEvents: []ministralEvent{
+ ministralEventContent{content: "content"},
+ ministralEventThinking{thinking: "thinking"},
+ ministralEventContent{content: "after"},
+ }},
+ },
+ },
+ {
+ desc: "trailing whitespace withheld then emitted",
+ steps: []step{
+ {input: "Hello ", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
+ {input: "world", wantEvents: []ministralEvent{ministralEventContent{content: " world"}}},
+ },
+ },
+ {
+ desc: "trailing newline withheld then emitted",
+ steps: []step{
+ {input: "Hello\n", wantEvents: []ministralEvent{ministralEventContent{content: "Hello"}}},
+ {input: "world", wantEvents: []ministralEvent{ministralEventContent{content: "\nworld"}}},
+ },
+ },
+
+ // Thinking support
+ {
+ desc: "thinking content",
+ think: true,
+ steps: []step{
+ {input: "thinking here[/THINK]", wantEvents: []ministralEvent{
+ ministralEventThinking{thinking: "thinking here"},
+ }},
+ {input: "content after", wantEvents: []ministralEvent{
+ ministralEventContent{content: "content after"},
+ }},
+ },
+ },
+ {
+ desc: "thinking with whitespace after end tag",
+ think: true,
+ steps: []step{
+ {input: "my thoughts[/THINK] \n response", wantEvents: []ministralEvent{
+ ministralEventThinking{thinking: "my thoughts"},
+ ministralEventContent{content: "response"},
+ }},
+ },
+ },
+ {
+ desc: "non-breaking space after think end tag is trimmed",
+ think: true,
+ steps: []step{
+ // \u00a0 is non-breaking space
+ {input: "thinking[/THINK]\u00a0response", wantEvents: []ministralEvent{
+ ministralEventThinking{thinking: "thinking"},
+ ministralEventContent{content: "response"},
+ }},
+ },
+ },
+ {
+ desc: "partial think end tag",
+ think: true,
+ steps: []step{
+ {input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
+ {input: "NK]after", wantEvents: []ministralEvent{ministralEventContent{content: "after"}}},
+ },
+ },
+ {
+ desc: "think tag fakeout",
+ think: true,
+ steps: []step{
+ {input: "thinking[/THI", wantEvents: []ministralEvent{ministralEventThinking{thinking: "thinking"}}},
+ {input: "not end tag", wantEvents: []ministralEvent{ministralEventThinking{thinking: "[/THInot end tag"}}},
+ },
+ },
+ {
+ desc: "thinking then tool call",
+ think: true,
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "test"}}},
+ steps: []step{
+ {input: "let me think[/THINK][TOOL_CALLS]test[ARGS]{}", wantEvents: []ministralEvent{
+ ministralEventThinking{thinking: "let me think"},
+ ministralEventToolCall{name: "test", args: `{}`},
+ }},
+ },
+ },
+
+ // Content then THINK tag transition
+ {
+ desc: "content then think tag",
+ steps: []step{
+ {input: "content[THINK]thinking[/THINK]more", wantEvents: []ministralEvent{
+ ministralEventContent{content: "content"},
+ ministralEventThinking{thinking: "thinking"},
+ ministralEventContent{content: "more"},
+ }},
+ },
+ },
+
+ // Unicode handling
+ {
+ desc: "unicode content",
+ steps: []step{
+ {input: "你好 🌍 مرحبا", wantEvents: []ministralEvent{
+ ministralEventContent{content: "你好 🌍 مرحبا"},
+ }},
+ },
+ },
+ {
+ desc: "unicode in tool args",
+ tools: []api.Tool{{Function: api.ToolFunction{Name: "greet"}}},
+ steps: []step{
+ {input: `[TOOL_CALLS]greet[ARGS]{"message": "你好 🌍"}`, wantEvents: []ministralEvent{
+ ministralEventToolCall{name: "greet", args: `{"message": "你好 🌍"}`},
+ }},
+ },
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.desc, func(t *testing.T) {
+ parser := MinistralParser{}
+ parser.hasThinkingSupport = tc.think
+ parser.Init(tc.tools, nil, nil)
+
+ for i, step := range tc.steps {
+ parser.buffer.WriteString(step.input)
+ gotEvents := parser.parseEvents()
+
+ if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
+ // avoid deep equal on empty vs. nil slices
+ continue
+ }
+
+ if !reflect.DeepEqual(gotEvents, step.wantEvents) {
+ t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
+ }
+ }
+ })
+ }
+}
+
+func TestMinistralParser_Errors(t *testing.T) {
+ t.Run("unknown tool returns error", func(t *testing.T) {
+ p := &MinistralParser{}
+ p.Init([]api.Tool{{Function: api.ToolFunction{Name: "known_tool"}}}, nil, nil)
+
+ _, _, _, err := p.Add(`[TOOL_CALLS]unknown_tool[ARGS]{"a": 1}`, true)
+ if err == nil {
+ t.Fatal("expected error for unknown tool")
+ }
+ })
+
+ t.Run("invalid JSON returns error", func(t *testing.T) {
+ p := &MinistralParser{}
+ p.Init([]api.Tool{{Function: api.ToolFunction{Name: "test"}}}, nil, nil)
+
+ _, _, _, err := p.Add(`[TOOL_CALLS]test[ARGS]{invalid json}`, true)
+ if err == nil {
+ t.Fatal("expected error for invalid JSON")
+ }
+ })
+}
+
+func TestFindJSONEnd(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected int
+ }{
+ {
+ name: "simple object",
+ input: `{"a": 1}`,
+ expected: 7,
+ },
+ {
+ name: "nested object",
+ input: `{"a": {"b": 2}}`,
+ expected: 14,
+ },
+ {
+ name: "array inside object",
+ input: `{"items": [1, 2, 3]}`,
+ expected: 19,
+ },
+ {
+ name: "braces in string",
+ input: `{"template": "Hello {name}!"}`,
+ expected: 28,
+ },
+ {
+ name: "escaped quotes",
+ input: `{"msg": "say \"hi\""}`,
+ expected: 20,
+ },
+ {
+ name: "incomplete object",
+ input: `{"a": {"b": 1}`,
+ expected: -1,
+ },
+ {
+ name: "deeply nested",
+ input: `{"a": {"b": {"c": {"d": 1}}}}`,
+ expected: 28,
+ },
+ {
+ name: "object with trailing content",
+ input: `{"a": 1} extra`,
+ expected: 7,
+ },
+ {
+ name: "array",
+ input: `[{"a": 1}, {"b": 2}]`,
+ expected: 19,
+ },
+ {
+ name: "escaped backslash before quote",
+ input: `{"path": "C:\\"}`,
+ expected: 15,
+ },
+ {
+ name: "empty string",
+ input: "",
+ expected: -1,
+ },
+ {
+ name: "no opening brace",
+ input: "hello world",
+ expected: -1,
+ },
+ {
+ name: "only opening brace",
+ input: "{",
+ expected: -1,
+ },
+ {
+ name: "unclosed string",
+ input: `{"key": "unclosed`,
+ expected: -1,
+ },
+ {
+ name: "double escaped backslash then quote",
+ input: `{"path": "C:\\\\"}`,
+ expected: 17,
+ },
+ {
+ name: "unicode in key and value",
+ input: `{"키": "값"}`,
+ expected: 13,
+ },
+ {
+ name: "nested arrays",
+ input: `{"matrix": [[1, 2], [3, 4]]}`,
+ expected: 27,
+ },
+ {
+ name: "mixed nesting",
+ input: `{"a": [{"b": {"c": [1, 2, 3]}}]}`,
+ expected: 31,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := findJSONEnd(tt.input)
+ if result != tt.expected {
+ t.Errorf("findJSONEnd(%q) = %d, want %d", tt.input, result, tt.expected)
+ }
+ })
+ }
+}
+
+func TestMinistralParser_HasToolSupport(t *testing.T) {
+ p := &MinistralParser{}
+ if !p.HasToolSupport() {
+ t.Error("expected HasToolSupport to return true")
+ }
+}
+
+func TestMinistralParser_HasThinkingSupport(t *testing.T) {
+ p := &MinistralParser{hasThinkingSupport: false}
+ if p.HasThinkingSupport() {
+ t.Error("expected HasThinkingSupport to return false")
+ }
+
+ p = &MinistralParser{hasThinkingSupport: true}
+ if !p.HasThinkingSupport() {
+ t.Error("expected HasThinkingSupport to return true")
+ }
+}
diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go
index 3a3261a04b2..fa9f8b59836 100644
--- a/model/parsers/parsers.go
+++ b/model/parsers/parsers.go
@@ -3,6 +3,7 @@ package parsers
import (
"strings"
"unicode"
+ "unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
@@ -44,6 +45,10 @@ func ParserForName(name string) Parser {
var p Parser
switch name {
+ case "qwen3":
+ p = &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
+ case "qwen3-thinking":
+ p = &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
case "qwen3-coder":
p = &Qwen3CoderParser{}
case "qwen3-vl-instruct":
@@ -70,6 +75,12 @@ func ParserForName(name string) Parser {
return &FunctionGemmaParser{}
case "glm-4.7":
return &GLM47Parser{}
+ case "glm-ocr":
+ return &GlmOcrParser{}
+ case "lfm2":
+ return &LFM2Parser{hasThinkingSupport: false}
+ case "lfm2-thinking":
+ return &LFM2Parser{hasThinkingSupport: true}
default:
return nil
}
@@ -110,3 +121,33 @@ func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string
sb.WriteString(after)
return before, after // return events
}
+
+// overlap returns the longest overlap between the suffix of s and the prefix of delim
+func overlap(s, delim string) int {
+ max := min(len(delim), len(s))
+ for i := max; i > 0; i-- {
+ if strings.HasSuffix(s, delim[:i]) {
+ return i
+ }
+ }
+ return 0
+}
+
+// trailingWhitespaceLen returns the length in bytes of trailing whitespace in s
+func trailingWhitespaceLen(s string) int {
+ remaining := s
+ total := 0
+ for len(remaining) > 0 {
+ r, size := utf8.DecodeLastRuneInString(remaining)
+ // if it's an invalid utf8 rune, assume it isn't whitespace
+ if r == utf8.RuneError && size == 1 {
+ break
+ }
+ if !unicode.IsSpace(r) {
+ break
+ }
+ total += size
+ remaining = remaining[:len(remaining)-size]
+ }
+ return total
+}
diff --git a/model/parsers/parsers_test.go b/model/parsers/parsers_test.go
index 4f8566de309..15c2f664f37 100644
--- a/model/parsers/parsers_test.go
+++ b/model/parsers/parsers_test.go
@@ -54,6 +54,8 @@ func TestBuiltInParsersStillWork(t *testing.T) {
name string
}{
{"passthrough"},
+ {"qwen3"},
+ {"qwen3-thinking"},
{"qwen3-coder"},
{"harmony"},
}
diff --git a/model/parsers/qwen3.go b/model/parsers/qwen3.go
new file mode 100644
index 00000000000..e49111fb5f7
--- /dev/null
+++ b/model/parsers/qwen3.go
@@ -0,0 +1,335 @@
+package parsers
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/logutil"
+)
+
+type qwen3ParserState int
+
+const (
+ qwen3ParserStateLookingForThinkingOpen qwen3ParserState = iota
+ qwen3ParserStateThinkingStartedEatingWhitespace
+ qwen3ParserStateCollectingThinking
+ qwen3ParserStateThinkingDoneEatingWhitespace
+ qwen3ParserStateCollectingContent
+ qwen3ParserStateToolStartedEatingWhitespace
+ qwen3ParserStateCollectingToolContent
+)
+
+const (
+ qwen3ThinkingOpenTag = ""
+ qwen3ThinkingCloseTag = ""
+ qwen3ToolOpenTag = ""
+ qwen3ToolCloseTag = ""
+)
+
+// Qwen3Parser parses Qwen3 output to extract thinking and tool calls.
+// Qwen3 prompts end with when thinking is enabled, so output begins
+// with thinking content directly (without an opening tag).
+type Qwen3Parser struct {
+ state qwen3ParserState
+ buffer strings.Builder
+ tools []api.Tool
+ hasThinkingSupport bool
+ defaultThinking bool
+ maybeThinkingOpenAtBOL bool
+}
+
+func (p *Qwen3Parser) HasToolSupport() bool {
+ return true
+}
+
+func (p *Qwen3Parser) HasThinkingSupport() bool {
+ return p.hasThinkingSupport
+}
+
+func (p *Qwen3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
+ p.tools = tools
+ p.buffer.Reset()
+
+ thinkingEnabled := thinkValue != nil && thinkValue.Bool()
+ if thinkValue == nil {
+ thinkingEnabled = p.defaultThinking
+ }
+
+ if p.hasThinkingSupport && thinkingEnabled {
+ p.state = qwen3ParserStateCollectingThinking
+ p.maybeThinkingOpenAtBOL = true
+ } else {
+ p.state = qwen3ParserStateCollectingContent
+ p.maybeThinkingOpenAtBOL = false
+ }
+ return tools
+}
+
+type qwen3Event interface {
+ isQwen3Event()
+}
+
+type qwen3EventContent struct {
+ content string
+}
+
+func (qwen3EventContent) isQwen3Event() {}
+
+type qwen3EventRawToolCall struct {
+ raw string
+}
+
+func (qwen3EventRawToolCall) isQwen3Event() {}
+
+type qwen3EventThinkingContent struct {
+ content string
+}
+
+func (qwen3EventThinkingContent) isQwen3Event() {}
+
+func (p *Qwen3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.buffer.WriteString(s)
+ events := p.parseEvents()
+
+ var contentSb strings.Builder
+ var thinkingSb strings.Builder
+ for _, event := range events {
+ switch event := event.(type) {
+ case qwen3EventRawToolCall:
+ toolCall, err := parseQwen3ToolCall(event, p.tools)
+ if err != nil {
+ slog.Warn("qwen3 tool call parsing failed", "error", err)
+ return "", "", nil, err
+ }
+ calls = append(calls, toolCall)
+ case qwen3EventThinkingContent:
+ thinkingSb.WriteString(event.content)
+ case qwen3EventContent:
+ contentSb.WriteString(event.content)
+ }
+ }
+
+ return contentSb.String(), thinkingSb.String(), calls, nil
+}
+
+func (p *Qwen3Parser) parseEvents() []qwen3Event {
+ var all []qwen3Event
+
+ keepLooping := true
+ for keepLooping {
+ var events []qwen3Event
+ events, keepLooping = p.eat()
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ if len(all) > 0 {
+ slog.Log(context.TODO(), logutil.LevelTrace, "qwen3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
+ }
+
+ return all
+}
+
+func (p *Qwen3Parser) eatLeadingWhitespaceAndTransitionTo(nextState qwen3ParserState) ([]qwen3Event, bool) {
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ p.buffer.Reset()
+ if trimmed == "" {
+ return nil, false
+ }
+ p.state = nextState
+ p.buffer.WriteString(trimmed)
+ return nil, true
+}
+
+func (p *Qwen3Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
+ return splitAtTag(&p.buffer, tag, trimAfter)
+}
+
+func (p *Qwen3Parser) eat() ([]qwen3Event, bool) {
+ var events []qwen3Event
+
+ switch p.state {
+ case qwen3ParserStateLookingForThinkingOpen:
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
+ after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ if after == "" {
+ p.state = qwen3ParserStateThinkingStartedEatingWhitespace
+ } else {
+ p.state = qwen3ParserStateCollectingThinking
+ }
+ return events, true
+ } else if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
+ return events, false
+ } else if trimmed == "" {
+ return events, false
+ }
+ p.state = qwen3ParserStateCollectingContent
+ return events, true
+
+ case qwen3ParserStateThinkingStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingThinking)
+
+ case qwen3ParserStateCollectingThinking:
+ acc := p.buffer.String()
+
+ // Some qwen3 checkpoints emit an explicit opening tag even
+ // though the prompt already ended with . Strip exactly one
+ // leading opening tag if present.
+ if p.maybeThinkingOpenAtBOL {
+ trimmed := strings.TrimLeftFunc(acc, unicode.IsSpace)
+ if strings.HasPrefix(trimmed, qwen3ThinkingOpenTag) {
+ after := strings.TrimPrefix(trimmed, qwen3ThinkingOpenTag)
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ if after == "" {
+ return events, false
+ }
+ p.maybeThinkingOpenAtBOL = false
+ return events, true
+ }
+ if strings.HasPrefix(qwen3ThinkingOpenTag, trimmed) {
+ return events, false
+ }
+ p.maybeThinkingOpenAtBOL = false
+ }
+
+ if strings.Contains(acc, qwen3ThinkingCloseTag) {
+ thinking, remaining := p.splitAtTag(qwen3ThinkingCloseTag, true)
+ if len(thinking) > 0 {
+ events = append(events, qwen3EventThinkingContent{content: thinking})
+ }
+ if remaining == "" {
+ p.state = qwen3ParserStateThinkingDoneEatingWhitespace
+ } else {
+ p.state = qwen3ParserStateCollectingContent
+ }
+ return events, true
+ } else if overlapLen := overlap(acc, qwen3ThinkingCloseTag); overlapLen > 0 {
+ beforePartialTag := acc[:len(acc)-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, qwen3EventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ whitespaceLen := trailingWhitespaceLen(acc)
+ ambiguousStart := len(acc) - whitespaceLen
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, qwen3EventThinkingContent{content: unambiguous})
+ }
+ return events, false
+
+ case qwen3ParserStateThinkingDoneEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingContent)
+
+ case qwen3ParserStateCollectingContent:
+ acc := p.buffer.String()
+ if strings.Contains(acc, qwen3ToolOpenTag) {
+ before, after := p.splitAtTag(qwen3ToolOpenTag, true)
+ if len(before) > 0 {
+ events = append(events, qwen3EventContent{content: before})
+ }
+ if after == "" {
+ p.state = qwen3ParserStateToolStartedEatingWhitespace
+ } else {
+ p.state = qwen3ParserStateCollectingToolContent
+ }
+ return events, true
+ } else if overlapLen := overlap(acc, qwen3ToolOpenTag); overlapLen > 0 {
+ beforePartialTag := acc[:len(acc)-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, qwen3EventContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ whitespaceLen := trailingWhitespaceLen(acc)
+ ambiguousStart := len(acc) - whitespaceLen
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, qwen3EventContent{content: unambiguous})
+ }
+ return events, false
+
+ case qwen3ParserStateToolStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(qwen3ParserStateCollectingToolContent)
+
+ case qwen3ParserStateCollectingToolContent:
+ acc := p.buffer.String()
+ if strings.Contains(acc, qwen3ToolCloseTag) {
+ toolContent, _ := p.splitAtTag(qwen3ToolCloseTag, true)
+ if len(toolContent) == 0 {
+ slog.Warn("qwen3 tool call closing tag found but no content before it")
+ }
+ events = append(events, qwen3EventRawToolCall{raw: toolContent})
+ p.state = qwen3ParserStateCollectingContent
+ return events, true
+ }
+ return events, false
+
+ default:
+ panic("unreachable")
+ }
+}
+
+func parseQwen3ToolCall(raw qwen3EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
+ var parsed struct {
+ Name string `json:"name"`
+ Arguments map[string]any `json:"arguments"`
+ }
+
+ if err := json.Unmarshal([]byte(raw.raw), &parsed); err != nil {
+ return api.ToolCall{}, fmt.Errorf("failed to parse JSON: %w", err)
+ }
+
+ if parsed.Name == "" {
+ return api.ToolCall{}, fmt.Errorf("empty function name")
+ }
+
+ _ = tools // qwen3 uses direct JSON args and does not require schema coercion here.
+
+ toolCall := api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: parsed.Name,
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ }
+
+ for key, value := range parsed.Arguments {
+ toolCall.Function.Arguments.Set(key, value)
+ }
+
+ return toolCall, nil
+}
diff --git a/model/parsers/qwen3_test.go b/model/parsers/qwen3_test.go
new file mode 100644
index 00000000000..853874ded5b
--- /dev/null
+++ b/model/parsers/qwen3_test.go
@@ -0,0 +1,147 @@
+package parsers
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestQwen3ParserThinkingEnabled(t *testing.T) {
+ parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
+ parser.Init(nil, nil, &api.ThinkValue{Value: true})
+
+ content, thinking, calls, err := parser.Add("Let me think...Answer.", true)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ if thinking != "Let me think..." {
+ t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
+ }
+ if content != "Answer." {
+ t.Fatalf("expected content %q, got %q", "Answer.", content)
+ }
+ if len(calls) != 0 {
+ t.Fatalf("expected no tool calls, got %d", len(calls))
+ }
+}
+
+func TestQwen3ParserThinkingEnabledWithExplicitOpeningTag(t *testing.T) {
+ parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
+ parser.Init(nil, nil, &api.ThinkValue{Value: true})
+
+ content, thinking, calls, err := parser.Add("\nLet me think...Answer.", true)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ if thinking != "Let me think..." {
+ t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
+ }
+ if content != "Answer." {
+ t.Fatalf("expected content %q, got %q", "Answer.", content)
+ }
+ if len(calls) != 0 {
+ t.Fatalf("expected no tool calls, got %d", len(calls))
+ }
+}
+
+func TestQwen3ParserThinkingEnabledWithSplitOpeningTag(t *testing.T) {
+ parser := &Qwen3Parser{hasThinkingSupport: true, defaultThinking: true}
+ parser.Init(nil, nil, &api.ThinkValue{Value: true})
+
+ content, thinking, calls, err := parser.Add("Let me think...Answer.", true)
+ if err != nil {
+ t.Fatalf("parse failed on second chunk: %v", err)
+ }
+ if thinking != "Let me think..." {
+ t.Fatalf("expected thinking %q, got %q", "Let me think...", thinking)
+ }
+ if content != "Answer." {
+ t.Fatalf("expected content %q, got %q", "Answer.", content)
+ }
+ if len(calls) != 0 {
+ t.Fatalf("expected no tool calls, got %d", len(calls))
+ }
+}
+
+func TestQwen3ParserThinkingDisabled(t *testing.T) {
+ parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
+ parser.Init(nil, nil, &api.ThinkValue{Value: false})
+
+ content, thinking, calls, err := parser.Add("Direct answer", true)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ if thinking != "" {
+ t.Fatalf("expected no thinking, got %q", thinking)
+ }
+ if content != "Direct answer" {
+ t.Fatalf("expected content %q, got %q", "Direct answer", content)
+ }
+ if len(calls) != 0 {
+ t.Fatalf("expected no tool calls, got %d", len(calls))
+ }
+}
+
+func TestQwen3ParserNilThinkDefaultsToContentForInstructParser(t *testing.T) {
+ parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
+ parser.Init(nil, nil, nil)
+
+ content, thinking, calls, err := parser.Add("Direct answer", true)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ if thinking != "" {
+ t.Fatalf("expected no thinking, got %q", thinking)
+ }
+ if content != "Direct answer" {
+ t.Fatalf("expected content %q, got %q", "Direct answer", content)
+ }
+ if len(calls) != 0 {
+ t.Fatalf("expected no tool calls, got %d", len(calls))
+ }
+}
+
+func TestQwen3ParserToolCall(t *testing.T) {
+ parser := &Qwen3Parser{hasThinkingSupport: false, defaultThinking: false}
+ parser.Init(nil, nil, &api.ThinkValue{Value: false})
+
+ input := "{\"name\":\"get_weather\",\"arguments\":{\"location\":\"San Francisco\",\"unit\":\"celsius\"}}"
+ content, thinking, calls, err := parser.Add(input, true)
+ if err != nil {
+ t.Fatalf("parse failed: %v", err)
+ }
+
+ if content != "" {
+ t.Fatalf("expected empty content, got %q", content)
+ }
+ if thinking != "" {
+ t.Fatalf("expected empty thinking, got %q", thinking)
+ }
+ if len(calls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
+ }
+ if calls[0].Function.Name != "get_weather" {
+ t.Fatalf("expected tool name %q, got %q", "get_weather", calls[0].Function.Name)
+ }
+
+ location, ok := calls[0].Function.Arguments.Get("location")
+ if !ok || location != "San Francisco" {
+ t.Fatalf("expected location %q, got %v", "San Francisco", location)
+ }
+ unit, ok := calls[0].Function.Arguments.Get("unit")
+ if !ok || unit != "celsius" {
+ t.Fatalf("expected unit %q, got %v", "celsius", unit)
+ }
+}
diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go
index cf8f214e204..5604988ec29 100644
--- a/model/parsers/qwen3coder.go
+++ b/model/parsers/qwen3coder.go
@@ -11,7 +11,6 @@ import (
"strconv"
"strings"
"unicode"
- "unicode/utf8"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
@@ -194,36 +193,6 @@ func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
}
}
-// TODO(drifkin): move this to a shared location
-// longest overlap between suffix of s and prefix of delim
-func overlap(s, delim string) int {
- max := min(len(delim), len(s))
- for i := max; i > 0; i-- {
- if strings.HasSuffix(s, delim[:i]) {
- return i
- }
- }
- return 0
-}
-
-func trailingWhitespaceLen(s string) int {
- remaining := s
- total := 0
- for len(remaining) > 0 {
- r, size := utf8.DecodeLastRuneInString(remaining)
- // if it's an invalid utf8 rune, assume it isn't whitespace
- if r == utf8.RuneError && size == 1 {
- break
- }
- if !unicode.IsSpace(r) {
- break
- }
- total += size
- remaining = remaining[:len(remaining)-size]
- }
- return total
-}
-
type XMLFunctionCall struct {
XMLName xml.Name `xml:"function"`
Name string `xml:"name,attr"`
diff --git a/model/renderers/glmocr.go b/model/renderers/glmocr.go
new file mode 100644
index 00000000000..b141da07d2b
--- /dev/null
+++ b/model/renderers/glmocr.go
@@ -0,0 +1,109 @@
+package renderers
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+type GlmOcrRenderer struct{}
+
+func (r *GlmOcrRenderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ sb.WriteString("[gMASK]")
+
+ if len(tools) > 0 {
+ sb.WriteString("<|system|>\n")
+ sb.WriteString("# Tools\n\n")
+ sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
+ sb.WriteString("You are provided with function signatures within XML tags:\n")
+ sb.WriteString("\n")
+ for _, tool := range tools {
+ d, _ := json.Marshal(tool)
+ sb.WriteString(formatGLM47ToolJSON(d))
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n\n")
+ sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
+ sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...")
+ }
+
+ enableThinking := false
+ thinkingExplicitlySet := false
+ if thinkValue != nil {
+ enableThinking = thinkValue.Bool()
+ thinkingExplicitlySet = true
+ }
+
+ for i, message := range messages {
+ switch message.Role {
+ case "user":
+ sb.WriteString("<|user|>\n")
+ sb.WriteString(message.Content)
+ if thinkingExplicitlySet && !enableThinking && !strings.HasSuffix(message.Content, "/nothink") {
+ sb.WriteString("/nothink")
+ }
+ case "assistant":
+ sb.WriteString("<|assistant|>\n")
+ if message.Thinking != "" {
+ sb.WriteString("" + strings.TrimSpace(message.Thinking) + "")
+ } else {
+ sb.WriteString("")
+ }
+ if message.Content != "" {
+ sb.WriteString("\n" + strings.TrimSpace(message.Content))
+ }
+ if len(message.ToolCalls) > 0 {
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("\n" + toolCall.Function.Name)
+ sb.WriteString(renderGlmOcrToolArguments(toolCall.Function.Arguments))
+ sb.WriteString("")
+ }
+ }
+ sb.WriteString("\n")
+ case "tool":
+ if i == 0 || messages[i-1].Role != "tool" {
+ sb.WriteString("<|observation|>")
+ }
+ sb.WriteString("\n\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("\n\n")
+ case "system":
+ sb.WriteString("<|system|>\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("\n")
+ }
+ }
+
+ sb.WriteString("<|assistant|>\n")
+ if thinkingExplicitlySet && !enableThinking {
+ sb.WriteString("\n")
+ }
+
+ return sb.String(), nil
+}
+
+func renderGlmOcrToolArguments(args api.ToolCallFunctionArguments) string {
+ var sb strings.Builder
+ for key, value := range args.All() {
+ sb.WriteString("" + key + "")
+ var valueStr string
+ if str, ok := value.(string); ok {
+ valueStr = str
+ } else {
+ jsonBytes, err := json.Marshal(value)
+ if err != nil {
+ valueStr = fmt.Sprintf("%v", value)
+ } else {
+ valueStr = string(jsonBytes)
+ }
+ }
+
+ sb.WriteString("" + valueStr + "")
+ }
+
+ return sb.String()
+}
diff --git a/model/renderers/lfm2.go b/model/renderers/lfm2.go
new file mode 100644
index 00000000000..5c046835f33
--- /dev/null
+++ b/model/renderers/lfm2.go
@@ -0,0 +1,144 @@
+package renderers
+
+import (
+ "encoding/json"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+type LFM2Renderer struct {
+ IsThinking bool
+}
+
+func (r *LFM2Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ // Note: BOS token is added by the tokenizer (add_bos_token: true), not the renderer
+
+ // Extract first system message if present (to combine with tools)
+ var firstSystemContent string
+ startIdx := 0
+ if len(messages) > 0 && messages[0].Role == "system" {
+ firstSystemContent = messages[0].Content
+ startIdx = 1
+ }
+
+ // Append tools to first system content
+ if len(tools) > 0 {
+ if firstSystemContent != "" {
+ firstSystemContent += "\n"
+ }
+ firstSystemContent += "List of tools: ["
+ for i, tool := range tools {
+ toolJSON, err := json.Marshal(tool)
+ if err != nil {
+ return "", err
+ }
+ firstSystemContent += string(toolJSON)
+ if i < len(tools)-1 {
+ firstSystemContent += ", "
+ }
+ }
+ firstSystemContent += "]"
+ }
+
+ // Output first system block if it has content
+ if firstSystemContent != "" {
+ sb.WriteString("<|im_start|>system\n")
+ sb.WriteString(firstSystemContent)
+ sb.WriteString("<|im_end|>\n")
+ }
+
+ // Find the index of the last assistant message for thinking stripping
+ lastAssistantIndex := -1
+ for i := len(messages) - 1; i >= startIdx; i-- {
+ if messages[i].Role == "assistant" {
+ lastAssistantIndex = i
+ break
+ }
+ }
+
+ // Track whether we need to add generation prompt
+ needsGenerationPrompt := len(messages) > 0
+
+ for i := startIdx; i < len(messages); i++ {
+ message := messages[i]
+ switch message.Role {
+ case "system":
+ // Additional system messages (after the first) are rendered normally
+ sb.WriteString("<|im_start|>system\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("<|im_end|>\n")
+
+ case "user":
+ sb.WriteString("<|im_start|>user\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("<|im_end|>\n")
+ needsGenerationPrompt = true
+
+ case "assistant":
+ sb.WriteString("<|im_start|>assistant\n")
+
+ // Check if this is the last assistant message
+ isLastAssistant := i == lastAssistantIndex
+
+ // Process content (may need thinking stripped)
+ content := message.Content
+
+ // Handle thinking tags in assistant content
+ keepPastThinking := r.IsThinking && (thinkValue != nil && thinkValue.Bool())
+ if strings.Contains(content, "") {
+ parts := strings.SplitN(content, "", 2)
+ if len(parts) > 1 {
+ if !isLastAssistant && !keepPastThinking {
+ // Strip thinking entirely for past assistant messages
+ content = strings.TrimSpace(parts[1])
+ } else {
+ // Preserve thinking but trim whitespace after
+ content = parts[0] + "" + strings.TrimLeft(parts[1], " \t\n\r")
+ }
+ }
+ }
+
+ if len(message.ToolCalls) > 0 {
+ // Assistant with tool calls - write content first (if any after stripping)
+ if content != "" {
+ sb.WriteString(content)
+ }
+
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("<|tool_call_start|>")
+ toolCallJSON := map[string]any{
+ "name": toolCall.Function.Name,
+ "arguments": toolCall.Function.Arguments,
+ }
+ callJSON, _ := json.Marshal(toolCallJSON)
+ sb.WriteString(string(callJSON))
+ sb.WriteString("<|tool_call_end|>")
+ }
+ } else {
+ sb.WriteString(content)
+ }
+
+ sb.WriteString("<|im_end|>\n")
+ needsGenerationPrompt = true // Always add gen prompt after assistant when add_generation_prompt=true
+
+ case "tool":
+ // Tool responses are rendered as plain messages per the chat template
+ sb.WriteString("<|im_start|>tool\n")
+ sb.WriteString(message.Content)
+ sb.WriteString("<|im_end|>\n")
+ needsGenerationPrompt = true
+ }
+ }
+
+ // Add generation prompt
+ if needsGenerationPrompt {
+ sb.WriteString("<|im_start|>assistant\n")
+ // Note: Model is a "thinking-only" model - it will output itself
+ // We don't add tag to the prompt
+ }
+
+ return sb.String(), nil
+}
diff --git a/model/renderers/lfm2_test.go b/model/renderers/lfm2_test.go
new file mode 100644
index 00000000000..9eb07eea3d5
--- /dev/null
+++ b/model/renderers/lfm2_test.go
@@ -0,0 +1,427 @@
+package renderers
+
+import (
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestLFM2Renderer(t *testing.T) {
+ tests := []struct {
+ name string
+ messages []api.Message
+ tools []api.Tool
+ thinkValue *api.ThinkValue
+ expected string
+ }{
+ {
+ name: "basic user message",
+ messages: []api.Message{
+ {Role: "user", Content: "Hello!"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "basic with system message",
+ messages: []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Hello!"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "multiple system messages rendered separately",
+ messages: []api.Message{
+ {Role: "system", Content: "First instruction."},
+ {Role: "system", Content: "Second instruction."},
+ {Role: "user", Content: "Hello!"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>system\nFirst instruction.<|im_end|>\n<|im_start|>system\nSecond instruction.<|im_end|>\n<|im_start|>user\nHello!<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "multi-turn conversation",
+ messages: []api.Message{
+ {Role: "user", Content: "What is 2+2?"},
+ {Role: "assistant", Content: "The answer is 4."},
+ {Role: "user", Content: "Thanks!"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nWhat is 2+2?<|im_end|>\n<|im_start|>assistant\nThe answer is 4.<|im_end|>\n<|im_start|>user\nThanks!<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "only system message",
+ messages: []api.Message{
+ {Role: "system", Content: "You are helpful."},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>system\nYou are helpful.<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // When assistant is the LAST assistant, thinking is preserved (even with keep_past_thinking=false)
+ name: "user-assistant-user: last assistant preserves thinking",
+ messages: []api.Message{
+ {Role: "user", Content: "Q1"},
+ {Role: "assistant", Content: "reasoningA1"},
+ {Role: "user", Content: "Q2"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreasoningA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // With two assistants, first is stripped (not last), second preserved (is last)
+ name: "multi-turn thinking: first stripped, second preserved",
+ messages: []api.Message{
+ {Role: "user", Content: "Q1"},
+ {Role: "assistant", Content: "reason1A1"},
+ {Role: "user", Content: "Q2"},
+ {Role: "assistant", Content: "reason2A2"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nA1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // With thinking enabled (keep_past_thinking=true), both preserved
+ name: "multi-turn thinking: both preserved when thinking enabled",
+ messages: []api.Message{
+ {Role: "user", Content: "Q1"},
+ {Role: "assistant", Content: "reason1A1"},
+ {Role: "user", Content: "Q2"},
+ {Role: "assistant", Content: "reason2A2"},
+ },
+ thinkValue: &api.ThinkValue{Value: true},
+ expected: "<|im_start|>user\nQ1<|im_end|>\n<|im_start|>assistant\nreason1A1<|im_end|>\n<|im_start|>user\nQ2<|im_end|>\n<|im_start|>assistant\nreason2A2<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "assistant with tool calls",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ {
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
+ },
+ {
+ name: "assistant with content and tool calls",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather in Paris?"},
+ {
+ Role: "assistant",
+ Content: "Let me check.",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: `<|im_start|>user` + "\n" + `What's the weather in Paris?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `Let me check.<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
+ },
+ {
+ name: "tool response",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ {Role: "assistant", Content: "Let me check."},
+ {Role: "tool", Content: "22C, Sunny"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.<|im_end|>\n<|im_start|>tool\n22C, Sunny<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "multiple tool calls",
+ messages: []api.Message{
+ {Role: "user", Content: "Get weather for Paris and London"},
+ {
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "London",
+ }),
+ },
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: `<|im_start|>user` + "\n" + `Get weather for Paris and London<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n" + `<|tool_call_start|>{"arguments":{"location":"Paris"},"name":"get_weather"}<|tool_call_end|><|tool_call_start|>{"arguments":{"location":"London"},"name":"get_weather"}<|tool_call_end|><|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
+ },
+ {
+ name: "tools definitions with system message",
+ messages: []api.Message{
+ {Role: "system", Content: "You are helpful."},
+ {Role: "user", Content: "What's the weather?"},
+ },
+ tools: []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get current weather",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: testPropsMap(map[string]api.ToolProperty{
+ "location": {
+ Type: api.PropertyType{"string"},
+ Description: "City name",
+ },
+ }),
+ Required: []string{"location"},
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: `<|im_start|>system` + "\n" + `You are helpful.` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
+ },
+ {
+ name: "tools definitions without system message",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ },
+ tools: []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get current weather",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: testPropsMap(map[string]api.ToolProperty{
+ "location": {
+ Type: api.PropertyType{"string"},
+ Description: "City name",
+ },
+ }),
+ Required: []string{"location"},
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: `<|im_start|>system` + "\n" + `List of tools: [{"type":"function","function":{"name":"get_weather","description":"Get current weather","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"City name"}}}}}]<|im_end|>` + "\n" + `<|im_start|>user` + "\n" + `What's the weather?<|im_end|>` + "\n" + `<|im_start|>assistant` + "\n",
+ },
+ {
+ name: "multiple tools without system message",
+ messages: []api.Message{
+ {Role: "user", Content: "Hello"},
+ },
+ tools: []api.Tool{
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get weather",
+ },
+ },
+ {
+ Type: "function",
+ Function: api.ToolFunction{
+ Name: "get_time",
+ Description: "Get time",
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>system\nList of tools: [{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather\",\"parameters\":{\"type\":\"\",\"properties\":null}}}, {\"type\":\"function\",\"function\":{\"name\":\"get_time\",\"description\":\"Get time\",\"parameters\":{\"type\":\"\",\"properties\":null}}}]<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "user-tool sequence",
+ messages: []api.Message{
+ {Role: "user", Content: "Check weather"},
+ {Role: "tool", Content: "22C"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "full tool call cycle",
+ messages: []api.Message{
+ {Role: "user", Content: "Check weather"},
+ {Role: "assistant", Content: "Let me check"},
+ {Role: "tool", Content: "22C"},
+ {Role: "assistant", Content: "It's 22C"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nCheck weather<|im_end|>\n<|im_start|>assistant\nLet me check<|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\nIt's 22C<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "unicode content",
+ messages: []api.Message{
+ {Role: "user", Content: "你好世界! مرحبا 🌍"},
+ {Role: "assistant", Content: "Hello! 👋"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\n你好世界! مرحبا 🌍<|im_end|>\n<|im_start|>assistant\nHello! 👋<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "newlines in content",
+ messages: []api.Message{
+ {Role: "user", Content: "Line 1\nLine 2\n\nLine 4"},
+ {Role: "assistant", Content: "Response with\nmultiple\nlines"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nLine 1\nLine 2\n\nLine 4<|im_end|>\n<|im_start|>assistant\nResponse with\nmultiple\nlines<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ name: "empty assistant content",
+ messages: []api.Message{
+ {Role: "user", Content: "Hello"},
+ {Role: "assistant", Content: ""},
+ {Role: "user", Content: "OK"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<|im_end|>\n<|im_start|>user\nOK<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // Generation prompt does NOT include - model outputs it
+ name: "generation prompt has no think tag",
+ messages: []api.Message{
+ {Role: "user", Content: "Think hard"},
+ },
+ thinkValue: &api.ThinkValue{Value: true},
+ expected: "<|im_start|>user\nThink hard<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // Interleaved: thinking before tool call - last assistant preserves thinking
+ name: "thinking before tool call (last assistant)",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ {
+ Role: "assistant",
+ Content: "I need to check the weather",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nI need to check the weather<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // Two assistants with tool calls - first has thinking stripped
+ name: "two assistants with tools: first thinking stripped",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ {
+ Role: "assistant",
+ Content: "checking",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "22C"},
+ {Role: "assistant", Content: "got resultIt's 22C!"},
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\n<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\ngot resultIt's 22C!<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // Two assistants with tools - both preserved when thinking enabled
+ name: "two assistants with tools: both preserved when thinking enabled",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ {
+ Role: "assistant",
+ Content: "checking",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "22C"},
+ {Role: "assistant", Content: "got resultIt's 22C!"},
+ },
+ thinkValue: &api.ThinkValue{Value: true},
+ expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nchecking<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>tool\n22C<|im_end|>\n<|im_start|>assistant\ngot resultIt's 22C!<|im_end|>\n<|im_start|>assistant\n",
+ },
+ {
+ // Content before thinking before tool call
+ name: "content then thinking then tool call",
+ messages: []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ {
+ Role: "assistant",
+ Content: "Let me check.Using weather API",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: testArgs(map[string]any{
+ "location": "Paris",
+ }),
+ },
+ },
+ },
+ },
+ },
+ thinkValue: &api.ThinkValue{Value: false},
+ expected: "<|im_start|>user\nWhat's the weather?<|im_end|>\n<|im_start|>assistant\nLet me check.Using weather API<|tool_call_start|>{\"arguments\":{\"location\":\"Paris\"},\"name\":\"get_weather\"}<|tool_call_end|><|im_end|>\n<|im_start|>assistant\n",
+ },
+ }
+
+ renderer := &LFM2Renderer{IsThinking: true}
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ rendered, err := renderer.Render(tt.messages, tt.tools, tt.thinkValue)
+ if err != nil {
+ t.Fatalf("Render() error = %v", err)
+ }
+ if diff := cmp.Diff(tt.expected, rendered); diff != "" {
+ t.Errorf("Render() mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}
diff --git a/model/renderers/qwen3coder.go b/model/renderers/qwen3coder.go
index 2b5a5ae9598..33466ba8117 100644
--- a/model/renderers/qwen3coder.go
+++ b/model/renderers/qwen3coder.go
@@ -167,12 +167,12 @@ func (r *Qwen3CoderRenderer) Render(messages []api.Message, tools []api.Tool, _
// only start a new user block if this is the first tool response
if i == 0 || filteredMessages[i-1].Role != "tool" {
- sb.WriteString(imStartTag + "user\n")
+ sb.WriteString(imStartTag + "user")
}
- sb.WriteString("\n")
+ sb.WriteString("\n\n")
sb.WriteString(message.Content)
- sb.WriteString("\n\n")
+ sb.WriteString("\n")
// close the user block only if this is the last tool response
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
diff --git a/model/renderers/qwen3coder_test.go b/model/renderers/qwen3coder_test.go
index b6ca56e7577..9f91c1f67d6 100644
--- a/model/renderers/qwen3coder_test.go
+++ b/model/renderers/qwen3coder_test.go
@@ -1,6 +1,7 @@
package renderers
import (
+ "strings"
"testing"
"github.com/google/go-cmp/cmp"
@@ -127,8 +128,7 @@ fahrenheit
<|im_start|>user
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
-
-<|im_end|>
+<|im_end|>
<|im_start|>user
That sounds nice! What about New York?<|im_end|>
<|im_start|>assistant
@@ -233,8 +233,7 @@ I'll call double(1) and triple(2) for you.
{"number": 6}
-
-<|im_end|>
+<|im_end|>
<|im_start|>assistant
`,
},
@@ -280,8 +279,7 @@ call tool<|im_end|>
<|im_start|>user
{"payload": {"foo": "bar"}}
-
-<|im_end|>
+<|im_end|>
<|im_start|>assistant
`,
},
@@ -337,6 +335,31 @@ func TestFormatToolCallArgument(t *testing.T) {
}
}
+func TestQwen3CoderRendererToolResponseNoTrailingNewline(t *testing.T) {
+ msgs := []api.Message{
+ {Role: "user", Content: "call tool"},
+ {Role: "assistant", ToolCalls: []api.ToolCall{
+ {Function: api.ToolCallFunction{
+ Name: "echo",
+ Arguments: testArgs(map[string]any{"payload": "ok"}),
+ }},
+ }},
+ {Role: "tool", Content: "{\"payload\":\"ok\"}", ToolName: "echo"},
+ }
+
+ rendered, err := (&Qwen3CoderRenderer{}).Render(msgs, nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if strings.Contains(rendered, "\n<|im_end|>") {
+ t.Fatalf("expected no newline after , got:\n%s", rendered)
+ }
+ if !strings.Contains(rendered, "<|im_end|>") {
+ t.Fatalf("expected to be immediately followed by <|im_end|>, got:\n%s", rendered)
+ }
+}
+
func TestQwen3ToolDefinitionTypes(t *testing.T) {
tests := []struct {
name string
diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go
index dbb63b07c24..baa0bc8c4dd 100644
--- a/model/renderers/renderer.go
+++ b/model/renderers/renderer.go
@@ -82,6 +82,12 @@ func rendererForName(name string) Renderer {
return &FunctionGemmaRenderer{}
case "glm-4.7":
return &GLM47Renderer{}
+ case "glm-ocr":
+ return &GlmOcrRenderer{}
+ case "lfm2":
+ return &LFM2Renderer{IsThinking: false}
+ case "lfm2-thinking":
+ return &LFM2Renderer{IsThinking: true}
default:
return nil
}
diff --git a/model/wordpiece_test.go b/model/wordpiece_test.go
deleted file mode 100644
index c03bb17a725..00000000000
--- a/model/wordpiece_test.go
+++ /dev/null
@@ -1,53 +0,0 @@
-package model
-
-import (
- "slices"
- "testing"
-
- "github.com/google/go-cmp/cmp"
-)
-
-func TestWordPiece(t *testing.T) {
- wpm := NewWordPiece(
- &Vocabulary{
- Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
- AddBOS: true,
- AddEOS: true,
- BOS: []int32{1},
- EOS: []int32{2},
- },
- true, // lowercase
- )
-
- ids, err := wpm.Encode("Hello world!", true)
- if err != nil {
- t.Fatal(err)
- }
-
- if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
- t.Errorf("unexpected ids (-want +got):\n%s", diff)
- }
-
- words, err := wpm.Decode(ids)
- if err != nil {
- t.Fatal(err)
- }
-
- if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
- t.Errorf("unexpected words (-want +got):\n%s", diff)
- }
-}
-
-func TestWordPieceWords(t *testing.T) {
- var wpm WordPiece
-
- basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
- if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
- t.Errorf("unexpected words (-want +got):\n%s", diff)
- }
-
- chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
- if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
- t.Errorf("unexpected words (-want +got):\n%s", diff)
- }
-}
diff --git a/openai/openai.go b/openai/openai.go
index d1f75c4aaa0..acc75535449 100644
--- a/openai/openai.go
+++ b/openai/openai.go
@@ -794,3 +794,47 @@ func ToImageGenerationResponse(resp api.GenerateResponse) ImageGenerationRespons
Data: data,
}
}
+
+// ImageEditRequest is an OpenAI-compatible image edit request.
+type ImageEditRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt"`
+ Image string `json:"image"` // Base64-encoded image data
+ Size string `json:"size,omitempty"` // e.g., "1024x1024"
+ Seed *int64 `json:"seed,omitempty"`
+}
+
+// FromImageEditRequest converts an OpenAI image edit request to an Ollama GenerateRequest.
+func FromImageEditRequest(r ImageEditRequest) (api.GenerateRequest, error) {
+ req := api.GenerateRequest{
+ Model: r.Model,
+ Prompt: r.Prompt,
+ }
+
+ // Decode the input image
+ if r.Image != "" {
+ imgData, err := decodeImageURL(r.Image)
+ if err != nil {
+ return api.GenerateRequest{}, fmt.Errorf("invalid image: %w", err)
+ }
+ req.Images = append(req.Images, imgData)
+ }
+
+ // Parse size if provided (e.g., "1024x768")
+ if r.Size != "" {
+ var w, h int32
+ if _, err := fmt.Sscanf(r.Size, "%dx%d", &w, &h); err == nil {
+ req.Width = w
+ req.Height = h
+ }
+ }
+
+ if r.Seed != nil {
+ if req.Options == nil {
+ req.Options = map[string]any{}
+ }
+ req.Options["seed"] = *r.Seed
+ }
+
+ return req, nil
+}
diff --git a/openai/openai_test.go b/openai/openai_test.go
index f76af7090f4..b2e98ead446 100644
--- a/openai/openai_test.go
+++ b/openai/openai_test.go
@@ -448,3 +448,86 @@ func TestFromChatRequest_TopLogprobsRange(t *testing.T) {
})
}
}
+
+func TestFromImageEditRequest_Basic(t *testing.T) {
+ req := ImageEditRequest{
+ Model: "test-model",
+ Prompt: "make it blue",
+ Image: prefix + image,
+ }
+
+ result, err := FromImageEditRequest(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if result.Model != "test-model" {
+ t.Errorf("expected model 'test-model', got %q", result.Model)
+ }
+
+ if result.Prompt != "make it blue" {
+ t.Errorf("expected prompt 'make it blue', got %q", result.Prompt)
+ }
+
+ if len(result.Images) != 1 {
+ t.Fatalf("expected 1 image, got %d", len(result.Images))
+ }
+}
+
+func TestFromImageEditRequest_WithSize(t *testing.T) {
+ req := ImageEditRequest{
+ Model: "test-model",
+ Prompt: "make it blue",
+ Image: prefix + image,
+ Size: "512x768",
+ }
+
+ result, err := FromImageEditRequest(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if result.Width != 512 {
+ t.Errorf("expected width 512, got %d", result.Width)
+ }
+
+ if result.Height != 768 {
+ t.Errorf("expected height 768, got %d", result.Height)
+ }
+}
+
+func TestFromImageEditRequest_WithSeed(t *testing.T) {
+ seed := int64(12345)
+ req := ImageEditRequest{
+ Model: "test-model",
+ Prompt: "make it blue",
+ Image: prefix + image,
+ Seed: &seed,
+ }
+
+ result, err := FromImageEditRequest(req)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if result.Options == nil {
+ t.Fatal("expected options to be set")
+ }
+
+ if result.Options["seed"] != seed {
+ t.Errorf("expected seed %d, got %v", seed, result.Options["seed"])
+ }
+}
+
+func TestFromImageEditRequest_InvalidImage(t *testing.T) {
+ req := ImageEditRequest{
+ Model: "test-model",
+ Prompt: "make it blue",
+ Image: "not-valid-base64",
+ }
+
+ _, err := FromImageEditRequest(req)
+ if err == nil {
+ t.Error("expected error for invalid image")
+ }
+}
diff --git a/readline/errors.go b/readline/errors.go
index bb3fbd4738f..1be5213e560 100644
--- a/readline/errors.go
+++ b/readline/errors.go
@@ -5,6 +5,7 @@ import (
)
var ErrInterrupt = errors.New("Interrupt")
+var ErrEditPrompt = errors.New("EditPrompt")
type InterruptError struct {
Line []rune
diff --git a/readline/readline.go b/readline/readline.go
index e18a256484e..0113aa2cef2 100644
--- a/readline/readline.go
+++ b/readline/readline.go
@@ -41,6 +41,7 @@ type Instance struct {
Terminal *Terminal
History *History
Pasting bool
+ Prefill string
pastedLines []string
}
@@ -89,13 +90,48 @@ func (i *Instance) Readline() (string, error) {
buf, _ := NewBuffer(i.Prompt)
+ // Prefill the buffer with any text that we received from an external editor
+ if i.Prefill != "" {
+ lines := strings.Split(i.Prefill, "\n")
+ i.Prefill = ""
+ for idx, l := range lines {
+ for _, r := range l {
+ buf.Add(r)
+ }
+ if idx < len(lines)-1 {
+ i.pastedLines = append(i.pastedLines, buf.String())
+ buf.Buf.Clear()
+ buf.Pos = 0
+ buf.DisplayPos = 0
+ buf.LineHasSpace.Clear()
+ fmt.Println()
+ fmt.Print(i.Prompt.AltPrompt)
+ i.Prompt.UseAlt = true
+ }
+ }
+ }
+
var esc bool
var escex bool
var metaDel bool
var currentLineBuf []rune
+ // draining tracks if we're processing buffered input from cooked mode.
+ // In cooked mode Enter sends \n, but in raw mode Ctrl+J sends \n.
+ // We treat \n from cooked mode as submit, not multiline.
+ // We check Buffered() after the first read since the bufio buffer is
+ // empty until then. This is compatible with """ multiline mode in
+ // interactive.go since each Readline() call is independent.
+ var draining, stopDraining bool
+
for {
+ // Apply deferred state change from previous iteration
+ if stopDraining {
+ draining = false
+ stopDraining = false
+ }
+
// don't show placeholder when pasting unless we're in multiline mode
showPlaceholder := !i.Pasting || i.Prompt.UseAlt
if buf.IsEmpty() && showPlaceholder {
@@ -105,6 +141,15 @@ func (i *Instance) Readline() (string, error) {
r, err := i.Terminal.Read()
+ // After reading, check if there's more buffered data. If so, we're
+ // processing cooked-mode input. Once buffer empties, the current
+ // char is the last buffered one (still drain it), then stop next iteration.
+ if i.Terminal.reader.Buffered() > 0 {
+ draining = true
+ } else if draining {
+ stopDraining = true
+ }
+
if buf.IsEmpty() {
fmt.Print(ClearToEOL)
}
@@ -228,19 +273,47 @@ func (i *Instance) Readline() (string, error) {
buf.ClearScreen()
case CharCtrlW:
buf.DeleteWord()
+ case CharBell:
+ output := buf.String()
+ numPastedLines := len(i.pastedLines)
+ if numPastedLines > 0 {
+ output = strings.Join(i.pastedLines, "\n") + "\n" + output
+ i.pastedLines = nil
+ }
+
+ // Move cursor to the last display line of the current buffer
+ currLine := buf.DisplayPos / buf.LineWidth
+ lastLine := buf.DisplaySize() / buf.LineWidth
+ if lastLine > currLine {
+ fmt.Print(CursorDownN(lastLine - currLine))
+ }
+
+ // Clear all lines from bottom to top: buffer wrapped lines + pasted lines
+ for range lastLine + numPastedLines {
+ fmt.Print(CursorBOL + ClearToEOL + CursorUp)
+ }
+ fmt.Print(CursorBOL + ClearToEOL)
+
+ i.Prompt.UseAlt = false
+ return output, ErrEditPrompt
case CharCtrlZ:
fd := os.Stdin.Fd()
return handleCharCtrlZ(fd, i.Terminal.termios)
case CharCtrlJ:
- i.pastedLines = append(i.pastedLines, buf.String())
- buf.Buf.Clear()
- buf.Pos = 0
- buf.DisplayPos = 0
- buf.LineHasSpace.Clear()
- fmt.Println()
- fmt.Print(i.Prompt.AltPrompt)
- i.Prompt.UseAlt = true
- continue
+ // If not draining cooked-mode input, treat as multiline
+ if !draining {
+ i.pastedLines = append(i.pastedLines, buf.String())
+ buf.Buf.Clear()
+ buf.Pos = 0
+ buf.DisplayPos = 0
+ buf.LineHasSpace.Clear()
+ fmt.Println()
+ fmt.Print(i.Prompt.AltPrompt)
+ i.Prompt.UseAlt = true
+ continue
+ }
+ // Draining cooked-mode input: treat \n as submit
+ fallthrough
case CharEnter:
output := buf.String()
if len(i.pastedLines) > 0 {
diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go
index faab1b229ca..895a8fb77e9 100644
--- a/runner/ollamarunner/cache.go
+++ b/runner/ollamarunner/cache.go
@@ -124,8 +124,17 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
}
if c.cache != nil {
- if numPast > 0 && !c.cache.CanResume(slot.Id, numPast) {
- numPast = 0
+ if numPast > 0 {
+ // Recurrent caches use checkpoints to pick a safe resume position.
+ if cc, ok := c.cache.(kvcache.CheckpointCache); ok {
+ if restored, ok := cc.PrepareRestore(slot.Id, numPast); ok {
+ numPast = restored
+ } else {
+ numPast = 0
+ }
+ } else if !c.cache.CanResume(slot.Id, numPast) {
+ numPast = 0
+ }
}
err = c.cache.Remove(slot.Id, numPast, math.MaxInt32)
diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go
index 49f4ea902eb..d314eda3b62 100644
--- a/runner/ollamarunner/runner.go
+++ b/runner/ollamarunner/runner.go
@@ -37,6 +37,7 @@ import (
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
+ "github.com/ollama/ollama/tokenizer"
_ "github.com/ollama/ollama/model/models"
)
@@ -210,9 +211,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// calculateLogprobs converts raw logits to log probabilities and finds top K tokens
-func calculateLogprobs(logits []float32, selectedToken int32, topK int, textProcessor model.TextProcessor) []llm.Logprob {
+func calculateLogprobs(logits []float32, selectedToken int32, topK int, tok tokenizer.Tokenizer) []llm.Logprob {
decoder := func(tokenID int) string {
- text, _ := textProcessor.Decode([]int32{int32(tokenID)})
+ text, _ := tok.Decode([]int32{int32(tokenID)})
return text
}
return common.CalculateLogprobs(logits, int(selectedToken), topK, decoder)
@@ -242,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
for i, part := range parts {
// text - tokenize
- tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
+ tokens, err := s.model.(tokenizer.Tokenizer).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
}
@@ -516,13 +517,6 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
continue
}
- // if past the num predict limit
- if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
- s.removeSequence(seqIdx, llm.DoneReasonLength)
- nextBatch.seqs[seqIdx] = nil
- continue
- }
-
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
@@ -711,7 +705,6 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
- seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
@@ -742,8 +735,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
if seq == nil || nextBatchTokens[i] == nil {
continue
}
+ // If the sequence was replaced while this batch was computing, discard results.
+ if activeBatch.seqs[i] != seq {
+ logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
+ continue
+ }
seq.lastUpdatedAt = t
+ seq.numPredicted++
if seq.numPredicted == 1 {
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
@@ -768,7 +767,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
nextBatchTokens[i].Token = token
// if it's an end of sequence token, break
- if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
+ if s.model.(tokenizer.Tokenizer).Is(token, tokenizer.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
@@ -777,18 +776,25 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
- piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
+ piece, err := s.model.(tokenizer.Tokenizer).Decode([]int32{token})
if err != nil {
panic("failed to decode token")
}
// Calculate logprobs if requested (after EOS check to avoid logprobs for EOS tokens)
if seq.logprobs {
- logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(model.TextProcessor))
+ logprobs := calculateLogprobs(logits, token, seq.topLogprobs, s.model.(tokenizer.Tokenizer))
seq.pendingLogprobs = append(seq.pendingLogprobs, logprobs...)
}
seq.pendingResponses = append(seq.pendingResponses, piece)
+
+ // if past the num predict limit
+ if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
+ s.removeSequence(i, llm.DoneReasonLength)
+ continue
+ }
+
sequence := strings.Join(seq.pendingResponses, "")
if ok, stop := common.FindStop(sequence, seq.stop); ok {
@@ -875,7 +881,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
var grammar *sample.GrammarSampler
var err error
if req.Grammar != "" {
- grammar, err = sample.NewGrammarSampler(s.model.(model.TextProcessor), req.Grammar)
+ grammar, err = sample.NewGrammarSampler(s.model.(tokenizer.Tokenizer), req.Grammar)
if err != nil {
http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
return
@@ -1362,7 +1368,7 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
// Dummy load to get the backend wired up
f, err := os.CreateTemp("", "*.bin")
if err != nil {
- http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
+ http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
defer f.Close()
@@ -1372,13 +1378,13 @@ func (s *Server) info(w http.ResponseWriter, r *http.Request) {
"general.architecture": "llama",
"tokenizer.ggml.model": "gpt2",
}, nil); err != nil {
- http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
+ http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
m, err = model.New(f.Name(), make([]string, 0), ml.BackendParams{NumThreads: runtime.NumCPU(), AllocMemory: false, GPULayers: ml.GPULayersList{{}}})
if err != nil {
- http.Error(w, fmt.Sprintf("failed to initialize baackend: %v", err), http.StatusInternalServerError)
+ http.Error(w, fmt.Sprintf("failed to initialize backend: %v", err), http.StatusInternalServerError)
return
}
slog.Debug("dummy model load took", "duration", time.Since(startLoad))
diff --git a/runner/runner.go b/runner/runner.go
index 5434107980e..d2daddf6946 100644
--- a/runner/runner.go
+++ b/runner/runner.go
@@ -3,7 +3,8 @@ package runner
import (
"github.com/ollama/ollama/runner/llamarunner"
"github.com/ollama/ollama/runner/ollamarunner"
- imagerunner "github.com/ollama/ollama/x/imagegen/runner"
+ "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/mlxrunner"
)
func Execute(args []string) error {
@@ -11,22 +12,15 @@ func Execute(args []string) error {
args = args[1:]
}
- var newRunner bool
- var imageRunner bool
- if len(args) > 0 && args[0] == "--ollama-engine" {
- args = args[1:]
- newRunner = true
- }
- if len(args) > 0 && args[0] == "--image-engine" {
- args = args[1:]
- imageRunner = true
- }
-
- if imageRunner {
- return imagerunner.Execute(args)
- } else if newRunner {
- return ollamarunner.Execute(args)
- } else {
- return llamarunner.Execute(args)
+ if len(args) > 0 {
+ switch args[0] {
+ case "--ollama-engine":
+ return ollamarunner.Execute(args[1:])
+ case "--imagegen-engine":
+ return imagegen.Execute(args[1:])
+ case "--mlx-engine":
+ return mlxrunner.Execute(args[1:])
+ }
}
+ return llamarunner.Execute(args)
}
diff --git a/sample/samplers.go b/sample/samplers.go
index d395650d9ed..eb17992861e 100644
--- a/sample/samplers.go
+++ b/sample/samplers.go
@@ -7,7 +7,7 @@ import (
"slices"
"github.com/ollama/ollama/llama"
- "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/tokenizer"
)
// token represents information about a single token during sampling
@@ -168,15 +168,15 @@ type GrammarSampler struct {
grammar *llama.Grammar
}
-func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSampler, error) {
- vocabIds := make([]uint32, len(model.Vocabulary().Values))
- pieces := make([]string, len(model.Vocabulary().Values))
- for i := range model.Vocabulary().Values {
- pieces[i], _ = model.Decode([]int32{int32(i)})
+func NewGrammarSampler(tok tokenizer.Tokenizer, grammarStr string) (*GrammarSampler, error) {
+ vocabIds := make([]uint32, len(tok.Vocabulary().Values))
+ pieces := make([]string, len(tok.Vocabulary().Values))
+ for i := range tok.Vocabulary().Values {
+ pieces[i], _ = tok.Decode([]int32{int32(i)})
vocabIds[i] = uint32(i)
}
- grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
+ grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, tok.Vocabulary().EOS)
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}
diff --git a/sample/samplers_test.go b/sample/samplers_test.go
index eb10295d453..9850d6e449b 100644
--- a/sample/samplers_test.go
+++ b/sample/samplers_test.go
@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
- "github.com/ollama/ollama/model"
+ "github.com/ollama/ollama/tokenizer"
)
func TestWeighted(t *testing.T) {
@@ -60,10 +60,10 @@ func TestWeighted(t *testing.T) {
}
}
-func modelHelper(t testing.TB) model.BytePairEncoding {
+func modelHelper(t testing.TB) tokenizer.Tokenizer {
t.Helper()
- f, err := os.Open(filepath.Join("..", "model", "testdata", "llama3.2", "encoder.json"))
+ f, err := os.Open(filepath.FromSlash("../tokenizer/testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -81,8 +81,8 @@ func modelHelper(t testing.TB) model.BytePairEncoding {
merges := make([]string, 0, 1)
// Only need vocab for Grammar Test
- return model.NewBytePairEncoding(
- &model.Vocabulary{
+ return tokenizer.NewBytePairEncoding(
+ &tokenizer.Vocabulary{
Values: tokens,
Types: make([]int32, len(vocab)),
Merges: merges,
diff --git a/scripts/build_darwin.sh b/scripts/build_darwin.sh
index 3560520ff94..4325a978792 100755
--- a/scripts/build_darwin.sh
+++ b/scripts/build_darwin.sh
@@ -14,8 +14,8 @@
VOL_NAME=${VOL_NAME:-"Ollama"}
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${VERSION#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'"
-export CGO_CFLAGS="-mmacosx-version-min=14.0"
-export CGO_CXXFLAGS="-mmacosx-version-min=14.0"
+export CGO_CFLAGS="-O3 -mmacosx-version-min=14.0"
+export CGO_CXXFLAGS="-O3 -mmacosx-version-min=14.0"
export CGO_LDFLAGS="-mmacosx-version-min=14.0"
set -e
diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1
index 30fb9d09afe..21e6f3be0e0 100644
--- a/scripts/build_windows.ps1
+++ b/scripts/build_windows.ps1
@@ -56,6 +56,12 @@ function checkEnv {
$script:DIST_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$env:CGO_ENABLED="1"
+ if (-not $env:CGO_CFLAGS) {
+ $env:CGO_CFLAGS = "-O3"
+ }
+ if (-not $env:CGO_CXXFLAGS) {
+ $env:CGO_CXXFLAGS = "-O3"
+ }
Write-Output "Checking version"
if (!$env:VERSION) {
$data=(git describe --tags --first-parent --abbrev=7 --long --dirty --always)
@@ -296,12 +302,22 @@ function deps {
}
function sign {
+ # Copy install.ps1 to dist for release packaging
+ write-host "Copying install.ps1 to dist"
+ Copy-Item -Path "${script:SRC_DIR}\scripts\install.ps1" -Destination "${script:SRC_DIR}\dist\install.ps1"
+
if ("${env:KEY_CONTAINER}") {
write-host "Signing Ollama executables, scripts and libraries"
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
$(get-childitem -path "${script:SRC_DIR}\dist\windows-*" -r -include @('*.exe', '*.dll'))
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
+
+ write-host "Signing install.ps1"
+ & "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
+ /csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} `
+ "${script:SRC_DIR}\dist\install.ps1"
+ if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
} else {
write-host "Signing not enabled"
}
diff --git a/scripts/install.ps1 b/scripts/install.ps1
new file mode 100644
index 00000000000..b31cc887ccf
--- /dev/null
+++ b/scripts/install.ps1
@@ -0,0 +1,323 @@
+<#
+.SYNOPSIS
+ Install, upgrade, or uninstall Ollama on Windows.
+
+.DESCRIPTION
+ Downloads and installs Ollama.
+
+ Quick install:
+
+ irm https://ollama.com/install.ps1 | iex
+
+ Specific version:
+
+ $env:OLLAMA_VERSION="0.5.7"; irm https://ollama.com/install.ps1 | iex
+
+ Custom install directory:
+
+ $env:OLLAMA_INSTALL_DIR="D:\Ollama"; irm https://ollama.com/install.ps1 | iex
+
+ Uninstall:
+
+ $env:OLLAMA_UNINSTALL=1; irm https://ollama.com/install.ps1 | iex
+
+ Environment variables:
+
+ OLLAMA_VERSION Target version (default: latest stable)
+ OLLAMA_INSTALL_DIR Custom install directory
+ OLLAMA_UNINSTALL Set to 1 to uninstall Ollama
+ OLLAMA_DEBUG Enable verbose output
+
+.EXAMPLE
+ irm https://ollama.com/install.ps1 | iex
+
+.EXAMPLE
+ $env:OLLAMA_VERSION = "0.5.7"; irm https://ollama.com/install.ps1 | iex
+
+.LINK
+ https://ollama.com
+#>
+
+$ErrorActionPreference = "Stop"
+$ProgressPreference = "SilentlyContinue"
+
+# --------------------------------------------------------------------------
+# Configuration from environment variables
+# --------------------------------------------------------------------------
+
+$Version = if ($env:OLLAMA_VERSION) { $env:OLLAMA_VERSION } else { "" }
+$InstallDir = if ($env:OLLAMA_INSTALL_DIR) { $env:OLLAMA_INSTALL_DIR } else { "" }
+$Uninstall = $env:OLLAMA_UNINSTALL -eq "1"
+$DebugInstall = [bool]$env:OLLAMA_DEBUG
+
+# --------------------------------------------------------------------------
+# Constants
+# --------------------------------------------------------------------------
+
+# OLLAMA_DOWNLOAD_URL for developer testing only
+$DownloadBaseURL = if ($env:OLLAMA_DOWNLOAD_URL) { $env:OLLAMA_DOWNLOAD_URL.TrimEnd('/') } else { "https://ollama.com/download" }
+$InnoSetupUninstallGuid = "{44E83376-CE68-45EB-8FC1-393500EB558C}_is1"
+
+# --------------------------------------------------------------------------
+# Helpers
+# --------------------------------------------------------------------------
+
+function Write-Status {
+ param([string]$Message)
+ if ($DebugInstall) { Write-Host $Message }
+}
+
+function Write-Step {
+ param([string]$Message)
+ if ($DebugInstall) { Write-Host ">>> $Message" -ForegroundColor Cyan }
+}
+
+function Test-Signature {
+ param([string]$FilePath)
+
+ $sig = Get-AuthenticodeSignature -FilePath $FilePath
+ if ($sig.Status -ne "Valid") {
+ Write-Status " Signature status: $($sig.Status)"
+ return $false
+ }
+
+ # Verify it's signed by Ollama Inc. (check exact organization name)
+ # Anchor with comma/boundary to prevent "O=Not Ollama Inc." from matching
+ $subject = $sig.SignerCertificate.Subject
+ if ($subject -notmatch "(^|, )O=Ollama Inc\.(,|$)") {
+ Write-Status " Unexpected signer: $subject"
+ return $false
+ }
+
+ Write-Status " Signature valid: $subject"
+ return $true
+}
+
+function Find-InnoSetupInstall {
+ # Check both HKCU (per-user) and HKLM (per-machine) locations
+ $possibleKeys = @(
+ "HKCU:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
+ "HKLM:\Software\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid",
+ "HKLM:\Software\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\$InnoSetupUninstallGuid"
+ )
+
+ foreach ($key in $possibleKeys) {
+ if (Test-Path $key) {
+ Write-Status " Found install at: $key"
+ return $key
+ }
+ }
+ return $null
+}
+
+function Update-SessionPath {
+ # Update PATH in current session so 'ollama' works immediately
+ if ($InstallDir) {
+ $ollamaDir = $InstallDir
+ } else {
+ $ollamaDir = Join-Path $env:LOCALAPPDATA "Programs\Ollama"
+ }
+
+ # Add to PATH if not already present
+ if (Test-Path $ollamaDir) {
+ $currentPath = $env:PATH -split ';'
+ if ($ollamaDir -notin $currentPath) {
+ $env:PATH = "$ollamaDir;$env:PATH"
+ Write-Status " Added $ollamaDir to session PATH"
+ }
+ }
+}
+
+function Invoke-Download {
+ param(
+ [string]$Url,
+ [string]$OutFile
+ )
+
+ Write-Status " Downloading: $Url"
+ try {
+ $request = [System.Net.HttpWebRequest]::Create($Url)
+ $request.AllowAutoRedirect = $true
+ $response = $request.GetResponse()
+ $totalBytes = $response.ContentLength
+ $stream = $response.GetResponseStream()
+ $fileStream = [System.IO.FileStream]::new($OutFile, [System.IO.FileMode]::Create)
+ $buffer = [byte[]]::new(65536)
+ $totalRead = 0
+ $lastUpdate = [DateTime]::MinValue
+ $barWidth = 40
+
+ try {
+ while (($read = $stream.Read($buffer, 0, $buffer.Length)) -gt 0) {
+ $fileStream.Write($buffer, 0, $read)
+ $totalRead += $read
+
+ $now = [DateTime]::UtcNow
+ if (($now - $lastUpdate).TotalMilliseconds -ge 250) {
+ if ($totalBytes -gt 0) {
+ $pct = [math]::Min(100.0, ($totalRead / $totalBytes) * 100)
+ $filled = [math]::Floor($barWidth * $pct / 100)
+ $empty = $barWidth - $filled
+ $bar = ('#' * $filled) + (' ' * $empty)
+ $pctFmt = $pct.ToString("0.0")
+ Write-Host -NoNewline "`r$bar ${pctFmt}%"
+ } else {
+ $sizeMB = [math]::Round($totalRead / 1MB, 1)
+ Write-Host -NoNewline "`r${sizeMB} MB downloaded..."
+ }
+ $lastUpdate = $now
+ }
+ }
+
+ # Final progress update
+ if ($totalBytes -gt 0) {
+ $bar = '#' * $barWidth
+ Write-Host "`r$bar 100.0%"
+ } else {
+ $sizeMB = [math]::Round($totalRead / 1MB, 1)
+ Write-Host "`r${sizeMB} MB downloaded. "
+ }
+ } finally {
+ $fileStream.Close()
+ $stream.Close()
+ $response.Close()
+ }
+ } catch {
+ if ($_.Exception -is [System.Net.WebException]) {
+ $webEx = [System.Net.WebException]$_.Exception
+ if ($webEx.Response -and ([System.Net.HttpWebResponse]$webEx.Response).StatusCode -eq [System.Net.HttpStatusCode]::NotFound) {
+ throw "Download failed: not found at $Url"
+ }
+ }
+ if ($_.Exception.InnerException -is [System.Net.WebException]) {
+ $webEx = [System.Net.WebException]$_.Exception.InnerException
+ if ($webEx.Response -and ([System.Net.HttpWebResponse]$webEx.Response).StatusCode -eq [System.Net.HttpStatusCode]::NotFound) {
+ throw "Download failed: not found at $Url"
+ }
+ }
+ throw "Download failed for ${Url}: $($_.Exception.Message)"
+ }
+}
+
+# --------------------------------------------------------------------------
+# Uninstall
+# --------------------------------------------------------------------------
+
+function Invoke-Uninstall {
+ Write-Step "Uninstalling Ollama"
+
+ $regKey = Find-InnoSetupInstall
+ if (-not $regKey) {
+ Write-Host ">>> Ollama is not installed."
+ return
+ }
+
+ $uninstallString = (Get-ItemProperty -Path $regKey).UninstallString
+ if (-not $uninstallString) {
+ Write-Warning "No uninstall string found in registry"
+ return
+ }
+
+ # Strip quotes if present
+ $uninstallExe = $uninstallString -replace '"', ''
+ Write-Status " Uninstaller: $uninstallExe"
+
+ if (-not (Test-Path $uninstallExe)) {
+ Write-Warning "Uninstaller not found at: $uninstallExe"
+ return
+ }
+
+ Write-Host ">>> Launching uninstaller..."
+ # Run with GUI so user can choose whether to keep models
+ Start-Process -FilePath $uninstallExe -Wait
+
+ # Verify removal
+ if (Find-InnoSetupInstall) {
+ Write-Warning "Uninstall may not have completed"
+ } else {
+ Write-Host ">>> Ollama has been uninstalled."
+ }
+}
+
+# --------------------------------------------------------------------------
+# Install
+# --------------------------------------------------------------------------
+
+function Invoke-Install {
+ # Determine installer URL
+ if ($Version) {
+ $installerUrl = "$DownloadBaseURL/OllamaSetup.exe?version=$Version"
+ } else {
+ $installerUrl = "$DownloadBaseURL/OllamaSetup.exe"
+ }
+
+ # Download installer
+ Write-Step "Downloading Ollama"
+ if (-not $DebugInstall) {
+ Write-Host ">>> Downloading Ollama for Windows..."
+ }
+
+ $tempInstaller = Join-Path $env:TEMP "OllamaSetup.exe"
+ Invoke-Download -Url $installerUrl -OutFile $tempInstaller
+
+ # Verify signature
+ Write-Step "Verifying signature"
+ if (-not (Test-Signature -FilePath $tempInstaller)) {
+ Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
+ throw "Installer signature verification failed"
+ }
+
+ # Build installer arguments
+ $installerArgs = "/VERYSILENT /NORESTART /SUPPRESSMSGBOXES"
+ if ($InstallDir) {
+ $installerArgs += " /DIR=`"$InstallDir`""
+ }
+ Write-Status " Installer args: $installerArgs"
+
+ # Run installer
+ Write-Step "Installing Ollama"
+ if (-not $DebugInstall) {
+ Write-Host ">>> Installing Ollama..."
+ }
+
+ # Create upgrade marker so the app starts hidden
+ # The app checks for this file on startup and removes it after
+ $markerDir = Join-Path $env:LOCALAPPDATA "Ollama"
+ $markerFile = Join-Path $markerDir "upgraded"
+ if (-not (Test-Path $markerDir)) {
+ New-Item -ItemType Directory -Path $markerDir -Force | Out-Null
+ }
+ New-Item -ItemType File -Path $markerFile -Force | Out-Null
+ Write-Status " Created upgrade marker: $markerFile"
+
+ # Start installer and wait for just the installer process (not children)
+ # Using -Wait would wait for Ollama to exit too, which we don't want
+ $proc = Start-Process -FilePath $tempInstaller `
+ -ArgumentList $installerArgs `
+ -PassThru
+ $proc.WaitForExit()
+
+ if ($proc.ExitCode -ne 0) {
+ Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
+ throw "Installation failed with exit code $($proc.ExitCode)"
+ }
+
+ # Cleanup
+ Remove-Item $tempInstaller -Force -ErrorAction SilentlyContinue
+
+ # Update PATH in current session so 'ollama' works immediately
+ Write-Step "Updating session PATH"
+ Update-SessionPath
+
+ Write-Host ">>> Install complete. Run 'ollama' from the command line."
+}
+
+# --------------------------------------------------------------------------
+# Main
+# --------------------------------------------------------------------------
+
+if ($Uninstall) {
+ Invoke-Uninstall
+} else {
+ Invoke-Install
+}
diff --git a/scripts/install.sh b/scripts/install.sh
old mode 100644
new mode 100755
index 8e0b8ed8feb..8bff7e2f41a
--- a/scripts/install.sh
+++ b/scripts/install.sh
@@ -1,7 +1,11 @@
#!/bin/sh
-# This script installs Ollama on Linux.
+# This script installs Ollama on Linux and macOS.
# It detects the current operating system architecture and installs the appropriate version of Ollama.
+# Wrap script in main function so that a truncated partial download doesn't end
+# up executing half a script.
+main() {
+
set -eu
red="$( (/usr/bin/tput bold || :; /usr/bin/tput setaf 1 || :) 2>&-)"
@@ -27,8 +31,7 @@ require() {
echo $MISSING
}
-[ "$(uname -s)" = "Linux" ] || error 'This script is intended to run on Linux only.'
-
+OS="$(uname -s)"
ARCH=$(uname -m)
case "$ARCH" in
x86_64) ARCH="amd64" ;;
@@ -36,6 +39,65 @@ case "$ARCH" in
*) error "Unsupported architecture: $ARCH" ;;
esac
+VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
+
+###########################################
+# macOS
+###########################################
+
+if [ "$OS" = "Darwin" ]; then
+ NEEDS=$(require curl unzip)
+ if [ -n "$NEEDS" ]; then
+ status "ERROR: The following tools are required but missing:"
+ for NEED in $NEEDS; do
+ echo " - $NEED"
+ done
+ exit 1
+ fi
+
+ DOWNLOAD_URL="https://ollama.com/download/Ollama-darwin.zip${VER_PARAM}"
+
+ if pgrep -x Ollama >/dev/null 2>&1; then
+ status "Stopping running Ollama instance..."
+ pkill -x Ollama 2>/dev/null || true
+ sleep 2
+ fi
+
+ if [ -d "/Applications/Ollama.app" ]; then
+ status "Removing existing Ollama installation..."
+ rm -rf "/Applications/Ollama.app"
+ fi
+
+ status "Downloading Ollama for macOS..."
+ curl --fail --show-error --location --progress-bar \
+ -o "$TEMP_DIR/Ollama-darwin.zip" "$DOWNLOAD_URL"
+
+ status "Installing Ollama to /Applications..."
+ unzip -q "$TEMP_DIR/Ollama-darwin.zip" -d "$TEMP_DIR"
+ mv "$TEMP_DIR/Ollama.app" "/Applications/"
+
+ if [ ! -L "/usr/local/bin/ollama" ] || [ "$(readlink "/usr/local/bin/ollama")" != "/Applications/Ollama.app/Contents/Resources/ollama" ]; then
+ status "Adding 'ollama' command to PATH (may require password)..."
+ mkdir -p "/usr/local/bin" 2>/dev/null || sudo mkdir -p "/usr/local/bin"
+ ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama" 2>/dev/null || \
+ sudo ln -sf "/Applications/Ollama.app/Contents/Resources/ollama" "/usr/local/bin/ollama"
+ fi
+
+ if [ -z "${OLLAMA_NO_START:-}" ]; then
+ status "Starting Ollama..."
+ open -a Ollama --args hidden
+ fi
+
+ status "Install complete. You can now run 'ollama'."
+ exit 0
+fi
+
+###########################################
+# Linux
+###########################################
+
+[ "$OS" = "Linux" ] || error 'This script is intended to run on Linux and macOS only.'
+
IS_WSL2=false
KERN=$(uname -r)
@@ -45,8 +107,6 @@ case "$KERN" in
*) ;;
esac
-VER_PARAM="${OLLAMA_VERSION:+?version=$OLLAMA_VERSION}"
-
SUDO=
if [ "$(id -u)" -ne 0 ]; then
# Running as root, no need for sudo
@@ -390,3 +450,6 @@ fi
status "NVIDIA GPU ready."
install_success
+}
+
+main
diff --git a/server/aliases.go b/server/aliases.go
new file mode 100644
index 00000000000..18e9447e5cc
--- /dev/null
+++ b/server/aliases.go
@@ -0,0 +1,438 @@
+package server
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "sort"
+ "strings"
+ "sync"
+
+ "github.com/ollama/ollama/manifest"
+ "github.com/ollama/ollama/types/model"
+)
+
+const (
+ serverConfigFilename = "server.json"
+ serverConfigVersion = 1
+)
+
+var errAliasCycle = errors.New("alias cycle detected")
+
+type aliasEntry struct {
+ Alias string `json:"alias"`
+ Target string `json:"target"`
+ PrefixMatching bool `json:"prefix_matching,omitempty"`
+}
+
+type serverConfig struct {
+ Version int `json:"version"`
+ Aliases []aliasEntry `json:"aliases"`
+}
+
+type store struct {
+ mu sync.RWMutex
+ path string
+ entries map[string]aliasEntry // normalized alias -> entry (exact matches)
+ prefixEntries []aliasEntry // prefix matches, sorted longest-first
+}
+
+func createStore(path string) (*store, error) {
+ store := &store{
+ path: path,
+ entries: make(map[string]aliasEntry),
+ }
+ if err := store.load(); err != nil {
+ return nil, err
+ }
+ return store, nil
+}
+
+func (s *store) load() error {
+ data, err := os.ReadFile(s.path)
+ if err != nil {
+ if errors.Is(err, os.ErrNotExist) {
+ return nil
+ }
+ return err
+ }
+
+ var cfg serverConfig
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return err
+ }
+
+ if cfg.Version != 0 && cfg.Version != serverConfigVersion {
+ return fmt.Errorf("unsupported router config version %d", cfg.Version)
+ }
+
+ for _, entry := range cfg.Aliases {
+ targetName := model.ParseName(entry.Target)
+ if !targetName.IsValid() {
+ slog.Warn("invalid alias target in router config", "target", entry.Target)
+ continue
+ }
+ canonicalTarget := displayAliasName(targetName)
+
+ if entry.PrefixMatching {
+ // Prefix aliases don't need to be valid model names
+ alias := strings.TrimSpace(entry.Alias)
+ if alias == "" {
+ slog.Warn("empty prefix alias in router config")
+ continue
+ }
+ s.prefixEntries = append(s.prefixEntries, aliasEntry{
+ Alias: alias,
+ Target: canonicalTarget,
+ PrefixMatching: true,
+ })
+ } else {
+ aliasName := model.ParseName(entry.Alias)
+ if !aliasName.IsValid() {
+ slog.Warn("invalid alias name in router config", "alias", entry.Alias)
+ continue
+ }
+ canonicalAlias := displayAliasName(aliasName)
+ s.entries[normalizeAliasKey(aliasName)] = aliasEntry{
+ Alias: canonicalAlias,
+ Target: canonicalTarget,
+ }
+ }
+ }
+
+ // Sort prefix entries by alias length descending (longest prefix wins)
+ s.sortPrefixEntriesLocked()
+
+ return nil
+}
+
+func (s *store) saveLocked() error {
+ dir := filepath.Dir(s.path)
+ if err := os.MkdirAll(dir, 0o755); err != nil {
+ return err
+ }
+
+ // Read existing file into a generic map to preserve unknown fields
+ // (e.g. disable_ollama_cloud) that aliasStore doesn't own.
+ existing := make(map[string]json.RawMessage)
+ if data, err := os.ReadFile(s.path); err == nil {
+ if err := json.Unmarshal(data, &existing); err != nil {
+ slog.Debug("failed to parse existing server config; preserving unknown fields skipped", "path", s.path, "error", err)
+ }
+ }
+
+ // Combine exact and prefix entries
+ entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
+ for _, entry := range s.entries {
+ entries = append(entries, entry)
+ }
+ entries = append(entries, s.prefixEntries...)
+
+ sort.Slice(entries, func(i, j int) bool {
+ return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
+ })
+
+ // Overwrite only the keys we own
+ versionJSON, err := json.Marshal(serverConfigVersion)
+ if err != nil {
+ return err
+ }
+ aliasesJSON, err := json.Marshal(entries)
+ if err != nil {
+ return err
+ }
+ existing["version"] = versionJSON
+ existing["aliases"] = aliasesJSON
+
+ f, err := os.CreateTemp(dir, "router-*.json")
+ if err != nil {
+ return err
+ }
+
+ enc := json.NewEncoder(f)
+ enc.SetIndent("", " ")
+ if err := enc.Encode(existing); err != nil {
+ _ = f.Close()
+ _ = os.Remove(f.Name())
+ return err
+ }
+
+ if err := f.Close(); err != nil {
+ _ = os.Remove(f.Name())
+ return err
+ }
+
+ if err := os.Chmod(f.Name(), 0o644); err != nil {
+ _ = os.Remove(f.Name())
+ return err
+ }
+
+ return os.Rename(f.Name(), s.path)
+}
+
+func (s *store) ResolveName(name model.Name) (model.Name, bool, error) {
+ // If a local model exists, do not allow alias shadowing (highest priority).
+ exists, err := localModelExists(name)
+ if err != nil {
+ return name, false, err
+ }
+ if exists {
+ return name, false, nil
+ }
+
+ key := normalizeAliasKey(name)
+
+ s.mu.RLock()
+ entry, exactMatch := s.entries[key]
+ var prefixMatch *aliasEntry
+ if !exactMatch {
+ // Try prefix matching - prefixEntries is sorted longest-first
+ nameStr := strings.ToLower(displayAliasName(name))
+ for i := range s.prefixEntries {
+ prefix := strings.ToLower(s.prefixEntries[i].Alias)
+ if strings.HasPrefix(nameStr, prefix) {
+ prefixMatch = &s.prefixEntries[i]
+ break // First match is longest due to sorting
+ }
+ }
+ }
+ s.mu.RUnlock()
+
+ if !exactMatch && prefixMatch == nil {
+ return name, false, nil
+ }
+
+ var current string
+ var visited map[string]struct{}
+
+ if exactMatch {
+ visited = map[string]struct{}{key: {}}
+ current = entry.Target
+ } else {
+ // For prefix match, use the target as-is
+ visited = map[string]struct{}{}
+ current = prefixMatch.Target
+ }
+
+ targetKey := normalizeAliasKeyString(current)
+
+ for {
+ targetName := model.ParseName(current)
+ if !targetName.IsValid() {
+ return name, false, fmt.Errorf("alias target %q is invalid", current)
+ }
+
+ if _, seen := visited[targetKey]; seen {
+ return name, false, errAliasCycle
+ }
+ visited[targetKey] = struct{}{}
+
+ s.mu.RLock()
+ next, ok := s.entries[targetKey]
+ s.mu.RUnlock()
+ if !ok {
+ return targetName, true, nil
+ }
+
+ current = next.Target
+ targetKey = normalizeAliasKeyString(current)
+ }
+}
+
+func (s *store) Set(alias, target model.Name, prefixMatching bool) error {
+ targetKey := normalizeAliasKey(target)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if prefixMatching {
+ // For prefix aliases, we skip cycle detection since prefix matching
+ // works differently and the target is a specific model
+ aliasStr := displayAliasName(alias)
+
+ // Remove any existing prefix entry with the same alias
+ for i, e := range s.prefixEntries {
+ if strings.EqualFold(e.Alias, aliasStr) {
+ s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
+ break
+ }
+ }
+
+ s.prefixEntries = append(s.prefixEntries, aliasEntry{
+ Alias: aliasStr,
+ Target: displayAliasName(target),
+ PrefixMatching: true,
+ })
+ s.sortPrefixEntriesLocked()
+ return s.saveLocked()
+ }
+
+ aliasKey := normalizeAliasKey(alias)
+
+ if aliasKey == targetKey {
+ return fmt.Errorf("alias cannot point to itself")
+ }
+
+ visited := map[string]struct{}{aliasKey: {}}
+ currentKey := targetKey
+ for {
+ if _, seen := visited[currentKey]; seen {
+ return errAliasCycle
+ }
+ visited[currentKey] = struct{}{}
+
+ next, ok := s.entries[currentKey]
+ if !ok {
+ break
+ }
+ currentKey = normalizeAliasKeyString(next.Target)
+ }
+
+ s.entries[aliasKey] = aliasEntry{
+ Alias: displayAliasName(alias),
+ Target: displayAliasName(target),
+ }
+
+ return s.saveLocked()
+}
+
+func (s *store) Delete(alias model.Name) (bool, error) {
+ aliasKey := normalizeAliasKey(alias)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Try exact match first
+ if _, ok := s.entries[aliasKey]; ok {
+ delete(s.entries, aliasKey)
+ return true, s.saveLocked()
+ }
+
+ // Try prefix entries
+ aliasStr := displayAliasName(alias)
+ for i, e := range s.prefixEntries {
+ if strings.EqualFold(e.Alias, aliasStr) {
+ s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
+ return true, s.saveLocked()
+ }
+ }
+
+ return false, nil
+}
+
+// DeleteByString deletes an alias by its raw string value, useful for prefix
+// aliases that may not be valid model names.
+func (s *store) DeleteByString(alias string) (bool, error) {
+ alias = strings.TrimSpace(alias)
+ aliasLower := strings.ToLower(alias)
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Try prefix entries first (since this is mainly for prefix aliases)
+ for i, e := range s.prefixEntries {
+ if strings.EqualFold(e.Alias, alias) {
+ s.prefixEntries = append(s.prefixEntries[:i], s.prefixEntries[i+1:]...)
+ return true, s.saveLocked()
+ }
+ }
+
+ // Also check exact entries by normalized key
+ if _, ok := s.entries[aliasLower]; ok {
+ delete(s.entries, aliasLower)
+ return true, s.saveLocked()
+ }
+
+ return false, nil
+}
+
+func (s *store) List() []aliasEntry {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ entries := make([]aliasEntry, 0, len(s.entries)+len(s.prefixEntries))
+ for _, entry := range s.entries {
+ entries = append(entries, entry)
+ }
+ entries = append(entries, s.prefixEntries...)
+
+ sort.Slice(entries, func(i, j int) bool {
+ return strings.Compare(entries[i].Alias, entries[j].Alias) < 0
+ })
+ return entries
+}
+
+func normalizeAliasKey(name model.Name) string {
+ return strings.ToLower(displayAliasName(name))
+}
+
+func (s *store) sortPrefixEntriesLocked() {
+ sort.Slice(s.prefixEntries, func(i, j int) bool {
+ // Sort by length descending (longest prefix first)
+ return len(s.prefixEntries[i].Alias) > len(s.prefixEntries[j].Alias)
+ })
+}
+
+func normalizeAliasKeyString(value string) string {
+ n := model.ParseName(value)
+ if !n.IsValid() {
+ return strings.ToLower(strings.TrimSpace(value))
+ }
+ return normalizeAliasKey(n)
+}
+
+func displayAliasName(n model.Name) string {
+ display := n.DisplayShortest()
+ if strings.EqualFold(n.Tag, "latest") {
+ if idx := strings.LastIndex(display, ":"); idx != -1 {
+ return display[:idx]
+ }
+ }
+ return display
+}
+
+func localModelExists(name model.Name) (bool, error) {
+ manifests, err := manifest.Manifests(true)
+ if err != nil {
+ return false, err
+ }
+ needle := name.String()
+ for existing := range manifests {
+ if strings.EqualFold(existing.String(), needle) {
+ return true, nil
+ }
+ }
+ return false, nil
+}
+
+func serverConfigPath() string {
+ home, err := os.UserHomeDir()
+ if err != nil {
+ return filepath.Join(".ollama", serverConfigFilename)
+ }
+ return filepath.Join(home, ".ollama", serverConfigFilename)
+}
+
+func (s *Server) aliasStore() (*store, error) {
+ s.aliasesOnce.Do(func() {
+ s.aliases, s.aliasesErr = createStore(serverConfigPath())
+ })
+
+ return s.aliases, s.aliasesErr
+}
+
+func (s *Server) resolveAlias(name model.Name) (model.Name, bool, error) {
+ store, err := s.aliasStore()
+ if err != nil {
+ return name, false, err
+ }
+
+ if store == nil {
+ return name, false, nil
+ }
+
+ return store.ResolveName(name)
+}
diff --git a/server/create.go b/server/create.go
index 6d39590f4fd..17ba9c91f84 100644
--- a/server/create.go
+++ b/server/create.go
@@ -28,6 +28,7 @@ import (
"github.com/ollama/ollama/format"
ofs "github.com/ollama/ollama/fs"
"github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -179,7 +180,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- resp
}
- oldManifest, _ := ParseNamedManifest(name)
+ oldManifest, _ := manifest.ParseNamedManifest(name)
var baseLayers []*layerGGML
var err error
@@ -212,9 +213,9 @@ func (s *Server) CreateHandler(c *gin.Context) {
}
if err == nil && !remote && (config.Renderer == "" || config.Parser == "" || config.Requires == "") {
- manifest, mErr := ParseNamedManifest(fromName)
- if mErr == nil && manifest.Config.Digest != "" {
- configPath, pErr := GetBlobsPath(manifest.Config.Digest)
+ mf, mErr := manifest.ParseNamedManifest(fromName)
+ if mErr == nil && mf.Config.Digest != "" {
+ configPath, pErr := manifest.BlobsPath(mf.Config.Digest)
if pErr == nil {
if cfgFile, fErr := os.Open(configPath); fErr == nil {
var baseConfig model.ConfigV2
@@ -439,7 +440,7 @@ func detectModelTypeFromFiles(files map[string]string) string {
return "gguf"
} else {
// try to see if we can find a gguf file even without the file extension
- blobPath, err := GetBlobsPath(files[fn])
+ blobPath, err := manifest.BlobsPath(files[fn])
if err != nil {
slog.Error("error getting blobs path", "file", fn)
return ""
@@ -491,7 +492,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
- blobPath, err := GetBlobsPath(digest)
+ blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -529,7 +530,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
- layer, err := NewLayer(t, mediaType)
+ layer, err := manifest.NewLayer(t, mediaType)
if err != nil {
return nil, err
}
@@ -562,7 +563,7 @@ func kvFromLayers(baseLayers []*layerGGML) (ofs.Config, error) {
}
func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, config *model.ConfigV2, fn func(resp api.ProgressResponse)) (err error) {
- var layers []Layer
+ var layers []manifest.Layer
for _, layer := range baseLayers {
if layer.GGML != nil {
quantType := strings.ToUpper(cmp.Or(r.Quantize, r.Quantization))
@@ -647,13 +648,13 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
}
for _, layer := range layers {
- if layer.status != "" {
- fn(api.ProgressResponse{Status: layer.status})
+ if layer.Status != "" {
+ fn(api.ProgressResponse{Status: layer.Status})
}
}
fn(api.ProgressResponse{Status: "writing manifest"})
- if err := WriteManifest(name, *configLayer, layers); err != nil {
+ if err := manifest.WriteManifest(name, *configLayer, layers); err != nil {
return err
}
@@ -674,7 +675,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
- blob, err := GetBlobsPath(layer.Digest)
+ blob, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -696,7 +697,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
}
temp.Seek(0, io.SeekStart)
fn(api.ProgressResponse{Status: "verifying conversion"})
- newLayer, err := NewLayer(temp, layer.MediaType)
+ newLayer, err := manifest.NewLayer(temp, layer.MediaType)
if err != nil {
return nil, err
}
@@ -716,7 +717,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
var layers []*layerGGML
fn(api.ProgressResponse{Status: "parsing GGUF"})
- blobPath, err := GetBlobsPath(digest)
+ blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -751,7 +752,7 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
mediatype = "application/vnd.ollama.image.projector"
}
- layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
+ layer, err := manifest.NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
@@ -762,8 +763,8 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return detectChatTemplate(layers)
}
-func removeLayer(layers []Layer, mediatype string) []Layer {
- return slices.DeleteFunc(layers, func(layer Layer) bool {
+func removeLayer(layers []manifest.Layer, mediatype string) []manifest.Layer {
+ return slices.DeleteFunc(layers, func(layer manifest.Layer) bool {
if layer.MediaType != mediatype {
return false
}
@@ -777,7 +778,7 @@ func removeLayer(layers []Layer, mediatype string) []Layer {
})
}
-func setTemplate(layers []Layer, t string) ([]Layer, error) {
+func setTemplate(layers []manifest.Layer, t string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.template")
if _, err := template.Parse(t); err != nil {
return nil, fmt.Errorf("%w: %s", errBadTemplate, err)
@@ -787,7 +788,7 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
}
blob := strings.NewReader(t)
- layer, err := NewLayer(blob, "application/vnd.ollama.image.template")
+ layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
@@ -796,11 +797,11 @@ func setTemplate(layers []Layer, t string) ([]Layer, error) {
return layers, nil
}
-func setSystem(layers []Layer, s string) ([]Layer, error) {
+func setSystem(layers []manifest.Layer, s string) ([]manifest.Layer, error) {
layers = removeLayer(layers, "application/vnd.ollama.image.system")
if s != "" {
blob := strings.NewReader(s)
- layer, err := NewLayer(blob, "application/vnd.ollama.image.system")
+ layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.system")
if err != nil {
return nil, err
}
@@ -809,9 +810,9 @@ func setSystem(layers []Layer, s string) ([]Layer, error) {
return layers, nil
}
-func setLicense(layers []Layer, l string) ([]Layer, error) {
+func setLicense(layers []manifest.Layer, l string) ([]manifest.Layer, error) {
blob := strings.NewReader(l)
- layer, err := NewLayer(blob, "application/vnd.ollama.image.license")
+ layer, err := manifest.NewLayer(blob, "application/vnd.ollama.image.license")
if err != nil {
return nil, err
}
@@ -819,7 +820,7 @@ func setLicense(layers []Layer, l string) ([]Layer, error) {
return layers, nil
}
-func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
+func setParameters(layers []manifest.Layer, p map[string]any) ([]manifest.Layer, error) {
if p == nil {
p = make(map[string]any)
}
@@ -828,7 +829,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
continue
}
- digestPath, err := GetBlobsPath(layer.Digest)
+ digestPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -862,7 +863,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(p); err != nil {
return nil, err
}
- layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
+ layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
@@ -870,7 +871,7 @@ func setParameters(layers []Layer, p map[string]any) ([]Layer, error) {
return layers, nil
}
-func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
+func setMessages(layers []manifest.Layer, m []api.Message) ([]manifest.Layer, error) {
// this leaves the old messages intact if no new messages were specified
// which may not be the correct behaviour
if len(m) == 0 {
@@ -883,7 +884,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
if err := json.NewEncoder(&b).Encode(m); err != nil {
return nil, err
}
- layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
+ layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return nil, err
}
@@ -891,7 +892,7 @@ func setMessages(layers []Layer, m []api.Message) ([]Layer, error) {
return layers, nil
}
-func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
+func createConfigLayer(layers []manifest.Layer, config model.ConfigV2) (*manifest.Layer, error) {
digests := make([]string, len(layers))
for i, layer := range layers {
digests[i] = layer.Digest
@@ -902,7 +903,7 @@ func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) {
if err := json.NewEncoder(&b).Encode(config); err != nil {
return nil, err
}
- layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+ layer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return nil, err
}
diff --git a/server/create_test.go b/server/create_test.go
index 061efb81aa6..0a9ac2d0a79 100644
--- a/server/create_test.go
+++ b/server/create_test.go
@@ -10,6 +10,7 @@ import (
"testing"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/manifest"
)
func TestConvertFromSafetensors(t *testing.T) {
@@ -17,7 +18,7 @@ func TestConvertFromSafetensors(t *testing.T) {
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
- l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
+ l, err := manifest.NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}
diff --git a/server/download.go b/server/download.go
index a3281661958..b8899d42d53 100644
--- a/server/download.go
+++ b/server/download.go
@@ -25,6 +25,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
+ "github.com/ollama/ollama/manifest"
+ "github.com/ollama/ollama/types/model"
)
const maxRetries = 6
@@ -457,7 +459,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
}
type downloadOpts struct {
- mp ModelPath
+ n model.Name
digest string
regOpts *registryOptions
fn func(api.ProgressResponse)
@@ -469,7 +471,7 @@ var hfDigestMap sync.Map // map[string]string
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
if opts.digest == "" {
- return false, fmt.Errorf(("%s: %s"), opts.mp.GetNamespaceRepository(), "digest is empty")
+ return false, fmt.Errorf(("%s: %s"), opts.n.DisplayNamespaceModel(), "digest is empty")
}
// Check if this is a HuggingFace download (digest starts with "hf:")
@@ -477,7 +479,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
return downloadHuggingFaceBlob(ctx, opts)
}
- fp, err := GetBlobsPath(opts.digest)
+ fp, err := manifest.BlobsPath(opts.digest)
if err != nil {
return false, err
}
@@ -501,8 +503,8 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
download := data.(*blobDownload)
if !ok {
- requestURL := opts.mp.BaseURL()
- requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
+ requestURL := opts.n.BaseURL()
+ requestURL = requestURL.JoinPath("v2", opts.n.DisplayNamespaceModel(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest)
return false, err
@@ -698,7 +700,7 @@ func downloadHuggingFaceBlob(ctx context.Context, opts downloadOpts) (cacheHit b
hfDigestMap.Store(opts.digest, digest)
// Move the file to the blobs directory
- fp, err := GetBlobsPath(digest)
+ fp, err := manifest.BlobsPath(digest)
if err != nil {
return false, err
}
diff --git a/server/images.go b/server/images.go
index 6b8e77acf01..c69412f3846 100644
--- a/server/images.go
+++ b/server/images.go
@@ -24,6 +24,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/gguf"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template"
@@ -282,11 +283,8 @@ func (m *Model) String() string {
return modelfile.String()
}
-func GetManifest(mp ModelPath) (*Manifest, string, error) {
- fp, err := mp.GetManifestPath()
- if err != nil {
- return nil, "", err
- }
+func GetManifest(n model.Name) (*manifest.Manifest, string, error) {
+ fp := n.Filepath()
f, err := os.Open(fp)
if err != nil {
@@ -296,30 +294,30 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) {
sha256sum := sha256.New()
- var manifest Manifest
- if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifest); err != nil {
+ var manifestFile manifest.Manifest
+ if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&manifestFile); err != nil {
return nil, "", err
}
- return &manifest, hex.EncodeToString(sha256sum.Sum(nil)), nil
+ return &manifestFile, hex.EncodeToString(sha256sum.Sum(nil)), nil
}
func GetModel(name string) (*Model, error) {
- mp := ParseModelPath(name)
- manifest, digest, err := GetManifest(mp)
+ n := model.ParseName(name)
+ mf, err := manifest.ParseNamedManifest(n)
if err != nil {
return nil, err
}
- model := &Model{
- Name: mp.GetFullTagname(),
- ShortName: mp.GetShortTagname(),
- Digest: digest,
+ m := &Model{
+ Name: n.String(),
+ ShortName: n.DisplayShortest(),
+ Digest: mf.Digest(),
Template: template.DefaultTemplate,
}
- if manifest.Config.Digest != "" {
- filename, err := GetBlobsPath(manifest.Config.Digest)
+ if mf.Config.Digest != "" {
+ filename, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
return nil, err
}
@@ -330,15 +328,15 @@ func GetModel(name string) (*Model, error) {
}
defer configFile.Close()
- if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
+ if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
return nil, err
}
}
readMainModelFlag := false
- for _, layer := range manifest.Layers {
- filename, err := GetBlobsPath(layer.Digest)
+ for _, layer := range mf.Layers {
+ filename, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -346,20 +344,20 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType {
case "application/vnd.ollama.image.model":
if !readMainModelFlag {
- model.ModelPath = filename
- model.ParentModel = layer.From
+ m.ModelPath = filename
+ m.ParentModel = layer.From
readMainModelFlag = true
} else {
- model.ExtraModelPaths = append(model.ExtraModelPaths, filename)
+ m.ExtraModelPaths = append(m.ExtraModelPaths, filename)
}
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
- model.AdapterPaths = append(model.AdapterPaths, filename)
+ m.AdapterPaths = append(m.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
- model.ProjectorPaths = append(model.ProjectorPaths, filename)
+ m.ProjectorPaths = append(m.ProjectorPaths, filename)
case "application/vnd.ollama.image.prompt",
"application/vnd.ollama.image.template":
bts, err := os.ReadFile(filename)
@@ -367,7 +365,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
- model.Template, err = template.Parse(string(bts))
+ m.Template, err = template.Parse(string(bts))
if err != nil {
return nil, err
}
@@ -377,7 +375,7 @@ func GetModel(name string) (*Model, error) {
return nil, err
}
- model.System = string(bts)
+ m.System = string(bts)
case "application/vnd.ollama.image.params":
params, err := os.Open(filename)
if err != nil {
@@ -386,7 +384,7 @@ func GetModel(name string) (*Model, error) {
defer params.Close()
// parse model options parameters into a map so that we can see which fields have been specified explicitly
- if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
+ if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.messages":
@@ -396,7 +394,7 @@ func GetModel(name string) (*Model, error) {
}
defer msgs.Close()
- if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
+ if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
@@ -404,11 +402,11 @@ func GetModel(name string) (*Model, error) {
if err != nil {
return nil, err
}
- model.License = append(model.License, string(bts))
+ m.License = append(m.License, string(bts))
}
}
- return model, nil
+ return m, nil
}
func CopyModel(src, dst model.Name) error {
@@ -423,7 +421,7 @@ func CopyModel(src, dst model.Name) error {
return nil
}
- manifests, err := GetManifestPath()
+ manifests, err := manifest.Path()
if err != nil {
return err
}
@@ -452,7 +450,7 @@ func CopyModel(src, dst model.Name) error {
func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// Ignore corrupt manifests to avoid blocking deletion of layers that are freshly orphaned
- manifests, err := Manifests(true)
+ manifests, err := manifest.Manifests(true)
if err != nil {
return err
}
@@ -467,7 +465,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
// only delete the files which are still in the deleteMap
for k := range deleteMap {
- fp, err := GetBlobsPath(k)
+ fp, err := manifest.BlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
continue
@@ -483,7 +481,7 @@ func deleteUnusedLayers(deleteMap map[string]struct{}) error {
func PruneLayers() error {
deleteMap := make(map[string]struct{})
- p, err := GetBlobsPath("")
+ p, err := manifest.BlobsPath("")
if err != nil {
return err
}
@@ -498,9 +496,9 @@ func PruneLayers() error {
name := blob.Name()
name = strings.ReplaceAll(name, "-", ":")
- _, err := GetBlobsPath(name)
+ _, err := manifest.BlobsPath(name)
if err != nil {
- if errors.Is(err, ErrInvalidDigestFormat) {
+ if errors.Is(err, manifest.ErrInvalidDigestFormat) {
// remove invalid blobs (e.g. partial downloads)
if err := os.Remove(filepath.Join(p, blob.Name())); err != nil {
slog.Error("couldn't remove blob", "blob", blob.Name(), "error", err)
@@ -559,29 +557,29 @@ func PruneDirectory(path string) error {
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
- mp := ParseModelPath(name)
+ n := model.ParseName(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
- if mp.ProtocolScheme == "http" && !regOpts.Insecure {
+ if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
- manifest, _, err := GetManifest(mp)
+ mf, err := manifest.ParseNamedManifest(n)
if err != nil {
fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
return err
}
- var layers []Layer
- layers = append(layers, manifest.Layers...)
- if manifest.Config.Digest != "" {
- layers = append(layers, manifest.Config)
+ var layers []manifest.Layer
+ layers = append(layers, mf.Layers...)
+ if mf.Config.Digest != "" {
+ layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
// Read raw manifest JSON to preserve tensor metadata fields
- manifestPath, err := mp.GetManifestPath()
+ manifestPath, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -589,7 +587,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if err != nil {
return err
}
- if err := pushWithTransfer(ctx, mp, layers, manifestJSON, regOpts, fn); err != nil {
+ if err := pushWithTransfer(ctx, n, layers, manifestJSON, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -597,17 +595,17 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
for _, layer := range layers {
- if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
+ if err := uploadBlob(ctx, n, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
return err
}
}
fn(api.ProgressResponse{Status: "pushing manifest"})
- requestURL := mp.BaseURL()
- requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
+ requestURL := n.BaseURL()
+ requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
- manifestJSON, err := json.Marshal(manifest)
+ manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
@@ -626,44 +624,44 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
- mp := ParseModelPath(name)
+ n := model.ParseName(name)
// build deleteMap to prune unused layers
deleteMap := make(map[string]struct{})
- manifest, _, err := GetManifest(mp)
+ existingMf, err := manifest.ParseNamedManifest(n)
if errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
slog.Warn("pulling model with bad existing manifest", "name", name, "error", err)
} else {
- for _, l := range manifest.Layers {
+ for _, l := range existingMf.Layers {
deleteMap[l.Digest] = struct{}{}
}
- if manifest.Config.Digest != "" {
- deleteMap[manifest.Config.Digest] = struct{}{}
+ if existingMf.Config.Digest != "" {
+ deleteMap[existingMf.Config.Digest] = struct{}{}
}
}
- if mp.ProtocolScheme == "http" && !regOpts.Insecure {
+ if n.ProtocolScheme == "http" && !regOpts.Insecure {
return errInsecureProtocol
}
fn(api.ProgressResponse{Status: "pulling manifest"})
- manifest, err = pullModelManifest(ctx, mp, regOpts)
+ mf, err := pullModelManifest(ctx, n, regOpts)
if err != nil {
return fmt.Errorf("pull model manifest: %s", err)
}
- var layers []Layer
- layers = append(layers, manifest.Layers...)
- if manifest.Config.Digest != "" {
- layers = append(layers, manifest.Config)
+ var layers []manifest.Layer
+ layers = append(layers, mf.Layers...)
+ if mf.Config.Digest != "" {
+ layers = append(layers, mf.Config)
}
// Use fast transfer for models with tensor layers (many small blobs)
if hasTensorLayers(layers) {
- if err := pullWithTransfer(ctx, mp, layers, manifest, regOpts, fn); err != nil {
+ if err := pullWithTransfer(ctx, n, layers, mf, regOpts, fn); err != nil {
return err
}
fn(api.ProgressResponse{Status: "success"})
@@ -671,11 +669,11 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
skipVerify := make(map[string]bool)
- isHF := isHuggingFaceRegistry(mp.Registry)
+ isHF := isHuggingFaceRegistry(n.Host)
for i, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{
- mp: mp,
+ n: n,
digest: layer.Digest,
regOpts: regOpts,
fn: fn,
@@ -690,10 +688,10 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layer.Digest = realDigest.(string)
layers[i].Digest = realDigest.(string)
// Update the manifest layers
- for j := range manifest.Layers {
- if strings.HasPrefix(manifest.Layers[j].Digest, "hf:") {
- if rd, ok := hfDigestMap.Load(manifest.Layers[j].Digest); ok {
- manifest.Layers[j].Digest = rd.(string)
+ for j := range mf.Layers {
+ if strings.HasPrefix(mf.Layers[j].Digest, "hf:") {
+ if rd, ok := hfDigestMap.Load(mf.Layers[j].Digest); ok {
+ mf.Layers[j].Digest = rd.(string)
}
}
}
@@ -711,7 +709,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) {
- fp, err := GetBlobsPath(layer.Digest)
+ fp, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return err
}
@@ -726,16 +724,16 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
delete(deleteMap, layer.Digest)
}
- delete(deleteMap, manifest.Config.Digest)
+ delete(deleteMap, mf.Config.Digest)
fn(api.ProgressResponse{Status: "writing manifest"})
- manifestJSON, err := json.Marshal(manifest)
+ manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
- fp, err := mp.GetManifestPath()
+ fp, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -762,9 +760,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
// hasTensorLayers checks if any layer has tensor media type.
-func hasTensorLayers(layers []Layer) bool {
+func hasTensorLayers(layers []manifest.Layer) bool {
for _, layer := range layers {
- if layer.MediaType == MediaTypeImageTensor {
+ if layer.MediaType == manifest.MediaTypeImageTensor {
return true
}
}
@@ -772,7 +770,7 @@ func hasTensorLayers(layers []Layer) bool {
}
// pullWithTransfer uses the simplified x/transfer package for downloading blobs.
-func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifest *Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
+func pullWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, mf *manifest.Manifest, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -781,12 +779,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
- destDir, err := GetBlobsPath("")
+ destDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
- base := mp.BaseURL()
+ base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -818,7 +816,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
Blobs: blobs,
BaseURL: baseURL,
DestDir: destDir,
- Repository: mp.GetNamespaceRepository(),
+ Repository: n.DisplayNamespaceModel(),
Progress: progress,
Token: regOpts.Token,
GetToken: getToken,
@@ -829,12 +827,12 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
// Write manifest
fn(api.ProgressResponse{Status: "writing manifest"})
- manifestJSON, err := json.Marshal(manifest)
+ manifestJSON, err := json.Marshal(mf)
if err != nil {
return err
}
- fp, err := mp.GetManifestPath()
+ fp, err := manifest.PathForName(n)
if err != nil {
return err
}
@@ -846,7 +844,7 @@ func pullWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
// pushWithTransfer uses the simplified x/transfer package for uploading blobs and manifest.
-func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
+func pushWithTransfer(ctx context.Context, n model.Name, layers []manifest.Layer, manifestJSON []byte, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
blobs := make([]transfer.Blob, len(layers))
for i, layer := range layers {
blobs[i] = transfer.Blob{
@@ -856,12 +854,12 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
}
}
- srcDir, err := GetBlobsPath("")
+ srcDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
- base := mp.BaseURL()
+ base := n.BaseURL()
if base.Scheme != "http" && regOpts != nil && regOpts.Insecure {
base.Scheme = "http"
}
@@ -898,18 +896,18 @@ func pushWithTransfer(ctx context.Context, mp ModelPath, layers []Layer, manifes
GetToken: getToken,
Logger: slog.Default(),
Manifest: manifestJSON,
- ManifestRef: mp.Tag,
- Repository: mp.GetNamespaceRepository(),
+ ManifestRef: n.Tag,
+ Repository: n.DisplayNamespaceModel(),
})
}
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
+func pullModelManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
// Check if this is a HuggingFace registry
- if isHuggingFaceRegistry(mp.Registry) {
- return pullHuggingFaceManifest(ctx, mp, regOpts)
+ if isHuggingFaceRegistry(n.Host) {
+ return pullHuggingFaceManifest(ctx, n, regOpts)
}
- requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
+ requestURL := n.BaseURL().JoinPath("v2", n.DisplayNamespaceModel(), "manifests", n.Tag)
headers := make(http.Header)
headers.Set("Accept", "application/vnd.docker.distribution.manifest.v2+json")
@@ -919,7 +917,7 @@ func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptio
}
defer resp.Body.Close()
- var m Manifest
+ var m manifest.Manifest
if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
return nil, err
}
@@ -945,14 +943,14 @@ type HFFileInfo struct {
}
// pullHuggingFaceManifest pulls a model manifest from HuggingFace
-func pullHuggingFaceManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*Manifest, error) {
+func pullHuggingFaceManifest(ctx context.Context, n model.Name, regOpts *registryOptions) (*manifest.Manifest, error) {
// For HuggingFace, the tag might be "main" or could include a subdirectory like "BF16"
// We'll use "main" as the revision and the tag as the subdirectory filter
revision := "main"
- subdirFilter := mp.Tag
+ subdirFilter := n.Tag
// Query HuggingFace API for file tree (always use main revision, recursive)
- apiURL := fmt.Sprintf("https://huggingface.co/api/models/%s/tree/%s?recursive=true", mp.GetNamespaceRepository(), revision)
+ apiURL := fmt.Sprintf("https://huggingface.co/api/models/%s/tree/%s?recursive=true", n.DisplayNamespaceModel(), revision)
req, err := http.NewRequestWithContext(ctx, "GET", apiURL, nil)
if err != nil {
@@ -1010,9 +1008,9 @@ func pullHuggingFaceManifest(ctx context.Context, mp ModelPath, regOpts *registr
// Check if these are split GGUF files
shardSets, singles := parser.GroupGGUFShards(extractPaths(ggufFiles))
- var manifest Manifest
- manifest.SchemaVersion = 2
- manifest.MediaType = "application/vnd.docker.distribution.manifest.v2+json"
+ var mf manifest.Manifest
+ mf.SchemaVersion = 2
+ mf.MediaType = "application/vnd.docker.distribution.manifest.v2+json"
// Handle split GGUF files
if len(shardSets) > 0 {
@@ -1034,7 +1032,7 @@ func pullHuggingFaceManifest(ctx context.Context, mp ModelPath, regOpts *registr
}
// Create a layer for this shard
- layer := Layer{
+ layer := manifest.Layer{
MediaType: "application/vnd.ollama.image.model",
Size: fileInfo.Size,
Digest: "", // Will be computed during download
@@ -1042,9 +1040,9 @@ func pullHuggingFaceManifest(ctx context.Context, mp ModelPath, regOpts *registr
// Store the HuggingFace download URL in the layer
// We'll use the digest field temporarily to store the download path
- layer.Digest = fmt.Sprintf("hf:%s/%s/%s", mp.GetNamespaceRepository(), mp.Tag, fileInfo.Path)
+ layer.Digest = fmt.Sprintf("hf:%s/%s/%s", n.DisplayNamespaceModel(), n.Tag, fileInfo.Path)
- manifest.Layers = append(manifest.Layers, layer)
+ mf.Layers = append(mf.Layers, layer)
}
} else if len(singles) > 0 {
// Single GGUF file
@@ -1062,16 +1060,16 @@ func pullHuggingFaceManifest(ctx context.Context, mp ModelPath, regOpts *registr
return nil, fmt.Errorf("GGUF file info not found")
}
- layer := Layer{
+ layer := manifest.Layer{
MediaType: "application/vnd.ollama.image.model",
Size: fileInfo.Size,
- Digest: fmt.Sprintf("hf:%s/%s/%s", mp.GetNamespaceRepository(), mp.Tag, fileInfo.Path),
+ Digest: fmt.Sprintf("hf:%s/%s/%s", n.DisplayNamespaceModel(), n.Tag, fileInfo.Path),
}
- manifest.Layers = append(manifest.Layers, layer)
+ mf.Layers = append(mf.Layers, layer)
}
- return &manifest, nil
+ return &mf, nil
}
// extractPaths extracts file paths from HFFileInfo slice
@@ -1237,7 +1235,7 @@ func parseRegistryChallenge(authStr string) registryChallenge {
var errDigestMismatch = errors.New("digest mismatch, file must be downloaded again")
func verifyBlob(digest string) error {
- fp, err := GetBlobsPath(digest)
+ fp, err := manifest.BlobsPath(digest)
if err != nil {
return err
}
diff --git a/server/images_test.go b/server/images_test.go
index 9e581c8c35f..639cf866220 100644
--- a/server/images_test.go
+++ b/server/images_test.go
@@ -56,6 +56,15 @@ func TestModelCapabilities(t *testing.T) {
},
expectedCaps: []model.Capability{model.CapabilityImage},
},
+ {
+ name: "model with image and vision capability (image editing)",
+ model: Model{
+ Config: model.ConfigV2{
+ Capabilities: []string{"image", "vision"},
+ },
+ },
+ expectedCaps: []model.Capability{model.CapabilityImage, model.CapabilityVision},
+ },
{
name: "model with completion capability",
model: Model{
diff --git a/server/model.go b/server/model.go
index 401547e4ef9..57190ffe083 100644
--- a/server/model.go
+++ b/server/model.go
@@ -13,6 +13,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
@@ -20,19 +21,19 @@ import (
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
- Layer
+ manifest.Layer
*ggml.GGML
}
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
- m, err := ParseNamedManifest(name)
+ m, err := manifest.ParseNamedManifest(name)
switch {
case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
return nil, err
}
- m, err = ParseNamedManifest(name)
+ m, err = manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -41,7 +42,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
for _, layer := range m.Layers {
- layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
+ layer, err := manifest.NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
if err != nil {
return nil, err
}
@@ -50,7 +51,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
case "application/vnd.ollama.image.model",
"application/vnd.ollama.image.projector",
"application/vnd.ollama.image.adapter":
- blobpath, err := GetBlobsPath(layer.Digest)
+ blobpath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
return nil, err
}
@@ -81,12 +82,12 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
} else {
- layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
+ layer, err := manifest.NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return nil, err
}
- layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
+ layer.Status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil})
if t.Parameters != nil {
@@ -95,7 +96,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
return nil, err
}
- layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
+ layer, err := manifest.NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
diff --git a/server/modelpath.go b/server/modelpath.go
deleted file mode 100644
index af82b8b3b43..00000000000
--- a/server/modelpath.go
+++ /dev/null
@@ -1,146 +0,0 @@
-package server
-
-import (
- "errors"
- "fmt"
- "io/fs"
- "net/url"
- "os"
- "path/filepath"
- "regexp"
- "strings"
-
- "github.com/ollama/ollama/envconfig"
- "github.com/ollama/ollama/types/model"
-)
-
-type ModelPath struct {
- ProtocolScheme string
- Registry string
- Namespace string
- Repository string
- Tag string
-}
-
-const (
- DefaultRegistry = "registry.ollama.ai"
- DefaultNamespace = "library"
- DefaultTag = "latest"
- DefaultProtocolScheme = "https"
-)
-
-var (
- ErrInvalidImageFormat = errors.New("invalid image format")
- ErrInvalidDigestFormat = errors.New("invalid digest format")
- ErrInvalidProtocol = errors.New("invalid protocol scheme")
- ErrInsecureProtocol = errors.New("insecure protocol http")
- ErrModelPathInvalid = errors.New("invalid model path")
-)
-
-func ParseModelPath(name string) ModelPath {
- mp := ModelPath{
- ProtocolScheme: DefaultProtocolScheme,
- Registry: DefaultRegistry,
- Namespace: DefaultNamespace,
- Repository: "",
- Tag: DefaultTag,
- }
-
- before, after, found := strings.Cut(name, "://")
- if found {
- mp.ProtocolScheme = before
- name = after
- }
-
- name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
- parts := strings.Split(name, "/")
- switch len(parts) {
- case 3:
- mp.Registry = parts[0]
- mp.Namespace = parts[1]
- mp.Repository = parts[2]
- case 2:
- mp.Namespace = parts[0]
- mp.Repository = parts[1]
- case 1:
- mp.Repository = parts[0]
- }
-
- if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
- mp.Repository = repo
- mp.Tag = tag
- }
-
- return mp
-}
-
-func (mp ModelPath) GetNamespaceRepository() string {
- return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
-}
-
-func (mp ModelPath) GetFullTagname() string {
- return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
-}
-
-func (mp ModelPath) GetShortTagname() string {
- if mp.Registry == DefaultRegistry {
- if mp.Namespace == DefaultNamespace {
- return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
- }
- return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
- }
- return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
-}
-
-// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
-func (mp ModelPath) GetManifestPath() (string, error) {
- name := model.Name{
- Host: mp.Registry,
- Namespace: mp.Namespace,
- Model: mp.Repository,
- Tag: mp.Tag,
- }
- if !name.IsValid() {
- return "", fs.ErrNotExist
- }
- return filepath.Join(envconfig.Models(), "manifests", name.Filepath()), nil
-}
-
-func (mp ModelPath) BaseURL() *url.URL {
- return &url.URL{
- Scheme: mp.ProtocolScheme,
- Host: mp.Registry,
- }
-}
-
-func GetManifestPath() (string, error) {
- path := filepath.Join(envconfig.Models(), "manifests")
- if err := os.MkdirAll(path, 0o755); err != nil {
- return "", fmt.Errorf("%w: ensure path elements are traversable", err)
- }
-
- return path, nil
-}
-
-func GetBlobsPath(digest string) (string, error) {
- // only accept actual sha256 digests
- pattern := "^sha256[:-][0-9a-fA-F]{64}$"
- re := regexp.MustCompile(pattern)
-
- if digest != "" && !re.MatchString(digest) {
- return "", ErrInvalidDigestFormat
- }
-
- digest = strings.ReplaceAll(digest, ":", "-")
- path := filepath.Join(envconfig.Models(), "blobs", digest)
- dirPath := filepath.Dir(path)
- if digest == "" {
- dirPath = path
- }
-
- if err := os.MkdirAll(dirPath, 0o755); err != nil {
- return "", fmt.Errorf("%w: ensure path elements are traversable", err)
- }
-
- return path, nil
-}
diff --git a/server/modelpath_test.go b/server/modelpath_test.go
deleted file mode 100644
index 96429f958d9..00000000000
--- a/server/modelpath_test.go
+++ /dev/null
@@ -1,153 +0,0 @@
-package server
-
-import (
- "path/filepath"
- "testing"
-
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
-)
-
-func TestGetBlobsPath(t *testing.T) {
- // GetBlobsPath expects an actual directory to exist
- tempDir := t.TempDir()
-
- tests := []struct {
- name string
- digest string
- expected string
- err error
- }{
- {
- "empty digest",
- "",
- filepath.Join(tempDir, "blobs"),
- nil,
- },
- {
- "valid with colon",
- "sha256:456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
- filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
- nil,
- },
- {
- "valid with dash",
- "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9",
- filepath.Join(tempDir, "blobs", "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9"),
- nil,
- },
- {
- "digest too short",
- "sha256-45640291",
- "",
- ErrInvalidDigestFormat,
- },
- {
- "digest too long",
- "sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7aad9aaaaaaaaaa",
- "",
- ErrInvalidDigestFormat,
- },
- {
- "digest invalid chars",
- "../sha256-456402914e838a953e0cf80caa6adbe75383d9e63584a964f504a7bbb8f7a",
- "",
- ErrInvalidDigestFormat,
- },
- }
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- t.Setenv("OLLAMA_MODELS", tempDir)
-
- got, err := GetBlobsPath(tc.digest)
-
- require.ErrorIs(t, tc.err, err, tc.name)
- assert.Equal(t, tc.expected, got, tc.name)
- })
- }
-}
-
-func TestParseModelPath(t *testing.T) {
- tests := []struct {
- name string
- arg string
- want ModelPath
- }{
- {
- "full path https",
- "https://example.com/ns/repo:tag",
- ModelPath{
- ProtocolScheme: "https",
- Registry: "example.com",
- Namespace: "ns",
- Repository: "repo",
- Tag: "tag",
- },
- },
- {
- "full path http",
- "http://example.com/ns/repo:tag",
- ModelPath{
- ProtocolScheme: "http",
- Registry: "example.com",
- Namespace: "ns",
- Repository: "repo",
- Tag: "tag",
- },
- },
- {
- "no protocol",
- "example.com/ns/repo:tag",
- ModelPath{
- ProtocolScheme: "https",
- Registry: "example.com",
- Namespace: "ns",
- Repository: "repo",
- Tag: "tag",
- },
- },
- {
- "no registry",
- "ns/repo:tag",
- ModelPath{
- ProtocolScheme: "https",
- Registry: DefaultRegistry,
- Namespace: "ns",
- Repository: "repo",
- Tag: "tag",
- },
- },
- {
- "no namespace",
- "repo:tag",
- ModelPath{
- ProtocolScheme: "https",
- Registry: DefaultRegistry,
- Namespace: DefaultNamespace,
- Repository: "repo",
- Tag: "tag",
- },
- },
- {
- "no tag",
- "repo",
- ModelPath{
- ProtocolScheme: "https",
- Registry: DefaultRegistry,
- Namespace: DefaultNamespace,
- Repository: "repo",
- Tag: DefaultTag,
- },
- },
- }
-
- for _, tc := range tests {
- t.Run(tc.name, func(t *testing.T) {
- got := ParseModelPath(tc.arg)
-
- if got != tc.want {
- t.Errorf("got: %q want: %q", got, tc.want)
- }
- })
- }
-}
diff --git a/server/prompt.go b/server/prompt.go
index 21759198217..bc12f4d5d28 100644
--- a/server/prompt.go
+++ b/server/prompt.go
@@ -27,14 +27,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// Clip images are represented as 768 tokens, each an embedding
imageNumTokens := 768
- n := len(msgs) - 1
- // in reverse, find all messages that fit into context window
- for i := n; i >= 0; i-- {
- // always include the last message
- if i == n {
- continue
- }
+ lastMsgIdx := len(msgs) - 1
+ currMsgIdx := 0
+ // Start with all messages and remove from the front until it fits in context
+ for i := 0; i <= lastMsgIdx; i++ {
+ // Collect system messages from the portion we're about to skip
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
@@ -54,20 +52,26 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
ctxLen := len(s)
if m.ProjectorPaths != nil {
- for _, m := range msgs[i:] {
- ctxLen += imageNumTokens * len(m.Images)
+ for _, msg := range msgs[i:] {
+ ctxLen += imageNumTokens * len(msg.Images)
}
}
- if truncate && ctxLen > opts.NumCtx {
- slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
+ if !truncate || ctxLen <= opts.NumCtx {
+ currMsgIdx = i
+ break
+ }
+
+ // Must always include at least the last message
+ if i == lastMsgIdx {
+ currMsgIdx = lastMsgIdx
break
- } else {
- n = i
}
}
- currMsgIdx := n
+ if currMsgIdx > 0 {
+ slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[currMsgIdx:]))
+ }
for cnt, msg := range msgs[currMsgIdx:] {
if slices.Contains(m.Config.ModelFamilies, "mllama") && len(msg.Images) > 1 {
diff --git a/server/prompt_test.go b/server/prompt_test.go
index 3bd621152b8..082667b83c2 100644
--- a/server/prompt_test.go
+++ b/server/prompt_test.go
@@ -2,6 +2,7 @@ package server
import (
"bytes"
+ "context"
"testing"
"github.com/google/go-cmp/cmp"
@@ -264,3 +265,68 @@ func TestChatPrompt(t *testing.T) {
})
}
}
+
+func TestChatPromptTokenizeCalls(t *testing.T) {
+ tmpl, err := template.Parse(`
+{{- if .System }}{{ .System }} {{ end }}
+{{- if .Prompt }}{{ .Prompt }} {{ end }}
+{{- if .Response }}{{ .Response }} {{ end }}`)
+ if err != nil {
+ t.Fatal(err)
+ }
+ model := Model{Template: tmpl}
+
+ cases := []struct {
+ name string
+ limit int
+ msgs []api.Message
+ maxTokenizes int
+ }{
+ {
+ name: "all messages fit",
+ limit: 2048,
+ msgs: []api.Message{
+ {Role: "user", Content: "message 1"},
+ {Role: "assistant", Content: "response 1"},
+ {Role: "user", Content: "message 2"},
+ {Role: "assistant", Content: "response 2"},
+ {Role: "user", Content: "message 3"},
+ },
+ maxTokenizes: 1,
+ },
+ {
+ name: "truncate to last message",
+ limit: 5,
+ msgs: []api.Message{
+ {Role: "user", Content: "message 1"},
+ {Role: "assistant", Content: "response 1"},
+ {Role: "user", Content: "message 2"},
+ {Role: "assistant", Content: "response 2"},
+ {Role: "user", Content: "message 3"},
+ },
+ maxTokenizes: 5,
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ tokenizeCount := 0
+ countingTokenize := func(ctx context.Context, s string) ([]int, error) {
+ tokenizeCount++
+ tokens, err := mockRunner{}.Tokenize(ctx, s)
+ return tokens, err
+ }
+
+ opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
+ think := false
+ _, _, err := chatPrompt(t.Context(), &model, countingTokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, true)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if tokenizeCount > tt.maxTokenizes {
+ t.Errorf("tokenize called %d times, expected at most %d", tokenizeCount, tt.maxTokenizes)
+ }
+ })
+ }
+}
diff --git a/server/quantization.go b/server/quantization.go
index b15451d7e48..76c54d8f698 100644
--- a/server/quantization.go
+++ b/server/quantization.go
@@ -58,6 +58,48 @@ func useMoreBits(iLayer, nLayers int) bool {
return iLayer < (nLayers/8) || iLayer >= 7*nLayers/8 || (iLayer-nLayers/8)%3 == 2
}
+func qwen3nextQuantType(name string) (fsggml.TensorType, bool) {
+ switch {
+ // Full attention
+ case strings.HasSuffix(name, ".attn_q.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".attn_k.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".attn_v.weight"):
+ return fsggml.TensorTypeQ6_K, true
+ case strings.HasSuffix(name, ".attn_output.weight"):
+ return fsggml.TensorTypeQ4_K, true
+
+ // Linear attention (Gated Delta Net) after split
+ case strings.HasSuffix(name, ".attn_qkv.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".attn_gate.weight"):
+ return fsggml.TensorTypeQ4_K, true
+
+ // SSM
+ case strings.HasSuffix(name, ".ssm_ba.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".ssm_out.weight"):
+ return fsggml.TensorTypeQ4_K, true
+
+ // MoE experts + shared experts
+ case strings.HasSuffix(name, ".ffn_down_exps.weight"):
+ return fsggml.TensorTypeQ6_K, true
+ case strings.HasSuffix(name, ".ffn_down_shexp.weight"):
+ return fsggml.TensorTypeQ6_K, true
+ case strings.HasSuffix(name, ".ffn_gate_exps.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".ffn_gate_shexp.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".ffn_up_exps.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ case strings.HasSuffix(name, ".ffn_up_shexp.weight"):
+ return fsggml.TensorTypeQ4_K, true
+ }
+
+ return 0, false
+}
+
func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType, name string, shape []uint64, ftype fsggml.FileType) fsggml.TensorType {
// Ported from llama_tensor_get_type, removed unsupported quantization types
nExperts := max(1, kv.Uint("expert_count", 0))
@@ -95,6 +137,13 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
// for the 8-expert model, bumping this to Q8_0 trades just ~128MB
newType = fsggml.TensorTypeQ8_0
}
+ } else if strings.Contains(name, "attn_k_b.weight") ||
+ strings.Contains(name, "attn_v_b.weight") ||
+ strings.Contains(name, "attn_kv_a_mqa.weight") ||
+ strings.Contains(name, "attn_q_a.weight") ||
+ strings.Contains(name, "attn_q_b.weight") {
+ // MLA tensors need higher precision to avoid quality degradation
+ newType = fsggml.TensorTypeQ8_0
} else if strings.Contains(name, "ffn_down") {
iLayer := qs.iFfnDown
n_layer := qs.nFfnDown
@@ -198,8 +247,8 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
name := t.Name
quantize := strings.HasSuffix(name, "weight")
- // don't quantize vision stuff
- quantize = quantize && (!strings.Contains(name, "v.") || strings.Contains(name, "_v."))
+ // don't quantize vision encoder tensors (named with "v." prefix)
+ quantize = quantize && !strings.HasPrefix(name, "v.")
quantize = quantize && !strings.Contains(name, "mm.")
// quantize only 2D and 3D tensors (experts)
@@ -210,6 +259,7 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// do not quantize expert gating tensors
quantize = quantize && !strings.Contains(name, "ffn_gate_inp.weight")
+ quantize = quantize && !strings.Contains(name, "ffn_gate_inp_shexp.weight")
// do not quantize positional embeddings and token types (BERT)
quantize = quantize && (name != "position_embd.weight")
@@ -219,6 +269,9 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
// NOTE: can't use LLM_TN here because the layer number is not known
quantize = quantize && !strings.Contains(name, "ssm_conv1d.weight")
+ // do not quantize LFM2's shortconv kernel weights
+ quantize = quantize && !strings.Contains(name, "shortconv.conv.weight")
+
// do not quantize RWKV's time_mix_first tensors
quantize = quantize && !strings.Contains(name, "time_mix_first.weight")
quantize = quantize && !strings.Contains(name, "time_mix_w1.weight")
@@ -234,6 +287,12 @@ func newType(t *fsggml.Tensor, kv fsggml.KV, qs *quantizeState, ftype fsggml.Fil
newType := fsggml.TensorType(t.Kind)
if quantize {
+ if kv.Architecture() == "qwen3next" && (ftype == fsggml.FileTypeQ4_K_M || ftype == fsggml.FileTypeQ4_K_S) {
+ if qt, ok := qwen3nextQuantType(name); ok {
+ return qt
+ }
+ }
+
// get more optimal quantization type based on the tensor shape, layer, etc.
newType = getTensorNewType(kv, qs, defaultType, t.Name, t.Shape, ftype)
if newType != defaultType {
diff --git a/server/routes.go b/server/routes.go
index 605ff2cc3fb..17d6635dfe6 100644
--- a/server/routes.go
+++ b/server/routes.go
@@ -22,6 +22,7 @@ import (
"os/signal"
"slices"
"strings"
+ "sync"
"sync/atomic"
"syscall"
"time"
@@ -37,8 +38,10 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/model/renderers"
@@ -50,12 +53,17 @@ import (
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
- "github.com/ollama/ollama/x/imagegen"
+ imagegenmanifest "github.com/ollama/ollama/x/imagegen/manifest"
xserver "github.com/ollama/ollama/x/server"
)
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
+const (
+ cloudErrRemoteInferenceUnavailable = "remote model is unavailable"
+ cloudErrRemoteModelDetailsUnavailable = "remote model details are unavailable"
+)
+
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
@@ -74,16 +82,15 @@ func experimentEnabled(name string) bool {
var useClient2 = experimentEnabled("client2")
-// Low VRAM mode is based on the sum of total VRAM (not free) and triggers
-// reduced context length on some models
-var lowVRAMThreshold uint64 = 20 * format.GibiByte
-
var mode string = gin.DebugMode
type Server struct {
- addr net.Addr
- sched *Scheduler
- lowVRAM bool
+ addr net.Addr
+ sched *Scheduler
+ defaultNumCtx int
+ aliasesOnce sync.Once
+ aliases *store
+ aliasesErr error
}
func init() {
@@ -106,8 +113,12 @@ var (
errBadTemplate = errors.New("template error")
)
-func modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
+func (s *Server) modelOptions(model *Model, requestOpts map[string]any) (api.Options, error) {
opts := api.DefaultOptions()
+ if opts.NumCtx == 0 {
+ opts.NumCtx = s.defaultNumCtx
+ }
+
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
@@ -139,21 +150,15 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C
return nil, nil, nil, fmt.Errorf("%s %w", name, err)
}
- opts, err := modelOptions(model, requestOpts)
+ useImagegen, _ := requestOpts["use_imagegen_runner"].(bool)
+ delete(requestOpts, "use_imagegen_runner")
+
+ opts, err := s.modelOptions(model, requestOpts)
if err != nil {
return nil, nil, nil, err
}
- // This model is much more capable with a larger context, so set that
- // unless it would penalize performance too much
- if !s.lowVRAM && slices.Contains([]string{
- "gptoss", "gpt-oss",
- "qwen3vl", "qwen3vlmoe",
- }, model.Config.ModelFamily) {
- opts.NumCtx = max(opts.NumCtx, 8192)
- }
-
- runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive)
+ runnerCh, errCh := s.sched.GetRunner(ctx, model, opts, keepAlive, useImagegen)
var runner *runnerRef
select {
case runner = <-runnerCh:
@@ -199,9 +204,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
+ resolvedName, _, err := s.resolveAlias(name)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+ name = resolvedName
+
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
- name, err := getExistingName(name)
+ name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
return
@@ -226,6 +238,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
+ if disabled, _ := internalcloud.Status(); disabled {
+ c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
+ return
+ }
+
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
@@ -974,7 +991,7 @@ func (s *Server) PushHandler(c *gin.Context) {
// is.
func getExistingName(n model.Name) (model.Name, error) {
var zero model.Name
- existing, err := Manifests(true)
+ existing, err := manifest.Manifests(true)
if err != nil {
return zero, err
}
@@ -1018,7 +1035,7 @@ func (s *Server) DeleteHandler(c *gin.Context) {
return
}
- m, err := ParseNamedManifest(n)
+ m, err := manifest.ParseNamedManifest(n)
if err != nil {
switch {
case os.IsNotExist(err):
@@ -1063,9 +1080,12 @@ func (s *Server) ShowHandler(c *gin.Context) {
resp, err := GetModelInfo(req)
if err != nil {
+ var statusErr api.StatusError
switch {
case os.IsNotExist(err):
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
+ case errors.As(err, &statusErr):
+ c.JSON(statusErr.StatusCode, gin.H{"error": statusErr.ErrorMessage})
case err.Error() == errtypes.InvalidModelNameErrMsg:
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
@@ -1080,7 +1100,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
- return nil, ErrModelPathInvalid
+ return nil, model.Unqualified(name)
}
name, err := getExistingName(name)
if err != nil {
@@ -1092,6 +1112,15 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
return nil, err
}
+ if m.Config.RemoteHost != "" {
+ if disabled, _ := internalcloud.Status(); disabled {
+ return nil, api.StatusError{
+ StatusCode: http.StatusForbidden,
+ ErrorMessage: internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable),
+ }
+ }
+ }
+
modelDetails := api.ModelDetails{
ParentModel: m.ParentModel,
Format: m.Config.ModelFormat,
@@ -1103,7 +1132,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For image generation models, populate details from imagegen package
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
- if info, err := imagegen.GetModelInfo(name.String()); err == nil {
+ if info, err := imagegenmanifest.GetModelInfo(name.String()); err == nil {
modelDetails.Family = info.Architecture
modelDetails.ParameterSize = format.HumanNumber(uint64(info.ParameterCount))
modelDetails.QuantizationLevel = info.Quantization
@@ -1112,7 +1141,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate details from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
- if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
+ if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
if arch, ok := info["general.architecture"].(string); ok && arch != "" {
modelDetails.Family = arch
}
@@ -1121,7 +1150,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
}
// Get torch_dtype directly from config.json for quantization level
- if dtype, err := xserver.GetSafetensorsDtype(name.String()); err == nil && dtype != "" {
+ if dtype, err := xserver.GetSafetensorsDtype(name); err == nil && dtype != "" {
modelDetails.QuantizationLevel = dtype
}
}
@@ -1135,7 +1164,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
}
- manifest, err := ParseNamedManifest(name)
+ mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, err
}
@@ -1147,7 +1176,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
Details: modelDetails,
Messages: msgs,
Capabilities: m.Capabilities(),
- ModifiedAt: manifest.fi.ModTime(),
+ ModifiedAt: mf.FileInfo().ModTime(),
Requires: m.Config.Requires,
// Several integrations crash on a nil/omitempty+empty ModelInfo, so by
// default we return an empty map.
@@ -1214,7 +1243,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
if slices.Contains(m.Capabilities(), model.CapabilityImage) {
// Populate tensor info if verbose
if req.Verbose {
- if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
+ if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1223,12 +1252,12 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
// For safetensors LLM models (experimental), populate ModelInfo from config.json
if m.Config.ModelFormat == "safetensors" && slices.Contains(m.Config.Capabilities, "completion") {
- if info, err := xserver.GetSafetensorsLLMInfo(name.String()); err == nil {
+ if info, err := xserver.GetSafetensorsLLMInfo(name); err == nil {
resp.ModelInfo = info
}
// Populate tensor info if verbose
if req.Verbose {
- if tensors, err := xserver.GetSafetensorsTensorInfo(name.String()); err == nil {
+ if tensors, err := xserver.GetSafetensorsTensorInfo(name); err == nil {
resp.Tensors = tensors
}
}
@@ -1285,7 +1314,7 @@ func getModelData(digest string, verbose bool) (ggml.KV, ggml.ForeignTensors, er
}
func (s *Server) ListHandler(c *gin.Context) {
- ms, err := Manifests(true)
+ ms, err := manifest.Manifests(true)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1316,8 +1345,8 @@ func (s *Server) ListHandler(c *gin.Context) {
RemoteModel: cf.RemoteModel,
RemoteHost: cf.RemoteHost,
Size: m.Size(),
- Digest: m.digest,
- ModifiedAt: m.fi.ModTime(),
+ Digest: m.Digest(),
+ ModifiedAt: m.FileInfo().ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
@@ -1376,7 +1405,7 @@ func (s *Server) CopyHandler(c *gin.Context) {
}
func (s *Server) HeadBlobHandler(c *gin.Context) {
- path, err := GetBlobsPath(c.Param("digest"))
+ path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1392,7 +1421,7 @@ func (s *Server) HeadBlobHandler(c *gin.Context) {
func (s *Server) CreateBlobHandler(c *gin.Context) {
if ib, ok := intermediateBlobs[c.Param("digest")]; ok {
- p, err := GetBlobsPath(ib)
+ p, err := manifest.BlobsPath(ib)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1410,7 +1439,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
}
}
- path, err := GetBlobsPath(c.Param("digest"))
+ path, err := manifest.BlobsPath(c.Param("digest"))
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -1428,7 +1457,7 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return
}
- layer, err := NewLayer(c.Request.Body, "")
+ layer, err := manifest.NewLayer(c.Request.Body, "")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -1568,6 +1597,7 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
+ r.GET("/api/status", s.StatusHandler)
// Local model cache management (new implementation is at end of function)
r.POST("/api/pull", s.PullHandler)
@@ -1588,6 +1618,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
+ r.GET("/api/experimental/aliases", s.ListAliasesHandler)
+ r.POST("/api/experimental/aliases", s.CreateAliasHandler)
+ r.DELETE("/api/experimental/aliases", s.DeleteAliasHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
@@ -1603,8 +1636,9 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/responses", middleware.ResponsesMiddleware(), s.ChatHandler)
- // OpenAI-compatible image generation endpoint
+ // OpenAI-compatible image generation endpoints
r.POST("/v1/images/generations", middleware.ImageGenerationsMiddleware(), s.GenerateHandler)
+ r.POST("/v1/images/edits", middleware.ImageEditsMiddleware(), s.GenerateHandler)
// Inference (Anthropic compatibility)
r.POST("/v1/messages", middleware.AnthropicMessagesMiddleware(), s.ChatHandler)
@@ -1627,8 +1661,10 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
func Serve(ln net.Listener) error {
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
slog.Info("server config", "env", envconfig.Values())
+ cloudDisabled, _ := internalcloud.Status()
+ slog.Info(fmt.Sprintf("Ollama cloud disabled: %t", cloudDisabled))
- blobsDir, err := GetBlobsPath("")
+ blobsDir, err := manifest.BlobsPath("")
if err != nil {
return err
}
@@ -1637,7 +1673,7 @@ func Serve(ln net.Listener) error {
}
if !envconfig.NoPrune() {
- if _, err := Manifests(false); err != nil {
+ if _, err := manifest.Manifests(false); err != nil {
slog.Warn("corrupt manifests detected, skipping prune operation. Re-pull or delete to clear", "error", err)
} else {
// clean up unused layers and manifests
@@ -1645,12 +1681,12 @@ func Serve(ln net.Listener) error {
return err
}
- manifestsPath, err := GetManifestPath()
+ manifestsPath, err := manifest.Path()
if err != nil {
return err
}
- if err := PruneDirectory(manifestsPath); err != nil {
+ if err := manifest.PruneDirectory(manifestsPath); err != nil {
return err
}
}
@@ -1723,10 +1759,18 @@ func Serve(ln net.Listener) error {
for _, gpu := range gpus {
totalVRAM += gpu.TotalMemory - envconfig.GpuOverhead()
}
- if totalVRAM < lowVRAMThreshold {
- s.lowVRAM = true
- slog.Info("entering low vram mode", "total vram", format.HumanBytes2(totalVRAM), "threshold", format.HumanBytes2(lowVRAMThreshold))
+
+ // Set default context based on VRAM tier
+ // Use slightly lower thresholds (47/23 GiB vs. 48/24 GiB) to account for small differences in the exact value
+ switch {
+ case totalVRAM >= 47*format.GibiByte:
+ s.defaultNumCtx = 262144
+ case totalVRAM >= 23*format.GibiByte:
+ s.defaultNumCtx = 32768
+ default:
+ s.defaultNumCtx = 4096
}
+ slog.Info("vram-based default context", "total_vram", format.HumanBytes2(totalVRAM), "default_num_ctx", s.defaultNumCtx)
err = srvr.Serve(ln)
// If server is closed from the signal handler, wait for the ctx to be done
@@ -1814,6 +1858,16 @@ func streamResponse(c *gin.Context, ch chan any) {
})
}
+func (s *Server) StatusHandler(c *gin.Context) {
+ disabled, source := internalcloud.Status()
+ c.JSON(http.StatusOK, api.StatusResponse{
+ Cloud: api.CloudStatus{
+ Disabled: disabled,
+ Source: source,
+ },
+ })
+}
+
func (s *Server) WhoamiHandler(c *gin.Context) {
// todo allow other hosts
u, err := url.Parse("https://ollama.com")
@@ -1900,8 +1954,8 @@ func (s *Server) PsHandler(c *gin.Context) {
Details: modelDetails,
ExpiresAt: v.expiresAt,
}
- if v.Options != nil {
- mr.ContextLength = v.Options.NumCtx
+ if v.llama != nil {
+ mr.ContextLength = v.llama.ContextLength()
}
// The scheduler waits to set expiresAt, so if a model is loading it's
// possible that it will be set to the unix epoch. For those cases, just
@@ -1954,13 +2008,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
- name, err := getExistingName(name)
+ resolvedName, _, err := s.resolveAlias(name)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+ name = resolvedName
+
+ name, err = getExistingName(name)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
}
- m, err := GetModel(req.Model)
+ m, err := GetModel(name.String())
if err != nil {
switch {
case os.IsNotExist(err):
@@ -1993,6 +2054,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" {
+ if disabled, _ := internalcloud.Status(); disabled {
+ c.JSON(http.StatusForbidden, gin.H{"error": internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)})
+ return
+ }
+
origModel := req.Model
remoteURL, err := url.Parse(m.Config.RemoteHost)
@@ -2511,8 +2577,14 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
return
}
- // Set headers for streaming response
- c.Header("Content-Type", "application/x-ndjson")
+ // Check streaming preference
+ isStreaming := req.Stream == nil || *req.Stream
+
+ contentType := "application/x-ndjson"
+ if !isStreaming {
+ contentType = "application/json; charset=utf-8"
+ }
+ c.Header("Content-Type", contentType)
// Get seed from options if provided
var seed int64
@@ -2527,13 +2599,21 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
}
}
+ var images []llm.ImageData
+ for i, imgData := range req.Images {
+ images = append(images, llm.ImageData{ID: i, Data: imgData})
+ }
+
var streamStarted bool
+ var finalResponse api.GenerateResponse
+
if err := runner.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
Steps: req.Steps,
Seed: seed,
+ Images: images,
}, func(cr llm.CompletionResponse) {
streamStarted = true
res := api.GenerateResponse{
@@ -2557,6 +2637,11 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
res.Metrics.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
+ if !isStreaming {
+ finalResponse = res
+ return
+ }
+
data, _ := json.Marshal(res)
c.Writer.Write(append(data, '\n'))
c.Writer.Flush()
@@ -2566,5 +2651,10 @@ func (s *Server) handleImageGenerate(c *gin.Context, req api.GenerateRequest, mo
if !streamStarted {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
+ return
+ }
+
+ if !isStreaming {
+ c.JSON(http.StatusOK, finalResponse)
}
}
diff --git a/server/routes_aliases.go b/server/routes_aliases.go
new file mode 100644
index 00000000000..d68514e9c55
--- /dev/null
+++ b/server/routes_aliases.go
@@ -0,0 +1,159 @@
+package server
+
+import (
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+
+ "github.com/gin-gonic/gin"
+
+ "github.com/ollama/ollama/types/model"
+)
+
+type aliasListResponse struct {
+ Aliases []aliasEntry `json:"aliases"`
+}
+
+type aliasDeleteRequest struct {
+ Alias string `json:"alias"`
+}
+
+func (s *Server) ListAliasesHandler(c *gin.Context) {
+ store, err := s.aliasStore()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ var aliases []aliasEntry
+ if store != nil {
+ aliases = store.List()
+ }
+
+ c.JSON(http.StatusOK, aliasListResponse{Aliases: aliases})
+}
+
+func (s *Server) CreateAliasHandler(c *gin.Context) {
+ var req aliasEntry
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+ return
+ } else if err != nil {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+
+ req.Alias = strings.TrimSpace(req.Alias)
+ req.Target = strings.TrimSpace(req.Target)
+ if req.Alias == "" || req.Target == "" {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias and target are required"})
+ return
+ }
+
+ // Target must always be a valid model name
+ targetName := model.ParseName(req.Target)
+ if !targetName.IsValid() {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("target %q is invalid", req.Target)})
+ return
+ }
+
+ var aliasName model.Name
+ if req.PrefixMatching {
+ // For prefix aliases, we still parse the alias to normalize it,
+ // but we allow any non-empty string since prefix patterns may not be valid model names
+ aliasName = model.ParseName(req.Alias)
+ // Even if not valid as a model name, we accept it for prefix matching
+ } else {
+ aliasName = model.ParseName(req.Alias)
+ if !aliasName.IsValid() {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q is invalid", req.Alias)})
+ return
+ }
+
+ if normalizeAliasKey(aliasName) == normalizeAliasKey(targetName) {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias cannot point to itself"})
+ return
+ }
+
+ exists, err := localModelExists(aliasName)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+ if exists {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("alias %q conflicts with existing model", req.Alias)})
+ return
+ }
+ }
+
+ store, err := s.aliasStore()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ if err := store.Set(aliasName, targetName, req.PrefixMatching); err != nil {
+ status := http.StatusInternalServerError
+ if errors.Is(err, errAliasCycle) {
+ status = http.StatusBadRequest
+ }
+ c.AbortWithStatusJSON(status, gin.H{"error": err.Error()})
+ return
+ }
+
+ resp := aliasEntry{
+ Alias: displayAliasName(aliasName),
+ Target: displayAliasName(targetName),
+ PrefixMatching: req.PrefixMatching,
+ }
+ if req.PrefixMatching && !aliasName.IsValid() {
+ // For prefix aliases that aren't valid model names, use the raw alias
+ resp.Alias = req.Alias
+ }
+ c.JSON(http.StatusOK, resp)
+}
+
+func (s *Server) DeleteAliasHandler(c *gin.Context) {
+ var req aliasDeleteRequest
+ if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
+ return
+ } else if err != nil {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+
+ req.Alias = strings.TrimSpace(req.Alias)
+ if req.Alias == "" {
+ c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "alias is required"})
+ return
+ }
+
+ store, err := s.aliasStore()
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+
+ aliasName := model.ParseName(req.Alias)
+ var deleted bool
+ if aliasName.IsValid() {
+ deleted, err = store.Delete(aliasName)
+ } else {
+ // For invalid model names (like prefix aliases), try deleting by raw string
+ deleted, err = store.DeleteByString(req.Alias)
+ }
+
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
+ if !deleted {
+ c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("alias %q not found", req.Alias)})
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{"deleted": true})
+}
diff --git a/server/routes_aliases_test.go b/server/routes_aliases_test.go
new file mode 100644
index 00000000000..27d06229f35
--- /dev/null
+++ b/server/routes_aliases_test.go
@@ -0,0 +1,475 @@
+package server
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/types/model"
+)
+
+func TestAliasShadowingRejected(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ setTestHome(t, t.TempDir())
+
+ s := Server{}
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "shadowed-model",
+ RemoteHost: "example.com",
+ From: "test",
+ Info: map[string]any{
+ "capabilities": []string{"completion"},
+ },
+ Stream: &stream,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "shadowed-model", Target: "other-model"})
+ if w.Code != http.StatusBadRequest {
+ t.Fatalf("expected status 400, got %d", w.Code)
+ }
+}
+
+func TestAliasResolvesForChatRemote(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ setTestHome(t, t.TempDir())
+
+ var remoteModel string
+ rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var req api.ChatRequest
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ t.Fatal(err)
+ }
+ remoteModel = req.Model
+
+ w.Header().Set("Content-Type", "application/json")
+ resp := api.ChatResponse{
+ Model: req.Model,
+ Done: true,
+ DoneReason: "load",
+ }
+ if err := json.NewEncoder(w).Encode(&resp); err != nil {
+ t.Fatal(err)
+ }
+ }))
+ defer rs.Close()
+
+ p, err := url.Parse(rs.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ t.Setenv("OLLAMA_REMOTES", p.Hostname())
+
+ s := Server{}
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "target-model",
+ RemoteHost: rs.URL,
+ From: "test",
+ Info: map[string]any{
+ "capabilities": []string{"completion"},
+ },
+ Stream: &stream,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ w = createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ w = createRequest(t, s.ChatHandler, api.ChatRequest{
+ Model: "alias-model",
+ Messages: []api.Message{{Role: "user", Content: "hi"}},
+ Stream: &stream,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ var resp api.ChatResponse
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+ t.Fatal(err)
+ }
+
+ if resp.Model != "alias-model" {
+ t.Fatalf("expected response model to be alias-model, got %q", resp.Model)
+ }
+
+ if remoteModel != "test" {
+ t.Fatalf("expected remote model to be 'test', got %q", remoteModel)
+ }
+}
+
+func TestPrefixAliasBasicMatching(t *testing.T) {
+ tmpDir := t.TempDir()
+ store, err := createStore(filepath.Join(tmpDir, "server.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a prefix alias: "myprefix-" -> "targetmodel"
+ targetName := model.ParseName("targetmodel")
+
+ // Set a prefix alias (using "myprefix-" as the pattern)
+ store.mu.Lock()
+ store.prefixEntries = append(store.prefixEntries, aliasEntry{
+ Alias: "myprefix-",
+ Target: "targetmodel",
+ PrefixMatching: true,
+ })
+ store.mu.Unlock()
+
+ // Test that "myprefix-foo" resolves to "targetmodel"
+ testName := model.ParseName("myprefix-foo")
+ resolved, wasResolved, err := store.ResolveName(testName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved")
+ }
+ if resolved.DisplayShortest() != targetName.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q, got %q", targetName.DisplayShortest(), resolved.DisplayShortest())
+ }
+
+ // Test that "otherprefix-foo" does not resolve
+ otherName := model.ParseName("otherprefix-foo")
+ _, wasResolved, err = store.ResolveName(otherName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if wasResolved {
+ t.Fatal("expected name not to be resolved")
+ }
+
+ // Test that exact alias takes precedence
+ exactAlias := model.ParseName("myprefix-exact")
+ exactTarget := model.ParseName("exacttarget")
+ if err := store.Set(exactAlias, exactTarget, false); err != nil {
+ t.Fatal(err)
+ }
+
+ resolved, wasResolved, err = store.ResolveName(exactAlias)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved")
+ }
+ if resolved.DisplayShortest() != exactTarget.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q (exact match), got %q", exactTarget.DisplayShortest(), resolved.DisplayShortest())
+ }
+}
+
+func TestPrefixAliasLongestMatchWins(t *testing.T) {
+ tmpDir := t.TempDir()
+ store, err := createStore(filepath.Join(tmpDir, "server.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Add two prefix aliases with overlapping patterns
+ store.mu.Lock()
+ store.prefixEntries = []aliasEntry{
+ {Alias: "abc-", Target: "short-target", PrefixMatching: true},
+ {Alias: "abc-def-", Target: "long-target", PrefixMatching: true},
+ }
+ store.sortPrefixEntriesLocked()
+ store.mu.Unlock()
+
+ // "abc-def-ghi" should match the longer prefix "abc-def-"
+ testName := model.ParseName("abc-def-ghi")
+ resolved, wasResolved, err := store.ResolveName(testName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved")
+ }
+ expectedLongTarget := model.ParseName("long-target")
+ if resolved.DisplayShortest() != expectedLongTarget.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q (longest prefix match), got %q", expectedLongTarget.DisplayShortest(), resolved.DisplayShortest())
+ }
+
+ // "abc-xyz" should match the shorter prefix "abc-"
+ testName2 := model.ParseName("abc-xyz")
+ resolved, wasResolved, err = store.ResolveName(testName2)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved")
+ }
+ expectedShortTarget := model.ParseName("short-target")
+ if resolved.DisplayShortest() != expectedShortTarget.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q, got %q", expectedShortTarget.DisplayShortest(), resolved.DisplayShortest())
+ }
+}
+
+func TestPrefixAliasChain(t *testing.T) {
+ tmpDir := t.TempDir()
+ store, err := createStore(filepath.Join(tmpDir, "server.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a chain: prefix "test-" -> "intermediate" -> "final"
+ intermediate := model.ParseName("intermediate")
+ final := model.ParseName("final")
+
+ // Add prefix alias
+ store.mu.Lock()
+ store.prefixEntries = []aliasEntry{
+ {Alias: "test-", Target: "intermediate", PrefixMatching: true},
+ }
+ store.mu.Unlock()
+
+ // Add exact alias for the intermediate step
+ if err := store.Set(intermediate, final, false); err != nil {
+ t.Fatal(err)
+ }
+
+ // "test-foo" should resolve through the chain to "final"
+ testName := model.ParseName("test-foo")
+ resolved, wasResolved, err := store.ResolveName(testName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved")
+ }
+ if resolved.DisplayShortest() != final.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q, got %q", final.DisplayShortest(), resolved.DisplayShortest())
+ }
+}
+
+func TestPrefixAliasCRUD(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ setTestHome(t, t.TempDir())
+
+ s := Server{}
+
+ // Create a prefix alias via API
+ w := createRequest(t, s.CreateAliasHandler, aliasEntry{
+ Alias: "myprefix-",
+ Target: "llama2",
+ PrefixMatching: true,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ var createResp aliasEntry
+ if err := json.NewDecoder(w.Body).Decode(&createResp); err != nil {
+ t.Fatal(err)
+ }
+ if !createResp.PrefixMatching {
+ t.Fatal("expected prefix_matching to be true in response")
+ }
+
+ // List aliases and verify the prefix alias is included
+ w = createRequest(t, s.ListAliasesHandler, nil)
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ var listResp aliasListResponse
+ if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
+ t.Fatal(err)
+ }
+
+ found := false
+ for _, a := range listResp.Aliases {
+ if a.PrefixMatching && a.Target == "llama2" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Fatal("expected to find prefix alias in list")
+ }
+
+ // Delete the prefix alias
+ w = createRequest(t, s.DeleteAliasHandler, aliasDeleteRequest{Alias: "myprefix-"})
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Verify it's deleted
+ w = createRequest(t, s.ListAliasesHandler, nil)
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ if err := json.NewDecoder(w.Body).Decode(&listResp); err != nil {
+ t.Fatal(err)
+ }
+
+ for _, a := range listResp.Aliases {
+ if a.PrefixMatching {
+ t.Fatal("expected prefix alias to be deleted")
+ }
+ }
+}
+
+func TestPrefixAliasCaseInsensitive(t *testing.T) {
+ tmpDir := t.TempDir()
+ store, err := createStore(filepath.Join(tmpDir, "server.json"))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Add a prefix alias with mixed case
+ store.mu.Lock()
+ store.prefixEntries = []aliasEntry{
+ {Alias: "MyPrefix-", Target: "targetmodel", PrefixMatching: true},
+ }
+ store.mu.Unlock()
+
+ // Test that matching is case-insensitive
+ testName := model.ParseName("myprefix-foo")
+ resolved, wasResolved, err := store.ResolveName(testName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved (case-insensitive)")
+ }
+ expectedTarget := model.ParseName("targetmodel")
+ if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
+ }
+
+ // Test uppercase request
+ testName2 := model.ParseName("MYPREFIX-BAR")
+ _, wasResolved, err = store.ResolveName(testName2)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected name to be resolved (uppercase)")
+ }
+}
+
+func TestPrefixAliasLocalModelPrecedence(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ setTestHome(t, t.TempDir())
+
+ s := Server{}
+
+ // Create a local model that would match a prefix alias
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "myprefix-localmodel",
+ RemoteHost: "example.com",
+ From: "test",
+ Info: map[string]any{
+ "capabilities": []string{"completion"},
+ },
+ Stream: &stream,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Create a prefix alias that would match the local model name
+ w = createRequest(t, s.CreateAliasHandler, aliasEntry{
+ Alias: "myprefix-",
+ Target: "someothermodel",
+ PrefixMatching: true,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Verify that resolving "myprefix-localmodel" returns the local model, not the alias target
+ store, err := s.aliasStore()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ localModelName := model.ParseName("myprefix-localmodel")
+ resolved, wasResolved, err := store.ResolveName(localModelName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if wasResolved {
+ t.Fatalf("expected local model to take precedence (wasResolved should be false), but got resolved to %q", resolved.DisplayShortest())
+ }
+ if resolved.DisplayShortest() != localModelName.DisplayShortest() {
+ t.Fatalf("expected resolved name to be local model %q, got %q", localModelName.DisplayShortest(), resolved.DisplayShortest())
+ }
+
+ // Also verify that a non-local model matching the prefix DOES resolve to the alias target
+ nonLocalName := model.ParseName("myprefix-nonexistent")
+ resolved, wasResolved, err = store.ResolveName(nonLocalName)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !wasResolved {
+ t.Fatal("expected non-local model to resolve via prefix alias")
+ }
+ expectedTarget := model.ParseName("someothermodel")
+ if resolved.DisplayShortest() != expectedTarget.DisplayShortest() {
+ t.Fatalf("expected resolved name to be %q, got %q", expectedTarget.DisplayShortest(), resolved.DisplayShortest())
+ }
+}
+
+func TestAliasSavePreservesCloudDisable(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ tmpDir := t.TempDir()
+ setTestHome(t, tmpDir)
+
+ configPath := filepath.Join(tmpDir, ".ollama", "server.json")
+ if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil {
+ t.Fatal(err)
+ }
+
+ initial := map[string]any{
+ "version": serverConfigVersion,
+ "disable_ollama_cloud": true,
+ "aliases": []aliasEntry{},
+ }
+ data, err := json.Marshal(initial)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(configPath, data, 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ s := Server{}
+ w := createRequest(t, s.CreateAliasHandler, aliasEntry{Alias: "alias-model", Target: "target-model"})
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ updated, err := os.ReadFile(configPath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var updatedCfg map[string]json.RawMessage
+ if err := json.Unmarshal(updated, &updatedCfg); err != nil {
+ t.Fatal(err)
+ }
+
+ raw, ok := updatedCfg["disable_ollama_cloud"]
+ if !ok {
+ t.Fatal("expected disable_ollama_cloud key to be preserved")
+ }
+ if string(raw) != "true" {
+ t.Fatalf("expected disable_ollama_cloud to remain true, got %s", string(raw))
+ }
+}
diff --git a/server/routes_cloud_test.go b/server/routes_cloud_test.go
new file mode 100644
index 00000000000..b0ee126ea4c
--- /dev/null
+++ b/server/routes_cloud_test.go
@@ -0,0 +1,94 @@
+package server
+
+import (
+ "encoding/json"
+ "net/http"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/ollama/ollama/api"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
+)
+
+func TestStatusHandler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ setTestHome(t, t.TempDir())
+ t.Setenv("OLLAMA_NO_CLOUD", "1")
+
+ s := Server{}
+ w := createRequest(t, s.StatusHandler, nil)
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ var resp api.StatusResponse
+ if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
+ t.Fatal(err)
+ }
+
+ if !resp.Cloud.Disabled {
+ t.Fatalf("expected cloud.disabled true, got false")
+ }
+ if resp.Cloud.Source != "env" {
+ t.Fatalf("expected cloud.source env, got %q", resp.Cloud.Source)
+ }
+}
+
+func TestCloudDisabledBlocksRemoteOperations(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ setTestHome(t, t.TempDir())
+ t.Setenv("OLLAMA_NO_CLOUD", "1")
+
+ s := Server{}
+
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "test-cloud",
+ RemoteHost: "example.com",
+ From: "test",
+ Info: map[string]any{
+ "capabilities": []string{"completion"},
+ },
+ Stream: &stream,
+ })
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ t.Run("chat remote blocked", func(t *testing.T) {
+ w := createRequest(t, s.ChatHandler, api.ChatRequest{
+ Model: "test-cloud",
+ Messages: []api.Message{{Role: "user", Content: "hi"}},
+ })
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("expected status 403, got %d", w.Code)
+ }
+ if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)+`"}` {
+ t.Fatalf("unexpected response: %s", got)
+ }
+ })
+
+ t.Run("generate remote blocked", func(t *testing.T) {
+ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+ Model: "test-cloud",
+ Prompt: "hi",
+ })
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("expected status 403, got %d", w.Code)
+ }
+ if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteInferenceUnavailable)+`"}` {
+ t.Fatalf("unexpected response: %s", got)
+ }
+ })
+
+ t.Run("show remote blocked", func(t *testing.T) {
+ w := createRequest(t, s.ShowHandler, api.ShowRequest{
+ Model: "test-cloud",
+ })
+ if w.Code != http.StatusForbidden {
+ t.Fatalf("expected status 403, got %d", w.Code)
+ }
+ if got := w.Body.String(); got != `{"error":"`+internalcloud.DisabledError(cloudErrRemoteModelDetailsUnavailable)+`"}` {
+ t.Fatalf("unexpected response: %s", got)
+ }
+ })
+}
diff --git a/server/routes_create_test.go b/server/routes_create_test.go
index ebe79e90219..5599ddd2cf8 100644
--- a/server/routes_create_test.go
+++ b/server/routes_create_test.go
@@ -25,6 +25,7 @@ import (
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/fs/ggml"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -223,15 +224,15 @@ func TestCreateFromModelInheritsRendererParser(t *testing.T) {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
- manifest, err := ParseNamedManifest(model.ParseName("child"))
+ mf, err := manifest.ParseNamedManifest(model.ParseName("child"))
if err != nil {
t.Fatalf("parse manifest: %v", err)
}
- if manifest.Config.Digest == "" {
+ if mf.Config.Digest == "" {
t.Fatalf("unexpected empty config digest for child manifest")
}
- configPath, err := GetBlobsPath(manifest.Config.Digest)
+ configPath, err := manifest.BlobsPath(mf.Config.Digest)
if err != nil {
t.Fatalf("config blob path: %v", err)
}
diff --git a/server/routes_debug_test.go b/server/routes_debug_test.go
index 8002c752c77..6cc1fca78c8 100644
--- a/server/routes_debug_test.go
+++ b/server/routes_debug_test.go
@@ -15,6 +15,7 @@ import (
)
func TestGenerateDebugRenderOnly(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -208,6 +209,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
}
func TestChatDebugRenderOnly(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
diff --git a/server/routes_delete_test.go b/server/routes_delete_test.go
index eb8c4432079..a1a5f542482 100644
--- a/server/routes_delete_test.go
+++ b/server/routes_delete_test.go
@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/types/model"
)
@@ -93,13 +94,13 @@ func TestDeleteDuplicateLayers(t *testing.T) {
t.Fatal(err)
}
- config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+ config, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
t.Fatal(err)
}
// create a manifest with duplicate layers
- if err := WriteManifest(n, config, []Layer{config}); err != nil {
+ if err := manifest.WriteManifest(n, config, []manifest.Layer{config}); err != nil {
t.Fatal(err)
}
diff --git a/server/routes_generate_renderer_test.go b/server/routes_generate_renderer_test.go
index 06336d02716..92a5205c4dc 100644
--- a/server/routes_generate_renderer_test.go
+++ b/server/routes_generate_renderer_test.go
@@ -20,6 +20,7 @@ import (
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
// when in chat-like flow (messages present, no suffix, no template)
func TestGenerateWithBuiltinRenderer(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -204,6 +205,7 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
func TestGenerateWithDebugRenderOnly(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go
index 6c6db256b72..d52d2955174 100644
--- a/server/routes_generate_test.go
+++ b/server/routes_generate_test.go
@@ -19,7 +19,9 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/types/model"
)
// testPropsMap creates a ToolPropertiesMap from a map (convenience function for tests)
@@ -71,6 +73,8 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
return
}
+func (mockRunner) Ping(_ context.Context) error { return nil }
+
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, []string, *ggml.MetaGGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ []string, _ *ggml.MetaGGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
return mock, nil
@@ -158,6 +162,7 @@ func TestGenerateChatRemote(t *testing.T) {
}
func TestGenerateChat(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -874,6 +879,7 @@ func TestGenerateChat(t *testing.T) {
}
func TestGenerate(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
gin.SetMode(gin.TestMode)
mock := mockRunner{
@@ -2193,3 +2199,253 @@ func TestGenerateUnload(t *testing.T) {
}
})
}
+
+func TestGenerateWithImages(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ mock := mockRunner{
+ CompletionResponse: llm.CompletionResponse{
+ Done: true,
+ DoneReason: llm.DoneReasonStop,
+ PromptEvalCount: 1,
+ PromptEvalDuration: 1,
+ EvalCount: 1,
+ EvalDuration: 1,
+ },
+ }
+
+ s := Server{
+ sched: &Scheduler{
+ pendingReqCh: make(chan *LlmRequest, 1),
+ finishedReqCh: make(chan *LlmRequest, 1),
+ expiredCh: make(chan *runnerRef, 1),
+ unloadedCh: make(chan any, 1),
+ loaded: make(map[string]*runnerRef),
+ newServerFn: newMockServer(&mock),
+ getGpuFn: getGpuFn,
+ getSystemInfoFn: getSystemInfoFn,
+ waitForRecovery: 250 * time.Millisecond,
+ loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
+ time.Sleep(time.Millisecond)
+ req.successCh <- &runnerRef{
+ llama: &mock,
+ }
+ return false
+ },
+ },
+ }
+
+ go s.sched.Run(t.Context())
+
+ _, digest := createBinFile(t, ggml.KV{
+ "general.architecture": "llama",
+ "llama.block_count": uint32(1),
+ "llama.context_length": uint32(8192),
+ "llama.embedding_length": uint32(4096),
+ "llama.attention.head_count": uint32(32),
+ "llama.attention.head_count_kv": uint32(8),
+ "tokenizer.ggml.tokens": []string{""},
+ "tokenizer.ggml.scores": []float32{0},
+ "tokenizer.ggml.token_type": []int32{0},
+ }, []*ggml.Tensor{
+ {Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ {Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
+ })
+
+ w := createRequest(t, s.CreateHandler, api.CreateRequest{
+ Model: "test",
+ Files: map[string]string{"file.gguf": digest},
+ Stream: &stream,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d", w.Code)
+ }
+
+ t.Run("images passed to completion request", func(t *testing.T) {
+ testImage := []byte("test-image-data")
+
+ mock.CompletionResponse.Content = "Image processed"
+ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+ Model: "test",
+ Prompt: "Describe this image",
+ Images: []api.ImageData{testImage},
+ Stream: &stream,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Verify images were passed to the completion request
+ if len(mock.CompletionRequest.Images) != 1 {
+ t.Fatalf("expected 1 image in completion request, got %d", len(mock.CompletionRequest.Images))
+ }
+
+ if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage) {
+ t.Errorf("image data mismatch in completion request")
+ }
+
+ if mock.CompletionRequest.Images[0].ID != 0 {
+ t.Errorf("expected image ID 0, got %d", mock.CompletionRequest.Images[0].ID)
+ }
+ })
+
+ t.Run("multiple images passed to completion request", func(t *testing.T) {
+ testImage1 := []byte("test-image-1")
+ testImage2 := []byte("test-image-2")
+
+ mock.CompletionResponse.Content = "Images processed"
+ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+ Model: "test",
+ Prompt: "Compare these images",
+ Images: []api.ImageData{testImage1, testImage2},
+ Stream: &stream,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Verify both images were passed
+ if len(mock.CompletionRequest.Images) != 2 {
+ t.Fatalf("expected 2 images in completion request, got %d", len(mock.CompletionRequest.Images))
+ }
+
+ if !bytes.Equal(mock.CompletionRequest.Images[0].Data, testImage1) {
+ t.Errorf("first image data mismatch")
+ }
+
+ if !bytes.Equal(mock.CompletionRequest.Images[1].Data, testImage2) {
+ t.Errorf("second image data mismatch")
+ }
+
+ if mock.CompletionRequest.Images[0].ID != 0 || mock.CompletionRequest.Images[1].ID != 1 {
+ t.Errorf("expected image IDs 0 and 1, got %d and %d",
+ mock.CompletionRequest.Images[0].ID, mock.CompletionRequest.Images[1].ID)
+ }
+ })
+
+ t.Run("no images when none provided", func(t *testing.T) {
+ mock.CompletionResponse.Content = "No images"
+ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+ Model: "test",
+ Prompt: "Hello",
+ Stream: &stream,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Verify no images in completion request
+ if len(mock.CompletionRequest.Images) != 0 {
+ t.Fatalf("expected 0 images in completion request, got %d", len(mock.CompletionRequest.Images))
+ }
+ })
+}
+
+// TestImageGenerateStreamFalse tests that image generation respects stream=false
+// and returns a single JSON response instead of streaming ndjson.
+func TestImageGenerateStreamFalse(t *testing.T) {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", "4096")
+ gin.SetMode(gin.TestMode)
+
+ p := t.TempDir()
+ t.Setenv("OLLAMA_MODELS", p)
+
+ mock := mockRunner{}
+ mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
+ fn(llm.CompletionResponse{Step: 1, TotalSteps: 3, Done: false})
+ fn(llm.CompletionResponse{Step: 2, TotalSteps: 3, Done: false})
+ fn(llm.CompletionResponse{Step: 3, TotalSteps: 3, Done: true, DoneReason: llm.DoneReasonStop, Image: "base64image"})
+ return nil
+ }
+
+ // Create model manifest with image capability
+ n := model.ParseName("test-image")
+ cfg := model.ConfigV2{Capabilities: []string{"image"}}
+ var b bytes.Buffer
+ if err := json.NewEncoder(&b).Encode(&cfg); err != nil {
+ t.Fatal(err)
+ }
+ configLayer, err := manifest.NewLayer(&b, "application/vnd.docker.container.image.v1+json")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := manifest.WriteManifest(n, configLayer, nil); err != nil {
+ t.Fatal(err)
+ }
+
+ loadedModel, err := GetModel("test-image")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ opts := api.DefaultOptions()
+ s := Server{
+ sched: &Scheduler{
+ pendingReqCh: make(chan *LlmRequest, 1),
+ finishedReqCh: make(chan *LlmRequest, 1),
+ expiredCh: make(chan *runnerRef, 1),
+ unloadedCh: make(chan any, 1),
+ loaded: map[string]*runnerRef{
+ schedulerModelKey(loadedModel): {
+ llama: &mock,
+ Options: &opts,
+ model: loadedModel,
+ isImagegen: true,
+ numParallel: 1,
+ },
+ },
+ newServerFn: newMockServer(&mock),
+ getGpuFn: getGpuFn,
+ getSystemInfoFn: getSystemInfoFn,
+ },
+ }
+
+ go s.sched.Run(t.Context())
+
+ streamFalse := false
+ w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
+ Model: "test-image",
+ Prompt: "test prompt",
+ Stream: &streamFalse,
+ })
+
+ if w.Code != http.StatusOK {
+ t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
+ }
+
+ if ct := w.Header().Get("Content-Type"); ct != "application/json; charset=utf-8" {
+ t.Errorf("expected Content-Type 'application/json; charset=utf-8', got %q", ct)
+ }
+
+ body := w.Body.String()
+ lines := strings.Split(strings.TrimSpace(body), "\n")
+ if len(lines) != 1 {
+ t.Errorf("expected 1 response line, got %d:\n%s", len(lines), body)
+ }
+
+ var resp api.GenerateResponse
+ if err := json.Unmarshal([]byte(lines[0]), &resp); err != nil {
+ t.Fatalf("failed to parse response: %v", err)
+ }
+
+ if resp.Image != "base64image" {
+ t.Errorf("expected image 'base64image', got %q", resp.Image)
+ }
+
+ if !resp.Done {
+ t.Errorf("expected done=true")
+ }
+}
diff --git a/server/routes_options_test.go b/server/routes_options_test.go
new file mode 100644
index 00000000000..9634e7e1abb
--- /dev/null
+++ b/server/routes_options_test.go
@@ -0,0 +1,127 @@
+package server
+
+import (
+ "testing"
+)
+
+func TestModelOptionsNumCtxPriority(t *testing.T) {
+ tests := []struct {
+ name string
+ envContextLen string // empty means not set (uses 0 sentinel)
+ defaultNumCtx int // VRAM-based default
+ modelNumCtx int // 0 means not set in model
+ requestNumCtx int // 0 means not set in request
+ expectedNumCtx int
+ }{
+ {
+ name: "vram default when nothing else set",
+ envContextLen: "",
+ defaultNumCtx: 32768,
+ modelNumCtx: 0,
+ requestNumCtx: 0,
+ expectedNumCtx: 32768,
+ },
+ {
+ name: "env var overrides vram default",
+ envContextLen: "8192",
+ defaultNumCtx: 32768,
+ modelNumCtx: 0,
+ requestNumCtx: 0,
+ expectedNumCtx: 8192,
+ },
+ {
+ name: "model overrides vram default",
+ envContextLen: "",
+ defaultNumCtx: 32768,
+ modelNumCtx: 16384,
+ requestNumCtx: 0,
+ expectedNumCtx: 16384,
+ },
+ {
+ name: "model overrides env var",
+ envContextLen: "8192",
+ defaultNumCtx: 32768,
+ modelNumCtx: 16384,
+ requestNumCtx: 0,
+ expectedNumCtx: 16384,
+ },
+ {
+ name: "request overrides everything",
+ envContextLen: "8192",
+ defaultNumCtx: 32768,
+ modelNumCtx: 16384,
+ requestNumCtx: 4096,
+ expectedNumCtx: 4096,
+ },
+ {
+ name: "request overrides vram default",
+ envContextLen: "",
+ defaultNumCtx: 32768,
+ modelNumCtx: 0,
+ requestNumCtx: 4096,
+ expectedNumCtx: 4096,
+ },
+ {
+ name: "request overrides model",
+ envContextLen: "",
+ defaultNumCtx: 32768,
+ modelNumCtx: 16384,
+ requestNumCtx: 4096,
+ expectedNumCtx: 4096,
+ },
+ {
+ name: "low vram tier default",
+ envContextLen: "",
+ defaultNumCtx: 4096,
+ modelNumCtx: 0,
+ requestNumCtx: 0,
+ expectedNumCtx: 4096,
+ },
+ {
+ name: "high vram tier default",
+ envContextLen: "",
+ defaultNumCtx: 262144,
+ modelNumCtx: 0,
+ requestNumCtx: 0,
+ expectedNumCtx: 262144,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Set or clear environment variable
+ if tt.envContextLen != "" {
+ t.Setenv("OLLAMA_CONTEXT_LENGTH", tt.envContextLen)
+ }
+
+ // Create server with VRAM-based default
+ s := &Server{
+ defaultNumCtx: tt.defaultNumCtx,
+ }
+
+ // Create model options (use float64 as FromMap expects JSON-style numbers)
+ var modelOpts map[string]any
+ if tt.modelNumCtx != 0 {
+ modelOpts = map[string]any{"num_ctx": float64(tt.modelNumCtx)}
+ }
+ model := &Model{
+ Options: modelOpts,
+ }
+
+ // Create request options (use float64 as FromMap expects JSON-style numbers)
+ var requestOpts map[string]any
+ if tt.requestNumCtx != 0 {
+ requestOpts = map[string]any{"num_ctx": float64(tt.requestNumCtx)}
+ }
+
+ opts, err := s.modelOptions(model, requestOpts)
+ if err != nil {
+ t.Fatalf("modelOptions failed: %v", err)
+ }
+
+ if opts.NumCtx != tt.expectedNumCtx {
+ t.Errorf("NumCtx = %d, want %d", opts.NumCtx, tt.expectedNumCtx)
+ }
+ })
+ }
+}
diff --git a/server/sched.go b/server/sched.go
index ac219746c5e..278426fb53a 100644
--- a/server/sched.go
+++ b/server/sched.go
@@ -22,6 +22,7 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/mlxrunner"
)
type LlmRequest struct {
@@ -32,6 +33,7 @@ type LlmRequest struct {
successCh chan *runnerRef
errCh chan error
schedAttempts uint
+ useImagegen bool
}
type Scheduler struct {
@@ -81,8 +83,30 @@ func InitScheduler(ctx context.Context) *Scheduler {
return sched
}
+// schedulerModelKey returns the scheduler map key for a model.
+// GGUF-backed models use ModelPath; safetensors/image models without a
+// ModelPath use manifest digest so distinct models don't collide.
+func schedulerModelKey(m *Model) string {
+ if m == nil {
+ return ""
+ }
+ if m.ModelPath != "" {
+ return m.ModelPath
+ }
+ if m.Digest != "" {
+ return "digest:" + m.Digest
+ }
+ if m.Name != "" {
+ return "name:" + m.Name
+ }
+ if m.ShortName != "" {
+ return "short:" + m.ShortName
+ }
+ return ""
+}
+
// context must be canceled to decrement ref count and release the runner
-func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
+func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration, useImagegen bool) (chan *runnerRef, chan error) {
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
@@ -99,10 +123,12 @@ func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, ses
sessionDuration: sessionDuration,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
+ useImagegen: useImagegen,
}
+ key := schedulerModelKey(req.model)
s.loadedMu.Lock()
- runner := s.loaded[req.model.ModelPath]
+ runner := s.loaded[key]
s.loadedMu.Unlock()
if runner != nil && !runner.needsReload(c, req) {
req.useLoadedRunner(runner, s.finishedReqCh)
@@ -148,8 +174,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
for {
var runnerToExpire *runnerRef
+ pendingKey := schedulerModelKey(pending.model)
s.loadedMu.Lock()
- runner := s.loaded[pending.model.ModelPath]
+ runner := s.loaded[pendingKey]
loadedCount := len(s.loaded)
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
for _, r := range s.loaded {
@@ -163,7 +190,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
runnerToExpire = runner
} else {
// Runner is usable, return it
- logutil.Trace("using existing loaded runner", "model", pending.model.ModelPath)
+ logutil.Trace("using existing loaded runner", "model", pendingKey)
pending.useLoadedRunner(runner, s.finishedReqCh)
break
}
@@ -198,14 +225,25 @@ func (s *Scheduler) processPending(ctx context.Context) {
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
}
- // Check for image generation model before attempting GGML load
+ // Check for image generation models - all use MLX runner
if slices.Contains(pending.model.Config.Capabilities, "image") {
- if s.loadImageGen(pending) {
+ if s.loadMLX(pending) {
break
}
continue
}
+ // Check for experimental safetensors LLM models
+ if pending.model.Config.ModelFormat == "safetensors" {
+ if slices.Contains(pending.model.Config.Capabilities, "completion") {
+ // LLM model with safetensors format - use MLX runner
+ if s.loadMLX(pending) {
+ break
+ }
+ continue
+ }
+ }
+
// Load model for fitting
logutil.Trace("loading model metadata", "model", pending.model.ModelPath)
ggml, err := llm.LoadModel(pending.model.ModelPath, pending.model.ExtraModelPaths, 1024, false)
@@ -281,11 +319,12 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
slog.Debug("shutting down scheduler completed loop")
return
case finished := <-s.finishedReqCh:
+ finishedKey := schedulerModelKey(finished.model)
s.loadedMu.Lock()
- runner := s.loaded[finished.model.ModelPath]
+ runner := s.loaded[finishedKey]
s.loadedMu.Unlock()
if runner == nil {
- slog.Error("finished request signal received after model unloaded", "modelPath", finished.model.ModelPath)
+ slog.Error("finished request signal received after model unloaded", "modelPath", finishedKey)
continue
}
runner.refMu.Lock()
@@ -336,7 +375,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
s.loadedMu.Lock()
slog.Debug("got lock to unload expired event", "runner", runner)
- runnerToUnload := s.loaded[runner.modelPath]
+ runnerToUnload := s.loaded[runner.modelKey]
if runnerToUnload == nil {
// If runnerToUnload is nil, we already processed an event and
// unloaded it. This double unload can happen if the initial
@@ -365,7 +404,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
}
finished := s.waitForVRAMRecovery(runner, runnersSnapshot)
runner.unload()
- delete(s.loaded, runner.modelPath)
+ delete(s.loaded, runner.modelKey)
s.loadedMu.Unlock()
slog.Debug("runner terminated and removed from list, blocking for VRAM recovery", "runner", runner)
<-finished
@@ -409,9 +448,9 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.MetaGGML, systemInfo ml.System
numParallel = 1
}
- // `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1
+ // Some architectures are not safe with num_parallel > 1.
// ref: https://github.com/ollama/ollama/issues/4165
- if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 {
+ if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe", "qwen3next", "lfm2", "lfm2moe"}, req.model.Config.ModelFamily) && numParallel != 1 {
numParallel = 1
slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily)
}
@@ -503,6 +542,7 @@ iGPUScan:
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
+ modelKey: schedulerModelKey(req.model),
llama: llama,
Options: &req.opts,
sessionDuration: sessionDuration,
@@ -517,7 +557,7 @@ iGPUScan:
runner.refMu.Lock() // hold lock until running or aborted
s.loadedMu.Lock()
- if oldRunner, ok := s.loaded[req.model.ModelPath]; ok {
+ if oldRunner, ok := s.loaded[runner.modelKey]; ok {
// Shouldn't happen, but safeguard against leaking a runner
slog.Warn("model was still loaded", "old_runner", oldRunner, "new_runner", runner)
oldRunner.refMu.Lock()
@@ -525,7 +565,7 @@ iGPUScan:
oldRunner.refMu.Unlock()
}
s.activeLoading = nil
- s.loaded[req.model.ModelPath] = runner
+ s.loaded[runner.modelKey] = runner
slog.Info("loaded runners", "count", len(s.loaded))
s.loadedMu.Unlock()
@@ -555,11 +595,23 @@ iGPUScan:
return false
}
-// loadImageGen loads an image generation model.
-func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
- // Use model name for imagegen (it resolves manifests by name, not file path)
+// loadMLX loads an experimental safetensors model using the unified MLX runner.
+// This supports both LLM (completion) and image generation models.
+func (s *Scheduler) loadMLX(req *LlmRequest) bool {
modelName := req.model.ShortName
- server, err := imagegen.NewServer(modelName)
+ var server llm.LlamaServer
+ var err error
+
+ isImagegen := false
+ if slices.Contains(req.model.Config.Capabilities, "image") {
+ server, err = imagegen.NewServer(modelName, imagegen.ModeImageGen)
+ isImagegen = true
+ } else if req.useImagegen {
+ server, err = imagegen.NewServer(modelName, imagegen.ModeLLM)
+ isImagegen = true
+ } else {
+ server, err = mlxrunner.NewClient(modelName)
+ }
if err != nil {
req.errCh <- err
return true
@@ -573,16 +625,18 @@ func (s *Scheduler) loadImageGen(req *LlmRequest) bool {
runner := &runnerRef{
model: req.model,
modelPath: req.model.ModelPath,
+ modelKey: schedulerModelKey(req.model),
llama: server,
Options: &req.opts,
loading: false,
+ isImagegen: isImagegen,
sessionDuration: sessionDuration,
totalSize: server.TotalSize(),
vramSize: server.VRAMSize(),
}
s.loadedMu.Lock()
- s.loaded[req.model.ModelPath] = runner
+ s.loaded[runner.modelKey] = runner
s.loadedMu.Unlock()
// Set up expiration timer
@@ -650,6 +704,7 @@ type runnerRef struct {
loading bool // True only during initial load, then false forever
gpus []ml.DeviceID // Recorded at time of provisioning
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
+ isImagegen bool // True if loaded via imagegen runner (vs mlxrunner)
vramSize uint64
totalSize uint64
@@ -659,6 +714,7 @@ type runnerRef struct {
model *Model
modelPath string
+ modelKey string
numParallel int
*api.Options
}
@@ -678,10 +734,16 @@ func (runner *runnerRef) unload() {
}
func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool {
- slog.Debug("evaluating already loaded", "model", req.model.ModelPath)
+ slog.Debug("evaluating already loaded", "model", schedulerModelKey(req.model))
runner.refMu.Lock()
defer runner.refMu.Unlock()
+ // Check if runner type (imagegen vs mlxrunner) matches what's requested
+ wantImagegen := req.useImagegen || slices.Contains(req.model.Config.Capabilities, "image")
+ if runner.isImagegen != wantImagegen {
+ return true
+ }
+
timeout := 10 * time.Second
if runner.loading {
timeout = 2 * time.Minute // Initial load can take a long time for big models on slow systems...
@@ -795,6 +857,10 @@ func (runner *runnerRef) LogValue() slog.Value {
if runner == nil {
return slog.StringValue("nil")
}
+ modelID := runner.modelPath
+ if modelID == "" {
+ modelID = runner.modelKey
+ }
attrs := []slog.Attr{}
if runner.model != nil {
attrs = append(attrs, slog.String("name", runner.model.Name))
@@ -809,7 +875,7 @@ func (runner *runnerRef) LogValue() slog.Value {
slog.String("vram", format.HumanBytes2(runner.vramSize)),
slog.Int("parallel", runner.numParallel),
slog.Int("pid", runner.pid),
- slog.String("model", runner.modelPath),
+ slog.String("model", modelID),
)
if runner.Options != nil {
attrs = append(attrs, slog.Int("num_ctx", runner.Options.NumCtx))
@@ -854,8 +920,16 @@ func (a ByDurationAndName) Less(i, j int) bool {
if d1 != d2 {
return d1 < d2
}
- // Secondary sort by model path lex order
- return a[i].modelPath < a[j].modelPath
+ // Secondary sort by model key/path lex order
+ n1 := a[i].modelPath
+ if n1 == "" {
+ n1 = a[i].modelKey
+ }
+ n2 := a[j].modelPath
+ if n2 == "" {
+ n2 = a[j].modelKey
+ }
+ return n1 < n2
}
// TODO - future consideration to pick runners based on size
@@ -915,8 +989,9 @@ func (s *Scheduler) unloadAllRunners() {
}
func (s *Scheduler) expireRunner(model *Model) {
+ modelKey := schedulerModelKey(model)
s.loadedMu.Lock()
- runner, ok := s.loaded[model.ModelPath]
+ runner, ok := s.loaded[modelKey]
s.loadedMu.Unlock()
if ok {
runner.refMu.Lock()
diff --git a/server/sched_test.go b/server/sched_test.go
index aeb66f8b2ee..edab970fab3 100644
--- a/server/sched_test.go
+++ b/server/sched_test.go
@@ -408,10 +408,10 @@ func TestSchedGetRunner(t *testing.T) {
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = a.newServer
slog.Info("a")
- successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
+ successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
slog.Info("b")
- successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration)
+ successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1)
@@ -435,7 +435,7 @@ func TestSchedGetRunner(t *testing.T) {
c.req.model.ModelPath = "bad path"
slog.Info("c")
- successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration)
+ successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration, false)
// Starts in pending channel, then should be quickly processed to return an error
time.Sleep(50 * time.Millisecond) // Long enough for the "a" model to expire and unload
require.Empty(t, successCh1c)
@@ -448,6 +448,71 @@ func TestSchedGetRunner(t *testing.T) {
b.ctxDone()
}
+func TestSchedGetRunnerUsesDigestKeyWhenModelPathEmpty(t *testing.T) {
+ ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
+ defer done()
+
+ s := InitScheduler(ctx)
+ opts := api.DefaultOptions()
+ opts.NumCtx = 4
+
+ loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
+ loadedRunner := &runnerRef{
+ model: loadedModel,
+ modelKey: schedulerModelKey(loadedModel),
+ llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
+ Options: &opts,
+ numParallel: 1,
+ }
+
+ s.loadedMu.Lock()
+ s.loaded[loadedRunner.modelKey] = loadedRunner
+ s.loadedMu.Unlock()
+
+ reqModel := &Model{Name: "safetensors-b", Digest: "sha-b"}
+ successCh, errCh := s.GetRunner(ctx, reqModel, opts, nil, false)
+
+ require.Empty(t, successCh)
+ require.Empty(t, errCh)
+ require.Len(t, s.pendingReqCh, 1)
+}
+
+func TestSchedGetRunnerReusesSameDigestWhenModelPathEmpty(t *testing.T) {
+ ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
+ defer done()
+
+ s := InitScheduler(ctx)
+ opts := api.DefaultOptions()
+ opts.NumCtx = 4
+
+ loadedModel := &Model{Name: "safetensors-a", Digest: "sha-a"}
+ loadedRunner := &runnerRef{
+ model: loadedModel,
+ modelKey: schedulerModelKey(loadedModel),
+ llama: &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}},
+ Options: &opts,
+ numParallel: 1,
+ }
+
+ s.loadedMu.Lock()
+ s.loaded[loadedRunner.modelKey] = loadedRunner
+ s.loadedMu.Unlock()
+
+ reqCtx, cancelReq := context.WithCancel(ctx)
+ successCh, errCh := s.GetRunner(reqCtx, &Model{Name: "safetensors-a-copy", Digest: "sha-a"}, opts, nil, false)
+ cancelReq()
+
+ select {
+ case runner := <-successCh:
+ require.Equal(t, loadedRunner, runner)
+ default:
+ t.Fatal("expected existing runner to be reused")
+ }
+
+ require.Empty(t, errCh)
+ require.Empty(t, s.pendingReqCh)
+}
+
func TestSchedExpireRunner(t *testing.T) {
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
defer done()
@@ -509,7 +574,7 @@ func TestSchedPrematureExpired(t *testing.T) {
s.getGpuFn = getGpuFn
s.getSystemInfoFn = getSystemInfoFn
s.newServerFn = scenario1a.newServer
- successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
+ successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration, false)
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
@@ -804,6 +869,7 @@ func (s *mockLlm) GetPort() int { return -
func (s *mockLlm) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
func (s *mockLlm) HasExited() bool { return false }
func (s *mockLlm) GetActiveDeviceIDs() []ml.DeviceID { return nil }
+func (s *mockLlm) ContextLength() int { return 0 }
// TestImageGenRunnerCanBeEvicted verifies that an image generation model
// loaded in the scheduler can be evicted when idle.
diff --git a/server/shard_metadata.go b/server/shard_metadata.go
index 86daaaa6796..aa842597932 100644
--- a/server/shard_metadata.go
+++ b/server/shard_metadata.go
@@ -4,6 +4,8 @@ import (
"encoding/json"
"errors"
"os"
+
+ "github.com/ollama/ollama/manifest"
)
// ShardMetadata stores information about sharded GGUF models
@@ -18,7 +20,7 @@ type ShardMetadata struct {
// WriteShardMetadata stores shard info for a model layer
func WriteShardMetadata(digest string, meta ShardMetadata) error {
- blobPath, err := GetBlobsPath(digest)
+ blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return err
}
@@ -36,7 +38,7 @@ func WriteShardMetadata(digest string, meta ShardMetadata) error {
// ReadShardMetadata loads shard info if it exists
func ReadShardMetadata(digest string) (*ShardMetadata, error) {
- blobPath, err := GetBlobsPath(digest)
+ blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return nil, err
}
@@ -62,7 +64,7 @@ func ReadShardMetadata(digest string) (*ShardMetadata, error) {
// DeleteShardMetadata removes shard metadata file
func DeleteShardMetadata(digest string) error {
- blobPath, err := GetBlobsPath(digest)
+ blobPath, err := manifest.BlobsPath(digest)
if err != nil {
return err
}
diff --git a/server/test_home_test.go b/server/test_home_test.go
new file mode 100644
index 00000000000..7a0393684ce
--- /dev/null
+++ b/server/test_home_test.go
@@ -0,0 +1,14 @@
+package server
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/envconfig"
+)
+
+func setTestHome(t *testing.T, home string) {
+ t.Helper()
+ t.Setenv("HOME", home)
+ t.Setenv("USERPROFILE", home)
+ envconfig.ReloadServerConfig()
+}
diff --git a/server/upload.go b/server/upload.go
index 2bd408d3ff0..35a32c6797e 100644
--- a/server/upload.go
+++ b/server/upload.go
@@ -21,12 +21,14 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/format"
+ "github.com/ollama/ollama/manifest"
+ "github.com/ollama/ollama/types/model"
)
var blobUploadManager sync.Map
type blobUpload struct {
- Layer
+ manifest.Layer
Total int64
Completed atomic.Int64
@@ -51,7 +53,7 @@ const (
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
- p, err := GetBlobsPath(b.Digest)
+ p, err := manifest.BlobsPath(b.Digest)
if err != nil {
return err
}
@@ -59,7 +61,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *reg
if b.From != "" {
values := requestURL.Query()
values.Add("mount", b.Digest)
- values.Add("from", ParseModelPath(b.From).GetNamespaceRepository())
+ values.Add("from", model.ParseName(b.From).DisplayNamespaceModel())
requestURL.RawQuery = values.Encode()
}
@@ -128,7 +130,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
- p, err := GetBlobsPath(b.Digest)
+ p, err := manifest.BlobsPath(b.Digest)
if err != nil {
b.err = err
return
@@ -364,9 +366,9 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
-func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
- requestURL := mp.BaseURL()
- requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
+func uploadBlob(ctx context.Context, n model.Name, layer manifest.Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
+ requestURL := n.BaseURL()
+ requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs", layer.Digest)
resp, err := makeRequestWithRetry(ctx, http.MethodHead, requestURL, nil, nil, opts)
switch {
@@ -388,8 +390,8 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOp
data, ok := blobUploadManager.LoadOrStore(layer.Digest, &blobUpload{Layer: layer})
upload := data.(*blobUpload)
if !ok {
- requestURL := mp.BaseURL()
- requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs/uploads/")
+ requestURL := n.BaseURL()
+ requestURL = requestURL.JoinPath("v2", n.DisplayNamespaceModel(), "blobs/uploads/")
if err := upload.Prepare(ctx, requestURL, opts); err != nil {
blobUploadManager.Delete(layer.Digest)
return err
diff --git a/x/model/bytepairencoding.go b/tokenizer/bytepairencoding.go
similarity index 93%
rename from x/model/bytepairencoding.go
rename to tokenizer/bytepairencoding.go
index acb58743b1e..b592aeedb55 100644
--- a/x/model/bytepairencoding.go
+++ b/tokenizer/bytepairencoding.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"cmp"
@@ -18,19 +18,19 @@ type BytePairEncoding struct {
regexps []*regexp2.Regexp
}
-var _ TextProcessor = (*BytePairEncoding)(nil)
+var _ Tokenizer = (*BytePairEncoding)(nil)
-func NewBytePairEncoding(vocab *Vocabulary, pretokenizers ...string) BytePairEncoding {
- if len(pretokenizers) == 0 {
+func NewBytePairEncoding(vocab *Vocabulary, pretokenizer ...string) BytePairEncoding {
+ if len(pretokenizer) == 0 {
// set default byte-level pretokenizer if none provided, e.g.
- // https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L44
- pretokenizers = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
+ // https://github.com/huggingface/tokenizer/blob/main/tokenizer/src/pre_tokenizer/byte_level.rs#L44
+ pretokenizer = []string{`'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`}
}
return BytePairEncoding{
vocab: vocab,
regexps: slices.Collect(func(yield func(*regexp2.Regexp) bool) {
- for _, p := range pretokenizers {
+ for _, p := range pretokenizer {
if !yield(regexp2.MustCompile(p, regexp2.RE2)) {
return
}
diff --git a/model/bytepairencoding_test.go b/tokenizer/bytepairencoding_test.go
similarity index 98%
rename from model/bytepairencoding_test.go
rename to tokenizer/bytepairencoding_test.go
index 15cb56ca978..9b9e901a7b0 100644
--- a/model/bytepairencoding_test.go
+++ b/tokenizer/bytepairencoding_test.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"bufio"
@@ -17,7 +17,7 @@ import (
func llama(t testing.TB) BytePairEncoding {
t.Helper()
- f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
+ f, err := os.Open(filepath.FromSlash("testdata/llama3.2/encoder.json"))
if err != nil {
t.Fatal(err)
}
@@ -43,7 +43,7 @@ func llama(t testing.TB) BytePairEncoding {
}
}
- f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
+ f, err = os.Open(filepath.FromSlash("testdata/llama3.2/vocab.bpe"))
if err != nil {
t.Fatal(err)
}
diff --git a/model/sentencepiece.go b/tokenizer/sentencepiece.go
similarity index 98%
rename from model/sentencepiece.go
rename to tokenizer/sentencepiece.go
index 2c178ec0c08..c41b9d9a62f 100644
--- a/model/sentencepiece.go
+++ b/tokenizer/sentencepiece.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"container/heap"
@@ -17,7 +17,7 @@ type SentencePiece struct {
vocab *Vocabulary
}
-var _ TextProcessor = (*SentencePiece)(nil)
+var _ Tokenizer = (*SentencePiece)(nil)
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
@@ -224,7 +224,7 @@ func (spm SentencePiece) Decode(ids []int32) (string, error) {
data := spm.vocab.Decode(id)
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
- // For tokenizers that use byte tokens like "<0xEA>"
+ // For tokenizer that use byte tokens like "<0xEA>"
// convert them to the partial unicode character
// so they are buffered correctly by the runner instead
// of being sent back to the api as "<0xEA>"
diff --git a/model/sentencepiece_test.go b/tokenizer/sentencepiece_test.go
similarity index 97%
rename from model/sentencepiece_test.go
rename to tokenizer/sentencepiece_test.go
index 8f4570c173f..dd60953f945 100644
--- a/model/sentencepiece_test.go
+++ b/tokenizer/sentencepiece_test.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"log/slog"
@@ -15,7 +15,7 @@ import (
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
- bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
+ bts, err := os.ReadFile(filepath.FromSlash("testdata/gemma2/tokenizer.model"))
if err != nil {
t.Fatal(err)
}
diff --git a/model/testdata/gemma2/tokenizer.model b/tokenizer/testdata/gemma2/tokenizer.model
similarity index 100%
rename from model/testdata/gemma2/tokenizer.model
rename to tokenizer/testdata/gemma2/tokenizer.model
diff --git a/model/testdata/llama3.2/encoder.json b/tokenizer/testdata/llama3.2/encoder.json
similarity index 100%
rename from model/testdata/llama3.2/encoder.json
rename to tokenizer/testdata/llama3.2/encoder.json
diff --git a/model/testdata/llama3.2/vocab.bpe b/tokenizer/testdata/llama3.2/vocab.bpe
similarity index 100%
rename from model/testdata/llama3.2/vocab.bpe
rename to tokenizer/testdata/llama3.2/vocab.bpe
diff --git a/model/testdata/war-and-peace.txt b/tokenizer/testdata/war-and-peace.txt
similarity index 100%
rename from model/testdata/war-and-peace.txt
rename to tokenizer/testdata/war-and-peace.txt
diff --git a/model/textprocessor.go b/tokenizer/tokenizer.go
similarity index 86%
rename from model/textprocessor.go
rename to tokenizer/tokenizer.go
index 4a36f235290..64b5c410245 100644
--- a/model/textprocessor.go
+++ b/tokenizer/tokenizer.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
const (
TOKEN_TYPE_NORMAL = iota + 1
@@ -9,7 +9,7 @@ const (
TOKEN_TYPE_BYTE
)
-type TextProcessor interface {
+type Tokenizer interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
diff --git a/model/vocabulary.go b/tokenizer/vocabulary.go
similarity index 99%
rename from model/vocabulary.go
rename to tokenizer/vocabulary.go
index d977c495781..f5d71ef69c5 100644
--- a/model/vocabulary.go
+++ b/tokenizer/vocabulary.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"log/slog"
diff --git a/model/vocabulary_test.go b/tokenizer/vocabulary_test.go
similarity index 99%
rename from model/vocabulary_test.go
rename to tokenizer/vocabulary_test.go
index ccfc39e6945..72f73203a04 100644
--- a/model/vocabulary_test.go
+++ b/tokenizer/vocabulary_test.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"testing"
diff --git a/model/wordpiece.go b/tokenizer/wordpiece.go
similarity index 94%
rename from model/wordpiece.go
rename to tokenizer/wordpiece.go
index e552bce0dd3..91569ca309e 100644
--- a/model/wordpiece.go
+++ b/tokenizer/wordpiece.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"fmt"
@@ -32,7 +32,7 @@ var wordPieceReplacer = strings.NewReplacer(
" 're", "'re",
)
-// Decode implements TextProcessor.
+// Decode implements Tokenizer.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
@@ -96,7 +96,7 @@ func (wpm WordPiece) words(s string) iter.Seq[string] {
}
}
-// Encode implements TextProcessor.
+// Encode implements Tokenizer.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
@@ -151,17 +151,17 @@ func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
return ids, nil
}
-// Is implements TextProcessor.
+// Is implements Tokenizer.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
-// Vocabulary implements TextProcessor.
+// Vocabulary implements Tokenizer.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
-var _ TextProcessor = (*WordPiece)(nil)
+var _ Tokenizer = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
return WordPiece{
diff --git a/x/model/wordpiece_test.go b/tokenizer/wordpiece_test.go
similarity index 98%
rename from x/model/wordpiece_test.go
rename to tokenizer/wordpiece_test.go
index c03bb17a725..cbc398c8ebf 100644
--- a/x/model/wordpiece_test.go
+++ b/tokenizer/wordpiece_test.go
@@ -1,4 +1,4 @@
-package model
+package tokenizer
import (
"slices"
diff --git a/types/model/capability.go b/types/model/capability.go
index 7ecfd848cc7..aeac37961de 100644
--- a/types/model/capability.go
+++ b/types/model/capability.go
@@ -3,13 +3,13 @@ package model
type Capability string
const (
- CapabilityCompletion = Capability("completion")
- CapabilityTools = Capability("tools")
- CapabilityInsert = Capability("insert")
- CapabilityVision = Capability("vision")
- CapabilityEmbedding = Capability("embedding")
- CapabilityThinking = Capability("thinking")
- CapabilityImage = Capability("image")
+ CapabilityCompletion = Capability("completion")
+ CapabilityTools = Capability("tools")
+ CapabilityInsert = Capability("insert")
+ CapabilityVision = Capability("vision")
+ CapabilityEmbedding = Capability("embedding")
+ CapabilityThinking = Capability("thinking")
+ CapabilityImage = Capability("image")
)
func (c Capability) String() string {
diff --git a/types/model/name.go b/types/model/name.go
index a46f3e28d86..311326d4cc9 100644
--- a/types/model/name.go
+++ b/types/model/name.go
@@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"log/slog"
+ "net/url"
"path/filepath"
"strings"
)
@@ -35,22 +36,25 @@ func Unqualified(n Name) error {
const MissingPart = "!MISSING!"
const (
- defaultHost = "registry.ollama.ai"
- defaultNamespace = "library"
- defaultTag = "latest"
+ defaultHost = "registry.ollama.ai"
+ defaultNamespace = "library"
+ defaultTag = "latest"
+ defaultProtocolScheme = "https"
)
// DefaultName returns a name with the default values for the host, namespace,
-// and tag parts. The model and digest parts are empty.
+// tag, and protocol scheme parts. The model and digest parts are empty.
//
// - The default host is ("registry.ollama.ai")
// - The default namespace is ("library")
// - The default tag is ("latest")
+// - The default protocol scheme is ("https")
func DefaultName() Name {
return Name{
- Host: defaultHost,
- Namespace: defaultNamespace,
- Tag: defaultTag,
+ Host: defaultHost,
+ Namespace: defaultNamespace,
+ Tag: defaultTag,
+ ProtocolScheme: defaultProtocolScheme,
}
}
@@ -87,10 +91,11 @@ func (k partKind) String() string {
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
type Name struct {
- Host string
- Namespace string
- Model string
- Tag string
+ Host string
+ Namespace string
+ Model string
+ Tag string
+ ProtocolScheme string
}
// ParseName parses and assembles a Name from a name string. The
@@ -160,7 +165,9 @@ func ParseNameBare(s string) Name {
}
scheme, host, ok := strings.Cut(s, "://")
- if !ok {
+ if ok {
+ n.ProtocolScheme = scheme
+ } else {
host = scheme
}
n.Host = host
@@ -189,12 +196,13 @@ func ParseNameFromFilepath(s string) (n Name) {
return n
}
-// Merge merges the host, namespace, and tag parts of the two names,
+// Merge merges the host, namespace, tag, and protocol scheme parts of the two names,
// preferring the non-empty parts of a.
func Merge(a, b Name) Name {
a.Host = cmp.Or(a.Host, b.Host)
a.Namespace = cmp.Or(a.Namespace, b.Namespace)
a.Tag = cmp.Or(a.Tag, b.Tag)
+ a.ProtocolScheme = cmp.Or(a.ProtocolScheme, b.ProtocolScheme)
return a
}
@@ -305,6 +313,23 @@ func (n Name) EqualFold(o Name) bool {
strings.EqualFold(n.Tag, o.Tag)
}
+// BaseURL returns the base URL for the registry.
+func (n Name) BaseURL() *url.URL {
+ return &url.URL{
+ Scheme: n.ProtocolScheme,
+ Host: n.Host,
+ }
+}
+
+// DisplayNamespaceModel returns the namespace and model joined by "/".
+func (n Name) DisplayNamespaceModel() string {
+ var b strings.Builder
+ b.WriteString(n.Namespace)
+ b.WriteByte('/')
+ b.WriteString(n.Model)
+ return b.String()
+}
+
func isValidLen(kind partKind, s string) bool {
switch kind {
case kindHost:
diff --git a/types/model/name_test.go b/types/model/name_test.go
index 794d14d798e..0569037289a 100644
--- a/types/model/name_test.go
+++ b/types/model/name_test.go
@@ -32,10 +32,11 @@ func TestParseNameParts(t *testing.T) {
{
in: "scheme://host:port/namespace/model:tag",
want: Name{
- Host: "host:port",
- Namespace: "namespace",
- Model: "model",
- Tag: "tag",
+ Host: "host:port",
+ Namespace: "namespace",
+ Model: "model",
+ Tag: "tag",
+ ProtocolScheme: "scheme",
},
wantFilepath: filepath.Join("host:port", "namespace", "model", "tag"),
},
diff --git a/x/cmd/run.go b/x/cmd/run.go
index 1bd452cd862..e5d7ea25e75 100644
--- a/x/cmd/run.go
+++ b/x/cmd/run.go
@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
+ "net/http"
"net/url"
"os"
"os/signal"
@@ -18,6 +19,7 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/model"
@@ -62,6 +64,18 @@ func isLocalServer() bool {
return hostname == "localhost" || hostname == "127.0.0.1" || strings.Contains(parsed.Host, ":11434")
}
+func cloudStatusDisabled(ctx context.Context, client *api.Client) (disabled bool, known bool) {
+ status, err := client.CloudStatusExperimental(ctx)
+ if err != nil {
+ var statusErr api.StatusError
+ if errors.As(err, &statusErr) && statusErr.StatusCode == http.StatusNotFound {
+ return false, false
+ }
+ return false, false
+ }
+ return status.Cloud.Disabled, true
+}
+
// truncateToolOutput truncates tool output to prevent context overflow.
// Uses a smaller limit (4k tokens) for local models, larger (10k) for cloud/remote.
func truncateToolOutput(output, modelName string) string {
@@ -86,6 +100,10 @@ func waitForOllamaSignin(ctx context.Context) error {
return err
}
+ if disabled, known := cloudStatusDisabled(ctx, client); known && disabled {
+ return errors.New(internalcloud.DisabledError("cloud account endpoints are unavailable"))
+ }
+
// Get signin URL from initial Whoami call
_, err = client.Whoami(ctx)
if err != nil {
@@ -664,6 +682,15 @@ func GenerateInteractive(cmd *cobra.Command, modelName string, wordWrap bool, op
supportsTools = false
}
+ if enableWebsearch {
+ if client, err := api.ClientFromEnvironment(); err == nil {
+ if disabled, known := cloudStatusDisabled(cmd.Context(), client); known && disabled {
+ fmt.Fprintf(os.Stderr, "%s\n", internalcloud.DisabledError("web search is unavailable"))
+ enableWebsearch = false
+ }
+ }
+ }
+
// Create tool registry only if model supports tools
var toolRegistry *tools.Registry
if supportsTools {
diff --git a/x/create/client/create.go b/x/create/client/create.go
index 7729c6e5f3b..b8062f3d44b 100644
--- a/x/create/client/create.go
+++ b/x/create/client/create.go
@@ -11,11 +11,15 @@ import (
"encoding/json"
"fmt"
"io"
+ "os"
+ "path/filepath"
+ "strings"
+ "github.com/ollama/ollama/manifest"
"github.com/ollama/ollama/progress"
- "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/x/create"
+ "github.com/ollama/ollama/x/imagegen/safetensors"
)
// MinOllamaVersion is the minimum Ollama version required for safetensors models.
@@ -26,14 +30,16 @@ type ModelfileConfig struct {
Template string
System string
License string
+ Parser string
+ Renderer string
}
// CreateOptions holds all options for model creation.
type CreateOptions struct {
ModelName string
ModelDir string
- Quantize string // "fp8" for quantization
- Modelfile *ModelfileConfig // template/system/license from Modelfile
+ Quantize string // "int4", "int8", "nvfp4", or "mxfp8" for quantization
+ Modelfile *ModelfileConfig // template/system/license/parser/renderer from Modelfile
}
// CreateModel imports a model from a local directory.
@@ -51,10 +57,20 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// Determine model type settings
var modelType, spinnerKey string
var capabilities []string
+ var parserName, rendererName string
if isSafetensors {
modelType = "safetensors model"
spinnerKey = "create"
capabilities = []string{"completion"}
+
+ // Check if model supports thinking based on architecture
+ if supportsThinking(opts.ModelDir) {
+ capabilities = append(capabilities, "thinking")
+ }
+
+ // Set parser and renderer name based on architecture
+ parserName = getParserName(opts.ModelDir)
+ rendererName = getRendererName(opts.ModelDir)
} else {
modelType = "image generation model"
spinnerKey = "imagegen"
@@ -79,14 +95,15 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
err = create.CreateSafetensorsModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
- newManifestWriter(opts, capabilities),
+ newManifestWriter(opts, capabilities, parserName, rendererName),
progressFn,
+ newPackedTensorLayerCreator(),
)
} else {
err = create.CreateImageGenModel(
opts.ModelName, opts.ModelDir, opts.Quantize,
newLayerCreator(), newTensorLayerCreator(),
- newManifestWriter(opts, capabilities),
+ newManifestWriter(opts, capabilities, "", ""),
progressFn,
)
}
@@ -103,7 +120,7 @@ func CreateModel(opts CreateOptions, p *progress.Progress) error {
// newLayerCreator returns a LayerCreator callback for creating config/JSON layers.
func newLayerCreator() create.LayerCreator {
return func(r io.Reader, mediaType, name string) (create.LayerInfo, error) {
- layer, err := server.NewLayer(r, mediaType)
+ layer, err := manifest.NewLayer(r, mediaType)
if err != nil {
return create.LayerInfo{}, err
}
@@ -128,65 +145,38 @@ func newTensorLayerCreator() create.QuantizingTensorLayerCreator {
}
}
-// createQuantizedLayers quantizes a tensor and returns the resulting layers.
+// createQuantizedLayers quantizes a tensor and returns a single combined layer.
+// The combined blob contains data, scale, and optional bias tensors with metadata.
func createQuantizedLayers(r io.Reader, name, dtype string, shape []int32, quantize string) ([]create.LayerInfo, error) {
if !QuantizeSupported() {
return nil, fmt.Errorf("quantization requires MLX support")
}
- // Quantize the tensor
- qweightData, scalesData, qbiasData, _, _, _, err := quantizeTensor(r, name, dtype, shape, quantize)
+ // Quantize the tensor into a single combined blob
+ blobData, err := quantizeTensor(r, name, dtype, shape, quantize)
if err != nil {
return nil, fmt.Errorf("failed to quantize %s: %w", name, err)
}
- // Create layer for quantized weight
- weightLayer, err := server.NewLayer(bytes.NewReader(qweightData), server.MediaTypeImageTensor)
- if err != nil {
- return nil, err
- }
-
- // Create layer for scales
- scalesLayer, err := server.NewLayer(bytes.NewReader(scalesData), server.MediaTypeImageTensor)
+ // Create single layer for the combined blob
+ layer, err := manifest.NewLayer(bytes.NewReader(blobData), manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
- layers := []create.LayerInfo{
+ return []create.LayerInfo{
{
- Digest: weightLayer.Digest,
- Size: weightLayer.Size,
- MediaType: weightLayer.MediaType,
+ Digest: layer.Digest,
+ Size: layer.Size,
+ MediaType: layer.MediaType,
Name: name,
},
- {
- Digest: scalesLayer.Digest,
- Size: scalesLayer.Size,
- MediaType: scalesLayer.MediaType,
- Name: name + "_scale",
- },
- }
-
- // Add qbiases layer if present (affine mode)
- if qbiasData != nil {
- qbiasLayer, err := server.NewLayer(bytes.NewReader(qbiasData), server.MediaTypeImageTensor)
- if err != nil {
- return nil, err
- }
- layers = append(layers, create.LayerInfo{
- Digest: qbiasLayer.Digest,
- Size: qbiasLayer.Size,
- MediaType: qbiasLayer.MediaType,
- Name: name + "_qbias",
- })
- }
-
- return layers, nil
+ }, nil
}
// createUnquantizedLayer creates a single tensor layer without quantization.
func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error) {
- layer, err := server.NewLayer(r, server.MediaTypeImageTensor)
+ layer, err := manifest.NewLayer(r, manifest.MediaTypeImageTensor)
if err != nil {
return nil, err
}
@@ -201,19 +191,86 @@ func createUnquantizedLayer(r io.Reader, name string) ([]create.LayerInfo, error
}, nil
}
+// newPackedTensorLayerCreator returns a PackedTensorLayerCreator callback for
+// creating packed multi-tensor blob layers (used for expert groups).
+func newPackedTensorLayerCreator() create.PackedTensorLayerCreator {
+ return func(groupName string, tensors []create.PackedTensorInput) (create.LayerInfo, error) {
+ // Check if any tensor in the group needs quantization
+ hasQuantize := false
+ for _, t := range tensors {
+ if t.Quantize != "" {
+ hasQuantize = true
+ break
+ }
+ }
+
+ var blobReader io.Reader
+ if hasQuantize {
+ if !QuantizeSupported() {
+ return create.LayerInfo{}, fmt.Errorf("quantization requires MLX support")
+ }
+ blobData, err := quantizePackedGroup(tensors)
+ if err != nil {
+ return create.LayerInfo{}, fmt.Errorf("failed to quantize packed group %s: %w", groupName, err)
+ }
+ blobReader = bytes.NewReader(blobData)
+ } else {
+ // Build unquantized packed blob using streaming reader
+ // Extract raw tensor data from safetensors-wrapped readers
+ var tds []*safetensors.TensorData
+ for _, t := range tensors {
+ rawData, err := safetensors.ExtractRawFromSafetensors(t.Reader)
+ if err != nil {
+ return create.LayerInfo{}, fmt.Errorf("failed to extract tensor %s: %w", t.Name, err)
+ }
+ td := safetensors.NewTensorDataFromBytes(t.Name, t.Dtype, t.Shape, rawData)
+ tds = append(tds, td)
+ }
+ blobReader = safetensors.BuildPackedSafetensorsReader(tds)
+ }
+
+ layer, err := manifest.NewLayer(blobReader, manifest.MediaTypeImageTensor)
+ if err != nil {
+ return create.LayerInfo{}, err
+ }
+
+ return create.LayerInfo{
+ Digest: layer.Digest,
+ Size: layer.Size,
+ MediaType: layer.MediaType,
+ Name: groupName,
+ }, nil
+ }
+}
+
// newManifestWriter returns a ManifestWriter callback for writing the model manifest.
-func newManifestWriter(opts CreateOptions, capabilities []string) create.ManifestWriter {
+func newManifestWriter(opts CreateOptions, capabilities []string, parserName, rendererName string) create.ManifestWriter {
return func(modelName string, config create.LayerInfo, layers []create.LayerInfo) error {
name := model.ParseName(modelName)
if !name.IsValid() {
return fmt.Errorf("invalid model name: %s", modelName)
}
+ // TODO: find a better way to detect image input support
+ // For now, hardcode Flux2KleinPipeline as supporting vision (image input)
+ caps := capabilities
+ modelIndex := filepath.Join(opts.ModelDir, "model_index.json")
+ if data, err := os.ReadFile(modelIndex); err == nil {
+ var cfg struct {
+ ClassName string `json:"_class_name"`
+ }
+ if json.Unmarshal(data, &cfg) == nil && cfg.ClassName == "Flux2KleinPipeline" {
+ caps = append(caps, "vision")
+ }
+ }
+
// Create config blob with version requirement
configData := model.ConfigV2{
ModelFormat: "safetensors",
- Capabilities: capabilities,
+ Capabilities: caps,
Requires: MinOllamaVersion,
+ Parser: resolveParserName(opts.Modelfile, parserName),
+ Renderer: resolveRendererName(opts.Modelfile, rendererName),
}
configJSON, err := json.Marshal(configData)
if err != nil {
@@ -221,15 +278,15 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
}
// Create config layer blob
- configLayer, err := server.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
+ configLayer, err := manifest.NewLayer(bytes.NewReader(configJSON), "application/vnd.docker.container.image.v1+json")
if err != nil {
return fmt.Errorf("failed to create config layer: %w", err)
}
- // Convert LayerInfo to server.Layer
- serverLayers := make([]server.Layer, 0, len(layers))
+ // Convert LayerInfo to manifest.Layer
+ manifestLayers := make([]manifest.Layer, 0, len(layers))
for _, l := range layers {
- serverLayers = append(serverLayers, server.Layer{
+ manifestLayers = append(manifestLayers, manifest.Layer{
MediaType: l.MediaType,
Digest: l.Digest,
Size: l.Size,
@@ -243,19 +300,35 @@ func newManifestWriter(opts CreateOptions, capabilities []string) create.Manifes
if err != nil {
return err
}
- serverLayers = append(serverLayers, modelfileLayers...)
+ manifestLayers = append(manifestLayers, modelfileLayers...)
}
- return server.WriteManifest(name, configLayer, serverLayers)
+ return manifest.WriteManifest(name, configLayer, manifestLayers)
}
}
+func resolveParserName(mf *ModelfileConfig, inferred string) string {
+ if mf != nil && mf.Parser != "" {
+ return mf.Parser
+ }
+
+ return inferred
+}
+
+func resolveRendererName(mf *ModelfileConfig, inferred string) string {
+ if mf != nil && mf.Renderer != "" {
+ return mf.Renderer
+ }
+
+ return inferred
+}
+
// createModelfileLayers creates layers for template, system, and license from Modelfile config.
-func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
- var layers []server.Layer
+func createModelfileLayers(mf *ModelfileConfig) ([]manifest.Layer, error) {
+ var layers []manifest.Layer
if mf.Template != "" {
- layer, err := server.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
+ layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.Template)), "application/vnd.ollama.image.template")
if err != nil {
return nil, fmt.Errorf("failed to create template layer: %w", err)
}
@@ -263,7 +336,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.System != "" {
- layer, err := server.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
+ layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.System)), "application/vnd.ollama.image.system")
if err != nil {
return nil, fmt.Errorf("failed to create system layer: %w", err)
}
@@ -271,7 +344,7 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
}
if mf.License != "" {
- layer, err := server.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
+ layer, err := manifest.NewLayer(bytes.NewReader([]byte(mf.License)), "application/vnd.ollama.image.license")
if err != nil {
return nil, fmt.Errorf("failed to create license layer: %w", err)
}
@@ -280,3 +353,146 @@ func createModelfileLayers(mf *ModelfileConfig) ([]server.Layer, error) {
return layers, nil
}
+
+// supportsThinking checks if the model supports thinking mode based on its architecture.
+// This reads the config.json from the model directory and checks the architectures field.
+func supportsThinking(modelDir string) bool {
+ configPath := filepath.Join(modelDir, "config.json")
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ return false
+ }
+
+ var cfg struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return false
+ }
+
+ // Check architectures that support thinking
+ thinkingArchitectures := []string{
+ "glm4moe", // GLM-4 MoE models
+ "deepseek", // DeepSeek models
+ "qwen3", // Qwen3 models
+ }
+
+ // Check the architecture list
+ for _, arch := range cfg.Architectures {
+ archLower := strings.ToLower(arch)
+ for _, thinkArch := range thinkingArchitectures {
+ if strings.Contains(archLower, thinkArch) {
+ return true
+ }
+ }
+ }
+
+ // Also check model_type
+ if cfg.ModelType != "" {
+ typeLower := strings.ToLower(cfg.ModelType)
+ for _, thinkArch := range thinkingArchitectures {
+ if strings.Contains(typeLower, thinkArch) {
+ return true
+ }
+ }
+ }
+
+ return false
+}
+
+// getParserName returns the parser name for a model based on its architecture.
+// This reads the config.json from the model directory and determines the appropriate parser.
+func getParserName(modelDir string) string {
+ configPath := filepath.Join(modelDir, "config.json")
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ return ""
+ }
+
+ var cfg struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return ""
+ }
+
+ // Check architectures for known parsers
+ for _, arch := range cfg.Architectures {
+ archLower := strings.ToLower(arch)
+ if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(archLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(archLower, "qwen3") {
+ return "qwen3"
+ }
+ }
+
+ // Also check model_type
+ if cfg.ModelType != "" {
+ typeLower := strings.ToLower(cfg.ModelType)
+ if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(typeLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(typeLower, "qwen3") {
+ return "qwen3"
+ }
+ }
+
+ return ""
+}
+
+// getRendererName returns the renderer name for a model based on its architecture.
+// This reads the config.json from the model directory and determines the appropriate renderer.
+func getRendererName(modelDir string) string {
+ configPath := filepath.Join(modelDir, "config.json")
+ data, err := os.ReadFile(configPath)
+ if err != nil {
+ return ""
+ }
+
+ var cfg struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(data, &cfg); err != nil {
+ return ""
+ }
+
+ // Check architectures for known renderers
+ for _, arch := range cfg.Architectures {
+ archLower := strings.ToLower(arch)
+ if strings.Contains(archLower, "glm4") || strings.Contains(archLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(archLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(archLower, "qwen3") {
+ return "qwen3-coder"
+ }
+ }
+
+ // Also check model_type
+ if cfg.ModelType != "" {
+ typeLower := strings.ToLower(cfg.ModelType)
+ if strings.Contains(typeLower, "glm4") || strings.Contains(typeLower, "glm-4") {
+ return "glm-4.7"
+ }
+ if strings.Contains(typeLower, "deepseek") {
+ return "deepseek3"
+ }
+ if strings.Contains(typeLower, "qwen3") {
+ return "qwen3-coder"
+ }
+ }
+
+ return ""
+}
diff --git a/x/create/client/create_test.go b/x/create/client/create_test.go
index b41807279f7..1e7062237d5 100644
--- a/x/create/client/create_test.go
+++ b/x/create/client/create_test.go
@@ -10,6 +10,8 @@ func TestModelfileConfig(t *testing.T) {
Template: "{{ .Prompt }}",
System: "You are a helpful assistant.",
License: "MIT",
+ Parser: "qwen3",
+ Renderer: "qwen3",
}
if config.Template != "{{ .Prompt }}" {
@@ -21,6 +23,12 @@ func TestModelfileConfig(t *testing.T) {
if config.License != "MIT" {
t.Errorf("License = %q, want %q", config.License, "MIT")
}
+ if config.Parser != "qwen3" {
+ t.Errorf("Parser = %q, want %q", config.Parser, "qwen3")
+ }
+ if config.Renderer != "qwen3" {
+ t.Errorf("Renderer = %q, want %q", config.Renderer, "qwen3")
+ }
}
func TestModelfileConfig_Empty(t *testing.T) {
@@ -35,6 +43,12 @@ func TestModelfileConfig_Empty(t *testing.T) {
if config.License != "" {
t.Errorf("License should be empty, got %q", config.License)
}
+ if config.Parser != "" {
+ t.Errorf("Parser should be empty, got %q", config.Parser)
+ }
+ if config.Renderer != "" {
+ t.Errorf("Renderer should be empty, got %q", config.Renderer)
+ }
}
func TestModelfileConfig_PartialFields(t *testing.T) {
@@ -53,6 +67,12 @@ func TestModelfileConfig_PartialFields(t *testing.T) {
if config.License != "" {
t.Error("License should be empty")
}
+ if config.Parser != "" {
+ t.Error("Parser should be empty")
+ }
+ if config.Renderer != "" {
+ t.Error("Renderer should be empty")
+ }
}
func TestMinOllamaVersion(t *testing.T) {
@@ -98,6 +118,8 @@ func TestCreateOptions(t *testing.T) {
Template: "test",
System: "system",
License: "MIT",
+ Parser: "qwen3-thinking",
+ Renderer: "qwen3",
},
}
@@ -116,6 +138,92 @@ func TestCreateOptions(t *testing.T) {
if opts.Modelfile.Template != "test" {
t.Errorf("Modelfile.Template = %q, want %q", opts.Modelfile.Template, "test")
}
+ if opts.Modelfile.Parser != "qwen3-thinking" {
+ t.Errorf("Modelfile.Parser = %q, want %q", opts.Modelfile.Parser, "qwen3-thinking")
+ }
+ if opts.Modelfile.Renderer != "qwen3" {
+ t.Errorf("Modelfile.Renderer = %q, want %q", opts.Modelfile.Renderer, "qwen3")
+ }
+}
+
+func TestResolveParserName(t *testing.T) {
+ tests := []struct {
+ name string
+ mf *ModelfileConfig
+ inferred string
+ want string
+ }{
+ {
+ name: "nil modelfile uses inferred",
+ mf: nil,
+ inferred: "qwen3",
+ want: "qwen3",
+ },
+ {
+ name: "empty parser uses inferred",
+ mf: &ModelfileConfig{
+ Parser: "",
+ },
+ inferred: "qwen3",
+ want: "qwen3",
+ },
+ {
+ name: "explicit parser overrides inferred",
+ mf: &ModelfileConfig{
+ Parser: "qwen3-thinking",
+ },
+ inferred: "qwen3",
+ want: "qwen3-thinking",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := resolveParserName(tt.mf, tt.inferred); got != tt.want {
+ t.Fatalf("resolveParserName() = %q, want %q", got, tt.want)
+ }
+ })
+ }
+}
+
+func TestResolveRendererName(t *testing.T) {
+ tests := []struct {
+ name string
+ mf *ModelfileConfig
+ inferred string
+ want string
+ }{
+ {
+ name: "nil modelfile uses inferred",
+ mf: nil,
+ inferred: "qwen3-coder",
+ want: "qwen3-coder",
+ },
+ {
+ name: "empty renderer uses inferred",
+ mf: &ModelfileConfig{
+ Renderer: "",
+ },
+ inferred: "qwen3-coder",
+ want: "qwen3-coder",
+ },
+ {
+ name: "explicit renderer overrides inferred",
+ mf: &ModelfileConfig{
+ Renderer: "qwen3",
+ },
+ inferred: "qwen3-coder",
+ want: "qwen3",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ if got := resolveRendererName(tt.mf, tt.inferred); got != tt.want {
+ t.Fatalf("resolveRendererName() = %q, want %q", got, tt.want)
+ }
+ })
+ }
}
func TestCreateOptions_Defaults(t *testing.T) {
diff --git a/x/create/client/quantize.go b/x/create/client/quantize.go
index 3a9f37cfc86..e47f1664d5b 100644
--- a/x/create/client/quantize.go
+++ b/x/create/client/quantize.go
@@ -3,118 +3,195 @@
package client
import (
+ "encoding/binary"
+ "encoding/json"
"fmt"
"io"
"os"
"path/filepath"
+ "strconv"
+ "github.com/ollama/ollama/x/create"
"github.com/ollama/ollama/x/imagegen/mlx"
)
-// quantizeTensor loads a tensor from safetensors format, quantizes it,
-// and returns safetensors data for the quantized weights, scales, and biases.
-// Supported quantization types: "fp8" (affine 8-bit)
-// Uses MLX's native SaveSafetensors to ensure correct dtype handling (especially uint32 for quantized weights).
-func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
+// quantizeParams maps quantization type names to MLX quantize parameters.
+var quantizeParams = map[string]struct {
+ groupSize int
+ bits int
+ mode string
+}{
+ "int4": {32, 4, "affine"},
+ "nvfp4": {16, 4, "nvfp4"},
+ "int8": {64, 8, "affine"},
+ "mxfp8": {32, 8, "mxfp8"},
+}
+
+// loadAndQuantizeArray writes a safetensors reader to a temp file, loads it with MLX,
+// quantizes the tensor, and appends the resulting arrays (weight, scale, optional bias)
+// to the provided maps. If quantize is empty, the tensor is kept as-is.
+// Returns any temp file paths created (caller must clean up) and arrays needing eval.
+func loadAndQuantizeArray(r io.Reader, name, quantize string, arrays map[string]*mlx.Array) (tmpPath string, toEval []*mlx.Array, nativeHandle *mlx.SafetensorsFile, err error) {
tmpDir := ensureTempDir()
- // Read safetensors data to a temp file (LoadSafetensorsNative needs a path)
- tmpFile, err := os.CreateTemp(tmpDir, "quant-input-*.safetensors")
+ tmpFile, err := os.CreateTemp(tmpDir, "quant-*.safetensors")
if err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to create temp file: %w", err)
+ return "", nil, nil, fmt.Errorf("failed to create temp file: %w", err)
}
- tmpPath := tmpFile.Name()
- defer os.Remove(tmpPath)
+ tmpPath = tmpFile.Name()
if _, err := io.Copy(tmpFile, r); err != nil {
tmpFile.Close()
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to write temp file: %w", err)
+ return tmpPath, nil, nil, fmt.Errorf("failed to write temp file for %s: %w", name, err)
}
tmpFile.Close()
- // Load the tensor using MLX's native loader
st, err := mlx.LoadSafetensorsNative(tmpPath)
if err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to load safetensors: %w", err)
+ return tmpPath, nil, nil, fmt.Errorf("failed to load safetensors for %s: %w", name, err)
+ }
+
+ // Find the tensor key (may differ from name for single-tensor blobs)
+ inputKey, err := findSafetensorsKey(tmpPath)
+ if err != nil {
+ st.Free()
+ return tmpPath, nil, nil, fmt.Errorf("failed to read blob header for %s: %w", name, err)
}
- defer st.Free()
- // Get the tensor (it's stored as "data" in our minimal safetensors format)
- arr := st.Get("data")
+ arr := st.Get(inputKey)
if arr == nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("tensor 'data' not found in safetensors")
+ st.Free()
+ return tmpPath, nil, nil, fmt.Errorf("tensor %q not found in safetensors", inputKey)
}
- // Convert to BFloat16 if needed (quantize expects float type)
+ if quantize == "" {
+ arr = mlx.Contiguous(arr)
+ arrays[name] = arr
+ return tmpPath, []*mlx.Array{arr}, st, nil
+ }
+
+ // Convert to float type if needed (quantize expects float)
if arr.Dtype() != mlx.DtypeBFloat16 && arr.Dtype() != mlx.DtypeFloat32 && arr.Dtype() != mlx.DtypeFloat16 {
arr = mlx.AsType(arr, mlx.DtypeBFloat16)
mlx.Eval(arr)
}
- // Quantize based on quantization type
- var qweight, scales, qbiases *mlx.Array
- switch quantize {
- case "fp4":
- // affine mode: group_size=32, bits=4
- qweight, scales, qbiases = mlx.Quantize(arr, 32, 4, "affine")
- case "fp8":
- // affine mode: group_size=32, bits=8
- qweight, scales, qbiases = mlx.Quantize(arr, 32, 8, "affine")
- default:
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
+ params, ok := quantizeParams[quantize]
+ if !ok {
+ st.Free()
+ return tmpPath, nil, nil, fmt.Errorf("unsupported quantization type: %s", quantize)
}
- // Eval and make contiguous for data access
+ qweight, scales, qbiases := mlx.Quantize(arr, params.groupSize, params.bits, params.mode)
+
qweight = mlx.Contiguous(qweight)
scales = mlx.Contiguous(scales)
+ arrays[name] = qweight
+ arrays[name+".scale"] = scales
+ toEval = append(toEval, qweight, scales)
+
if qbiases != nil {
qbiases = mlx.Contiguous(qbiases)
- mlx.Eval(qweight, scales, qbiases)
- } else {
- mlx.Eval(qweight, scales)
+ arrays[name+".bias"] = qbiases
+ toEval = append(toEval, qbiases)
}
- // Get shapes
- qweightShape = qweight.Shape()
- scalesShape = scales.Shape()
+ return tmpPath, toEval, st, nil
+}
- // Save quantized weight using MLX's native safetensors (correctly handles uint32 dtype)
- qweightPath := filepath.Join(tmpDir, "qweight.safetensors")
- defer os.Remove(qweightPath)
- if err := mlx.SaveSafetensors(qweightPath, map[string]*mlx.Array{"data": qweight}); err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save quantized weight: %w", err)
+// quantizeTensor loads a tensor from safetensors format, quantizes it,
+// and returns a single combined safetensors blob with the quantized weight, scale, and optional bias.
+// Tensor keys use the original tensor name: name, name.scale, name.bias.
+// The blob includes __metadata__ with quant_type and group_size.
+// Supported quantization types: "int4", "nvfp4", "int8", "mxfp8".
+func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
+ arrays := make(map[string]*mlx.Array)
+ tmpPath, toEval, st, err := loadAndQuantizeArray(r, tensorName, quantize, arrays)
+ if tmpPath != "" {
+ defer os.Remove(tmpPath)
+ }
+ if st != nil {
+ defer st.Free()
}
- qweightData, err = os.ReadFile(qweightPath)
if err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read quantized weight: %w", err)
+ return nil, err
}
- // Save scales using MLX's native safetensors
- scalesPath := filepath.Join(tmpDir, "scales.safetensors")
- defer os.Remove(scalesPath)
- if err := mlx.SaveSafetensors(scalesPath, map[string]*mlx.Array{"data": scales}); err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save scales: %w", err)
+ mlx.Eval(toEval...)
+
+ // Build metadata for single-tensor blobs
+ params := quantizeParams[quantize]
+ metadata := map[string]string{
+ "quant_type": quantize,
+ "group_size": strconv.Itoa(params.groupSize),
}
- scalesData, err = os.ReadFile(scalesPath)
- if err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read scales: %w", err)
+
+ tmpDir := ensureTempDir()
+ outPath := filepath.Join(tmpDir, "combined.safetensors")
+ defer os.Remove(outPath)
+ if err := mlx.SaveSafetensorsWithMetadata(outPath, arrays, metadata); err != nil {
+ return nil, fmt.Errorf("failed to save combined blob: %w", err)
}
+ return os.ReadFile(outPath)
+}
- // Affine mode returns qbiases for zero-point offset
- if qbiases != nil {
- qbiasShape = qbiases.Shape()
- qbiasPath := filepath.Join(tmpDir, "qbias.safetensors")
- defer os.Remove(qbiasPath)
- if err := mlx.SaveSafetensors(qbiasPath, map[string]*mlx.Array{"data": qbiases}); err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to save qbiases: %w", err)
+// quantizePackedGroup quantizes multiple tensors and saves them all into a single
+// combined safetensors blob. Used for packing expert groups.
+// Each tensor may have a different quantization type (mixed-precision).
+// Returns the blob bytes. No __metadata__ is added because different tensors
+// may use different quantization types.
+func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
+ allArrays := make(map[string]*mlx.Array)
+ var allToEval []*mlx.Array
+ var tmpPaths []string
+ var handles []*mlx.SafetensorsFile
+
+ for _, input := range inputs {
+ tmpPath, toEval, st, err := loadAndQuantizeArray(input.Reader, input.Name, input.Quantize, allArrays)
+ if tmpPath != "" {
+ tmpPaths = append(tmpPaths, tmpPath)
+ }
+ if st != nil {
+ handles = append(handles, st)
}
- qbiasData, err = os.ReadFile(qbiasPath)
if err != nil {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("failed to read qbiases: %w", err)
+ // Cleanup on error
+ for _, h := range handles {
+ h.Free()
+ }
+ for _, p := range tmpPaths {
+ os.Remove(p)
+ }
+ return nil, err
}
+ allToEval = append(allToEval, toEval...)
}
- return qweightData, scalesData, qbiasData, qweightShape, scalesShape, qbiasShape, nil
+ mlx.Eval(allToEval...)
+
+ // Free native handles after eval
+ for _, h := range handles {
+ h.Free()
+ }
+
+ // Save combined blob (no global metadata for mixed-precision packed blobs)
+ tmpDir := ensureTempDir()
+ outPath := filepath.Join(tmpDir, "packed-combined.safetensors")
+ defer os.Remove(outPath)
+ if err := mlx.SaveSafetensorsWithMetadata(outPath, allArrays, nil); err != nil {
+ return nil, fmt.Errorf("failed to save packed blob: %w", err)
+ }
+
+ blobData, err := os.ReadFile(outPath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read packed blob: %w", err)
+ }
+
+ for _, p := range tmpPaths {
+ os.Remove(p)
+ }
+
+ return blobData, nil
}
// QuantizeSupported returns true if quantization is supported (MLX build)
@@ -128,3 +205,33 @@ func ensureTempDir() string {
os.MkdirAll(tmpDir, 0755)
return tmpDir
}
+
+// findSafetensorsKey reads the first non-metadata tensor key from a safetensors file.
+func findSafetensorsKey(path string) (string, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return "", err
+ }
+ defer f.Close()
+
+ var headerSize uint64
+ if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
+ return "", err
+ }
+ headerBytes := make([]byte, headerSize)
+ if _, err := io.ReadFull(f, headerBytes); err != nil {
+ return "", err
+ }
+
+ var header map[string]json.RawMessage
+ if err := json.Unmarshal(headerBytes, &header); err != nil {
+ return "", err
+ }
+
+ for k := range header {
+ if k != "__metadata__" {
+ return k, nil
+ }
+ }
+ return "", fmt.Errorf("no tensor found in safetensors header")
+}
diff --git a/x/create/client/quantize_stub.go b/x/create/client/quantize_stub.go
index 3a85afcc719..7a75671a034 100644
--- a/x/create/client/quantize_stub.go
+++ b/x/create/client/quantize_stub.go
@@ -5,11 +5,18 @@ package client
import (
"fmt"
"io"
+
+ "github.com/ollama/ollama/x/create"
)
// quantizeTensor is not available without MLX
-func quantizeTensor(r io.Reader, name, dtype string, shape []int32, quantize string) (qweightData, scalesData, qbiasData []byte, qweightShape, scalesShape, qbiasShape []int32, err error) {
- return nil, nil, nil, nil, nil, nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
+func quantizeTensor(r io.Reader, tensorName, dtype string, shape []int32, quantize string) (blobData []byte, err error) {
+ return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
+}
+
+// quantizePackedGroup is not available without MLX
+func quantizePackedGroup(inputs []create.PackedTensorInput) ([]byte, error) {
+ return nil, fmt.Errorf("quantization requires MLX support (build with mlx tag)")
}
// QuantizeSupported returns false when MLX is not available
diff --git a/x/create/create.go b/x/create/create.go
index 823d0f842b5..385efadab63 100644
--- a/x/create/create.go
+++ b/x/create/create.go
@@ -6,7 +6,9 @@ import (
"io"
"os"
"path/filepath"
+ "regexp"
"slices"
+ "sort"
"strings"
"github.com/ollama/ollama/envconfig"
@@ -228,7 +230,7 @@ type LayerCreator func(r io.Reader, mediaType, name string) (LayerInfo, error)
type TensorLayerCreator func(r io.Reader, name, dtype string, shape []int32) (LayerInfo, error)
// QuantizingTensorLayerCreator creates tensor layers with optional quantization.
-// When quantize is non-empty (e.g., "fp8"), returns multiple layers (weight + scales + biases).
+// When quantize is non-empty (e.g., "int8"), returns multiple layers (weight + scales + biases).
type QuantizingTensorLayerCreator func(r io.Reader, name, dtype string, shape []int32, quantize string) ([]LayerInfo, error)
// ManifestWriter writes the manifest file.
@@ -262,40 +264,176 @@ func ShouldQuantize(name, component string) bool {
return strings.HasSuffix(name, ".weight")
}
-// ShouldQuantizeTensor returns true if a tensor should be quantized based on name and shape.
+// ShouldQuantizeTensor returns true if a tensor should be quantized based on name, shape, and quantize type.
// This is a more detailed check that also considers tensor dimensions.
-func ShouldQuantizeTensor(name string, shape []int32) bool {
+// The quantize parameter specifies the quantization type (e.g., "int4", "nvfp4", "int8", "mxfp8").
+func ShouldQuantizeTensor(name string, shape []int32, quantize string) bool {
+ return GetTensorQuantization(name, shape, quantize) != ""
+}
+
+// normalizeQuantType converts various quantization type aliases to canonical forms.
+// Supports: q4/Q4/int4/INT4/fp4/FP4 -> int4, q8/Q8/int8/INT8/fp8/FP8 -> int8, nvfp4/NVFP4, mxfp8/MXFP8
+func normalizeQuantType(quantize string) string {
+ switch strings.ToUpper(quantize) {
+ case "Q4", "INT4", "FP4":
+ return "int4"
+ case "Q8", "INT8", "FP8":
+ return "int8"
+ case "NVFP4":
+ return "nvfp4"
+ case "MXFP8":
+ return "mxfp8"
+ default:
+ return quantize
+ }
+}
+
+// GetTensorQuantization returns the appropriate quantization type for a tensor.
+// Returns "" if the tensor should not be quantized.
+// This implements mixed-precision quantization:
+// - Attention MLA weights (q_a, q_b, kv_a, kv_b): unquantized (most sensitive)
+// - Output projection, gate/up weights: int4 (less sensitive)
+// - Down projection weights: int8 (more sensitive, would be Q6 in GGML but no MLX kernel)
+// - Norms, embeddings, biases, routing gates: no quantization
+func GetTensorQuantization(name string, shape []int32, quantize string) string {
// Use basic name-based check first
if !ShouldQuantize(name, "") {
- return false
+ return ""
}
// Only quantize 2D tensors (linear layers) - skip 1D (biases, norms) and higher-D (convolutions if any)
if len(shape) != 2 {
- return false
+ return ""
}
// Skip small tensors (less than 1024 elements) - not worth quantizing
if len(shape) >= 2 && int64(shape[0])*int64(shape[1]) < 1024 {
- return false
+ return ""
}
- // MLX quantization requires last dimension to be divisible by group size (32)
- if shape[len(shape)-1]%32 != 0 {
- return false
+ // Normalize quantization type to canonical form
+ quantNorm := normalizeQuantType(quantize)
+
+ // MLX quantization requires last dimension to be divisible by group size
+ // nvfp4: 16, int4/mxfp8: 32, int8: 64
+ groupSize := int32(32)
+ switch quantNorm {
+ case "nvfp4":
+ groupSize = 16
+ case "int8":
+ groupSize = 64
+ }
+ if shape[len(shape)-1]%groupSize != 0 {
+ return ""
+ }
+
+ // Skip routing gate weights (should stay high precision)
+ // In safetensors these are: mlp.gate.weight (not mlp.gate_proj.weight)
+ if strings.Contains(name, "mlp.gate.weight") && !strings.Contains(name, "_proj") {
+ return ""
+ }
+
+ // For NVFP4 or MXFP8, use the same quantization for all (no mixed precision)
+ if quantNorm == "nvfp4" || quantNorm == "mxfp8" {
+ return quantNorm
+ }
+
+ // Attention MLA weights - keep unquantized (bf16)
+ // These are highly sensitive: errors accumulate in the KV cache over time
+ // q_a_proj, q_b_proj, kv_a_proj_with_mqa, kv_b_proj
+ if strings.Contains(name, "q_a_proj") ||
+ strings.Contains(name, "q_b_proj") ||
+ strings.Contains(name, "kv_a_proj") ||
+ strings.Contains(name, "kv_b_proj") {
+ return "" // No quantization - keep bf16
+ }
+
+ // Down projection weights - use INT8 (would be Q6_K in GGML, but MLX has no Q6 kernel)
+ // mlp.down_proj, mlp.experts.X.down_proj, mlp.shared_experts.down_proj
+ if strings.Contains(name, "down_proj") {
+ return "int8"
+ }
+
+ // Output projection, gate/up weights - use requested quantization (INT4)
+ // o_proj, gate_proj, up_proj
+ if strings.Contains(name, "o_proj") ||
+ strings.Contains(name, "gate_proj") ||
+ strings.Contains(name, "up_proj") {
+ return quantNorm
+ }
+
+ // LM head - use requested quantization
+ if strings.Contains(name, "lm_head") {
+ return quantNorm
+ }
+
+ // Default to requested quantization for other weights
+ return quantNorm
+}
+
+// expertGroupRegexp matches expert tensor names and captures the group prefix.
+// Matches: model.layers.{L}.mlp.experts.{E}.{proj}.weight (and .scale, .bias suffixes)
+// Captures: model.layers.{L}.mlp.experts
+var expertGroupRegexp = regexp.MustCompile(`^(model\.layers\.\d+\.mlp\.(?:shared_)?experts)\..*\.weight`)
+
+// ExpertGroupPrefix returns the group prefix for expert tensors that should be packed together.
+// For example:
+// - "model.layers.1.mlp.experts.0.down_proj.weight" -> "model.layers.1.mlp.experts"
+// - "model.layers.1.mlp.shared_experts.down_proj.weight" -> "model.layers.1.mlp.shared_experts"
+// - "model.layers.0.mlp.down_proj.weight" -> "" (dense layer, no experts)
+// - "model.layers.1.mlp.gate.weight" -> "" (routing gate, not an expert)
+func ExpertGroupPrefix(tensorName string) string {
+ m := expertGroupRegexp.FindStringSubmatch(tensorName)
+ if m == nil {
+ return ""
}
+ return m[1]
+}
- return true
+// PackedTensorInput holds metadata for a tensor that will be packed into a multi-tensor blob.
+type PackedTensorInput struct {
+ Name string
+ Dtype string
+ Shape []int32
+ Quantize string // per-tensor quantization type (may differ within group)
+ Reader io.Reader // safetensors-wrapped tensor data
}
+// PackedTensorLayerCreator creates a single blob layer containing multiple packed tensors.
+// groupName is the group prefix (e.g., "model.layers.1.mlp.experts").
+type PackedTensorLayerCreator func(groupName string, tensors []PackedTensorInput) (LayerInfo, error)
+
// CreateSafetensorsModel imports a standard safetensors model from a directory.
// This handles Hugging Face style models with config.json and *.safetensors files.
// Stores each tensor as a separate blob for fine-grained deduplication.
-// If quantize is non-empty (e.g., "fp8"), eligible tensors will be quantized.
-func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
+// Expert tensors are packed into per-layer blobs when createPackedLayer is non-nil.
+// If quantize is non-empty (e.g., "int8"), eligible tensors will be quantized.
+func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string), createPackedLayer ...PackedTensorLayerCreator) error {
var layers []LayerInfo
var configLayer LayerInfo
+ // Resolve the optional packed layer creator
+ var packedCreator PackedTensorLayerCreator
+ if len(createPackedLayer) > 0 {
+ packedCreator = createPackedLayer[0]
+ }
+
+ // Accumulate expert tensors by group prefix for packing.
+ // Readers reference file-backed SectionReaders, so we keep extractors
+ // open until each group is flushed to avoid buffering tensor data in memory.
+ expertGroups := make(map[string][]PackedTensorInput)
+ var expertGroupOrder []string
+
+ // Track open extractors so we can close them after flushing groups
+ var openExtractors []*safetensors.TensorExtractor
+
+ closeExtractors := func() {
+ for _, ext := range openExtractors {
+ ext.Close()
+ }
+ openExtractors = nil
+ }
+
entries, err := os.ReadDir(modelDir)
if err != nil {
return fmt.Errorf("failed to read directory: %w", err)
@@ -312,6 +450,7 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
// Extract individual tensors from safetensors file
extractor, err := safetensors.OpenForExtraction(stPath)
if err != nil {
+ closeExtractors()
return fmt.Errorf("failed to open %s: %w", stPath, err)
}
@@ -322,32 +461,82 @@ func CreateSafetensorsModel(modelName, modelDir, quantize string, createLayer La
}
fn(fmt.Sprintf("importing %s (%d tensors%s)", entry.Name(), len(tensorNames), quantizeMsg))
+ // Track whether this extractor has expert tensors that need to stay open
+ hasExpertTensors := false
+
for _, tensorName := range tensorNames {
td, err := extractor.GetTensor(tensorName)
if err != nil {
extractor.Close()
+ closeExtractors()
return fmt.Errorf("failed to get tensor %s: %w", tensorName, err)
}
// Determine quantization type for this tensor (empty string if not quantizing)
+ // GetTensorQuantization handles mixed-precision (e.g., Q8 for attention, Q4 for FFN)
quantizeType := ""
- if quantize != "" && ShouldQuantizeTensor(tensorName, td.Shape) {
- quantizeType = quantize
+ if quantize != "" {
+ quantizeType = GetTensorQuantization(tensorName, td.Shape, quantize)
}
- // Store as minimal safetensors format (88 bytes header overhead)
- // This enables native mmap loading via mlx_load_safetensors
- // createTensorLayer returns multiple layers if quantizing (weight + scales)
- newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
- if err != nil {
- extractor.Close()
- return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
+ // Check if this tensor belongs to an expert group for packing
+ groupPrefix := ""
+ if packedCreator != nil {
+ groupPrefix = ExpertGroupPrefix(tensorName)
+ }
+
+ if groupPrefix != "" {
+ // Accumulate expert tensor for packed blob.
+ // The Reader uses a file-backed SectionReader, so we must
+ // keep the extractor open until this group is flushed.
+ hasExpertTensors = true
+ if _, exists := expertGroups[groupPrefix]; !exists {
+ expertGroupOrder = append(expertGroupOrder, groupPrefix)
+ }
+ expertGroups[groupPrefix] = append(expertGroups[groupPrefix], PackedTensorInput{
+ Name: tensorName,
+ Dtype: td.Dtype,
+ Shape: td.Shape,
+ Quantize: quantizeType,
+ Reader: td.SafetensorsReader(),
+ })
+ } else {
+ // Store as minimal safetensors format (88 bytes header overhead)
+ // This enables native mmap loading via mlx_load_safetensors
+ // createTensorLayer returns multiple layers if quantizing (weight + scales)
+ newLayers, err := createTensorLayer(td.SafetensorsReader(), tensorName, td.Dtype, td.Shape, quantizeType)
+ if err != nil {
+ extractor.Close()
+ closeExtractors()
+ return fmt.Errorf("failed to create layer for %s: %w", tensorName, err)
+ }
+ layers = append(layers, newLayers...)
}
- layers = append(layers, newLayers...)
}
- extractor.Close()
+ if hasExpertTensors {
+ // Keep extractor open - readers still reference its file handle
+ openExtractors = append(openExtractors, extractor)
+ } else {
+ extractor.Close()
+ }
+ }
+
+ // Process accumulated expert groups into packed blobs, then close extractors
+ if packedCreator != nil {
+ sort.Strings(expertGroupOrder)
+ for _, groupName := range expertGroupOrder {
+ tensors := expertGroups[groupName]
+ fn(fmt.Sprintf("packing %s (%d tensors)", groupName, len(tensors)))
+ layer, err := packedCreator(groupName, tensors)
+ if err != nil {
+ closeExtractors()
+ return fmt.Errorf("failed to create packed layer for %s: %w", groupName, err)
+ }
+ layers = append(layers, layer)
+ }
}
+ closeExtractors()
// Process all JSON config files
for _, entry := range entries {
diff --git a/x/create/create_test.go b/x/create/create_test.go
index c69bb10a806..fb48987d636 100644
--- a/x/create/create_test.go
+++ b/x/create/create_test.go
@@ -536,41 +536,84 @@ func TestShouldQuantize(t *testing.T) {
func TestShouldQuantizeTensor(t *testing.T) {
tests := []struct {
- name string
- tensor string
- shape []int32
- want bool
+ name string
+ tensor string
+ shape []int32
+ quantize string
+ want bool
}{
// 2D tensors with sufficient size should be quantized
- {"large 2D weight", "q_proj.weight", []int32{4096, 4096}, true},
- {"medium 2D weight", "small_proj.weight", []int32{128, 128}, true},
+ {"large 2D weight fp8", "q_proj.weight", []int32{4096, 4096}, "fp8", true},
+ {"medium 2D weight fp8", "small_proj.weight", []int32{128, 128}, "fp8", true},
+ {"large 2D weight nvfp4", "q_proj.weight", []int32{4096, 4096}, "nvfp4", true},
// Small tensors should not be quantized (< 1024 elements)
- {"tiny 2D weight", "tiny.weight", []int32{16, 16}, false},
- {"small 2D weight", "small.weight", []int32{31, 31}, false},
+ {"tiny 2D weight", "tiny.weight", []int32{16, 16}, "fp8", false},
+ {"small 2D weight", "small.weight", []int32{31, 31}, "fp8", false},
// 1D tensors should not be quantized
- {"1D tensor", "layer_norm.weight", []int32{4096}, false},
+ {"1D tensor", "layer_norm.weight", []int32{4096}, "fp8", false},
// 3D+ tensors should not be quantized
- {"3D tensor", "conv.weight", []int32{64, 64, 3}, false},
- {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, false},
+ {"3D tensor", "conv.weight", []int32{64, 64, 3}, "fp8", false},
+ {"4D tensor", "conv2d.weight", []int32{64, 64, 3, 3}, "fp8", false},
// Embeddings should not be quantized regardless of shape
- {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, false},
+ {"embedding 2D", "embed_tokens.weight", []int32{32000, 4096}, "fp8", false},
// Norms should not be quantized regardless of shape
- {"norm 2D", "layer_norm.weight", []int32{4096, 1}, false},
+ {"norm 2D", "layer_norm.weight", []int32{4096, 1}, "fp8", false},
// Biases should not be quantized
- {"bias 2D", "proj.bias", []int32{4096, 1}, false},
+ {"bias 2D", "proj.bias", []int32{4096, 1}, "fp8", false},
+
+ // Group size divisibility tests
+ // FP8/FP4 require divisible by 32
+ {"not divisible by 32 fp8", "proj.weight", []int32{128, 48}, "fp8", false},
+ {"divisible by 32 fp8", "proj.weight", []int32{128, 64}, "fp8", true},
+ // NVFP4 requires divisible by 16
+ {"not divisible by 16 nvfp4", "proj.weight", []int32{128, 24}, "nvfp4", false},
+ {"divisible by 16 nvfp4", "proj.weight", []int32{128, 48}, "nvfp4", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ got := ShouldQuantizeTensor(tt.tensor, tt.shape, tt.quantize)
+ if got != tt.want {
+ t.Errorf("ShouldQuantizeTensor(%q, %v, %q) = %v, want %v", tt.tensor, tt.shape, tt.quantize, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestExpertGroupPrefix(t *testing.T) {
+ tests := []struct {
+ name string
+ want string
+ }{
+ // Expert tensors should return the group prefix
+ {"model.layers.1.mlp.experts.0.down_proj.weight", "model.layers.1.mlp.experts"},
+ {"model.layers.1.mlp.experts.63.gate_proj.weight", "model.layers.1.mlp.experts"},
+ {"model.layers.0.mlp.experts.0.up_proj.weight", "model.layers.0.mlp.experts"},
+
+ // Shared expert tensors should return their own group prefix
+ {"model.layers.1.mlp.shared_experts.down_proj.weight", "model.layers.1.mlp.shared_experts"},
+ {"model.layers.2.mlp.shared_experts.gate_proj.weight", "model.layers.2.mlp.shared_experts"},
+
+ // Non-expert tensors should return empty string
+ {"model.layers.0.mlp.down_proj.weight", ""}, // dense layer, no experts
+ {"model.layers.1.mlp.gate.weight", ""}, // routing gate, not an expert
+ {"model.embed_tokens.weight", ""}, // embedding
+ {"model.layers.0.self_attn.q_proj.weight", ""}, // attention
+ {"model.norm.weight", ""}, // norm
+ {"lm_head.weight", ""}, // output head
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
- got := ShouldQuantizeTensor(tt.tensor, tt.shape)
+ got := ExpertGroupPrefix(tt.name)
if got != tt.want {
- t.Errorf("ShouldQuantizeTensor(%q, %v) = %v, want %v", tt.tensor, tt.shape, got, tt.want)
+ t.Errorf("ExpertGroupPrefix(%q) = %q, want %q", tt.name, got, tt.want)
}
})
}
@@ -741,7 +784,7 @@ func TestCreateImageGenModel_WithQuantize(t *testing.T) {
progressFn := func(status string) {}
- err := CreateImageGenModel("test-imagegen", dir, "fp8", createLayer, createTensorLayer, writeManifest, progressFn)
+ err := CreateImageGenModel("test-imagegen", dir, "int8", createLayer, createTensorLayer, writeManifest, progressFn)
if err != nil {
t.Fatalf("CreateImageGenModel failed: %v", err)
}
diff --git a/x/create/imagegen.go b/x/create/imagegen.go
index 595a404175b..6dbbcbfccfb 100644
--- a/x/create/imagegen.go
+++ b/x/create/imagegen.go
@@ -15,15 +15,15 @@ import (
// CreateImageGenModel imports an image generation model from a directory.
// Stores each tensor as a separate blob for fine-grained deduplication.
// If quantize is specified, linear weights in transformer/text_encoder are quantized.
-// Supported quantization types: fp8 (or empty for no quantization).
+// Supported quantization types: int4, int8, nvfp4, mxfp8 (or empty for no quantization).
// Layer creation and manifest writing are done via callbacks to avoid import cycles.
func CreateImageGenModel(modelName, modelDir, quantize string, createLayer LayerCreator, createTensorLayer QuantizingTensorLayerCreator, writeManifest ManifestWriter, fn func(status string)) error {
// Validate quantization type
switch quantize {
- case "", "fp4", "fp8":
+ case "", "int4", "int8", "nvfp4", "mxfp8":
// valid
default:
- return fmt.Errorf("unsupported quantization type %q: supported types are fp4, fp8", quantize)
+ return fmt.Errorf("unsupported quantization type %q: supported types are int4, int8, nvfp4, mxfp8", quantize)
}
var layers []LayerInfo
@@ -89,7 +89,7 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
// Determine quantization type for this tensor (empty string if not quantizing)
quantizeType := ""
- if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape) {
+ if quantize != "" && ShouldQuantize(tensorName, component) && canQuantizeShape(td.Shape, quantize) {
quantizeType = quantize
}
@@ -213,10 +213,18 @@ func CreateImageGenModel(modelName, modelDir, quantize string, createLayer Layer
}
// canQuantizeShape returns true if a tensor shape is compatible with MLX quantization.
-// MLX requires the last dimension to be divisible by the group size (32).
-func canQuantizeShape(shape []int32) bool {
+// MLX requires the last dimension to be divisible by the group size.
+// nvfp4: 16, int4/mxfp8: 32, int8: 64
+func canQuantizeShape(shape []int32, quantize string) bool {
if len(shape) < 2 {
return false
}
- return shape[len(shape)-1]%32 == 0
+ groupSize := int32(32)
+ switch strings.ToUpper(quantize) {
+ case "NVFP4":
+ groupSize = 16
+ case "INT8":
+ groupSize = 64
+ }
+ return shape[len(shape)-1]%groupSize == 0
}
diff --git a/x/imagegen/cache/cache.go b/x/imagegen/cache/cache.go
index 4faa2412ee8..8a25193cd58 100644
--- a/x/imagegen/cache/cache.go
+++ b/x/imagegen/cache/cache.go
@@ -9,6 +9,7 @@ type Cache interface {
Offset() int
Len() int
State() []*mlx.Array
+ Reset()
}
type KVCache struct {
@@ -63,6 +64,13 @@ func (c *KVCache) State() []*mlx.Array {
func (c *KVCache) Offset() int { return c.offset }
func (c *KVCache) Len() int { return c.offset }
+// Reset clears the cache state for a new generation session
+func (c *KVCache) Reset() {
+ c.keys = nil
+ c.values = nil
+ c.offset = 0
+}
+
// RotatingKVCache implements sliding window attention with bounded memory
type RotatingKVCache struct {
keys, values *mlx.Array
@@ -154,3 +162,11 @@ func (c *RotatingKVCache) State() []*mlx.Array {
func (c *RotatingKVCache) Offset() int { return c.offset }
func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
+
+// Reset clears the cache state for a new generation session
+func (c *RotatingKVCache) Reset() {
+ c.keys = nil
+ c.values = nil
+ c.offset = 0
+ c.idx = 0
+}
diff --git a/x/imagegen/cache/step.go b/x/imagegen/cache/step.go
index 830df447fb1..f91f22fa0f0 100644
--- a/x/imagegen/cache/step.go
+++ b/x/imagegen/cache/step.go
@@ -9,7 +9,7 @@ import "github.com/ollama/ollama/x/imagegen/mlx"
// shallow layers change little between consecutive steps, so we can
// cache their outputs and skip recomputation on non-refresh steps.
//
-// Supports both single-stream (Z-Image) and dual-stream (Qwen-Image) architectures:
+// Supports both single-stream and dual-stream architectures:
// - Single-stream: use Get/Set for the single output per layer
// - Dual-stream: use Get/Set for stream 1 (imgH), Get2/Set2 for stream 2 (txtH)
//
@@ -87,7 +87,7 @@ func (c *StepCache) Set(layer int, arr *mlx.Array) {
}
// Get2 returns the cached output for a layer (stream 2), or nil if not cached.
-// Used for dual-stream architectures like Qwen-Image.
+// Used for dual-stream architectures.
func (c *StepCache) Get2(layer int) *mlx.Array {
if layer < len(c.layers2) {
return c.layers2[layer]
@@ -96,7 +96,7 @@ func (c *StepCache) Get2(layer int) *mlx.Array {
}
// Set2 stores a layer output (stream 2), freeing any previous value.
-// Used for dual-stream architectures like Qwen-Image.
+// Used for dual-stream architectures.
func (c *StepCache) Set2(layer int, arr *mlx.Array) {
if layer < len(c.layers2) {
if c.layers2[layer] != nil {
diff --git a/x/imagegen/cli.go b/x/imagegen/cli.go
index a55a1b01698..e5de34efa7f 100644
--- a/x/imagegen/cli.go
+++ b/x/imagegen/cli.go
@@ -10,7 +10,10 @@ import (
"errors"
"fmt"
"io"
+ "net/http"
"os"
+ "regexp"
+ "slices"
"strconv"
"strings"
"time"
@@ -75,6 +78,7 @@ Image Generation Flags (experimental):
// RunCLI handles the CLI for image generation models.
// Returns true if it handled the request, false if the caller should continue with normal flow.
// Supports flags: --width, --height, --steps, --seed, --negative
+// Image paths can be included in the prompt and will be extracted automatically.
func RunCLI(cmd *cobra.Command, name string, prompt string, interactive bool, keepAlive *api.Duration) error {
// Get options from flags (with env var defaults)
opts := DefaultOptions()
@@ -111,9 +115,16 @@ func generateImageWithOptions(cmd *cobra.Command, modelName, prompt string, keep
return err
}
+ // Extract any image paths from the prompt
+ prompt, images, err := extractFileData(prompt)
+ if err != nil {
+ return err
+ }
+
req := &api.GenerateRequest{
Model: modelName,
Prompt: prompt,
+ Images: images,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
@@ -254,14 +265,33 @@ func runInteractive(cmd *cobra.Command, modelName string, keepAlive *api.Duratio
printCurrentSettings(opts)
continue
case strings.HasPrefix(line, "/"):
- fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", line)
+ // Check if it's a file path, not a command
+ args := strings.Fields(line)
+ isFile := false
+ for _, f := range extractFileNames(line) {
+ if strings.HasPrefix(f, args[0]) {
+ isFile = true
+ break
+ }
+ }
+ if !isFile {
+ fmt.Fprintf(os.Stderr, "Unknown command: %s (try /help)\n", args[0])
+ continue
+ }
+ }
+
+ // Extract any image paths from the input
+ prompt, images, err := extractFileData(line)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error: %v\n", err)
continue
}
// Generate image with current options
req := &api.GenerateRequest{
Model: modelName,
- Prompt: line,
+ Prompt: prompt,
+ Images: images,
Width: int32(opts.Width),
Height: int32(opts.Height),
Steps: int32(opts.Steps),
@@ -486,3 +516,61 @@ func displayImageInTerminal(imagePath string) bool {
return false
}
}
+
+// extractFileNames finds image file paths in the input string.
+func extractFileNames(input string) []string {
+ // Regex to match file paths with image extensions
+ regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
+ re := regexp.MustCompile(regexPattern)
+ return re.FindAllString(input, -1)
+}
+
+// extractFileData extracts image data from file paths found in the input.
+// Returns the cleaned prompt (with file paths removed) and the image data.
+func extractFileData(input string) (string, []api.ImageData, error) {
+ filePaths := extractFileNames(input)
+ var imgs []api.ImageData
+
+ for _, fp := range filePaths {
+ // Normalize shell escapes
+ nfp := strings.ReplaceAll(fp, "\\ ", " ")
+ nfp = strings.ReplaceAll(nfp, "\\(", "(")
+ nfp = strings.ReplaceAll(nfp, "\\)", ")")
+ nfp = strings.ReplaceAll(nfp, "%20", " ")
+
+ data, err := getImageData(nfp)
+ if errors.Is(err, os.ErrNotExist) {
+ continue
+ } else if err != nil {
+ return "", nil, err
+ }
+ fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
+ input = strings.ReplaceAll(input, fp, "")
+ imgs = append(imgs, data)
+ }
+ return strings.TrimSpace(input), imgs, nil
+}
+
+// getImageData reads and validates image data from a file.
+func getImageData(filePath string) ([]byte, error) {
+ file, err := os.Open(filePath)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+
+ buf := make([]byte, 512)
+ _, err = file.Read(buf)
+ if err != nil {
+ return nil, err
+ }
+
+ contentType := http.DetectContentType(buf)
+ allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
+ if !slices.Contains(allowedTypes, contentType) {
+ return nil, fmt.Errorf("invalid image type: %s", contentType)
+ }
+
+ // Re-read the full file
+ return os.ReadFile(filePath)
+}
diff --git a/x/imagegen/cmd/engine/main.go b/x/imagegen/cmd/engine/main.go
index 003be3a37cd..f0e705d1c52 100644
--- a/x/imagegen/cmd/engine/main.go
+++ b/x/imagegen/cmd/engine/main.go
@@ -21,8 +21,6 @@ import (
"github.com/ollama/ollama/x/imagegen/models/gemma3"
"github.com/ollama/ollama/x/imagegen/models/gpt_oss"
"github.com/ollama/ollama/x/imagegen/models/llama"
- "github.com/ollama/ollama/x/imagegen/models/qwen_image"
- "github.com/ollama/ollama/x/imagegen/models/qwen_image_edit"
"github.com/ollama/ollama/x/imagegen/models/zimage"
"github.com/ollama/ollama/x/imagegen/safetensors"
)
@@ -61,14 +59,11 @@ func main() {
listTensors := flag.Bool("list", false, "List tensors only")
cpuProfile := flag.String("cpuprofile", "", "Write CPU profile to file")
gpuCapture := flag.String("gpu-capture", "", "Capture GPU trace to .gputrace file (run with MTL_CAPTURE_ENABLED=1)")
- layerCache := flag.Bool("layer-cache", false, "Enable layer caching for faster diffusion (Z-Image, Qwen-Image). Not compatible with CFG/negative prompts.")
wiredLimitGB := flag.Int("wired-limit", 32, "Metal wired memory limit in GB")
// Legacy mode flags
zimageFlag := flag.Bool("zimage", false, "Z-Image generation")
flux2Flag := flag.Bool("flux2", false, "FLUX.2 Klein generation")
- qwenImage := flag.Bool("qwen-image", false, "Qwen-Image text-to-image generation")
- qwenImageEdit := flag.Bool("qwen-image-edit", false, "Qwen-Image-Edit image editing")
var inputImages stringSlice
flag.Var(&inputImages, "input-image", "Input image for image editing (can be specified multiple times)")
negativePrompt := flag.String("negative-prompt", "", "Negative prompt for CFG (empty = no CFG, matching Python)")
@@ -166,60 +161,6 @@ func main() {
if err == nil {
err = saveImageArray(img, *out)
}
- case *qwenImage:
- m, loadErr := qwen_image.LoadPersistent(*modelPath)
- if loadErr != nil {
- log.Fatal(loadErr)
- }
- var img *mlx.Array
- img, err = m.GenerateFromConfig(&qwen_image.GenerateConfig{
- Prompt: *prompt,
- NegativePrompt: *negativePrompt,
- CFGScale: float32(*cfgScale),
- Width: int32(*width),
- Height: int32(*height),
- Steps: *steps,
- Seed: *seed,
- LayerCache: *layerCache,
- })
- if err == nil {
- err = saveImageArray(img, *out)
- }
- case *qwenImageEdit:
- if len(inputImages) == 0 {
- log.Fatal("qwen-image-edit requires at least one -input-image")
- }
-
- m, loadErr := qwen_image_edit.LoadPersistent(*modelPath)
- if loadErr != nil {
- log.Fatal(loadErr)
- }
- // For image editing, use 0 for dimensions to auto-detect from input image
- // unless explicitly overridden from defaults
- editWidth := int32(0)
- editHeight := int32(0)
- if *width != 1024 {
- editWidth = int32(*width)
- }
- if *height != 1024 {
- editHeight = int32(*height)
- }
-
- cfg := &qwen_image_edit.GenerateConfig{
- Prompt: *prompt,
- NegativePrompt: *negativePrompt,
- CFGScale: float32(*cfgScale),
- Width: editWidth,
- Height: editHeight,
- Steps: *steps,
- Seed: *seed,
- }
-
- var img *mlx.Array
- img, err = m.EditFromConfig(inputImages, cfg)
- if err == nil {
- err = saveImageArray(img, *out)
- }
case *listTensors:
err = listModelTensors(*modelPath)
default:
diff --git a/x/imagegen/docs/blob-format.md b/x/imagegen/docs/blob-format.md
new file mode 100644
index 00000000000..768f1c2f9ec
--- /dev/null
+++ b/x/imagegen/docs/blob-format.md
@@ -0,0 +1,158 @@
+# Tensor Blob Format
+
+Ollama stores model tensors as individual blobs in the safetensors format. Each blob contains a logical tensor (or a combined quantized tensor with its scale/bias components), or a group of logical tensors (e.g. shared experts for a given layer along with the scale/bias components for that tensor).
+
+## Safetensors File Format
+
+Every blob follows the [safetensors](https://github.com/huggingface/safetensors) layout:
+
+```
+[8 bytes: header_size (uint64 LE)] [header_size bytes: JSON header] [tensor data region]
+```
+
+The JSON header maps tensor names to their dtype, shape, and byte offsets within the data region. A special `__metadata__` key holds string-to-string metadata.
+
+## Unquantized Blobs
+
+An unquantized blob stores a single tensor keyed by its name:
+
+```json
+{
+ "model.layers.0.self_attn.q_proj.weight": {
+ "dtype": "BF16",
+ "shape": [2560, 2560],
+ "data_offsets": [0, 13107200]
+ }
+}
+```
+
+The tensor key is the full tensor name. Dtype is typically `BF16` or `F32`.
+
+## Quantized Blobs (Combined Format)
+
+A quantized blob stores the packed weight, scaling factors, and optional zero-point biases in a single file. Tensor keys use the tensor name, with `.scale` and `.bias` suffixes for the auxiliary tensors:
+
+```json
+{
+ "__metadata__": {
+ "quant_type": "int4",
+ "group_size": "32"
+ },
+ "model.layers.0.mlp.up_proj.weight": {
+ "dtype": "U32",
+ "shape": [2560, 320],
+ "data_offsets": [0, 3276800]
+ },
+ "model.layers.0.mlp.up_proj.weight.scale": {
+ "dtype": "BF16",
+ "shape": [2560, 80],
+ "data_offsets": [3276800, 3686400]
+ },
+ "model.layers.0.mlp.up_proj.weight.bias": {
+ "dtype": "BF16",
+ "shape": [2560, 80],
+ "data_offsets": [3686400, 4096000]
+ }
+}
+```
+
+### Metadata Fields
+
+| Field | Description |
+|---|---|
+| `quant_type` | Quantization type: `int4`, `int8`, `nvfp4`, or `mxfp8` |
+| `group_size` | Number of elements per quantization group (e.g., `32`, `64`) |
+
+### Tensor Keys
+
+| Key | Description |
+|---|---|
+| `{name}` | Packed quantized weights (dtype `U32`) |
+| `{name}.scale` | Per-group scaling factors |
+| `{name}.bias` | Per-group zero-point offsets (affine modes only) |
+
+## Quantization Types
+
+| Type | Bits | Group Size | Mode | Has Bias |
+|---|---|---|---|---|
+| `int4` | 4 | 32 | affine | yes |
+| `int8` | 8 | 64 | affine | yes |
+| `nvfp4` | 4 | 16 | nvfp4 | no |
+| `mxfp8` | 8 | 32 | mxfp8 | no |
+
+**Affine modes** (`int4`, `int8`) use `scale + bias` for dequantization. The bias tensor provides the zero-point offset.
+
+**Non-affine modes** (`nvfp4`, `mxfp8`) use only `scale` with specialized E4M3 scale formats.
+
+### Packed Weight Shape
+
+Quantized weights are packed into `uint32` values:
+- **4-bit** (int4, nvfp4): 8 values per uint32, so `packed_cols = original_cols / 8`
+- **8-bit** (int8, mxfp8): 4 values per uint32, so `packed_cols = original_cols / 4`
+
+Scale shape: `[rows, original_cols / group_size]`
+
+## Manifest References
+
+Blobs are referenced from the model manifest as layers:
+
+```json
+{
+ "mediaType": "application/vnd.ollama.image.tensor",
+ "digest": "sha256:abc123...",
+ "size": 4096150,
+ "name": "model.layers.0.mlp.up_proj.weight"
+}
+```
+
+Each tensor (quantized or not) is one layer in the manifest. The layer name matches the tensor key in the blob header.
+
+## Packed Blobs (Expert Groups)
+
+For MoE (Mixture of Experts) models, expert tensors from the same layer are packed into a single blob to reduce blob count and improve loading efficiency. A packed blob is a standard safetensors file containing multiple tensor entries:
+
+```json
+{
+ "model.layers.1.mlp.experts.0.down_proj.weight": {
+ "dtype": "U32",
+ "shape": [2560, 640],
+ "data_offsets": [0, 6553600]
+ },
+ "model.layers.1.mlp.experts.0.down_proj.weight.scale": {
+ "dtype": "BF16",
+ "shape": [2560, 40],
+ "data_offsets": [6553600, 6963200]
+ },
+ "model.layers.1.mlp.experts.0.gate_proj.weight": {
+ "dtype": "U32",
+ "shape": [10240, 320],
+ "data_offsets": [6963200, 20070400]
+ },
+ "model.layers.1.mlp.experts.0.gate_proj.weight.scale": { "..." : "..." }
+}
+```
+
+### Grouping Rules
+
+- `model.layers.{L}.mlp.experts.*` tensors are packed into one blob per layer
+- `model.layers.{L}.mlp.shared_experts.*` tensors are packed into one blob per layer
+- All other tensors remain as individual blobs
+
+### Manifest Representation
+
+One manifest layer per packed group, using the group prefix as the layer name:
+
+```json
+{
+ "mediaType": "application/vnd.ollama.image.tensor",
+ "digest": "sha256:...",
+ "size": 123456789,
+ "name": "model.layers.1.mlp.experts"
+}
+```
+
+## Loading
+
+At load time, `mlx_load_safetensors` opens each blob via mmap for zero-copy access. For combined quantized blobs, the loader extracts `{name}`, `{name}.scale`, and `{name}.bias` tensors and caches them as `name`, `name + "_scale"`, and `name + "_qbias"` respectively, maintaining compatibility with the weight loading interface.
+
+For packed blobs, if the manifest layer name (group prefix) is not found as a tensor key, the loader parses the blob header to discover all tensor names and loads each individually.
diff --git a/x/imagegen/image.go b/x/imagegen/image.go
index db4d1a4c53e..2dca0ee1d36 100644
--- a/x/imagegen/image.go
+++ b/x/imagegen/image.go
@@ -7,6 +7,8 @@ import (
"encoding/base64"
"fmt"
"image"
+ "image/color"
+ "image/draw"
_ "image/jpeg"
"image/png"
"os"
@@ -111,6 +113,7 @@ func clampF(v, min, max float32) float32 {
}
// DecodeImage decodes image bytes with EXIF orientation applied.
+// Transparent images are composited onto a white background.
func DecodeImage(data []byte) (image.Image, error) {
orientation := readJPEGOrientation(data)
@@ -119,9 +122,33 @@ func DecodeImage(data []byte) (image.Image, error) {
return nil, err
}
+ img = flattenAlpha(img)
return applyOrientation(img, orientation), nil
}
+// flattenAlpha composites an image onto a white background,
+// removing any transparency. This is needed because image
+// generation models don't handle alpha channels well.
+func flattenAlpha(img image.Image) image.Image {
+ if _, ok := img.(*image.RGBA); !ok {
+ if _, ok := img.(*image.NRGBA); !ok {
+ // No alpha channel, return as-is
+ return img
+ }
+ }
+
+ bounds := img.Bounds()
+ dst := image.NewRGBA(bounds)
+
+ // Fill with white background
+ draw.Draw(dst, bounds, &image.Uniform{color.White}, image.Point{}, draw.Src)
+
+ // Composite the image on top
+ draw.Draw(dst, bounds, img, bounds.Min, draw.Over)
+
+ return dst
+}
+
// readJPEGOrientation extracts EXIF orientation from JPEG bytes.
// Returns 1 (normal) for non-JPEG or if orientation not found.
func readJPEGOrientation(data []byte) int {
diff --git a/x/imagegen/imagegen.go b/x/imagegen/imagegen.go
new file mode 100644
index 00000000000..d870bed9bcf
--- /dev/null
+++ b/x/imagegen/imagegen.go
@@ -0,0 +1,134 @@
+//go:build mlx
+
+package imagegen
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "sync"
+ "time"
+
+ "github.com/ollama/ollama/x/imagegen/manifest"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+ "github.com/ollama/ollama/x/imagegen/models/flux2"
+ "github.com/ollama/ollama/x/imagegen/models/zimage"
+)
+
+// ImageModel is the interface for image generation models.
+type ImageModel interface {
+ GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
+}
+
+var imageGenMu sync.Mutex
+
+// loadImageModel loads an image generation model.
+func (s *server) loadImageModel() error {
+ // Check memory requirements before loading
+ var requiredMemory uint64
+ if modelManifest, err := manifest.LoadManifest(s.modelName); err == nil {
+ requiredMemory = uint64(modelManifest.TotalTensorSize())
+ }
+ availableMemory := mlx.GetMemoryLimit()
+ if availableMemory > 0 && requiredMemory > 0 && availableMemory < requiredMemory {
+ return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
+ requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
+ }
+
+ // Detect model type and load appropriate model
+ modelType := DetectModelType(s.modelName)
+ slog.Info("detected image model type", "type", modelType)
+
+ var model ImageModel
+ switch modelType {
+ case "Flux2KleinPipeline":
+ m := &flux2.Model{}
+ if err := m.Load(s.modelName); err != nil {
+ return fmt.Errorf("failed to load flux2 model: %w", err)
+ }
+ model = m
+ default:
+ // Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
+ m := &zimage.Model{}
+ if err := m.Load(s.modelName); err != nil {
+ return fmt.Errorf("failed to load zimage model: %w", err)
+ }
+ model = m
+ }
+
+ s.imageModel = model
+ return nil
+}
+
+// handleImageCompletion handles image generation requests.
+func (s *server) handleImageCompletion(w http.ResponseWriter, r *http.Request, req Request) {
+ // Serialize generation requests - MLX model may not handle concurrent generation
+ imageGenMu.Lock()
+ defer imageGenMu.Unlock()
+
+ // Set seed if not provided
+ if req.Seed <= 0 {
+ req.Seed = time.Now().UnixNano()
+ }
+
+ // Set up streaming response
+ w.Header().Set("Content-Type", "application/x-ndjson")
+ w.Header().Set("Transfer-Encoding", "chunked")
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ http.Error(w, "streaming not supported", http.StatusInternalServerError)
+ return
+ }
+
+ ctx := r.Context()
+ enc := json.NewEncoder(w)
+
+ // Progress callback streams step updates
+ progress := func(step, total int) {
+ resp := Response{Step: step, Total: total}
+ enc.Encode(resp)
+ w.Write([]byte("\n"))
+ flusher.Flush()
+ }
+
+ // Generate image
+ img, err := s.imageModel.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
+ if err != nil {
+ // Don't send error for cancellation
+ if ctx.Err() != nil {
+ return
+ }
+ resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
+ data, _ := json.Marshal(resp)
+ w.Write(data)
+ w.Write([]byte("\n"))
+ return
+ }
+
+ // Encode image as base64 PNG
+ imageData, err := EncodeImageBase64(img)
+ if err != nil {
+ resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
+ data, _ := json.Marshal(resp)
+ w.Write(data)
+ w.Write([]byte("\n"))
+ return
+ }
+
+ // Free the generated image array and clean up MLX state
+ img.Free()
+ mlx.ClearCache()
+ mlx.MetalResetPeakMemory()
+
+ // Send final response with image data
+ resp := Response{
+ Image: imageData,
+ Done: true,
+ }
+ data, _ := json.Marshal(resp)
+ w.Write(data)
+ w.Write([]byte("\n"))
+ flusher.Flush()
+}
diff --git a/x/imagegen/llm.go b/x/imagegen/llm.go
new file mode 100644
index 00000000000..eda3b64b15a
--- /dev/null
+++ b/x/imagegen/llm.go
@@ -0,0 +1,420 @@
+//go:build mlx
+
+package imagegen
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ollama/ollama/x/imagegen/cache"
+ "github.com/ollama/ollama/x/imagegen/manifest"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+ "github.com/ollama/ollama/x/imagegen/models/glm4_moe_lite"
+ "github.com/ollama/ollama/x/imagegen/tokenizer"
+)
+
+// TextModel is the interface for LLM text generation models.
+type TextModel interface {
+ Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array
+ NewCache(maxSeqLen int32) []cache.Cache
+ Tokenizer() *tokenizer.Tokenizer
+ VocabSize() int32
+ MaxContextLength() int32
+ NumLayers() int
+}
+
+// llmState holds the state for LLM generation
+type llmState struct {
+ model TextModel
+}
+
+var llmMu sync.Mutex
+
+// Dedicated stream for generation (like mlx-lm's generation_stream)
+var generationStream *mlx.Stream
+
+// withStream runs fn with the generation stream as default
+func withStream(fn func()) {
+ // Lazy initialization of generationStream
+ if generationStream == nil {
+ generationStream = mlx.NewStream()
+ }
+ orig := mlx.GetDefaultStream()
+ mlx.SetDefaultStream(generationStream)
+ fn()
+ mlx.SetDefaultStream(orig)
+}
+
+// Decoder wraps model + cache for autoregressive generation.
+// This matches the pattern from cmd/engine/generate.go
+type Decoder struct {
+ model TextModel
+ caches []cache.Cache
+ vocabSize int32
+ temp float32
+ token *mlx.Array // Current token (kept across iterations)
+ oldCacheState []*mlx.Array // Preallocated slice for old cache state
+}
+
+func NewDecoder(m TextModel, temp float32) *Decoder {
+ caches := m.NewCache(0)
+ return &Decoder{
+ model: m,
+ caches: caches,
+ vocabSize: m.VocabSize(),
+ temp: temp,
+ oldCacheState: make([]*mlx.Array, 0, len(caches)*2),
+ }
+}
+
+func (d *Decoder) prefill(inputIDs []int32) int {
+ processed := 0
+
+ // Track old cache state to free after each chunk
+ var oldCacheState []*mlx.Array
+
+ // Process all-but-1 tokens in chunks, eval cache state for memory management
+ for len(inputIDs) > 1 {
+ chunkSize := min(2048, len(inputIDs)-1)
+ if chunkSize <= 0 {
+ break
+ }
+ chunk := inputIDs[:chunkSize]
+
+ // Save old cache state before forward
+ oldCacheState = oldCacheState[:0]
+ for _, c := range d.caches {
+ oldCacheState = append(oldCacheState, c.State()...)
+ }
+
+ var cacheState []*mlx.Array
+ withStream(func() {
+ x := mlx.NewArrayInt32(chunk, []int32{1, int32(len(chunk))})
+ d.model.Forward(x, d.caches)
+ for _, c := range d.caches {
+ cacheState = append(cacheState, c.State()...)
+ }
+ })
+ mlx.Eval(cacheState...)
+
+ // Free old cache state
+ for _, arr := range oldCacheState {
+ if arr != nil {
+ arr.Free()
+ }
+ }
+
+ inputIDs = inputIDs[chunkSize:]
+ processed += chunkSize
+ }
+
+ // Save old cache state before final step
+ oldCacheState = oldCacheState[:0]
+ for _, c := range d.caches {
+ oldCacheState = append(oldCacheState, c.State()...)
+ }
+
+ // Final token + sampling
+ withStream(func() {
+ x := mlx.NewArrayInt32(inputIDs, []int32{1, int32(len(inputIDs))})
+ mlx.Eval(x) // Materialize before any other evals
+ logits := d.model.Forward(x, d.caches)
+ d.token = sample(logits, d.temp, d.vocabSize)
+ })
+ // Keep cache state (token auto-kept by AsyncEval)
+ for _, c := range d.caches {
+ mlx.Keep(c.State()...)
+ }
+ mlx.AsyncEval(d.token)
+
+ // Free old cache state from before final step
+ for _, arr := range oldCacheState {
+ if arr != nil {
+ arr.Free()
+ }
+ }
+
+ mlx.ClearCache()
+
+ return processed + len(inputIDs)
+}
+
+func (d *Decoder) step() int32 {
+ prevToken := d.token
+
+ // Save old cache state (reuse preallocated slice)
+ d.oldCacheState = d.oldCacheState[:0]
+ for _, c := range d.caches {
+ d.oldCacheState = append(d.oldCacheState, c.State()...)
+ }
+
+ withStream(func() {
+ logits := d.model.Forward(mlx.Reshape(prevToken, 1, 1), d.caches)
+ d.token = sample(logits, d.temp, d.vocabSize)
+ })
+ // Keep token and new cache state so they survive cleanup
+ mlx.Keep(d.token)
+ for _, c := range d.caches {
+ mlx.Keep(c.State()...)
+ }
+ mlx.AsyncEval(d.token)
+
+ // Sync on previous token (GPU already working on next step)
+ val := prevToken.ItemInt32()
+
+ // Free old token and old cache state
+ prevToken.Free()
+ for _, arr := range d.oldCacheState {
+ arr.Free()
+ }
+ return val
+}
+
+// sample samples from logits using temperature scaling
+func sample(logits *mlx.Array, temp float32, vocabSize int32) *mlx.Array {
+ // Get last position logits: [1, L, vocab] -> [vocab]
+ shape := logits.Shape()
+ seqLen := shape[1]
+ lastLogits := mlx.Slice(logits, []int32{0, seqLen - 1, 0}, []int32{1, seqLen, vocabSize})
+ lastLogits = mlx.Reshape(lastLogits, vocabSize)
+
+ if temp <= 0 || temp < 0.01 {
+ // Greedy decoding
+ return mlx.Argmax(lastLogits, -1, false)
+ }
+
+ // Apply temperature scaling
+ scaled := mlx.DivScalar(lastLogits, temp)
+ return mlx.RandomCategorical(scaled, -1, 1)
+}
+
+// loadLLMModel loads a safetensors LLM model and its tokenizer from manifest storage.
+func (s *server) loadLLMModel() error {
+ // Load the manifest to get model information
+ modelManifest, err := manifest.LoadManifest(s.modelName)
+ if err != nil {
+ return fmt.Errorf("failed to load manifest: %w", err)
+ }
+
+ // Detect model architecture from config.json
+ configData, err := modelManifest.ReadConfig("config.json")
+ if err != nil {
+ return fmt.Errorf("failed to read config.json: %w", err)
+ }
+
+ var modelConfig struct {
+ Architectures []string `json:"architectures"`
+ ModelType string `json:"model_type"`
+ }
+ if err := json.Unmarshal(configData, &modelConfig); err != nil {
+ return fmt.Errorf("failed to parse config.json: %w", err)
+ }
+
+ arch := ""
+ if len(modelConfig.Architectures) > 0 {
+ arch = modelConfig.Architectures[0]
+ }
+ if arch == "" {
+ arch = modelConfig.ModelType
+ }
+
+ slog.Info("detected LLM architecture", "architecture", arch, "model_type", modelConfig.ModelType)
+
+ // Load the appropriate model based on architecture
+ var model TextModel
+ archLower := strings.ToLower(arch)
+
+ switch {
+ case strings.Contains(archLower, "glm4moelite"):
+ m, err := glm4_moe_lite.LoadFromManifest(modelManifest)
+ if err != nil {
+ return fmt.Errorf("failed to load glm4-moe-lite model: %w", err)
+ }
+ model = m
+ slog.Info("loaded glm4-moe-lite model", "vocab_size", m.VocabSize(), "layers", m.NumLayers())
+
+ default:
+ return fmt.Errorf("LLM architecture %q is not yet supported. "+
+ "Supported architectures: glm4-moe-lite. "+
+ "Please convert your model to GGUF format or use a supported architecture", arch)
+ }
+
+ s.llmModel = &llmState{
+ model: model,
+ }
+
+ return nil
+}
+
+// handleLLMCompletion handles LLM text generation requests.
+func (s *server) handleLLMCompletion(w http.ResponseWriter, r *http.Request, req Request) {
+ if s.llmModel == nil {
+ http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
+ return
+ }
+
+ // Serialize generation requests
+ llmMu.Lock()
+ defer llmMu.Unlock()
+
+ if err := s.llmGenerate(w, r, req); err != nil {
+ slog.Error("LLM generation failed", "error", err)
+ // Don't send error if we've already started streaming
+ }
+}
+
+// llmGenerate runs the generation loop using the Decoder pattern from cmd/engine
+func (s *server) llmGenerate(w http.ResponseWriter, r *http.Request, req Request) error {
+ state := s.llmModel
+
+ // Set up streaming response
+ w.Header().Set("Content-Type", "application/x-ndjson")
+ w.Header().Set("Transfer-Encoding", "chunked")
+ flusher, ok := w.(http.Flusher)
+ if !ok {
+ return errors.New("streaming not supported")
+ }
+
+ tok := state.model.Tokenizer()
+
+ // The prompt is already formatted by the server using the model's renderer
+ // (see server/prompt.go renderPrompt), so we don't apply FormatPrompt here.
+ prompt := req.Prompt
+
+ // Tokenize the prompt
+ inputIDs := tok.Encode(prompt, true)
+ slog.Debug("tokenized prompt", "num_tokens", len(inputIDs))
+
+ // Generation parameters
+ maxTokens := int(state.model.MaxContextLength())
+ if maxTokens <= 0 {
+ maxTokens = 4096
+ }
+ if req.Options != nil && req.Options.NumPredict > 0 {
+ maxTokens = req.Options.NumPredict
+ }
+
+ temperature := float32(0.7)
+ if req.Options != nil && req.Options.Temperature > 0 {
+ temperature = float32(req.Options.Temperature)
+ }
+
+ // Enable MLX compilation for better performance
+ mlx.EnableCompile()
+
+ // Create decoder with fresh caches
+ dec := NewDecoder(state.model, temperature)
+
+ prefillStart := time.Now()
+ prefillTokens := dec.prefill(inputIDs)
+ // Prefill measurement includes time to first token
+ firstToken := dec.step()
+ prefillDuration := time.Since(prefillStart)
+ promptEvalDuration := prefillDuration
+
+ enc := json.NewEncoder(w)
+ ctx := r.Context()
+ generated := 0
+ stopReason := "max_tokens"
+
+ // Handle first token
+ generated++
+ if tok.IsEOS(firstToken) {
+ resp := Response{
+ Done: true,
+ StopReason: fmt.Sprintf("first_token_eos:%d", firstToken),
+ PromptEvalCount: prefillTokens,
+ PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
+ }
+ enc.Encode(resp)
+ flusher.Flush()
+ return nil
+ }
+
+ text := tok.Decode([]int32{firstToken})
+ resp := Response{Content: text}
+ enc.Encode(resp)
+ flusher.Flush()
+
+ genStart := time.Now()
+
+ // Generation loop
+ for n := 1; n < maxTokens; n++ {
+ // Check for cancellation
+ select {
+ case <-ctx.Done():
+ stopReason = fmt.Sprintf("context_cancelled:%d", generated)
+ break
+ default:
+ }
+ if stopReason != "max_tokens" {
+ break
+ }
+
+ token := dec.step()
+ generated++
+
+ if tok.IsEOS(token) {
+ stopReason = fmt.Sprintf("eos_token:%d", token)
+ break
+ }
+
+ text := tok.Decode([]int32{token})
+
+ // Check for stop sequences
+ if req.Options != nil && len(req.Options.Stop) > 0 {
+ shouldStop := false
+ var matchedStop string
+ for _, stop := range req.Options.Stop {
+ if strings.Contains(text, stop) {
+ text = strings.Split(text, stop)[0]
+ shouldStop = true
+ matchedStop = stop
+ break
+ }
+ }
+ if shouldStop {
+ if text != "" {
+ resp := Response{Content: text}
+ enc.Encode(resp)
+ flusher.Flush()
+ }
+ stopReason = fmt.Sprintf("stop_sequence:%s", matchedStop)
+ break
+ }
+ }
+
+ resp := Response{Content: text}
+ enc.Encode(resp)
+ flusher.Flush()
+
+ // Periodically clear MLX cache
+ if n%256 == 0 {
+ mlx.ClearCache()
+ }
+ }
+
+ // Clean up
+ mlx.ClearCache()
+
+ // Send final response with stats
+ evalDuration := time.Since(genStart)
+ resp = Response{
+ Done: true,
+ StopReason: fmt.Sprintf("%s:generated=%d", stopReason, generated),
+ PromptEvalCount: prefillTokens,
+ PromptEvalDuration: int(promptEvalDuration.Nanoseconds()),
+ EvalCount: generated,
+ EvalDuration: int(evalDuration.Nanoseconds()),
+ }
+ enc.Encode(resp)
+ flusher.Flush()
+
+ return nil
+}
diff --git a/x/imagegen/manifest.go b/x/imagegen/manifest/manifest.go
similarity index 62%
rename from x/imagegen/manifest.go
rename to x/imagegen/manifest/manifest.go
index 3b30067791e..4de66644c24 100644
--- a/x/imagegen/manifest.go
+++ b/x/imagegen/manifest/manifest.go
@@ -1,13 +1,16 @@
-package imagegen
+package manifest
import (
+ "encoding/binary"
"encoding/json"
"fmt"
"io"
"os"
"path/filepath"
- "runtime"
+ "sort"
"strings"
+
+ "github.com/ollama/ollama/envconfig"
)
// ManifestLayer represents a layer in the manifest.
@@ -32,31 +35,15 @@ type ModelManifest struct {
BlobDir string
}
-// DefaultBlobDir returns the default blob storage directory.
func DefaultBlobDir() string {
- home, err := os.UserHomeDir()
- if err != nil {
- home = "."
- }
- switch runtime.GOOS {
- case "darwin":
- return filepath.Join(home, ".ollama", "models", "blobs")
- case "linux":
- return filepath.Join(home, ".ollama", "models", "blobs")
- case "windows":
- return filepath.Join(home, ".ollama", "models", "blobs")
- default:
- return filepath.Join(home, ".ollama", "models", "blobs")
- }
+ return filepath.Join(envconfig.Models(), "blobs")
}
-// DefaultManifestDir returns the default manifest storage directory.
+// DefaultManifestDir returns the manifest storage directory.
+// Respects OLLAMA_MODELS.
+
func DefaultManifestDir() string {
- home, err := os.UserHomeDir()
- if err != nil {
- home = "."
- }
- return filepath.Join(home, ".ollama", "models", "manifests")
+ return filepath.Join(envconfig.Models(), "manifests")
}
// LoadManifest loads a manifest for the given model name.
@@ -117,14 +104,17 @@ func (m *ModelManifest) BlobPath(digest string) string {
return filepath.Join(m.BlobDir, blobName)
}
-// GetTensorLayers returns all tensor layers for a given component.
-// Component should be "text_encoder", "transformer", or "vae".
-// Tensor names are path-style: "component/tensor_name" (e.g., "text_encoder/model.embed_tokens.weight").
+// GetTensorLayers returns tensor layers, optionally filtered by component.
+// If component is empty, returns all tensor layers (for LLM models).
+// If component is specified (e.g., "text_encoder", "transformer", "vae"),
+// returns only layers with that prefix.
func (m *ModelManifest) GetTensorLayers(component string) []ManifestLayer {
- prefix := component + "/"
var layers []ManifestLayer
for _, layer := range m.Manifest.Layers {
- if layer.MediaType == "application/vnd.ollama.image.tensor" && strings.HasPrefix(layer.Name, prefix) {
+ if layer.MediaType != "application/vnd.ollama.image.tensor" {
+ continue
+ }
+ if component == "" || strings.HasPrefix(layer.Name, component+"/") {
layers = append(layers, layer)
}
}
@@ -176,6 +166,17 @@ func (m *ModelManifest) HasTensorLayers() bool {
return false
}
+// TotalTensorSize returns the total size in bytes of all tensor layers.
+func (m *ModelManifest) TotalTensorSize() int64 {
+ var total int64
+ for _, layer := range m.Manifest.Layers {
+ if layer.MediaType == "application/vnd.ollama.image.tensor" {
+ total += layer.Size
+ }
+ }
+ return total
+}
+
// ModelInfo contains metadata about an image generation model.
type ModelInfo struct {
Architecture string
@@ -206,17 +207,12 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
}
}
- // Fallback: detect quantization from tensor names if not in config
+ // Fallback: detect quantization from first tensor blob's __metadata__
if info.Quantization == "" {
- for _, layer := range manifest.Manifest.Layers {
- if strings.HasSuffix(layer.Name, ".weight_scale") {
- info.Quantization = "FP8"
- break
- }
- }
- if info.Quantization == "" {
- info.Quantization = "BF16"
- }
+ info.Quantization = detectQuantizationFromBlobs(manifest)
+ }
+ if info.Quantization == "" {
+ info.Quantization = "BF16"
}
// Fallback: estimate parameter count if not in config
@@ -224,9 +220,7 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
var totalSize int64
for _, layer := range manifest.Manifest.Layers {
if layer.MediaType == "application/vnd.ollama.image.tensor" {
- if !strings.HasSuffix(layer.Name, "_scale") && !strings.HasSuffix(layer.Name, "_qbias") {
- totalSize += layer.Size
- }
+ totalSize += layer.Size
}
}
// Assume BF16 (2 bytes/param) as rough estimate
@@ -235,3 +229,79 @@ func GetModelInfo(modelName string) (*ModelInfo, error) {
return info, nil
}
+
+// detectQuantizationFromBlobs reads __metadata__ from the first tensor blob
+// to detect quantization type.
+func detectQuantizationFromBlobs(manifest *ModelManifest) string {
+ for _, layer := range manifest.Manifest.Layers {
+ if layer.MediaType != "application/vnd.ollama.image.tensor" {
+ continue
+ }
+ data, err := readBlobHeader(manifest.BlobPath(layer.Digest))
+ if err != nil {
+ continue
+ }
+ var header map[string]json.RawMessage
+ if json.Unmarshal(data, &header) != nil {
+ continue
+ }
+ if metaRaw, ok := header["__metadata__"]; ok {
+ var meta map[string]string
+ if json.Unmarshal(metaRaw, &meta) == nil {
+ if qt, ok := meta["quant_type"]; ok && qt != "" {
+ return strings.ToUpper(qt)
+ }
+ }
+ }
+ // Only check the first tensor blob
+ break
+ }
+ return ""
+}
+
+// ParseBlobTensorNames reads a safetensors blob and returns all "main" tensor names.
+// Filters out __metadata__, .scale, and .bias entries to return only primary weight tensors.
+func ParseBlobTensorNames(path string) ([]string, error) {
+ data, err := readBlobHeader(path)
+ if err != nil {
+ return nil, err
+ }
+
+ var header map[string]json.RawMessage
+ if err := json.Unmarshal(data, &header); err != nil {
+ return nil, err
+ }
+
+ var names []string
+ for k := range header {
+ if k == "__metadata__" || strings.HasSuffix(k, ".scale") || strings.HasSuffix(k, ".bias") {
+ continue
+ }
+ names = append(names, k)
+ }
+
+ sort.Strings(names)
+ return names, nil
+}
+
+// readBlobHeader reads the JSON header bytes from a safetensors blob file.
+func readBlobHeader(path string) ([]byte, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+
+ var headerSize uint64
+ if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
+ return nil, err
+ }
+ if headerSize > 1024*1024 {
+ return nil, fmt.Errorf("header too large: %d", headerSize)
+ }
+ data := make([]byte, headerSize)
+ if _, err := io.ReadFull(f, data); err != nil {
+ return nil, err
+ }
+ return data, nil
+}
diff --git a/x/imagegen/manifest/manifest_test.go b/x/imagegen/manifest/manifest_test.go
new file mode 100644
index 00000000000..03361c6dfbd
--- /dev/null
+++ b/x/imagegen/manifest/manifest_test.go
@@ -0,0 +1,57 @@
+package manifest
+
+import (
+ "path/filepath"
+ "testing"
+)
+
+func TestTotalTensorSize(t *testing.T) {
+ m := &ModelManifest{
+ Manifest: &Manifest{
+ Layers: []ManifestLayer{
+ {MediaType: "application/vnd.ollama.image.tensor", Size: 1000},
+ {MediaType: "application/vnd.ollama.image.tensor", Size: 2000},
+ {MediaType: "application/vnd.ollama.image.json", Size: 500}, // not a tensor
+ {MediaType: "application/vnd.ollama.image.tensor", Size: 3000},
+ },
+ },
+ }
+
+ got := m.TotalTensorSize()
+ want := int64(6000)
+ if got != want {
+ t.Errorf("TotalTensorSize() = %d, want %d", got, want)
+ }
+}
+
+func TestTotalTensorSizeEmpty(t *testing.T) {
+ m := &ModelManifest{
+ Manifest: &Manifest{
+ Layers: []ManifestLayer{},
+ },
+ }
+
+ if got := m.TotalTensorSize(); got != 0 {
+ t.Errorf("TotalTensorSize() = %d, want 0", got)
+ }
+}
+
+func TestManifestAndBlobDirsRespectOLLAMAModels(t *testing.T) {
+ modelsDir := filepath.Join(t.TempDir(), "models")
+
+ // Simulate packaged/systemd environment
+ t.Setenv("OLLAMA_MODELS", modelsDir)
+ t.Setenv("HOME", "/usr/share/ollama")
+
+ // Manifest dir must respect OLLAMA_MODELS
+ wantManifest := filepath.Join(modelsDir, "manifests")
+ if got := DefaultManifestDir(); got != wantManifest {
+ t.Fatalf("DefaultManifestDir() = %q, want %q", got, wantManifest)
+ }
+
+ // Blob dir must respect OLLAMA_MODELS
+ wantBlobs := filepath.Join(modelsDir, "blobs")
+ if got := DefaultBlobDir(); got != wantBlobs {
+ t.Fatalf("DefaultBlobDir() = %q, want %q", got, wantBlobs)
+ }
+}
diff --git a/x/imagegen/manifest/weights.go b/x/imagegen/manifest/weights.go
new file mode 100644
index 00000000000..e1209c9db5c
--- /dev/null
+++ b/x/imagegen/manifest/weights.go
@@ -0,0 +1,298 @@
+//go:build mlx
+
+package manifest
+
+import (
+ "fmt"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/ollama/ollama/x/imagegen/mlx"
+)
+
+// ManifestWeights provides fast weight loading from tensor blobs.
+// Uses native mmap loading with synthetic safetensors headers for zero-copy.
+type ManifestWeights struct {
+ manifest *ModelManifest
+ component string
+ tensors map[string]ManifestLayer // name -> layer
+ cache map[string]*mlx.Array // name -> loaded array
+ nativeCache []*mlx.SafetensorsFile // keep native handles alive
+ quantType string // quantization type from blob metadata (e.g., "int4", "int8")
+ groupSize int // quantization group size from blob metadata
+}
+
+// LoadWeightsFromManifest creates a weight loader from manifest storage.
+// If component is empty, loads all tensors (for LLM models).
+// If component is specified, loads only tensors for that component and strips the prefix.
+func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
+ layers := manifest.GetTensorLayers(component)
+ if len(layers) == 0 {
+ if component == "" {
+ return nil, fmt.Errorf("no tensor layers found in manifest")
+ }
+ return nil, fmt.Errorf("no tensor layers found for component %q", component)
+ }
+
+ // Strip component prefix from tensor names for model loading
+ // e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
+ tensors := make(map[string]ManifestLayer, len(layers))
+ for _, layer := range layers {
+ if component == "" {
+ tensors[layer.Name] = layer
+ } else {
+ tensorName := strings.TrimPrefix(layer.Name, component+"/")
+ tensors[tensorName] = layer
+ }
+ }
+
+ return &ManifestWeights{
+ manifest: manifest,
+ component: component,
+ tensors: tensors,
+ cache: make(map[string]*mlx.Array),
+ }, nil
+}
+
+// Load loads all tensor blobs using native mmap (zero-copy).
+// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
+// Combined quantized blobs contain tensors keyed by name, name+".scale", and optional name+".bias"
+// with quantization metadata. Scale and bias are stored in cache as name+"_scale"
+// and name+"_qbias" for compatibility with downstream loading code.
+// Packed blobs (e.g., for expert groups) contain multiple tensors; the manifest name
+// is a group prefix and individual tensors are loaded by their actual names from the blob.
+// If dtype is non-zero, non-quantized tensors are converted to the specified dtype.
+func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
+ // Track native handles to free after batch eval
+ nativeHandles := make([]*mlx.SafetensorsFile, 0, len(mw.tensors))
+ arrays := make([]*mlx.Array, 0, len(mw.tensors))
+
+ // Group tensors by digest to avoid loading the same blob multiple times
+ type blobEntry struct {
+ name string
+ layer ManifestLayer
+ }
+ blobGroups := make(map[string][]blobEntry)
+ for name, layer := range mw.tensors {
+ blobGroups[layer.Digest] = append(blobGroups[layer.Digest], blobEntry{name, layer})
+ }
+
+ for digest, entries := range blobGroups {
+ path := mw.manifest.BlobPath(digest)
+
+ // Load blob as safetensors (native mmap, zero-copy)
+ sf, err := mlx.LoadSafetensorsNative(path)
+ if err != nil {
+ for _, h := range nativeHandles {
+ h.Free()
+ }
+ return fmt.Errorf("load %s: %w", entries[0].name, err)
+ }
+ nativeHandles = append(nativeHandles, sf)
+
+ // Read quantization metadata from blob
+ if qt := sf.GetMetadata("quant_type"); qt != "" && mw.quantType == "" {
+ mw.quantType = qt
+ if gs := sf.GetMetadata("group_size"); gs != "" {
+ mw.groupSize, _ = strconv.Atoi(gs)
+ }
+ }
+
+ for _, entry := range entries {
+ name := entry.name
+
+ // Try to get tensor by stripped name first, then with component prefix,
+ // then fall back to "data" for legacy blobs created by older versions
+ // that stored all tensors with the generic key "data".
+ lookupName := name
+ arr := sf.Get(lookupName)
+ if arr == nil && mw.component != "" {
+ lookupName = mw.component + "/" + name
+ arr = sf.Get(lookupName)
+ }
+ if arr == nil {
+ // Legacy blob format: tensor stored as "data"
+ lookupName = "data"
+ arr = sf.Get(lookupName)
+ }
+ if arr != nil {
+ // Single-tensor blob or tensor found by name
+ if dtype != 0 && arr.Dtype() != dtype {
+ arr = mlx.AsType(arr, dtype)
+ }
+ arr = mlx.Contiguous(arr)
+ mw.cache[name] = arr
+ arrays = append(arrays, arr)
+
+ // Check for scale tensor
+ if scale := sf.Get(lookupName + ".scale"); scale != nil {
+ scale = mlx.Contiguous(scale)
+ mw.cache[name+"_scale"] = scale
+ arrays = append(arrays, scale)
+ }
+
+ // Check for bias tensor
+ if bias := sf.Get(lookupName + ".bias"); bias != nil {
+ bias = mlx.Contiguous(bias)
+ mw.cache[name+"_qbias"] = bias
+ arrays = append(arrays, bias)
+ }
+ } else {
+ // Packed blob: manifest name is a group prefix, not a tensor name.
+ // Load all individual tensors from the blob.
+ tensorNames, err := ParseBlobTensorNames(path)
+ if err != nil {
+ for _, h := range nativeHandles {
+ h.Free()
+ }
+ return fmt.Errorf("parse packed blob for %s: %w", name, err)
+ }
+
+ for _, tensorName := range tensorNames {
+ tArr := sf.Get(tensorName)
+ if tArr == nil {
+ continue
+ }
+
+ if dtype != 0 && tArr.Dtype() != dtype {
+ tArr = mlx.AsType(tArr, dtype)
+ }
+ tArr = mlx.Contiguous(tArr)
+
+ // Strip component prefix from blob-internal names so cache keys
+ // match the stripped names used by LoadModule.
+ cacheName := tensorName
+ if mw.component != "" {
+ cacheName = strings.TrimPrefix(tensorName, mw.component+"/")
+ }
+ mw.cache[cacheName] = tArr
+ arrays = append(arrays, tArr)
+
+ // Check for scale tensor
+ if scale := sf.Get(tensorName + ".scale"); scale != nil {
+ scale = mlx.Contiguous(scale)
+ mw.cache[cacheName+"_scale"] = scale
+ arrays = append(arrays, scale)
+ }
+
+ // Check for bias tensor
+ if bias := sf.Get(tensorName + ".bias"); bias != nil {
+ bias = mlx.Contiguous(bias)
+ mw.cache[cacheName+"_qbias"] = bias
+ arrays = append(arrays, bias)
+ }
+ }
+ }
+ }
+ }
+
+ // Batch evaluate all tensors at once (much faster than one at a time)
+ mlx.Eval(arrays...)
+
+ // Now safe to free all native handles
+ for _, sf := range nativeHandles {
+ sf.Free()
+ }
+
+ return nil
+}
+
+// GetTensor returns a tensor from cache. Call Load() first.
+func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
+ if mw.cache == nil {
+ return nil, fmt.Errorf("cache not initialized: call Load() first")
+ }
+ arr, ok := mw.cache[name]
+ if !ok {
+ return nil, fmt.Errorf("tensor %q not found", name)
+ }
+ return arr, nil
+}
+
+// ListTensors returns all tensor names in sorted order.
+// Includes both manifest tensor names and scale/bias entries from combined blobs.
+func (mw *ManifestWeights) ListTensors() []string {
+ seen := make(map[string]bool, len(mw.tensors)+len(mw.cache))
+ for name := range mw.tensors {
+ seen[name] = true
+ }
+ // Also include cache entries (scale/bias from combined blobs)
+ for name := range mw.cache {
+ seen[name] = true
+ }
+ names := make([]string, 0, len(seen))
+ for name := range seen {
+ names = append(names, name)
+ }
+ sort.Strings(names)
+ return names
+}
+
+// HasTensor checks if a tensor exists in the manifest or cache.
+func (mw *ManifestWeights) HasTensor(name string) bool {
+ if _, ok := mw.tensors[name]; ok {
+ return true
+ }
+ // Also check cache for scale/bias entries from combined blobs
+ if _, ok := mw.cache[name]; ok {
+ return true
+ }
+ return false
+}
+
+// Quantization returns the model's quantization type.
+// Returns the quant_type from blob metadata (e.g., "int4", "int8", "nvfp4", "mxfp8").
+// Returns empty string if not quantized.
+// Falls back to model_index.json for image gen models.
+func (mw *ManifestWeights) Quantization() string {
+ if mw.quantType != "" {
+ return strings.ToUpper(mw.quantType)
+ }
+
+ if mw.manifest == nil {
+ return ""
+ }
+
+ // Fallback: read from model_index.json (for image gen models)
+ var index struct {
+ Quantization string `json:"quantization"`
+ }
+ if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.Quantization != "" {
+ return index.Quantization
+ }
+
+ return ""
+}
+
+// GroupSize returns the quantization group size.
+// Returns the group_size from blob metadata.
+// Returns 0 if not specified (caller should use default based on quantization type).
+func (mw *ManifestWeights) GroupSize() int {
+ if mw.groupSize > 0 {
+ return mw.groupSize
+ }
+
+ if mw.manifest == nil {
+ return 0
+ }
+
+ // Fallback: read from model_index.json (for image gen models)
+ var index struct {
+ GroupSize int `json:"group_size"`
+ }
+ if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err == nil && index.GroupSize > 0 {
+ return index.GroupSize
+ }
+
+ return 0
+}
+
+// ReleaseAll frees all native handles and clears the tensor cache.
+func (mw *ManifestWeights) ReleaseAll() {
+ for _, sf := range mw.nativeCache {
+ sf.Free()
+ }
+ mw.nativeCache = nil
+ mw.cache = nil
+}
diff --git a/x/imagegen/memory.go b/x/imagegen/memory.go
index 57dc4667cd1..39a428c69e3 100644
--- a/x/imagegen/memory.go
+++ b/x/imagegen/memory.go
@@ -14,20 +14,13 @@ import (
"encoding/json"
"fmt"
"runtime"
-)
-// GB is a convenience constant for gigabytes.
-const GB = 1024 * 1024 * 1024
+ "github.com/ollama/ollama/x/imagegen/manifest"
+)
// SupportedBackends lists the backends that support image generation.
var SupportedBackends = []string{"metal", "cuda", "cpu"}
-// modelVRAMEstimates maps pipeline class names to their estimated VRAM requirements.
-var modelVRAMEstimates = map[string]uint64{
- "ZImagePipeline": 21 * GB, // ~21GB for Z-Image (text encoder + transformer + VAE)
- "FluxPipeline": 20 * GB, // ~20GB for Flux
-}
-
// CheckPlatformSupport validates that image generation is supported on the current platform.
// Returns nil if supported, or an error describing why it's not supported.
func CheckPlatformSupport() error {
@@ -47,47 +40,26 @@ func CheckPlatformSupport() error {
}
}
-// CheckMemoryRequirements validates that there's enough memory for image generation.
-// Returns nil if memory is sufficient, or an error if not.
-func CheckMemoryRequirements(modelName string, availableMemory uint64) error {
- required := EstimateVRAM(modelName)
- if availableMemory < required {
- return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
- required/GB, availableMemory/GB)
- }
- return nil
-}
-
// ResolveModelName checks if a model name is a known image generation model.
// Returns the normalized model name if found, empty string otherwise.
func ResolveModelName(modelName string) string {
- manifest, err := LoadManifest(modelName)
- if err == nil && manifest.HasTensorLayers() {
+ modelManifest, err := manifest.LoadManifest(modelName)
+ if err == nil && modelManifest.HasTensorLayers() {
return modelName
}
return ""
}
-// EstimateVRAM returns the estimated VRAM needed for an image generation model.
-// Returns a conservative default of 21GB if the model type cannot be determined.
-func EstimateVRAM(modelName string) uint64 {
- className := DetectModelType(modelName)
- if estimate, ok := modelVRAMEstimates[className]; ok {
- return estimate
- }
- return 21 * GB
-}
-
// DetectModelType reads model_index.json and returns the model type.
// Checks both "architecture" (Ollama format) and "_class_name" (diffusers format).
// Returns empty string if detection fails.
func DetectModelType(modelName string) string {
- manifest, err := LoadManifest(modelName)
+ modelManifest, err := manifest.LoadManifest(modelName)
if err != nil {
return ""
}
- data, err := manifest.ReadConfig("model_index.json")
+ data, err := modelManifest.ReadConfig("model_index.json")
if err != nil {
return ""
}
diff --git a/x/imagegen/memory_test.go b/x/imagegen/memory_test.go
index 180021f6bca..531cffda27a 100644
--- a/x/imagegen/memory_test.go
+++ b/x/imagegen/memory_test.go
@@ -30,69 +30,6 @@ func TestCheckPlatformSupport(t *testing.T) {
}
}
-func TestCheckMemoryRequirements(t *testing.T) {
- tests := []struct {
- name string
- availableMemory uint64
- wantErr bool
- }{
- {
- name: "sufficient memory",
- availableMemory: 32 * GB,
- wantErr: false,
- },
- {
- name: "exactly enough memory",
- availableMemory: 21 * GB,
- wantErr: false,
- },
- {
- name: "insufficient memory",
- availableMemory: 16 * GB,
- wantErr: true,
- },
- {
- name: "zero memory",
- availableMemory: 0,
- wantErr: true,
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- // Use a non-existent model name which will default to 21GB estimate
- err := CheckMemoryRequirements("nonexistent-model", tt.availableMemory)
- if (err != nil) != tt.wantErr {
- t.Errorf("CheckMemoryRequirements() error = %v, wantErr %v", err, tt.wantErr)
- }
- })
- }
-}
-
-func TestModelVRAMEstimates(t *testing.T) {
- // Verify the VRAM estimates map has expected entries
- expected := map[string]uint64{
- "ZImagePipeline": 21 * GB,
- "FluxPipeline": 20 * GB,
- }
-
- for name, expectedVRAM := range expected {
- if actual, ok := modelVRAMEstimates[name]; !ok {
- t.Errorf("Missing VRAM estimate for %s", name)
- } else if actual != expectedVRAM {
- t.Errorf("VRAM estimate for %s = %d GB, want %d GB", name, actual/GB, expectedVRAM/GB)
- }
- }
-}
-
-func TestEstimateVRAMDefault(t *testing.T) {
- // Non-existent model should return default 21GB
- vram := EstimateVRAM("nonexistent-model-that-does-not-exist")
- if vram != 21*GB {
- t.Errorf("EstimateVRAM() = %d GB, want 21 GB", vram/GB)
- }
-}
-
func TestResolveModelName(t *testing.T) {
// Non-existent model should return empty string
result := ResolveModelName("nonexistent-model")
diff --git a/x/ml/backend/mlx/CMakeLists.txt b/x/imagegen/mlx/CMakeLists.txt
similarity index 100%
rename from x/ml/backend/mlx/CMakeLists.txt
rename to x/imagegen/mlx/CMakeLists.txt
diff --git a/x/imagegen/mlx/mlx.go b/x/imagegen/mlx/mlx.go
index 2b31aadfb39..cf3e5157239 100644
--- a/x/imagegen/mlx/mlx.go
+++ b/x/imagegen/mlx/mlx.go
@@ -991,6 +991,19 @@ func Concat(a, b *Array, axis int) *Array {
return Concatenate([]*Array{a, b}, axis)
}
+// Stack stacks arrays along a new axis (axis 0 by default)
+func Stack(arrays []*Array, axis int) *Array {
+ handles := make([]C.mlx_array, len(arrays))
+ for i, arr := range arrays {
+ handles[i] = arr.c
+ }
+ vec := C.mlx_vector_array_new_data(&handles[0], C.size_t(len(handles)))
+ res := C.mlx_array_new()
+ C.mlx_stack_axis(&res, vec, C.int(axis), C.default_stream())
+ C.mlx_vector_array_free(vec)
+ return newArray(res)
+}
+
// Slice slices the array
func Slice(a *Array, start, stop []int32) *Array {
n := len(start)
@@ -1531,6 +1544,18 @@ func (s *SafetensorsFile) Count() int {
return 0
}
+// GetMetadata retrieves a metadata value by key from the safetensors file
+func (s *SafetensorsFile) GetMetadata(key string) string {
+ cKey := C.CString(key)
+ defer C.free(unsafe.Pointer(cKey))
+
+ var cValue *C.char
+ if C.mlx_map_string_to_string_get(&cValue, s.metadata, cKey) != 0 {
+ return ""
+ }
+ return C.GoString(cValue)
+}
+
// Free releases the safetensors file
func (s *SafetensorsFile) Free() {
C.mlx_map_string_to_array_free(s.arrays)
@@ -1565,6 +1590,41 @@ func SaveSafetensors(path string, arrays map[string]*Array) error {
return nil
}
+// SaveSafetensorsWithMetadata saves arrays to a safetensors file with metadata key/value pairs.
+// This is like SaveSafetensors but inserts metadata into the __metadata__ section.
+func SaveSafetensorsWithMetadata(path string, arrays map[string]*Array, metadata map[string]string) error {
+ cPath := C.CString(path)
+ defer C.free(unsafe.Pointer(cPath))
+
+ // Create the array map
+ cArrays := C.mlx_map_string_to_array_new()
+ defer C.mlx_map_string_to_array_free(cArrays)
+
+ for name, arr := range arrays {
+ cName := C.CString(name)
+ C.mlx_map_string_to_array_insert(cArrays, cName, arr.c)
+ C.free(unsafe.Pointer(cName))
+ }
+
+ // Create metadata map
+ cMeta := C.mlx_map_string_to_string_new()
+ defer C.mlx_map_string_to_string_free(cMeta)
+
+ for key, value := range metadata {
+ cKey := C.CString(key)
+ cValue := C.CString(value)
+ C.mlx_map_string_to_string_insert(cMeta, cKey, cValue)
+ C.free(unsafe.Pointer(cKey))
+ C.free(unsafe.Pointer(cValue))
+ }
+
+ // Save
+ if C.mlx_save_safetensors(cPath, cArrays, cMeta) != 0 {
+ return fmt.Errorf("failed to save safetensors: %s", path)
+ }
+ return nil
+}
+
// ============ NPY Loading ============
// LoadNpy loads a numpy array from an npy file
diff --git a/x/imagegen/models/flux2/flux2.go b/x/imagegen/models/flux2/flux2.go
index 348490ba734..894af41f80c 100644
--- a/x/imagegen/models/flux2/flux2.go
+++ b/x/imagegen/models/flux2/flux2.go
@@ -12,7 +12,7 @@ import (
"math"
"time"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/models/qwen3"
"github.com/ollama/ollama/x/imagegen/tokenizer"
@@ -61,7 +61,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
- manifest, err := imagegen.LoadManifest(modelName)
+ manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
@@ -177,6 +177,20 @@ func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height
})
}
+// GenerateImageWithInputs implements runner.ImageEditModel interface.
+// It generates an image conditioned on the provided input images for image editing.
+func (m *Model) GenerateImageWithInputs(ctx context.Context, prompt string, width, height int32, steps int, seed int64, inputImages []image.Image, progress func(step, total int)) (*mlx.Array, error) {
+ return m.GenerateFromConfig(ctx, &GenerateConfig{
+ Prompt: prompt,
+ Width: width,
+ Height: height,
+ Steps: steps,
+ Seed: seed,
+ InputImages: inputImages,
+ Progress: progress,
+ })
+}
+
// MaxOutputPixels is the maximum output resolution (4 megapixels, ~2048x2048)
const MaxOutputPixels = 2048 * 2048
diff --git a/x/imagegen/models/flux2/transformer.go b/x/imagegen/models/flux2/transformer.go
index 3a9524ea771..93771a661b0 100644
--- a/x/imagegen/models/flux2/transformer.go
+++ b/x/imagegen/models/flux2/transformer.go
@@ -6,7 +6,7 @@ import (
"fmt"
"math"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -14,19 +14,19 @@ import (
// TransformerConfig holds Flux2 transformer configuration
type TransformerConfig struct {
- AttentionHeadDim int32 `json:"attention_head_dim"` // 128
- AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
- Eps float32 `json:"eps"` // 1e-6
- GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
- InChannels int32 `json:"in_channels"` // 128
- JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
- MLPRatio float32 `json:"mlp_ratio"` // 3.0
- NumAttentionHeads int32 `json:"num_attention_heads"` // 24
- NumLayers int32 `json:"num_layers"` // 5
- NumSingleLayers int32 `json:"num_single_layers"` // 20
- PatchSize int32 `json:"patch_size"` // 1
- RopeTheta int32 `json:"rope_theta"` // 2000
- TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
+ AttentionHeadDim int32 `json:"attention_head_dim"` // 128
+ AxesDimsRoPE []int32 `json:"axes_dims_rope"` // [32, 32, 32, 32]
+ Eps float32 `json:"eps"` // 1e-6
+ GuidanceEmbeds bool `json:"guidance_embeds"` // false for Klein
+ InChannels int32 `json:"in_channels"` // 128
+ JointAttentionDim int32 `json:"joint_attention_dim"` // 7680
+ MLPRatio float32 `json:"mlp_ratio"` // 3.0
+ NumAttentionHeads int32 `json:"num_attention_heads"` // 24
+ NumLayers int32 `json:"num_layers"` // 5
+ NumSingleLayers int32 `json:"num_single_layers"` // 20
+ PatchSize int32 `json:"patch_size"` // 1
+ RopeTheta int32 `json:"rope_theta"` // 2000
+ TimestepGuidanceChannels int32 `json:"timestep_guidance_channels"` // 256
}
// Computed dimensions
@@ -392,12 +392,12 @@ type Flux2Transformer2DModel struct {
}
// Load loads the Flux2 transformer from ollama blob storage.
-func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
+func (m *Flux2Transformer2DModel) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
- if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
+ if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.TransformerConfig = &cfg
@@ -412,7 +412,7 @@ func (m *Flux2Transformer2DModel) Load(manifest *imagegen.ModelManifest) error {
}
// Load weights from tensor blobs
- weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
+ weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
diff --git a/x/imagegen/models/flux2/vae.go b/x/imagegen/models/flux2/vae.go
index 9736b08c7dc..4b09b1ba4f1 100644
--- a/x/imagegen/models/flux2/vae.go
+++ b/x/imagegen/models/flux2/vae.go
@@ -6,7 +6,7 @@ import (
"fmt"
"math"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -15,21 +15,21 @@ import (
// VAEConfig holds AutoencoderKLFlux2 configuration
type VAEConfig struct {
- ActFn string `json:"act_fn"` // "silu"
- BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
- BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
- BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
- ForceUpcast bool `json:"force_upcast"` // true
- InChannels int32 `json:"in_channels"` // 3
- LatentChannels int32 `json:"latent_channels"` // 32
- LayersPerBlock int32 `json:"layers_per_block"` // 2
+ ActFn string `json:"act_fn"` // "silu"
+ BatchNormEps float32 `json:"batch_norm_eps"` // 0.0001
+ BatchNormMomentum float32 `json:"batch_norm_momentum"` // 0.1
+ BlockOutChannels []int32 `json:"block_out_channels"` // [128, 256, 512, 512]
+ ForceUpcast bool `json:"force_upcast"` // true
+ InChannels int32 `json:"in_channels"` // 3
+ LatentChannels int32 `json:"latent_channels"` // 32
+ LayersPerBlock int32 `json:"layers_per_block"` // 2
MidBlockAddAttn bool `json:"mid_block_add_attention"` // true
- NormNumGroups int32 `json:"norm_num_groups"` // 32
- OutChannels int32 `json:"out_channels"` // 3
- PatchSize []int32 `json:"patch_size"` // [2, 2]
- SampleSize int32 `json:"sample_size"` // 1024
- UsePostQuantConv bool `json:"use_post_quant_conv"` // true
- UseQuantConv bool `json:"use_quant_conv"` // true
+ NormNumGroups int32 `json:"norm_num_groups"` // 32
+ OutChannels int32 `json:"out_channels"` // 3
+ PatchSize []int32 `json:"patch_size"` // [2, 2]
+ SampleSize int32 `json:"sample_size"` // 1024
+ UsePostQuantConv bool `json:"use_post_quant_conv"` // true
+ UseQuantConv bool `json:"use_quant_conv"` // true
}
// BatchNorm2D implements 2D batch normalization with running statistics
@@ -356,18 +356,18 @@ func (db *DownEncoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
}
// Load loads the Flux2 VAE from ollama blob storage.
-func (m *AutoencoderKLFlux2) Load(manifest *imagegen.ModelManifest) error {
+func (m *AutoencoderKLFlux2) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading VAE... ")
// Load config from blob
var cfg VAEConfig
- if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
+ if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
- weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
+ weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
diff --git a/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
new file mode 100644
index 00000000000..3931693b895
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/glm4_moe_lite.go
@@ -0,0 +1,840 @@
+//go:build mlx
+
+// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
+// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
+package glm4_moe_lite
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/x/imagegen/cache"
+ "github.com/ollama/ollama/x/imagegen/manifest"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+ "github.com/ollama/ollama/x/imagegen/nn"
+ "github.com/ollama/ollama/x/imagegen/safetensors"
+ "github.com/ollama/ollama/x/imagegen/tokenizer"
+)
+
+// RopeScaling holds RoPE scaling configuration
+type RopeScaling struct {
+ Factor float32 `json:"factor"`
+ MscaleAllDim float32 `json:"mscale_all_dim"`
+}
+
+// Config holds GLM4-MoE-Lite model configuration
+type Config struct {
+ HiddenSize int32 `json:"hidden_size"`
+ NumHiddenLayers int32 `json:"num_hidden_layers"`
+ IntermediateSize int32 `json:"intermediate_size"`
+ MoEIntermediateSize int32 `json:"moe_intermediate_size"`
+ NumAttentionHeads int32 `json:"num_attention_heads"`
+ NumKeyValueHeads int32 `json:"num_key_value_heads"`
+ VocabSize int32 `json:"vocab_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ RopeTheta float32 `json:"rope_theta"`
+ MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
+ AttentionBias bool `json:"attention_bias"`
+
+ // MLA (Multi-head Latent Attention) parameters
+ QLoraRank int32 `json:"q_lora_rank"`
+ KVLoraRank int32 `json:"kv_lora_rank"`
+ QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
+ QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
+ VHeadDim int32 `json:"v_head_dim"`
+
+ // MoE parameters
+ NRoutedExperts int32 `json:"n_routed_experts"`
+ NSharedExperts int32 `json:"n_shared_experts"`
+ NumExpertsPerTok int32 `json:"num_experts_per_tok"`
+ RoutedScalingFactor float32 `json:"routed_scaling_factor"`
+ NormTopKProb bool `json:"norm_topk_prob"`
+ FirstKDenseReplace int32 `json:"first_k_dense_replace"`
+ NGroup int32 `json:"n_group"`
+ TopKGroup int32 `json:"topk_group"`
+
+ // RoPE scaling
+ RopeScaling *RopeScaling `json:"rope_scaling"`
+
+ // Quantization parameters (set during load based on model quantization)
+ QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
+ QuantBits int `json:"-"` // Bits per weight (4 or 8)
+ QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
+
+ // Computed fields
+ QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
+ Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
+}
+
+// MLAAttention implements Multi-head Latent Attention with absorption.
+// This uses absorbed MLA which operates in latent space for reduced KV cache.
+type MLAAttention struct {
+ // Low-rank query projections
+ QAProj nn.LinearLayer `weight:"self_attn.q_a_proj"`
+ QALayerNorm *nn.RMSNorm `weight:"self_attn.q_a_layernorm"`
+ QBProj nn.LinearLayer `weight:"self_attn.q_b_proj"`
+
+ // Low-rank KV projections (with shared rope component)
+ KVAProjWithMQA nn.LinearLayer `weight:"self_attn.kv_a_proj_with_mqa"`
+ KVALayerNorm *nn.RMSNorm `weight:"self_attn.kv_a_layernorm"`
+
+ // Absorbed MLA projections (derived from kv_b_proj)
+ // EmbedQ: projects q_nope to latent space [num_heads, kv_lora_rank, qk_nope_head_dim]
+ // UnembedOut: projects attention output from latent space [num_heads, v_head_dim, kv_lora_rank]
+ EmbedQ *nn.MultiLinear `weight:"-"`
+ UnembedOut *nn.MultiLinear `weight:"-"`
+
+ // Output projection
+ OProj nn.LinearLayer `weight:"self_attn.o_proj"`
+}
+
+// Forward computes absorbed MLA attention output.
+// This operates in latent space for reduced KV cache memory.
+func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ // Query path: q_a_proj -> layernorm -> q_b_proj
+ q := a.QAProj.Forward(x)
+ q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
+ q = a.QBProj.Forward(q)
+
+ // Reshape Q: [B, L, num_heads * q_head_dim] -> [B, num_heads, L, q_head_dim]
+ q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
+ q = mlx.Transpose(q, 0, 2, 1, 3)
+
+ // Split Q into nope and rope parts
+ qNope := mlx.Slice(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
+ qPE := mlx.Slice(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
+
+ // KV path: get compressed KV and k_pe
+ compressedKV := a.KVAProjWithMQA.Forward(x)
+
+ // Split into compressed_kv and k_pe (shared rope component)
+ kvCompressed := mlx.Slice(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
+ kPE := mlx.Slice(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
+
+ // k_pe is shared across heads (MQA-style): [B, L, rope_dim] -> [B, 1, L, rope_dim]
+ kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
+ kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
+
+ // Apply layernorm to get kv latent representation
+ kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
+ // kvLatent: [B, L, kv_lora_rank] -> [B, 1, L, kv_lora_rank] for broadcasting
+ kvLatent = mlx.ExpandDims(kvLatent, 1)
+
+ // Apply RoPE to the rope parts
+ offset := 0
+ if c != nil {
+ offset = c.Offset()
+ }
+ qPE = mlx.RoPE(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
+ kPE = mlx.RoPE(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
+
+ // ABSORBED MLA: project q_nope to latent space
+ // qNope: [B, num_heads, L, qk_nope_head_dim]
+ // EmbedQ: [num_heads, kv_lora_rank, qk_nope_head_dim]
+ // Result: [B, num_heads, L, kv_lora_rank]
+ qLatent := a.EmbedQ.Forward(qNope)
+
+ // Keys = concat(kvLatent, kPE)
+ // kvLatent: [B, 1, L, kv_lora_rank]
+ // kPE: [B, 1, L, qk_rope_head_dim]
+ // keys: [B, 1, L, kv_lora_rank + qk_rope_head_dim]
+ keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
+
+ // Cache the smaller latent representation
+ // We cache keys (latent + rope) and use empty values since values are derived from keys
+ cachedL := L
+ if c != nil {
+ // Create placeholder values with 0 dims for cache (we don't actually use cached values)
+ placeholderValues := mlx.Zeros([]int32{B, 1, L, 0}, mlx.DtypeFloat32)
+ keys, _ = c.Update(keys, placeholderValues, int(L))
+ cachedL = int32(keys.Shape()[2])
+ }
+
+ // Values are the first kv_lora_rank dims of keys (slice off rope part)
+ values := mlx.Slice(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
+
+ // Queries = concat(qLatent, qPE)
+ // qLatent: [B, num_heads, L, kv_lora_rank]
+ // qPE: [B, num_heads, L, qk_rope_head_dim]
+ // queries: [B, num_heads, L, kv_lora_rank + qk_rope_head_dim]
+ queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
+
+ // Attention in latent space
+ // queries: [B, num_heads, L, kv_lora_rank + rope_dim]
+ // keys: [B, 1, cachedL, kv_lora_rank + rope_dim]
+ // values: [B, 1, cachedL, kv_lora_rank]
+ out := mlx.ScaledDotProductAttention(queries, keys, values, cfg.Scale, L > 1)
+
+ // ABSORBED MLA: unembed from latent space
+ // out: [B, num_heads, L, kv_lora_rank]
+ // UnembedOut: [num_heads, v_head_dim, kv_lora_rank]
+ // Result: [B, num_heads, L, v_head_dim]
+ out = a.UnembedOut.Forward(out)
+
+ // Reshape back: [B, num_heads, L, v_head_dim] -> [B, L, num_heads * v_head_dim]
+ out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
+
+ return a.OProj.Forward(out)
+}
+
+// DenseMLP implements the standard SwiGLU MLP for dense layers
+type DenseMLP struct {
+ GateProj nn.LinearLayer `weight:"mlp.gate_proj"`
+ UpProj nn.LinearLayer `weight:"mlp.up_proj"`
+ DownProj nn.LinearLayer `weight:"mlp.down_proj"`
+}
+
+// Forward applies the SwiGLU MLP
+func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.SiLU(m.GateProj.Forward(x))
+ up := m.UpProj.Forward(x)
+ return m.DownProj.Forward(mlx.Mul(gate, up))
+}
+
+// MoEGate implements the expert gating mechanism
+type MoEGate struct {
+ Gate nn.LinearLayer `weight:"mlp.gate"`
+ EScoreCorrectionBias *mlx.Array `weight:"mlp.gate.e_score_correction_bias,optional"`
+}
+
+// Forward computes expert selection indices and scores
+func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
+ // Compute gate logits through linear layer (handles both quantized and non-quantized)
+ gates := g.Gate.Forward(x)
+
+ // Sigmoid scoring
+ scores := mlx.Sigmoid(gates)
+ origScores := scores
+
+ // Add correction bias if present
+ if g.EScoreCorrectionBias != nil {
+ scores = mlx.Add(scores, g.EScoreCorrectionBias)
+ }
+
+ // Group-wise expert selection (simplified for n_group=1)
+ // Select top-k experts
+ topK := cfg.NumExpertsPerTok
+ negScores := mlx.Neg(scores)
+ inds := mlx.Argpartition(negScores, int(topK)-1, -1)
+
+ shape := inds.Shape()
+ inds = mlx.Slice(inds, []int32{0, 0, 0}, []int32{shape[0], shape[1], topK})
+
+ // Get scores for selected experts
+ scores = mlx.TakeAlongAxis(origScores, inds, -1)
+
+ // Normalize if configured
+ if topK > 1 && cfg.NormTopKProb {
+ sumScores := mlx.Sum(scores, -1, true)
+ scores = mlx.Div(scores, sumScores)
+ }
+
+ // Apply routing scaling factor
+ scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
+
+ return inds, scores
+}
+
+// SwitchMLP implements the MoE expert computation using stacked weights
+// Note: No weight tags - these are populated manually by stacking expert weights
+type SwitchMLP struct {
+ // Dequantized weights (used when GatherQMM not available)
+ GateWeight *mlx.Array
+ UpWeight *mlx.Array
+ DownWeight *mlx.Array
+
+ // Quantized weights (used with GatherQMM for 4/8-bit affine)
+ GateWeightQ, GateScales, GateBiases *mlx.Array
+ UpWeightQ, UpScales, UpBiases *mlx.Array
+ DownWeightQ, DownScales, DownBiases *mlx.Array
+
+ // Quantization bits per projection (supports mixed precision Q4/Q8)
+ GateBits int
+ UpBits int
+ DownBits int
+
+ // Quantization group size per projection (detected from tensor shapes)
+ GateGroupSize int
+ UpGroupSize int
+ DownGroupSize int
+
+ // If true, use GatherQMM with quantized weights
+ UseQuantized bool
+}
+
+// Forward applies the switched expert MLP
+func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
+ shape := x.Shape()
+ B, L := shape[0], shape[1]
+ topK := cfg.NumExpertsPerTok
+
+ // Expand x for expert computation: [B, L, D] -> [B, L, 1, 1, D]
+ xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
+
+ // Flatten for gather_mm: [B*L, 1, 1, D]
+ xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
+
+ // Flatten indices: [B, L, topK] -> [B*L, topK]
+ idxFlat := mlx.Reshape(indices, B*L, topK)
+
+ // Sort for efficient gather (when we have many tokens)
+ doSort := B*L >= 64
+ var invOrder *mlx.Array
+ n := B * L * topK
+
+ if doSort {
+ idxAll := mlx.Flatten(idxFlat)
+ order := mlx.Argsort(idxAll, 0)
+ invOrder = mlx.Argsort(order, 0)
+ // Reorder x based on sorted indices
+ xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
+ idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
+ }
+
+ var gate, up, hidden, down *mlx.Array
+
+ if s.UseQuantized {
+ // Use GatherQMM for quantized weights (faster, keeps weights quantized)
+ // Each projection may have different bits and group sizes (mixed precision: Q4 for gate/up, Q8 for down)
+ gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
+ nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
+ up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
+ nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
+
+ hidden = mlx.Mul(mlx.SiLU(gate), up)
+
+ down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
+ nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
+ } else {
+ // Use GatherMM for dequantized/non-quantized weights
+ gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
+ up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
+
+ hidden = mlx.Mul(mlx.SiLU(gate), up)
+
+ down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
+ }
+
+ // Unsort if we sorted
+ if doSort {
+ down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
+ } else {
+ down = mlx.Squeeze(down, 2)
+ }
+
+ return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
+}
+
+// SharedExperts implements the shared expert MLP
+type SharedExperts struct {
+ GateProj nn.LinearLayer `weight:"mlp.shared_experts.gate_proj"`
+ UpProj nn.LinearLayer `weight:"mlp.shared_experts.up_proj"`
+ DownProj nn.LinearLayer `weight:"mlp.shared_experts.down_proj"`
+}
+
+// Forward applies the shared expert MLP
+func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.SiLU(s.GateProj.Forward(x))
+ up := s.UpProj.Forward(x)
+ return s.DownProj.Forward(mlx.Mul(gate, up))
+}
+
+// MoE implements the full Mixture of Experts layer
+type MoE struct {
+ Gate *MoEGate
+ SwitchMLP *SwitchMLP
+ SharedExperts *SharedExperts
+}
+
+// Forward applies the MoE layer
+func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
+ shape := x.Shape()
+ B, L := shape[0], shape[1]
+
+ // Get expert indices and scores
+ inds, scores := m.Gate.Forward(x, cfg)
+
+ // Apply routed experts
+ expertOut := m.SwitchMLP.Forward(x, inds, cfg)
+
+ // Weight by scores: [B, L, topK, D] * [B, L, topK, 1] -> sum over topK
+ scoresExpanded := mlx.ExpandDims(scores, -1)
+ y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
+
+ // Add shared experts if present
+ if m.SharedExperts != nil {
+ y = mlx.Add(y, m.SharedExperts.Forward(x))
+ }
+
+ return mlx.Reshape(y, B, L, cfg.HiddenSize)
+}
+
+// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
+type DenseBlock struct {
+ Attention *MLAAttention
+ MLP *DenseMLP
+ InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
+ PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
+}
+
+// Forward applies the dense block
+func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ // Pre-norm attention with residual
+ r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
+ h := mlx.Add(x, r)
+
+ // Pre-norm MLP with residual
+ r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
+ return mlx.Add(h, r)
+}
+
+// MoEBlock represents a MoE transformer block
+type MoEBlock struct {
+ Attention *MLAAttention
+ MoE *MoE
+ InputLayerNorm *nn.RMSNorm `weight:"input_layernorm"`
+ PostAttentionLayerNorm *nn.RMSNorm `weight:"post_attention_layernorm"`
+}
+
+// Forward applies the MoE block
+func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ // Pre-norm attention with residual
+ r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
+ h := mlx.Add(x, r)
+
+ // Pre-norm MoE with residual
+ r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
+ return mlx.Add(h, r)
+}
+
+// Block interface for both dense and MoE blocks
+type Block interface {
+ Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
+}
+
+// Model represents the complete GLM4-MoE-Lite model
+type Model struct {
+ EmbedTokens *nn.Embedding `weight:"model.embed_tokens"`
+ Layers []Block `weight:"-"` // Loaded manually due to different block types
+ Norm *nn.RMSNorm `weight:"model.norm"`
+ LMHead nn.LinearLayer `weight:"lm_head"`
+
+ tok *tokenizer.Tokenizer
+ *Config
+}
+
+// computeScale computes the attention scale.
+// Uses the full key head dimension (qkNopeHeadDim + qkRopeHeadDim) to match the Ollama runner.
+func computeScale(cfg *Config) float32 {
+ keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
+ scale := float32(1.0 / math.Sqrt(float64(keyLength)))
+ if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
+ s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
+ scale *= s * s
+ }
+ return scale
+}
+
+// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
+// Currently only 4-bit and 8-bit affine quantization are supported.
+func supportsGatherQMM(mode string, bits int) bool {
+ return mode == "affine" && (bits == 4 || bits == 8)
+}
+
+// ExpertWeight holds a single expert's weight with optional quantization components.
+type ExpertWeight struct {
+ Weight *mlx.Array // Quantized weight (if quantized) or dequantized weight
+ Scales *mlx.Array // Quantization scales (nil if not quantized)
+ Biases *mlx.Array // Quantization biases (nil if not quantized or mode doesn't use biases)
+ Bits int // Quantization bits (4 or 8), 0 if not quantized
+ GroupSize int // Quantization group size, 0 if not quantized
+}
+
+// getQuantParams returns quantization parameters from model metadata.
+// Returns groupSize, bits, and mode for the model's quantization type.
+func getQuantParams(weights safetensors.WeightSource) (groupSize, bits int, mode string) {
+ groupSize, bits, mode = safetensors.QuantizationParams(weights.Quantization())
+ // Use metadata group_size if available (overrides default)
+ if gs := weights.GroupSize(); gs > 0 {
+ groupSize = gs
+ }
+ return groupSize, bits, mode
+}
+
+// loadExpertWeight loads an expert weight.
+// If useQuantized is true and the weight is quantized with a supported mode, returns quantized components.
+// Otherwise dequantizes and returns only the weight.
+func loadExpertWeight(weights safetensors.WeightSource, path string, useQuantized bool, cfg *Config) *ExpertWeight {
+ w, _ := weights.GetTensor(path + ".weight")
+ if w == nil {
+ return nil
+ }
+
+ // Check if this is a quantized weight by looking for scales
+ scalePath := path + ".weight_scale"
+ if weights.HasTensor(scalePath) {
+ scales, _ := weights.GetTensor(scalePath)
+ var qbiases *mlx.Array
+ qbiasPath := path + ".weight_qbias"
+ if weights.HasTensor(qbiasPath) {
+ qbiases, _ = weights.GetTensor(qbiasPath)
+ }
+
+ // Get quantization params from metadata
+ groupSize, bits, mode := getQuantParams(weights)
+
+ // Update config with group size (for GatherQMM calls)
+ if cfg.QuantGroupSize == 0 {
+ cfg.QuantGroupSize = groupSize
+ }
+
+ // If GatherQMM is supported and requested, return quantized components
+ if useQuantized && supportsGatherQMM(mode, bits) {
+ return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
+ }
+
+ // Otherwise dequantize
+ return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
+ }
+
+ return &ExpertWeight{Weight: w}
+}
+
+// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
+// Returns embed_q and unembed_out weights for per-head projections.
+//
+// kv_b_proj.weight shape: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
+// Output:
+// - embed_q: [num_heads, kv_lora_rank, qk_nope_head_dim] - projects q_nope to latent
+// - unembed_out: [num_heads, v_head_dim, kv_lora_rank] - projects latent to output
+func sanitizeMLAWeights(weights safetensors.WeightSource, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
+ path := prefix + ".self_attn.kv_b_proj"
+ w, err := weights.GetTensor(path + ".weight")
+ if err != nil || w == nil {
+ return nil, nil
+ }
+
+ // Check if quantized and dequantize
+ scalePath := path + ".weight_scale"
+ if weights.HasTensor(scalePath) {
+ scales, _ := weights.GetTensor(scalePath)
+ var qbiases *mlx.Array
+ qbiasPath := path + ".weight_qbias"
+ if weights.HasTensor(qbiasPath) {
+ qbiases, _ = weights.GetTensor(qbiasPath)
+ }
+
+ groupSize, bits, mode := getQuantParams(weights)
+ w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
+ }
+
+ // w: [num_heads * (qk_nope_head_dim + v_head_dim), kv_lora_rank]
+ // Reshape to [num_heads, qk_nope_head_dim + v_head_dim, kv_lora_rank]
+ headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
+ w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
+
+ // Split into wk and wv
+ // wk: [num_heads, qk_nope_head_dim, kv_lora_rank]
+ // wv: [num_heads, v_head_dim, kv_lora_rank]
+ wk := mlx.Slice(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
+ wv := mlx.Slice(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
+
+ // Transform for absorbed MLA:
+ // embed_q: transpose(wk) -> [num_heads, kv_lora_rank, qk_nope_head_dim]
+ // This allows: q_nope @ embed_q.T = q_nope @ wk (absorbed key projection)
+ embedQ := mlx.Transpose(wk, 0, 2, 1)
+
+ // unembed_out: wv stays [num_heads, v_head_dim, kv_lora_rank]
+ // This allows: latent_out @ unembed_out.T = latent_out @ wv.T (absorbed value projection)
+ unembedOut := wv
+
+ return embedQ, unembedOut
+}
+
+// StackedExpertWeights holds stacked weights for all experts.
+type StackedExpertWeights struct {
+ Weight *mlx.Array // Stacked weights [num_experts, out, in] or [num_experts, out, in_packed]
+ Scales *mlx.Array // Stacked scales (nil if not quantized)
+ Biases *mlx.Array // Stacked biases (nil if not quantized)
+ Bits int // Quantization bits (4 or 8), 0 if not quantized
+ GroupSize int // Quantization group size, 0 if not quantized
+}
+
+// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
+func collectAndStackExpertWeights(
+ weights safetensors.WeightSource,
+ prefix string,
+ projName string,
+ numExperts int32,
+ useQuantized bool,
+ cfg *Config,
+) *StackedExpertWeights {
+ var w, s, b []*mlx.Array
+ var bits, groupSize int
+
+ for e := int32(0); e < numExperts; e++ {
+ path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
+ ew := loadExpertWeight(weights, path, useQuantized, cfg)
+ if ew == nil {
+ continue
+ }
+ w = append(w, ew.Weight)
+ if ew.Scales != nil {
+ s = append(s, ew.Scales)
+ }
+ if ew.Biases != nil {
+ b = append(b, ew.Biases)
+ }
+ if e == 0 {
+ bits = ew.Bits
+ groupSize = ew.GroupSize
+ }
+ }
+
+ result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
+ if len(w) > 0 {
+ result.Weight = mlx.Stack(w, 0)
+ if len(s) > 0 {
+ result.Scales = mlx.Stack(s, 0)
+ }
+ if len(b) > 0 {
+ result.Biases = mlx.Stack(b, 0)
+ }
+ }
+ return result
+}
+
+// sanitizeExpertWeights stacks individual expert weights into tensors.
+// If useQuantized is true and weights support GatherQMM, returns quantized components.
+// Otherwise returns dequantized weights with nil scales/biases.
+// Bits and GroupSize are detected per-weight to support mixed-precision (Q4 for gate/up, Q8 for down).
+func sanitizeExpertWeights(weights safetensors.WeightSource, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
+ gate = collectAndStackExpertWeights(weights, prefix, "gate_proj", numExperts, useQuantized, cfg)
+ up = collectAndStackExpertWeights(weights, prefix, "up_proj", numExperts, useQuantized, cfg)
+ down = collectAndStackExpertWeights(weights, prefix, "down_proj", numExperts, useQuantized, cfg)
+ return gate, up, down
+}
+
+// LoadFromManifest loads a GLM4-MoE-Lite model from a manifest (Ollama blob storage).
+func LoadFromManifest(modelManifest *manifest.ModelManifest) (*Model, error) {
+ // Read config from manifest
+ configData, err := modelManifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("load config: %w", err)
+ }
+
+ var cfg Config
+ if err := json.Unmarshal(configData, &cfg); err != nil {
+ return nil, fmt.Errorf("parse config: %w", err)
+ }
+
+ // Compute derived fields
+ cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
+ cfg.Scale = computeScale(&cfg)
+
+ // Load weights from manifest blobs
+ weights, err := manifest.LoadWeightsFromManifest(modelManifest, "")
+ if err != nil {
+ return nil, fmt.Errorf("load weights: %w", err)
+ }
+
+ if err := weights.Load(0); err != nil {
+ return nil, fmt.Errorf("load weight data: %w", err)
+ }
+
+ // Set up quantization parameters (only if model is actually quantized)
+ // Note: QuantGroupSize will be detected dynamically from tensor shapes during weight loading
+ quantization := weights.Quantization()
+ useQuantized := false
+ if quantization != "" {
+ _, cfg.QuantBits, cfg.QuantMode = safetensors.QuantizationParams(quantization)
+ useQuantized = supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
+ }
+
+ // Load tokenizer from manifest with config files for EOS token detection
+ tokData, err := modelManifest.ReadConfig("tokenizer.json")
+ if err != nil {
+ return nil, fmt.Errorf("load tokenizer config: %w", err)
+ }
+
+ // Build tokenizer config with companion files for EOS/BOS token loading
+ tokConfig := &tokenizer.TokenizerConfig{
+ ConfigJSON: configData, // Already loaded above, contains eos_token_id
+ }
+
+ // Try to load generation_config.json if available (preferred source for EOS)
+ if genConfigData, err := modelManifest.ReadConfig("generation_config.json"); err == nil {
+ tokConfig.GenerationConfigJSON = genConfigData
+ }
+
+ // Try to load tokenizer_config.json if available
+ if tokConfigData, err := modelManifest.ReadConfig("tokenizer_config.json"); err == nil {
+ tokConfig.TokenizerConfigJSON = tokConfigData
+ }
+
+ tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
+ if err != nil {
+ return nil, fmt.Errorf("parse tokenizer: %w", err)
+ }
+
+ m := &Model{
+ Layers: make([]Block, cfg.NumHiddenLayers),
+ Config: &cfg,
+ tok: tok,
+ }
+
+ // Load embedding, norm, and lm_head
+ if err := safetensors.LoadModule(m, weights, ""); err != nil {
+ return nil, err
+ }
+
+ // Load layers manually due to different block types
+ for i := int32(0); i < cfg.NumHiddenLayers; i++ {
+ prefix := fmt.Sprintf("model.layers.%d", i)
+
+ // Load attention (same for both block types)
+ attn := &MLAAttention{}
+ if err := safetensors.LoadModule(attn, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d attention: %w", i, err)
+ }
+
+ // Sanitize MLA weights for absorbed attention
+ embedQ, unembedOut := sanitizeMLAWeights(weights, prefix, &cfg)
+ attn.EmbedQ = nn.NewMultiLinear(embedQ)
+ attn.UnembedOut = nn.NewMultiLinear(unembedOut)
+
+ if i < cfg.FirstKDenseReplace {
+ // Dense block
+ block := &DenseBlock{Attention: attn}
+ if err := safetensors.LoadModule(block, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d dense: %w", i, err)
+ }
+ m.Layers[i] = block
+ } else {
+ // MoE block
+ block := &MoEBlock{Attention: attn}
+ if err := safetensors.LoadModule(block, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d moe block: %w", i, err)
+ }
+
+ // Stack expert weights (pass cfg so group sizes can be detected)
+ gate, up, down := sanitizeExpertWeights(weights, prefix, cfg.NRoutedExperts, useQuantized, &cfg)
+
+ switchMLP := &SwitchMLP{UseQuantized: useQuantized}
+ if useQuantized {
+ switchMLP.GateWeightQ = gate.Weight
+ switchMLP.GateScales = gate.Scales
+ switchMLP.GateBiases = gate.Biases
+ switchMLP.GateBits = gate.Bits
+ switchMLP.GateGroupSize = gate.GroupSize
+ switchMLP.UpWeightQ = up.Weight
+ switchMLP.UpScales = up.Scales
+ switchMLP.UpBiases = up.Biases
+ switchMLP.UpBits = up.Bits
+ switchMLP.UpGroupSize = up.GroupSize
+ switchMLP.DownWeightQ = down.Weight
+ switchMLP.DownScales = down.Scales
+ switchMLP.DownBiases = down.Biases
+ switchMLP.DownBits = down.Bits
+ switchMLP.DownGroupSize = down.GroupSize
+ } else {
+ switchMLP.GateWeight = gate.Weight
+ switchMLP.UpWeight = up.Weight
+ switchMLP.DownWeight = down.Weight
+ }
+
+ block.MoE = &MoE{
+ Gate: &MoEGate{},
+ SwitchMLP: switchMLP,
+ }
+
+ // Load gate weights
+ if err := safetensors.LoadModule(block.MoE.Gate, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d gate: %w", i, err)
+ }
+
+ // Load shared experts if present
+ if cfg.NSharedExperts > 0 {
+ block.MoE.SharedExperts = &SharedExperts{}
+ if err := safetensors.LoadModule(block.MoE.SharedExperts, weights, prefix); err != nil {
+ return nil, fmt.Errorf("layer %d shared experts: %w", i, err)
+ }
+ }
+
+ m.Layers[i] = block
+ }
+ }
+
+ mlx.Eval(mlx.Collect(m)...)
+ weights.ReleaseAll()
+
+ return m, nil
+}
+
+// Forward computes the forward pass of the model
+func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
+ B, L := tokens.Shape()[0], tokens.Shape()[1]
+
+ h := m.EmbedTokens.Forward(tokens)
+
+ for i, layer := range m.Layers {
+ var c cache.Cache
+ if caches != nil {
+ c = caches[i]
+ }
+ h = layer.Forward(h, c, B, L, m.Config)
+ }
+
+ h = m.Norm.Forward(h, m.RMSNormEps)
+ return m.LMHead.Forward(h)
+}
+
+// Interface methods
+
+// NumLayers returns the number of transformer layers
+func (m *Model) NumLayers() int { return len(m.Layers) }
+
+// MaxContextLength returns the maximum context length
+func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
+
+// VocabSize returns the vocabulary size
+func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
+
+// Tokenizer returns the model's tokenizer
+func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
+
+// NewCache creates a new KV cache for the model
+func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
+ caches := make([]cache.Cache, len(m.Layers))
+ for i := range caches {
+ caches[i] = cache.NewKVCache()
+ }
+ return caches
+}
+
+// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
+// This follows the GLM-4.7 format with tag for reasoning mode.
+func (m *Model) FormatPrompt(prompt string) string {
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+}
+
+// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
+// When think is true, the prompt ends with to enable reasoning mode.
+// When think is false, the prompt ends with to skip reasoning.
+func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
+ if think {
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+ }
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+}
+
+// NewRenderer returns a new Renderer for formatting multi-turn conversations.
+func (m *Model) NewRenderer() *Renderer {
+ return &Renderer{}
+}
+
+// NewParser returns a new Parser for extracting thinking and tool calls from output.
+func (m *Model) NewParser() *Parser {
+ return &Parser{}
+}
diff --git a/x/imagegen/models/glm4_moe_lite/parser.go b/x/imagegen/models/glm4_moe_lite/parser.go
new file mode 100644
index 00000000000..c81ec5a4043
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/parser.go
@@ -0,0 +1,479 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "log/slog"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/logutil"
+)
+
+type parserState int
+
+const (
+ parserState_LookingForThinkingOpen parserState = iota
+ parserState_ThinkingStartedEatingWhitespace
+ parserState_CollectingThinking
+ parserState_ThinkingDoneEatingWhitespace
+ parserState_CollectingContent
+ parserState_ToolStartedEatingWhitespace
+ parserState_CollectingToolContent
+)
+
+const (
+ thinkingOpenTag = ""
+ thinkingCloseTag = ""
+ toolOpenTag = ""
+ toolCloseTag = ""
+)
+
+// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
+// GLM-4's prompt ends with when thinking is enabled, so the parser
+// must start in CollectingThinking state (the model outputs thinking content directly).
+type Parser struct {
+ state parserState
+ buffer strings.Builder
+ tools []api.Tool
+}
+
+// HasToolSupport returns true as GLM4 supports tool calling.
+func (p *Parser) HasToolSupport() bool {
+ return true
+}
+
+// HasThinkingSupport returns true as GLM4 supports thinking mode.
+func (p *Parser) HasThinkingSupport() bool {
+ return true
+}
+
+// Init initializes the parser with tools and thinking configuration.
+func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
+ p.tools = tools
+ // When thinking is enabled (nil or true), the prompt ends with ,
+ // so model output starts directly with thinking content (no opening tag).
+ if thinkValue == nil || thinkValue.Bool() {
+ p.state = parserState_CollectingThinking
+ }
+ return tools
+}
+
+type parserEvent interface {
+ isParserEvent()
+}
+
+type eventContent struct {
+ content string
+}
+
+func (eventContent) isParserEvent() {}
+
+type eventRawToolCall struct {
+ raw string
+}
+
+func (eventRawToolCall) isParserEvent() {}
+
+type eventThinkingContent struct {
+ content string
+}
+
+func (eventThinkingContent) isParserEvent() {}
+
+// Add processes new output text and returns parsed content, thinking, and tool calls.
+func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.buffer.WriteString(s)
+ events := p.parseEvents()
+
+ var toolCalls []api.ToolCall
+ var contentSb strings.Builder
+ var thinkingSb strings.Builder
+
+ for _, event := range events {
+ switch event := event.(type) {
+ case eventRawToolCall:
+ toolCall, err := parseToolCall(event, p.tools)
+ if err != nil {
+ slog.Warn("glm-4 tool call parsing failed", "error", err)
+ return "", "", nil, err
+ }
+ toolCalls = append(toolCalls, toolCall)
+ case eventThinkingContent:
+ thinkingSb.WriteString(event.content)
+ case eventContent:
+ contentSb.WriteString(event.content)
+ }
+ }
+
+ return contentSb.String(), thinkingSb.String(), toolCalls, nil
+}
+
+func (p *Parser) parseEvents() []parserEvent {
+ var all []parserEvent
+
+ keepLooping := true
+ for keepLooping {
+ var events []parserEvent
+ events, keepLooping = p.eat()
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ if len(all) > 0 {
+ slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
+ }
+
+ return all
+}
+
+// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
+// and transitions to the next state. Returns (nil, false) if only whitespace remains
+// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
+func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ p.buffer.Reset()
+ if trimmed == "" {
+ return nil, false // Still only whitespace, keep waiting for more input
+ }
+ p.state = nextState
+ p.buffer.WriteString(trimmed)
+ return nil, true // Successfully transitioned
+}
+
+// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
+// the content after (optionally trimmed of leading whitespace), and updates the buffer
+func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
+ split := strings.SplitN(p.buffer.String(), tag, 2)
+ before := split[0]
+ before = strings.TrimRightFunc(before, unicode.IsSpace)
+ after := split[1]
+ if trimAfter {
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ }
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ return before, after
+}
+
+func (p *Parser) eat() ([]parserEvent, bool) {
+ var events []parserEvent
+
+ switch p.state {
+ case parserState_LookingForThinkingOpen:
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ if strings.HasPrefix(trimmed, thinkingOpenTag) {
+ // Found opening tag
+ after := strings.TrimPrefix(trimmed, thinkingOpenTag)
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ if after == "" {
+ p.state = parserState_ThinkingStartedEatingWhitespace
+ } else {
+ p.state = parserState_CollectingThinking
+ }
+ return events, true
+ } else if strings.HasPrefix(thinkingOpenTag, trimmed) {
+ // Partial opening tag seen, keep accumulating
+ return events, false
+ } else if trimmed == "" {
+ // Only whitespace, keep accumulating
+ return events, false
+ } else {
+ // No thinking tag found, skip to content collection
+ p.state = parserState_CollectingContent
+ // Don't trim - we want to keep the original content
+ return events, true
+ }
+
+ case parserState_ThinkingStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
+
+ case parserState_CollectingThinking:
+ acc := p.buffer.String()
+ if strings.Contains(acc, thinkingCloseTag) {
+ thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
+ if len(thinking) > 0 {
+ events = append(events, eventThinkingContent{content: thinking})
+ }
+ if remaining == "" {
+ p.state = parserState_ThinkingDoneEatingWhitespace
+ } else {
+ p.state = parserState_CollectingContent
+ }
+ return events, true
+ } else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
+ // Partial closing tag - withhold it along with any trailing whitespace before it
+ beforePartialTag := acc[:len(acc)-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ } else {
+ // Pure thinking content - withhold trailing whitespace (might precede closing tag)
+ whitespaceLen := trailingWhitespaceLen(acc)
+ ambiguousStart := len(acc) - whitespaceLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case parserState_ThinkingDoneEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
+
+ case parserState_CollectingContent:
+ if strings.Contains(p.buffer.String(), toolOpenTag) {
+ before, after := p.splitAtTag(toolOpenTag, true)
+ if len(before) > 0 {
+ events = append(events, eventContent{content: before})
+ }
+ if after == "" {
+ p.state = parserState_ToolStartedEatingWhitespace
+ } else {
+ p.state = parserState_CollectingToolContent
+ }
+ return events, true
+ } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
+ beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := p.buffer.String()[:ambiguousStart]
+ ambiguous := p.buffer.String()[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventContent{content: unambiguous})
+ }
+ return events, false
+ } else {
+ whitespaceLen := trailingWhitespaceLen(p.buffer.String())
+ ambiguousStart := len(p.buffer.String()) - whitespaceLen
+
+ unambiguous := p.buffer.String()[:ambiguousStart]
+ ambiguous := p.buffer.String()[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case parserState_ToolStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
+
+ case parserState_CollectingToolContent:
+ acc := p.buffer.String()
+ if strings.Contains(acc, toolCloseTag) {
+ toolContent, _ := p.splitAtTag(toolCloseTag, true)
+ if len(toolContent) == 0 {
+ slog.Warn("glm4 tool call closing tag found but no content before it")
+ }
+ events = append(events, eventRawToolCall{raw: toolContent})
+ p.state = parserState_CollectingContent
+ return events, true
+ } else {
+ // Keep accumulating - tool calls are not streamed
+ // We just wait for the closing tag
+ return events, false
+ }
+
+ default:
+ panic("unreachable")
+ }
+}
+
+// overlap returns the length of the overlap between the end of s and the start of tag.
+func overlap(s, tag string) int {
+ for i := 1; i <= len(tag) && i <= len(s); i++ {
+ if strings.HasSuffix(s, tag[:i]) {
+ return i
+ }
+ }
+ return 0
+}
+
+// trailingWhitespaceLen returns the length of trailing whitespace in s.
+func trailingWhitespaceLen(s string) int {
+ trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
+ return len(s) - len(trimmed)
+}
+
+// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
+type ToolCallXML struct {
+ XMLName xml.Name `xml:"tool_call"`
+ Content string `xml:",chardata"` // Function name (text nodes between tags)
+ Keys []string `xml:"arg_key"` // All arg_key elements in document order
+ Values []string `xml:"arg_value"` // All arg_value elements in document order
+}
+
+// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
+func escapeContent(s string) string {
+ var result strings.Builder
+ inTag := false
+
+ for i := range len(s) {
+ ch := s[i]
+
+ if ch == '<' {
+ // Check if this is a known tag
+ if strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") {
+ inTag = true
+ }
+ }
+
+ if inTag {
+ result.WriteByte(ch)
+ if ch == '>' {
+ inTag = false
+ }
+ } else {
+ // Escape special characters in text content
+ switch ch {
+ case '&':
+ result.WriteString("&")
+ case '<':
+ result.WriteString("<")
+ case '>':
+ result.WriteString(">")
+ default:
+ result.WriteByte(ch)
+ }
+ }
+ }
+
+ return result.String()
+}
+
+func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
+ // Escape any unescaped entities in text content
+ escaped := escapeContent(raw.raw)
+
+ // Wrap the content in a root element to make it valid XML
+ xmlString := "" + escaped + ""
+
+ // Parse XML into struct
+ var parsed ToolCallXML
+ if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
+ return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
+ }
+
+ // Extract and trim function name
+ functionName := strings.TrimSpace(parsed.Content)
+ if functionName == "" {
+ return api.ToolCall{}, fmt.Errorf("empty function name")
+ }
+
+ // Verify keys and values are paired correctly
+ if len(parsed.Keys) != len(parsed.Values) {
+ return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
+ }
+
+ // Find the matching tool to get parameter types
+ var matchedTool *api.Tool
+ for i := range tools {
+ if tools[i].Function.Name == functionName {
+ matchedTool = &tools[i]
+ break
+ }
+ }
+
+ // Build arguments map by pairing keys and values
+ toolCall := api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: functionName,
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ }
+
+ for i := range parsed.Keys {
+ key := strings.TrimSpace(parsed.Keys[i])
+ value := parsed.Values[i] // Don't trim here - parseValue handles it
+
+ // Look up parameter type
+ var paramType api.PropertyType
+ if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
+ if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
+ // Handle anyOf by collecting all types from the union
+ if len(prop.AnyOf) > 0 {
+ for _, anyOfProp := range prop.AnyOf {
+ paramType = append(paramType, anyOfProp.Type...)
+ }
+ } else {
+ paramType = prop.Type
+ }
+ }
+ }
+
+ // Parse value with type coercion
+ toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
+ }
+
+ return toolCall, nil
+}
+
+// parseValue parses a string value and coerces it to the appropriate type based on paramType.
+func parseValue(value string, paramType api.PropertyType) any {
+ value = strings.TrimSpace(value)
+
+ // If no type specified, return as string
+ if len(paramType) == 0 {
+ return value
+ }
+
+ // Try to parse based on specified types
+ for _, t := range paramType {
+ switch t {
+ case "boolean":
+ if value == "true" {
+ return true
+ }
+ if value == "false" {
+ return false
+ }
+ case "integer":
+ var i int64
+ if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
+ return i
+ }
+ case "number":
+ var f float64
+ if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
+ return f
+ }
+ case "array", "object":
+ // Try to parse as JSON
+ var result any
+ if err := json.Unmarshal([]byte(value), &result); err == nil {
+ return result
+ }
+ }
+ }
+
+ // Default to string
+ return value
+}
diff --git a/x/imagegen/models/glm4_moe_lite/parser_test.go b/x/imagegen/models/glm4_moe_lite/parser_test.go
new file mode 100644
index 00000000000..0ce3827098b
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/parser_test.go
@@ -0,0 +1,192 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestParserThinking(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ thinkEnabled bool
+ wantContent string
+ wantThinking string
+ wantToolCalls int
+ }{
+ {
+ name: "thinking enabled - simple thinking then content",
+ input: "Let me think about this...Here is my answer.",
+ thinkEnabled: true,
+ wantThinking: "Let me think about this...",
+ wantContent: "Here is my answer.",
+ },
+ {
+ name: "thinking enabled - only thinking",
+ input: "I need to consider multiple factors...",
+ thinkEnabled: true,
+ wantThinking: "I need to consider multiple factors...",
+ wantContent: "",
+ },
+ {
+ name: "thinking disabled - direct content",
+ input: "Here is my direct answer.",
+ thinkEnabled: false,
+ wantThinking: "",
+ wantContent: "Here is my direct answer.",
+ },
+ {
+ name: "thinking with tool call",
+ input: "Let me search for that...I'll use a tool.searchquerytest",
+ thinkEnabled: true,
+ wantThinking: "Let me search for that...",
+ wantContent: "I'll use a tool.",
+ wantToolCalls: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p := &Parser{}
+
+ var thinkValue *api.ThinkValue
+ if tt.thinkEnabled {
+ thinkValue = &api.ThinkValue{Value: true}
+ } else {
+ thinkValue = &api.ThinkValue{Value: false}
+ }
+
+ // Define tools for tool call tests
+ props := api.NewToolPropertiesMap()
+ props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "search",
+ Parameters: api.ToolFunctionParameters{
+ Properties: props,
+ },
+ },
+ },
+ }
+
+ p.Init(tools, nil, thinkValue)
+
+ content, thinking, calls, err := p.Add(tt.input, true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if thinking != tt.wantThinking {
+ t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
+ }
+ if content != tt.wantContent {
+ t.Errorf("content = %q, want %q", content, tt.wantContent)
+ }
+ if len(calls) != tt.wantToolCalls {
+ t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
+ }
+ })
+ }
+}
+
+func TestParserToolCall(t *testing.T) {
+ p := &Parser{}
+
+ props := api.NewToolPropertiesMap()
+ props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
+ props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Parameters: api.ToolFunctionParameters{
+ Properties: props,
+ },
+ },
+ },
+ }
+
+ // Initialize with thinking disabled
+ tv := &api.ThinkValue{Value: false}
+ p.Init(tools, nil, tv)
+
+ input := "get_weatherlocationSan Franciscounitcelsius"
+
+ _, _, calls, err := p.Add(input, true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(calls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
+ }
+
+ call := calls[0]
+ if call.Function.Name != "get_weather" {
+ t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
+ }
+
+ location, ok := call.Function.Arguments.Get("location")
+ if !ok || location != "San Francisco" {
+ t.Errorf("location = %v, want %q", location, "San Francisco")
+ }
+
+ unit, ok := call.Function.Arguments.Get("unit")
+ if !ok || unit != "celsius" {
+ t.Errorf("unit = %v, want %q", unit, "celsius")
+ }
+}
+
+func TestOverlap(t *testing.T) {
+ tests := []struct {
+ s string
+ tag string
+ want int
+ }{
+ {"hello<", "", 1},
+ {"hello", "", 2},
+ {"hello", 3},
+ {"hello", 4},
+ {"hello", 5},
+ {"hello", 6},
+ {"hello", 7},
+ {"hello", "", 8}, // Complete tag at end returns full length
+ {"hello", "", 0},
+ {"", "", 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
+ got := overlap(tt.s, tt.tag)
+ if got != tt.want {
+ t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTrailingWhitespaceLen(t *testing.T) {
+ tests := []struct {
+ s string
+ want int
+ }{
+ {"hello ", 3},
+ {"hello\n\t ", 3},
+ {"hello", 0},
+ {"", 0},
+ {" ", 3},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.s, func(t *testing.T) {
+ got := trailingWhitespaceLen(tt.s)
+ if got != tt.want {
+ t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/x/imagegen/models/glm4_moe_lite/render.go b/x/imagegen/models/glm4_moe_lite/render.go
new file mode 100644
index 00000000000..4998604bf39
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/render.go
@@ -0,0 +1,175 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+// Renderer renders messages for GLM4-MoE-Lite models.
+//
+// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
+//
+// 1. INTERLEAVED THINKING
+// The model thinks between tool calls and after receiving tool results.
+// This enables complex step-by-step reasoning: interpreting each tool output
+// before deciding what to do next. Thinking blocks are preserved and returned
+// with tool results to maintain reasoning continuity.
+//
+// 2. PRESERVED THINKING
+// The model retains reasoning content from previous assistant turns in context.
+// This preserves reasoning continuity across multi-turn conversations. The
+// upstream API has a "clear_thinking" parameter to control this:
+// - clear_thinking=true: clears reasoning from previous turns (outputs )
+// - clear_thinking=false: preserves ... blocks from previous turns
+//
+// 3. TURN-LEVEL THINKING
+// Controls whether the model should reason on each turn. The upstream API
+// uses "enable_thinking" parameter:
+// - enable_thinking=true: outputs to start reasoning
+// - enable_thinking=false: outputs to skip reasoning
+//
+// OLLAMA DEFAULTS:
+// - Thinking is ENABLED by default (thinkValue=nil or true outputs )
+// - Thinking is PRESERVED by default (reasoning content from previous turns is always
+// included in ... blocks, equivalent to clear_thinking=false)
+// - Users can disable thinking per-turn via thinkValue=false
+type Renderer struct{}
+
+// Render renders messages into the GLM4 chat format.
+func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ sb.WriteString("[gMASK]")
+
+ if len(tools) > 0 {
+ sb.WriteString("<|system|>\n")
+ sb.WriteString("# Tools\n\n")
+ sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
+ sb.WriteString("You are provided with function signatures within XML tags:\n")
+ sb.WriteString("\n")
+ for _, tool := range tools {
+ d, _ := json.Marshal(tool)
+ sb.WriteString(formatToolJSON(d))
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n\n")
+ sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
+ sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...")
+ }
+
+ think := true
+ if thinkValue != nil && !thinkValue.Bool() {
+ think = false
+ }
+
+ for i, message := range messages {
+ switch message.Role {
+ case "user":
+ sb.WriteString("<|user|>")
+ sb.WriteString(message.Content)
+ case "assistant":
+ sb.WriteString("<|assistant|>")
+ if message.Thinking != "" {
+ sb.WriteString("" + message.Thinking + "")
+ } else {
+ sb.WriteString("")
+ }
+ if message.Content != "" {
+ sb.WriteString(message.Content)
+ }
+ if len(message.ToolCalls) > 0 {
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("" + toolCall.Function.Name)
+ sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
+ sb.WriteString("")
+ }
+ }
+ case "tool":
+ if i == 0 || messages[i-1].Role != "tool" {
+ sb.WriteString("<|observation|>")
+ }
+ sb.WriteString("")
+ sb.WriteString(message.Content)
+ sb.WriteString("")
+ case "system":
+ sb.WriteString("<|system|>")
+ sb.WriteString(message.Content)
+ }
+ }
+
+ sb.WriteString("<|assistant|>")
+ if think {
+ sb.WriteString("")
+ } else {
+ sb.WriteString("")
+ }
+
+ return sb.String(), nil
+}
+
+// renderToolArguments converts tool call arguments to GLM4 XML format.
+func renderToolArguments(args api.ToolCallFunctionArguments) string {
+ var sb strings.Builder
+ for key, value := range args.All() {
+ sb.WriteString("" + key + "")
+ var valueStr string
+ if str, ok := value.(string); ok {
+ valueStr = str
+ } else {
+ jsonBytes, err := json.Marshal(value)
+ if err != nil {
+ valueStr = fmt.Sprintf("%v", value)
+ } else {
+ valueStr = string(jsonBytes)
+ }
+ }
+
+ sb.WriteString("" + valueStr + "")
+ }
+
+ return sb.String()
+}
+
+// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
+func formatToolJSON(raw []byte) string {
+ var sb strings.Builder
+ sb.Grow(len(raw) + len(raw)/10)
+
+ inString := false
+ escaped := false
+ for i := range raw {
+ ch := raw[i]
+ sb.WriteByte(ch)
+
+ if inString {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == '"' {
+ inString = false
+ }
+ continue
+ }
+
+ if ch == '"' {
+ inString = true
+ continue
+ }
+
+ if ch == ':' || ch == ',' {
+ sb.WriteByte(' ')
+ }
+ }
+
+ return sb.String()
+}
diff --git a/x/imagegen/models/glm4_moe_lite/render_test.go b/x/imagegen/models/glm4_moe_lite/render_test.go
new file mode 100644
index 00000000000..f0d576bec85
--- /dev/null
+++ b/x/imagegen/models/glm4_moe_lite/render_test.go
@@ -0,0 +1,205 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestRendererSimple(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "Hello"},
+ }
+
+ // Thinking enabled (default)
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := "[gMASK]<|user|>Hello<|assistant|>"
+ if result != expected {
+ t.Errorf("result = %q, want %q", result, expected)
+ }
+}
+
+func TestRendererThinkingDisabled(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "Hello"},
+ }
+
+ tv := &api.ThinkValue{Value: false}
+
+ result, err := r.Render(messages, nil, tv)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := "[gMASK]<|user|>Hello<|assistant|>"
+ if result != expected {
+ t.Errorf("result = %q, want %q", result, expected)
+ }
+}
+
+func TestRendererMultiTurn(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "What is 2+2?"},
+ {Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
+ {Role: "user", Content: "And 3+3?"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Check key parts
+ if !strings.Contains(result, "[gMASK]") {
+ t.Error("missing [gMASK] prefix")
+ }
+ if !strings.Contains(result, "<|user|>What is 2+2?") {
+ t.Error("missing first user message")
+ }
+ if !strings.Contains(result, "<|assistant|>Let me calculate: 2+2=44") {
+ t.Error("missing assistant message with thinking")
+ }
+ if !strings.Contains(result, "<|user|>And 3+3?") {
+ t.Error("missing second user message")
+ }
+ if !strings.HasSuffix(result, "<|assistant|>") {
+ t.Errorf("should end with <|assistant|>, got suffix: %q", result[len(result)-30:])
+ }
+}
+
+func TestRendererWithSystem(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Hello"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
+ t.Error("missing system message")
+ }
+}
+
+func TestRendererWithTools(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ }
+
+ props := api.NewToolPropertiesMap()
+ props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get the weather for a location",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: props,
+ Required: []string{"location"},
+ },
+ },
+ },
+ }
+
+ result, err := r.Render(messages, tools, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Check for tool system prompt
+ if !strings.Contains(result, "<|system|>") {
+ t.Error("missing system tag for tools")
+ }
+ if !strings.Contains(result, "# Tools") {
+ t.Error("missing tools header")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing tools tag")
+ }
+ if !strings.Contains(result, "get_weather") {
+ t.Error("missing tool name")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing closing tools tag")
+ }
+}
+
+func TestRendererWithToolCalls(t *testing.T) {
+ r := &Renderer{}
+
+ args := api.NewToolCallFunctionArguments()
+ args.Set("location", "San Francisco")
+
+ messages := []api.Message{
+ {Role: "user", Content: "What's the weather in SF?"},
+ {
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: args,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "Sunny, 72F"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if !strings.Contains(result, "get_weather") {
+ t.Error("missing tool call")
+ }
+ if !strings.Contains(result, "location") {
+ t.Error("missing arg_key")
+ }
+ if !strings.Contains(result, "San Francisco") {
+ t.Error("missing arg_value")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing tool call closing tag")
+ }
+ if !strings.Contains(result, "<|observation|>") {
+ t.Error("missing observation tag")
+ }
+ if !strings.Contains(result, "Sunny, 72F") {
+ t.Error("missing tool response")
+ }
+}
+
+func TestFormatToolJSON(t *testing.T) {
+ input := []byte(`{"name":"test","value":123}`)
+ result := formatToolJSON(input)
+
+ // Should add spaces after : and ,
+ if !strings.Contains(result, ": ") {
+ t.Error("should add space after colon")
+ }
+ if !strings.Contains(result, ", ") {
+ t.Error("should add space after comma")
+ }
+}
diff --git a/x/imagegen/models/qwen3/text_encoder.go b/x/imagegen/models/qwen3/text_encoder.go
index 59d66ca4997..de32bd347f3 100644
--- a/x/imagegen/models/qwen3/text_encoder.go
+++ b/x/imagegen/models/qwen3/text_encoder.go
@@ -7,7 +7,7 @@ import (
"fmt"
"math"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -181,19 +181,19 @@ type TextEncoder struct {
}
// Load loads the Qwen3 text encoder from ollama blob storage.
-func (m *TextEncoder) Load(manifest *imagegen.ModelManifest, configPath string) error {
+func (m *TextEncoder) Load(modelManifest *manifest.ModelManifest, configPath string) error {
fmt.Print(" Loading text encoder... ")
// Load config from blob
var cfg Config
- if err := manifest.ReadConfigJSON(configPath, &cfg); err != nil {
+ if err := modelManifest.ReadConfigJSON(configPath, &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
m.Layers = make([]*Block, cfg.NumHiddenLayers)
// Load weights from tensor blobs
- weights, err := imagegen.LoadWeightsFromManifest(manifest, "text_encoder")
+ weights, err := manifest.LoadWeightsFromManifest(modelManifest, "text_encoder")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
diff --git a/x/imagegen/models/qwen_image/pipeline_test.go b/x/imagegen/models/qwen_image/pipeline_test.go
deleted file mode 100644
index 4a0ad7135cd..00000000000
--- a/x/imagegen/models/qwen_image/pipeline_test.go
+++ /dev/null
@@ -1,87 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "fmt"
- "os"
- "path/filepath"
- "runtime"
- "testing"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
-)
-
-// TestMain initializes MLX before running tests.
-// If MLX libraries are not available, tests are skipped.
-func TestMain(m *testing.M) {
- // Change to repo root so ./build/lib/ollama/ path works
- _, thisFile, _, _ := runtime.Caller(0)
- repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
- if err := os.Chdir(repoRoot); err != nil {
- fmt.Printf("Failed to change to repo root: %v\n", err)
- os.Exit(1)
- }
-
- if err := mlx.InitMLX(); err != nil {
- fmt.Printf("Skipping qwen_image tests: %v\n", err)
- os.Exit(0)
- }
- os.Exit(m.Run())
-}
-
-// TestPipelineOutput runs the full pipeline (integration test).
-// Skips if model weights not found. Requires ~50GB VRAM.
-func TestPipelineOutput(t *testing.T) {
- modelPath := "../../../weights/Qwen-Image-2512"
- if _, err := os.Stat(modelPath); os.IsNotExist(err) {
- t.Skip("Skipping: model weights not found at " + modelPath)
- }
-
- // Load model
- pm, err := LoadPersistent(modelPath)
- if err != nil {
- t.Skipf("Skipping: failed to load model: %v", err)
- }
-
- // Run 2-step pipeline (minimum for stable scheduler)
- cfg := &GenerateConfig{
- Prompt: "a cat",
- Width: 256,
- Height: 256,
- Steps: 2,
- Seed: 42,
- }
-
- output, err := pm.GenerateFromConfig(cfg)
- if err != nil {
- t.Fatalf("Pipeline failed: %v", err)
- }
- mlx.Eval(output)
-
- // Verify output shape [1, C, H, W]
- shape := output.Shape()
- if len(shape) != 4 {
- t.Errorf("Expected 4D output, got %v", shape)
- }
- if shape[0] != 1 || shape[1] != 3 || shape[2] != cfg.Height || shape[3] != cfg.Width {
- t.Errorf("Shape mismatch: got %v, expected [1, 3, %d, %d]", shape, cfg.Height, cfg.Width)
- }
-
- // Verify values in expected range [0, 1]
- data := output.Data()
- minVal, maxVal := float32(1.0), float32(0.0)
- for _, v := range data {
- if v < minVal {
- minVal = v
- }
- if v > maxVal {
- maxVal = v
- }
- }
- t.Logf("Output range: [%.4f, %.4f]", minVal, maxVal)
-
- if minVal < -0.1 || maxVal > 1.1 {
- t.Errorf("Output values out of range: [%.4f, %.4f]", minVal, maxVal)
- }
-}
diff --git a/x/imagegen/models/qwen_image/qwen25vl.go b/x/imagegen/models/qwen_image/qwen25vl.go
deleted file mode 100644
index af519ee7dfb..00000000000
--- a/x/imagegen/models/qwen_image/qwen25vl.go
+++ /dev/null
@@ -1,1802 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "errors"
- "fmt"
- "math"
- "path/filepath"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/safetensors"
- "github.com/ollama/ollama/x/imagegen/tokenizer"
-)
-
-// Qwen25VLConfig holds Qwen2.5-VL configuration
-type Qwen25VLConfig struct {
- // Text model config
- HiddenSize int32 `json:"hidden_size"` // 3584
- NumHiddenLayers int32 `json:"num_hidden_layers"` // 28
- IntermediateSize int32 `json:"intermediate_size"` // 18944
- NumAttentionHeads int32 `json:"num_attention_heads"` // 28
- NumKeyValueHeads int32 `json:"num_key_value_heads"` // 4
- VocabSize int32 `json:"vocab_size"` // 152064
- RMSNormEps float32 `json:"rms_norm_eps"` // 1e-6
- RopeTheta float32 `json:"rope_theta"` // 1000000
- HeadDim int32 // Calculated: HiddenSize / NumAttentionHeads
- MRoPESection []int32 // [16, 24, 24] for temporal, height, width
-
- // Vision config
- VisionHiddenSize int32 `json:"vision_hidden_size"` // 1280
- VisionNumLayers int32 `json:"vision_num_layers"` // 32
- VisionNumHeads int32 `json:"vision_num_heads"` // 16
- VisionIntermSize int32 `json:"vision_intermediate"` // 3420
- VisionPatchSize int32 `json:"vision_patch_size"` // 14
- VisionOutHiddenSize int32 `json:"vision_out_hidden"` // 3584
- VisionSpatialMerge int32 `json:"vision_spatial_merge"` // 2
- VisionWindowSize int32 `json:"vision_window_size"` // 112
- VisionFullAttIdx []int32 // [7, 15, 23, 31]
-
- // Special tokens
- ImageTokenID int32 // 151655
- VisionStartTokenID int32 // 151652
- VisionEndTokenID int32 // 151653
-}
-
-// defaultQwen25VLConfig returns default config
-func defaultQwen25VLConfig() *Qwen25VLConfig {
- cfg := &Qwen25VLConfig{
- // Text
- HiddenSize: 3584,
- NumHiddenLayers: 28,
- IntermediateSize: 18944,
- NumAttentionHeads: 28,
- NumKeyValueHeads: 4,
- VocabSize: 152064,
- RMSNormEps: 1e-6,
- RopeTheta: 1000000,
- MRoPESection: []int32{16, 24, 24},
-
- // Vision
- VisionHiddenSize: 1280,
- VisionNumLayers: 32,
- VisionNumHeads: 16,
- VisionIntermSize: 3420,
- VisionPatchSize: 14,
- VisionOutHiddenSize: 3584,
- VisionSpatialMerge: 2,
- VisionWindowSize: 112,
- VisionFullAttIdx: []int32{7, 15, 23, 31},
-
- // Special tokens
- ImageTokenID: 151655,
- VisionStartTokenID: 151652,
- VisionEndTokenID: 153653,
- }
- cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
- return cfg
-}
-
-// Qwen25VL is the Qwen2.5-VL vision-language encoder
-type Qwen25VL struct {
- Config *Qwen25VLConfig
-
- // Text model
- Embedding *mlx.Array
- Blocks []*VLTextBlock
- FinalNorm *mlx.Array
-
- // Vision tower (optional - nil for text-only models)
- VisionPatchEmbed *VisionPatchEmbed
- VisionBlocks []*VisionBlock
- VisionMerger *VisionMerger
- HasVision bool // True if vision tower is loaded
-}
-
-// LoadTextOnly loads only the text encoder components (skips vision tower)
-// Use this for text-to-image generation where vision components are not needed
-func (m *Qwen25VL) LoadTextOnly(path string) error {
- return m.load(path, false)
-}
-
-// Load loads the vision-language encoder from a directory
-// Vision components are loaded if weights exist
-func (m *Qwen25VL) Load(path string) error {
- return m.load(path, true)
-}
-
-// load is the internal loading function
-func (m *Qwen25VL) load(path string, loadVision bool) error {
- fmt.Println("Loading Qwen2.5-VL encoder...")
-
- cfg := defaultQwen25VLConfig()
- m.Config = cfg
-
- weights, err := safetensors.LoadModelWeights(path)
- if err != nil {
- return fmt.Errorf("weights: %w", err)
- }
-
- // Bulk load all weights as bf16
- fmt.Print(" Loading weights as bf16... ")
- if err := weights.Load(mlx.DtypeBFloat16); err != nil {
- return fmt.Errorf("failed to load weights: %w", err)
- }
- fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
-
- // Load text embedding
- fmt.Print(" Loading text embeddings... ")
- embedding, err := weights.Get("model.embed_tokens.weight")
- if err != nil {
- return err
- }
- m.Embedding = embedding
- fmt.Printf("✓ [%v]\n", embedding.Shape())
-
- // Load text blocks
- m.Blocks = make([]*VLTextBlock, cfg.NumHiddenLayers)
- for i := int32(0); i < cfg.NumHiddenLayers; i++ {
- fmt.Printf("\r Loading text blocks... %d/%d", i+1, cfg.NumHiddenLayers)
- block, err := newVLTextBlock(weights, int(i), cfg)
- if err != nil {
- return fmt.Errorf("failed to load text block %d: %w", i, err)
- }
- m.Blocks[i] = block
- }
- fmt.Printf("\r Loading text blocks... ✓ [%d blocks] \n", cfg.NumHiddenLayers)
-
- // Load final norm
- fmt.Print(" Loading final norm... ")
- finalNorm, err := weights.Get("model.norm.weight")
- if err != nil {
- return err
- }
- m.FinalNorm = finalNorm
- fmt.Println("✓")
-
- // Try to load vision tower (optional)
- m.HasVision = false
- if loadVision {
- if _, err := weights.Get("visual.patch_embed.proj.weight"); err == nil {
- fmt.Print(" Loading vision patch embed... ")
- m.VisionPatchEmbed, err = newVisionPatchEmbed(weights, cfg)
- if err != nil {
- return fmt.Errorf("vision patch embed: %w", err)
- }
- fmt.Println("✓")
-
- m.VisionBlocks = make([]*VisionBlock, cfg.VisionNumLayers)
- for i := int32(0); i < cfg.VisionNumLayers; i++ {
- fmt.Printf("\r Loading vision blocks... %d/%d", i+1, cfg.VisionNumLayers)
- block, err := newVisionBlock(weights, int(i), cfg)
- if err != nil {
- return fmt.Errorf("failed to load vision block %d: %w", i, err)
- }
- m.VisionBlocks[i] = block
- }
- fmt.Printf("\r Loading vision blocks... ✓ [%d blocks] \n", cfg.VisionNumLayers)
-
- fmt.Print(" Loading vision merger... ")
- m.VisionMerger, err = newVisionMerger(weights, cfg)
- if err != nil {
- return fmt.Errorf("vision merger: %w", err)
- }
- fmt.Println("✓")
-
- m.HasVision = true
- } else {
- fmt.Println(" (No vision tower - text-only mode)")
- }
- } else {
- fmt.Println(" (Skipping vision tower)")
- }
-
- weights.ReleaseAll()
- return nil
-}
-
-// EncodePrompt encodes a text prompt for image generation (text-only mode)
-// Uses the Qwen-Image template and drops the first 34 tokens (system prefix)
-func (m *Qwen25VL) EncodePrompt(tok *tokenizer.Tokenizer, prompt string) *mlx.Array {
- cfg := m.Config
-
- // Template from Python: prompt_template_encode (for image generation)
- template := "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n%s<|im_end|>\n<|im_start|>assistant\n"
- formattedPrompt := fmt.Sprintf(template, prompt)
-
- // Tokenize
- tokens := tok.Encode(formattedPrompt, false)
-
- // Create token array
- seqLen := int32(len(tokens))
- tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen})
-
- // Get text embeddings
- textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr)
-
- // Compute RoPE
- cossin := m.computeTextRoPE(seqLen, 1)
-
- // Forward through ALL text blocks
- x := textEmbed
- for _, block := range m.Blocks {
- x = block.Forward(x, cossin)
- }
-
- // Apply final norm
- x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
-
- // Drop first 34 tokens (system prefix)
- // prompt_template_encode_start_idx = 34
- dropIdx := int32(34)
- if x.Shape()[1] > dropIdx {
- x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
- }
-
- return x
-}
-
-// EncodePromptWithImage encodes a text prompt with an image
-// Returns: embeddings [B, L, hidden_size], mask [B, L], error
-func (m *Qwen25VL) EncodePromptWithImage(tok *tokenizer.Tokenizer, prompt string, image *mlx.Array) (*mlx.Array, *mlx.Array, error) {
- if !m.HasVision {
- return nil, nil, errors.New("EncodePromptWithImage called on text-only model")
- }
-
- cfg := m.Config
-
- // Template from Python diffusers pipeline: prompt_template_encode
- // Python's _get_qwen_prompt_embeds adds "Picture 1: " before vision tokens
- template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\nPicture 1: <|vision_start|><|image_pad|><|vision_end|>%s<|im_end|>\n<|im_start|>assistant\n"
- formattedPrompt := fmt.Sprintf(template, prompt)
-
- // Tokenize
- tokens := tok.Encode(formattedPrompt, false)
-
- // Process vision if image provided
- var visionEmbeddings *mlx.Array
- var numImageTokens int32
- var visionH, visionW int32 // Grid dims in patches (before spatial merge)
- if image != nil {
- visionEmbeddings = m.encodeVision(image)
- numImageTokens = visionEmbeddings.Shape()[1]
- // Get original grid dimensions from image shape
- imgShape := image.Shape()
- visionH = imgShape[2] / cfg.VisionPatchSize // Height in patches
- visionW = imgShape[3] / cfg.VisionPatchSize // Width in patches
- }
-
- // Find image token position and expand
- expandedTokens := make([]int32, 0, len(tokens)+int(numImageTokens))
- imageTokenPos := int32(-1)
- textAfterCount := int32(0)
- for i, t := range tokens {
- if t == cfg.ImageTokenID {
- imageTokenPos = int32(len(expandedTokens))
- // Insert placeholder tokens for image
- for j := int32(0); j < numImageTokens; j++ {
- expandedTokens = append(expandedTokens, cfg.ImageTokenID)
- }
- // Count remaining tokens after image
- textAfterCount = int32(len(tokens) - i - 1)
- } else {
- expandedTokens = append(expandedTokens, t)
- }
- }
-
- // Create token array
- seqLen := int32(len(expandedTokens))
- tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen})
-
- // Get text embeddings
- textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
-
- // Replace image token embeddings with vision embeddings
- if visionEmbeddings != nil && imageTokenPos >= 0 {
- // Split, replace, concat
- before := mlx.Slice(textEmbed, []int32{0, 0, 0}, []int32{1, imageTokenPos, cfg.HiddenSize})
- after := mlx.Slice(textEmbed, []int32{0, imageTokenPos + numImageTokens, 0}, []int32{1, seqLen, cfg.HiddenSize})
- textEmbed = mlx.Concatenate([]*mlx.Array{before, visionEmbeddings, after}, 1)
- }
-
- // Compute RoPE - use multimodal RoPE when image is present
- var cossin [2]*mlx.Array
- if image != nil && imageTokenPos >= 0 {
- cossin = m.ComputeMultimodalRoPE(imageTokenPos, visionH, visionW, textAfterCount, cfg.VisionSpatialMerge)
- } else {
- cossin = m.computeTextRoPE(seqLen, 1)
- }
-
- // Forward through ALL text blocks
- // Python uses hidden_states[-1] (LAST layer output, not second-to-last!)
- x := textEmbed
- for _, block := range m.Blocks {
- x = block.Forward(x, cossin)
- }
-
- // Apply final norm (Python DOES apply this for the output)
- x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
-
- // Drop first N tokens (system prefix)
- // prompt_template_encode_start_idx = 64
- dropIdx := int32(64)
- if x.Shape()[1] > dropIdx {
- x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
- }
-
- // Create attention mask (all ones for now)
- mask := mlx.Ones(1, x.Shape()[1])
-
- return x, mask, nil
-}
-
-// EncodeVision encodes an image through the vision tower (exported for testing)
-// image: [B, C, H, W] normalized image tensor
-// Returns: [B, num_tokens, hidden_size] vision embeddings
-func (m *Qwen25VL) EncodeVision(image *mlx.Array) *mlx.Array {
- return m.encodeVision(image)
-}
-
-// VisionRegion describes where vision embeddings are inserted in the sequence
-type VisionRegion struct {
- StartPos int32 // Position in sequence where vision tokens start
- NumTokens int32 // Number of vision tokens
- GridH int32 // Vision grid height (in patches, after spatial merge)
- GridW int32 // Vision grid width (in patches, after spatial merge)
-}
-
-// EncodePromptWithImages encodes a text prompt with multiple images
-// Returns: embeddings [B, L, hidden_size], mask [B, L], regions []VisionRegion, error
-func (m *Qwen25VL) EncodePromptWithImages(tok *tokenizer.Tokenizer, prompt string, images []*mlx.Array) (*mlx.Array, *mlx.Array, []VisionRegion, error) {
- if !m.HasVision {
- return nil, nil, nil, errors.New("EncodePromptWithImages called on text-only model")
- }
- if len(images) == 0 {
- return nil, nil, nil, errors.New("EncodePromptWithImages called with no images")
- }
-
- cfg := m.Config
-
- // Build image prompt prefix: "Picture 1: ...Picture N: ..."
- imgPromptTemplate := "Picture %d: <|vision_start|><|image_pad|><|vision_end|>"
- imgPrompt := ""
- for i := range images {
- imgPrompt += fmt.Sprintf(imgPromptTemplate, i+1)
- }
-
- // Template from Python diffusers pipeline: prompt_template_encode
- template := "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n%s%s<|im_end|>\n<|im_start|>assistant\n"
- formattedPrompt := fmt.Sprintf(template, imgPrompt, prompt)
-
- // Tokenize
- tokens := tok.Encode(formattedPrompt, false)
-
- // Process each image through vision tower
- visionEmbeddings := make([]*mlx.Array, len(images))
- numImageTokens := make([]int32, len(images))
- visionGridH := make([]int32, len(images))
- visionGridW := make([]int32, len(images))
-
- for i, image := range images {
- visionEmbeddings[i] = m.encodeVision(image)
- numImageTokens[i] = visionEmbeddings[i].Shape()[1]
- // Get original grid dimensions from image shape
- imgShape := image.Shape()
- visionH := imgShape[2] / cfg.VisionPatchSize // Height in patches
- visionW := imgShape[3] / cfg.VisionPatchSize // Width in patches
- // After spatial merge, grid is halved
- visionGridH[i] = visionH / cfg.VisionSpatialMerge
- visionGridW[i] = visionW / cfg.VisionSpatialMerge
- }
-
- // Find all image token positions and expand tokens
- expandedTokens := make([]int32, 0, len(tokens)+int(sum(numImageTokens)))
- imagePositions := make([]int32, 0, len(images)) // Start position for each image's tokens
- imageIdx := 0
-
- for _, t := range tokens {
- if t == cfg.ImageTokenID {
- if imageIdx < len(images) {
- imagePositions = append(imagePositions, int32(len(expandedTokens)))
- // Insert placeholder tokens for this image
- for j := int32(0); j < numImageTokens[imageIdx]; j++ {
- expandedTokens = append(expandedTokens, cfg.ImageTokenID)
- }
- imageIdx++
- }
- } else {
- expandedTokens = append(expandedTokens, t)
- }
- }
-
- // Create token array
- seqLen := int32(len(expandedTokens))
- tokenArr := mlx.NewArrayInt32(expandedTokens, []int32{1, seqLen})
-
- // Get text embeddings
- textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
-
- // Replace image token embeddings with vision embeddings
- // Build list of segments to concatenate
- segments := make([]*mlx.Array, 0, len(images)*2+1)
- regions := make([]VisionRegion, len(images))
- lastEnd := int32(0)
-
- for i, imgPos := range imagePositions {
- // Text segment before this image
- if imgPos > lastEnd {
- segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, imgPos, cfg.HiddenSize}))
- }
- // Vision embeddings for this image
- segments = append(segments, visionEmbeddings[i])
- regions[i] = VisionRegion{
- StartPos: imgPos,
- NumTokens: numImageTokens[i],
- GridH: visionGridH[i],
- GridW: visionGridW[i],
- }
- lastEnd = imgPos + numImageTokens[i]
- }
- // Remaining text after last image
- if lastEnd < seqLen {
- segments = append(segments, mlx.Slice(textEmbed, []int32{0, lastEnd, 0}, []int32{1, seqLen, cfg.HiddenSize}))
- }
-
- // Concatenate all segments
- textEmbed = mlx.Concatenate(segments, 1)
-
- // Compute RoPE - use multimodal RoPE for multiple images
- cossin, err := m.ComputeMultiImageRoPE(imagePositions, visionGridH, visionGridW, numImageTokens, seqLen)
- if err != nil {
- return nil, nil, nil, fmt.Errorf("computing RoPE: %w", err)
- }
-
- // Forward through ALL text blocks
- x := textEmbed
- for _, block := range m.Blocks {
- x = block.Forward(x, cossin)
- }
-
- // Apply final norm
- x = mlx.RMSNorm(x, m.FinalNorm, cfg.RMSNormEps)
-
- // Drop first N tokens (system prefix)
- // prompt_template_encode_start_idx = 64
- dropIdx := int32(64)
- if x.Shape()[1] > dropIdx {
- x = mlx.Slice(x, []int32{0, dropIdx, 0}, []int32{1, x.Shape()[1], cfg.HiddenSize})
- // Adjust region positions
- for i := range regions {
- regions[i].StartPos -= dropIdx
- }
- }
-
- // Create attention mask (all ones)
- mask := mlx.Ones(1, x.Shape()[1])
-
- return x, mask, regions, nil
-}
-
-// sum returns the sum of int32 slice
-func sum(arr []int32) int32 {
- var s int32
- for _, v := range arr {
- s += v
- }
- return s
-}
-
-// EncodeTextOnly encodes text tokens through all text blocks (exported for testing)
-// tokens: array of token IDs
-// Returns: [B, L, hidden_size] text embeddings after all blocks
-func (m *Qwen25VL) EncodeTextOnly(tokens []int32) *mlx.Array {
- seqLen := int32(len(tokens))
- tokenArr := mlx.NewArrayInt32(tokens, []int32{1, seqLen})
-
- // Get text embeddings
- textEmbed := mlx.EmbeddingLookup(m.Embedding, tokenArr) // [1, L, hidden]
-
- // Compute RoPE
- cossin := m.computeTextRoPE(seqLen, 1)
-
- // Forward through ALL text blocks (unlike Encode which stops at second-to-last)
- x := textEmbed
- for _, block := range m.Blocks {
- x = block.Forward(x, cossin)
- }
-
- // Apply final norm
- x = mlx.RMSNorm(x, m.FinalNorm, m.Config.RMSNormEps)
-
- return x
-}
-
-// encodeVision encodes an image through the vision tower
-// image: [B, C, H, W] normalized image tensor
-// Returns: [B, num_tokens, hidden_size] vision embeddings
-func (m *Qwen25VL) encodeVision(image *mlx.Array) *mlx.Array {
- cfg := m.Config
-
- // Calculate grid dimensions from image
- imgShape := image.Shape()
- imgH := imgShape[2]
- imgW := imgShape[3]
- pH := imgH / cfg.VisionPatchSize // grid height in patches
- pW := imgW / cfg.VisionPatchSize // grid width in patches
-
- // Patch embed
- x := m.VisionPatchEmbed.Forward(image)
- mlx.Eval(x)
-
- // Get window reordering info
- winInfo := m.getWindowInfo(pH, pW)
-
- // Compute vision RoPE embeddings (already in 2x2-block order)
- posEmb := m.computeVisionRoPE(pH, pW)
-
- shape := x.Shape()
- B := shape[0]
- L := shape[1] // num patches = pH * pW
- D := shape[2]
- spatialMergeUnit := winInfo.SpatialMergeUnit
- spatialMerge := cfg.VisionSpatialMerge
-
- // Convert patch embed from row-major to 2x2-block order
- // Row-major: (0,0), (0,1), (0,2), ..., (1,0), (1,1), ...
- // 2x2-block: (0,0), (0,1), (1,0), (1,1), (0,2), (0,3), (1,2), (1,3), ...
- llmGridH := pH / spatialMerge
- llmGridW := pW / spatialMerge
- blockReorderIdx := make([]int32, L)
- idx := int32(0)
- for hBlock := int32(0); hBlock < llmGridH; hBlock++ {
- for wBlock := int32(0); wBlock < llmGridW; wBlock++ {
- for dh := int32(0); dh < spatialMerge; dh++ {
- for dw := int32(0); dw < spatialMerge; dw++ {
- h := hBlock*spatialMerge + dh
- w := wBlock*spatialMerge + dw
- rowMajorIdx := h*pW + w
- blockReorderIdx[idx] = rowMajorIdx
- idx++
- }
- }
- }
- }
- blockIdxArr := mlx.NewArrayInt32(blockReorderIdx, []int32{L})
- x = mlx.Take(x, blockIdxArr, 1) // Reorder patches to 2x2-block order
-
- // Window reorder hidden states and RoPE before blocks
- // Python: reshape to [L/4, 4, D], reorder dim 0, reshape back
- // Reshape x: [B, L, D] -> [B, L/4, 4, D]
- x = mlx.Reshape(x, B, L/spatialMergeUnit, spatialMergeUnit, D)
- // Reorder using window index
- winIdxArr := mlx.NewArrayInt32(winInfo.WindowIndex, []int32{int32(len(winInfo.WindowIndex))})
- x = mlx.Take(x, winIdxArr, 1) // Take along axis 1
- // Reshape back: [B, L/4, 4, D] -> [B, L, D]
- x = mlx.Reshape(x, B, L, D)
-
- // Similarly reorder RoPE: [L, headDim] -> [L/4, 4, headDim] -> reorder -> [L, headDim]
- cosShape := posEmb[0].Shape()
- ropeL := cosShape[0]
- ropeD := cosShape[1]
- cos := mlx.Reshape(posEmb[0], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD)
- sin := mlx.Reshape(posEmb[1], ropeL/spatialMergeUnit, spatialMergeUnit, ropeD)
- cos = mlx.Take(cos, winIdxArr, 0)
- sin = mlx.Take(sin, winIdxArr, 0)
- cos = mlx.Reshape(cos, ropeL, ropeD)
- sin = mlx.Reshape(sin, ropeL, ropeD)
- posEmb = [2]*mlx.Array{cos, sin}
-
- // Materialize to prevent freeing during block evaluations
- mlx.Eval(x, posEmb[0], posEmb[1])
-
- // Full sequence cu_seqlens for full attention blocks
- cuSeqlensFull := []int32{0, L}
-
- // Vision blocks - use window attention except at full attention indices
- for i, block := range m.VisionBlocks {
- useFullAttention := false
- for _, idx := range cfg.VisionFullAttIdx {
- if int32(i) == idx {
- useFullAttention = true
- break
- }
- }
-
- var cuSeqlens []int32
- if useFullAttention {
- cuSeqlens = cuSeqlensFull
- } else {
- cuSeqlens = winInfo.CuWindowSeqlens
- }
-
- x = block.Forward(x, posEmb, cuSeqlens)
- }
-
- // Spatial merge (2x2 -> 1)
- x = m.VisionMerger.ForwardWithDims(x, pH, pW)
-
- // Reverse window reorder after merger
- revIdxArr := mlx.NewArrayInt32(winInfo.ReverseIndex, []int32{int32(len(winInfo.ReverseIndex))})
- x = mlx.Take(x, revIdxArr, 1)
-
- return x
-}
-
-// WindowInfo holds window reordering and attention boundary info
-type WindowInfo struct {
- WindowIndex []int32 // Reordering indices for merged tokens
- ReverseIndex []int32 // Reverse reordering indices
- CuWindowSeqlens []int32 // Cumulative window boundaries in UNMERGED sequence
- SpatialMergeUnit int32 // Number of patches per merged token (4 = 2x2)
-}
-
-// getWindowInfo computes window reordering indices and attention boundaries
-// pH, pW: patch grid dimensions before 2x2 merge
-func (m *Qwen25VL) getWindowInfo(pH, pW int32) *WindowInfo {
- cfg := m.Config
- spatialMergeUnit := cfg.VisionSpatialMerge * cfg.VisionSpatialMerge // 4
-
- // After 2x2 merge
- llmGridH := pH / cfg.VisionSpatialMerge
- llmGridW := pW / cfg.VisionSpatialMerge
- numTokens := llmGridH * llmGridW
-
- // Window size in merged tokens
- // window_size=112, spatial_merge_size=2, patch_size=14
- // vit_merger_window_size = 112 / 2 / 14 = 4
- vitMergerWindowSize := cfg.VisionWindowSize / cfg.VisionSpatialMerge / cfg.VisionPatchSize
-
- // Calculate padding and number of windows
- padH := vitMergerWindowSize - llmGridH%vitMergerWindowSize
- if padH == vitMergerWindowSize {
- padH = 0
- }
- padW := vitMergerWindowSize - llmGridW%vitMergerWindowSize
- if padW == vitMergerWindowSize {
- padW = 0
- }
-
- numWindowsH := (llmGridH + padH) / vitMergerWindowSize
- numWindowsW := (llmGridW + padW) / vitMergerWindowSize
-
- // Create padded grid with -1 for padding
- paddedH := llmGridH + padH
- paddedW := llmGridW + padW
- grid := make([]int32, paddedH*paddedW)
- for i := range grid {
- grid[i] = -1
- }
- for h := int32(0); h < llmGridH; h++ {
- for w := int32(0); w < llmGridW; w++ {
- grid[h*paddedW+w] = h*llmGridW + w
- }
- }
-
- // Reorder into windows and track window sizes
- windowIndex := make([]int32, 0, numTokens)
- windowSizes := make([]int32, 0, numWindowsH*numWindowsW)
- ws := vitMergerWindowSize
-
- for wh := int32(0); wh < numWindowsH; wh++ {
- for ww := int32(0); ww < numWindowsW; ww++ {
- windowStart := len(windowIndex)
- // Extract window
- for h := int32(0); h < ws; h++ {
- for w := int32(0); w < ws; w++ {
- idx := (wh*ws+h)*paddedW + (ww*ws + w)
- if grid[idx] >= 0 {
- windowIndex = append(windowIndex, grid[idx])
- }
- }
- }
- windowSize := int32(len(windowIndex) - windowStart)
- windowSizes = append(windowSizes, windowSize)
- }
- }
-
- // Create reverse index (argsort of windowIndex)
- reverseIndex := make([]int32, numTokens)
- for i, idx := range windowIndex {
- reverseIndex[idx] = int32(i)
- }
-
- // Compute cumulative sequence lengths in UNMERGED sequence
- // Each merged token corresponds to spatialMergeUnit patches
- cuWindowSeqlens := make([]int32, len(windowSizes)+1)
- cuWindowSeqlens[0] = 0
- for i, size := range windowSizes {
- cuWindowSeqlens[i+1] = cuWindowSeqlens[i] + size*spatialMergeUnit
- }
-
- return &WindowInfo{
- WindowIndex: windowIndex,
- ReverseIndex: reverseIndex,
- CuWindowSeqlens: cuWindowSeqlens,
- SpatialMergeUnit: spatialMergeUnit,
- }
-}
-
-// ComputeMultiImageRoPE computes M-RoPE for combined text + multiple vision regions + text sequences
-// This extends ComputeMultimodalRoPE to handle N images instead of just one.
-//
-// Parameters:
-// - imagePositions: starting position of each image's tokens in the sequence
-// - visionGridH, visionGridW: grid dimensions for each image (after spatial merge)
-// - numImageTokens: number of tokens for each image
-// - totalLen: total sequence length
-func (m *Qwen25VL) ComputeMultiImageRoPE(imagePositions []int32, visionGridH, visionGridW, numImageTokens []int32, totalLen int32) ([2]*mlx.Array, error) {
- numImages := len(imagePositions)
-
- // Build 3D position IDs: [3, 1, totalLen]
- // Dimension 0: temporal, Dimension 1: height, Dimension 2: width
- posIDs := make([]float32, 3*totalLen)
-
- // Process sequence in order
- stIdx := int32(0) // Running text position counter
- seqIdx := int32(0)
-
- for i := 0; i < numImages; i++ {
- imgPos := imagePositions[i]
- gridH := visionGridH[i]
- gridW := visionGridW[i]
- numTokens := numImageTokens[i]
-
- // Text segment before this image
- for seqIdx < imgPos {
- posIDs[0*totalLen+seqIdx] = float32(stIdx)
- posIDs[1*totalLen+seqIdx] = float32(stIdx)
- posIDs[2*totalLen+seqIdx] = float32(stIdx)
- stIdx++
- seqIdx++
- }
-
- // Vision tokens for this image
- // Python uses stIdx as base offset for all position dimensions
- for h := int32(0); h < gridH; h++ {
- for w := int32(0); w < gridW; w++ {
- posIDs[0*totalLen+seqIdx] = float32(stIdx) // temporal: constant = stIdx
- posIDs[1*totalLen+seqIdx] = float32(stIdx + h) // height: stIdx + row_index
- posIDs[2*totalLen+seqIdx] = float32(stIdx + w) // width: stIdx + col_index
- seqIdx++
- }
- }
-
- // Verify we processed the expected number of tokens
- if seqIdx != imgPos+numTokens {
- return [2]*mlx.Array{}, fmt.Errorf("mismatch: processed %d but expected %d tokens for image %d", seqIdx-imgPos, numTokens, i)
- }
-
- // Update stIdx for next text segment: max(temporal, height, width) + 1
- maxVisionPos := stIdx // temporal max
- if stIdx+gridH-1 > maxVisionPos {
- maxVisionPos = stIdx + gridH - 1
- }
- if stIdx+gridW-1 > maxVisionPos {
- maxVisionPos = stIdx + gridW - 1
- }
- stIdx = maxVisionPos + 1
- }
-
- // Text after last image
- for seqIdx < totalLen {
- posIDs[0*totalLen+seqIdx] = float32(stIdx)
- posIDs[1*totalLen+seqIdx] = float32(stIdx)
- posIDs[2*totalLen+seqIdx] = float32(stIdx)
- stIdx++
- seqIdx++
- }
-
- posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen})
- return m.computeRoPEFromPositions(posIDsArr, totalLen, 1), nil
-}
-
-// computeTextRoPE computes M-RoPE for text-only sequences
-func (m *Qwen25VL) computeTextRoPE(L, B int32) [2]*mlx.Array {
- // For text-only, all 3 dims use same positions [0, 1, 2, ..., L-1]
- posArr := make([]float32, L*3)
- for d := 0; d < 3; d++ {
- for i := int32(0); i < L; i++ {
- posArr[int32(d)*L+i] = float32(i)
- }
- }
- posIDs := mlx.NewArray(posArr, []int32{3, 1, L})
- posIDs = mlx.Tile(posIDs, []int32{1, B, 1})
- return m.computeRoPEFromPositions(posIDs, L, B)
-}
-
-// ComputeMultimodalRoPE computes M-RoPE for combined text + vision + text sequences
-// This matches Python's get_rope_index behavior exactly.
-// Exported for testing.
-//
-// Python pattern discovered from testing:
-//
-// Vision row 1: temporal=stIdx, height=stIdx, width=[stIdx, stIdx+1, ..., stIdx+gridW-1]
-// Vision row 2: temporal=stIdx, height=stIdx+1, width=[stIdx, stIdx+1, ..., stIdx+gridW-1]
-// Text after: temporal=stIdx+1+i, height=stIdx+gridH+i, width=stIdx+gridW+i
-func (m *Qwen25VL) ComputeMultimodalRoPE(textBefore, visionH, visionW, textAfter int32, spatialMerge int32) [2]*mlx.Array {
- // Vision grid after spatial merge
- llmGridH := visionH / spatialMerge
- llmGridW := visionW / spatialMerge
- visionLen := llmGridH * llmGridW
- totalLen := textBefore + visionLen + textAfter
-
- // Build 3D position IDs: [3, 1, totalLen]
- // Dimension 0: temporal, Dimension 1: height, Dimension 2: width
- posIDs := make([]float32, 3*totalLen)
-
- // Text before vision: all dims same [0, 1, 2, ..., textBefore-1]
- for d := 0; d < 3; d++ {
- for i := int32(0); i < textBefore; i++ {
- posIDs[int32(d)*totalLen+i] = float32(i)
- }
- }
-
- // Vision tokens: 3D grid positions
- // Python uses stIdx (textBefore) as base offset for all position dimensions
- stIdx := textBefore
- for h := int32(0); h < llmGridH; h++ {
- for w := int32(0); w < llmGridW; w++ {
- idx := stIdx + h*llmGridW + w
- posIDs[0*totalLen+idx] = float32(stIdx) // temporal: constant = stIdx
- posIDs[1*totalLen+idx] = float32(stIdx + h) // height: stIdx + row_index
- posIDs[2*totalLen+idx] = float32(stIdx + w) // width: stIdx + col_index
- }
- }
-
- // Text after vision: ALL dimensions continue from max(temporal, height, width) + 1
- // max is max(stIdx, stIdx+llmGridH-1, stIdx+llmGridW-1) = stIdx + max(0, llmGridH-1, llmGridW-1)
- // Then st_idx = max + 1
- maxVisionPos := stIdx // temporal max
- if stIdx+llmGridH-1 > maxVisionPos {
- maxVisionPos = stIdx + llmGridH - 1
- }
- if stIdx+llmGridW-1 > maxVisionPos {
- maxVisionPos = stIdx + llmGridW - 1
- }
- textAfterStart := maxVisionPos + 1
- for i := int32(0); i < textAfter; i++ {
- seqIdx := textBefore + visionLen + i
- posIDs[0*totalLen+seqIdx] = float32(textAfterStart + i) // temporal
- posIDs[1*totalLen+seqIdx] = float32(textAfterStart + i) // height
- posIDs[2*totalLen+seqIdx] = float32(textAfterStart + i) // width
- }
-
- posIDsArr := mlx.NewArray(posIDs, []int32{3, 1, totalLen})
- return m.computeRoPEFromPositions(posIDsArr, totalLen, 1)
-}
-
-// computeRoPEFromPositions computes cos/sin from 3D position IDs
-// posIDs: [3, B, L] where dim 0 is temporal, 1 is height, 2 is width
-func (m *Qwen25VL) computeRoPEFromPositions(posIDs *mlx.Array, L, B int32) [2]*mlx.Array {
- cfg := m.Config
- half := cfg.HeadDim / 2
-
- // Compute inv_freq
- invFreqArr := make([]float32, half)
- for i := int32(0); i < half; i++ {
- invFreqArr[i] = float32(1.0 / math.Pow(float64(cfg.RopeTheta), 2.0*float64(i)/float64(cfg.HeadDim)))
- }
- invFreq := mlx.NewArray(invFreqArr, []int32{half})
-
- // Process each position dimension
- var cosAll, sinAll []*mlx.Array
- for d := int32(0); d < 3; d++ {
- // Get positions for this dimension: [B, L]
- pos := mlx.Slice(posIDs, []int32{d, 0, 0}, []int32{d + 1, B, L})
- pos = mlx.Squeeze(pos, 0) // [B, L]
-
- posExp := mlx.ExpandDims(pos, 2) // [B, L, 1]
- invFreqExp := mlx.Reshape(invFreq, 1, 1, half) // [1, 1, half]
- freqs := mlx.Mul(posExp, invFreqExp) // [B, L, half]
- emb := mlx.Tile(freqs, []int32{1, 1, 2}) // [B, L, D]
-
- cosAll = append(cosAll, mlx.ExpandDims(mlx.Cos(emb), 0))
- sinAll = append(sinAll, mlx.ExpandDims(mlx.Sin(emb), 0))
- }
-
- cos := mlx.Concatenate(cosAll, 0) // [3, B, L, D]
- sin := mlx.Concatenate(sinAll, 0)
-
- return [2]*mlx.Array{cos, sin}
-}
-
-// computeVisionRoPE computes RoPE embeddings for vision patches
-// pH, pW: grid dimensions in patches
-// Returns: [2]*mlx.Array containing (cos, sin) each of shape [numPatches, headDim]
-func (m *Qwen25VL) computeVisionRoPE(pH, pW int32) [2]*mlx.Array {
- cfg := m.Config
- headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads // 80 for 1280/16
- halfDim := headDim / 2 // 40
- quarterDim := halfDim / 2 // 20
- spatialMerge := cfg.VisionSpatialMerge // 2
-
- // Python Qwen2_5_VisionRotaryEmbedding uses dim=head_dim/2=40
- // inv_freq = 1.0 / (theta ** (arange(0, dim, 2) / dim)) -> 20 elements
- theta := float64(10000.0)
- invFreqArr := make([]float32, quarterDim)
- for i := int32(0); i < quarterDim; i++ {
- invFreqArr[i] = float32(1.0 / math.Pow(theta, float64(2*i)/float64(halfDim)))
- }
- invFreq := mlx.NewArray(invFreqArr, []int32{quarterDim})
-
- // Create position IDs matching Python's 2x2 block ordering:
- // Python does: reshape(h//2, 2, w//2, 2), permute(0, 2, 1, 3), flatten
- // This groups patches by 2x2 merged token blocks
- numPatches := pH * pW
- hPosArr := make([]float32, numPatches)
- wPosArr := make([]float32, numPatches)
-
- // Number of merged token blocks
- llmGridH := pH / spatialMerge
- llmGridW := pW / spatialMerge
-
- idx := int32(0)
- for hBlock := int32(0); hBlock < llmGridH; hBlock++ {
- for wBlock := int32(0); wBlock < llmGridW; wBlock++ {
- // Within each 2x2 block: (0,0), (0,1), (1,0), (1,1)
- for dh := int32(0); dh < spatialMerge; dh++ {
- for dw := int32(0); dw < spatialMerge; dw++ {
- h := hBlock*spatialMerge + dh
- w := wBlock*spatialMerge + dw
- hPosArr[idx] = float32(h)
- wPosArr[idx] = float32(w)
- idx++
- }
- }
- }
- }
-
- hPos := mlx.NewArray(hPosArr, []int32{numPatches, 1})
- wPos := mlx.NewArray(wPosArr, []int32{numPatches, 1})
- invFreqExp := mlx.Reshape(invFreq, 1, quarterDim)
-
- // Compute freqs: [numPatches, quarterDim] for each of h and w
- hFreqs := mlx.Mul(hPos, invFreqExp) // [L, 20]
- wFreqs := mlx.Mul(wPos, invFreqExp) // [L, 20]
-
- // Concatenate h and w freqs: [numPatches, halfDim] = [L, 40]
- freqs := mlx.Concatenate([]*mlx.Array{hFreqs, wFreqs}, 1)
-
- // Double for cos/sin application: [L, 40] -> [L, 80] = [L, headDim]
- emb := mlx.Concatenate([]*mlx.Array{freqs, freqs}, 1)
-
- cos := mlx.Cos(emb)
- sin := mlx.Sin(emb)
-
- return [2]*mlx.Array{cos, sin}
-}
-
-// VLTextBlock is a single Qwen2.5 transformer block (for VL model)
-type VLTextBlock struct {
- Attention *VLTextAttention
- MLP *VLTextMLP
- InputLayerNorm *mlx.Array
- PostAttnLayerNorm *mlx.Array
- NormEps float32
-}
-
-// newVLTextBlock creates a text block
-func newVLTextBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VLTextBlock, error) {
- prefix := fmt.Sprintf("model.layers.%d", layerIdx)
-
- inputNorm, err := weights.Get(prefix + ".input_layernorm.weight")
- if err != nil {
- return nil, err
- }
- postAttnNorm, err := weights.Get(prefix + ".post_attention_layernorm.weight")
- if err != nil {
- return nil, err
- }
-
- attention, err := newVLTextAttention(weights, prefix, cfg)
- if err != nil {
- return nil, err
- }
-
- mlpLayer, err := newVLTextMLP(weights, prefix)
- if err != nil {
- return nil, err
- }
-
- return &VLTextBlock{
- Attention: attention,
- MLP: mlpLayer,
- InputLayerNorm: inputNorm,
- PostAttnLayerNorm: postAttnNorm,
- NormEps: cfg.RMSNormEps,
- }, nil
-}
-
-// Forward applies the block
-func (tb *VLTextBlock) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array {
- h := mlx.RMSNorm(x, tb.InputLayerNorm, tb.NormEps)
- attnOut := tb.Attention.Forward(h, cossin)
- x = mlx.Add(x, attnOut)
-
- h = mlx.RMSNorm(x, tb.PostAttnLayerNorm, tb.NormEps)
- mlpOut := tb.MLP.Forward(h)
- x = mlx.Add(x, mlpOut)
-
- return x
-}
-
-// VLTextAttention implements Qwen2.5 attention with M-RoPE
-type VLTextAttention struct {
- QProj *mlx.Array
- KProj *mlx.Array
- VProj *mlx.Array
- OProj *mlx.Array
- QBias *mlx.Array
- KBias *mlx.Array
- VBias *mlx.Array
- NHeads int32
- NKVHeads int32
- HeadDim int32
- Scale float32
- MRoPESection []int32
-}
-
-// newVLTextAttention creates a text attention layer
-func newVLTextAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VLTextAttention, error) {
- qProj, err := weights.Get(prefix + ".self_attn.q_proj.weight")
- if err != nil {
- return nil, err
- }
- kProj, err := weights.Get(prefix + ".self_attn.k_proj.weight")
- if err != nil {
- return nil, err
- }
- vProj, err := weights.Get(prefix + ".self_attn.v_proj.weight")
- if err != nil {
- return nil, err
- }
- oProj, err := weights.Get(prefix + ".self_attn.o_proj.weight")
- if err != nil {
- return nil, err
- }
-
- qBias, _ := weights.Get(prefix + ".self_attn.q_proj.bias")
- kBias, _ := weights.Get(prefix + ".self_attn.k_proj.bias")
- vBias, _ := weights.Get(prefix + ".self_attn.v_proj.bias")
-
- return &VLTextAttention{
- QProj: mlx.Transpose(qProj, 1, 0),
- KProj: mlx.Transpose(kProj, 1, 0),
- VProj: mlx.Transpose(vProj, 1, 0),
- OProj: mlx.Transpose(oProj, 1, 0),
- QBias: qBias,
- KBias: kBias,
- VBias: vBias,
- NHeads: cfg.NumAttentionHeads,
- NKVHeads: cfg.NumKeyValueHeads,
- HeadDim: cfg.HeadDim,
- Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
- MRoPESection: cfg.MRoPESection,
- }, nil
-}
-
-// Forward computes attention
-func (attn *VLTextAttention) Forward(x *mlx.Array, cossin [2]*mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- L := shape[1]
-
- q := mlx.Linear(x, attn.QProj)
- if attn.QBias != nil {
- q = mlx.Add(q, attn.QBias)
- }
- k := mlx.Linear(x, attn.KProj)
- if attn.KBias != nil {
- k = mlx.Add(k, attn.KBias)
- }
- v := mlx.Linear(x, attn.VProj)
- if attn.VBias != nil {
- v = mlx.Add(v, attn.VBias)
- }
-
- q = mlx.Reshape(q, B, L, attn.NHeads, attn.HeadDim)
- k = mlx.Reshape(k, B, L, attn.NKVHeads, attn.HeadDim)
- v = mlx.Reshape(v, B, L, attn.NKVHeads, attn.HeadDim)
-
- q = mlx.Transpose(q, 0, 2, 1, 3)
- k = mlx.Transpose(k, 0, 2, 1, 3)
- v = mlx.Transpose(v, 0, 2, 1, 3)
-
- // Apply M-RoPE
- if cossin[0] != nil && cossin[1] != nil {
- q = applyMRoPE(q, cossin[0], cossin[1], attn.MRoPESection)
- k = applyMRoPE(k, cossin[0], cossin[1], attn.MRoPESection)
- }
-
- // Repeat KV for GQA
- if attn.NKVHeads < attn.NHeads {
- repeats := attn.NHeads / attn.NKVHeads
- k = repeatKV(k, repeats)
- v = repeatKV(v, repeats)
- }
-
- out := mlx.ScaledDotProductAttention(q, k, v, attn.Scale, true)
-
- out = mlx.Transpose(out, 0, 2, 1, 3)
- out = mlx.Reshape(out, B, L, attn.NHeads*attn.HeadDim)
-
- return mlx.Linear(out, attn.OProj)
-}
-
-// applyMRoPE applies Multi-Resolution RoPE
-func applyMRoPE(x *mlx.Array, cos, sin *mlx.Array, section []int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- L := shape[2]
- D := shape[3]
- half := D / 2
-
- fullSection := make([]int32, len(section))
- for i, s := range section {
- fullSection[i] = s * 2
- }
-
- var cosParts, sinParts []*mlx.Array
- offset := int32(0)
- for i, size := range fullSection {
- posDim := int32(i % 3)
- cosSection := mlx.Slice(cos, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size})
- sinSection := mlx.Slice(sin, []int32{posDim, 0, 0, offset}, []int32{posDim + 1, B, L, offset + size})
- cosSection = mlx.Squeeze(cosSection, 0)
- sinSection = mlx.Squeeze(sinSection, 0)
- cosParts = append(cosParts, cosSection)
- sinParts = append(sinParts, sinSection)
- offset += size
- }
-
- cosFlat := mlx.Concatenate(cosParts, 2)
- sinFlat := mlx.Concatenate(sinParts, 2)
-
- cosFlat = mlx.Reshape(cosFlat, B, 1, L, D)
- sinFlat = mlx.Reshape(sinFlat, B, 1, L, D)
-
- x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{B, H, L, half})
- x2 := mlx.Slice(x, []int32{0, 0, 0, half}, []int32{B, H, L, D})
- negX2 := mlx.MulScalar(x2, -1)
- rotatedX := mlx.Concatenate([]*mlx.Array{negX2, x1}, 3)
-
- return mlx.Add(mlx.Mul(x, cosFlat), mlx.Mul(rotatedX, sinFlat))
-}
-
-// repeatKV repeats key/value heads for GQA
-func repeatKV(x *mlx.Array, repeats int32) *mlx.Array {
- if repeats == 1 {
- return x
- }
- shape := x.Shape()
- x = mlx.ExpandDims(x, 2)
- x = mlx.Tile(x, []int32{1, 1, repeats, 1, 1})
- return mlx.Reshape(x, shape[0], shape[1]*repeats, shape[2], shape[3])
-}
-
-// VLTextMLP implements Qwen2.5 SwiGLU MLP
-type VLTextMLP struct {
- GateProj *mlx.Array
- UpProj *mlx.Array
- DownProj *mlx.Array
-}
-
-// newVLTextMLP creates a text MLP layer
-func newVLTextMLP(weights *safetensors.ModelWeights, prefix string) (*VLTextMLP, error) {
- gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight")
- if err != nil {
- return nil, err
- }
- upProj, err := weights.Get(prefix + ".mlp.up_proj.weight")
- if err != nil {
- return nil, err
- }
- downProj, err := weights.Get(prefix + ".mlp.down_proj.weight")
- if err != nil {
- return nil, err
- }
-
- return &VLTextMLP{
- GateProj: mlx.Transpose(gateProj, 1, 0),
- UpProj: mlx.Transpose(upProj, 1, 0),
- DownProj: mlx.Transpose(downProj, 1, 0),
- }, nil
-}
-
-// Forward applies the SwiGLU MLP
-func (mlp *VLTextMLP) Forward(x *mlx.Array) *mlx.Array {
- gate := mlx.Linear(x, mlp.GateProj)
- gate = mlx.SiLU(gate)
- up := mlx.Linear(x, mlp.UpProj)
- h := mlx.Mul(gate, up)
- return mlx.Linear(h, mlp.DownProj)
-}
-
-// VisionPatchEmbed embeds image patches
-type VisionPatchEmbed struct {
- ProjWeight *mlx.Array
- ProjBias *mlx.Array
- PatchSize int32
-}
-
-// newVisionPatchEmbed creates a vision patch embed layer
-func newVisionPatchEmbed(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionPatchEmbed, error) {
- projWeight, err := weights.Get("visual.patch_embed.proj.weight")
- if err != nil {
- return nil, err
- }
- projBias, _ := weights.Get("visual.patch_embed.proj.bias")
-
- return &VisionPatchEmbed{
- ProjWeight: projWeight,
- ProjBias: projBias,
- PatchSize: cfg.VisionPatchSize,
- }, nil
-}
-
-// Forward embeds patches from an image
-// image: [B, C, H, W]
-// Returns: [B, num_patches, hidden_size]
-func (pe *VisionPatchEmbed) Forward(image *mlx.Array) *mlx.Array {
- // Qwen2.5-VL uses 3D conv for patch embedding to support video
- // Weight shape is [O, I, kT, kH, kW] e.g. [1280, 3, 2, 14, 14]
- // For single image, we duplicate the frame to match temporal_patch_size
-
- wShape := pe.ProjWeight.Shape()
- if len(wShape) == 5 {
- // 3D convolution case
- temporalPatchSize := wShape[2] // kT from weight shape
-
- // Add temporal dimension: [B, C, H, W] -> [B, C, 1, H, W]
- image = mlx.ExpandDims(image, 2)
-
- // Duplicate frame to match temporal_patch_size (Python does this for single images)
- // [B, C, 1, H, W] -> [B, C, T, H, W] where T = temporal_patch_size
- if temporalPatchSize > 1 {
- image = mlx.Tile(image, []int32{1, 1, temporalPatchSize, 1, 1})
- }
-
- // Convert to channels-last: [B, C, T, H, W] -> [B, T, H, W, C]
- image = mlx.Transpose(image, 0, 2, 3, 4, 1)
-
- // Weight is [O, I, kT, kH, kW] - keep as-is since patches are now in [I, kT, kH, kW] order
- // (extractPatches3DStrided transposes each patch to [C, T, H, W] to match Python)
-
- // Apply 3D conv using manual patch extraction
- // Strides: (temporal_patch_size, patch_size, patch_size)
- x := conv3DStrided(image, pe.ProjWeight, temporalPatchSize, pe.PatchSize, pe.PatchSize)
-
- if pe.ProjBias != nil {
- outC := pe.ProjBias.Dim(0)
- bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, 1, outC)
- x = mlx.Add(x, bias)
- }
-
- // x is [B, T', H', W', C], squeeze T' and flatten spatial
- shape := x.Shape()
- // T' should be 1 for single image (since we used stride=temporal_patch_size)
- x = mlx.Reshape(x, shape[0], shape[2]*shape[3], shape[4])
-
- return x
- }
-
- // Original 2D case (fallback)
- // Convert to channels-last for Conv2d
- image = mlx.Transpose(image, 0, 2, 3, 1) // [B, H, W, C]
-
- // Apply conv with stride=patch_size using manual strided convolution
- weight := mlx.Transpose(pe.ProjWeight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
- x := conv2DStrided(image, weight, pe.PatchSize)
- if pe.ProjBias != nil {
- bias := mlx.Reshape(pe.ProjBias, 1, 1, 1, pe.ProjBias.Dim(0))
- x = mlx.Add(x, bias)
- }
-
- // Flatten patches: [B, pH, pW, C] -> [B, pH*pW, C]
- shape := x.Shape()
- x = mlx.Reshape(x, shape[0], shape[1]*shape[2], shape[3])
-
- return x
-}
-
-// VisionBlock is a single vision transformer block
-type VisionBlock struct {
- Norm1 *mlx.Array
- Norm2 *mlx.Array
- Attention *VisionAttention
- MLP *VisionMLP
-}
-
-// newVisionBlock creates a vision block
-func newVisionBlock(weights *safetensors.ModelWeights, layerIdx int, cfg *Qwen25VLConfig) (*VisionBlock, error) {
- prefix := fmt.Sprintf("visual.blocks.%d", layerIdx)
-
- norm1, err := weights.Get(prefix + ".norm1.weight")
- if err != nil {
- return nil, err
- }
- norm2, err := weights.Get(prefix + ".norm2.weight")
- if err != nil {
- return nil, err
- }
-
- attention, err := newVisionAttention(weights, prefix, cfg)
- if err != nil {
- return nil, err
- }
-
- mlpLayer, err := newVisionMLP(weights, prefix, cfg)
- if err != nil {
- return nil, err
- }
-
- return &VisionBlock{
- Norm1: norm1,
- Norm2: norm2,
- Attention: attention,
- MLP: mlpLayer,
- }, nil
-}
-
-// Forward applies the vision block
-// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil
-// cuSeqlens: cumulative sequence lengths for window attention
-func (vb *VisionBlock) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array {
- // Python uses RMSNorm, not LayerNorm!
- h := mlx.RMSNormNoWeight(x, 1e-6)
- h = mlx.Mul(h, vb.Norm1)
- attnOut := vb.Attention.Forward(h, posEmb, cuSeqlens)
- x = mlx.Add(x, attnOut)
-
- h = mlx.RMSNormNoWeight(x, 1e-6)
- h = mlx.Mul(h, vb.Norm2)
- mlpOut := vb.MLP.Forward(h)
- x = mlx.Add(x, mlpOut)
-
- return x
-}
-
-// VisionAttention implements vision attention
-type VisionAttention struct {
- QKVProj *mlx.Array
- QKVBias *mlx.Array
- OutProj *mlx.Array
- OutBias *mlx.Array
- NHeads int32
- HeadDim int32
- Scale float32
-}
-
-// newVisionAttention creates a vision attention layer
-func newVisionAttention(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionAttention, error) {
- qkvProj, err := weights.Get(prefix + ".attn.qkv.weight")
- if err != nil {
- return nil, err
- }
- qkvBias, _ := weights.Get(prefix + ".attn.qkv.bias")
- outProj, err := weights.Get(prefix + ".attn.proj.weight")
- if err != nil {
- return nil, err
- }
- outBias, _ := weights.Get(prefix + ".attn.proj.bias")
-
- headDim := cfg.VisionHiddenSize / cfg.VisionNumHeads
-
- return &VisionAttention{
- QKVProj: mlx.Transpose(qkvProj, 1, 0),
- QKVBias: qkvBias,
- OutProj: mlx.Transpose(outProj, 1, 0),
- OutBias: outBias,
- NHeads: cfg.VisionNumHeads,
- HeadDim: headDim,
- Scale: float32(1.0 / math.Sqrt(float64(headDim))),
- }, nil
-}
-
-// Forward applies vision attention with optional RoPE and window attention
-// posEmb: [2]*mlx.Array containing (cos, sin) for RoPE, can be nil
-// cuSeqlens: cumulative sequence lengths for window boundaries
-func (attn *VisionAttention) Forward(x *mlx.Array, posEmb [2]*mlx.Array, cuSeqlens []int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- L := shape[1]
- D := shape[2]
-
- qkv := mlx.Linear(x, attn.QKVProj)
- if attn.QKVBias != nil {
- qkv = mlx.Add(qkv, attn.QKVBias)
- }
-
- // Split into Q, K, V
- qkv = mlx.Reshape(qkv, B, L, 3, attn.NHeads, attn.HeadDim)
- q := mlx.Slice(qkv, []int32{0, 0, 0, 0, 0}, []int32{B, L, 1, attn.NHeads, attn.HeadDim})
- k := mlx.Slice(qkv, []int32{0, 0, 1, 0, 0}, []int32{B, L, 2, attn.NHeads, attn.HeadDim})
- v := mlx.Slice(qkv, []int32{0, 0, 2, 0, 0}, []int32{B, L, 3, attn.NHeads, attn.HeadDim})
-
- q = mlx.Squeeze(q, 2) // [B, L, H, D]
- k = mlx.Squeeze(k, 2)
- v = mlx.Squeeze(v, 2)
-
- // Apply RoPE if position embeddings provided
- if posEmb[0] != nil && posEmb[1] != nil {
- q, k = applyVisionRoPE(q, k, posEmb[0], posEmb[1])
- }
-
- q = mlx.Transpose(q, 0, 2, 1, 3) // [B, H, L, D]
- k = mlx.Transpose(k, 0, 2, 1, 3)
- v = mlx.Transpose(v, 0, 2, 1, 3)
-
- var out *mlx.Array
-
- // Check if we need window attention (more than 1 window)
- numWindows := len(cuSeqlens) - 1
- if numWindows <= 1 {
- // Full attention - single window covering entire sequence
- out = mlx.ScaledDotProductAttention(q, k, v, attn.Scale, false)
- } else {
- // Window attention - process each window separately
- attnOutputs := make([]*mlx.Array, numWindows)
-
- for w := 0; w < numWindows; w++ {
- start := cuSeqlens[w]
- end := cuSeqlens[w+1]
-
- // Slice Q, K, V for this window: [B, H, winLen, D]
- qWin := mlx.Slice(q, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
- kWin := mlx.Slice(k, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
- vWin := mlx.Slice(v, []int32{0, 0, start, 0}, []int32{B, attn.NHeads, end, attn.HeadDim})
-
- // Compute attention for this window
- attnWin := mlx.ScaledDotProductAttention(qWin, kWin, vWin, attn.Scale, false)
- attnOutputs[w] = attnWin
- }
-
- // Concatenate all window outputs along sequence dimension
- out = mlx.Concatenate(attnOutputs, 2)
- }
-
- out = mlx.Transpose(out, 0, 2, 1, 3) // [B, L, H, D]
- out = mlx.Reshape(out, B, L, D)
-
- out = mlx.Linear(out, attn.OutProj)
- if attn.OutBias != nil {
- out = mlx.Add(out, attn.OutBias)
- }
-
- return out
-}
-
-// applyVisionRoPE applies rotary position embedding to Q and K for vision
-// q, k: [B, L, H, D], cos, sin: [L, D] (already doubled: D = head_dim)
-// Returns: rotated q, k with same shape
-// Note: Python does this computation in float32 for numerical stability
-func applyVisionRoPE(q, k, cos, sin *mlx.Array) (*mlx.Array, *mlx.Array) {
- // Convert to float32 for numerical stability (matches Python)
- origDtype := q.Dtype()
- q = mlx.AsType(q, mlx.DtypeFloat32)
- k = mlx.AsType(k, mlx.DtypeFloat32)
- cos = mlx.AsType(cos, mlx.DtypeFloat32)
- sin = mlx.AsType(sin, mlx.DtypeFloat32)
-
- // Expand cos/sin to match q/k shape: [L, D] -> [1, L, 1, D]
- cos = mlx.ExpandDims(cos, 0)
- cos = mlx.ExpandDims(cos, 2)
- sin = mlx.ExpandDims(sin, 0)
- sin = mlx.ExpandDims(sin, 2)
-
- // rotate_half: split last dim in half and swap with negation
- // q_rot = q * cos + rotate_half(q) * sin
- qRotated := rotateHalf(q)
- kRotated := rotateHalf(k)
-
- qOut := mlx.Add(mlx.Mul(q, cos), mlx.Mul(qRotated, sin))
- kOut := mlx.Add(mlx.Mul(k, cos), mlx.Mul(kRotated, sin))
-
- // Convert back to original dtype
- qOut = mlx.AsType(qOut, origDtype)
- kOut = mlx.AsType(kOut, origDtype)
-
- return qOut, kOut
-}
-
-// rotateHalf rotates the last dimension by splitting in half and swapping with negation
-// x: [..., D] -> split to [..., D/2] and [..., D/2], then concat(-x2, x1)
-func rotateHalf(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- lastDim := shape[len(shape)-1]
- halfDim := lastDim / 2
-
- // Split into two halves
- x1 := mlx.Slice(x, []int32{0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], halfDim})
- x2 := mlx.Slice(x, []int32{0, 0, 0, halfDim}, []int32{shape[0], shape[1], shape[2], lastDim})
-
- // Negate x2 and concatenate
- x2Neg := mlx.MulScalar(x2, -1.0)
- return mlx.Concatenate([]*mlx.Array{x2Neg, x1}, 3)
-}
-
-// VisionMLP implements vision SwiGLU MLP
-type VisionMLP struct {
- GateProj *mlx.Array
- GateProjBias *mlx.Array
- UpProj *mlx.Array
- UpProjBias *mlx.Array
- DownProj *mlx.Array
- DownProjBias *mlx.Array
-}
-
-// newVisionMLP creates a vision MLP layer
-func newVisionMLP(weights *safetensors.ModelWeights, prefix string, cfg *Qwen25VLConfig) (*VisionMLP, error) {
- gateProj, err := weights.Get(prefix + ".mlp.gate_proj.weight")
- if err != nil {
- return nil, err
- }
- gateProjBias, _ := weights.Get(prefix + ".mlp.gate_proj.bias")
- upProj, err := weights.Get(prefix + ".mlp.up_proj.weight")
- if err != nil {
- return nil, err
- }
- upProjBias, _ := weights.Get(prefix + ".mlp.up_proj.bias")
- downProj, err := weights.Get(prefix + ".mlp.down_proj.weight")
- if err != nil {
- return nil, err
- }
- downProjBias, _ := weights.Get(prefix + ".mlp.down_proj.bias")
-
- return &VisionMLP{
- GateProj: mlx.Transpose(gateProj, 1, 0),
- GateProjBias: gateProjBias,
- UpProj: mlx.Transpose(upProj, 1, 0),
- UpProjBias: upProjBias,
- DownProj: mlx.Transpose(downProj, 1, 0),
- DownProjBias: downProjBias,
- }, nil
-}
-
-// Forward applies the vision SwiGLU MLP
-func (m *VisionMLP) Forward(x *mlx.Array) *mlx.Array {
- gate := mlx.Linear(x, m.GateProj)
- if m.GateProjBias != nil {
- gate = mlx.Add(gate, m.GateProjBias)
- }
- gate = mlx.SiLU(gate)
-
- up := mlx.Linear(x, m.UpProj)
- if m.UpProjBias != nil {
- up = mlx.Add(up, m.UpProjBias)
- }
-
- h := mlx.Mul(gate, up)
- h = mlx.Linear(h, m.DownProj)
- if m.DownProjBias != nil {
- h = mlx.Add(h, m.DownProjBias)
- }
- return h
-}
-
-// VisionMerger merges spatial patches (2x2 -> 1)
-type VisionMerger struct {
- MLP0Weight *mlx.Array
- MLP0Bias *mlx.Array
- MLP2Weight *mlx.Array
- MLP2Bias *mlx.Array
- LNWeight *mlx.Array
-}
-
-// newVisionMerger creates a vision merger
-func newVisionMerger(weights *safetensors.ModelWeights, cfg *Qwen25VLConfig) (*VisionMerger, error) {
- mlp0Weight, err := weights.Get("visual.merger.mlp.0.weight")
- if err != nil {
- return nil, err
- }
- mlp0Bias, _ := weights.Get("visual.merger.mlp.0.bias")
- mlp2Weight, err := weights.Get("visual.merger.mlp.2.weight")
- if err != nil {
- return nil, err
- }
- mlp2Bias, _ := weights.Get("visual.merger.mlp.2.bias")
- lnWeight, _ := weights.Get("visual.merger.ln_q.weight")
-
- return &VisionMerger{
- MLP0Weight: mlx.Transpose(mlp0Weight, 1, 0),
- MLP0Bias: mlp0Bias,
- MLP2Weight: mlx.Transpose(mlp2Weight, 1, 0),
- MLP2Bias: mlp2Bias,
- LNWeight: lnWeight,
- }, nil
-}
-
-// Forward merges 2x2 patches into 1 (assumes square grid - use ForwardWithDims for non-square)
-func (m *VisionMerger) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- L := shape[1]
- side := int32(math.Sqrt(float64(L)))
- return m.ForwardWithDims(x, side, side)
-}
-
-// ForwardWithDims merges 2x2 patches into 1 with explicit grid dimensions
-// After window reordering, consecutive 4 patches form a 2x2 block, so we just
-// reshape [B, L, D] -> [B, L/4, 4*D] without 2D spatial rearrangement.
-func (m *VisionMerger) ForwardWithDims(x *mlx.Array, pH, pW int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- L := shape[1]
- D := shape[2]
-
- // RMSNorm BEFORE merge (applied to each token with D dimensions)
- // Python: ln_q = Qwen2RMSNorm(context_dim, eps=1e-6)
- if m.LNWeight != nil {
- x = mlx.RMSNormNoWeight(x, 1e-6)
- x = mlx.Mul(x, m.LNWeight)
- }
-
- // After window reordering, consecutive 4 patches belong to a 2x2 block
- // Just reshape to [B, L/4, 4*D] - no spatial rearrangement needed
- newL := L / 4
- x = mlx.Reshape(x, B, newL, 4*D)
-
- // MLP
- h := mlx.Linear(x, m.MLP0Weight)
- if m.MLP0Bias != nil {
- h = mlx.Add(h, m.MLP0Bias)
- }
- h = mlx.GELU(h)
- h = mlx.Linear(h, m.MLP2Weight)
- if m.MLP2Bias != nil {
- h = mlx.Add(h, m.MLP2Bias)
- }
-
- return h
-}
-
-// LoadQwen25VLFromPath loads the encoder from path
-func LoadQwen25VLFromPath(path string) (*Qwen25VL, error) {
- m := &Qwen25VL{}
- if err := m.Load(filepath.Join(path, "text_encoder")); err != nil {
- return nil, err
- }
- return m, nil
-}
-
-// conv2DStrided applies conv with stride > 1 using manual patch extraction
-// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
-func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- W := shape[2]
-
- wShape := weight.Shape()
- Cout := wShape[0]
- kH := wShape[1]
- kW := wShape[2]
-
- outH := (H - kH) / stride + 1
- outW := (W - kW) / stride + 1
-
- patches := extractPatches2DStrided(x, kH, kW, stride)
- wFlat := mlx.Reshape(weight, Cout, -1)
- patches = mlx.Reshape(patches, B*outH*outW, -1)
- out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
- return mlx.Reshape(out, B, outH, outW, Cout)
-}
-
-// conv3DStrided applies 3D conv with strides using manual patch extraction
-// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
-// strideT, strideH, strideW are the strides for each dimension
-// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
-func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
-
- wShape := weight.Shape()
- Cout := wShape[0]
- // I := wShape[1]
- kT := wShape[2]
- kH := wShape[3]
- kW := wShape[4]
-
- // For temporal: if T < kT, we need to repeat frames temporally
- // For single image with T=1 and kT=2, we duplicate the frame to T=kT
- // Python Qwen2.5-VL duplicates the frame, not zero-pads
- if T < kT {
- // Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
- x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
- T = kT
- }
-
- outT := (T - kT) / strideT + 1
- outH := (H - kH) / strideH + 1
- outW := (W - kW) / strideW + 1
-
- // Extract 3D patches in [C, T, H, W] order to match Python
- patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
- // patches shape: [B, outT, outH, outW, C*kT*kH*kW]
-
- // Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
- wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
- patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
- out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
- return mlx.Reshape(out, B, outT, outH, outW, Cout)
-}
-
-// extractPatches3DStrided extracts 3D patches with given strides
-// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
-func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
-
- outT := (T - kT) / strideT + 1
- outH := (H - kH) / strideH + 1
- outW := (W - kW) / strideW + 1
-
- numPatches := outT * outH * outW
- patches := make([]*mlx.Array, numPatches)
- idx := 0
- for t := int32(0); t < outT; t++ {
- for i := int32(0); i < outH; i++ {
- for j := int32(0); j < outW; j++ {
- startT := t * strideT
- startH := i * strideH
- startW := j * strideW
- // Extract patch: [B, kT, kH, kW, C]
- patch := mlx.Slice(x,
- []int32{0, startT, startH, startW, 0},
- []int32{B, startT + kT, startH + kH, startW + kW, C})
- // Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
- patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
- // Flatten to [B, C*T*H*W]
- patch = mlx.Reshape(patch, B, C*kT*kH*kW)
- patches[idx] = patch
- idx++
- }
- }
- }
-
- for i := range patches {
- patches[i] = mlx.ExpandDims(patches[i], 1)
- }
- stacked := mlx.Concatenate(patches, 1)
- return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
-}
-
-// extractPatches2DStrided extracts patches with given stride
-func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- W := shape[2]
- C := shape[3]
-
- outH := (H - kH) / stride + 1
- outW := (W - kW) / stride + 1
-
- patches := make([]*mlx.Array, outH*outW)
- idx := 0
- for i := int32(0); i < outH; i++ {
- for j := int32(0); j < outW; j++ {
- startH := i * stride
- startW := j * stride
- patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
- patch = mlx.Reshape(patch, B, kH*kW*C)
- patches[idx] = patch
- idx++
- }
- }
-
- for i := range patches {
- patches[i] = mlx.ExpandDims(patches[i], 1)
- }
- stacked := mlx.Concatenate(patches, 1)
- return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
-}
diff --git a/x/imagegen/models/qwen_image/qwen_image.go b/x/imagegen/models/qwen_image/qwen_image.go
deleted file mode 100644
index a7e554623c6..00000000000
--- a/x/imagegen/models/qwen_image/qwen_image.go
+++ /dev/null
@@ -1,367 +0,0 @@
-//go:build mlx
-
-// Package qwen_image implements the Qwen-Image diffusion transformer model.
-package qwen_image
-
-import (
- "context"
- "fmt"
- "path/filepath"
- "time"
-
- "github.com/ollama/ollama/x/imagegen/cache"
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/tokenizer"
-)
-
-// GenerateConfig holds all options for image generation.
-type GenerateConfig struct {
- Prompt string
- NegativePrompt string // Empty = no CFG
- CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
- Width int32 // Image width (default: 1024)
- Height int32 // Image height (default: 1024)
- Steps int // Denoising steps (default: 30)
- Seed int64 // Random seed
- Progress func(step, totalSteps int) // Optional progress callback
-
- // Layer caching (DeepCache/Learning-to-Cache speedup)
- LayerCache bool // Enable layer caching (default: false)
- CacheInterval int // Refresh cache every N steps (default: 3)
- CacheLayers int // Number of shallow layers to cache (default: 25)
-}
-
-// Model represents a Qwen-Image diffusion model.
-type Model struct {
- ModelPath string
- Tokenizer *tokenizer.Tokenizer
- TextEncoder *Qwen25VL
- Transformer *Transformer
- VAEDecoder *VAEDecoder
-}
-
-// Load loads the Qwen-Image model from a directory.
-func (m *Model) Load(modelPath string) error {
- fmt.Println("Loading Qwen-Image model...")
- start := time.Now()
-
- if mlx.GPUIsAvailable() {
- mlx.SetDefaultDeviceGPU()
- mlx.EnableCompile()
- }
-
- m.ModelPath = modelPath
-
- // Load tokenizer
- fmt.Print(" Loading tokenizer... ")
- tokenizerPath := filepath.Join(modelPath, "tokenizer")
- tok, err := tokenizer.Load(tokenizerPath)
- if err != nil {
- return fmt.Errorf("tokenizer: %w", err)
- }
- m.Tokenizer = tok
- fmt.Println("✓")
-
- // Load text encoder (Qwen2.5-VL in text-only mode - skip vision tower for efficiency)
- m.TextEncoder = &Qwen25VL{}
- if err := m.TextEncoder.LoadTextOnly(filepath.Join(modelPath, "text_encoder")); err != nil {
- return fmt.Errorf("text encoder: %w", err)
- }
- mlx.Eval(mlx.Collect(m.TextEncoder)...)
- fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
- float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
- float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- // Load transformer
- m.Transformer = &Transformer{}
- if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
- return fmt.Errorf("transformer: %w", err)
- }
- mlx.Eval(mlx.Collect(m.Transformer)...)
- fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
- float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
- float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- // Load VAE decoder
- m.VAEDecoder = &VAEDecoder{}
- if err := m.VAEDecoder.Load(filepath.Join(modelPath, "vae")); err != nil {
- return fmt.Errorf("VAE decoder: %w", err)
- }
- mlx.Eval(mlx.Collect(m.VAEDecoder)...)
- fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
- float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
- float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- mem := mlx.MetalGetActiveMemory()
- peak := mlx.MetalGetPeakMemory()
- fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
- time.Since(start).Seconds(),
- float64(mem)/(1024*1024*1024),
- float64(peak)/(1024*1024*1024))
-
- return nil
-}
-
-// Generate creates an image from a prompt.
-func (m *Model) Generate(prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
- return m.GenerateFromConfig(&GenerateConfig{
- Prompt: prompt,
- Width: width,
- Height: height,
- Steps: steps,
- Seed: seed,
- })
-}
-
-// GenerateWithProgress creates an image with progress callback.
-func (m *Model) GenerateWithProgress(prompt string, width, height int32, steps int, seed int64, progress func(step, totalSteps int)) (*mlx.Array, error) {
- return m.GenerateFromConfig(&GenerateConfig{
- Prompt: prompt,
- Width: width,
- Height: height,
- Steps: steps,
- Seed: seed,
- Progress: progress,
- })
-}
-
-// GenerateWithCFG creates an image with classifier-free guidance.
-func (m *Model) GenerateWithCFG(prompt, negativePrompt string, width, height int32, steps int, seed int64, cfgScale float32, progress func(step, totalSteps int)) (*mlx.Array, error) {
- return m.GenerateFromConfig(&GenerateConfig{
- Prompt: prompt,
- NegativePrompt: negativePrompt,
- CFGScale: cfgScale,
- Width: width,
- Height: height,
- Steps: steps,
- Seed: seed,
- Progress: progress,
- })
-}
-
-// GenerateFromConfig generates an image using the unified config struct.
-func (m *Model) GenerateFromConfig(cfg *GenerateConfig) (*mlx.Array, error) {
- start := time.Now()
- result, err := m.generate(cfg)
- if err != nil {
- return nil, err
- }
- if cfg.NegativePrompt != "" {
- fmt.Printf("Generated with CFG (scale=%.1f) in %.2fs (%d steps)\n", cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
- } else {
- fmt.Printf("Generated in %.2fs (%d steps)\n", time.Since(start).Seconds(), cfg.Steps)
- }
- return result, nil
-}
-
-// GenerateImage implements model.ImageModel interface.
-func (m *Model) GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
- return m.Generate(prompt, width, height, steps, seed)
-}
-
-// generate is the internal denoising pipeline.
-func (m *Model) generate(cfg *GenerateConfig) (*mlx.Array, error) {
- // Apply defaults
- if cfg.Width <= 0 {
- cfg.Width = 1024
- }
- if cfg.Height <= 0 {
- cfg.Height = 1024
- }
- if cfg.Steps <= 0 {
- cfg.Steps = 50
- }
- if cfg.CFGScale <= 0 {
- cfg.CFGScale = 4.0
- }
- if cfg.CacheInterval <= 0 {
- cfg.CacheInterval = 3
- }
- if cfg.CacheLayers <= 0 {
- cfg.CacheLayers = 25 // ~42% of 60 layers (similar ratio to Z-Image's 15/38)
- }
-
- useCFG := cfg.NegativePrompt != ""
- tcfg := m.Transformer.Config
- latentH := cfg.Height / 8
- latentW := cfg.Width / 8
- pH := latentH / tcfg.PatchSize
- pW := latentW / tcfg.PatchSize
- imgSeqLen := pH * pW
-
- // Text encoding
- var posEmb, negEmb *mlx.Array
- {
- posEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.Prompt)
- if useCFG {
- negEmb = m.TextEncoder.EncodePrompt(m.Tokenizer, cfg.NegativePrompt)
- mlx.Keep(posEmb, negEmb)
- mlx.Eval(posEmb, negEmb)
- } else {
- mlx.Keep(posEmb)
- mlx.Eval(posEmb)
- }
- }
-
- // Pad sequences to same length for CFG
- txtLen := posEmb.Shape()[1]
- if useCFG {
- negLen := negEmb.Shape()[1]
- if negLen > txtLen {
- txtLen = negLen
- }
- if posEmb.Shape()[1] < txtLen {
- posEmb = padSequence(posEmb, txtLen)
- }
- if negEmb.Shape()[1] < txtLen {
- negEmb = padSequence(negEmb, txtLen)
- }
- mlx.Keep(posEmb, negEmb)
- }
-
- // Pre-compute batched embeddings for CFG (single forward pass optimization)
- var batchedEmb *mlx.Array
- if useCFG {
- batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
- mlx.Keep(batchedEmb)
- mlx.Eval(batchedEmb)
- }
-
- // Scheduler
- scheduler := NewFlowMatchScheduler(DefaultSchedulerConfig())
- scheduler.SetTimesteps(cfg.Steps, imgSeqLen)
-
- // Init latents [B, C, T, H, W]
- var latents *mlx.Array
- {
- latents = scheduler.InitNoise([]int32{1, tcfg.OutChannels, 1, latentH, latentW}, cfg.Seed)
- mlx.Eval(latents)
- }
-
- // RoPE cache
- var ropeCache *RoPECache
- {
- ropeCache = PrepareRoPE(pH, pW, txtLen, tcfg.AxesDimsRope)
- mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
- mlx.Eval(ropeCache.ImgFreqs)
- }
-
- // Layer cache for DeepCache/Learning-to-Cache speedup
- var stepCache *cache.StepCache
- if cfg.LayerCache {
- stepCache = cache.NewStepCache(cfg.CacheLayers)
- fmt.Printf(" Layer caching: %d layers, refresh every %d steps\n", cfg.CacheLayers, cfg.CacheInterval)
- }
-
- // Denoising loop
- for i := 0; i < cfg.Steps; i++ {
- stepStart := time.Now()
- if cfg.Progress != nil {
- cfg.Progress(i+1, cfg.Steps)
- }
-
- t := scheduler.Timesteps[i]
- timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
-
- // Squeeze temporal dim: [B, C, T, H, W] -> [B, C, H, W]
- latents2D := mlx.Squeeze(latents, 2)
- patches := PackLatents(latents2D, tcfg.PatchSize)
-
- var output *mlx.Array
- if useCFG {
- // CFG Batching: single forward pass with batch=2
- // Note: layer caching with CFG is not supported yet (would need 2 caches)
- batchedPatches := mlx.Tile(patches, []int32{2, 1, 1})
- batchedTimestep := mlx.Tile(timestep, []int32{2})
-
- // Single batched forward pass
- batchedOutput := m.Transformer.Forward(batchedPatches, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
-
- // Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
- L := batchedOutput.Shape()[1]
- D := batchedOutput.Shape()[2]
- posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, L, D})
- negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, L, D})
-
- diff := mlx.Sub(posOutput, negOutput)
- scaledDiff := mlx.MulScalar(diff, cfg.CFGScale)
- combPred := mlx.Add(negOutput, scaledDiff)
-
- // Norm rescaling: rescale combined prediction to match conditional prediction's norm
- condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posOutput), -1, true))
- combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
- output = mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
- } else if stepCache != nil {
- output = m.Transformer.ForwardWithCache(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs,
- stepCache, i, cfg.CacheInterval, cfg.CacheLayers)
- } else {
- output = m.Transformer.Forward(patches, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
- }
-
- noisePred := UnpackLatents(output, latentH, latentW, tcfg.PatchSize)
- oldLatents := latents
- latents = scheduler.Step(noisePred, latents, i)
-
- // Keep cached arrays alive across cleanup
- if stepCache != nil {
- mlx.Keep(stepCache.Arrays()...)
- }
- mlx.Eval(latents)
- oldLatents.Free()
-
- activeMem := float64(mlx.MetalGetActiveMemory()) / (1024 * 1024 * 1024)
- peakMem := float64(mlx.MetalGetPeakMemory()) / (1024 * 1024 * 1024)
- fmt.Printf(" Step %d/%d: t=%.4f (%.2fs) [%.1f GB active, %.1f GB peak]\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds(), activeMem, peakMem)
- }
-
- // Free denoising temporaries before VAE decode
- posEmb.Free()
- if negEmb != nil {
- negEmb.Free()
- }
- if batchedEmb != nil {
- batchedEmb.Free()
- }
- ropeCache.ImgFreqs.Free()
- ropeCache.TxtFreqs.Free()
- if stepCache != nil {
- stepCache.Free()
- }
-
- // VAE decode (Decode manages its own pools for staged memory)
- decoded := m.VAEDecoder.Decode(latents)
- latents.Free()
- // Post-process: squeeze temporal dim and rescale to [0, 1]
- {
- decoded = mlx.Squeeze(decoded, 2)
- decoded = mlx.AddScalar(decoded, 1.0)
- decoded = mlx.DivScalar(decoded, 2.0)
- mlx.Eval(decoded)
- }
-
- fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- return decoded, nil
-}
-
-// padSequence pads a sequence tensor to the target length with zeros
-func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
- shape := x.Shape()
- currentLen := shape[1]
- if currentLen >= targetLen {
- return x
- }
- padLen := targetLen - currentLen
- // Pad on sequence dimension (axis 1)
- return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
-}
-
-// LoadPersistent is an alias for backward compatibility.
-// Use m := &Model{}; m.Load(path) instead.
-func LoadPersistent(modelPath string) (*Model, error) {
- m := &Model{}
- if err := m.Load(modelPath); err != nil {
- return nil, err
- }
- return m, nil
-}
diff --git a/x/imagegen/models/qwen_image/scheduler.go b/x/imagegen/models/qwen_image/scheduler.go
deleted file mode 100644
index d1f0da049ed..00000000000
--- a/x/imagegen/models/qwen_image/scheduler.go
+++ /dev/null
@@ -1,218 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "math"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
-)
-
-// SchedulerConfig holds FlowMatchEulerDiscreteScheduler configuration
-type SchedulerConfig struct {
- NumTrainTimesteps int32 `json:"num_train_timesteps"` // 1000
- BaseShift float32 `json:"base_shift"` // 0.5
- MaxShift float32 `json:"max_shift"` // 0.9
- BaseImageSeqLen int32 `json:"base_image_seq_len"` // 256
- MaxImageSeqLen int32 `json:"max_image_seq_len"` // 8192
- ShiftTerminal float32 `json:"shift_terminal"` // 0.02
- UseDynamicShift bool `json:"use_dynamic_shifting"` // true
-}
-
-// DefaultSchedulerConfig returns config for FlowMatchEulerDiscreteScheduler
-func DefaultSchedulerConfig() *SchedulerConfig {
- return &SchedulerConfig{
- NumTrainTimesteps: 1000,
- BaseShift: 0.5,
- MaxShift: 0.9, // Matches scheduler_config.json
- BaseImageSeqLen: 256,
- MaxImageSeqLen: 8192,
- ShiftTerminal: 0.02,
- UseDynamicShift: true,
- }
-}
-
-// FlowMatchScheduler implements the Flow Match Euler discrete scheduler
-type FlowMatchScheduler struct {
- Config *SchedulerConfig
- Timesteps []float32
- Sigmas []float32
- NumSteps int
-}
-
-// NewFlowMatchScheduler creates a new scheduler
-func NewFlowMatchScheduler(cfg *SchedulerConfig) *FlowMatchScheduler {
- return &FlowMatchScheduler{
- Config: cfg,
- }
-}
-
-// CalculateShift computes the dynamic shift based on image sequence length
-// This matches Python's calculate_shift function
-func CalculateShift(imageSeqLen int32, baseSeqLen int32, maxSeqLen int32, baseShift float32, maxShift float32) float32 {
- m := (maxShift - baseShift) / float32(maxSeqLen-baseSeqLen)
- b := baseShift - m*float32(baseSeqLen)
- mu := float32(imageSeqLen)*m + b
- return mu
-}
-
-// SetTimesteps sets up the scheduler for the given number of inference steps
-// Matches Python diffusers FlowMatchEulerDiscreteScheduler behavior:
-// 1. Create sigmas from sigma_max to sigma_min (linspace)
-// 2. Apply time_shift with mu (if dynamic shifting)
-// 3. Apply stretch_shift_to_terminal to make final value = shift_terminal
-func (s *FlowMatchScheduler) SetTimesteps(numSteps int, imageSeqLen int32) {
- s.NumSteps = numSteps
-
- // Calculate mu for dynamic shifting
- var mu float32
- if s.Config.UseDynamicShift {
- mu = CalculateShift(
- imageSeqLen,
- s.Config.BaseImageSeqLen,
- s.Config.MaxImageSeqLen,
- s.Config.BaseShift,
- s.Config.MaxShift,
- )
- }
-
- // Step 1: Create sigmas from 1.0 to 1/num_steps
- // Python (pipeline_qwenimage.py:639):
- // sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
- // This gives sigmas from 1.0 to 1/30 = 0.033 for 30 steps
- sigmas := make([]float32, numSteps)
- sigmaMax := float32(1.0)
- sigmaMin := 1.0 / float32(numSteps) // 1/30 = 0.033 for 30 steps
- if numSteps == 1 {
- sigmas[0] = sigmaMax
- } else {
- for i := 0; i < numSteps; i++ {
- sigmas[i] = sigmaMax + float32(i)*(sigmaMin-sigmaMax)/float32(numSteps-1)
- }
- }
-
- // Step 2: Apply time shift if using dynamic shifting
- if s.Config.UseDynamicShift && mu != 0 {
- for i := range sigmas {
- sigmas[i] = s.timeShift(mu, sigmas[i])
- }
- }
-
- // Step 3: Apply stretch_shift_to_terminal
- if s.Config.ShiftTerminal > 0 {
- sigmas = s.stretchShiftToTerminal(sigmas)
- }
-
- // Step 4: Append terminal sigma (0) and store
- // Note: Python's scheduler.timesteps are sigmas*1000, but the pipeline divides by 1000
- // before passing to transformer. We skip both steps and just use sigmas directly.
- s.Sigmas = make([]float32, numSteps+1)
- s.Timesteps = make([]float32, numSteps+1)
- for i := 0; i < numSteps; i++ {
- s.Sigmas[i] = sigmas[i]
- s.Timesteps[i] = sigmas[i]
- }
- s.Sigmas[numSteps] = 0.0
- s.Timesteps[numSteps] = 0.0
-}
-
-// stretchShiftToTerminal stretches and shifts the timestep schedule
-// so the final value equals shift_terminal (matches Python behavior)
-func (s *FlowMatchScheduler) stretchShiftToTerminal(sigmas []float32) []float32 {
- if len(sigmas) == 0 {
- return sigmas
- }
-
- // one_minus_z = 1 - t
- // scale_factor = one_minus_z[-1] / (1 - shift_terminal)
- // stretched_t = 1 - (one_minus_z / scale_factor)
- lastSigma := sigmas[len(sigmas)-1]
- scaleFactor := (1.0 - lastSigma) / (1.0 - s.Config.ShiftTerminal)
-
- // Handle edge case: if scaleFactor is 0 or near 0, skip stretch
- // This happens when lastSigma ≈ 1.0 (e.g., single step with timeshift)
- if scaleFactor < 1e-6 {
- return sigmas
- }
-
- result := make([]float32, len(sigmas))
- for i, t := range sigmas {
- oneMinusZ := 1.0 - t
- result[i] = 1.0 - (oneMinusZ / scaleFactor)
- }
- return result
-}
-
-// timeShift applies the dynamic time shift (exponential)
-// exp(mu) / (exp(mu) + (1/t - 1))
-func (s *FlowMatchScheduler) timeShift(mu float32, t float32) float32 {
- if t <= 0 {
- return 0
- }
- expMu := float32(math.Exp(float64(mu)))
- return expMu / (expMu + (1.0/t - 1.0))
-}
-
-// Step performs one denoising step
-// modelOutput: predicted velocity from the transformer
-// sample: current noisy sample
-// timestepIdx: current timestep index
-func (s *FlowMatchScheduler) Step(modelOutput, sample *mlx.Array, timestepIdx int) *mlx.Array {
- // Get current and next sigma
- sigma := s.Sigmas[timestepIdx]
- sigmaNext := s.Sigmas[timestepIdx+1]
-
- // Euler step: x_{t-dt} = x_t + (sigma_next - sigma) * v_t
- dt := sigmaNext - sigma
-
- // Upcast to float32 to avoid precision issues (matches Python diffusers)
- sampleF32 := mlx.AsType(sample, mlx.DtypeFloat32)
- modelOutputF32 := mlx.AsType(modelOutput, mlx.DtypeFloat32)
-
- scaledOutput := mlx.MulScalar(modelOutputF32, dt)
- result := mlx.Add(sampleF32, scaledOutput)
-
- // Cast back to original dtype
- return mlx.ToBFloat16(result)
-}
-
-// GetTimestep returns the timestep value at the given index
-func (s *FlowMatchScheduler) GetTimestep(idx int) float32 {
- if idx < len(s.Timesteps) {
- return s.Timesteps[idx]
- }
- return 0.0
-}
-
-// InitNoise creates initial noise for sampling in unpacked format [B, C, T, H, W]
-func (s *FlowMatchScheduler) InitNoise(shape []int32, seed int64) *mlx.Array {
- return mlx.RandomNormal(shape, uint64(seed))
-}
-
-// InitNoisePacked creates initial noise directly in packed format [B, L, C*4]
-// This matches how Python diffusers generates noise - directly in packed space.
-// Generating in unpacked format and then packing produces different spatial
-// correlation structure, which affects model output quality.
-func (s *FlowMatchScheduler) InitNoisePacked(batchSize, seqLen, channels int32, seed int64) *mlx.Array {
- shape := []int32{batchSize, seqLen, channels}
- return mlx.RandomNormal(shape, uint64(seed))
-}
-
-// GetLatentShape returns the latent shape for a given image size
-// For qwen_image: VAE downscale is 8x (spatial), latent has 16 channels
-func GetLatentShape(batchSize, height, width int32) []int32 {
- latentH := height / 8
- latentW := width / 8
- return []int32{batchSize, 16, 1, latentH, latentW} // [B, C, T, H, W]
-}
-
-// GetPatchedLatentShape returns the patchified latent shape
-// After patchification: [B, L, C*patch_size^2] where L = H/2 * W/2
-func GetPatchedLatentShape(batchSize, height, width, patchSize int32) []int32 {
- latentH := height / 8
- latentW := width / 8
- pH := latentH / patchSize
- pW := latentW / patchSize
- inChannels := int32(64) // 16 * patch_size^2
- return []int32{batchSize, pH * pW, inChannels}
-}
diff --git a/x/imagegen/models/qwen_image/scheduler_test.go b/x/imagegen/models/qwen_image/scheduler_test.go
deleted file mode 100644
index 46adeb99a26..00000000000
--- a/x/imagegen/models/qwen_image/scheduler_test.go
+++ /dev/null
@@ -1,135 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "math"
- "testing"
-)
-
-// TestSchedulerSetTimesteps verifies scheduler sigmas match Python diffusers reference.
-// Golden values generated via:
-//
-// python3 -c "
-// from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
-// import numpy as np
-// s = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, base_shift=0.5, max_shift=0.9,
-// base_image_seq_len=256, max_image_seq_len=8192, shift_terminal=0.02, use_dynamic_shifting=True)
-// mu = 4096 * (0.9-0.5)/(8192-256) + 0.5 - (0.9-0.5)/(8192-256)*256
-// sigmas = np.linspace(1.0, 1.0/30, 30)
-// s.set_timesteps(sigmas=sigmas, mu=mu)
-// print(s.sigmas.numpy())"
-func TestSchedulerSetTimesteps(t *testing.T) {
- cfg := DefaultSchedulerConfig()
- scheduler := NewFlowMatchScheduler(cfg)
- scheduler.SetTimesteps(30, 4096)
-
- // Golden values from Python diffusers (first 3, last 3 before terminal)
- wantFirst := []float32{1.000000, 0.982251, 0.963889}
- wantLast := []float32{0.142924, 0.083384, 0.020000}
-
- // Check first 3
- for i, want := range wantFirst {
- got := scheduler.Sigmas[i]
- if abs32(got-want) > 1e-4 {
- t.Errorf("sigma[%d]: got %v, want %v", i, got, want)
- }
- }
-
- // Check last 3 (indices 27, 28, 29)
- for i, want := range wantLast {
- idx := 27 + i
- got := scheduler.Sigmas[idx]
- if abs32(got-want) > 1e-4 {
- t.Errorf("sigma[%d]: got %v, want %v", idx, got, want)
- }
- }
-
- // Check terminal is 0
- if scheduler.Sigmas[30] != 0.0 {
- t.Errorf("terminal sigma: got %v, want 0", scheduler.Sigmas[30])
- }
-
- // Check length
- if len(scheduler.Sigmas) != 31 {
- t.Errorf("sigmas length: got %d, want 31", len(scheduler.Sigmas))
- }
-}
-
-// TestSchedulerProperties tests mathematical invariants of the scheduler.
-func TestSchedulerProperties(t *testing.T) {
- cfg := DefaultSchedulerConfig()
- scheduler := NewFlowMatchScheduler(cfg)
- scheduler.SetTimesteps(30, 4096)
-
- // Property: sigmas monotonically decreasing
- for i := 1; i < len(scheduler.Sigmas); i++ {
- if scheduler.Sigmas[i] > scheduler.Sigmas[i-1] {
- t.Errorf("sigmas not monotonically decreasing at %d: %v > %v",
- i, scheduler.Sigmas[i], scheduler.Sigmas[i-1])
- }
- }
-
- // Property: first sigma should be ~1.0 (with time shift)
- if scheduler.Sigmas[0] < 0.9 || scheduler.Sigmas[0] > 1.01 {
- t.Errorf("first sigma out of expected range [0.9, 1.01]: %v", scheduler.Sigmas[0])
- }
-
- // Property: terminal sigma should be exactly 0
- if scheduler.Sigmas[len(scheduler.Sigmas)-1] != 0.0 {
- t.Errorf("terminal sigma should be 0, got %v", scheduler.Sigmas[len(scheduler.Sigmas)-1])
- }
-
- // Property: last non-terminal sigma should be shift_terminal (0.02)
- lastNonTerminal := scheduler.Sigmas[len(scheduler.Sigmas)-2]
- if abs32(lastNonTerminal-0.02) > 1e-5 {
- t.Errorf("last non-terminal sigma should be 0.02, got %v", lastNonTerminal)
- }
-
- // Property: length = steps + 1
- if len(scheduler.Sigmas) != scheduler.NumSteps+1 {
- t.Errorf("sigmas length should be steps+1: got %d, want %d",
- len(scheduler.Sigmas), scheduler.NumSteps+1)
- }
-}
-
-// TestCalculateShift verifies the mu calculation against Python reference.
-// Golden values from: mu = img_seq_len * m + b where m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
-func TestCalculateShift(t *testing.T) {
- cases := []struct {
- imgSeqLen int32
- want float32
- }{
- {256, 0.5}, // base case
- {8192, 0.9}, // max case
- {4096, 0.6935}, // middle case (rounded)
- }
-
- for _, c := range cases {
- got := CalculateShift(c.imgSeqLen, 256, 8192, 0.5, 0.9)
- if abs32(got-c.want) > 0.001 {
- t.Errorf("CalculateShift(%d): got %v, want %v", c.imgSeqLen, got, c.want)
- }
- }
-}
-
-// TestSchedulerStep verifies the Euler step formula.
-func TestSchedulerStep(t *testing.T) {
- cfg := DefaultSchedulerConfig()
- scheduler := NewFlowMatchScheduler(cfg)
- scheduler.SetTimesteps(30, 4096)
-
- // Verify dt calculation for first step
- sigma0 := scheduler.Sigmas[0]
- sigma1 := scheduler.Sigmas[1]
- expectedDt := sigma1 - sigma0
-
- // dt should be negative (sigmas decrease)
- if expectedDt >= 0 {
- t.Errorf("expected negative dt, got %v (sigma0=%v, sigma1=%v)", expectedDt, sigma0, sigma1)
- }
-}
-
-func abs32(x float32) float32 {
- return float32(math.Abs(float64(x)))
-}
diff --git a/x/imagegen/models/qwen_image/text_encoder_test.go b/x/imagegen/models/qwen_image/text_encoder_test.go
deleted file mode 100644
index 7704513c856..00000000000
--- a/x/imagegen/models/qwen_image/text_encoder_test.go
+++ /dev/null
@@ -1,174 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "encoding/json"
- "math"
- "os"
- "path/filepath"
- "slices"
- "testing"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/safetensors"
-)
-
-// TinyTextEncoderConfig holds config for the tiny test text encoder
-type TinyTextEncoderConfig struct {
- HiddenSize int32 `json:"hidden_size"`
- NumHiddenLayers int32 `json:"num_hidden_layers"`
- IntermediateSize int32 `json:"intermediate_size"`
- NumAttentionHeads int32 `json:"num_attention_heads"`
- NumKeyValueHeads int32 `json:"num_key_value_heads"`
- VocabSize int32 `json:"vocab_size"`
- RMSNormEps float32 `json:"rms_norm_eps"`
- RopeTheta float32 `json:"rope_theta"`
- HeadDim int32 `json:"head_dim"`
- MRoPESection []int32 `json:"mrope_section"`
-}
-
-// loadTinyTextEncoder loads the tiny text encoder from testdata
-func loadTinyTextEncoder(t *testing.T) (*Qwen25VL, *TinyTextEncoderConfig) {
- t.Helper()
-
- testdataDir := filepath.Join("testdata", "tiny_text_encoder")
-
- // Load config
- configData, err := os.ReadFile(filepath.Join(testdataDir, "config.json"))
- if err != nil {
- t.Skipf("Skipping: tiny weights not found. Regenerate with Python (see models/CLAUDE.md)")
- }
-
- var tinyCfg TinyTextEncoderConfig
- if err := json.Unmarshal(configData, &tinyCfg); err != nil {
- t.Fatalf("Failed to parse config: %v", err)
- }
-
- // Create encoder config (using Qwen25VLConfig)
- cfg := &Qwen25VLConfig{
- HiddenSize: tinyCfg.HiddenSize,
- NumHiddenLayers: tinyCfg.NumHiddenLayers,
- IntermediateSize: tinyCfg.IntermediateSize,
- NumAttentionHeads: tinyCfg.NumAttentionHeads,
- NumKeyValueHeads: tinyCfg.NumKeyValueHeads,
- VocabSize: tinyCfg.VocabSize,
- RMSNormEps: tinyCfg.RMSNormEps,
- RopeTheta: tinyCfg.RopeTheta,
- HeadDim: tinyCfg.HeadDim,
- MRoPESection: tinyCfg.MRoPESection,
- }
-
- // Load weights
- weights, err := safetensors.LoadModelWeights(testdataDir)
- if err != nil {
- t.Fatalf("Failed to load weights: %v", err)
- }
-
- if err := weights.Load(mlx.DtypeBFloat16); err != nil {
- t.Fatalf("Failed to bulk load weights: %v", err)
- }
-
- // Build encoder
- embedding, err := weights.Get("model.embed_tokens.weight")
- if err != nil {
- t.Fatalf("Failed to get embedding: %v", err)
- }
-
- blocks := make([]*VLTextBlock, cfg.NumHiddenLayers)
- for i := int32(0); i < cfg.NumHiddenLayers; i++ {
- block, err := newVLTextBlock(weights, int(i), cfg)
- if err != nil {
- t.Fatalf("Failed to load block %d: %v", i, err)
- }
- blocks[i] = block
- }
-
- finalNorm, err := weights.Get("model.norm.weight")
- if err != nil {
- t.Fatalf("Failed to get final norm: %v", err)
- }
-
- encoder := &Qwen25VL{
- Config: cfg,
- Embedding: embedding,
- Blocks: blocks,
- FinalNorm: finalNorm,
- HasVision: false, // Text-only mode
- }
-
- return encoder, &tinyCfg
-}
-
-// TestTextEncoderForward verifies the text encoder forward pass with tiny weights.
-func TestTextEncoderForward(t *testing.T) {
- encoder, cfg := loadTinyTextEncoder(t)
-
- // Create test tokens (within vocab range)
- tokens := []int32{1, 2, 3, 4, 5}
-
- // Forward pass using EncodeTextOnly
- out := encoder.EncodeTextOnly(tokens)
- mlx.Eval(out)
-
- // Verify output shape: [batch, seq_len, hidden_size]
- wantShape := []int32{1, 5, cfg.HiddenSize}
- if !slices.Equal(out.Shape(), wantShape) {
- t.Errorf("output shape: got %v, want %v", out.Shape(), wantShape)
- }
-
- // Verify output is finite (not NaN or Inf)
- data := out.Data()
- for i, v := range data {
- if math.IsNaN(float64(v)) || math.IsInf(float64(v), 0) {
- t.Errorf("output[%d] is not finite: %v", i, v)
- break
- }
- }
-}
-
-// TestTextEncoderBatch tests batch processing.
-func TestTextEncoderBatch(t *testing.T) {
- encoder, cfg := loadTinyTextEncoder(t)
-
- // For batch test, we'll use EncodeTextOnly with a single sequence
- // (EncodeTextOnly doesn't support batch, but we can verify single sequence works)
- tokens := []int32{1, 2, 3}
-
- out := encoder.EncodeTextOnly(tokens)
- mlx.Eval(out)
-
- wantShape := []int32{1, 3, cfg.HiddenSize}
- if !slices.Equal(out.Shape(), wantShape) {
- t.Errorf("shape: got %v, want %v", out.Shape(), wantShape)
- }
-}
-
-// TestMRoPEComputation verifies M-RoPE frequency computation produces valid values.
-func TestMRoPEComputation(t *testing.T) {
- encoder, cfg := loadTinyTextEncoder(t)
-
- cossin := encoder.computeTextRoPE(10, 1)
- mlx.Eval(cossin[0], cossin[1])
-
- // Verify shapes: [3, B, L, head_dim]
- wantShape := []int32{3, 1, 10, cfg.HeadDim}
- if !slices.Equal(cossin[0].Shape(), wantShape) {
- t.Errorf("cos shape: got %v, want %v", cossin[0].Shape(), wantShape)
- }
- if !slices.Equal(cossin[1].Shape(), wantShape) {
- t.Errorf("sin shape: got %v, want %v", cossin[1].Shape(), wantShape)
- }
-
- // Verify cos/sin values are in valid range [-1, 1]
- cosData := cossin[0].Data()
- sinData := cossin[1].Data()
- for i := 0; i < min(100, len(cosData)); i++ {
- if cosData[i] < -1.01 || cosData[i] > 1.01 {
- t.Errorf("cos[%d] out of range: %v", i, cosData[i])
- }
- if sinData[i] < -1.01 || sinData[i] > 1.01 {
- t.Errorf("sin[%d] out of range: %v", i, sinData[i])
- }
- }
-}
diff --git a/x/imagegen/models/qwen_image/transformer.go b/x/imagegen/models/qwen_image/transformer.go
deleted file mode 100644
index 06e677619e1..00000000000
--- a/x/imagegen/models/qwen_image/transformer.go
+++ /dev/null
@@ -1,868 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "fmt"
- "math"
- "path/filepath"
-
- "github.com/ollama/ollama/x/imagegen/cache"
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/safetensors"
-)
-
-// TransformerConfig holds Qwen-Image transformer configuration
-type TransformerConfig struct {
- HiddenDim int32 `json:"hidden_dim"` // 3072 (24 * 128)
- NHeads int32 `json:"num_attention_heads"` // 24
- HeadDim int32 `json:"attention_head_dim"` // 128
- NLayers int32 `json:"num_layers"` // 60
- InChannels int32 `json:"in_channels"` // 64
- OutChannels int32 `json:"out_channels"` // 16
- PatchSize int32 `json:"patch_size"` // 2
- JointAttentionDim int32 `json:"joint_attention_dim"` // 3584 (text encoder dim)
- NormEps float32 `json:"norm_eps"` // 1e-6
- AxesDimsRope []int32 `json:"axes_dims_rope"` // [16, 56, 56]
- GuidanceEmbeds bool `json:"guidance_embeds"` // false
-}
-
-// defaultTransformerConfig returns config for Qwen-Image transformer
-func defaultTransformerConfig() *TransformerConfig {
- return &TransformerConfig{
- HiddenDim: 3072, // 24 * 128
- NHeads: 24,
- HeadDim: 128,
- NLayers: 60,
- InChannels: 64,
- OutChannels: 16,
- PatchSize: 2,
- JointAttentionDim: 3584,
- NormEps: 1e-6,
- AxesDimsRope: []int32{16, 56, 56},
- GuidanceEmbeds: false,
- }
-}
-
-// TimestepEmbedder creates timestep embeddings
-type TimestepEmbedder struct {
- Linear1Weight *mlx.Array // [256, hidden_dim]
- Linear1Bias *mlx.Array
- Linear2Weight *mlx.Array // [hidden_dim, hidden_dim]
- Linear2Bias *mlx.Array
-}
-
-// newTimestepEmbedder creates a timestep embedder from weights
-func newTimestepEmbedder(weights *safetensors.ModelWeights) (*TimestepEmbedder, error) {
- linear1Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_1.weight")
- if err != nil {
- return nil, err
- }
- linear1Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_1.bias")
- if err != nil {
- return nil, err
- }
- linear2Weight, err := weights.Get("time_text_embed.timestep_embedder.linear_2.weight")
- if err != nil {
- return nil, err
- }
- linear2Bias, err := weights.Get("time_text_embed.timestep_embedder.linear_2.bias")
- if err != nil {
- return nil, err
- }
-
- return &TimestepEmbedder{
- Linear1Weight: mlx.Transpose(linear1Weight, 1, 0),
- Linear1Bias: linear1Bias,
- Linear2Weight: mlx.Transpose(linear2Weight, 1, 0),
- Linear2Bias: linear2Bias,
- }, nil
-}
-
-// Forward computes timestep embeddings
-// t: [B] timesteps (normalized 0-1, will be scaled by 1000 internally)
-func (te *TimestepEmbedder) Forward(t *mlx.Array) *mlx.Array {
- half := int32(128) // embedding_dim / 2
-
- // Sinusoidal embedding with flip_sin_to_cos=True, scale=1000
- freqs := make([]float32, half)
- for i := int32(0); i < half; i++ {
- freqs[i] = float32(math.Exp(-math.Log(10000.0) * float64(i) / float64(half)))
- }
- freqsArr := mlx.NewArray(freqs, []int32{1, half})
-
- tExpanded := mlx.ExpandDims(t, 1)
- args := mlx.Mul(tExpanded, freqsArr)
- args = mlx.MulScalar(args, 1000.0) // scale
-
- // [cos, sin] (flip_sin_to_cos=True)
- sinArgs := mlx.Sin(args)
- cosArgs := mlx.Cos(args)
- embedding := mlx.Concatenate([]*mlx.Array{cosArgs, sinArgs}, 1) // [B, 256]
-
- // MLP: linear1 -> silu -> linear2
- h := mlx.Linear(embedding, te.Linear1Weight)
- h = mlx.Add(h, te.Linear1Bias)
- h = mlx.SiLU(h)
- h = mlx.Linear(h, te.Linear2Weight)
- h = mlx.Add(h, te.Linear2Bias)
-
- return h
-}
-
-// JointAttention implements dual-stream joint attention
-type JointAttention struct {
- // Image projections
- ToQ *mlx.Array
- ToQB *mlx.Array
- ToK *mlx.Array
- ToKB *mlx.Array
- ToV *mlx.Array
- ToVB *mlx.Array
- ToOut *mlx.Array
- ToOutB *mlx.Array
- NormQ *mlx.Array
- NormK *mlx.Array
-
- // Text (added) projections
- AddQProj *mlx.Array
- AddQProjB *mlx.Array
- AddKProj *mlx.Array
- AddKProjB *mlx.Array
- AddVProj *mlx.Array
- AddVProjB *mlx.Array
- ToAddOut *mlx.Array
- ToAddOutB *mlx.Array
- NormAddQ *mlx.Array
- NormAddK *mlx.Array
-
- NHeads int32
- HeadDim int32
- Scale float32
-}
-
-// newJointAttention creates a joint attention layer
-func newJointAttention(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*JointAttention, error) {
- toQ, _ := weights.Get(prefix + ".attn.to_q.weight")
- toQB, _ := weights.Get(prefix + ".attn.to_q.bias")
- toK, _ := weights.Get(prefix + ".attn.to_k.weight")
- toKB, _ := weights.Get(prefix + ".attn.to_k.bias")
- toV, _ := weights.Get(prefix + ".attn.to_v.weight")
- toVB, _ := weights.Get(prefix + ".attn.to_v.bias")
- toOut, _ := weights.Get(prefix + ".attn.to_out.0.weight")
- toOutB, _ := weights.Get(prefix + ".attn.to_out.0.bias")
- normQ, _ := weights.Get(prefix + ".attn.norm_q.weight")
- normK, _ := weights.Get(prefix + ".attn.norm_k.weight")
-
- addQProj, _ := weights.Get(prefix + ".attn.add_q_proj.weight")
- addQProjB, _ := weights.Get(prefix + ".attn.add_q_proj.bias")
- addKProj, _ := weights.Get(prefix + ".attn.add_k_proj.weight")
- addKProjB, _ := weights.Get(prefix + ".attn.add_k_proj.bias")
- addVProj, _ := weights.Get(prefix + ".attn.add_v_proj.weight")
- addVProjB, _ := weights.Get(prefix + ".attn.add_v_proj.bias")
- toAddOut, _ := weights.Get(prefix + ".attn.to_add_out.weight")
- toAddOutB, _ := weights.Get(prefix + ".attn.to_add_out.bias")
- normAddQ, _ := weights.Get(prefix + ".attn.norm_added_q.weight")
- normAddK, _ := weights.Get(prefix + ".attn.norm_added_k.weight")
-
- return &JointAttention{
- ToQ: mlx.Transpose(toQ, 1, 0),
- ToQB: toQB,
- ToK: mlx.Transpose(toK, 1, 0),
- ToKB: toKB,
- ToV: mlx.Transpose(toV, 1, 0),
- ToVB: toVB,
- ToOut: mlx.Transpose(toOut, 1, 0),
- ToOutB: toOutB,
- NormQ: normQ,
- NormK: normK,
- AddQProj: mlx.Transpose(addQProj, 1, 0),
- AddQProjB: addQProjB,
- AddKProj: mlx.Transpose(addKProj, 1, 0),
- AddKProjB: addKProjB,
- AddVProj: mlx.Transpose(addVProj, 1, 0),
- AddVProjB: addVProjB,
- ToAddOut: mlx.Transpose(toAddOut, 1, 0),
- ToAddOutB: toAddOutB,
- NormAddQ: normAddQ,
- NormAddK: normAddK,
- NHeads: cfg.NHeads,
- HeadDim: cfg.HeadDim,
- Scale: float32(1.0 / math.Sqrt(float64(cfg.HeadDim))),
- }, nil
-}
-
-// Forward computes joint attention
-// img: [B, L_img, D], txt: [B, L_txt, D]
-// imgFreqs, txtFreqs: complex RoPE frequencies [L, head_dim/2] as interleaved real/imag
-func (attn *JointAttention) Forward(img, txt *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
- imgShape := img.Shape()
- B := imgShape[0]
- Limg := imgShape[1]
- D := imgShape[2]
-
- txtShape := txt.Shape()
- Ltxt := txtShape[1]
-
- // === Image Q/K/V ===
- imgFlat := mlx.Reshape(img, B*Limg, D)
- qImg := mlx.Add(mlx.Linear(imgFlat, attn.ToQ), attn.ToQB)
- kImg := mlx.Add(mlx.Linear(imgFlat, attn.ToK), attn.ToKB)
- vImg := mlx.Add(mlx.Linear(imgFlat, attn.ToV), attn.ToVB)
-
- qImg = mlx.Reshape(qImg, B, Limg, attn.NHeads, attn.HeadDim)
- kImg = mlx.Reshape(kImg, B, Limg, attn.NHeads, attn.HeadDim)
- vImg = mlx.Reshape(vImg, B, Limg, attn.NHeads, attn.HeadDim)
-
- // QK norm (RMSNorm per head)
- qImg = mlx.RMSNorm(qImg, attn.NormQ, 1e-6)
- kImg = mlx.RMSNorm(kImg, attn.NormK, 1e-6)
-
- // Apply RoPE
- if imgFreqs != nil {
- qImg = applyRoPE(qImg, imgFreqs)
- kImg = applyRoPE(kImg, imgFreqs)
- }
-
- // === Text Q/K/V ===
- txtFlat := mlx.Reshape(txt, B*Ltxt, D)
- qTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddQProj), attn.AddQProjB)
- kTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddKProj), attn.AddKProjB)
- vTxt := mlx.Add(mlx.Linear(txtFlat, attn.AddVProj), attn.AddVProjB)
-
- qTxt = mlx.Reshape(qTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
- kTxt = mlx.Reshape(kTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
- vTxt = mlx.Reshape(vTxt, B, Ltxt, attn.NHeads, attn.HeadDim)
-
- qTxt = mlx.RMSNorm(qTxt, attn.NormAddQ, 1e-6)
- kTxt = mlx.RMSNorm(kTxt, attn.NormAddK, 1e-6)
-
- if txtFreqs != nil {
- qTxt = applyRoPE(qTxt, txtFreqs)
- kTxt = applyRoPE(kTxt, txtFreqs)
- }
-
- // Concatenate for joint attention: [txt, img] order
- qJoint := mlx.Concatenate([]*mlx.Array{qTxt, qImg}, 1)
- kJoint := mlx.Concatenate([]*mlx.Array{kTxt, kImg}, 1)
- vJoint := mlx.Concatenate([]*mlx.Array{vTxt, vImg}, 1)
-
- // Transpose to [B, nheads, L, head_dim]
- qJoint = mlx.Transpose(qJoint, 0, 2, 1, 3)
- kJoint = mlx.Transpose(kJoint, 0, 2, 1, 3)
- vJoint = mlx.Transpose(vJoint, 0, 2, 1, 3)
-
- // SDPA
- outJoint := mlx.ScaledDotProductAttention(qJoint, kJoint, vJoint, attn.Scale, false)
-
- // Transpose back and split
- outJoint = mlx.Transpose(outJoint, 0, 2, 1, 3) // [B, L, nheads, head_dim]
- outJoint = mlx.Reshape(outJoint, B, Ltxt+Limg, D)
-
- outTxt := mlx.Slice(outJoint, []int32{0, 0, 0}, []int32{B, Ltxt, D})
- outImg := mlx.Slice(outJoint, []int32{0, Ltxt, 0}, []int32{B, Ltxt + Limg, D})
-
- // Output projections
- outImg = mlx.Reshape(outImg, B*Limg, D)
- outImg = mlx.Add(mlx.Linear(outImg, attn.ToOut), attn.ToOutB)
- outImg = mlx.Reshape(outImg, B, Limg, D)
-
- outTxt = mlx.Reshape(outTxt, B*Ltxt, D)
- outTxt = mlx.Add(mlx.Linear(outTxt, attn.ToAddOut), attn.ToAddOutB)
- outTxt = mlx.Reshape(outTxt, B, Ltxt, D)
-
- return outImg, outTxt
-}
-
-// applyRoPE applies rotary embeddings using complex multiplication
-// x: [B, L, nheads, head_dim]
-// freqs: [L, head_dim] as complex (interleaved real/imag pairs)
-func applyRoPE(x *mlx.Array, freqs *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- L := shape[1]
- nheads := shape[2]
- headDim := shape[3]
- halfDim := headDim / 2
-
- // Reshape x to pairs: [B, L, nheads, half, 2]
- xPairs := mlx.Reshape(x, B, L, nheads, halfDim, 2)
-
- // freqs: [L, head_dim] -> [1, L, 1, half, 2]
- freqsExp := mlx.Reshape(freqs, 1, L, 1, halfDim, 2)
-
- // Extract real/imag parts
- xReal := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 0}, []int32{B, L, nheads, halfDim, 1}, []int32{1, 1, 1, 1, 1})
- xImag := mlx.SliceStride(xPairs, []int32{0, 0, 0, 0, 1}, []int32{B, L, nheads, halfDim, 2}, []int32{1, 1, 1, 1, 1})
- xReal = mlx.Squeeze(xReal, 4)
- xImag = mlx.Squeeze(xImag, 4)
-
- freqReal := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 0}, []int32{1, L, 1, halfDim, 1}, []int32{1, 1, 1, 1, 1})
- freqImag := mlx.SliceStride(freqsExp, []int32{0, 0, 0, 0, 1}, []int32{1, L, 1, halfDim, 2}, []int32{1, 1, 1, 1, 1})
- freqReal = mlx.Squeeze(freqReal, 4)
- freqImag = mlx.Squeeze(freqImag, 4)
-
- // Complex multiplication: (a + bi) * (c + di) = (ac - bd) + (ad + bc)i
- outReal := mlx.Sub(mlx.Mul(xReal, freqReal), mlx.Mul(xImag, freqImag))
- outImag := mlx.Add(mlx.Mul(xReal, freqImag), mlx.Mul(xImag, freqReal))
-
- // Interleave back
- outReal = mlx.ExpandDims(outReal, 4)
- outImag = mlx.ExpandDims(outImag, 4)
- out := mlx.Concatenate([]*mlx.Array{outReal, outImag}, 4)
-
- return mlx.Reshape(out, B, L, nheads, headDim)
-}
-
-// MLP implements GELU MLP (not GEGLU)
-type MLP struct {
- ProjWeight *mlx.Array
- ProjBias *mlx.Array
- OutWeight *mlx.Array
- OutBias *mlx.Array
-}
-
-// newMLP creates a GELU MLP
-func newMLP(weights *safetensors.ModelWeights, prefix string) (*MLP, error) {
- projWeight, _ := weights.Get(prefix + ".net.0.proj.weight")
- projBias, _ := weights.Get(prefix + ".net.0.proj.bias")
- outWeight, _ := weights.Get(prefix + ".net.2.weight")
- outBias, _ := weights.Get(prefix + ".net.2.bias")
-
- return &MLP{
- ProjWeight: mlx.Transpose(projWeight, 1, 0),
- ProjBias: projBias,
- OutWeight: mlx.Transpose(outWeight, 1, 0),
- OutBias: outBias,
- }, nil
-}
-
-// Forward applies GELU MLP
-func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- L := shape[1]
- D := shape[2]
-
- xFlat := mlx.Reshape(x, B*L, D)
- h := mlx.Add(mlx.Linear(xFlat, m.ProjWeight), m.ProjBias)
- h = geluApprox(h)
- h = mlx.Add(mlx.Linear(h, m.OutWeight), m.OutBias)
- return mlx.Reshape(h, B, L, m.OutBias.Dim(0))
-}
-
-// geluApprox implements approximate GELU
-func geluApprox(x *mlx.Array) *mlx.Array {
- sqrt2OverPi := float32(math.Sqrt(2.0 / math.Pi))
- x3 := mlx.Mul(mlx.Mul(x, x), x)
- inner := mlx.Add(x, mlx.MulScalar(x3, 0.044715))
- inner = mlx.MulScalar(inner, sqrt2OverPi)
- return mlx.Mul(mlx.MulScalar(x, 0.5), mlx.AddScalar(mlx.Tanh(inner), 1.0))
-}
-
-// TransformerBlock is a single dual-stream transformer block
-type TransformerBlock struct {
- Attention *JointAttention
- ImgMLP *MLP
- TxtMLP *MLP
-
- ImgModWeight *mlx.Array
- ImgModBias *mlx.Array
- TxtModWeight *mlx.Array
- TxtModBias *mlx.Array
-
- HiddenDim int32
- NormEps float32
-}
-
-// newTransformerBlock creates a transformer block
-func newTransformerBlock(weights *safetensors.ModelWeights, prefix string, cfg *TransformerConfig) (*TransformerBlock, error) {
- attn, err := newJointAttention(weights, prefix, cfg)
- if err != nil {
- return nil, err
- }
-
- imgMLP, _ := newMLP(weights, prefix+".img_mlp")
- txtMLP, _ := newMLP(weights, prefix+".txt_mlp")
-
- imgModWeight, _ := weights.Get(prefix + ".img_mod.1.weight")
- imgModBias, _ := weights.Get(prefix + ".img_mod.1.bias")
- txtModWeight, _ := weights.Get(prefix + ".txt_mod.1.weight")
- txtModBias, _ := weights.Get(prefix + ".txt_mod.1.bias")
-
- return &TransformerBlock{
- Attention: attn,
- ImgMLP: imgMLP,
- TxtMLP: txtMLP,
- ImgModWeight: mlx.Transpose(imgModWeight, 1, 0),
- ImgModBias: imgModBias,
- TxtModWeight: mlx.Transpose(txtModWeight, 1, 0),
- TxtModBias: txtModBias,
- HiddenDim: cfg.HiddenDim,
- NormEps: cfg.NormEps,
- }, nil
-}
-
-// Forward applies the transformer block
-func (tb *TransformerBlock) Forward(img, txt, temb *mlx.Array, imgFreqs, txtFreqs *mlx.Array) (*mlx.Array, *mlx.Array) {
- // Compute modulation: silu(temb) -> linear -> [B, 6*D]
- siluT := mlx.SiLU(temb)
- imgMod := mlx.Add(mlx.Linear(siluT, tb.ImgModWeight), tb.ImgModBias)
- txtMod := mlx.Add(mlx.Linear(siluT, tb.TxtModWeight), tb.TxtModBias)
-
- // Split into 6 parts: shift1, scale1, gate1, shift2, scale2, gate2
- imgModParts := splitMod6(imgMod, tb.HiddenDim)
- txtModParts := splitMod6(txtMod, tb.HiddenDim)
-
- // Pre-attention: norm + modulate
- imgNorm := layerNormNoAffine(img, tb.NormEps)
- imgNorm = mlx.Add(mlx.Mul(imgNorm, mlx.AddScalar(imgModParts[1], 1.0)), imgModParts[0])
-
- txtNorm := layerNormNoAffine(txt, tb.NormEps)
- txtNorm = mlx.Add(mlx.Mul(txtNorm, mlx.AddScalar(txtModParts[1], 1.0)), txtModParts[0])
-
- // Joint attention
- attnImg, attnTxt := tb.Attention.Forward(imgNorm, txtNorm, imgFreqs, txtFreqs)
-
- // Residual with gate
- img = mlx.Add(img, mlx.Mul(imgModParts[2], attnImg))
- txt = mlx.Add(txt, mlx.Mul(txtModParts[2], attnTxt))
-
- // Pre-MLP: norm + modulate
- imgNorm2 := layerNormNoAffine(img, tb.NormEps)
- imgNorm2 = mlx.Add(mlx.Mul(imgNorm2, mlx.AddScalar(imgModParts[4], 1.0)), imgModParts[3])
-
- txtNorm2 := layerNormNoAffine(txt, tb.NormEps)
- txtNorm2 = mlx.Add(mlx.Mul(txtNorm2, mlx.AddScalar(txtModParts[4], 1.0)), txtModParts[3])
-
- // MLP
- mlpImg := tb.ImgMLP.Forward(imgNorm2)
- mlpTxt := tb.TxtMLP.Forward(txtNorm2)
-
- // Residual with gate
- img = mlx.Add(img, mlx.Mul(imgModParts[5], mlpImg))
- txt = mlx.Add(txt, mlx.Mul(txtModParts[5], mlpTxt))
-
- return img, txt
-}
-
-// splitMod6 splits modulation into 6 parts each [B, 1, D]
-func splitMod6(mod *mlx.Array, hiddenDim int32) []*mlx.Array {
- shape := mod.Shape()
- B := shape[0]
- parts := make([]*mlx.Array, 6)
- for i := int32(0); i < 6; i++ {
- part := mlx.Slice(mod, []int32{0, i * hiddenDim}, []int32{B, (i + 1) * hiddenDim})
- parts[i] = mlx.ExpandDims(part, 1)
- }
- return parts
-}
-
-// layerNormNoAffine applies layer norm without learnable parameters
-func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
- ndim := x.Ndim()
- lastAxis := ndim - 1
- mean := mlx.Mean(x, lastAxis, true)
- xCentered := mlx.Sub(x, mean)
- variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
- return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
-}
-
-// Transformer is the full Qwen-Image transformer model
-type Transformer struct {
- Config *TransformerConfig
-
- ImgIn *mlx.Array
- ImgInBias *mlx.Array
- TxtIn *mlx.Array
- TxtInBias *mlx.Array
- TxtNorm *mlx.Array
-
- TEmbed *TimestepEmbedder
- Layers []*TransformerBlock
-
- NormOutWeight *mlx.Array
- NormOutBias *mlx.Array
- ProjOut *mlx.Array
- ProjOutBias *mlx.Array
-}
-
-// Load loads the transformer from a directory
-func (m *Transformer) Load(path string) error {
- fmt.Println("Loading Qwen-Image transformer...")
-
- cfg := defaultTransformerConfig()
- m.Config = cfg
-
- weights, err := safetensors.LoadModelWeights(path)
- if err != nil {
- return fmt.Errorf("weights: %w", err)
- }
-
- // Bulk load all weights as bf16
- fmt.Print(" Loading weights as bf16... ")
- if err := weights.Load(mlx.DtypeBFloat16); err != nil {
- return fmt.Errorf("load weights: %w", err)
- }
- fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
-
- fmt.Print(" Loading input projections... ")
- imgIn, _ := weights.Get("img_in.weight")
- imgInBias, _ := weights.Get("img_in.bias")
- txtIn, _ := weights.Get("txt_in.weight")
- txtInBias, _ := weights.Get("txt_in.bias")
- txtNorm, _ := weights.Get("txt_norm.weight")
- m.ImgIn = mlx.Transpose(imgIn, 1, 0)
- m.ImgInBias = imgInBias
- m.TxtIn = mlx.Transpose(txtIn, 1, 0)
- m.TxtInBias = txtInBias
- m.TxtNorm = txtNorm
- fmt.Println("✓")
-
- fmt.Print(" Loading timestep embedder... ")
- m.TEmbed, err = newTimestepEmbedder(weights)
- if err != nil {
- return fmt.Errorf("timestep embedder: %w", err)
- }
- fmt.Println("✓")
-
- m.Layers = make([]*TransformerBlock, cfg.NLayers)
- for i := int32(0); i < cfg.NLayers; i++ {
- fmt.Printf("\r Loading transformer layers... %d/%d", i+1, cfg.NLayers)
- prefix := fmt.Sprintf("transformer_blocks.%d", i)
- m.Layers[i], err = newTransformerBlock(weights, prefix, cfg)
- if err != nil {
- return fmt.Errorf("layer %d: %w", i, err)
- }
- }
- fmt.Printf("\r Loading transformer layers... ✓ [%d blocks] \n", cfg.NLayers)
-
- fmt.Print(" Loading output layers... ")
- normOutWeight, _ := weights.Get("norm_out.linear.weight")
- normOutBias, _ := weights.Get("norm_out.linear.bias")
- projOut, _ := weights.Get("proj_out.weight")
- projOutBias, _ := weights.Get("proj_out.bias")
- m.NormOutWeight = mlx.Transpose(normOutWeight, 1, 0)
- m.NormOutBias = normOutBias
- m.ProjOut = mlx.Transpose(projOut, 1, 0)
- m.ProjOutBias = projOutBias
- fmt.Println("✓")
-
- weights.ReleaseAll()
- return nil
-}
-
-// LoadFromPath is a convenience function to load transformer from path
-func LoadTransformerFromPath(path string) (*Transformer, error) {
- m := &Transformer{}
- if err := m.Load(filepath.Join(path, "transformer")); err != nil {
- return nil, err
- }
- return m, nil
-}
-
-// Forward runs the transformer
-// img: [B, L_img, in_channels] patchified latents
-// txt: [B, L_txt, joint_attention_dim] text embeddings
-// t: [B] timesteps (0-1)
-// imgFreqs, txtFreqs: RoPE frequencies
-func (tr *Transformer) Forward(img, txt, t *mlx.Array, imgFreqs, txtFreqs *mlx.Array) *mlx.Array {
- imgShape := img.Shape()
- B := imgShape[0]
- Limg := imgShape[1]
-
- txtShape := txt.Shape()
- Ltxt := txtShape[1]
-
- // Timestep embedding
- temb := tr.TEmbed.Forward(t)
-
- // Project image: [B, L, in_channels] -> [B, L, hidden_dim]
- imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
- imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
- imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
-
- // Project text: RMSNorm then linear
- txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
- txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
- txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
- txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
-
- for _, layer := range tr.Layers {
- imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
- }
-
- // Final norm with modulation (AdaLayerNormContinuous)
- // Python: scale, shift = torch.chunk(emb, 2, dim=1)
- finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
- modShape := finalMod.Shape()
- halfDim := modShape[1] / 2
- scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
- shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
-
- imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
- imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
-
- // Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
- imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
- out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
-
- outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
- return mlx.Reshape(out, B, Limg, outChannels)
-}
-
-// ForwardWithCache runs the transformer with layer caching for speedup.
-// Based on DeepCache (CVPR 2024) / Learning-to-Cache (NeurIPS 2024):
-// shallow layers change little between denoising steps, so we cache their
-// outputs and reuse them on non-refresh steps.
-//
-// stepCache: cache for layer outputs (use cache.NewStepCache(cacheLayers))
-// step: current denoising step (0-indexed)
-// cacheInterval: refresh cache every N steps (e.g., 3)
-// cacheLayers: number of shallow layers to cache (e.g., 15)
-func (tr *Transformer) ForwardWithCache(
- img, txt, t *mlx.Array,
- imgFreqs, txtFreqs *mlx.Array,
- stepCache *cache.StepCache,
- step, cacheInterval, cacheLayers int,
-) *mlx.Array {
- imgShape := img.Shape()
- B := imgShape[0]
- Limg := imgShape[1]
-
- txtShape := txt.Shape()
- Ltxt := txtShape[1]
-
- // Timestep embedding
- temb := tr.TEmbed.Forward(t)
-
- // Project image: [B, L, in_channels] -> [B, L, hidden_dim]
- imgFlat := mlx.Reshape(img, B*Limg, tr.Config.InChannels)
- imgH := mlx.Add(mlx.Linear(imgFlat, tr.ImgIn), tr.ImgInBias)
- imgH = mlx.Reshape(imgH, B, Limg, tr.Config.HiddenDim)
-
- // Project text: RMSNorm then linear
- txtFlat := mlx.Reshape(txt, B*Ltxt, tr.Config.JointAttentionDim)
- txtNormed := mlx.RMSNorm(txtFlat, tr.TxtNorm, 1e-6)
- txtH := mlx.Add(mlx.Linear(txtNormed, tr.TxtIn), tr.TxtInBias)
- txtH = mlx.Reshape(txtH, B, Ltxt, tr.Config.HiddenDim)
-
- // Check if we should refresh the cache
- refreshCache := stepCache.ShouldRefresh(step, cacheInterval)
-
- for i, layer := range tr.Layers {
- if i < cacheLayers && !refreshCache && stepCache.Get(i) != nil {
- // Use cached outputs for shallow layers
- imgH = stepCache.Get(i)
- txtH = stepCache.Get2(i)
- } else {
- // Compute layer
- imgH, txtH = layer.Forward(imgH, txtH, temb, imgFreqs, txtFreqs)
- // Cache shallow layers on refresh steps
- if i < cacheLayers && refreshCache {
- stepCache.Set(i, imgH)
- stepCache.Set2(i, txtH)
- }
- }
- }
-
- // Final norm with modulation (AdaLayerNormContinuous)
- finalMod := mlx.Add(mlx.Linear(mlx.SiLU(temb), tr.NormOutWeight), tr.NormOutBias)
- modShape := finalMod.Shape()
- halfDim := modShape[1] / 2
- scale := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, 0}, []int32{B, halfDim}), 1)
- shift := mlx.ExpandDims(mlx.Slice(finalMod, []int32{0, halfDim}, []int32{B, modShape[1]}), 1)
-
- imgH = layerNormNoAffine(imgH, tr.Config.NormEps)
- imgH = mlx.Add(mlx.Mul(imgH, mlx.AddScalar(scale, 1.0)), shift)
-
- // Final projection: [B, L, hidden_dim] -> [B, L, patch_size^2 * out_channels]
- imgFlat = mlx.Reshape(imgH, B*Limg, tr.Config.HiddenDim)
- out := mlx.Add(mlx.Linear(imgFlat, tr.ProjOut), tr.ProjOutBias)
-
- outChannels := tr.Config.PatchSize * tr.Config.PatchSize * tr.Config.OutChannels
- return mlx.Reshape(out, B, Limg, outChannels)
-}
-
-// RoPECache holds precomputed RoPE frequencies
-type RoPECache struct {
- ImgFreqs *mlx.Array // [L_img, head_dim]
- TxtFreqs *mlx.Array // [L_txt, head_dim]
-}
-
-// PrepareRoPE computes RoPE for image and text sequences
-// This matches Python's QwenEmbedRope with scale_rope=True
-func PrepareRoPE(imgH, imgW int32, txtLen int32, axesDims []int32) *RoPECache {
- theta := float64(10000)
- maxIdx := int32(4096)
-
- // Compute base frequencies for each axis dimension
- freqsT := ComputeAxisFreqs(axesDims[0], theta)
- freqsH := ComputeAxisFreqs(axesDims[1], theta)
- freqsW := ComputeAxisFreqs(axesDims[2], theta)
-
- // Build frequency lookup tables
- posFreqsT := MakeFreqTable(maxIdx, freqsT, false)
- posFreqsH := MakeFreqTable(maxIdx, freqsH, false)
- posFreqsW := MakeFreqTable(maxIdx, freqsW, false)
- negFreqsH := MakeFreqTable(maxIdx, freqsH, true)
- negFreqsW := MakeFreqTable(maxIdx, freqsW, true)
-
- // Image frequencies with scale_rope=True
- imgLen := imgH * imgW
- headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
- imgFreqsData := make([]float32, imgLen*headDim)
-
- hHalf := imgH / 2
- wHalf := imgW / 2
-
- idx := int32(0)
- for y := int32(0); y < imgH; y++ {
- for x := int32(0); x < imgW; x++ {
- // Frame = 0
- for i := 0; i < len(freqsT)*2; i++ {
- imgFreqsData[idx+int32(i)] = posFreqsT[0][i]
- }
- idx += int32(len(freqsT) * 2)
-
- // Height: scale_rope pattern
- hNegCount := imgH - hHalf
- if y < hNegCount {
- negTableIdx := maxIdx - hNegCount + y
- for i := 0; i < len(freqsH)*2; i++ {
- imgFreqsData[idx+int32(i)] = negFreqsH[negTableIdx][i]
- }
- } else {
- posIdx := y - hNegCount
- for i := 0; i < len(freqsH)*2; i++ {
- imgFreqsData[idx+int32(i)] = posFreqsH[posIdx][i]
- }
- }
- idx += int32(len(freqsH) * 2)
-
- // Width: scale_rope pattern
- wNegCount := imgW - wHalf
- if x < wNegCount {
- negTableIdx := maxIdx - wNegCount + x
- for i := 0; i < len(freqsW)*2; i++ {
- imgFreqsData[idx+int32(i)] = negFreqsW[negTableIdx][i]
- }
- } else {
- posIdx := x - wNegCount
- for i := 0; i < len(freqsW)*2; i++ {
- imgFreqsData[idx+int32(i)] = posFreqsW[posIdx][i]
- }
- }
- idx += int32(len(freqsW) * 2)
- }
- }
-
- imgFreqs := mlx.NewArray(imgFreqsData, []int32{imgLen, headDim})
- imgFreqs = mlx.ToBFloat16(imgFreqs)
-
- // Text frequencies
- maxVidIdx := max(hHalf, wHalf)
- txtFreqsData := make([]float32, txtLen*headDim)
-
- idx = 0
- for t := int32(0); t < txtLen; t++ {
- pos := maxVidIdx + t
- for i := 0; i < len(freqsT)*2; i++ {
- txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
- }
- idx += int32(len(freqsT) * 2)
- for i := 0; i < len(freqsH)*2; i++ {
- txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
- }
- idx += int32(len(freqsH) * 2)
- for i := 0; i < len(freqsW)*2; i++ {
- txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
- }
- idx += int32(len(freqsW) * 2)
- }
-
- txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
- txtFreqs = mlx.ToBFloat16(txtFreqs)
-
- return &RoPECache{
- ImgFreqs: imgFreqs,
- TxtFreqs: txtFreqs,
- }
-}
-
-// ComputeAxisFreqs computes RoPE base frequencies for a given dimension.
-func ComputeAxisFreqs(dim int32, theta float64) []float64 {
- halfDim := dim / 2
- freqs := make([]float64, halfDim)
- for i := int32(0); i < halfDim; i++ {
- freqs[i] = 1.0 / math.Pow(theta, float64(i)/float64(halfDim))
- }
- return freqs
-}
-
-// MakeFreqTable builds a table of cos/sin values for RoPE positions.
-func MakeFreqTable(maxIdx int32, baseFreqs []float64, negative bool) [][]float32 {
- table := make([][]float32, maxIdx)
- for idx := int32(0); idx < maxIdx; idx++ {
- var pos float64
- if negative {
- pos = float64(-maxIdx + int32(idx))
- } else {
- pos = float64(idx)
- }
-
- row := make([]float32, len(baseFreqs)*2)
- for i, f := range baseFreqs {
- angle := pos * f
- row[i*2] = float32(math.Cos(angle))
- row[i*2+1] = float32(math.Sin(angle))
- }
- table[idx] = row
- }
- return table
-}
-
-func max(a, b int32) int32 {
- if a > b {
- return a
- }
- return b
-}
-
-// PackLatents converts [B, C, H, W] to [B, L, C*4] patches
-func PackLatents(latents *mlx.Array, patchSize int32) *mlx.Array {
- shape := latents.Shape()
- B := shape[0]
- C := shape[1]
- H := shape[2]
- W := shape[3]
-
- pH := H / patchSize
- pW := W / patchSize
-
- // [B, C, H, W] -> [B, C, pH, 2, pW, 2]
- x := mlx.Reshape(latents, B, C, pH, patchSize, pW, patchSize)
- // -> [B, pH, pW, C, 2, 2]
- x = mlx.Transpose(x, 0, 2, 4, 1, 3, 5)
- // -> [B, pH*pW, C*4]
- return mlx.Reshape(x, B, pH*pW, C*patchSize*patchSize)
-}
-
-// UnpackLatents converts [B, L, C*4] back to [B, C, 1, H, W] (5D for VAE)
-func UnpackLatents(patches *mlx.Array, H, W, patchSize int32) *mlx.Array {
- shape := patches.Shape()
- B := shape[0]
- channels := shape[2] / (patchSize * patchSize)
-
- pH := H / patchSize
- pW := W / patchSize
-
- // [B, L, C*4] -> [B, pH, pW, C, 2, 2]
- x := mlx.Reshape(patches, B, pH, pW, channels, patchSize, patchSize)
- // -> [B, C, pH, 2, pW, 2]
- x = mlx.Transpose(x, 0, 3, 1, 4, 2, 5)
- // -> [B, C, H, W]
- x = mlx.Reshape(x, B, channels, pH*patchSize, pW*patchSize)
- // Add temporal dimension for VAE: [B, C, 1, H, W]
- return mlx.ExpandDims(x, 2)
-}
diff --git a/x/imagegen/models/qwen_image/transformer_test.go b/x/imagegen/models/qwen_image/transformer_test.go
deleted file mode 100644
index 5eef53b1d94..00000000000
--- a/x/imagegen/models/qwen_image/transformer_test.go
+++ /dev/null
@@ -1,119 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "math"
- "os"
- "testing"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
-)
-
-// TestTransformerConfig tests configuration invariants.
-func TestTransformerConfig(t *testing.T) {
- cfg := defaultTransformerConfig()
-
- // Property: hidden_dim = n_heads * head_dim
- if cfg.HiddenDim != cfg.NHeads*cfg.HeadDim {
- t.Errorf("hidden_dim != n_heads * head_dim: %d != %d * %d",
- cfg.HiddenDim, cfg.NHeads, cfg.HeadDim)
- }
-
- // Property: axes_dims_rope sums to head_dim
- var ropeSum int32
- for _, d := range cfg.AxesDimsRope {
- ropeSum += d
- }
- if ropeSum != cfg.HeadDim {
- t.Errorf("axes_dims_rope sum != head_dim: %d != %d", ropeSum, cfg.HeadDim)
- }
-
- // Property: in_channels = out_channels * patch_size^2
- expectedIn := cfg.OutChannels * cfg.PatchSize * cfg.PatchSize
- if cfg.InChannels != expectedIn {
- t.Errorf("in_channels != out_channels * patch_size^2: %d != %d", cfg.InChannels, expectedIn)
- }
-}
-
-// TestTransformerRoPE tests RoPE frequency computation produces valid values.
-func TestTransformerRoPE(t *testing.T) {
- cfg := defaultTransformerConfig()
-
- // Test with small image dimensions
- imgH, imgW := int32(4), int32(4) // 4x4 latent = 16 patches
- txtLen := int32(5)
-
- ropeCache := PrepareRoPE(imgH, imgW, txtLen, cfg.AxesDimsRope)
- mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
-
- // Verify shapes: [seq_len, head_dim]
- imgSeqLen := imgH * imgW
- if ropeCache.ImgFreqs.Shape()[0] != imgSeqLen {
- t.Errorf("ImgFreqs seq_len: got %d, want %d", ropeCache.ImgFreqs.Shape()[0], imgSeqLen)
- }
- if ropeCache.ImgFreqs.Shape()[1] != cfg.HeadDim {
- t.Errorf("ImgFreqs head_dim: got %d, want %d", ropeCache.ImgFreqs.Shape()[1], cfg.HeadDim)
- }
-
- if ropeCache.TxtFreqs.Shape()[0] != txtLen {
- t.Errorf("TxtFreqs seq_len: got %d, want %d", ropeCache.TxtFreqs.Shape()[0], txtLen)
- }
-
- // Verify values are finite
- imgData := ropeCache.ImgFreqs.Data()
- for i := 0; i < min(100, len(imgData)); i++ {
- if math.IsNaN(float64(imgData[i])) || math.IsInf(float64(imgData[i]), 0) {
- t.Errorf("ImgFreqs[%d] not finite: %v", i, imgData[i])
- break
- }
- }
-}
-
-// TestTransformerForward tests full forward pass (integration test).
-// Skips if model weights are not available.
-func TestTransformerForward(t *testing.T) {
- weightsPath := "../../../weights/Qwen-Image-2512/transformer"
- if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
- t.Skip("Skipping: model weights not found at " + weightsPath)
- }
-
- transformer := &Transformer{}
- if err := transformer.Load(weightsPath); err != nil {
- t.Fatalf("Failed to load transformer: %v", err)
- }
- mlx.Keep(mlx.Collect(transformer)...)
- cfg := transformer.Config
-
- // Small test inputs
- batchSize := int32(1)
- imgH, imgW := int32(4), int32(4)
- imgSeqLen := imgH * imgW
- txtSeqLen := int32(5)
-
- hiddenStates := mlx.RandomNormal([]int32{batchSize, imgSeqLen, cfg.InChannels}, 0)
- encoderHiddenStates := mlx.RandomNormal([]int32{batchSize, txtSeqLen, cfg.JointAttentionDim}, 0)
- timestep := mlx.NewArray([]float32{0.5}, []int32{batchSize})
-
- ropeCache := PrepareRoPE(imgH, imgW, txtSeqLen, cfg.AxesDimsRope)
-
- // Forward pass
- out := transformer.Forward(hiddenStates, encoderHiddenStates, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
- mlx.Eval(out)
-
- // Verify output shape: [batch, img_seq_len, in_channels]
- wantShape := []int32{batchSize, imgSeqLen, cfg.InChannels}
- gotShape := out.Shape()
- if gotShape[0] != wantShape[0] || gotShape[1] != wantShape[1] || gotShape[2] != wantShape[2] {
- t.Errorf("output shape: got %v, want %v", gotShape, wantShape)
- }
-
- // Verify output is finite
- outData := out.Data()
- for i := 0; i < min(100, len(outData)); i++ {
- if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
- t.Errorf("output[%d] not finite: %v", i, outData[i])
- break
- }
- }
-}
diff --git a/x/imagegen/models/qwen_image/vae.go b/x/imagegen/models/qwen_image/vae.go
deleted file mode 100644
index e1c7f5255b1..00000000000
--- a/x/imagegen/models/qwen_image/vae.go
+++ /dev/null
@@ -1,854 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "fmt"
- "math"
- "path/filepath"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/safetensors"
-)
-
-// VAEConfig holds Qwen-Image VAE configuration
-type VAEConfig struct {
- ZDim int32 `json:"z_dim"` // 16
- BaseDim int32 `json:"base_dim"` // 96
- DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
- NumResBlocks int32 `json:"num_res_blocks"` // 2
- LatentsMean []float32 `json:"latents_mean"` // 16 values
- LatentsStd []float32 `json:"latents_std"` // 16 values
- TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
-}
-
-// defaultVAEConfig returns config for Qwen-Image VAE
-func defaultVAEConfig() *VAEConfig {
- return &VAEConfig{
- ZDim: 16,
- BaseDim: 96,
- DimMult: []int32{1, 2, 4, 4},
- NumResBlocks: 2,
- LatentsMean: []float32{
- -0.7571, -0.7089, -0.9113, 0.1075,
- -0.1745, 0.9653, -0.1517, 1.5508,
- 0.4134, -0.0715, 0.5517, -0.3632,
- -0.1922, -0.9497, 0.2503, -0.2921,
- },
- LatentsStd: []float32{
- 2.8184, 1.4541, 2.3275, 2.6558,
- 1.2196, 1.7708, 2.6052, 2.0743,
- 3.2687, 2.1526, 2.8652, 1.5579,
- 1.6382, 1.1253, 2.8251, 1.916,
- },
- TemperalDownsample: []bool{false, true, true},
- }
-}
-
-// CausalConv3d is a causal 3D convolution (for temporal causality)
-type CausalConv3d struct {
- Weight *mlx.Array
- Bias *mlx.Array
- BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
- KernelT int32
-}
-
-// newCausalConv3d creates a 3D causal conv
-func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
- weight, err := weights.Get(prefix + ".weight")
- if err != nil {
- return nil, fmt.Errorf("weight not found: %s", prefix)
- }
- bias, _ := weights.Get(prefix + ".bias")
-
- kernelT := weight.Shape()[2]
- outC := weight.Shape()[0]
-
- var biasReshaped *mlx.Array
- if bias != nil {
- biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
- }
-
- return &CausalConv3d{
- Weight: weight,
- Bias: bias,
- BiasReshaped: biasReshaped,
- KernelT: kernelT,
- }, nil
-}
-
-// Forward applies causal 3D convolution
-// x: [B, T, H, W, C] (channels-last, MLX format)
-func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
- shape := c.Weight.Shape() // PyTorch format: [O, I, kT, kH, kW]
- kernelT := shape[2]
- kernelH := shape[3]
- kernelW := shape[4]
-
- // Causal temporal padding, same spatial padding
- // Input is channels-last: [B, T, H, W, C]
- padT := kernelT - 1
- padH := kernelH / 2
- padW := kernelW / 2
-
- // Stage 1: Pad
- {
- x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
- mlx.Eval(x)
- }
-
- // Stage 2: Conv + bias
- var out *mlx.Array
- {
- prev := x
- weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
- out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
- if c.Bias != nil {
- bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
- out = mlx.Add(out, bias)
- }
- prev.Free()
- mlx.Eval(out)
- }
-
- return out
-}
-
-// RMSNorm3D applies RMS normalization over channels
-// Works with channels-last [B, T, H, W, C] format
-type RMSNorm3D struct {
- Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
-}
-
-// newRMSNorm3D creates an RMS norm
-func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
- gamma, err := weights.Get(prefix + ".gamma")
- if err != nil {
- return nil, err
- }
- // Reshape for channels-last broadcasting: [1, 1, 1, 1, C]
- gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
- return &RMSNorm3D{Gamma: gamma}, nil
-}
-
-// Forward applies RMS norm to channels-last input [B, T, H, W, C]
-func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
- // RMSNorm: x * rsqrt(mean(x^2) + eps) * gamma
- normalized := mlx.RMSNormNoWeight(x, 1e-6)
- return mlx.Mul(normalized, n.Gamma)
-}
-
-// ResBlock is a residual block with RMS norm and causal convs
-type ResBlock struct {
- Norm1 *RMSNorm3D
- Conv1 *CausalConv3d
- Norm2 *RMSNorm3D
- Conv2 *CausalConv3d
- Shortcut *CausalConv3d
-}
-
-// newResBlock creates a residual block
-func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
- norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
- if err != nil {
- return nil, err
- }
- conv1, err := newCausalConv3d(weights, prefix+".conv1")
- if err != nil {
- return nil, err
- }
- norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
- if err != nil {
- return nil, err
- }
- conv2, err := newCausalConv3d(weights, prefix+".conv2")
- if err != nil {
- return nil, err
- }
-
- var shortcut *CausalConv3d
- if inDim != outDim {
- shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
- if err != nil {
- return nil, err
- }
- }
-
- return &ResBlock{
- Norm1: norm1,
- Conv1: conv1,
- Norm2: norm2,
- Conv2: conv2,
- Shortcut: shortcut,
- }, nil
-}
-
-// Forward applies the residual block
-func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
- // Use h as working variable, keep x intact for residual (caller will free x)
- // Conv handles its own pools, so we just need pools for non-conv operations
- var h *mlx.Array
-
- // Keep x so it survives Eval() cleanup - needed for residual connection
- mlx.Keep(x)
-
- // Stage 1: norm1 + silu
- {
- h = r.Norm1.Forward(x)
- h = silu3D(h)
- mlx.Eval(h)
- }
-
- // Stage 2: conv1 (handles its own pools)
- {
- prev := h
- h = r.Conv1.Forward(h)
- prev.Free()
- }
-
- // Stage 3: norm2 + silu
- {
- prev := h
- h = r.Norm2.Forward(h)
- h = silu3D(h)
- prev.Free()
- mlx.Eval(h)
- }
-
- // Stage 4: conv2 (handles its own pools)
- {
- prev := h
- h = r.Conv2.Forward(h)
- prev.Free()
- }
-
- // Residual connection (shortcut handles its own pools if present)
- if r.Shortcut != nil {
- shortcut := r.Shortcut.Forward(x)
- h = mlx.Add(h, shortcut)
- mlx.Eval(h)
- } else {
- h = mlx.Add(h, x)
- mlx.Eval(h)
- }
-
- return h
-}
-
-// AttentionBlock is a 2D attention block
-type AttentionBlock struct {
- Norm *RMSNorm3D
- ToQKV *mlx.Array
- ToQKVBias *mlx.Array
- Proj *mlx.Array
- ProjBias *mlx.Array
- Dim int32
-}
-
-// newAttentionBlock creates an attention block
-func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
- norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
- if err != nil {
- return nil, err
- }
- toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
- toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
- proj, _ := weights.Get(prefix + ".proj.weight")
- projBias, _ := weights.Get(prefix + ".proj.bias")
-
- return &AttentionBlock{
- Norm: norm,
- ToQKV: toQKV,
- ToQKVBias: toQKVBias,
- Proj: proj,
- ProjBias: projBias,
- Dim: dim,
- }, nil
-}
-
-// Forward applies 2D attention
-// Input: [B, T, H, W, C] (channels-last)
-func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
-
- identity := x
-
- // Flatten to [B*T, 1, H, W, C] for norm
- x = mlx.Reshape(x, B*T, 1, H, W, C)
- x = a.Norm.Forward(x)
- x = mlx.Reshape(x, B*T, H, W, C)
-
- // Flatten spatial to [B*T, H*W, C]
- x = mlx.Reshape(x, B*T, H*W, C)
-
- // Linear to get Q, K, V: [B*T, H*W, 3*C]
- // Weight is [outC, inC] or [outC, inC, 1, 1]
- wShape := a.ToQKV.Shape()
- var w *mlx.Array
- if len(wShape) == 4 {
- w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
- } else {
- w = a.ToQKV
- }
- w = mlx.Transpose(w, 1, 0) // [inC, outC]
-
- qkv := mlx.Linear(x, w) // [B*T, H*W, 3*C]
- if a.ToQKVBias != nil {
- qkv = mlx.Add(qkv, a.ToQKVBias)
- }
- qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
-
- q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
- k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
- v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
-
- scale := float32(1.0 / math.Sqrt(float64(C)))
- out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
-
- // out: [B*T, 1, H*W, C]
- out = mlx.Reshape(out, B*T, H*W, C)
-
- // Project back
- pShape := a.Proj.Shape()
- var p *mlx.Array
- if len(pShape) == 4 {
- p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
- } else {
- p = a.Proj
- }
- p = mlx.Transpose(p, 1, 0) // [inC, outC]
- out = mlx.Linear(out, p) // [B*T, H*W, C]
- if a.ProjBias != nil {
- out = mlx.Add(out, a.ProjBias)
- }
-
- out = mlx.Reshape(out, B, T, H, W, C)
- return mlx.Add(out, identity)
-}
-
-// UpBlock handles upsampling in decoder
-type UpBlock struct {
- ResBlocks []*ResBlock
- Upsampler *Upsample
-}
-
-// newUpBlock creates an up block
-func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
- resBlocks := make([]*ResBlock, numBlocks+1)
-
- currentDim := inDim
- for i := int32(0); i <= numBlocks; i++ {
- resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
- block, err := newResBlock(weights, resPrefix, currentDim, outDim)
- if err != nil {
- return nil, err
- }
- resBlocks[i] = block
- currentDim = outDim
- }
-
- var upsampler *Upsample
- if upsampleMode != "" {
- upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
- }
-
- return &UpBlock{
- ResBlocks: resBlocks,
- Upsampler: upsampler,
- }, nil
-}
-
-// Forward applies up block with staged memory management
-func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
- // ResBlocks handle their own pools
- for _, block := range u.ResBlocks {
- prev := x
- x = block.Forward(x)
- prev.Free()
- }
-
- // Upsampler handles its own pools
- if u.Upsampler != nil {
- prev := x
- x = u.Upsampler.Forward(x)
- prev.Free()
- }
- return x
-}
-
-// Upsample handles spatial upsampling
-type Upsample struct {
- Conv *mlx.Array
- Bias *mlx.Array
- Mode string
-}
-
-// newUpsample creates an upsampler
-func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
- conv, _ := weights.Get(prefix + ".resample.1.weight")
- bias, _ := weights.Get(prefix + ".resample.1.bias")
- return &Upsample{
- Conv: conv,
- Bias: bias,
- Mode: mode,
- }
-}
-
-// Forward applies upsampling to channels-last input [B, T, H, W, C]
-// Uses staged pools to reduce peak memory during 2x upsampling
-func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
- outC := u.Conv.Shape()[0]
-
- // Stage 1: 2x nearest neighbor upsample
- {
- x = mlx.Reshape(x, B*T, H, W, C)
- x = upsample2xChannelsLast(x)
- mlx.Eval(x)
- }
-
- // Stage 2: Conv + bias
- {
- prev := x
- weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
- x = conv2D3x3PaddedChannelsLast(x, weight)
- if u.Bias != nil {
- bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
- x = mlx.Add(x, bias)
- }
- x = mlx.Reshape(x, B, T, H*2, W*2, outC)
- prev.Free()
- mlx.Eval(x)
- }
-
- return x
-}
-
-// MidBlock is the middle block of decoder
-type MidBlock struct {
- ResBlock1 *ResBlock
- Attention *AttentionBlock
- ResBlock2 *ResBlock
-}
-
-// newMidBlock creates a mid block
-func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
- res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
- if err != nil {
- return nil, err
- }
- attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
- if err != nil {
- return nil, err
- }
- res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
- if err != nil {
- return nil, err
- }
-
- return &MidBlock{
- ResBlock1: res1,
- Attention: attn,
- ResBlock2: res2,
- }, nil
-}
-
-// Forward applies mid block
-func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
- // Each component handles its own pools; we just free inputs
- prev := x
- x = m.ResBlock1.Forward(x)
- prev.Free()
-
- prev = x
- x = m.Attention.Forward(x)
- prev.Free()
-
- prev = x
- x = m.ResBlock2.Forward(x)
- prev.Free()
-
- return x
-}
-
-// VAEDecoder is the full VAE decoder
-type VAEDecoder struct {
- Config *VAEConfig
-
- PostQuantConv *CausalConv3d
- ConvIn *CausalConv3d
- MidBlock *MidBlock
- UpBlocks []*UpBlock
- NormOut *RMSNorm3D
- ConvOut *CausalConv3d
-}
-
-// Load loads the VAE decoder from a directory
-func (m *VAEDecoder) Load(path string) error {
- fmt.Println("Loading Qwen-Image VAE decoder...")
-
- cfg := defaultVAEConfig()
- m.Config = cfg
-
- weights, err := safetensors.LoadModelWeights(path)
- if err != nil {
- return fmt.Errorf("weights: %w", err)
- }
-
- // Bulk load all weights as bf16
- fmt.Print(" Loading weights as bf16... ")
- if err := weights.Load(mlx.DtypeBFloat16); err != nil {
- return fmt.Errorf("failed to load weights: %w", err)
- }
- fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
-
- fmt.Print(" Loading post_quant_conv... ")
- postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
- if err != nil {
- return err
- }
- m.PostQuantConv = postQuantConv
- fmt.Println("✓")
-
- fmt.Print(" Loading conv_in... ")
- convIn, err := newCausalConv3d(weights, "decoder.conv_in")
- if err != nil {
- return err
- }
- m.ConvIn = convIn
- fmt.Println("✓")
-
- // Mid block (dim = base_dim * dim_mult[-1] = 96 * 4 = 384)
- fmt.Print(" Loading mid_block... ")
- midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
- midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
- if err != nil {
- return err
- }
- m.MidBlock = midBlock
- fmt.Println("✓")
-
- // Up blocks (reversed dim_mult)
- fmt.Print(" Loading up_blocks... ")
- numUpBlocks := len(cfg.DimMult)
- m.UpBlocks = make([]*UpBlock, numUpBlocks)
-
- dimsMult := make([]int32, numUpBlocks+1)
- dimsMult[0] = cfg.DimMult[numUpBlocks-1]
- for i := 0; i < numUpBlocks; i++ {
- dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
- }
-
- temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
- for i := range cfg.TemperalDownsample {
- temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
- }
-
- for i := 0; i < numUpBlocks; i++ {
- inDim := cfg.BaseDim * dimsMult[i]
- outDim := cfg.BaseDim * dimsMult[i+1]
-
- if i > 0 {
- inDim = inDim / 2
- }
-
- upsampleMode := ""
- if i < numUpBlocks-1 {
- if temporalUpsample[i] {
- upsampleMode = "upsample3d"
- } else {
- upsampleMode = "upsample2d"
- }
- }
-
- prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
- upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
- if err != nil {
- return err
- }
- m.UpBlocks[i] = upBlock
- }
- fmt.Printf("✓ [%d blocks]\n", numUpBlocks)
-
- fmt.Print(" Loading output layers... ")
- normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
- if err != nil {
- return err
- }
- m.NormOut = normOut
- convOut, err := newCausalConv3d(weights, "decoder.conv_out")
- if err != nil {
- return err
- }
- m.ConvOut = convOut
- fmt.Println("✓")
-
- weights.ReleaseAll()
- return nil
-}
-
-// LoadVAEDecoderFromPath is a convenience function to load VAE from path
-func LoadVAEDecoderFromPath(path string) (*VAEDecoder, error) {
- m := &VAEDecoder{}
- if err := m.Load(filepath.Join(path, "vae")); err != nil {
- return nil, err
- }
- return m, nil
-}
-
-// Decode converts latents to image
-// z: [B, C, T, H, W] normalized latents
-// Uses staged pools to free intermediate arrays and reduce peak memory.
-func (vae *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
- var x *mlx.Array
-
- // Stage 1a: Denormalize and transpose
- {
- z = vae.Denormalize(z)
- // Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
- z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
- mlx.Eval(z)
- }
-
- // Stage 1b: PostQuantConv (handles its own pools)
- x = vae.PostQuantConv.Forward(z)
- z.Free()
-
- // Stage 1c: ConvIn (handles its own pools)
- {
- prev := x
- x = vae.ConvIn.Forward(x)
- prev.Free()
- }
-
- // Stage 2: Mid block (handles its own pools)
- x = vae.MidBlock.Forward(x)
-
- // Stage 3: Up blocks (each handles its own pools)
- for _, upBlock := range vae.UpBlocks {
- x = upBlock.Forward(x)
- }
-
- // Stage 4a: NormOut + silu
- {
- prev := x
- x = vae.NormOut.Forward(x)
- x = silu3D(x)
- prev.Free()
- mlx.Eval(x)
- }
-
- // Stage 4b: ConvOut (handles its own pools)
- {
- prev := x
- x = vae.ConvOut.Forward(x)
- prev.Free()
- }
-
- // Stage 4c: Post-processing
- {
- prev := x
- // Clamp to [-1, 1]
- x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
- // Convert back from channels-last to channels-first
- x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
- prev.Free()
- mlx.Eval(x)
- }
-
- return x
-}
-
-// Denormalize reverses the normalization applied during encoding
-func (vae *VAEDecoder) Denormalize(z *mlx.Array) *mlx.Array {
- shape := z.Shape()
- C := shape[1]
-
- mean := mlx.NewArray(vae.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
- std := mlx.NewArray(vae.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
-
- mean = mlx.ToBFloat16(mean)
- std = mlx.ToBFloat16(std)
-
- return mlx.Add(mlx.Mul(z, std), mean)
-}
-
-// Helper functions
-
-func silu3D(x *mlx.Array) *mlx.Array {
- return mlx.Mul(x, mlx.Sigmoid(x))
-}
-
-// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
-func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
- if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
- return x
- }
- // Pad dims: [B before, B after, T before, T after, H before, H after, W before, W after, C before, C after]
- return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
-}
-
-func pad2D(x *mlx.Array, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
- if hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
- return x
- }
- return mlx.Pad(x, []int32{0, 0, 0, 0, hBefore, hAfter, wBefore, wAfter})
-}
-
-func conv2D1x1(x, weight *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[2]
- W := shape[3]
-
- x = mlx.Transpose(x, 0, 2, 3, 1)
- x = mlx.Reshape(x, B*H*W, shape[1])
-
- wShape := weight.Shape()
- var w *mlx.Array
- if len(wShape) == 4 {
- w = mlx.Reshape(weight, wShape[0], wShape[1])
- } else {
- w = weight
- }
- w = mlx.Transpose(w, 1, 0)
-
- out := mlx.Linear(x, w)
- outC := w.Dim(1)
- out = mlx.Reshape(out, B, H, W, outC)
- return mlx.Transpose(out, 0, 3, 1, 2)
-}
-
-func conv2D3x3Padded(x, weight *mlx.Array) *mlx.Array {
- x = pad2D(x, 1, 1, 1, 1)
- return conv2D(x, weight, 1, 1)
-}
-
-func conv2D(x, w *mlx.Array, strideH, strideW int32) *mlx.Array {
- x = mlx.Transpose(x, 0, 2, 3, 1)
- w = mlx.Transpose(w, 0, 2, 3, 1)
-
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- W := shape[2]
-
- wShape := w.Shape()
- Cout := wShape[0]
- kH := wShape[1]
- kW := wShape[2]
-
- outH := (H-kH)/strideH + 1
- outW := (W-kW)/strideW + 1
-
- patches := extractPatches2D(x, kH, kW, strideH, strideW)
- wFlat := mlx.Reshape(w, Cout, -1)
- patches = mlx.Reshape(patches, B*outH*outW, -1)
- out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
- out = mlx.Reshape(out, B, outH, outW, Cout)
- return mlx.Transpose(out, 0, 3, 1, 2)
-}
-
-func extractPatches2D(x *mlx.Array, kH, kW, strideH, strideW int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- W := shape[2]
- C := shape[3]
-
- outH := (H-kH)/strideH + 1
- outW := (W-kW)/strideW + 1
-
- patches := make([]*mlx.Array, outH*outW)
- idx := 0
- for i := int32(0); i < outH; i++ {
- for j := int32(0); j < outW; j++ {
- startH := i * strideH
- startW := j * strideW
- patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
- patch = mlx.Reshape(patch, B, kH*kW*C)
- patches[idx] = patch
- idx++
- }
- }
-
- for i := range patches {
- patches[i] = mlx.ExpandDims(patches[i], 1)
- }
- stacked := mlx.Concatenate(patches, 1)
- return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
-}
-
-func upsample2x(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- H := shape[2]
- W := shape[3]
-
- rowIdxData := make([]int32, H*2)
- for i := int32(0); i < H; i++ {
- rowIdxData[i*2] = i
- rowIdxData[i*2+1] = i
- }
- rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
-
- colIdxData := make([]int32, W*2)
- for i := int32(0); i < W; i++ {
- colIdxData[i*2] = i
- colIdxData[i*2+1] = i
- }
- colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
-
- x = mlx.Take(x, rowIdx, 2)
- x = mlx.Take(x, colIdx, 3)
-
- return x
-}
-
-// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
-func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- H := shape[1]
- W := shape[2]
-
- // Create repeat indices for rows
- rowIdxData := make([]int32, H*2)
- for i := int32(0); i < H; i++ {
- rowIdxData[i*2] = i
- rowIdxData[i*2+1] = i
- }
- rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
-
- // Create repeat indices for columns
- colIdxData := make([]int32, W*2)
- for i := int32(0); i < W; i++ {
- colIdxData[i*2] = i
- colIdxData[i*2+1] = i
- }
- colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
-
- // Take along H (axis 1) then W (axis 2)
- x = mlx.Take(x, rowIdx, 1)
- x = mlx.Take(x, colIdx, 2)
-
- return x
-}
-
-// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
-// weight: [outC, kH, kW, inC] (MLX channels-last format)
-func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
- // Pad spatial dims: [B, H, W, C] -> pad H and W by 1 each side
- x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
- // Conv2d expects: input [B, H, W, inC], weight [outC, kH, kW, inC]
- // stride=1, padding=0 (we already padded manually)
- return mlx.Conv2d(x, weight, 1, 0)
-}
diff --git a/x/imagegen/models/qwen_image/vae_test.go b/x/imagegen/models/qwen_image/vae_test.go
deleted file mode 100644
index f15a1134bde..00000000000
--- a/x/imagegen/models/qwen_image/vae_test.go
+++ /dev/null
@@ -1,114 +0,0 @@
-//go:build mlx
-
-package qwen_image
-
-import (
- "math"
- "os"
- "testing"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
-)
-
-// TestVAEConfig tests configuration invariants.
-func TestVAEConfig(t *testing.T) {
- cfg := defaultVAEConfig()
-
- // Property: latents_mean and latents_std have z_dim elements
- if int32(len(cfg.LatentsMean)) != cfg.ZDim {
- t.Errorf("latents_mean length != z_dim: %d != %d", len(cfg.LatentsMean), cfg.ZDim)
- }
- if int32(len(cfg.LatentsStd)) != cfg.ZDim {
- t.Errorf("latents_std length != z_dim: %d != %d", len(cfg.LatentsStd), cfg.ZDim)
- }
-
- // Property: dim_mult defines 4 stages
- if len(cfg.DimMult) != 4 {
- t.Errorf("dim_mult should have 4 stages: got %d", len(cfg.DimMult))
- }
-
- // Property: temperal_downsample has 3 elements (for 3 transitions)
- if len(cfg.TemperalDownsample) != 3 {
- t.Errorf("temperal_downsample should have 3 elements: got %d", len(cfg.TemperalDownsample))
- }
-}
-
-// TestVAELatentsNormalization tests the latent denormalization values.
-func TestVAELatentsNormalization(t *testing.T) {
- cfg := defaultVAEConfig()
-
- // Verify latents_std values are all positive
- for i, std := range cfg.LatentsStd {
- if std <= 0 {
- t.Errorf("latents_std[%d] should be positive: %v", i, std)
- }
- }
-
- // Verify values are in reasonable range (from actual model)
- for i, mean := range cfg.LatentsMean {
- if math.Abs(float64(mean)) > 5 {
- t.Errorf("latents_mean[%d] seems too large: %v", i, mean)
- }
- }
- for i, std := range cfg.LatentsStd {
- if std > 10 {
- t.Errorf("latents_std[%d] seems too large: %v", i, std)
- }
- }
-}
-
-// TestVAEDecoderForward tests full forward pass (integration test).
-// Skips if model weights are not available.
-func TestVAEDecoderForward(t *testing.T) {
- weightsPath := "../../../weights/Qwen-Image-2512/vae"
- if _, err := os.Stat(weightsPath); os.IsNotExist(err) {
- t.Skip("Skipping: model weights not found at " + weightsPath)
- }
-
- vae := &VAEDecoder{}
- if err := vae.Load(weightsPath); err != nil {
- t.Fatalf("Failed to load VAE decoder: %v", err)
- }
- mlx.Keep(mlx.Collect(vae)...)
-
- // Small test input: [B, C, T, H, W]
- // After 4 upsampling stages (2x each), H/W multiply by 16
- batchSize := int32(1)
- channels := int32(16)
- frames := int32(1)
- latentH := int32(4)
- latentW := int32(4)
-
- latents := mlx.RandomNormal([]int32{batchSize, channels, frames, latentH, latentW}, 0)
-
- // Decode
- out := vae.Decode(latents)
- mlx.Eval(out)
-
- // Verify output shape: [B, 3, T, H*16, W*16]
- outShape := out.Shape()
- if outShape[0] != batchSize {
- t.Errorf("batch size: got %d, want %d", outShape[0], batchSize)
- }
- if outShape[1] != 3 {
- t.Errorf("channels: got %d, want 3", outShape[1])
- }
- if outShape[2] != frames {
- t.Errorf("frames: got %d, want %d", outShape[2], frames)
- }
- expectedH := latentH * 16 // 4 stages of 2x upsampling
- expectedW := latentW * 16
- if outShape[3] != expectedH || outShape[4] != expectedW {
- t.Errorf("spatial dims: got [%d, %d], want [%d, %d]",
- outShape[3], outShape[4], expectedH, expectedW)
- }
-
- // Verify output is in valid range (should be clamped to [0, 1] by decode)
- outData := out.Data()
- for i := 0; i < min(100, len(outData)); i++ {
- if math.IsNaN(float64(outData[i])) || math.IsInf(float64(outData[i]), 0) {
- t.Errorf("output[%d] not finite: %v", i, outData[i])
- break
- }
- }
-}
diff --git a/x/imagegen/models/qwen_image_edit/layers.go b/x/imagegen/models/qwen_image_edit/layers.go
deleted file mode 100644
index 04c19207796..00000000000
--- a/x/imagegen/models/qwen_image_edit/layers.go
+++ /dev/null
@@ -1,682 +0,0 @@
-//go:build mlx
-
-package qwen_image_edit
-
-import (
- "fmt"
- "math"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/safetensors"
-)
-
-// CausalConv3d is a causal 3D convolution (for temporal causality)
-type CausalConv3d struct {
- Weight *mlx.Array
- Bias *mlx.Array
- BiasReshaped *mlx.Array // [1, C, 1, 1, 1]
- KernelT int32
-}
-
-// newCausalConv3d creates a 3D causal conv
-func newCausalConv3d(weights *safetensors.ModelWeights, prefix string) (*CausalConv3d, error) {
- weight, err := weights.Get(prefix + ".weight")
- if err != nil {
- return nil, fmt.Errorf("weight not found: %s", prefix)
- }
- bias, _ := weights.Get(prefix + ".bias")
-
- kernelT := weight.Shape()[2]
- outC := weight.Shape()[0]
-
- var biasReshaped *mlx.Array
- if bias != nil {
- biasReshaped = mlx.Reshape(bias, 1, outC, 1, 1, 1)
- }
-
- return &CausalConv3d{
- Weight: weight,
- Bias: bias,
- BiasReshaped: biasReshaped,
- KernelT: kernelT,
- }, nil
-}
-
-// Forward applies causal 3D convolution (or 2D if weight is 4D)
-// x: [B, T, H, W, C] (channels-last, MLX format)
-func (c *CausalConv3d) Forward(x *mlx.Array) *mlx.Array {
- shape := c.Weight.Shape()
-
- // Handle both 5D (3D conv) and 4D (2D conv) weights
- if len(shape) == 4 {
- // 2D conv: [O, I, kH, kW] - need to apply per-frame
- return c.forward2D(x)
- }
-
- // 3D conv: [O, I, kT, kH, kW]
- kernelT := shape[2]
- kernelH := shape[3]
- kernelW := shape[4]
-
- // Causal temporal padding, same spatial padding
- padT := kernelT - 1
- padH := kernelH / 2
- padW := kernelW / 2
-
- // Stage 1: Pad
- {
- x = pad3DChannelsLast(x, padT, 0, padH, padH, padW, padW)
- mlx.Eval(x)
- }
-
- // Stage 2: Conv + bias
- var out *mlx.Array
- {
- prev := x
- weight := mlx.Transpose(c.Weight, 0, 2, 3, 4, 1)
- out = mlx.Conv3d(x, weight, 1, 1, 1, 0, 0, 0)
- if c.Bias != nil {
- bias := mlx.Reshape(c.Bias, 1, 1, 1, 1, c.Bias.Dim(0))
- out = mlx.Add(out, bias)
- }
- prev.Free()
- mlx.Eval(out)
- }
-
- return out
-}
-
-// forward2D applies 2D conv per-frame for [B, T, H, W, C] input
-func (c *CausalConv3d) forward2D(x *mlx.Array) *mlx.Array {
- xShape := x.Shape()
- B := xShape[0]
- T := xShape[1]
- H := xShape[2]
- W := xShape[3]
- C := xShape[4]
-
- wShape := c.Weight.Shape() // [O, I, kH, kW]
- kernelH := wShape[2]
- kernelW := wShape[3]
- outC := wShape[0]
-
- padH := kernelH / 2
- padW := kernelW / 2
-
- // Reshape to [B*T, H, W, C] for 2D conv
- x = mlx.Reshape(x, B*T, H, W, C)
-
- // Pad spatially
- x = mlx.Pad(x, []int32{0, 0, padH, padH, padW, padW, 0, 0})
-
- // Apply 2D conv
- weight := mlx.Transpose(c.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
- x = mlx.Conv2d(x, weight, 1, 0)
-
- if c.Bias != nil {
- bias := mlx.Reshape(c.Bias, 1, 1, 1, outC)
- x = mlx.Add(x, bias)
- }
-
- // Get output spatial dims
- outH := H
- outW := W
-
- // Reshape back to [B, T, H, W, C]
- x = mlx.Reshape(x, B, T, outH, outW, outC)
- mlx.Eval(x)
-
- return x
-}
-
-// RMSNorm3D applies RMS normalization over channels
-type RMSNorm3D struct {
- Gamma *mlx.Array // [1, 1, 1, 1, C] for broadcasting
-}
-
-// newRMSNorm3D creates an RMS norm
-func newRMSNorm3D(weights *safetensors.ModelWeights, prefix string, dim int32) (*RMSNorm3D, error) {
- gamma, err := weights.Get(prefix + ".gamma")
- if err != nil {
- return nil, err
- }
- gamma = mlx.Reshape(gamma, 1, 1, 1, 1, gamma.Dim(0))
- return &RMSNorm3D{Gamma: gamma}, nil
-}
-
-// Forward applies RMS norm to channels-last input [B, T, H, W, C]
-func (n *RMSNorm3D) Forward(x *mlx.Array) *mlx.Array {
- normalized := mlx.RMSNormNoWeight(x, 1e-6)
- return mlx.Mul(normalized, n.Gamma)
-}
-
-// ResBlock is a residual block with RMS norm and causal convs
-type ResBlock struct {
- Norm1 *RMSNorm3D
- Conv1 *CausalConv3d
- Norm2 *RMSNorm3D
- Conv2 *CausalConv3d
- Shortcut *CausalConv3d
-}
-
-// newResBlock creates a residual block
-func newResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*ResBlock, error) {
- norm1, err := newRMSNorm3D(weights, prefix+".norm1", inDim)
- if err != nil {
- return nil, err
- }
- conv1, err := newCausalConv3d(weights, prefix+".conv1")
- if err != nil {
- return nil, err
- }
- norm2, err := newRMSNorm3D(weights, prefix+".norm2", outDim)
- if err != nil {
- return nil, err
- }
- conv2, err := newCausalConv3d(weights, prefix+".conv2")
- if err != nil {
- return nil, err
- }
-
- var shortcut *CausalConv3d
- if inDim != outDim {
- shortcut, err = newCausalConv3d(weights, prefix+".conv_shortcut")
- if err != nil {
- return nil, err
- }
- }
-
- return &ResBlock{
- Norm1: norm1,
- Conv1: conv1,
- Norm2: norm2,
- Conv2: conv2,
- Shortcut: shortcut,
- }, nil
-}
-
-// Forward applies the residual block
-func (r *ResBlock) Forward(x *mlx.Array) *mlx.Array {
- var h *mlx.Array
-
- mlx.Keep(x)
-
- // Stage 1: norm1 + silu
- {
- h = r.Norm1.Forward(x)
- h = silu3D(h)
- mlx.Eval(h)
- }
-
- // Stage 2: conv1
- {
- prev := h
- h = r.Conv1.Forward(h)
- prev.Free()
- }
-
- // Stage 3: norm2 + silu
- {
- prev := h
- h = r.Norm2.Forward(h)
- h = silu3D(h)
- prev.Free()
- mlx.Eval(h)
- }
-
- // Stage 4: conv2
- {
- prev := h
- h = r.Conv2.Forward(h)
- prev.Free()
- }
-
- // Residual connection
- if r.Shortcut != nil {
- shortcut := r.Shortcut.Forward(x)
- h = mlx.Add(h, shortcut)
- mlx.Eval(h)
- } else {
- h = mlx.Add(h, x)
- mlx.Eval(h)
- }
-
- return h
-}
-
-// AttentionBlock is a 2D attention block
-type AttentionBlock struct {
- Norm *RMSNorm3D
- ToQKV *mlx.Array
- ToQKVBias *mlx.Array
- Proj *mlx.Array
- ProjBias *mlx.Array
- Dim int32
-}
-
-// newAttentionBlock creates an attention block
-func newAttentionBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*AttentionBlock, error) {
- norm, err := newRMSNorm3D(weights, prefix+".norm", dim)
- if err != nil {
- return nil, err
- }
- toQKV, _ := weights.Get(prefix + ".to_qkv.weight")
- toQKVBias, _ := weights.Get(prefix + ".to_qkv.bias")
- proj, _ := weights.Get(prefix + ".proj.weight")
- projBias, _ := weights.Get(prefix + ".proj.bias")
-
- return &AttentionBlock{
- Norm: norm,
- ToQKV: toQKV,
- ToQKVBias: toQKVBias,
- Proj: proj,
- ProjBias: projBias,
- Dim: dim,
- }, nil
-}
-
-// Forward applies 2D attention
-// Input: [B, T, H, W, C] (channels-last)
-func (a *AttentionBlock) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
-
- identity := x
-
- // Flatten to [B*T, 1, H, W, C] for norm
- x = mlx.Reshape(x, B*T, 1, H, W, C)
- x = a.Norm.Forward(x)
- x = mlx.Reshape(x, B*T, H, W, C)
-
- // Flatten spatial to [B*T, H*W, C]
- x = mlx.Reshape(x, B*T, H*W, C)
-
- // Linear to get Q, K, V
- wShape := a.ToQKV.Shape()
- var w *mlx.Array
- if len(wShape) == 4 {
- w = mlx.Reshape(a.ToQKV, wShape[0], wShape[1])
- } else {
- w = a.ToQKV
- }
- w = mlx.Transpose(w, 1, 0)
-
- qkv := mlx.Linear(x, w)
- if a.ToQKVBias != nil {
- qkv = mlx.Add(qkv, a.ToQKVBias)
- }
- qkv = mlx.Reshape(qkv, B*T, 1, H*W, 3*C)
-
- q := mlx.Slice(qkv, []int32{0, 0, 0, 0}, []int32{B * T, 1, H * W, C})
- k := mlx.Slice(qkv, []int32{0, 0, 0, C}, []int32{B * T, 1, H * W, 2 * C})
- v := mlx.Slice(qkv, []int32{0, 0, 0, 2 * C}, []int32{B * T, 1, H * W, 3 * C})
-
- scale := float32(1.0 / math.Sqrt(float64(C)))
- out := mlx.ScaledDotProductAttention(q, k, v, scale, false)
-
- out = mlx.Reshape(out, B*T, H*W, C)
-
- // Project back
- pShape := a.Proj.Shape()
- var p *mlx.Array
- if len(pShape) == 4 {
- p = mlx.Reshape(a.Proj, pShape[0], pShape[1])
- } else {
- p = a.Proj
- }
- p = mlx.Transpose(p, 1, 0)
- out = mlx.Linear(out, p)
- if a.ProjBias != nil {
- out = mlx.Add(out, a.ProjBias)
- }
-
- out = mlx.Reshape(out, B, T, H, W, C)
- return mlx.Add(out, identity)
-}
-
-// UpBlock handles upsampling in decoder
-type UpBlock struct {
- ResBlocks []*ResBlock
- Upsampler *Upsample
-}
-
-// newUpBlock creates an up block
-func newUpBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, upsampleMode string) (*UpBlock, error) {
- resBlocks := make([]*ResBlock, numBlocks+1)
-
- currentDim := inDim
- for i := int32(0); i <= numBlocks; i++ {
- resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
- block, err := newResBlock(weights, resPrefix, currentDim, outDim)
- if err != nil {
- return nil, err
- }
- resBlocks[i] = block
- currentDim = outDim
- }
-
- var upsampler *Upsample
- if upsampleMode != "" {
- upsampler = newUpsample(weights, prefix+".upsamplers.0", outDim, upsampleMode)
- }
-
- return &UpBlock{
- ResBlocks: resBlocks,
- Upsampler: upsampler,
- }, nil
-}
-
-// Forward applies up block
-func (u *UpBlock) Forward(x *mlx.Array) *mlx.Array {
- for _, block := range u.ResBlocks {
- prev := x
- x = block.Forward(x)
- prev.Free()
- }
-
- if u.Upsampler != nil {
- prev := x
- x = u.Upsampler.Forward(x)
- prev.Free()
- }
- return x
-}
-
-// Upsample handles spatial upsampling
-type Upsample struct {
- Conv *mlx.Array
- Bias *mlx.Array
- Mode string
-}
-
-// newUpsample creates an upsampler
-func newUpsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Upsample {
- conv, _ := weights.Get(prefix + ".resample.1.weight")
- bias, _ := weights.Get(prefix + ".resample.1.bias")
- return &Upsample{
- Conv: conv,
- Bias: bias,
- Mode: mode,
- }
-}
-
-// Forward applies upsampling to channels-last input [B, T, H, W, C]
-func (u *Upsample) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
- outC := u.Conv.Shape()[0]
-
- // Stage 1: 2x nearest neighbor upsample
- {
- x = mlx.Reshape(x, B*T, H, W, C)
- x = upsample2xChannelsLast(x)
- mlx.Eval(x)
- }
-
- // Stage 2: Conv + bias
- {
- prev := x
- weight := mlx.Transpose(u.Conv, 0, 2, 3, 1)
- x = conv2D3x3PaddedChannelsLast(x, weight)
- if u.Bias != nil {
- bias := mlx.Reshape(u.Bias, 1, 1, 1, outC)
- x = mlx.Add(x, bias)
- }
- x = mlx.Reshape(x, B, T, H*2, W*2, outC)
- prev.Free()
- mlx.Eval(x)
- }
-
- return x
-}
-
-// MidBlock is the middle block
-type MidBlock struct {
- ResBlock1 *ResBlock
- Attention *AttentionBlock
- ResBlock2 *ResBlock
-}
-
-// newMidBlock creates a mid block
-func newMidBlock(weights *safetensors.ModelWeights, prefix string, dim int32) (*MidBlock, error) {
- res1, err := newResBlock(weights, prefix+".resnets.0", dim, dim)
- if err != nil {
- return nil, err
- }
- attn, err := newAttentionBlock(weights, prefix+".attentions.0", dim)
- if err != nil {
- return nil, err
- }
- res2, err := newResBlock(weights, prefix+".resnets.1", dim, dim)
- if err != nil {
- return nil, err
- }
-
- return &MidBlock{
- ResBlock1: res1,
- Attention: attn,
- ResBlock2: res2,
- }, nil
-}
-
-// Forward applies mid block
-func (m *MidBlock) Forward(x *mlx.Array) *mlx.Array {
- prev := x
- x = m.ResBlock1.Forward(x)
- prev.Free()
-
- prev = x
- x = m.Attention.Forward(x)
- prev.Free()
-
- prev = x
- x = m.ResBlock2.Forward(x)
- prev.Free()
-
- return x
-}
-
-// Helper functions
-
-func silu3D(x *mlx.Array) *mlx.Array {
- return mlx.Mul(x, mlx.Sigmoid(x))
-}
-
-// pad3DChannelsLast pads a channels-last [B, T, H, W, C] tensor
-func pad3DChannelsLast(x *mlx.Array, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter int32) *mlx.Array {
- if tBefore == 0 && tAfter == 0 && hBefore == 0 && hAfter == 0 && wBefore == 0 && wAfter == 0 {
- return x
- }
- return mlx.Pad(x, []int32{0, 0, tBefore, tAfter, hBefore, hAfter, wBefore, wAfter, 0, 0})
-}
-
-// upsample2xChannelsLast upsamples channels-last input [B, H, W, C] by 2x
-func upsample2xChannelsLast(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- H := shape[1]
- W := shape[2]
-
- rowIdxData := make([]int32, H*2)
- for i := int32(0); i < H; i++ {
- rowIdxData[i*2] = i
- rowIdxData[i*2+1] = i
- }
- rowIdx := mlx.NewArrayInt32(rowIdxData, []int32{H * 2})
-
- colIdxData := make([]int32, W*2)
- for i := int32(0); i < W; i++ {
- colIdxData[i*2] = i
- colIdxData[i*2+1] = i
- }
- colIdx := mlx.NewArrayInt32(colIdxData, []int32{W * 2})
-
- x = mlx.Take(x, rowIdx, 1)
- x = mlx.Take(x, colIdx, 2)
-
- return x
-}
-
-// conv2D3x3PaddedChannelsLast applies 3x3 conv with padding to channels-last input [B, H, W, C]
-func conv2D3x3PaddedChannelsLast(x, weight *mlx.Array) *mlx.Array {
- x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
- return mlx.Conv2d(x, weight, 1, 0)
-}
-
-// conv2DStrided applies conv with stride > 1 using manual patch extraction
-// x: [B, H, W, C] (channels-last), weight: [O, kH, kW, I]
-func conv2DStrided(x, weight *mlx.Array, stride int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- W := shape[2]
-
- wShape := weight.Shape()
- Cout := wShape[0]
- kH := wShape[1]
- kW := wShape[2]
-
- outH := (H - kH) / stride + 1
- outW := (W - kW) / stride + 1
-
- patches := extractPatches2DStrided(x, kH, kW, stride)
- wFlat := mlx.Reshape(weight, Cout, -1)
- patches = mlx.Reshape(patches, B*outH*outW, -1)
- out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
- return mlx.Reshape(out, B, outH, outW, Cout)
-}
-
-// conv3DStrided applies 3D conv with strides using manual patch extraction
-// x: [B, T, H, W, C] (channels-last), weight: [O, I, kT, kH, kW] (PyTorch format)
-// strideT, strideH, strideW are the strides for each dimension
-// Patches are extracted in [C, T, H, W] order to match Python's preprocessing
-func conv3DStrided(x, weight *mlx.Array, strideT, strideH, strideW int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
-
- wShape := weight.Shape()
- Cout := wShape[0]
- // I := wShape[1]
- kT := wShape[2]
- kH := wShape[3]
- kW := wShape[4]
-
- // For temporal: if T < kT, we need to repeat frames temporally
- // For single image with T=1 and kT=2, we duplicate the frame to T=kT
- // Python Qwen2.5-VL duplicates the frame, not zero-pads
- if T < kT {
- // Tile along T dimension: [B, T, H, W, C] -> [B, kT, H, W, C]
- x = mlx.Tile(x, []int32{1, kT, 1, 1, 1})
- T = kT
- }
-
- outT := (T - kT) / strideT + 1
- outH := (H - kH) / strideH + 1
- outW := (W - kW) / strideW + 1
-
- // Extract 3D patches in [C, T, H, W] order to match Python
- patches := extractPatches3DStrided(x, kT, kH, kW, strideT, strideH, strideW)
- // patches shape: [B, outT, outH, outW, C*kT*kH*kW]
-
- // Weight is [O, I, kT, kH, kW] - flatten to [O, I*kT*kH*kW] to match patch order [C, T, H, W]
- wFlat := mlx.Reshape(weight, Cout, -1) // [Cout, I*kT*kH*kW]
- patches = mlx.Reshape(patches, B*outT*outH*outW, C*kT*kH*kW)
- out := mlx.Linear(patches, mlx.Transpose(wFlat, 1, 0))
- return mlx.Reshape(out, B, outT, outH, outW, Cout)
-}
-
-// extractPatches3DStrided extracts 3D patches with given strides
-// Returns patches with values in [C, T, H, W] order to match Python's preprocessing
-func extractPatches3DStrided(x *mlx.Array, kT, kH, kW, strideT, strideH, strideW int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
-
- outT := (T - kT) / strideT + 1
- outH := (H - kH) / strideH + 1
- outW := (W - kW) / strideW + 1
-
- numPatches := outT * outH * outW
- patches := make([]*mlx.Array, numPatches)
- idx := 0
- for t := int32(0); t < outT; t++ {
- for i := int32(0); i < outH; i++ {
- for j := int32(0); j < outW; j++ {
- startT := t * strideT
- startH := i * strideH
- startW := j * strideW
- // Extract patch: [B, kT, kH, kW, C]
- patch := mlx.Slice(x,
- []int32{0, startT, startH, startW, 0},
- []int32{B, startT + kT, startH + kH, startW + kW, C})
- // Transpose from [B, T, H, W, C] to [B, C, T, H, W] to match Python's order
- patch = mlx.Transpose(patch, 0, 4, 1, 2, 3)
- // Flatten to [B, C*T*H*W]
- patch = mlx.Reshape(patch, B, C*kT*kH*kW)
- patches[idx] = patch
- idx++
- }
- }
- }
-
- for i := range patches {
- patches[i] = mlx.ExpandDims(patches[i], 1)
- }
- stacked := mlx.Concatenate(patches, 1)
- return mlx.Reshape(stacked, B, outT, outH, outW, C*kT*kH*kW)
-}
-
-// extractPatches2DStrided extracts patches with given stride
-func extractPatches2DStrided(x *mlx.Array, kH, kW, stride int32) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- H := shape[1]
- W := shape[2]
- C := shape[3]
-
- outH := (H - kH) / stride + 1
- outW := (W - kW) / stride + 1
-
- patches := make([]*mlx.Array, outH*outW)
- idx := 0
- for i := int32(0); i < outH; i++ {
- for j := int32(0); j < outW; j++ {
- startH := i * stride
- startW := j * stride
- patch := mlx.Slice(x, []int32{0, startH, startW, 0}, []int32{B, startH + kH, startW + kW, C})
- patch = mlx.Reshape(patch, B, kH*kW*C)
- patches[idx] = patch
- idx++
- }
- }
-
- for i := range patches {
- patches[i] = mlx.ExpandDims(patches[i], 1)
- }
- stacked := mlx.Concatenate(patches, 1)
- return mlx.Reshape(stacked, B, outH, outW, kH*kW*C)
-}
-
-// layerNormNoAffine applies layer norm without learnable parameters
-func layerNormNoAffine(x *mlx.Array, eps float32) *mlx.Array {
- ndim := x.Ndim()
- lastAxis := ndim - 1
- mean := mlx.Mean(x, lastAxis, true)
- xCentered := mlx.Sub(x, mean)
- variance := mlx.Mean(mlx.Square(xCentered), lastAxis, true)
- return mlx.Div(xCentered, mlx.Sqrt(mlx.AddScalar(variance, eps)))
-}
diff --git a/x/imagegen/models/qwen_image_edit/processor.go b/x/imagegen/models/qwen_image_edit/processor.go
deleted file mode 100644
index c80f5a3b1eb..00000000000
--- a/x/imagegen/models/qwen_image_edit/processor.go
+++ /dev/null
@@ -1,475 +0,0 @@
-//go:build mlx
-
-package qwen_image_edit
-
-import (
- "fmt"
- "image"
- "image/color"
- _ "image/jpeg"
- _ "image/png"
- "math"
- "os"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "golang.org/x/image/draw"
- _ "golang.org/x/image/webp"
-)
-
-// loadImageFile loads an image from disk
-func loadImageFile(path string) (image.Image, error) {
- f, err := os.Open(path)
- if err != nil {
- return nil, fmt.Errorf("open image: %w", err)
- }
- defer f.Close()
-
- img, _, err := image.Decode(f)
- if err != nil {
- return nil, fmt.Errorf("decode image: %w", err)
- }
- return img, nil
-}
-
-// imageToFloat32Pixels converts an image to a float32 pixel array [H, W, C] in [0, 1] range
-func imageToFloat32Pixels(img image.Image, width, height int) []float32 {
- pixels := make([]float32, width*height*3)
- idx := 0
- for y := 0; y < height; y++ {
- for x := 0; x < width; x++ {
- r, g, b, _ := img.At(x, y).RGBA()
- pixels[idx] = float32(r) / 65535.0
- pixels[idx+1] = float32(g) / 65535.0
- pixels[idx+2] = float32(b) / 65535.0
- idx += 3
- }
- }
- return pixels
-}
-
-// normalizeImageNet applies ImageNet normalization to an image tensor
-func (p *Processor) normalizeImageNet(arr *mlx.Array) *mlx.Array {
- mean := mlx.NewArray(p.Config.ImageMean, []int32{1, 1, 3})
- std := mlx.NewArray(p.Config.ImageStd, []int32{1, 1, 3})
- return mlx.Div(mlx.Sub(arr, mean), std)
-}
-
-// prepareImageTensor transforms [H, W, C] to [B, C, H, W] and converts to bf16
-func prepareImageTensor(arr *mlx.Array) *mlx.Array {
- // Transpose to [C, H, W] and make contiguous
- arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
- // Add batch dimension [1, C, H, W]
- arr = mlx.ExpandDims(arr, 0)
- // Convert to bf16
- arr = mlx.ToBFloat16(arr)
- mlx.Eval(arr)
- return arr
-}
-
-// clampFloat clamps a value to [0, 255] and returns uint8
-func clampFloat(v, weightSum float64) uint8 {
- v /= weightSum
- if v < 0 {
- v = 0
- }
- if v > 255 {
- v = 255
- }
- return uint8(math.Round(v))
-}
-
-// ImageDims holds dimensions for a preprocessed image
-type ImageDims struct {
- // Original image dimensions
- OrigW, OrigH int32
- // Condition image dimensions (for vision encoder)
- CondW, CondH int32
- // VAE image dimensions
- VaeW, VaeH int32
- // Latent dimensions (VAE dims / vae_scale_factor)
- LatentW, LatentH int32
- // Patch dimensions (latent dims / patch_size)
- PatchW, PatchH int32
-}
-
-// ProcessorConfig holds image processor configuration
-type ProcessorConfig struct {
- // Condition image size (target pixel area for vision encoder input)
- // Python: CONDITION_IMAGE_SIZE = 384 * 384 = 147456
- // Pipeline resizes image to this area before passing to encode_prompt
- ConditionImageSize int32
-
- // VAE image size (target pixel area)
- // Python: VAE_IMAGE_SIZE = 1024 * 1024 = 1048576
- VAEImageSize int32
-
- // Image normalization (ImageNet stats for vision encoder)
- ImageMean []float32
- ImageStd []float32
-}
-
-// defaultProcessorConfig returns default processor config
-func defaultProcessorConfig() *ProcessorConfig {
- return &ProcessorConfig{
- ConditionImageSize: 384 * 384, // 147456 - matches Python CONDITION_IMAGE_SIZE
- VAEImageSize: 1024 * 1024, // 1048576 - matches Python VAE_IMAGE_SIZE
- ImageMean: []float32{0.48145466, 0.4578275, 0.40821073},
- ImageStd: []float32{0.26862954, 0.26130258, 0.27577711},
- }
-}
-
-// Processor handles image preprocessing for Qwen-Image-Edit
-type Processor struct {
- Config *ProcessorConfig
-}
-
-// Load loads the processor config
-func (p *Processor) Load(path string) error {
- p.Config = defaultProcessorConfig()
- return nil
-}
-
-// LoadAndPreprocess loads an image and preprocesses it for both paths
-// Returns: condImage (for vision encoder), vaeImage (for VAE encoding)
-func (p *Processor) LoadAndPreprocess(imagePath string) (*mlx.Array, *mlx.Array, error) {
- img, err := loadImageFile(imagePath)
- if err != nil {
- return nil, nil, err
- }
-
- bounds := img.Bounds()
- origW := bounds.Dx()
- origH := bounds.Dy()
- ratio := float64(origW) / float64(origH)
-
- // Calculate dimensions for condition image (vision encoder)
- // Python pipeline does TWO resizes:
- // 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
- // 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
- intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
- finalH, finalW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
-
- // Calculate dimensions for VAE image (1024x1024 area)
- // Use multiple of 32 (vae_scale_factor * patch_size * 2 = 8 * 2 * 2 = 32)
- vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
-
- // Preprocess for condition (vision encoder) - two-step resize
- condImage := p.preprocessImageTwoStep(img, intermediateW, intermediateH, finalW, finalH)
-
- // Preprocess for VAE ([-1, 1] range, 5D tensor)
- vaeImage := p.preprocessImageForVAE(img, vaeW, vaeH)
-
- return condImage, vaeImage, nil
-}
-
-// preprocessImageLanczos does single-step Lanczos resize for vision encoder
-// Matches Python VaeImageProcessor.resize with resample='lanczos' (the default)
-// Used by edit_plus pipeline for multi-image input
-// Returns: [B, C, H, W] normalized tensor
-func (p *Processor) preprocessImageLanczos(img image.Image, width, height int32) *mlx.Array {
- resized := resizeImageLanczos(img, int(width), int(height))
- pixels := imageToFloat32Pixels(resized, int(width), int(height))
- arr := mlx.NewArray(pixels, []int32{height, width, 3})
- arr = p.normalizeImageNet(arr)
- return prepareImageTensor(arr)
-}
-
-// preprocessImageTwoStep does two-step resize for vision encoder to match Python pipeline
-// Step 1: Lanczos resize from original to intermediate size (VaeImageProcessor.resize)
-// Step 2: Bicubic resize from intermediate to final size (Qwen2VLProcessor smart_resize)
-// Returns: [B, C, H, W] normalized tensor
-func (p *Processor) preprocessImageTwoStep(img image.Image, intermediateW, intermediateH, finalW, finalH int32) *mlx.Array {
- intermediate := resizeImageLanczos(img, int(intermediateW), int(intermediateH))
- resized := resizeImageBicubic(intermediate, int(finalW), int(finalH))
- pixels := imageToFloat32Pixels(resized, int(finalW), int(finalH))
- arr := mlx.NewArray(pixels, []int32{finalH, finalW, 3})
- arr = p.normalizeImageNet(arr)
- return prepareImageTensor(arr)
-}
-
-// preprocessImage converts image to tensor for vision encoder
-// Returns: [B, C, H, W] normalized tensor
-func (p *Processor) preprocessImage(img image.Image, width, height int32, normalize bool) *mlx.Array {
- resized := resizeImageBicubic(img, int(width), int(height))
- pixels := imageToFloat32Pixels(resized, int(width), int(height))
- arr := mlx.NewArray(pixels, []int32{height, width, 3})
- if normalize {
- arr = p.normalizeImageNet(arr)
- }
- return prepareImageTensor(arr)
-}
-
-// preprocessImageForVAE converts image to tensor for VAE encoding
-// Returns: [B, C, T, H, W] tensor in [-1, 1] range
-func (p *Processor) preprocessImageForVAE(img image.Image, width, height int32) *mlx.Array {
- resized := resizeImageLanczos(img, int(width), int(height))
- pixels := imageToFloat32Pixels(resized, int(width), int(height))
- arr := mlx.NewArray(pixels, []int32{height, width, 3})
-
- // Scale to [-1, 1]: arr * 2 - 1
- arr = mlx.MulScalar(arr, 2.0)
- arr = mlx.AddScalar(arr, -1.0)
-
- // Transpose to [C, H, W] and make contiguous
- arr = mlx.Contiguous(mlx.Transpose(arr, 2, 0, 1))
-
- // Add batch and temporal dimensions [1, C, 1, H, W]
- arr = mlx.ExpandDims(arr, 0) // [1, C, H, W]
- arr = mlx.ExpandDims(arr, 2) // [1, C, 1, H, W]
-
- arr = mlx.ToBFloat16(arr)
- mlx.Eval(arr)
- return arr
-}
-
-// smartResize implements Python Qwen2VL processor's smart_resize logic
-// Returns (resizedHeight, resizedWidth) that fit within min/max pixel constraints
-func smartResize(height, width, factor, minPixels, maxPixels int32) (int32, int32) {
- // Round to factor
- hBar := int32(math.Round(float64(height)/float64(factor))) * factor
- wBar := int32(math.Round(float64(width)/float64(factor))) * factor
-
- // Ensure minimum factor size
- if hBar < factor {
- hBar = factor
- }
- if wBar < factor {
- wBar = factor
- }
-
- // Check pixel constraints
- total := hBar * wBar
- if total > maxPixels {
- // Scale down
- beta := math.Sqrt(float64(maxPixels) / float64(total))
- hBar = int32(math.Floor(float64(height)*beta/float64(factor))) * factor
- wBar = int32(math.Floor(float64(width)*beta/float64(factor))) * factor
- } else if total < minPixels {
- // Scale up
- beta := math.Sqrt(float64(minPixels) / float64(total))
- hBar = int32(math.Ceil(float64(height)*beta/float64(factor))) * factor
- wBar = int32(math.Ceil(float64(width)*beta/float64(factor))) * factor
- }
-
- return hBar, wBar
-}
-
-// calculateDimensions calculates width and height for a target area while maintaining ratio
-// multiple: the value to round dimensions to (e.g., 28 for vision encoder with patch 14 and 2x2 merge)
-func calculateDimensions(targetArea int32, ratio float64, multiple int32) (int32, int32) {
- width := math.Sqrt(float64(targetArea) * ratio)
- height := width / ratio
-
- m := float64(multiple)
- width = math.Round(width/m) * m
- height = math.Round(height/m) * m
-
- // Ensure minimum dimensions
- if width < m {
- width = m
- }
- if height < m {
- height = m
- }
-
- return int32(width), int32(height)
-}
-
-// resizeImageLanczos resizes an image using Lanczos3 interpolation (matches PIL.LANCZOS)
-func resizeImageLanczos(img image.Image, width, height int) image.Image {
- bounds := img.Bounds()
- dst := image.NewRGBA(image.Rect(0, 0, width, height))
-
- // Lanczos3 kernel (a=3) to match PIL.LANCZOS
- lanczos3 := &draw.Kernel{
- Support: 3.0,
- At: func(t float64) float64 {
- if t == 0 {
- return 1.0
- }
- if t < 0 {
- t = -t
- }
- if t >= 3.0 {
- return 0.0
- }
- // sinc(t) * sinc(t/3)
- piT := math.Pi * t
- return (math.Sin(piT) / piT) * (math.Sin(piT/3) / (piT / 3))
- },
- }
- lanczos3.Scale(dst, dst.Bounds(), img, bounds, draw.Over, nil)
-
- return dst
-}
-
-// resizeImageBicubic resizes an image using bicubic interpolation (matches PIL.BICUBIC)
-// Uses separable interpolation with PIL's coordinate mapping for exact match
-func resizeImageBicubic(img image.Image, width, height int) image.Image {
- bounds := img.Bounds()
- srcW := bounds.Dx()
- srcH := bounds.Dy()
-
- // Convert to RGBA if needed
- var src *image.RGBA
- if rgba, ok := img.(*image.RGBA); ok {
- src = rgba
- } else {
- src = image.NewRGBA(bounds)
- for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
- for x := bounds.Min.X; x < bounds.Max.X; x++ {
- src.Set(x, y, img.At(x, y))
- }
- }
- }
-
- // Keys cubic with a=-0.5 (PIL BICUBIC)
- cubic := func(x float64) float64 {
- if x < 0 {
- x = -x
- }
- if x < 1 {
- return 1.5*x*x*x - 2.5*x*x + 1
- }
- if x < 2 {
- return -0.5*x*x*x + 2.5*x*x - 4*x + 2
- }
- return 0
- }
-
- // Horizontal pass: srcW -> width, keep srcH rows
- temp := image.NewRGBA(image.Rect(0, 0, width, srcH))
- for y := 0; y < srcH; y++ {
- for dstX := 0; dstX < width; dstX++ {
- // PIL coordinate mapping: center-to-center
- srcXf := (float64(dstX)+0.5)*(float64(srcW)/float64(width)) - 0.5
- baseX := int(math.Floor(srcXf))
-
- var sumR, sumG, sumB, sumA, weightSum float64
- for i := -1; i <= 2; i++ {
- sx := baseX + i
- if sx < 0 {
- sx = 0
- }
- if sx >= srcW {
- sx = srcW - 1
- }
-
- w := cubic(math.Abs(srcXf - float64(baseX+i)))
- c := src.RGBAAt(sx, y)
- sumR += float64(c.R) * w
- sumG += float64(c.G) * w
- sumB += float64(c.B) * w
- sumA += float64(c.A) * w
- weightSum += w
- }
-
- temp.SetRGBA(dstX, y, color.RGBA{
- clampFloat(sumR, weightSum),
- clampFloat(sumG, weightSum),
- clampFloat(sumB, weightSum),
- clampFloat(sumA, weightSum),
- })
- }
- }
-
- // Vertical pass: srcH -> height
- dst := image.NewRGBA(image.Rect(0, 0, width, height))
- for x := 0; x < width; x++ {
- for dstY := 0; dstY < height; dstY++ {
- srcYf := (float64(dstY)+0.5)*(float64(srcH)/float64(height)) - 0.5
- baseY := int(math.Floor(srcYf))
-
- var sumR, sumG, sumB, sumA, weightSum float64
- for j := -1; j <= 2; j++ {
- sy := baseY + j
- if sy < 0 {
- sy = 0
- }
- if sy >= srcH {
- sy = srcH - 1
- }
-
- w := cubic(math.Abs(srcYf - float64(baseY+j)))
- c := temp.RGBAAt(x, sy)
- sumR += float64(c.R) * w
- sumG += float64(c.G) * w
- sumB += float64(c.B) * w
- sumA += float64(c.A) * w
- weightSum += w
- }
-
- dst.SetRGBA(x, dstY, color.RGBA{
- clampFloat(sumR, weightSum),
- clampFloat(sumG, weightSum),
- clampFloat(sumB, weightSum),
- clampFloat(sumA, weightSum),
- })
- }
- }
-
- return dst
-}
-
-// LoadAndPreprocessMultiple loads multiple images and preprocesses them
-// Returns: condImages (for vision encoder), vaeImages (for VAE encoding), dims (per-image dimensions)
-func (p *Processor) LoadAndPreprocessMultiple(imagePaths []string) ([]*mlx.Array, []*mlx.Array, []ImageDims, error) {
- const vaeScaleFactor int32 = 8
- const patchSize int32 = 2
-
- condImages := make([]*mlx.Array, len(imagePaths))
- vaeImages := make([]*mlx.Array, len(imagePaths))
- dims := make([]ImageDims, len(imagePaths))
-
- for i, imagePath := range imagePaths {
- img, err := loadImageFile(imagePath)
- if err != nil {
- return nil, nil, nil, fmt.Errorf("image %d: %w", i, err)
- }
-
- bounds := img.Bounds()
- origW := int32(bounds.Dx())
- origH := int32(bounds.Dy())
- ratio := float64(origW) / float64(origH)
-
- // Calculate dimensions for condition image (vision encoder)
- // Python pipeline does TWO resizes:
- // 1. VaeImageProcessor.resize with Lanczos to CONDITION_IMAGE_SIZE (384x384 area)
- // 2. Qwen2VLProcessor's smart_resize with Bicubic to multiple of 28
- intermediateW, intermediateH := calculateDimensions(p.Config.ConditionImageSize, ratio, 32)
- condH, condW := smartResize(intermediateH, intermediateW, 28, 56*56, 28*28*1280)
-
- // Calculate dimensions for VAE image (1024x1024 area)
- vaeW, vaeH := calculateDimensions(p.Config.VAEImageSize, ratio, 32)
-
- // Calculate derived dimensions
- latentW := vaeW / vaeScaleFactor
- latentH := vaeH / vaeScaleFactor
- patchW := latentW / patchSize
- patchH := latentH / patchSize
-
- dims[i] = ImageDims{
- OrigW: origW,
- OrigH: origH,
- CondW: condW,
- CondH: condH,
- VaeW: vaeW,
- VaeH: vaeH,
- LatentW: latentW,
- LatentH: latentH,
- PatchW: patchW,
- PatchH: patchH,
- }
-
- fmt.Printf(" Image %d: orig=%dx%d, cond=%dx%d, vae=%dx%d, latent=%dx%d, patch=%dx%d\n",
- i+1, origW, origH, condW, condH, vaeW, vaeH, latentW, latentH, patchW, patchH)
-
- // Preprocess for condition (vision encoder) - two-step resize to match Python pipeline
- condImages[i] = p.preprocessImageTwoStep(img, intermediateW, intermediateH, condW, condH)
-
- // Preprocess for VAE ([-1, 1] range, 5D tensor)
- vaeImages[i] = p.preprocessImageForVAE(img, vaeW, vaeH)
- }
-
- return condImages, vaeImages, dims, nil
-}
diff --git a/x/imagegen/models/qwen_image_edit/qwen_image_edit.go b/x/imagegen/models/qwen_image_edit/qwen_image_edit.go
deleted file mode 100644
index d1e39498613..00000000000
--- a/x/imagegen/models/qwen_image_edit/qwen_image_edit.go
+++ /dev/null
@@ -1,625 +0,0 @@
-//go:build mlx
-
-// Package qwen_image_edit implements the Qwen-Image-Edit diffusion model for image editing.
-// It reuses components from qwen_image where possible.
-package qwen_image_edit
-
-import (
- "context"
- "fmt"
- "path/filepath"
- "time"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/models/qwen_image"
- "github.com/ollama/ollama/x/imagegen/tokenizer"
-)
-
-// GenerateConfig holds all options for image editing.
-type GenerateConfig struct {
- Prompt string
- NegativePrompt string // Unconditional prompt for CFG (empty string "" is valid)
- CFGScale float32 // CFG enabled when > 1.0 (default: 4.0)
- Width int32 // Output width (default: from input image)
- Height int32 // Output height (default: from input image)
- Steps int // Denoising steps (default: 50)
- Seed int64 // Random seed
- Progress func(step, totalSteps int) // Optional progress callback
-}
-
-// Model represents a Qwen-Image-Edit diffusion model.
-type Model struct {
- ModelPath string
- Tokenizer *tokenizer.Tokenizer
- Processor *Processor // Image processor for vision encoder
- TextEncoder *qwen_image.Qwen25VL // Qwen2.5-VL vision-language encoder (from qwen_image)
- Transformer *qwen_image.Transformer // Reuse qwen_image transformer
- VAE *VAE // Combined encoder + decoder
-}
-
-// Load loads the Qwen-Image-Edit model from a directory.
-func (m *Model) Load(modelPath string) error {
- fmt.Println("Loading Qwen-Image-Edit model...")
- start := time.Now()
-
- if mlx.GPUIsAvailable() {
- mlx.SetDefaultDeviceGPU()
- mlx.EnableCompile()
- }
-
- m.ModelPath = modelPath
-
- // Load tokenizer from processor directory
- fmt.Print(" Loading tokenizer... ")
- processorPath := filepath.Join(modelPath, "processor")
- tok, err := tokenizer.Load(processorPath)
- if err != nil {
- // Fallback to tokenizer directory
- tokenizerPath := filepath.Join(modelPath, "tokenizer")
- tok, err = tokenizer.Load(tokenizerPath)
- if err != nil {
- return fmt.Errorf("tokenizer: %w", err)
- }
- }
- m.Tokenizer = tok
- fmt.Println("✓")
-
- // Load processor (image preprocessing config)
- fmt.Print(" Loading processor... ")
- m.Processor = &Processor{}
- if err := m.Processor.Load(processorPath); err != nil {
- return fmt.Errorf("processor: %w", err)
- }
- fmt.Println("✓")
-
- // Load vision-language text encoder (Qwen2.5-VL from qwen_image package)
- m.TextEncoder = &qwen_image.Qwen25VL{}
- if err := m.TextEncoder.Load(filepath.Join(modelPath, "text_encoder")); err != nil {
- return fmt.Errorf("text encoder: %w", err)
- }
- mlx.Eval(mlx.Collect(m.TextEncoder)...)
- fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
- float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
- float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- // Load transformer (reuse qwen_image)
- m.Transformer = &qwen_image.Transformer{}
- if err := m.Transformer.Load(filepath.Join(modelPath, "transformer")); err != nil {
- return fmt.Errorf("transformer: %w", err)
- }
- mlx.Eval(mlx.Collect(m.Transformer)...)
- fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
- float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
- float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- // Load VAE (encoder + decoder)
- m.VAE = &VAE{}
- if err := m.VAE.Load(filepath.Join(modelPath, "vae")); err != nil {
- return fmt.Errorf("VAE: %w", err)
- }
- mlx.Eval(mlx.Collect(m.VAE)...)
- fmt.Printf(" (%.1f GB, peak %.1f GB)\n",
- float64(mlx.MetalGetActiveMemory())/(1024*1024*1024),
- float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
-
- mem := mlx.MetalGetActiveMemory()
- peak := mlx.MetalGetPeakMemory()
- fmt.Printf(" Loaded in %.2fs (%.1f GB active, %.1f GB peak)\n",
- time.Since(start).Seconds(),
- float64(mem)/(1024*1024*1024),
- float64(peak)/(1024*1024*1024))
-
- return nil
-}
-
-// Edit edits an image based on a text prompt.
-// inputImagePath: path to input image
-// prompt: text description of desired edit
-func (m *Model) Edit(inputImagePath string, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
- return m.EditFromConfig([]string{inputImagePath}, &GenerateConfig{
- Prompt: prompt,
- Width: width,
- Height: height,
- Steps: steps,
- Seed: seed,
- })
-}
-
-// EditFromConfig edits images using the unified config struct.
-// Accepts one or more input images.
-func (m *Model) EditFromConfig(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
- if len(inputImagePaths) == 0 {
- return nil, fmt.Errorf("no input images provided")
- }
-
- start := time.Now()
- result, err := m.edit(inputImagePaths, cfg)
- if err != nil {
- return nil, err
- }
-
- if cfg.NegativePrompt != "" {
- fmt.Printf("Edited %d image(s) with CFG (scale=%.1f) in %.2fs (%d steps)\n",
- len(inputImagePaths), cfg.CFGScale, time.Since(start).Seconds(), cfg.Steps)
- } else {
- fmt.Printf("Edited %d image(s) in %.2fs (%d steps)\n",
- len(inputImagePaths), time.Since(start).Seconds(), cfg.Steps)
- }
- return result, nil
-}
-
-// EditImage implements model.ImageEditModel interface.
-func (m *Model) EditImage(ctx context.Context, inputImagePath, prompt string, width, height int32, steps int, seed int64) (*mlx.Array, error) {
- return m.Edit(inputImagePath, prompt, width, height, steps, seed)
-}
-
-// EditMultiImage edits using multiple source images.
-// This matches diffusers' QwenImageEditPlusPipeline behavior.
-func (m *Model) EditMultiImage(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
- return m.EditFromConfig(inputImagePaths, cfg)
-}
-
-// edit is the internal editing pipeline that handles one or more images.
-func (m *Model) edit(inputImagePaths []string, cfg *GenerateConfig) (*mlx.Array, error) {
- // Apply defaults
- if cfg.Steps <= 0 {
- cfg.Steps = 50
- }
- if cfg.CFGScale <= 0 {
- cfg.CFGScale = 4.0
- }
-
- // Load and preprocess all input images
- fmt.Printf("Loading %d image(s)...\n", len(inputImagePaths))
- condImages, vaeImages, inputDims, err := m.Processor.LoadAndPreprocessMultiple(inputImagePaths)
- if err != nil {
- return nil, fmt.Errorf("preprocess images: %w", err)
- }
- for _, img := range condImages {
- mlx.Keep(img)
- }
- for _, img := range vaeImages {
- mlx.Keep(img)
- }
- mlx.Eval(append(condImages, vaeImages...)...)
-
- useCFG := cfg.NegativePrompt != ""
- tcfg := m.Transformer.Config
- vaeScaleFactor := int32(8)
-
- // Output dimensions - if not specified, use first input image dimensions
- if cfg.Width <= 0 {
- cfg.Width = inputDims[0].VaeW
- }
- if cfg.Height <= 0 {
- cfg.Height = inputDims[0].VaeH
- }
-
- // Output (noise) latent dimensions
- outLatentH := cfg.Height / vaeScaleFactor
- outLatentW := cfg.Width / vaeScaleFactor
- outPH := outLatentH / tcfg.PatchSize
- outPW := outLatentW / tcfg.PatchSize
- noiseSeqLen := outPH * outPW
- imgSeqLen := noiseSeqLen
-
- // Encode prompt with all images for conditioning
- posEmb, _, _, err := m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.Prompt, condImages)
- if err != nil {
- return nil, fmt.Errorf("encoding prompt: %w", err)
- }
- mlx.Keep(posEmb)
- mlx.Eval(posEmb)
-
- var negEmb *mlx.Array
- if useCFG {
- negEmb, _, _, err = m.TextEncoder.EncodePromptWithImages(m.Tokenizer, cfg.NegativePrompt, condImages)
- if err != nil {
- return nil, fmt.Errorf("encoding negative prompt: %w", err)
- }
- mlx.Keep(negEmb)
- mlx.Eval(negEmb)
- }
-
- // Pad sequences to same length for CFG
- txtLen := posEmb.Shape()[1]
- if useCFG {
- negLen := negEmb.Shape()[1]
- if negLen > txtLen {
- txtLen = negLen
- }
- if posEmb.Shape()[1] < txtLen {
- posEmb = padSequence(posEmb, txtLen)
- }
- if negEmb.Shape()[1] < txtLen {
- negEmb = padSequence(negEmb, txtLen)
- }
- mlx.Keep(posEmb, negEmb)
- mlx.Eval(posEmb, negEmb)
- }
-
- // Pre-compute batched embeddings for CFG (single forward pass optimization)
- var batchedEmb *mlx.Array
- if useCFG {
- batchedEmb = mlx.Concatenate([]*mlx.Array{posEmb, negEmb}, 0)
- mlx.Keep(batchedEmb)
- mlx.Eval(batchedEmb)
- }
-
- // Encode all input images to latents and concatenate
- fmt.Println("Encoding images to latents...")
- allImageLatentsPacked := make([]*mlx.Array, len(vaeImages))
- for i, vaeImage := range vaeImages {
- imageLatents := m.VAE.Encode(vaeImage)
- imageLatents = m.VAE.Normalize(imageLatents)
- imageLatents2D := mlx.Squeeze(imageLatents, 2)
- packed := qwen_image.PackLatents(imageLatents2D, tcfg.PatchSize)
- mlx.Keep(packed)
- mlx.Eval(packed)
- allImageLatentsPacked[i] = packed
- }
-
- imageLatentsPacked := mlx.Concatenate(allImageLatentsPacked, 1)
- mlx.Keep(imageLatentsPacked)
- mlx.Eval(imageLatentsPacked)
-
- // Scheduler
- scheduler := qwen_image.NewFlowMatchScheduler(qwen_image.DefaultSchedulerConfig())
- scheduler.SetTimesteps(cfg.Steps, noiseSeqLen)
-
- // Init noise latents in packed format
- packedChannels := tcfg.OutChannels * tcfg.PatchSize * tcfg.PatchSize
- packedNoise := scheduler.InitNoisePacked(1, noiseSeqLen, packedChannels, cfg.Seed)
- latents := qwen_image.UnpackLatents(packedNoise, outLatentH, outLatentW, tcfg.PatchSize)
- mlx.Eval(latents)
-
- // RoPE cache
- ropeCache := PrepareRoPEMultiImage(outPH, outPW, inputDims, txtLen, tcfg.AxesDimsRope)
- mlx.Keep(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
- mlx.Eval(ropeCache.ImgFreqs, ropeCache.TxtFreqs)
-
- // Denoising loop
- fmt.Printf("Running denoising (%d steps)...\n", cfg.Steps)
- for i := 0; i < cfg.Steps; i++ {
- stepStart := time.Now()
- if cfg.Progress != nil {
- cfg.Progress(i+1, cfg.Steps)
- }
-
- t := scheduler.Timesteps[i]
- timestep := mlx.ToBFloat16(mlx.NewArray([]float32{t}, []int32{1}))
- mlx.Eval(timestep)
-
- latents2D := mlx.Squeeze(latents, 2)
- patches := qwen_image.PackLatents(latents2D, tcfg.PatchSize)
- latentInput := mlx.Concatenate([]*mlx.Array{patches, imageLatentsPacked}, 1)
-
- var output *mlx.Array
- if useCFG {
- // CFG Batching: single forward pass with batch=2
- // Tile inputs: [1, L, D] -> [2, L, D]
- batchedLatentInput := mlx.Tile(latentInput, []int32{2, 1, 1})
- batchedTimestep := mlx.Tile(timestep, []int32{2})
-
- // Single batched forward pass
- batchedOutput := m.Transformer.Forward(batchedLatentInput, batchedEmb, batchedTimestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
-
- // Split output: [2, L, D] -> pos [1, L, D], neg [1, L, D]
- D := batchedOutput.Shape()[2]
- posOutput := mlx.Slice(batchedOutput, []int32{0, 0, 0}, []int32{1, imgSeqLen, D})
- negOutput := mlx.Slice(batchedOutput, []int32{1, 0, 0}, []int32{2, imgSeqLen, D})
-
- output = applyCFGWithNormRescale(posOutput, negOutput, cfg.CFGScale)
- } else {
- output = m.Transformer.Forward(latentInput, posEmb, timestep, ropeCache.ImgFreqs, ropeCache.TxtFreqs)
- output = mlx.Slice(output, []int32{0, 0, 0}, []int32{1, imgSeqLen, output.Shape()[2]})
- }
-
- noisePred := qwen_image.UnpackLatents(output, outLatentH, outLatentW, tcfg.PatchSize)
- oldLatents := latents
- latents = scheduler.Step(noisePred, latents, i)
- mlx.Eval(latents)
- oldLatents.Free()
-
- fmt.Printf(" Step %d/%d: t=%.4f (%.2fs)\n", i+1, cfg.Steps, t, time.Since(stepStart).Seconds())
- }
-
- // Free denoising temporaries
- posEmb.Free()
- if negEmb != nil {
- negEmb.Free()
- }
- if batchedEmb != nil {
- batchedEmb.Free()
- }
- ropeCache.ImgFreqs.Free()
- ropeCache.TxtFreqs.Free()
- imageLatentsPacked.Free()
-
- // Decode latents
- decoded := m.decodeAndPostprocess(latents)
- latents.Free()
-
- fmt.Printf(" Peak memory: %.2f GB\n", float64(mlx.MetalGetPeakMemory())/(1024*1024*1024))
- return decoded, nil
-}
-
-// applyCFGWithNormRescale applies classifier-free guidance with norm rescaling.
-// This prevents CFG from inflating magnitude too much.
-func applyCFGWithNormRescale(posOutput, negOutput *mlx.Array, scale float32) *mlx.Array {
- // Upcast to float32 for precision
- posF32 := mlx.AsType(posOutput, mlx.DtypeFloat32)
- negF32 := mlx.AsType(negOutput, mlx.DtypeFloat32)
-
- // CFG: pred = neg + scale * (pos - neg)
- diff := mlx.Sub(posF32, negF32)
- scaledDiff := mlx.MulScalar(diff, scale)
- combPred := mlx.Add(negF32, scaledDiff)
-
- // Norm rescaling: rescale combined prediction to match conditional norm
- condNorm := mlx.Sqrt(mlx.Sum(mlx.Square(posF32), -1, true))
- combNorm := mlx.Sqrt(mlx.Sum(mlx.Square(combPred), -1, true))
- output := mlx.Mul(combPred, mlx.Div(condNorm, combNorm))
-
- mlx.Eval(output)
- return mlx.ToBFloat16(output)
-}
-
-// decodeAndPostprocess denormalizes latents, decodes through VAE, and scales to [0,1].
-func (m *Model) decodeAndPostprocess(latents *mlx.Array) *mlx.Array {
- latents = m.VAE.Denormalize(latents)
- decoded := m.VAE.Decode(latents)
-
- // Post-process: squeeze temporal dim and rescale to [0, 1]
- decoded = mlx.Squeeze(decoded, 2)
- decoded = mlx.AddScalar(decoded, 1.0)
- decoded = mlx.DivScalar(decoded, 2.0)
- decoded = mlx.ClipScalar(decoded, 0.0, 1.0, true, true)
- mlx.Eval(decoded)
- return decoded
-}
-
-// padSequence pads a sequence tensor to the target length with zeros
-func padSequence(x *mlx.Array, targetLen int32) *mlx.Array {
- shape := x.Shape()
- currentLen := shape[1]
- if currentLen >= targetLen {
- return x
- }
- padLen := targetLen - currentLen
- // Pad on sequence dimension (axis 1)
- return mlx.Pad(x, []int32{0, 0, 0, padLen, 0, 0})
-}
-
-// LoadPersistent is an alias for backward compatibility.
-func LoadPersistent(modelPath string) (*Model, error) {
- m := &Model{}
- if err := m.Load(modelPath); err != nil {
- return nil, err
- }
- return m, nil
-}
-
-// PrepareRoPEMultiImage computes RoPE with interpolation for image editing.
-// Handles single or multiple input images with different resolutions.
-//
-// Parameters:
-// - outPH, outPW: output patch dimensions (noise latent resolution)
-// - inputDims: patch dimensions for each input image [(pH1, pW1), (pH2, pW2), ...]
-// - txtLen: text sequence length
-// - axesDims: RoPE axis dimensions [16, 56, 56]
-//
-// Returns RoPE cache where:
-// - ImgFreqs has (outPH*outPW + sum(inPH*inPW for each image)) positions
-// - First outPH*outPW positions are for noise latents (standard RoPE at output res)
-// - Following positions are for each input image (interpolated from output res)
-func PrepareRoPEMultiImage(outPH, outPW int32, inputDims []ImageDims, txtLen int32, axesDims []int32) *qwen_image.RoPECache {
- theta := float64(10000)
- maxIdx := int32(4096)
-
- // Compute base frequencies for each axis dimension
- freqsT := qwen_image.ComputeAxisFreqs(axesDims[0], theta)
- freqsH := qwen_image.ComputeAxisFreqs(axesDims[1], theta)
- freqsW := qwen_image.ComputeAxisFreqs(axesDims[2], theta)
-
- // Build frequency lookup tables
- posFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
- posFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, false)
- posFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, false)
- negFreqsT := qwen_image.MakeFreqTable(maxIdx, freqsT, true) // For frame -1 on last condition image
- negFreqsH := qwen_image.MakeFreqTable(maxIdx, freqsH, true)
- negFreqsW := qwen_image.MakeFreqTable(maxIdx, freqsW, true)
-
- headDim := int32(len(freqsT)+len(freqsH)+len(freqsW)) * 2
-
- // Helper to compute RoPE for a single position at output resolution with scale_rope
- computePosFreqs := func(framePos, y, x int32) []float32 {
- row := make([]float32, headDim)
- idx := 0
-
- // Frame position
- for i := 0; i < len(freqsT)*2; i++ {
- row[idx+i] = posFreqsT[framePos][i]
- }
- idx += len(freqsT) * 2
-
- // Height with scale_rope centering (using OUTPUT dimensions)
- outHHalf := outPH / 2
- hNegCount := outPH - outHHalf
- if y < hNegCount {
- negTableIdx := maxIdx - hNegCount + y
- for i := 0; i < len(freqsH)*2; i++ {
- row[idx+i] = negFreqsH[negTableIdx][i]
- }
- } else {
- posIdx := y - hNegCount
- for i := 0; i < len(freqsH)*2; i++ {
- row[idx+i] = posFreqsH[posIdx][i]
- }
- }
- idx += len(freqsH) * 2
-
- // Width with scale_rope centering (using OUTPUT dimensions)
- outWHalf := outPW / 2
- wNegCount := outPW - outWHalf
- if x < wNegCount {
- negTableIdx := maxIdx - wNegCount + x
- for i := 0; i < len(freqsW)*2; i++ {
- row[idx+i] = negFreqsW[negTableIdx][i]
- }
- } else {
- posIdx := x - wNegCount
- for i := 0; i < len(freqsW)*2; i++ {
- row[idx+i] = posFreqsW[posIdx][i]
- }
- }
-
- return row
- }
-
- // Helper to compute RoPE for frame -1 (used for last condition image)
- // This matches Python's _compute_condition_freqs which uses freqs_neg[0][-1:]
- computeNegFrameFreqs := func(y, x int32) []float32 {
- row := make([]float32, headDim)
- idx := 0
-
- // Frame -1: use last row of negative frame frequencies
- negFrameIdx := maxIdx - 1
- for i := 0; i < len(freqsT)*2; i++ {
- row[idx+i] = negFreqsT[negFrameIdx][i]
- }
- idx += len(freqsT) * 2
-
- // Height with scale_rope centering (using OUTPUT dimensions)
- outHHalf := outPH / 2
- hNegCount := outPH - outHHalf
- if y < hNegCount {
- negTableIdx := maxIdx - hNegCount + y
- for i := 0; i < len(freqsH)*2; i++ {
- row[idx+i] = negFreqsH[negTableIdx][i]
- }
- } else {
- posIdx := y - hNegCount
- for i := 0; i < len(freqsH)*2; i++ {
- row[idx+i] = posFreqsH[posIdx][i]
- }
- }
- idx += len(freqsH) * 2
-
- // Width with scale_rope centering (using OUTPUT dimensions)
- outWHalf := outPW / 2
- wNegCount := outPW - outWHalf
- if x < wNegCount {
- negTableIdx := maxIdx - wNegCount + x
- for i := 0; i < len(freqsW)*2; i++ {
- row[idx+i] = negFreqsW[negTableIdx][i]
- }
- } else {
- posIdx := x - wNegCount
- for i := 0; i < len(freqsW)*2; i++ {
- row[idx+i] = posFreqsW[posIdx][i]
- }
- }
-
- return row
- }
-
- // Total image sequence length: noise + all input images
- noiseSeqLen := outPH * outPW
- totalImgLen := noiseSeqLen
- for _, dims := range inputDims {
- totalImgLen += dims.PatchH * dims.PatchW
- }
-
- imgFreqsData := make([]float32, totalImgLen*headDim)
- idx := int32(0)
-
- // Segment 0: Noise latents - standard RoPE at output resolution (frame 0)
- for y := int32(0); y < outPH; y++ {
- for x := int32(0); x < outPW; x++ {
- row := computePosFreqs(0, y, x)
- copy(imgFreqsData[idx:], row)
- idx += headDim
- }
- }
-
- // Segments 1..N: Edit image latents - INTERPOLATED RoPE
- // For single image: use frame 1 (matches original PrepareRoPEInterpolated)
- // For multiple images: Python uses frame -1 for the LAST condition image
- // (_compute_condition_freqs), positive indices for others.
- numImages := len(inputDims)
- lastImgIdx := numImages - 1
- for imgIdx, dims := range inputDims {
- inPH := dims.PatchH
- inPW := dims.PatchW
-
- // Determine frame index for this image
- // Single image case: use frame 1 (like original PrepareRoPEInterpolated)
- // Multi-image case: last image uses frame -1, others use frame 1, 2, etc.
- useNegFrame := numImages > 1 && imgIdx == lastImgIdx
-
- // Map each input position to an output position using linear interpolation
- for y := int32(0); y < inPH; y++ {
- for x := int32(0); x < inPW; x++ {
- // Interpolate: map input (y, x) to output grid position
- // This is the key fix from DiffSynth's forward_sampling
- var yOut, xOut int32
- if inPH == 1 {
- yOut = 0
- } else {
- // Linear interpolation: y_out = y * (outPH - 1) / (inPH - 1)
- yOut = y * (outPH - 1) / (inPH - 1)
- }
- if inPW == 1 {
- xOut = 0
- } else {
- xOut = x * (outPW - 1) / (inPW - 1)
- }
-
- var row []float32
- if useNegFrame {
- // Last image in multi-image uses frame -1
- row = computeNegFrameFreqs(yOut, xOut)
- } else {
- // Single image uses frame 1, multi-image uses frame 1, 2, etc.
- frameIdx := int32(imgIdx + 1)
- row = computePosFreqs(frameIdx, yOut, xOut)
- }
- copy(imgFreqsData[idx:], row)
- idx += headDim
- }
- }
- }
-
- imgFreqs := mlx.NewArray(imgFreqsData, []int32{totalImgLen, headDim})
- imgFreqs = mlx.ToBFloat16(imgFreqs)
-
- // Text frequencies - start after max video index
- maxVidIdx := max(outPH/2, outPW/2)
-
- txtFreqsData := make([]float32, txtLen*headDim)
- idx = 0
- for t := int32(0); t < txtLen; t++ {
- pos := maxVidIdx + t
- for i := 0; i < len(freqsT)*2; i++ {
- txtFreqsData[idx+int32(i)] = posFreqsT[pos][i]
- }
- idx += int32(len(freqsT) * 2)
- for i := 0; i < len(freqsH)*2; i++ {
- txtFreqsData[idx+int32(i)] = posFreqsH[pos][i]
- }
- idx += int32(len(freqsH) * 2)
- for i := 0; i < len(freqsW)*2; i++ {
- txtFreqsData[idx+int32(i)] = posFreqsW[pos][i]
- }
- idx += int32(len(freqsW) * 2)
- }
-
- txtFreqs := mlx.NewArray(txtFreqsData, []int32{txtLen, headDim})
- txtFreqs = mlx.ToBFloat16(txtFreqs)
-
- return &qwen_image.RoPECache{
- ImgFreqs: imgFreqs,
- TxtFreqs: txtFreqs,
- }
-}
diff --git a/x/imagegen/models/qwen_image_edit/rope_test.go b/x/imagegen/models/qwen_image_edit/rope_test.go
deleted file mode 100644
index 200940fbe6f..00000000000
--- a/x/imagegen/models/qwen_image_edit/rope_test.go
+++ /dev/null
@@ -1,249 +0,0 @@
-//go:build mlx
-
-package qwen_image_edit
-
-import (
- "fmt"
- "math"
- "os"
- "path/filepath"
- "runtime"
- "testing"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/models/qwen_image"
-)
-
-// TestMain initializes MLX before running tests.
-// If MLX libraries are not available, tests are skipped.
-func TestMain(m *testing.M) {
- // Change to repo root so ./build/lib/ollama/ path works
- _, thisFile, _, _ := runtime.Caller(0)
- repoRoot := filepath.Join(filepath.Dir(thisFile), "..", "..", "..", "..")
- if err := os.Chdir(repoRoot); err != nil {
- fmt.Printf("Failed to change to repo root: %v\n", err)
- os.Exit(1)
- }
-
- if err := mlx.InitMLX(); err != nil {
- fmt.Printf("Skipping qwen_image_edit tests: %v\n", err)
- os.Exit(0)
- }
- os.Exit(m.Run())
-}
-
-// TestComputeAxisFreqs verifies frequency computation matches Python reference
-func TestComputeAxisFreqs(t *testing.T) {
- theta := float64(10000)
-
- // Expected values from Python:
- // freqs = 1.0 / (theta ** (np.arange(0, half_dim) / half_dim))
- expectedFreqsT := []float64{
- 1.000000000000000, 0.316227766016838, 0.100000000000000, 0.031622776601684,
- 0.010000000000000, 0.003162277660168, 0.001000000000000, 0.000316227766017,
- }
-
- expectedFreqsH_first4 := []float64{
- 1.000000000000000, 0.719685673001152, 0.517947467923121, 0.372759372031494,
- }
-
- expectedFreqsH_last4 := []float64{
- 0.000372759372031, 0.000268269579528, 0.000193069772888, 0.000138949549437,
- }
-
- // Test temporal frequencies (dim=16)
- freqsT := qwen_image.ComputeAxisFreqs(16, theta)
- if len(freqsT) != 8 {
- t.Fatalf("expected 8 temporal frequencies, got %d", len(freqsT))
- }
- for i, expected := range expectedFreqsT {
- if diff := math.Abs(freqsT[i] - expected); diff > 1e-10 {
- t.Errorf("freqsT[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsT[i], diff)
- }
- }
-
- // Test height/width frequencies (dim=56)
- freqsH := qwen_image.ComputeAxisFreqs(56, theta)
- if len(freqsH) != 28 {
- t.Fatalf("expected 28 height frequencies, got %d", len(freqsH))
- }
- for i, expected := range expectedFreqsH_first4 {
- if diff := math.Abs(freqsH[i] - expected); diff > 1e-10 {
- t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", i, expected, freqsH[i], diff)
- }
- }
- for i, expected := range expectedFreqsH_last4 {
- idx := 24 + i // last 4 of 28
- if diff := math.Abs(freqsH[idx] - expected); diff > 1e-10 {
- t.Errorf("freqsH[%d]: expected %.15f, got %.15f, diff %.2e", idx, expected, freqsH[idx], diff)
- }
- }
-}
-
-// TestMakeFreqTable verifies the frequency lookup table for both positive and negative positions
-func TestMakeFreqTable(t *testing.T) {
- theta := float64(10000)
- freqsT := qwen_image.ComputeAxisFreqs(16, theta)
- maxIdx := int32(4096)
-
- // Test positive table
- posTable := qwen_image.MakeFreqTable(maxIdx, freqsT, false)
-
- // Position 0 should give cos=1, sin=0 for all frequencies
- for i := 0; i < len(freqsT)*2; i += 2 {
- if posTable[0][i] != 1.0 {
- t.Errorf("posTable[0][%d] (cos): expected 1.0, got %f", i, posTable[0][i])
- }
- if posTable[0][i+1] != 0.0 {
- t.Errorf("posTable[0][%d] (sin): expected 0.0, got %f", i+1, posTable[0][i+1])
- }
- }
-
- // Position 1, first frequency (1.0): angle = 1*1 = 1
- // cos(1) = 0.5403, sin(1) = 0.8415
- if diff := math.Abs(float64(posTable[1][0]) - 0.5403023058681398); diff > 1e-6 {
- t.Errorf("posTable[1][0] (cos): expected 0.5403, got %f", posTable[1][0])
- }
- if diff := math.Abs(float64(posTable[1][1]) - 0.8414709848078965); diff > 1e-6 {
- t.Errorf("posTable[1][1] (sin): expected 0.8415, got %f", posTable[1][1])
- }
-
- // Test negative table
- negTable := qwen_image.MakeFreqTable(maxIdx, freqsT, true)
-
- // negTable[4095] corresponds to position -1
- // cos(-1) = cos(1), sin(-1) = -sin(1)
- if diff := math.Abs(float64(negTable[4095][0]) - 0.5403023058681398); diff > 1e-6 {
- t.Errorf("negTable[4095][0] (cos(-1)): expected 0.5403, got %f", negTable[4095][0])
- }
- if diff := math.Abs(float64(negTable[4095][1]) - (-0.8414709848078965)); diff > 1e-6 {
- t.Errorf("negTable[4095][1] (sin(-1)): expected -0.8415, got %f", negTable[4095][1])
- }
-
- // negTable[4094] corresponds to position -2
- // cos(-2) = cos(2), sin(-2) = -sin(2)
- cos2 := math.Cos(2.0)
- sin2 := math.Sin(2.0)
- if diff := math.Abs(float64(negTable[4094][0]) - cos2); diff > 1e-6 {
- t.Errorf("negTable[4094][0] (cos(-2)): expected %f, got %f", cos2, negTable[4094][0])
- }
- if diff := math.Abs(float64(negTable[4094][1]) - (-sin2)); diff > 1e-6 {
- t.Errorf("negTable[4094][1] (sin(-2)): expected %f, got %f", -sin2, negTable[4094][1])
- }
-}
-
-// TestPrepareRoPE_QwenImage verifies qwen_image.PrepareRoPE for single-segment case
-func TestPrepareRoPE_QwenImage(t *testing.T) {
- if !mlx.GPUIsAvailable() {
- t.Skip("GPU not available")
- }
-
- mlx.SetDefaultDeviceCPU()
-
- // 4x4 patch grid, single image
- imgH, imgW := int32(4), int32(4)
- txtLen := int32(5)
- axesDims := []int32{16, 56, 56}
-
- cache := qwen_image.PrepareRoPE(imgH, imgW, txtLen, axesDims)
- mlx.Eval(cache.ImgFreqs, cache.TxtFreqs)
-
- // Check shapes
- imgShape := cache.ImgFreqs.Shape()
- if imgShape[0] != 16 { // 4*4 patches
- t.Errorf("ImgFreqs seq len: expected 16, got %d", imgShape[0])
- }
-
- // For single image (frame=0), all temporal values should be cos=1, sin=0
- imgFreqsCPU := mlx.AsType(cache.ImgFreqs, mlx.DtypeFloat32)
- mlx.Eval(imgFreqsCPU)
- imgData := imgFreqsCPU.Data()
-
- // Check first 16 values of patch 0 (temporal cos/sin pairs)
- for i := 0; i < 16; i += 2 {
- cosVal := imgData[i]
- sinVal := imgData[i+1]
- if diff := math.Abs(float64(cosVal - 1.0)); diff > 1e-5 {
- t.Errorf("ImgFreqs[0][%d] (cos): expected 1.0, got %f", i, cosVal)
- }
- if diff := math.Abs(float64(sinVal - 0.0)); diff > 1e-5 {
- t.Errorf("ImgFreqs[0][%d] (sin): expected 0.0, got %f", i+1, sinVal)
- }
- }
-
- cache.ImgFreqs.Free()
- cache.TxtFreqs.Free()
-}
-
-// TestScaleRopePositions verifies the centered position calculation for scale_rope=True
-func TestScaleRopePositions(t *testing.T) {
- // For a 4x4 grid with scale_rope=True:
- // hHalf = 2, wHalf = 2
- // hNegCount = 4 - 2 = 2 (positions 0,1 are negative)
- // wNegCount = 4 - 2 = 2 (positions 0,1 are negative)
- //
- // Height positions:
- // y=0: -(4-2) + 0 = -2
- // y=1: -(4-2) + 1 = -1
- // y=2: 2 - 2 = 0
- // y=3: 3 - 2 = 1
- //
- // Same for width
-
- pH, pW := int32(4), int32(4)
- hHalf := pH / 2
- wHalf := pW / 2
- hNegCount := pH - hHalf
- wNegCount := pW - wHalf
-
- expectedH := []int32{-2, -1, 0, 1}
- expectedW := []int32{-2, -1, 0, 1}
-
- for y := int32(0); y < pH; y++ {
- var hPos int32
- if y < hNegCount {
- hPos = -(pH - hHalf) + y
- } else {
- hPos = y - hNegCount
- }
- if hPos != expectedH[y] {
- t.Errorf("y=%d: expected h_pos=%d, got %d", y, expectedH[y], hPos)
- }
- }
-
- for x := int32(0); x < pW; x++ {
- var wPos int32
- if x < wNegCount {
- wPos = -(pW - wHalf) + x
- } else {
- wPos = x - wNegCount
- }
- if wPos != expectedW[x] {
- t.Errorf("x=%d: expected w_pos=%d, got %d", x, expectedW[x], wPos)
- }
- }
-}
-
-// TestRoPEHeadDimensions verifies the head dimension breakdown
-func TestRoPEHeadDimensions(t *testing.T) {
- // axes_dims_rope = [16, 56, 56]
- // Each dimension uses half the values for frequencies
- // So we get: 8 + 28 + 28 = 64 frequency values
- // Each frequency produces cos + sin, so: 64 * 2 = 128 total values per position
-
- axesDims := []int32{16, 56, 56}
- expectedFreqs := (axesDims[0]/2 + axesDims[1]/2 + axesDims[2]/2)
- expectedHeadDim := expectedFreqs * 2
-
- if expectedFreqs != 64 {
- t.Errorf("expected 64 frequency values, got %d", expectedFreqs)
- }
- if expectedHeadDim != 128 {
- t.Errorf("expected head_dim=128, got %d", expectedHeadDim)
- }
-
- // This should match the transformer's attention head dimension
- // hidden_size = 3072, num_heads = 24
- // head_dim = 3072 / 24 = 128
-}
-
diff --git a/x/imagegen/models/qwen_image_edit/vae.go b/x/imagegen/models/qwen_image_edit/vae.go
deleted file mode 100644
index 3dbe7ef3cd4..00000000000
--- a/x/imagegen/models/qwen_image_edit/vae.go
+++ /dev/null
@@ -1,642 +0,0 @@
-//go:build mlx
-
-package qwen_image_edit
-
-import (
- "fmt"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/safetensors"
-)
-
-// VAEConfig holds Qwen-Image VAE configuration
-type VAEConfig struct {
- ZDim int32 `json:"z_dim"` // 16
- BaseDim int32 `json:"base_dim"` // 96
- DimMult []int32 `json:"dim_mult"` // [1, 2, 4, 4]
- NumResBlocks int32 `json:"num_res_blocks"` // 2
- LatentsMean []float32 `json:"latents_mean"` // 16 values
- LatentsStd []float32 `json:"latents_std"` // 16 values
- TemperalDownsample []bool `json:"temperal_downsample"` // [false, true, true]
-}
-
-// defaultVAEConfig returns config for Qwen-Image VAE
-func defaultVAEConfig() *VAEConfig {
- return &VAEConfig{
- ZDim: 16,
- BaseDim: 96,
- DimMult: []int32{1, 2, 4, 4},
- NumResBlocks: 2,
- LatentsMean: []float32{
- -0.7571, -0.7089, -0.9113, 0.1075,
- -0.1745, 0.9653, -0.1517, 1.5508,
- 0.4134, -0.0715, 0.5517, -0.3632,
- -0.1922, -0.9497, 0.2503, -0.2921,
- },
- LatentsStd: []float32{
- 2.8184, 1.4541, 2.3275, 2.6558,
- 1.2196, 1.7708, 2.6052, 2.0743,
- 3.2687, 2.1526, 2.8652, 1.5579,
- 1.6382, 1.1253, 2.8251, 1.916,
- },
- TemperalDownsample: []bool{false, true, true},
- }
-}
-
-// VAE is the full VAE with encoder and decoder
-type VAE struct {
- Config *VAEConfig
- Encoder *VAEEncoder
- Decoder *VAEDecoder
-}
-
-// Load loads the VAE from a directory
-func (m *VAE) Load(path string) error {
- fmt.Println("Loading Qwen-Image-Edit VAE (encoder + decoder)...")
-
- cfg := defaultVAEConfig()
- m.Config = cfg
-
- weights, err := safetensors.LoadModelWeights(path)
- if err != nil {
- return fmt.Errorf("weights: %w", err)
- }
-
- // Load weights as f32 for quality (matches Python default behavior)
- // VAE decoder precision is critical for final image quality
- fmt.Print(" Loading weights as f32... ")
- if err := weights.Load(mlx.DtypeFloat32); err != nil {
- return fmt.Errorf("failed to load weights: %w", err)
- }
- fmt.Printf("✓ (%.1f GB)\n", float64(mlx.MetalGetActiveMemory())/(1024*1024*1024))
-
- // Load encoder
- fmt.Print(" Loading encoder... ")
- m.Encoder = &VAEEncoder{}
- if err := m.Encoder.loadFromWeights(weights, cfg); err != nil {
- return fmt.Errorf("encoder: %w", err)
- }
- fmt.Println("✓")
-
- // Load decoder
- fmt.Print(" Loading decoder... ")
- m.Decoder = &VAEDecoder{}
- if err := m.Decoder.loadFromWeights(weights, cfg); err != nil {
- return fmt.Errorf("decoder: %w", err)
- }
- fmt.Println("✓")
-
- weights.ReleaseAll()
- return nil
-}
-
-// Encode encodes an image to latents
-// x: [B, C, T, H, W] image tensor in [-1, 1] range
-// Returns: [B, C, T, H/8, W/8] latents (unnormalized)
-func (m *VAE) Encode(x *mlx.Array) *mlx.Array {
- return m.Encoder.Encode(x)
-}
-
-// Decode decodes latents to image
-// z: [B, C, T, H, W] latents (denormalized)
-// Returns: [B, C, T, H*8, W*8] image in [-1, 1]
-func (m *VAE) Decode(z *mlx.Array) *mlx.Array {
- return m.Decoder.Decode(z)
-}
-
-// Normalize applies latent normalization
-// Input z should be f32 (from VAE encoder), output is f32 for transformer
-func (m *VAE) Normalize(z *mlx.Array) *mlx.Array {
- shape := z.Shape()
- C := shape[1]
-
- mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
- std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
-
- // Mean/std are f32, will match z dtype through broadcasting
- return mlx.Div(mlx.Sub(z, mean), std)
-}
-
-// Denormalize reverses latent normalization
-// Input z is bf16 (from transformer), output converted to f32 for VAE decoder
-func (m *VAE) Denormalize(z *mlx.Array) *mlx.Array {
- shape := z.Shape()
- C := shape[1]
-
- // Convert latents to f32 for VAE decoder quality
- z = mlx.AsType(z, mlx.DtypeFloat32)
-
- mean := mlx.NewArray(m.Config.LatentsMean[:C], []int32{1, C, 1, 1, 1})
- std := mlx.NewArray(m.Config.LatentsStd[:C], []int32{1, C, 1, 1, 1})
-
- return mlx.Add(mlx.Mul(z, std), mean)
-}
-
-// VAEEncoder is the encoder part of the VAE
-// The encoder uses a flat structure where down_blocks contains a mix of ResBlocks and Downsamplers:
-// - Blocks 0,1: ResBlocks (base_dim)
-// - Block 2: Downsample
-// - Blocks 3,4: ResBlocks (base_dim*2)
-// - Block 5: Downsample + temporal
-// - Blocks 6,7: ResBlocks (base_dim*4)
-// - Block 8: Downsample + temporal
-// - Blocks 9,10: ResBlocks (base_dim*4)
-type VAEEncoder struct {
- Config *VAEConfig
-
- ConvIn *CausalConv3d
- Blocks []EncoderBlock // Flat list of ResBlocks and Downsamplers
- MidBlock *MidBlock
- NormOut *RMSNorm3D
- ConvOut *CausalConv3d
- QuantConv *CausalConv3d
-}
-
-// EncoderBlock is either a ResBlock or a Downsample
-type EncoderBlock interface {
- Forward(x *mlx.Array) *mlx.Array
- IsDownsample() bool
-}
-
-// EncoderResBlock wraps ResBlock
-type EncoderResBlock struct {
- *ResBlock
-}
-
-func (b *EncoderResBlock) IsDownsample() bool { return false }
-
-// EncoderDownsample is a downsample layer
-type EncoderDownsample struct {
- Resample *CausalConv3d
- TimeConv *CausalConv3d // Optional temporal downsample
-}
-
-func (d *EncoderDownsample) IsDownsample() bool { return true }
-
-func (d *EncoderDownsample) Forward(x *mlx.Array) *mlx.Array {
- // Spatial downsample with stride 2
- // WAN VAE uses: ZeroPad2d(0,1,0,1) + Conv2d(3x3, stride=2)
- x = d.forwardSpatialDownsample(x)
-
- // NOTE: In WAN VAE, time_conv is ONLY used in streaming/chunked mode
- // with feat_cache. For single-frame encoding (T=1), time_conv is skipped.
- // The Python forward checks: if feat_cache is not None ... then use time_conv
- // Since we don't support streaming, we skip time_conv entirely.
- return x
-}
-
-// forwardSpatialDownsample applies 2D conv with stride 2 for spatial downsampling
-func (d *EncoderDownsample) forwardSpatialDownsample(x *mlx.Array) *mlx.Array {
- xShape := x.Shape()
- B := xShape[0]
- T := xShape[1]
- H := xShape[2]
- W := xShape[3]
- C := xShape[4]
-
- wShape := d.Resample.Weight.Shape()
- outC := wShape[0]
-
- // Reshape to [B*T, H, W, C] for 2D conv
- x = mlx.Reshape(x, B*T, H, W, C)
-
- // Asymmetric padding: pad right and bottom by 1 (WAN VAE style)
- // ZeroPad2d(0, 1, 0, 1) means (left=0, right=1, top=0, bottom=1)
- x = mlx.Pad(x, []int32{0, 0, 0, 1, 0, 1, 0, 0}) // [B, H, W, C] -> pad H and W
-
- // Apply 2D conv with stride 2
- weight := mlx.Transpose(d.Resample.Weight, 0, 2, 3, 1) // [O, I, kH, kW] -> [O, kH, kW, I]
- x = conv2DStrided(x, weight, 2)
-
- if d.Resample.Bias != nil {
- bias := mlx.Reshape(d.Resample.Bias, 1, 1, 1, outC)
- x = mlx.Add(x, bias)
- }
-
- // Output dims after stride 2: (H+1)/2, (W+1)/2
- outH := (H + 1) / 2
- outW := (W + 1) / 2
-
- // Reshape back to [B, T, H', W', C]
- x = mlx.Reshape(x, B, T, outH, outW, outC)
- mlx.Eval(x)
-
- return x
-}
-
-// loadFromWeights loads the encoder from pre-loaded weights
-func (e *VAEEncoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
- e.Config = cfg
-
- // Conv in
- convIn, err := newCausalConv3d(weights, "encoder.conv_in")
- if err != nil {
- return err
- }
- e.ConvIn = convIn
-
- // Encoder uses flat block structure:
- // dim_mult = [1, 2, 4, 4], num_res_blocks = 2, temporal_downsample = [false, true, true]
- // Block layout: res,res,down, res,res,down+t, res,res,down+t, res,res
- // That's 11 blocks: 0,1=res, 2=down, 3,4=res, 5=down+t, 6,7=res, 8=down+t, 9,10=res
- e.Blocks = make([]EncoderBlock, 0, 11)
-
- // Track dimensions
- dims := []int32{cfg.BaseDim, cfg.BaseDim * 2, cfg.BaseDim * 4, cfg.BaseDim * 4}
- blockIdx := 0
-
- for stage := 0; stage < len(cfg.DimMult); stage++ {
- inDim := cfg.BaseDim
- if stage > 0 {
- inDim = dims[stage-1]
- }
- outDim := dims[stage]
-
- // ResBlocks for this stage (num_res_blocks per stage)
- for r := int32(0); r < cfg.NumResBlocks; r++ {
- prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
- currentInDim := inDim
- if r > 0 {
- currentInDim = outDim
- }
- block, err := newEncoderResBlock(weights, prefix, currentInDim, outDim)
- if err != nil {
- return fmt.Errorf("encoder res block %d: %w", blockIdx, err)
- }
- e.Blocks = append(e.Blocks, block)
- blockIdx++
- }
-
- // Downsample after each stage except the last
- if stage < len(cfg.DimMult)-1 {
- prefix := fmt.Sprintf("encoder.down_blocks.%d", blockIdx)
- down, err := newEncoderDownsample(weights, prefix, cfg.TemperalDownsample[stage])
- if err != nil {
- return fmt.Errorf("encoder downsample %d: %w", blockIdx, err)
- }
- e.Blocks = append(e.Blocks, down)
- blockIdx++
- }
- }
-
- // Mid block
- midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
- midBlock, err := newMidBlock(weights, "encoder.mid_block", midDim)
- if err != nil {
- return err
- }
- e.MidBlock = midBlock
-
- // Norm out
- normOut, err := newRMSNorm3D(weights, "encoder.norm_out", midDim)
- if err != nil {
- return err
- }
- e.NormOut = normOut
-
- // Conv out
- convOut, err := newCausalConv3d(weights, "encoder.conv_out")
- if err != nil {
- return err
- }
- e.ConvOut = convOut
-
- // Quant conv
- quantConv, err := newCausalConv3d(weights, "quant_conv")
- if err != nil {
- return err
- }
- e.QuantConv = quantConv
-
- return nil
-}
-
-// newEncoderResBlock creates a ResBlock for the encoder (flat structure)
-func newEncoderResBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32) (*EncoderResBlock, error) {
- block, err := newResBlock(weights, prefix, inDim, outDim)
- if err != nil {
- return nil, err
- }
- return &EncoderResBlock{block}, nil
-}
-
-// newEncoderDownsample creates a downsample layer for the encoder
-func newEncoderDownsample(weights *safetensors.ModelWeights, prefix string, temporal bool) (*EncoderDownsample, error) {
- resample, err := newCausalConv3d(weights, prefix+".resample.1")
- if err != nil {
- return nil, err
- }
-
- var timeConv *CausalConv3d
- if temporal {
- timeConv, _ = newCausalConv3d(weights, prefix+".time_conv")
- }
-
- return &EncoderDownsample{
- Resample: resample,
- TimeConv: timeConv,
- }, nil
-}
-
-// Encode encodes an image to latents
-// x: [B, C, T, H, W] image tensor (channels-first)
-// Returns: [B, latent_C, T, H/8, W/8] latent distribution mode
-func (e *VAEEncoder) Encode(x *mlx.Array) *mlx.Array {
- // Convert from channels-first [N, C, T, H, W] to channels-last [N, T, H, W, C]
- x = mlx.Contiguous(mlx.Transpose(x, 0, 2, 3, 4, 1))
- mlx.Eval(x)
-
- // Conv in
- x = e.ConvIn.Forward(x)
-
- // Encoder blocks (mix of ResBlocks and Downsamplers)
- for _, block := range e.Blocks {
- prev := x
- x = block.Forward(x)
- prev.Free()
- }
-
- // Mid block
- x = e.MidBlock.Forward(x)
-
- // Norm + silu
- {
- prev := x
- x = e.NormOut.Forward(x)
- x = silu3D(x)
- prev.Free()
- mlx.Eval(x)
- }
-
- // Conv out
- {
- prev := x
- x = e.ConvOut.Forward(x)
- prev.Free()
- }
-
- // Quant conv
- {
- prev := x
- x = e.QuantConv.Forward(x)
- prev.Free()
- }
-
- // Get mode from distribution (first half of channels = mean)
- // Output is [B, T, H, W, 2*latent_C], we take first latent_C channels
- shape := x.Shape()
- latentC := shape[4] / 2
- x = mlx.Slice(x, []int32{0, 0, 0, 0, 0}, []int32{shape[0], shape[1], shape[2], shape[3], latentC})
-
- // Convert back to channels-first [N, C, T, H, W]
- x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
- mlx.Eval(x)
-
- return x
-}
-
-// VAEDecoder is the decoder part of the VAE
-type VAEDecoder struct {
- Config *VAEConfig
-
- PostQuantConv *CausalConv3d
- ConvIn *CausalConv3d
- MidBlock *MidBlock
- UpBlocks []*UpBlock
- NormOut *RMSNorm3D
- ConvOut *CausalConv3d
-}
-
-// loadFromWeights loads the decoder from pre-loaded weights
-func (d *VAEDecoder) loadFromWeights(weights *safetensors.ModelWeights, cfg *VAEConfig) error {
- d.Config = cfg
-
- postQuantConv, err := newCausalConv3d(weights, "post_quant_conv")
- if err != nil {
- return err
- }
- d.PostQuantConv = postQuantConv
-
- convIn, err := newCausalConv3d(weights, "decoder.conv_in")
- if err != nil {
- return err
- }
- d.ConvIn = convIn
-
- // Mid block
- midDim := cfg.BaseDim * cfg.DimMult[len(cfg.DimMult)-1]
- midBlock, err := newMidBlock(weights, "decoder.mid_block", midDim)
- if err != nil {
- return err
- }
- d.MidBlock = midBlock
-
- // Up blocks (reversed dim_mult)
- numUpBlocks := len(cfg.DimMult)
- d.UpBlocks = make([]*UpBlock, numUpBlocks)
-
- dimsMult := make([]int32, numUpBlocks+1)
- dimsMult[0] = cfg.DimMult[numUpBlocks-1]
- for i := 0; i < numUpBlocks; i++ {
- dimsMult[i+1] = cfg.DimMult[numUpBlocks-1-i]
- }
-
- temporalUpsample := make([]bool, len(cfg.TemperalDownsample))
- for i := range cfg.TemperalDownsample {
- temporalUpsample[i] = cfg.TemperalDownsample[len(cfg.TemperalDownsample)-1-i]
- }
-
- for i := 0; i < numUpBlocks; i++ {
- inDim := cfg.BaseDim * dimsMult[i]
- outDim := cfg.BaseDim * dimsMult[i+1]
-
- if i > 0 {
- inDim = inDim / 2
- }
-
- upsampleMode := ""
- if i < numUpBlocks-1 {
- if temporalUpsample[i] {
- upsampleMode = "upsample3d"
- } else {
- upsampleMode = "upsample2d"
- }
- }
-
- prefix := fmt.Sprintf("decoder.up_blocks.%d", i)
- upBlock, err := newUpBlock(weights, prefix, inDim, outDim, cfg.NumResBlocks, upsampleMode)
- if err != nil {
- return err
- }
- d.UpBlocks[i] = upBlock
- }
-
- normOut, err := newRMSNorm3D(weights, "decoder.norm_out", cfg.BaseDim)
- if err != nil {
- return err
- }
- d.NormOut = normOut
-
- convOut, err := newCausalConv3d(weights, "decoder.conv_out")
- if err != nil {
- return err
- }
- d.ConvOut = convOut
-
- return nil
-}
-
-// Decode converts latents to image
-// z: [B, C, T, H, W] denormalized latents
-func (d *VAEDecoder) Decode(z *mlx.Array) *mlx.Array {
- var x *mlx.Array
-
- // Convert from channels-first to channels-last
- {
- z = mlx.Contiguous(mlx.Transpose(z, 0, 2, 3, 4, 1))
- mlx.Eval(z)
- }
-
- // PostQuantConv
- x = d.PostQuantConv.Forward(z)
- z.Free()
-
- // ConvIn
- {
- prev := x
- x = d.ConvIn.Forward(x)
- prev.Free()
- }
-
- // Mid block
- x = d.MidBlock.Forward(x)
-
- // Up blocks
- for _, upBlock := range d.UpBlocks {
- x = upBlock.Forward(x)
- }
-
- // NormOut + silu
- {
- prev := x
- x = d.NormOut.Forward(x)
- x = silu3D(x)
- prev.Free()
- mlx.Eval(x)
- }
-
- // ConvOut
- {
- prev := x
- x = d.ConvOut.Forward(x)
- prev.Free()
- }
-
- // Post-processing: clamp and convert back to channels-first
- {
- prev := x
- x = mlx.ClipScalar(x, -1.0, 1.0, true, true)
- x = mlx.Contiguous(mlx.Transpose(x, 0, 4, 1, 2, 3))
- prev.Free()
- mlx.Eval(x)
- }
-
- return x
-}
-
-// DownBlock handles downsampling in encoder
-type DownBlock struct {
- ResBlocks []*ResBlock
- Downsampler *Downsample
-}
-
-// newDownBlock creates a down block
-func newDownBlock(weights *safetensors.ModelWeights, prefix string, inDim, outDim int32, numBlocks int32, downsampleMode string) (*DownBlock, error) {
- resBlocks := make([]*ResBlock, numBlocks+1)
-
- currentDim := inDim
- for i := int32(0); i <= numBlocks; i++ {
- resPrefix := fmt.Sprintf("%s.resnets.%d", prefix, i)
- block, err := newResBlock(weights, resPrefix, currentDim, outDim)
- if err != nil {
- return nil, err
- }
- resBlocks[i] = block
- currentDim = outDim
- }
-
- var downsampler *Downsample
- if downsampleMode != "" {
- downsampler = newDownsample(weights, prefix+".downsamplers.0", outDim, downsampleMode)
- }
-
- return &DownBlock{
- ResBlocks: resBlocks,
- Downsampler: downsampler,
- }, nil
-}
-
-// Forward applies down block
-func (d *DownBlock) Forward(x *mlx.Array) *mlx.Array {
- for _, block := range d.ResBlocks {
- prev := x
- x = block.Forward(x)
- prev.Free()
- }
-
- if d.Downsampler != nil {
- prev := x
- x = d.Downsampler.Forward(x)
- prev.Free()
- }
- return x
-}
-
-// Downsample handles spatial downsampling
-type Downsample struct {
- Conv *mlx.Array
- Bias *mlx.Array
- Mode string
-}
-
-// newDownsample creates a downsampler
-func newDownsample(weights *safetensors.ModelWeights, prefix string, dim int32, mode string) *Downsample {
- conv, _ := weights.Get(prefix + ".resample.1.weight")
- bias, _ := weights.Get(prefix + ".resample.1.bias")
- return &Downsample{
- Conv: conv,
- Bias: bias,
- Mode: mode,
- }
-}
-
-// Forward applies downsampling to channels-last input [B, T, H, W, C]
-func (d *Downsample) Forward(x *mlx.Array) *mlx.Array {
- shape := x.Shape()
- B := shape[0]
- T := shape[1]
- H := shape[2]
- W := shape[3]
- C := shape[4]
- outC := d.Conv.Shape()[0]
-
- // Reshape to [B*T, H, W, C] for 2D conv
- x = mlx.Reshape(x, B*T, H, W, C)
-
- // Pad for stride-2 conv: need (3-1)/2 = 1 on each side, but for stride 2 we need specific padding
- // For 3x3 stride 2: pad 1 on all sides
- x = mlx.Pad(x, []int32{0, 0, 1, 1, 1, 1, 0, 0})
-
- // Conv with stride 2 using manual strided patching
- weight := mlx.Transpose(d.Conv, 0, 2, 3, 1)
- x = conv2DStrided(x, weight, 2)
- if d.Bias != nil {
- bias := mlx.Reshape(d.Bias, 1, 1, 1, outC)
- x = mlx.Add(x, bias)
- }
-
- x = mlx.Reshape(x, B, T, H/2, W/2, outC)
- mlx.Eval(x)
-
- return x
-}
diff --git a/x/imagegen/models/zimage/transformer.go b/x/imagegen/models/zimage/transformer.go
index 4164fed8c68..2c42d8c254b 100644
--- a/x/imagegen/models/zimage/transformer.go
+++ b/x/imagegen/models/zimage/transformer.go
@@ -7,8 +7,8 @@ import (
"fmt"
"math"
- "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/nn"
"github.com/ollama/ollama/x/imagegen/safetensors"
@@ -38,7 +38,7 @@ type TransformerConfig struct {
type TimestepEmbedder struct {
Linear1 nn.LinearLayer `weight:"mlp.0"`
Linear2 nn.LinearLayer `weight:"mlp.2"`
- FreqEmbedSize int32 // 256 (computed)
+ FreqEmbedSize int32 // 256 (computed)
}
// Forward computes timestep embeddings -> [B, 256]
@@ -85,9 +85,9 @@ func (xe *XEmbedder) Forward(x *mlx.Array) *mlx.Array {
// CapEmbedder projects caption features to model dimension
type CapEmbedder struct {
- Norm *nn.RMSNorm `weight:"0"`
- Linear nn.LinearLayer `weight:"1"`
- PadToken *mlx.Array // loaded separately at root level
+ Norm *nn.RMSNorm `weight:"0"`
+ Linear nn.LinearLayer `weight:"1"`
+ PadToken *mlx.Array // loaded separately at root level
}
// Forward projects caption embeddings: [B, L, cap_feat_dim] -> [B, L, dim]
@@ -103,10 +103,9 @@ type FeedForward struct {
W1 nn.LinearLayer `weight:"w1"` // gate projection
W2 nn.LinearLayer `weight:"w2"` // down projection
W3 nn.LinearLayer `weight:"w3"` // up projection
- OutDim int32 // computed from W2
+ OutDim int32 // computed from W2
}
-
// Forward applies SwiGLU: silu(W1(x)) * W3(x), then W2
func (ff *FeedForward) Forward(x *mlx.Array) *mlx.Array {
shape := x.Shape()
@@ -132,11 +131,11 @@ type Attention struct {
ToK nn.LinearLayer `weight:"to_k"`
ToV nn.LinearLayer `weight:"to_v"`
ToOut nn.LinearLayer `weight:"to_out.0"`
- NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
- NormK *mlx.Array `weight:"norm_k.weight"`
+ NormQ *mlx.Array `weight:"norm_q.weight"` // [head_dim] for per-head RMSNorm
+ NormK *mlx.Array `weight:"norm_k.weight"`
// Fused QKV (computed at init time for efficiency, not loaded from weights)
ToQKV nn.LinearLayer `weight:"-"` // Fused Q+K+V projection (created by FuseQKV)
- Fused bool `weight:"-"` // Whether to use fused QKV path
+ Fused bool `weight:"-"` // Whether to use fused QKV path
// Computed fields (not loaded from weights)
NHeads int32 `weight:"-"`
HeadDim int32 `weight:"-"`
@@ -288,13 +287,13 @@ func applyRoPE3D(x *mlx.Array, cos, sin *mlx.Array) *mlx.Array {
// TransformerBlock is a single transformer block with optional AdaLN modulation
type TransformerBlock struct {
- Attention *Attention `weight:"attention"`
- FeedForward *FeedForward `weight:"feed_forward"`
- AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
- AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
- FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
- FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
- AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
+ Attention *Attention `weight:"attention"`
+ FeedForward *FeedForward `weight:"feed_forward"`
+ AttentionNorm1 *nn.RMSNorm `weight:"attention_norm1"`
+ AttentionNorm2 *nn.RMSNorm `weight:"attention_norm2"`
+ FFNNorm1 *nn.RMSNorm `weight:"ffn_norm1"`
+ FFNNorm2 *nn.RMSNorm `weight:"ffn_norm2"`
+ AdaLN nn.LinearLayer `weight:"adaLN_modulation.0,optional"` // only if modulation
// Computed fields
HasModulation bool
Dim int32
@@ -350,7 +349,7 @@ func (tb *TransformerBlock) Forward(x *mlx.Array, adaln *mlx.Array, cos, sin *ml
type FinalLayer struct {
AdaLN nn.LinearLayer `weight:"adaLN_modulation.1"` // [256] -> [dim]
Output nn.LinearLayer `weight:"linear"` // [dim] -> [out_channels]
- OutDim int32 // computed from Output
+ OutDim int32 // computed from Output
}
// Forward computes final output
@@ -401,12 +400,12 @@ type Transformer struct {
}
// Load loads the Z-Image transformer from ollama blob storage.
-func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
+func (m *Transformer) Load(modelManifest *manifest.ModelManifest) error {
fmt.Print(" Loading transformer... ")
// Load config from blob
var cfg TransformerConfig
- if err := manifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
+ if err := modelManifest.ReadConfigJSON("transformer/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
if len(cfg.AllPatchSize) > 0 {
@@ -417,7 +416,7 @@ func (m *Transformer) Load(manifest *imagegen.ModelManifest) error {
m.ContextRefiners = make([]*TransformerBlock, cfg.NRefinerLayers)
m.Layers = make([]*TransformerBlock, cfg.NLayers)
- weights, err := imagegen.LoadWeightsFromManifest(manifest, "transformer")
+ weights, err := manifest.LoadWeightsFromManifest(modelManifest, "transformer")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
diff --git a/x/imagegen/models/zimage/vae.go b/x/imagegen/models/zimage/vae.go
index 663365fc232..aca2b1bfc05 100644
--- a/x/imagegen/models/zimage/vae.go
+++ b/x/imagegen/models/zimage/vae.go
@@ -6,7 +6,7 @@ import (
"fmt"
"math"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/safetensors"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -562,7 +562,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
if ub.Upsample != nil {
// Stage 1: Upsample2x (nearest neighbor)
{
- prev := x
+ prev := x
x = Upsample2x(x)
prev.Free()
mlx.Eval(x)
@@ -570,7 +570,7 @@ func (ub *UpDecoderBlock2D) Forward(x *mlx.Array) *mlx.Array {
// Stage 2: Upsample conv
{
- prev := x
+ prev := x
x = ub.Upsample.Forward(x)
prev.Free()
mlx.Eval(x)
@@ -643,16 +643,16 @@ type VAEDecoder struct {
}
// Load loads the VAE decoder from ollama blob storage.
-func (m *VAEDecoder) Load(manifest *imagegen.ModelManifest) error {
+func (m *VAEDecoder) Load(modelManifest *manifest.ModelManifest) error {
// Load config from blob
var cfg VAEConfig
- if err := manifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
+ if err := modelManifest.ReadConfigJSON("vae/config.json", &cfg); err != nil {
return fmt.Errorf("config: %w", err)
}
m.Config = &cfg
// Load weights from tensor blobs
- weights, err := imagegen.LoadWeightsFromManifest(manifest, "vae")
+ weights, err := manifest.LoadWeightsFromManifest(modelManifest, "vae")
if err != nil {
return fmt.Errorf("weights: %w", err)
}
diff --git a/x/imagegen/models/zimage/zimage.go b/x/imagegen/models/zimage/zimage.go
index f076935ee31..e7ce8436dd8 100644
--- a/x/imagegen/models/zimage/zimage.go
+++ b/x/imagegen/models/zimage/zimage.go
@@ -8,8 +8,8 @@ import (
"fmt"
"time"
- "github.com/ollama/ollama/x/imagegen"
"github.com/ollama/ollama/x/imagegen/cache"
+ "github.com/ollama/ollama/x/imagegen/manifest"
"github.com/ollama/ollama/x/imagegen/mlx"
"github.com/ollama/ollama/x/imagegen/tokenizer"
"github.com/ollama/ollama/x/imagegen/vae"
@@ -18,14 +18,14 @@ import (
// GenerateConfig holds all options for image generation.
type GenerateConfig struct {
Prompt string
- NegativePrompt string // Empty = no CFG
- CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
- Width int32 // Image width (default: 1024)
- Height int32 // Image height (default: 1024)
- Steps int // Denoising steps (default: 9 for turbo)
- Seed int64 // Random seed
+ NegativePrompt string // Empty = no CFG
+ CFGScale float32 // Only used if NegativePrompt is set (default: 4.0)
+ Width int32 // Image width (default: 1024)
+ Height int32 // Image height (default: 1024)
+ Steps int // Denoising steps (default: 9 for turbo)
+ Seed int64 // Random seed
Progress func(step, totalSteps int) // Optional progress callback
- CapturePath string // GPU capture path (debug)
+ CapturePath string // GPU capture path (debug)
// TeaCache options (timestep embedding aware caching)
TeaCache bool // TeaCache is always enabled for faster inference
@@ -58,7 +58,7 @@ func (m *Model) Load(modelName string) error {
m.ModelName = modelName
// Load manifest
- manifest, err := imagegen.LoadManifest(modelName)
+ manifest, err := manifest.LoadManifest(modelName)
if err != nil {
return fmt.Errorf("load manifest: %w", err)
}
diff --git a/x/imagegen/nn/nn.go b/x/imagegen/nn/nn.go
index 65bf7fa22bc..d7247435857 100644
--- a/x/imagegen/nn/nn.go
+++ b/x/imagegen/nn/nn.go
@@ -32,10 +32,16 @@ func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
// NewQuantizedLinear creates a quantized linear layer directly from bf16 weights.
// Quantizes the weight immediately and evaluates to break lazy dependencies.
+// Note: For modes like "nvfp4", qbiases will be nil.
func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
// Eval immediately so bf16 weight can be freed
- mlx.Eval(qw, scales, qbiases)
+ // Handle modes that don't return qbiases (e.g., nvfp4)
+ if qbiases != nil {
+ mlx.Eval(qw, scales, qbiases)
+ } else {
+ mlx.Eval(qw, scales)
+ }
return &QuantizedLinear{
Weight: qw,
Scales: scales,
@@ -77,10 +83,13 @@ func (l *Linear) ToQuantized(groupSize, bits int, mode string) *QuantizedLinear
// QuantizedLinear applies an affine transformation using quantized weights.
// Equivalent to mlx.nn.QuantizedLinear.
+// Supports multiple quantization modes:
+// - "affine": scale + zero-point bias (QBiases required)
+// - "nvfp4": NVIDIA FP4 with E4M3 scales (QBiases nil)
type QuantizedLinear struct {
Weight *mlx.Array // Quantized weight data
Scales *mlx.Array // Scale factors for dequantization
- QBiases *mlx.Array // Quantization biases (NOT layer bias)
+ QBiases *mlx.Array // Quantization biases (NOT layer bias), nil for nvfp4
Bias *mlx.Array // Layer bias [output_dims] or nil
GroupSize int
Bits int
@@ -220,3 +229,32 @@ func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
}
return out
}
+
+// MultiLinearLayer is an interface for per-head linear layers.
+// This allows swapping between MultiLinear (bf16) and pre-dequantized weights.
+type MultiLinearLayer interface {
+ Forward(x *mlx.Array) *mlx.Array
+}
+
+// MultiLinear performs per-head linear projections.
+// Weight shape: [num_heads, output_dims, input_dims]
+// Input shape: [B, num_heads, L, input_dims]
+// Output shape: [B, num_heads, L, output_dims]
+type MultiLinear struct {
+ Weight *mlx.Array `weight:"weight"`
+}
+
+// NewMultiLinear creates a MultiLinear layer with the given weight.
+func NewMultiLinear(weight *mlx.Array) *MultiLinear {
+ return &MultiLinear{Weight: weight}
+}
+
+// Forward applies per-head linear transformation: x @ weight.T per head via broadcasting.
+func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
+ // Weight: [num_heads, output_dims, input_dims]
+ // x: [B, num_heads, L, input_dims]
+ // wT: [num_heads, input_dims, output_dims]
+ // Result: [B, num_heads, L, output_dims]
+ wT := mlx.Transpose(ml.Weight, 0, 2, 1)
+ return mlx.Matmul(x, wT)
+}
diff --git a/x/imagegen/runner.go b/x/imagegen/runner.go
new file mode 100644
index 00000000000..e24383ad504
--- /dev/null
+++ b/x/imagegen/runner.go
@@ -0,0 +1,203 @@
+//go:build mlx
+
+// Package imagegen provides a unified MLX runner for both LLM and image generation models.
+package imagegen
+
+import (
+ "context"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+ "os/signal"
+ "syscall"
+ "time"
+
+ "github.com/ollama/ollama/envconfig"
+ "github.com/ollama/ollama/x/imagegen/mlx"
+)
+
+// Execute is the entry point for the unified MLX runner subprocess.
+func Execute(args []string) error {
+ // Set up logging with appropriate level from environment
+ slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: envconfig.LogLevel()})))
+
+ fs := flag.NewFlagSet("mlx-runner", flag.ExitOnError)
+ modelName := fs.String("model", "", "path to model")
+ port := fs.Int("port", 0, "port to listen on")
+
+ if err := fs.Parse(args); err != nil {
+ return err
+ }
+
+ if *modelName == "" {
+ return fmt.Errorf("--model is required")
+ }
+ if *port == 0 {
+ return fmt.Errorf("--port is required")
+ }
+
+ // Initialize MLX
+ if err := mlx.InitMLX(); err != nil {
+ slog.Error("unable to initialize MLX", "error", err)
+ return err
+ }
+ slog.Info("MLX library initialized")
+
+ // Detect model type from capabilities
+ mode := detectModelMode(*modelName)
+ slog.Info("starting mlx runner", "model", *modelName, "port", *port, "mode", mode)
+
+ // Create and start server
+ server, err := newServer(*modelName, *port, mode)
+ if err != nil {
+ return fmt.Errorf("failed to create server: %w", err)
+ }
+
+ // Set up HTTP handlers
+ mux := http.NewServeMux()
+ mux.HandleFunc("/health", server.healthHandler)
+ mux.HandleFunc("/completion", server.completionHandler)
+
+ // LLM-specific endpoints
+ if mode == ModeLLM {
+ mux.HandleFunc("/tokenize", server.tokenizeHandler)
+ mux.HandleFunc("/embedding", server.embeddingHandler)
+ }
+
+ httpServer := &http.Server{
+ Addr: fmt.Sprintf("127.0.0.1:%d", *port),
+ Handler: mux,
+ }
+
+ // Handle shutdown
+ done := make(chan struct{})
+ go func() {
+ sigCh := make(chan os.Signal, 1)
+ signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
+ <-sigCh
+ slog.Info("shutting down mlx runner")
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ httpServer.Shutdown(ctx)
+ close(done)
+ }()
+
+ slog.Info("mlx runner listening", "addr", httpServer.Addr)
+ if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
+ return err
+ }
+
+ <-done
+ return nil
+}
+
+// detectModelMode determines whether a model is an LLM or image generation model.
+func detectModelMode(modelName string) ModelMode {
+ // Check for image generation model by looking at model_index.json
+ modelType := DetectModelType(modelName)
+ if modelType != "" {
+ // Known image generation model types
+ switch modelType {
+ case "ZImagePipeline", "FluxPipeline", "Flux2KleinPipeline":
+ return ModeImageGen
+ }
+ }
+
+ // Default to LLM mode for safetensors models without known image gen types
+ return ModeLLM
+}
+
+// server holds the model and handles HTTP requests.
+type server struct {
+ mode ModelMode
+ modelName string
+ port int
+
+ // Image generation model (when mode == ModeImageGen)
+ imageModel ImageModel
+
+ // LLM model (when mode == ModeLLM)
+ llmModel *llmState
+}
+
+// newServer creates a new server instance and loads the appropriate model.
+func newServer(modelName string, port int, mode ModelMode) (*server, error) {
+ s := &server{
+ mode: mode,
+ modelName: modelName,
+ port: port,
+ }
+
+ switch mode {
+ case ModeImageGen:
+ if err := s.loadImageModel(); err != nil {
+ return nil, fmt.Errorf("failed to load image model: %w", err)
+ }
+ case ModeLLM:
+ if err := s.loadLLMModel(); err != nil {
+ return nil, fmt.Errorf("failed to load LLM model: %w", err)
+ }
+ }
+
+ return s, nil
+}
+
+func (s *server) healthHandler(w http.ResponseWriter, r *http.Request) {
+ resp := HealthResponse{Status: "ok"}
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(resp)
+}
+
+func (s *server) completionHandler(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
+ return
+ }
+
+ var req Request
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ switch s.mode {
+ case ModeImageGen:
+ s.handleImageCompletion(w, r, req)
+ case ModeLLM:
+ s.handleLLMCompletion(w, r, req)
+ }
+}
+
+func (s *server) tokenizeHandler(w http.ResponseWriter, r *http.Request) {
+ if s.llmModel == nil {
+ http.Error(w, "LLM model not loaded", http.StatusInternalServerError)
+ return
+ }
+
+ var req struct {
+ Content string `json:"content"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ tok := s.llmModel.model.Tokenizer()
+ tokens := tok.Encode(req.Content, false)
+
+ // Convert int32 to int for JSON response
+ intTokens := make([]int, len(tokens))
+ for i, t := range tokens {
+ intTokens[i] = int(t)
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ json.NewEncoder(w).Encode(map[string][]int{"tokens": intTokens})
+}
+
+func (s *server) embeddingHandler(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "embeddings not yet implemented for MLX models", http.StatusNotImplemented)
+}
diff --git a/x/imagegen/runner/runner.go b/x/imagegen/runner/runner.go
deleted file mode 100644
index baa0eb4bf7b..00000000000
--- a/x/imagegen/runner/runner.go
+++ /dev/null
@@ -1,233 +0,0 @@
-//go:build mlx
-
-// Package runner provides a subprocess server for image generation.
-// It listens on a port and handles HTTP requests for image generation.
-package runner
-
-import (
- "context"
- "encoding/json"
- "flag"
- "fmt"
- "log/slog"
- "net/http"
- "os"
- "os/signal"
- "sync"
- "syscall"
- "time"
-
- "github.com/ollama/ollama/x/imagegen"
- "github.com/ollama/ollama/x/imagegen/mlx"
- "github.com/ollama/ollama/x/imagegen/models/flux2"
- "github.com/ollama/ollama/x/imagegen/models/zimage"
-)
-
-// Request is the image generation request format
-type Request struct {
- Prompt string `json:"prompt"`
- Width int32 `json:"width,omitempty"`
- Height int32 `json:"height,omitempty"`
- Steps int `json:"steps,omitempty"`
- Seed int64 `json:"seed,omitempty"`
-}
-
-// Response is streamed back for each progress update
-type Response struct {
- Content string `json:"content,omitempty"`
- Image string `json:"image,omitempty"` // Base64-encoded PNG
- Done bool `json:"done"`
- Step int `json:"step,omitempty"`
- Total int `json:"total,omitempty"`
-}
-
-// ImageModel is the interface for image generation models
-type ImageModel interface {
- GenerateImage(ctx context.Context, prompt string, width, height int32, steps int, seed int64, progress func(step, total int)) (*mlx.Array, error)
-}
-
-// Server holds the model and handles requests
-type Server struct {
- mu sync.Mutex
- model ImageModel
- modelName string
-}
-
-// Execute is the entry point for the image runner subprocess
-func Execute(args []string) error {
- fs := flag.NewFlagSet("image-runner", flag.ExitOnError)
- modelName := fs.String("model", "", "path to image model")
- port := fs.Int("port", 0, "port to listen on")
-
- if err := fs.Parse(args); err != nil {
- return err
- }
-
- if *modelName == "" {
- return fmt.Errorf("--model is required")
- }
- if *port == 0 {
- return fmt.Errorf("--port is required")
- }
-
- err := mlx.InitMLX()
- if err != nil {
- slog.Error("unable to initialize MLX", "error", err)
- return err
- }
- slog.Info("MLX library initialized")
- slog.Info("starting image runner", "model", *modelName, "port", *port)
-
- // Check memory requirements before loading
- requiredMemory := imagegen.EstimateVRAM(*modelName)
- availableMemory := mlx.GetMemoryLimit()
- if availableMemory > 0 && availableMemory < requiredMemory {
- return fmt.Errorf("insufficient memory for image generation: need %d GB, have %d GB",
- requiredMemory/(1024*1024*1024), availableMemory/(1024*1024*1024))
- }
-
- // Detect model type and load appropriate model
- modelType := imagegen.DetectModelType(*modelName)
- slog.Info("detected model type", "type", modelType)
-
- var model ImageModel
- switch modelType {
- case "Flux2KleinPipeline":
- m := &flux2.Model{}
- if err := m.Load(*modelName); err != nil {
- return fmt.Errorf("failed to load model: %w", err)
- }
- model = m
- default:
- // Default to Z-Image for ZImagePipeline, FluxPipeline, etc.
- m := &zimage.Model{}
- if err := m.Load(*modelName); err != nil {
- return fmt.Errorf("failed to load model: %w", err)
- }
- model = m
- }
-
- server := &Server{
- model: model,
- modelName: *modelName,
- }
-
- // Set up HTTP handlers
- mux := http.NewServeMux()
- mux.HandleFunc("/health", server.healthHandler)
- mux.HandleFunc("/completion", server.completionHandler)
-
- httpServer := &http.Server{
- Addr: fmt.Sprintf("127.0.0.1:%d", *port),
- Handler: mux,
- }
-
- // Handle shutdown
- done := make(chan struct{})
- go func() {
- sigCh := make(chan os.Signal, 1)
- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
- <-sigCh
- slog.Info("shutting down image runner")
- ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
- defer cancel()
- httpServer.Shutdown(ctx)
- close(done)
- }()
-
- slog.Info("image runner listening", "addr", httpServer.Addr)
- if err := httpServer.ListenAndServe(); err != http.ErrServerClosed {
- return err
- }
-
- <-done
- return nil
-}
-
-func (s *Server) healthHandler(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusOK)
- json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
-}
-
-func (s *Server) completionHandler(w http.ResponseWriter, r *http.Request) {
- if r.Method != http.MethodPost {
- http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
- return
- }
-
- var req Request
- if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- // Serialize generation requests - MLX model may not handle concurrent generation
- s.mu.Lock()
- defer s.mu.Unlock()
-
- // Model applies its own defaults for width/height/steps
- // Only seed needs to be set here if not provided
- if req.Seed <= 0 {
- req.Seed = time.Now().UnixNano()
- }
-
- // Set up streaming response
- w.Header().Set("Content-Type", "application/x-ndjson")
- w.Header().Set("Transfer-Encoding", "chunked")
- flusher, ok := w.(http.Flusher)
- if !ok {
- http.Error(w, "streaming not supported", http.StatusInternalServerError)
- return
- }
-
- // Generate image using the common interface
- ctx := r.Context()
- enc := json.NewEncoder(w)
-
- // Progress callback streams step updates
- progress := func(step, total int) {
- resp := Response{Step: step, Total: total}
- enc.Encode(resp)
- w.Write([]byte("\n"))
- flusher.Flush()
- }
-
- img, err := s.model.GenerateImage(ctx, req.Prompt, req.Width, req.Height, req.Steps, req.Seed, progress)
-
- if err != nil {
- // Don't send error for cancellation
- if ctx.Err() != nil {
- return
- }
- resp := Response{Content: fmt.Sprintf("error: %v", err), Done: true}
- data, _ := json.Marshal(resp)
- w.Write(data)
- w.Write([]byte("\n"))
- return
- }
-
- // Encode image as base64 PNG
- imageData, err := imagegen.EncodeImageBase64(img)
- if err != nil {
- resp := Response{Content: fmt.Sprintf("error encoding: %v", err), Done: true}
- data, _ := json.Marshal(resp)
- w.Write(data)
- w.Write([]byte("\n"))
- return
- }
-
- // Free the generated image array and clean up MLX state
- img.Free()
- mlx.ClearCache()
- mlx.MetalResetPeakMemory()
-
- // Send final response with image data
- resp := Response{
- Image: imageData,
- Done: true,
- }
- data, _ := json.Marshal(resp)
- w.Write(data)
- w.Write([]byte("\n"))
- flusher.Flush()
-}
diff --git a/x/imagegen/runner/runner_stub.go b/x/imagegen/runner_stub.go
similarity index 60%
rename from x/imagegen/runner/runner_stub.go
rename to x/imagegen/runner_stub.go
index eafad7bca1e..866a4408c79 100644
--- a/x/imagegen/runner/runner_stub.go
+++ b/x/imagegen/runner_stub.go
@@ -1,10 +1,10 @@
//go:build !mlx
-package runner
+package imagegen
import "errors"
// Execute returns an error when not built with MLX support.
func Execute(args []string) error {
- return errors.New("image generation not available: build with mlx tag")
+ return errors.New("MLX runner not available: build with mlx tag")
}
diff --git a/x/imagegen/safetensors/extractor.go b/x/imagegen/safetensors/extractor.go
index b14dfe96530..65a4f6da0ab 100644
--- a/x/imagegen/safetensors/extractor.go
+++ b/x/imagegen/safetensors/extractor.go
@@ -41,13 +41,11 @@ func (td *TensorData) Reader() io.Reader {
return td.reader
}
-// SafetensorsReader returns a reader that outputs the tensor wrapped in
-// minimal safetensors format. This allows using mlx_load_safetensors on
-// individual tensor blobs for native zero-copy loading.
-func (td *TensorData) SafetensorsReader() io.Reader {
- // Build minimal safetensors header with tensor named "data"
- header := map[string]tensorInfo{
- "data": {
+// safetensorsHeader builds the JSON header for a minimal safetensors blob
+// containing a single tensor keyed by its name.
+func (td *TensorData) safetensorsHeader() []byte {
+ header := map[string]any{
+ td.Name: tensorInfo{
Dtype: td.Dtype,
Shape: td.Shape,
DataOffsets: [2]int{0, int(td.Size)},
@@ -58,6 +56,15 @@ func (td *TensorData) SafetensorsReader() io.Reader {
// Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
+ return headerJSON
+}
+
+// SafetensorsReader returns a reader that outputs the tensor wrapped in
+// minimal safetensors format. This allows using mlx_load_safetensors on
+// individual tensor blobs for native zero-copy loading.
+// The tensor is keyed by its name in the safetensors header.
+func (td *TensorData) SafetensorsReader() io.Reader {
+ headerJSON := td.safetensorsHeader()
// Build header with size prefix
headerBuf := new(bytes.Buffer)
@@ -71,16 +78,77 @@ func (td *TensorData) SafetensorsReader() io.Reader {
// SafetensorsSize returns the total size of the safetensors-wrapped tensor.
func (td *TensorData) SafetensorsSize() int64 {
- header := map[string]tensorInfo{
- "data": {
+ headerJSON := td.safetensorsHeader()
+ return 8 + int64(len(headerJSON)) + td.Size
+}
+
+// NewTensorDataFromBytes creates a TensorData from raw tensor bytes.
+// This is useful for constructing packed blobs from already-extracted data.
+func NewTensorDataFromBytes(name, dtype string, shape []int32, rawData []byte) *TensorData {
+ return &TensorData{
+ Name: name,
+ Dtype: dtype,
+ Shape: shape,
+ Size: int64(len(rawData)),
+ reader: io.NewSectionReader(bytes.NewReader(rawData), 0, int64(len(rawData))),
+ }
+}
+
+// ExtractRawFromSafetensors reads a safetensors-wrapped reader and extracts
+// the raw tensor data bytes (stripping the header).
+func ExtractRawFromSafetensors(r io.Reader) ([]byte, error) {
+ // Read header size (8 bytes, little endian)
+ var headerSize uint64
+ if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
+ return nil, fmt.Errorf("failed to read header size: %w", err)
+ }
+
+ // Skip header
+ if _, err := io.CopyN(io.Discard, r, int64(headerSize)); err != nil {
+ return nil, fmt.Errorf("failed to skip header: %w", err)
+ }
+
+ // Read remaining bytes (the raw tensor data)
+ return io.ReadAll(r)
+}
+
+// BuildPackedSafetensorsReader builds a streaming io.Reader that outputs a valid
+// safetensors file containing multiple tensors. Used for packing expert tensors
+// into a single blob without loading all data into memory.
+// Each TensorData must have been obtained from GetTensor.
+func BuildPackedSafetensorsReader(tensors []*TensorData) io.Reader {
+ // Build the header with sequential data offsets
+ header := make(map[string]tensorInfo, len(tensors))
+ var offset int
+ for _, td := range tensors {
+ header[td.Name] = tensorInfo{
Dtype: td.Dtype,
Shape: td.Shape,
- DataOffsets: [2]int{0, int(td.Size)},
- },
+ DataOffsets: [2]int{offset, offset + int(td.Size)},
+ }
+ offset += int(td.Size)
}
+
headerJSON, _ := json.Marshal(header)
+
+ // Pad header to 8-byte alignment
padding := (8 - len(headerJSON)%8) % 8
- return 8 + int64(len(headerJSON)) + int64(padding) + td.Size
+ headerJSON = append(headerJSON, bytes.Repeat([]byte(" "), padding)...)
+
+ // Build header with size prefix
+ headerBuf := new(bytes.Buffer)
+ binary.Write(headerBuf, binary.LittleEndian, uint64(len(headerJSON)))
+ headerBuf.Write(headerJSON)
+
+ // Build multi-reader: header + all tensor data readers
+ readers := make([]io.Reader, 0, 1+len(tensors))
+ readers = append(readers, headerBuf)
+ for _, td := range tensors {
+ td.reader.Seek(0, io.SeekStart)
+ readers = append(readers, td.reader)
+ }
+
+ return io.MultiReader(readers...)
}
// OpenForExtraction opens a safetensors file for tensor extraction.
diff --git a/x/imagegen/safetensors/loader.go b/x/imagegen/safetensors/loader.go
index 7f8860b0632..d0426a2ef63 100644
--- a/x/imagegen/safetensors/loader.go
+++ b/x/imagegen/safetensors/loader.go
@@ -17,17 +17,31 @@ type WeightSource interface {
GetTensor(name string) (*mlx.Array, error)
ListTensors() []string
HasTensor(name string) bool
- Quantization() string // Returns "FP4", "FP8", or ""
+ Quantization() string // Returns "NVFP4", "INT4", "INT8", or ""
+ GroupSize() int // Returns quantization group size, or 0 if not specified
}
-// quantizationParams returns groupSize, bits, mode for a quantization type.
-// Returns defaults (32, 8, "affine") for unknown types (backward compatibility).
-func quantizationParams(quantization string) (groupSize, bits int, mode string) {
+// QuantizationParams returns groupSize, bits, mode for a quantization type.
+// MLX quantization modes:
+// - "affine": scale + zero-point bias, group_size=32/64/128
+// - "nvfp4": NVIDIA FP4 with E4M3 scales, group_size=16 (no bias)
+// - "mxfp8": Microsoft MX FP8 with E4M3 scales, group_size=32 (no bias)
+func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
switch strings.ToUpper(quantization) {
- case "FP4":
+ case "NVFP4":
+ // NVIDIA FP4: group_size=16, bits=4, E4M3 scales (no qbias)
+ return 16, 4, "nvfp4"
+ case "FP4", "Q4", "INT4":
+ // 4-bit quantization with affine mode (scale + qbias)
return 32, 4, "affine"
+ case "MXFP8":
+ // Microsoft MX FP8: group_size=32, bits=8, E4M3 scales (no qbias)
+ return 32, 8, "mxfp8"
+ case "FP8", "Q8", "INT8", "":
+ // 8-bit quantization with affine mode (default for quantized models)
+ return 64, 8, "affine"
default:
- return 32, 8, "affine" // FP8 or unknown
+ return 32, 8, "affine" // Default to affine
}
}
@@ -122,7 +136,8 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
}
// Handle nn.LinearLayer interface fields specially
- if field.Type == reflect.TypeOf((*nn.LinearLayer)(nil)).Elem() {
+ linearLayerType := reflect.TypeOf((*nn.LinearLayer)(nil)).Elem()
+ if field.Type == linearLayerType {
if !hasTag {
continue // no tag = skip
}
@@ -137,6 +152,23 @@ func loadStruct(v reflect.Value, weights WeightSource, prefix string, errs *[]st
continue
}
+ // Handle nn.MultiLinearLayer interface fields specially
+ multiLinearLayerType := reflect.TypeOf((*nn.MultiLinearLayer)(nil)).Elem()
+ if field.Type == multiLinearLayerType {
+ if !hasTag {
+ continue // no tag = skip
+ }
+ layer, err := LoadMultiLinearLayer(weights, fullPath)
+ if err != nil {
+ if !optional {
+ *errs = append(*errs, fullPath+": "+err.Error())
+ }
+ continue
+ }
+ fieldVal.Set(reflect.ValueOf(layer))
+ continue
+ }
+
// Handle by kind
switch fieldVal.Kind() {
case reflect.Ptr:
@@ -216,12 +248,86 @@ func joinPath(prefix, suffix string) string {
return prefix + "." + suffix
}
+// LoadMultiLinearLayer loads a per-head linear layer from weights.
+// Weight shape should be [num_heads, output_dims, input_dims].
+// If quantized, always dequantizes since batched quantized matmul isn't supported.
+func LoadMultiLinearLayer(weights WeightSource, path string) (nn.MultiLinearLayer, error) {
+ // Check if this is a quantized layer by looking for scale tensor
+ scalePath := path + ".weight_scale"
+ hasScale := weights.HasTensor(scalePath)
+
+ weight, err := weights.GetTensor(path + ".weight")
+ if err != nil {
+ return nil, fmt.Errorf("failed to load weight %s: %w", path, err)
+ }
+
+ if hasScale {
+ scales, err := weights.GetTensor(scalePath)
+ if err != nil {
+ return nil, fmt.Errorf("failed to load scales %s: %w", scalePath, err)
+ }
+
+ var qbiases *mlx.Array
+ qbiasPath := path + ".weight_qbias"
+ if weights.HasTensor(qbiasPath) {
+ qbiases, _ = weights.GetTensor(qbiasPath)
+ }
+
+ // Always dequantize for MultiLinear - no batched quantized matmul support
+ // Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
+ weightShape := weight.Shape()
+ scalesShape := scales.Shape()
+ weightCols := int(weightShape[len(weightShape)-1])
+ scalesCols := int(scalesShape[len(scalesShape)-1])
+
+ // Detect quantization from tensor shapes
+ // groupSize = weightCols * packFactor / scalesCols
+ // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
+ groupSize4 := weightCols * 8 / scalesCols
+ groupSize8 := weightCols * 4 / scalesCols
+
+ var bits, groupSize int
+ // Use metadata to help disambiguate when shapes are ambiguous
+ // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
+ quantType := strings.ToUpper(weights.Quantization())
+ isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
+
+ if groupSize4 == 32 {
+ // Unambiguous: Q4 with group_size=32
+ bits = 4
+ groupSize = 32
+ } else if groupSize8 == 64 {
+ // Unambiguous: Q8 with group_size=64
+ bits = 8
+ groupSize = 64
+ } else if groupSize4 == 64 && groupSize8 == 32 {
+ // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
+ if isQ8Type {
+ bits = 8
+ groupSize = 32
+ } else {
+ bits = 4
+ groupSize = 64
+ }
+ } else {
+ // Fallback: use global quantization params
+ _, bits, _ = QuantizationParams(weights.Quantization())
+ packFactor := 32 / bits
+ groupSize = weightCols * packFactor / scalesCols
+ }
+ weight = mlx.Dequantize(weight, scales, qbiases, groupSize, bits, "affine")
+ }
+
+ return nn.NewMultiLinear(weight), nil
+}
+
// LoadLinearLayer loads a linear layer from weights, automatically detecting if it's quantized.
-// If {path}.weight_scale exists, dequantizes the weights.
+// If {path}.weight_scale exists, creates a QuantizedLinear layer (or dequantizes if no kernel support).
func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error) {
// Check if this is a quantized layer by looking for scale tensor
scalePath := path + ".weight_scale"
- if weights.HasTensor(scalePath) {
+ hasScale := weights.HasTensor(scalePath)
+ if hasScale {
weight, err := weights.GetTensor(path + ".weight")
if err != nil {
return nil, fmt.Errorf("failed to load quantized weight %s: %w", path, err)
@@ -245,9 +351,52 @@ func LoadLinearLayer(weights WeightSource, path string) (nn.LinearLayer, error)
qbiases, _ = weights.GetTensor(qbiasPath)
}
- groupSize, bits, mode := quantizationParams(weights.Quantization())
+ // Detect bits from tensor shapes (supports mixed-precision Q4/Q8)
+ weightShape := weight.Shape()
+ scalesShape := scales.Shape()
+ weightCols := int(weightShape[len(weightShape)-1])
+ scalesCols := int(scalesShape[len(scalesShape)-1])
+
+ // Detect quantization from tensor shapes
+ // groupSize = weightCols * packFactor / scalesCols
+ // Note: groupSize4 = 2 * groupSize8 always, so ambiguous cases need metadata
+ groupSize4 := weightCols * 8 / scalesCols
+ groupSize8 := weightCols * 4 / scalesCols
+
+ var bits, groupSize int
+ mode := "affine"
+ // Use metadata to help disambiguate when shapes are ambiguous
+ // (e.g., Q4 with group_size=64 has same shapes as Q8 with group_size=32)
+ quantType := strings.ToUpper(weights.Quantization())
+ isQ8Type := quantType == "Q8" || quantType == "FP8" || quantType == "INT8"
+
+ if groupSize4 == 32 {
+ // Unambiguous: Q4 with group_size=32
+ bits = 4
+ groupSize = 32
+ } else if groupSize8 == 64 {
+ // Unambiguous: Q8 with group_size=64
+ bits = 8
+ groupSize = 64
+ } else if groupSize4 == 64 && groupSize8 == 32 {
+ // Ambiguous: could be Q4/gs=64 or Q8/gs=32, use metadata
+ if isQ8Type {
+ bits = 8
+ groupSize = 32
+ } else {
+ bits = 4
+ groupSize = 64
+ }
+ } else {
+ // Fallback: use global quantization params
+ _, bits, mode = QuantizationParams(weights.Quantization())
+ packFactor := 32 / bits
+ groupSize = weightCols * packFactor / scalesCols
+ }
- if mlx.MetalIsAvailable() {
+ // NVFP4 and MXFP8 don't have native quantized matmul kernels in MLX,
+ // so we always dequantize at load time. Affine modes (FP4, FP8) have kernel support.
+ if mlx.MetalIsAvailable() && mode != "nvfp4" && mode != "mxfp8" {
return &nn.QuantizedLinear{
Weight: weight,
Scales: scales,
diff --git a/x/imagegen/safetensors/safetensors.go b/x/imagegen/safetensors/safetensors.go
index a36052fcef9..4dbcf59a35d 100644
--- a/x/imagegen/safetensors/safetensors.go
+++ b/x/imagegen/safetensors/safetensors.go
@@ -303,6 +303,11 @@ func (mw *ModelWeights) Quantization() string {
return ""
}
+// GroupSize returns 0 for directory-based weights (use default).
+func (mw *ModelWeights) GroupSize() int {
+ return 0
+}
+
// ReleaseAll releases all cached native file handles.
func (mw *ModelWeights) ReleaseAll() {
for path, native := range mw.nativeCache {
diff --git a/x/imagegen/server.go b/x/imagegen/server.go
index d7d282d8e44..f79b8d3e94b 100644
--- a/x/imagegen/server.go
+++ b/x/imagegen/server.go
@@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
+ "io"
"log/slog"
"math/rand"
"net"
@@ -20,21 +21,22 @@ import (
"sync"
"time"
+ "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/x/imagegen/manifest"
)
-// Server wraps an image generation subprocess to implement llm.LlamaServer.
+// Server wraps an MLX runner subprocess to implement llm.LlamaServer.
//
// This implementation is compatible with Ollama's scheduler and can be loaded/unloaded
-// like any other model. The plan is to eventually bring this into the llm/ package
-// and evolve llm/ to support MLX and multimodal models. For now, keeping the code
-// separate allows for independent iteration on image generation support.
+// like any other model. It supports both LLM (safetensors) and image generation models.
type Server struct {
mu sync.Mutex
cmd *exec.Cmd
port int
modelName string
+ mode ModelMode
vramSize uint64
done chan error
client *http.Client
@@ -42,8 +44,8 @@ type Server struct {
lastErrLock sync.Mutex
}
-// NewServer spawns a new image generation subprocess and waits until it's ready.
-func NewServer(modelName string) (*Server, error) {
+// NewServer spawns a new MLX runner subprocess and waits until it's ready.
+func NewServer(modelName string, mode ModelMode) (*Server, error) {
// Validate platform support before attempting to start
if err := CheckPlatformSupport(); err != nil {
return nil, err
@@ -70,8 +72,8 @@ func NewServer(modelName string) (*Server, error) {
exe = eval
}
- // Spawn subprocess: ollama runner --image-engine --model --port
- cmd := exec.Command(exe, "runner", "--image-engine", "--model", modelName, "--port", strconv.Itoa(port))
+ // Spawn subprocess: ollama runner --imagegen-engine --model --port
+ cmd := exec.Command(exe, "runner", "--imagegen-engine", "--model", modelName, "--port", strconv.Itoa(port))
cmd.Env = os.Environ()
// On Linux, set LD_LIBRARY_PATH to include MLX library directories
@@ -104,11 +106,21 @@ func NewServer(modelName string) (*Server, error) {
slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
}
+ // Estimate VRAM based on tensor size from manifest
+ var vramSize uint64
+ if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
+ vramSize = uint64(modelManifest.TotalTensorSize())
+ } else {
+ // Fallback: default to 8GB if manifest can't be loaded
+ vramSize = 8 * 1024 * 1024 * 1024
+ }
+
s := &Server{
cmd: cmd,
port: port,
modelName: modelName,
- vramSize: EstimateVRAM(modelName),
+ mode: mode,
+ vramSize: vramSize,
done: make(chan error, 1),
client: &http.Client{Timeout: 10 * time.Minute},
}
@@ -119,23 +131,23 @@ func NewServer(modelName string) (*Server, error) {
go func() {
scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
- slog.Info("image-runner", "msg", scanner.Text())
+ slog.Info("mlx-runner", "msg", scanner.Text())
}
}()
go func() {
scanner := bufio.NewScanner(stderr)
for scanner.Scan() {
line := scanner.Text()
- slog.Warn("image-runner", "msg", line)
+ slog.Warn("mlx-runner", "msg", line)
s.lastErrLock.Lock()
s.lastErr = line
s.lastErrLock.Unlock()
}
}()
- slog.Info("starting image runner subprocess", "exe", exe, "model", modelName, "port", port)
+ slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port, "mode", mode)
if err := cmd.Start(); err != nil {
- return nil, fmt.Errorf("failed to start image runner: %w", err)
+ return nil, fmt.Errorf("failed to start mlx runner: %w", err)
}
// Reap subprocess when it exits
@@ -158,6 +170,7 @@ func (s *Server) ModelPath() string {
return s.modelName
}
+// Load satisfies the LlamaServer interface. MLX models don't need GPU layer assignment.
func (s *Server) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
return nil, nil
}
@@ -183,7 +196,7 @@ func (s *Server) Ping(ctx context.Context) error {
// waitUntilRunning waits for the subprocess to be ready.
func (s *Server) waitUntilRunning() error {
ctx := context.Background()
- timeout := time.After(2 * time.Minute)
+ timeout := time.After(envconfig.LoadTimeout())
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
@@ -193,18 +206,18 @@ func (s *Server) waitUntilRunning() error {
// Include recent stderr lines for better error context
errMsg := s.getLastErr()
if errMsg != "" {
- return fmt.Errorf("image runner failed: %s (exit: %v)", errMsg, err)
+ return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
}
- return fmt.Errorf("image runner exited unexpectedly: %w", err)
+ return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
case <-timeout:
errMsg := s.getLastErr()
if errMsg != "" {
- return fmt.Errorf("timeout waiting for image runner: %s", errMsg)
+ return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
}
- return errors.New("timeout waiting for image runner to start")
+ return errors.New("timeout waiting for mlx runner to start")
case <-ticker.C:
if err := s.Ping(ctx); err == nil {
- slog.Info("image runner is ready", "port", s.port)
+ slog.Info("mlx runner is ready", "port", s.port)
return nil
}
}
@@ -218,27 +231,43 @@ func (s *Server) getLastErr() string {
return s.lastErr
}
-func (s *Server) WaitUntilRunning(ctx context.Context) error { return nil }
+// WaitUntilRunning satisfies the LlamaServer interface.
+func (s *Server) WaitUntilRunning(ctx context.Context) error {
+ return nil
+}
+// Completion handles both text and image generation requests.
func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
seed := req.Seed
if seed == 0 {
seed = time.Now().UnixNano()
}
+ // Extract raw image bytes from llm.ImageData slice
+ var images [][]byte
+ for _, img := range req.Images {
+ images = append(images, img.Data)
+ }
+
// Build request for subprocess
- creq := struct {
- Prompt string `json:"prompt"`
- Width int32 `json:"width,omitempty"`
- Height int32 `json:"height,omitempty"`
- Steps int32 `json:"steps,omitempty"`
- Seed int64 `json:"seed,omitempty"`
- }{
+ creq := Request{
Prompt: req.Prompt,
Width: req.Width,
Height: req.Height,
- Steps: req.Steps,
+ Steps: int(req.Steps),
Seed: seed,
+ Images: images,
+ }
+
+ // Pass LLM options if present
+ if req.Options != nil {
+ creq.Options = &RequestOptions{
+ NumPredict: req.Options.NumPredict,
+ Temperature: float64(req.Options.Temperature),
+ TopP: float64(req.Options.TopP),
+ TopK: req.Options.TopK,
+ Stop: req.Options.Stop,
+ }
}
body, err := json.Marshal(creq)
@@ -260,31 +289,47 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- return fmt.Errorf("request failed: %d", resp.StatusCode)
+ body, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("%s", strings.TrimSpace(string(body)))
}
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(make([]byte, 1024*1024), 16*1024*1024) // 16MB max
for scanner.Scan() {
- // Parse subprocess response (has singular "image" field)
+ // Parse subprocess response
var raw struct {
- Image string `json:"image,omitempty"`
- Content string `json:"content,omitempty"`
- Done bool `json:"done"`
- Step int `json:"step,omitempty"`
- Total int `json:"total,omitempty"`
+ Image string `json:"image,omitempty"`
+ Content string `json:"content,omitempty"`
+ Done bool `json:"done"`
+ Step int `json:"step,omitempty"`
+ Total int `json:"total,omitempty"`
+ StopReason string `json:"stop_reason,omitempty"`
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
+ PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
+ EvalCount int `json:"eval_count,omitempty"`
+ EvalDuration int `json:"eval_duration,omitempty"`
}
if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
+ slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
continue
}
+ // Log stop reason when generation completes
+ if raw.Done && raw.StopReason != "" {
+ slog.Info("mlx generation completed", "stop_reason", raw.StopReason)
+ }
+
// Convert to llm.CompletionResponse
cresp := llm.CompletionResponse{
- Content: raw.Content,
- Done: raw.Done,
- Step: raw.Step,
- TotalSteps: raw.Total,
- Image: raw.Image,
+ Content: raw.Content,
+ Done: raw.Done,
+ Step: raw.Step,
+ TotalSteps: raw.Total,
+ Image: raw.Image,
+ PromptEvalCount: raw.PromptEvalCount,
+ PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
+ EvalCount: raw.EvalCount,
+ EvalDuration: time.Duration(raw.EvalDuration),
}
fn(cresp)
@@ -293,7 +338,20 @@ func (s *Server) Completion(ctx context.Context, req llm.CompletionRequest, fn f
}
}
- return scanner.Err()
+ // Scanner exited without receiving Done - connection was likely closed
+ scanErr := scanner.Err()
+ if scanErr != nil {
+ slog.Error("mlx scanner error", "error", scanErr)
+ } else {
+ slog.Warn("mlx scanner EOF without Done response - subprocess may have crashed")
+ }
+
+ // Check if subprocess is still alive
+ if s.HasExited() {
+ slog.Error("mlx subprocess has exited unexpectedly")
+ }
+
+ return scanErr
}
// Close terminates the subprocess.
@@ -302,7 +360,7 @@ func (s *Server) Close() error {
defer s.mu.Unlock()
if s.cmd != nil && s.cmd.Process != nil {
- slog.Info("stopping image runner subprocess", "pid", s.cmd.Process.Pid)
+ slog.Info("stopping mlx runner subprocess", "pid", s.cmd.Process.Pid)
s.cmd.Process.Signal(os.Interrupt)
// Wait briefly for graceful shutdown
@@ -331,18 +389,56 @@ func (s *Server) VRAMByGPU(id ml.DeviceID) uint64 {
return s.vramSize
}
+// ContextLength returns the context length (not applicable for image generation).
+func (s *Server) ContextLength() int {
+ return 0
+}
+
+// Embedding returns embeddings for the input.
func (s *Server) Embedding(ctx context.Context, input string) ([]float32, int, error) {
- return nil, 0, errors.New("not supported")
+ return nil, 0, errors.New("embeddings not supported for MLX models")
}
+// Tokenize tokenizes the input content.
func (s *Server) Tokenize(ctx context.Context, content string) ([]int, error) {
- return nil, errors.New("not supported")
+ body, err := json.Marshal(map[string]string{"content": content})
+ if err != nil {
+ return nil, err
+ }
+
+ url := fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port)
+ req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := s.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ return nil, fmt.Errorf("tokenize failed: %d", resp.StatusCode)
+ }
+
+ var result struct {
+ Tokens []int `json:"tokens"`
+ }
+ if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
+ return nil, err
+ }
+
+ return result.Tokens, nil
}
+// Detokenize converts tokens back to text.
func (s *Server) Detokenize(ctx context.Context, tokens []int) (string, error) {
- return "", errors.New("not supported")
+ return "", errors.New("detokenization not supported for MLX models")
}
+// Pid returns the process ID of the subprocess.
func (s *Server) Pid() int {
s.mu.Lock()
defer s.mu.Unlock()
@@ -352,9 +448,17 @@ func (s *Server) Pid() int {
return -1
}
-func (s *Server) GetPort() int { return s.port }
-func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo { return nil }
+// GetPort returns the port the subprocess is listening on.
+func (s *Server) GetPort() int {
+ return s.port
+}
+
+// GetDeviceInfos returns device information.
+func (s *Server) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
+ return nil
+}
+// HasExited returns whether the subprocess has exited.
func (s *Server) HasExited() bool {
select {
case <-s.done:
diff --git a/x/imagegen/server_test.go b/x/imagegen/server_test.go
deleted file mode 100644
index c3957236a89..00000000000
--- a/x/imagegen/server_test.go
+++ /dev/null
@@ -1,82 +0,0 @@
-package imagegen
-
-import (
- "runtime"
- "testing"
-)
-
-// TestPlatformSupport verifies platform validation works correctly.
-func TestPlatformSupport(t *testing.T) {
- err := CheckPlatformSupport()
-
- switch runtime.GOOS {
- case "darwin":
- if runtime.GOARCH == "arm64" {
- // Apple Silicon should be supported
- if err != nil {
- t.Errorf("Expected nil error on darwin/arm64, got: %v", err)
- }
- } else {
- // Intel Mac should fail
- if err == nil {
- t.Error("Expected error on darwin/amd64 (Intel), got nil")
- }
- if err != nil && err.Error() == "" {
- t.Error("Expected meaningful error message for unsupported platform")
- }
- }
- case "linux", "windows":
- // Linux/Windows are allowed (CUDA support checked at runtime)
- if err != nil {
- t.Errorf("Expected nil error on %s, got: %v", runtime.GOOS, err)
- }
- default:
- // Other platforms should fail
- if err == nil {
- t.Errorf("Expected error on unsupported platform %s, got nil", runtime.GOOS)
- }
- }
-}
-
-// TestMemoryRequirementsError verifies memory check returns clear error.
-func TestMemoryRequirementsError(t *testing.T) {
- // Test with insufficient memory
- err := CheckMemoryRequirements("test-model", 8*GB)
- if err == nil {
- t.Error("Expected error for insufficient memory (8GB < 21GB default)")
- }
-
- // Test with sufficient memory
- err = CheckMemoryRequirements("test-model", 32*GB)
- if err != nil {
- t.Errorf("Expected no error for sufficient memory (32GB), got: %v", err)
- }
-}
-
-// TestEstimateVRAMReturnsReasonableDefaults verifies VRAM estimates are sensible.
-func TestEstimateVRAMReturnsReasonableDefaults(t *testing.T) {
- // Unknown model should return default (21GB)
- vram := EstimateVRAM("unknown-model")
- if vram < 10*GB || vram > 100*GB {
- t.Errorf("VRAM estimate %d GB is outside reasonable range (10-100 GB)", vram/GB)
- }
-
- // Verify known pipeline estimates exist and are reasonable
- for name, estimate := range modelVRAMEstimates {
- if estimate < 10*GB {
- t.Errorf("VRAM estimate for %s (%d GB) is suspiciously low", name, estimate/GB)
- }
- if estimate > 200*GB {
- t.Errorf("VRAM estimate for %s (%d GB) is suspiciously high", name, estimate/GB)
- }
- }
-}
-
-// TestServerInterfaceCompliance verifies Server implements llm.LlamaServer.
-// This is a compile-time check but we document it as a test.
-func TestServerInterfaceCompliance(t *testing.T) {
- // The var _ llm.LlamaServer = (*Server)(nil) line in server.go
- // ensures compile-time interface compliance.
- // This test documents that requirement.
- t.Log("Server implements llm.LlamaServer interface (compile-time checked)")
-}
diff --git a/x/imagegen/types.go b/x/imagegen/types.go
new file mode 100644
index 00000000000..7c837a02169
--- /dev/null
+++ b/x/imagegen/types.go
@@ -0,0 +1,81 @@
+// Package imagegen provides a unified MLX runner for both LLM and image generation models.
+//
+// This package handles safetensors models created with `ollama create --experimental`,
+// supporting both text generation (LLM) and image generation (diffusion) models
+// through a single unified interface.
+package imagegen
+
+// Request is the request format for completion requests.
+type Request struct {
+ Prompt string `json:"prompt"`
+
+ // LLM-specific fields
+ Options *RequestOptions `json:"options,omitempty"`
+
+ // Image generation fields
+ Width int32 `json:"width,omitempty"`
+ Height int32 `json:"height,omitempty"`
+ Steps int `json:"steps,omitempty"`
+ Seed int64 `json:"seed,omitempty"`
+ Images [][]byte `json:"images,omitempty"` // Input images for image editing/conditioning
+}
+
+// RequestOptions contains LLM-specific generation options.
+type RequestOptions struct {
+ NumPredict int `json:"num_predict,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Stop []string `json:"stop,omitempty"`
+}
+
+// Response is streamed back for each progress update.
+type Response struct {
+ // Text generation response
+ Content string `json:"content,omitempty"`
+
+ // Image generation response
+ Image string `json:"image,omitempty"` // Base64-encoded PNG
+
+ // Common fields
+ Done bool `json:"done"`
+ DoneReason int `json:"done_reason,omitempty"`
+ StopReason string `json:"stop_reason,omitempty"` // Debug: why generation stopped
+
+ // Progress fields
+ Step int `json:"step,omitempty"`
+ Total int `json:"total,omitempty"`
+
+ // Statistics
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
+ PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
+ EvalCount int `json:"eval_count,omitempty"`
+ EvalDuration int `json:"eval_duration,omitempty"`
+}
+
+// HealthResponse is returned by the health endpoint.
+type HealthResponse struct {
+ Status string `json:"status"`
+ Progress float32 `json:"progress,omitempty"`
+}
+
+// ModelMode represents the type of model being run.
+type ModelMode int
+
+const (
+ // ModeLLM indicates a text generation model.
+ ModeLLM ModelMode = iota
+ // ModeImageGen indicates an image generation model.
+ ModeImageGen
+)
+
+func (m ModelMode) String() string {
+ switch m {
+ case ModeLLM:
+ return "llm"
+ case ModeImageGen:
+ return "imagegen"
+ default:
+ return "unknown"
+ }
+}
diff --git a/x/imagegen/weights.go b/x/imagegen/weights.go
deleted file mode 100644
index f49c7e77e2f..00000000000
--- a/x/imagegen/weights.go
+++ /dev/null
@@ -1,131 +0,0 @@
-//go:build mlx
-
-package imagegen
-
-import (
- "fmt"
- "sort"
- "strings"
-
- "github.com/ollama/ollama/x/imagegen/mlx"
-)
-
-// ManifestWeights provides fast weight loading from tensor blobs.
-// Uses native mmap loading with synthetic safetensors headers for zero-copy.
-type ManifestWeights struct {
- manifest *ModelManifest
- component string
- tensors map[string]ManifestLayer // name -> layer
- cache map[string]*mlx.Array // name -> loaded array
- nativeCache []*mlx.SafetensorsFile // keep native handles alive
-}
-
-// LoadWeightsFromManifest creates a weight loader for a component from manifest storage.
-func LoadWeightsFromManifest(manifest *ModelManifest, component string) (*ManifestWeights, error) {
- layers := manifest.GetTensorLayers(component)
- if len(layers) == 0 {
- return nil, fmt.Errorf("no tensor layers found for component %q", component)
- }
-
- // Strip component prefix from tensor names for model loading
- // e.g., "text_encoder/model.embed_tokens.weight" -> "model.embed_tokens.weight"
- prefix := component + "/"
- tensors := make(map[string]ManifestLayer, len(layers))
- for _, layer := range layers {
- tensorName := strings.TrimPrefix(layer.Name, prefix)
- tensors[tensorName] = layer
- }
-
- return &ManifestWeights{
- manifest: manifest,
- component: component,
- tensors: tensors,
- cache: make(map[string]*mlx.Array),
- }, nil
-}
-
-// Load loads all tensor blobs using native mmap (zero-copy).
-// Blobs are stored in safetensors format for native mlx_load_safetensors mmap.
-// If dtype is non-zero, tensors are converted to the specified dtype.
-func (mw *ManifestWeights) Load(dtype mlx.Dtype) error {
- for name, layer := range mw.tensors {
- path := mw.manifest.BlobPath(layer.Digest)
-
- // Load blob as safetensors (native mmap, zero-copy)
- sf, err := mlx.LoadSafetensorsNative(path)
- if err != nil {
- return fmt.Errorf("load %s: %w", name, err)
- }
-
- // Blob contains single tensor named "data"
- arr := sf.Get("data")
- if arr == nil {
- sf.Free()
- return fmt.Errorf("tensor 'data' not found in blob for %s", name)
- }
-
- // Convert dtype if needed
- if dtype != 0 && arr.Dtype() != dtype {
- arr = mlx.AsType(arr, dtype)
- }
- // ALWAYS make a contiguous copy to ensure independence from mmap
- arr = mlx.Contiguous(arr)
- mlx.Eval(arr)
- mw.cache[name] = arr
- sf.Free() // Safe to free - arr is now an independent copy
- }
-
- return nil
-}
-
-// GetTensor returns a tensor from cache. Call Load() first.
-func (mw *ManifestWeights) GetTensor(name string) (*mlx.Array, error) {
- if mw.cache == nil {
- return nil, fmt.Errorf("cache not initialized: call Load() first")
- }
- arr, ok := mw.cache[name]
- if !ok {
- return nil, fmt.Errorf("tensor %q not found", name)
- }
- return arr, nil
-}
-
-// ListTensors returns all tensor names in sorted order.
-func (mw *ManifestWeights) ListTensors() []string {
- names := make([]string, 0, len(mw.tensors))
- for name := range mw.tensors {
- names = append(names, name)
- }
- sort.Strings(names)
- return names
-}
-
-// HasTensor checks if a tensor exists.
-func (mw *ManifestWeights) HasTensor(name string) bool {
- _, ok := mw.tensors[name]
- return ok
-}
-
-// Quantization returns the model's quantization type from model_index.json.
-// Returns empty string if not quantized or unknown.
-func (mw *ManifestWeights) Quantization() string {
- if mw.manifest == nil {
- return ""
- }
- var index struct {
- Quantization string `json:"quantization"`
- }
- if err := mw.manifest.ReadConfigJSON("model_index.json", &index); err != nil {
- return ""
- }
- return index.Quantization
-}
-
-// ReleaseAll frees all native handles and clears the tensor cache.
-func (mw *ManifestWeights) ReleaseAll() {
- for _, sf := range mw.nativeCache {
- sf.Free()
- }
- mw.nativeCache = nil
- mw.cache = nil
-}
diff --git a/x/kvcache/cache.go b/x/kvcache/cache.go
deleted file mode 100644
index f0627584aaa..00000000000
--- a/x/kvcache/cache.go
+++ /dev/null
@@ -1,77 +0,0 @@
-package kvcache
-
-import (
- "errors"
-
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/model/input"
-)
-
-var (
- ErrKvCacheFull = errors.New("could not find a kv cache slot")
- ErrNotSupported = errors.New("model does not support operation")
-)
-
-type Cache interface {
- // ** used by model implementations **
-
- // SetLayer sets the active layer of the cache
- SetLayer(layer int)
-
- // Get returns the history of key and value tensors plus a mask
- //
- // The shape of the tensors is documented in the specific
- // cache implementation used.
- Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
-
- // Put stores a batch of key and value in the cache
- //
- // The shape of the tensors is documented in the specific
- // cache implementation used.
- Put(ctx ml.Context, key, value ml.Tensor)
-
- // SetConfig controls optimizations (mostly backend-specific) that may transform
- // the output of the cache to work better with specific kernels. If not called,
- // the backend settings will be used. This works well when calling Attention.
- //
- // The config can be overridden by models, especially if they require vanilla
- // output when implementing their own version of attention. To do this, pass
- // an empty ml.CacheConfig.
- //
- // Most models will not need to use this.
- SetConfig(ml.CacheConfig)
-
- // ** cache management **
-
- // Init sets up runtime parameters.
- // backend: Used to allocate cache data storage and execute management operations (such as defrag)
- // dtype: The data type for storing cache entries
- // maxSequences: The maximum number of sequences stored in the cache - across all batches
- // capacity: The number of cache entries to store, per sequence
- // maxBatch: The maximum number of tokens that can occur in a single batch
- Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
-
- // Close closes the cache and frees resources associated with it
- Close()
-
- // StartForward is called before the start of the model's forward pass.
- // For each token in the coming batch, there must be a corresponding
- // entry in positions and seqs. reserve is to preallocate memory
- // without actually storing data in the cache.
- StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
-
- // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
- CopyPrefix(srcSeq, dstSeq int, len int32)
-
- // CanResume returns true if the cache can continue with the next token at
- // the given position and sequence. Assumes that the caller has already
- // verified the contents of the cache.
- CanResume(seq int, pos int32) bool
-
- // Remove deletes tokens in the range [beginIndex, endIndex) from seq. Set
- // endIndex to math.MaxInt32 to remove everything starting at beginIndex.
- //
- // If an error occurs, the entire context for the sequence should be
- // removed by calling Remove(seq, 0, math.MaxInt32)
- Remove(seq int, beginIndex, endIndex int32) error
-}
diff --git a/x/kvcache/causal.go b/x/kvcache/causal.go
deleted file mode 100644
index 967fed6744c..00000000000
--- a/x/kvcache/causal.go
+++ /dev/null
@@ -1,797 +0,0 @@
-package kvcache
-
-// import (
-// "errors"
-// "fmt"
-// "log/slog"
-// "math"
-// "slices"
-
-// "github.com/ollama/ollama/ml"
-// "github.com/ollama/ollama/model/input"
-// )
-
-// type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error)
-
-// // Causal cache stores K and V tensors according to their position in the
-// // sequence. Returns the history and a mask for attending to past tokens
-// //
-// // The tensors are of shape embed dim, kv heads, batch size
-// // The mask is of shape history size, batch size
-// type Causal struct {
-// DType ml.DType
-
-// // swaWindowSize is the number of tokens that will be included in the mask
-// // during attention operations. swaMemorySize is the number of tokens that
-// // will be retained in memory for partial prefix caching. Set to math.MaxInt32
-// // for unlimited or if sliding window attention is not being used.
-// swaWindowSize int32
-// swaMemorySize int32
-
-// chunkSize int32
-
-// opts CausalOptions
-
-// // maxBatch is the largest batch that we might receive
-// maxBatch int
-
-// // config controls mostly backend-specific optimizations
-// config *ml.CacheConfig
-
-// // ** current forward pass **
-
-// // size of the current batch
-// curBatchSize int
-
-// // locations for data storage for this batch
-// curLoc ml.Tensor
-
-// // mask of the cache as used by this batch
-// curMask ml.Tensor
-
-// // the active layer for Get and Put
-// curLayer int
-
-// // locations in the cache that are needed for this batch
-// curCellRange cellRange
-
-// // curSequences is the sequences corresponding to this pass's entries in the cache
-// curSequences []int
-
-// // curPositions is the positions corresponding to this pass's entries in the cache
-// curPositions []int32
-
-// // ** cache metadata **
-
-// // for each possible location in the cache, stores the position and set of sequences
-// // that reference the data there
-// cells []cacheCell
-
-// // maps from sequence to the range of locations where it is stored in the cache
-// cellRanges map[int]cellRange
-
-// // ** cache data storage **
-
-// shiftFn shiftFn
-// backend ml.Backend
-// ctxs map[int]ml.Context
-// keys, values map[int]ml.Tensor
-
-// kHeadDims, vHeadDims, numKVHeads map[int]int
-// }
-
-// type cacheCell struct {
-// pos int32
-// sequences []int
-// }
-
-// type cellRange struct {
-// min int
-// max int
-// }
-
-// func NewCausalCache(shift shiftFn) *Causal {
-// return &Causal{
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func NewSWACache(windowSize int32, shift shiftFn) *Causal {
-// return &Causal{
-// swaWindowSize: windowSize,
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func NewSWAMemCache(windowSize int32, memorySize int32, shift shiftFn) *Causal {
-// return &Causal{
-// swaWindowSize: windowSize,
-// swaMemorySize: memorySize,
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func NewChunkedAttentionCache(chunkSize int32, shift shiftFn) *Causal {
-// return &Causal{
-// chunkSize: chunkSize,
-// shiftFn: shift,
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// kHeadDims: make(map[int]int),
-// vHeadDims: make(map[int]int),
-// numKVHeads: make(map[int]int),
-// }
-// }
-
-// func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
-// if c.config == nil {
-// var config ml.CacheConfig
-// if cc, ok := backend.(ml.BackendCacheConfig); ok {
-// config = cc.CacheConfig()
-// }
-// c.config = &config
-// }
-
-// if c.config.CachePadding == 0 {
-// c.config.CachePadding = 1
-// }
-
-// if c.config.MaskBatchPadding == 0 {
-// c.config.MaskBatchPadding = 1
-// }
-
-// // TODO what types do we handle here?
-// // if c.config.MaskDType == ml.DTypeOther {
-// // c.config.MaskDType = ml.DTypeFloat32
-// // }
-
-// if c.swaWindowSize == 0 {
-// c.swaWindowSize = math.MaxInt32
-// }
-// if c.swaMemorySize == 0 {
-// c.swaMemorySize = c.swaWindowSize
-// }
-// // We will allocate space in the cache for the stop token, which won't be part of a follow on
-// // sequence, so allocate an extra token of storage to ensure that we can jump back without
-// // causing a cache break. As an optimization, only do this when we have parallel sequences
-// // because the extra token will live in the batch buffer and won't get overwritten if we
-// // only have a single sequence.
-// if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
-// c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
-// }
-// if int(c.swaMemorySize) >= capacity {
-// c.swaMemorySize = math.MaxInt32
-// }
-
-// if c.swaMemorySize < c.swaWindowSize {
-// panic(fmt.Errorf("sliding window memory (%v) must be at least as large as the window (%v)", c.swaMemorySize, c.swaWindowSize))
-// }
-
-// var cacheSize int
-// if c.swaMemorySize == math.MaxInt32 {
-// cacheSize = maxSequences * capacity
-// } else {
-// cacheSize = (maxSequences * int(c.swaMemorySize)) + maxBatch
-// }
-// cacheSize = roundUp(cacheSize, c.config.CachePadding)
-// c.cells = make([]cacheCell, cacheSize)
-
-// c.DType = dtype
-// c.cellRanges = make(map[int]cellRange)
-// c.backend = backend
-// c.maxBatch = maxBatch
-// }
-
-// func (c *Causal) SetConfig(config ml.CacheConfig) {
-// if c.config != nil {
-// panic("config cannot be changed after being previously set, either by the model or backend")
-// }
-
-// c.config = &config
-// }
-
-// func (c *Causal) Close() {
-// slog.Info("XXX Causal.Close called", "number of contexts", len(c.ctxs))
-// for _, ctx := range c.ctxs {
-// ctx.Close()
-// }
-// }
-
-// func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
-// slog.Info("XXX Causal.StartForward", "cell count", len(c.cells), "prior batch size", c.curBatchSize, "positions", len(batch.Positions), "reserve", reserve, "batch", batch)
-// // panic("XXX Causal.StartForward")
-// c.curBatchSize = len(batch.Positions)
-// c.curSequences = batch.Sequences
-// c.curPositions = batch.Positions
-// c.opts.Except = nil
-
-// var locs []int32
-// if !reserve {
-// c.updateSlidingWindow()
-
-// var err error
-// locs, err = c.findLocs()
-// if err != nil {
-// return err
-// }
-// slog.Info("XXX Causal.StartForward", "findLocs len", len(locs))
-
-// for i, pos := range batch.Positions {
-// seq := batch.Sequences[i]
-// loc := int(locs[i])
-
-// c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
-
-// seqRange, ok := c.cellRanges[seq]
-// if !ok {
-// seqRange = newRange()
-// }
-
-// seqRange.min = min(seqRange.min, loc)
-// c.curCellRange.min = min(c.curCellRange.min, loc)
-
-// seqRange.max = max(seqRange.max, loc)
-// c.curCellRange.max = max(c.curCellRange.max, loc)
-
-// c.cellRanges[seq] = seqRange
-// }
-// } else {
-// // If we are reserving memory, don't update any of the cache metadata but set the size
-// // to the worst case.
-// locs = make([]int32, c.curBatchSize)
-// for i := range locs {
-// locs[i] = int32(i)
-// }
-// c.curCellRange.min = 0
-// c.curCellRange.max = len(c.cells) - 1
-// }
-
-// // XXX Building up the locs for what's already processed (if any)
-// dummyLocs := []int{}
-// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
-// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
-
-// for i := range c.curBatchSize {
-// enabled := !slices.Contains(c.opts.Except, i)
-// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
-// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
-// (enabled && c.cells[j].pos > c.curPositions[i]) ||
-// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
-// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
-// // mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
-// } else {
-// if len(dummyLocs) == 0 || dummyLocs[len(dummyLocs)-1] != i {
-// dummyLocs = append(dummyLocs, i)
-// }
-// }
-// }
-// }
-// slog.Info("XXX Causa.StartForward calculated locations", "locs", dummyLocs)
-
-// slog.Info("XXX Causal.StartForward", "locs", locs)
-// c.curLoc = ctx.Input().FromInts(locs, len(locs))
-// c.curMask = c.buildMask(ctx)
-
-// return nil
-// }
-
-// func newRange() cellRange {
-// return cellRange{
-// min: math.MaxInt,
-// max: 0,
-// }
-// }
-
-// // Returns a slice of locations where each token in the batch should be stored
-// func (c *Causal) findLocs() ([]int32, error) {
-// loc := make([]int32, 0, c.curBatchSize)
-
-// for i := range c.cells {
-// if len(c.cells[i].sequences) == 0 {
-// loc = append(loc, int32(i))
-// if len(loc) >= c.curBatchSize {
-// return loc, nil
-// }
-// }
-// }
-
-// return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
-// }
-
-// func (c *Causal) updateSlidingWindow() {
-// c.curCellRange = newRange()
-
-// if c.swaMemorySize == math.MaxInt32 {
-// for _, seq := range c.curSequences {
-// if seqRange, ok := c.cellRanges[seq]; ok {
-// c.curCellRange.min = min(c.curCellRange.min, seqRange.min)
-// c.curCellRange.max = max(c.curCellRange.max, seqRange.max)
-// }
-// }
-
-// return
-// }
-
-// type lowestPosition struct {
-// pos int32
-// curBatch bool
-// }
-
-// // create a map of unique sequences to the lowest position in that sequence
-// lowestPos := make(map[int]lowestPosition)
-// for i := range c.curPositions {
-// seq := c.curSequences[i]
-
-// lowest, ok := lowestPos[seq]
-// if !ok {
-// lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
-// } else if c.curPositions[i] < lowest.pos {
-// lowest.pos = c.curPositions[i]
-// }
-
-// lowestPos[seq] = lowest
-// }
-
-// // for any sequences are not part of this batch, clean up any tokens
-// // that are no longer needed after the processing of the previous
-// // batch
-// for seq, seqRange := range c.cellRanges {
-// if _, ok := lowestPos[seq]; !ok {
-// var last int32
-// for i := seqRange.min; i <= seqRange.max; i++ {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// last = max(last, c.cells[i].pos)
-// }
-// }
-
-// lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
-// }
-// }
-
-// // delete any entries that are beyond the window of the oldest position in the sequence
-// for seq, lowest := range lowestPos {
-// oldRange, ok := c.cellRanges[seq]
-// if !ok {
-// continue
-// }
-
-// newRange := newRange()
-
-// for i := oldRange.min; i <= oldRange.max; i++ {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// if c.cells[i].pos < lowest.pos-c.swaMemorySize {
-// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
-// } else {
-// newRange.min = min(newRange.min, i)
-// newRange.max = max(newRange.max, i)
-// }
-// if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
-// c.curCellRange.min = min(c.curCellRange.min, i)
-// c.curCellRange.max = max(c.curCellRange.max, i)
-// }
-// }
-// }
-
-// c.cellRanges[seq] = newRange
-// }
-// }
-
-// func roundDown(length, pad int) int {
-// return (length / pad) * pad
-// }
-
-// func roundUp(length, pad int) int {
-// return ((length + pad - 1) / pad) * pad
-// }
-
-// // Builds a mask of history x batch indicating whether for each token in the batch the
-// // token in the history should apply. This is based on both the sequence and causality (the
-// // position of the history is not ahead of the token in the batch).
-// func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
-// // Align and pad the two dimensions as required by the backend
-// batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
-
-// c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
-// c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
-
-// length := c.curCellRange.max - c.curCellRange.min + 1
-
-// mask := make([]float32, batchSize*length)
-
-// for i := range c.curBatchSize {
-// enabled := !slices.Contains(c.opts.Except, i)
-// for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
-// if !slices.Contains(c.cells[j].sequences, c.curSequences[i]) ||
-// (enabled && c.cells[j].pos > c.curPositions[i]) ||
-// c.chunkSize > 0 && c.cells[j].pos < c.curPositions[i]-c.curPositions[i]%c.chunkSize ||
-// c.cells[j].pos < c.curPositions[i]-c.swaWindowSize {
-// mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
-// }
-// }
-// }
-
-// // Mask out any padding tokens we added. For padding that we added to the cache history, this
-// // has already been masked out because the sequence doesn't match.
-// for i := c.curBatchSize * length; i < len(mask); i++ {
-// mask[i] = float32(math.Inf(-1))
-// }
-
-// maskTensor := ctx.Input().FromFloats(mask, batchSize, length)
-
-// // if c.config.MaskDType != ml.DTypeFloat32 {
-// // maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
-// // }
-
-// slog.Info("XXX Causal.buildMask", "c.curBatchSize", c.curBatchSize, "c.config.MaskBatchPadding", c.config.MaskBatchPadding, "c.curCellRange.min", c.curCellRange.min, "c.curCellRange.max", c.curCellRange.max, "size", len(mask), "shape", []int{1, batchSize, length})
-
-// return maskTensor
-// }
-
-// func (c *Causal) SetLayer(layer int) {
-// c.curLayer = layer
-// }
-
-// type CausalOptions struct {
-// // Enabled controls whether the causal mask is generated for a particular index in a batch
-// Except []int
-// }
-
-// // SetCausal disables causal mask generation for a particular range of indicies in
-// // the current batch for subsequent calls to Get. The state resets for the next forward pass.
-// func (c *Causal) SetCausal(ctx ml.Context, opts CausalOptions) {
-// if !slices.Equal(c.opts.Except, opts.Except) {
-// c.opts = opts
-// if ctx != nil {
-// c.curMask = c.buildMask(ctx)
-// }
-// }
-// }
-
-// func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
-// key := c.keys[c.curLayer]
-// value := c.values[c.curLayer]
-
-// kHeadDim := c.kHeadDims[c.curLayer]
-// vHeadDim := c.vHeadDims[c.curLayer]
-// numKVHeads := c.numKVHeads[c.curLayer]
-// // rowSize := numKVHeads * c.curBatchSize
-// // cachedSize := c.curMask.Dim(1)
-// cachedSize := c.curLoc.Dim(0)
-// // kCellSize := kHeadDim * numKVHeads
-// // vCellSize := vHeadDim * numKVHeads
-
-// slog.Info("XXX Causal.Get full cache", "key", key)
-// slog.Info("XXX Causal.Get full cache", "value", value)
-// slog.Info("XXX Causal.Get full cache", "curloc", c.curLoc)
-// slog.Info("XXX Causal.Get", "curMask", c.curMask)
-// slog.Info("XXX Causal.Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "cachedSize", cachedSize, "kHeadDim", kHeadDim)
-// // panic("XXX")
-
-// // fmt.Fprintln(os.Stderr, key.ToString())
-// // panic("full cache value")
-
-// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
-// key = key.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
-// // key = key.AsStrided(ctx, []int{1, numKVHeads, cachedSize, kHeadDim}, []int{}, rowSize*c.curCellRange.min)
-
-// // slog.Info("XXX Causal.Get after AsStrided", "key", key)
-// // panic("XXX")
-
-// // if c.config.PermutedV {
-// // panic("permuted")
-// // // TODO not converted
-// // vHeadDim := value.Dim(1)
-// // elemSize := value.Stride(2)
-
-// // value = value.AsStrided(ctx,
-// // []int{numKVHeads, vHeadDim, cachedSize},
-// // []int{value.Stride(0), value.Stride(1)},
-// // elemSize*c.curCellRange.min,
-// // )
-// // } else {
-// // vHeadDim := c.vHeadDims[c.curLayer]
-// // rowSize := value.Stride(2)
-// // slog.Info("XXX Causal.Get before AsStrided", "vHeadDim", vHeadDim, "rowSize", rowSize)
-// // panic("XXX")
-
-// // TODO we should use TakeAxes to gather the cells from curLoc, but for now to be consistent with GGML, just grab a larger chunk and mask
-// value = value.TakeAxes(ctx, c.curLoc, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
-// // value = value.AsStrided(ctx, []int{1, numKVHeads, cachedSize, vHeadDim}, []int{}, rowSize*c.curCellRange.min)
-
-// // slog.Info("XXX Causal.Get after AsStrided", "value", value)
-// // panic("XXX")
-
-// // }
-
-// // // TODO The mask changes from X,X to 1,X, and with the Row-order change
-// // // the 1 becomes trailing and messes up later operations
-// // // This isn't the right solution, but works around it...
-// // if c.curMask.Dim(1) == 1 {
-// // return key, value, c.curMask.Transpose(ctx, 1, 0, 2, 3)
-// // }
-// // fmt.Fprintln(os.Stderr, key.ToString())
-// // fmt.Fprintln(os.Stderr, value.ToString())
-// // panic("XXX")
-// slog.Info("XXX Mask", "curLayer", c.curLayer, "shape", c.curMask.Shape())
-
-// return key, value, c.curMask
-// }
-
-// func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
-// kHeadDim := key.Dim(3)
-// vHeadDim := value.Dim(3)
-// numKVHeads := key.Dim(1)
-// batchSize := key.Dim(2)
-// kCellSize := kHeadDim * numKVHeads
-// vCellSize := vHeadDim * numKVHeads
-
-// // slog.Info("XXX Causal.Put", "key", key, "value", value)
-// slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize)
-// // panic("XXX")
-
-// if c.curBatchSize != batchSize {
-// panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
-// }
-
-// // slog.Info("XXX", "c.ctxs", c.ctxs, "c.curLayer", c.curLayer, "backend", c.backend)
-// if _, ok := c.ctxs[c.curLayer]; !ok {
-// slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
-// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
-// }
-
-// if _, ok := c.keys[c.curLayer]; !ok {
-// slog.Info("XXX Causal.Put allocating keys", "c.curLayer", c.curLayer, "shape", []int{len(c.cells), kCellSize})
-
-// c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), kCellSize)
-// c.kHeadDims[c.curLayer] = kHeadDim
-// c.vHeadDims[c.curLayer] = vHeadDim
-// c.numKVHeads[c.curLayer] = numKVHeads
-// }
-
-// if _, ok := c.values[c.curLayer]; !ok {
-// // if c.config.PermutedV {
-// // c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, numKVHeads, vHeadDim, len(c.cells))
-// // } else {
-// c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vCellSize)
-// // }
-// }
-
-// key = key.Reshape(ctx, batchSize, 1, kCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
-
-// // slog.Info("XXX Causal.Put after reshape", "keyCache", keyCache)
-// // panic("XXX")
-// // curLoc := 0 // TODO c.curLoc is now a tensor
-// // kSize := numKVHeads * kHeadDim
-// // vSize := numKVHeads * vHeadDim
-// // start := []int{int(curLoc), 0}
-// // kStop := []int{int(curLoc + batchSize), int(kSize)}
-// // vStop := []int{int(curLoc + batchSize), int(vSize)}
-// // strides := []int{1, 1}
-
-// // slog.Info("XXX Causal.Put Key SliceUpdate", "keyCache", keyCache)
-// // slog.Info("XXX Causal.Put Key SliceUpdate", "key", key)
-
-// // slog.Info("XXX Causal.Put Key SliceUpdate", "start", start, "kStop", kStop, "strides", strides)
-
-// // ctx.Forward(c.keys[c.curLayer].SliceUpdate(ctx, key, start, kStop, strides))
-// ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, key, []int{0}))
-// // fmt.Fprintln(os.Stderr, keyCache.ToString())
-// // panic("input value")
-
-// // fmt.Fprintln(os.Stderr, t.ToString())
-// // panic("XXX")
-
-// // if c.config.PermutedV {
-// // panic("permuted")
-// // // TODO not adjusted
-// // value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
-// // value = value.Transpose(ctx, 2, 0, 1, 3)
-
-// // valueCache := c.values[c.curLayer]
-// // valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
-
-// // ctx.Forward(valueCache.SliceUpdate(ctx, value, start, vStop, strides))
-// // } else {
-// value = value.Reshape(ctx, batchSize, 1, vCellSize) //.Contiguous(ctx, false) // TODO contiguous may not be needed
-// // slog.Info("XXX Causal.Put Value SliceUpdate", "valueCache", valueCache)
-// // slog.Info("XXX Causal.Put Value SliceUpdate", "value", value)
-// // slog.Info("XXX Causal.Put Value SliceUpdate", "start", start, "vStop", vStop, "strides", strides)
-
-// ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLoc}, value, []int{0}))
-// // }
-// // fmt.Fprintln(os.Stderr, c.keys[c.curLayer].ToString())
-// // fmt.Fprintln(os.Stderr, c.values[c.curLayer].ToString())
-// // panic("XXX")
-
-// }
-
-// func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
-// seqRange := newRange()
-
-// for i := range c.cells {
-// // Remove the contents of dstSeq so that we only have the copied prefix, metadata will be reset at the end
-// if slices.Contains(c.cells[i].sequences, dstSeq) {
-// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == dstSeq })
-// }
-
-// if slices.Contains(c.cells[i].sequences, srcSeq) && c.cells[i].pos < len {
-// c.cells[i].sequences = append(c.cells[i].sequences, dstSeq)
-// if i < seqRange.min {
-// seqRange.min = i
-// }
-// if i > seqRange.max {
-// seqRange.max = i
-// }
-// }
-// }
-
-// c.cellRanges[dstSeq] = seqRange
-// }
-
-// func (c *Causal) CanResume(seq int, pos int32) bool {
-// if c.swaMemorySize == math.MaxInt32 {
-// return true
-// }
-
-// seqRange, ok := c.cellRanges[seq]
-// if !ok {
-// return false
-// }
-
-// // for sliding window, check that the window of the new sequence is contained in
-// // the window of what we are storing
-// var first int32 = math.MaxInt32
-// var last int32 = -1
-// for i := seqRange.min; i <= seqRange.max; i++ {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// first = min(first, c.cells[i].pos)
-// last = max(last, c.cells[i].pos)
-// }
-// }
-
-// if last == -1 {
-// return false
-// }
-
-// posWindowStart := max(0, pos-c.swaWindowSize)
-// return posWindowStart >= first && pos <= last+1
-// }
-
-// func (c *Causal) shift(seq int, beginIndex, offset int32) error {
-// if c.shiftFn == nil {
-// return ErrNotSupported
-// }
-
-// seqRange := c.cellRanges[seq]
-
-// for start := seqRange.min; start <= seqRange.max; start += c.maxBatch {
-// size := min(seqRange.max-start+1, c.maxBatch)
-// offsets := make([]int32, size)
-
-// var batchFirst, batchLast int
-
-// batchFirst = -1
-// for i := range offsets {
-// cell := c.cells[start+i]
-
-// if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex {
-// offsets[i] = offset
-// if batchFirst < 0 {
-// batchFirst = i
-// }
-// batchLast = i
-// }
-// }
-
-// if batchFirst < 0 {
-// continue
-// }
-
-// offsets = offsets[batchFirst : batchLast+1]
-
-// slog.Info("XXX Causal.shift creating new temporary context")
-// ctx := c.backend.NewContext()
-// kShift := ctx.Input().FromInts(offsets, len(offsets))
-
-// for i, key := range c.keys {
-// if key == nil {
-// continue
-// }
-
-// kHeadDim := key.Dim(2)
-// numKVHeads := key.Dim(1)
-// rowSize := key.Stride(0)
-
-// key = key.AsStrided(ctx,
-// []int{len(offsets), numKVHeads, kHeadDim},
-// []int{key.Stride(0), key.Stride(1)},
-// rowSize*(start+batchFirst),
-// )
-
-// roped, err := c.shiftFn(ctx, i, key, kShift)
-// if err != nil {
-// ctx.Close()
-// return err
-// }
-
-// ctx.Forward(roped.Copy(ctx, key))
-// }
-
-// ctx.Compute()
-// ctx.Close()
-// }
-
-// return nil
-// }
-
-// func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
-// // TODO(jessegross): We should check to see if removing the middle of the sequence will
-// // cause the sliding window to encompass tokens that we no longer have. If so, then we
-// // should return an error, which will trigger the runner to evaluate the full history and
-// // rebuild the window. However, if we have multimodal inputs in our history, this reuse
-// // results in use after free, so we don't do it for now.
-
-// var offset int32
-// if endIndex != math.MaxInt32 {
-// offset = beginIndex - endIndex
-// }
-
-// seqRange := newRange()
-
-// for i := range c.cells {
-// if slices.Contains(c.cells[i].sequences, seq) {
-// if c.cells[i].pos >= beginIndex && c.cells[i].pos < endIndex {
-// c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
-// } else {
-// if c.cells[i].pos >= endIndex {
-// if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
-// return errors.New("shifting cells shared by multiple sequences not supported")
-// }
-
-// c.cells[i].pos += offset
-// }
-// if i < seqRange.min {
-// seqRange.min = i
-// }
-// if i > seqRange.max {
-// seqRange.max = i
-// }
-// }
-// }
-// }
-
-// if seqRange == newRange() {
-// delete(c.cellRanges, seq)
-// return nil
-// }
-
-// c.cellRanges[seq] = seqRange
-
-// if endIndex != math.MaxInt32 {
-// err := c.shift(seq, endIndex+offset, offset)
-// if err != nil {
-// return err
-// }
-// }
-
-// return nil
-// }
diff --git a/x/kvcache/causal_test.go b/x/kvcache/causal_test.go
deleted file mode 100644
index d7ac430b1d8..00000000000
--- a/x/kvcache/causal_test.go
+++ /dev/null
@@ -1,973 +0,0 @@
-package kvcache
-
-// import (
-// "fmt"
-// "math"
-// "slices"
-// "testing"
-
-// "github.com/ollama/ollama/ml"
-// "github.com/ollama/ollama/model/input"
-// )
-
-// type testCase struct {
-// name string
-// in []float32
-// inShape []int
-// seqs []int
-// pos []int32
-// expected []float32
-// expectedShape []int
-// expectedMask []float32
-// }
-
-// func runPermutedVariants(t *testing.T, fn func(t *testing.T, backend *testBackend)) {
-// t.Helper()
-// for _, permuted := range []bool{false, true} {
-// t.Run(fmt.Sprintf("PermutedV=%t", permuted), func(t *testing.T) {
-// fn(t, &testBackend{permutedV: permuted})
-// })
-// }
-// }
-
-// func TestStore(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
-// inShape: []int{2, 3, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234},
-// expectedShape: []int{2, 3, 4},
-// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{115, 215, 125, 225, 135, 235},
-// inShape: []int{2, 3, 1},
-// seqs: []int{0},
-// pos: []int32{4},
-// expected: []float32{111, 211, 121, 221, 131, 231, 112, 212, 122, 222, 132, 232, 113, 213, 123, 223, 133, 233, 114, 214, 124, 224, 134, 234, 115, 215, 125, 225, 135, 235},
-// expectedShape: []int{2, 3, 5},
-// expectedMask: []float32{0, 0, 0, 0, 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestSWA(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewSWACache(1, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, 0, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{4, 5},
-// expected: []float32{5, 6, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, 0,
-// 0, 0, x, x,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestSWASeparateBatches(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewSWACache(1, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 2, 16, 2)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "First seq 0",
-// in: []float32{1, 2},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{0, 1},
-// expected: []float32{1, 2},
-// expectedShape: []int{1, 1, 2},
-// expectedMask: []float32{
-// 0, x,
-// 0, 0,
-// },
-// },
-// {
-// name: "Second seq 0",
-// in: []float32{3, 4},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{2, 3},
-// expected: []float32{2, 3, 4},
-// expectedShape: []int{1, 1, 3},
-// expectedMask: []float32{
-// 0, 0, x,
-// x, 0, 0,
-// },
-// },
-// {
-// name: "First seq 1",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{1, 1},
-// pos: []int32{0, 1},
-// expected: []float32{5, 6},
-// expectedShape: []int{1, 1, 2},
-// expectedMask: []float32{
-// 0, x,
-// 0, 0,
-// },
-// },
-// {
-// name: "Second seq 1",
-// in: []float32{7, 8},
-// inShape: []int{1, 1, 2},
-// seqs: []int{1, 1},
-// pos: []int32{2, 3},
-// expected: []float32{6, 3, 4, 7, 8},
-// expectedShape: []int{1, 1, 5},
-// expectedMask: []float32{
-// 0, x, x, 0, x,
-// x, x, x, 0, 0,
-// },
-// },
-// {
-// name: "Third seq 0",
-// in: []float32{9, 10},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{4, 5},
-// expected: []float32{9, 10, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, 0,
-// 0, 0, x, x,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestSWAMem(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewSWAMemCache(1, 3, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, 0, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{4, 5},
-// expected: []float32{5, 2, 3, 4, 6},
-// expectedShape: []int{1, 1, 5},
-// expectedMask: []float32{
-// 0, x, x, 0, x,
-// 0, x, x, x, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestChunkedAttention(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewChunkedAttentionCache(2, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// testCache(
-// t, backend, cache,
-// []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, x, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6, 7},
-// inShape: []int{1, 1, 3},
-// seqs: []int{0, 0, 0},
-// pos: []int32{4, 5, 6},
-// expected: []float32{1, 2, 3, 4, 5, 6, 7},
-// expectedShape: []int{1, 1, 7},
-// expectedMask: []float32{
-// x, x, x, x, 0, x, x,
-// x, x, x, x, 0, 0, x,
-// x, x, x, x, x, x, 0,
-// },
-// },
-// {
-// name: "ThirdBatch",
-// in: []float32{8, 9},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{7, 8},
-// expected: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
-// expectedShape: []int{1, 1, 9},
-// expectedMask: []float32{
-// x, x, x, x, x, x, 0, 0, x,
-// x, x, x, x, x, x, x, x, 0,
-// },
-// },
-// },
-// )
-// })
-// }
-
-// func TestSequences(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 1, 1},
-// pos: []int32{0, 1, 0, 1},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
-// },
-// {
-// name: "SecondBatch",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 1},
-// pos: []int32{2, 2},
-// expected: []float32{1, 2, 3, 4, 5, 6},
-// expectedShape: []int{1, 1, 6},
-// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestRemove(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
-// return key.Add(ctx, shift), nil
-// })
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// x := float32(math.Inf(-1))
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 1, 1},
-// pos: []int32{0, 1, 0, 1},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{
-// 0, x, x, x,
-// 0, 0, x, x,
-// x, x, 0, x,
-// x, x, 0, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-
-// err := cache.Remove(0, 1, math.MaxInt32)
-// if err != nil {
-// panic(err)
-// }
-
-// tests = []testCase{
-// {
-// name: "RemoveEnd",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 1},
-// pos: []int32{1, 2},
-// expected: []float32{1, 5, 3, 4, 6},
-// expectedShape: []int{1, 1, 5},
-// expectedMask: []float32{
-// 0, 0, x, x, x,
-// x, x, 0, 0, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-
-// err = cache.Remove(0, 0, 1)
-// if err != nil {
-// panic(err)
-// }
-
-// tests = []testCase{
-// {
-// name: "RemoveMiddle",
-// in: []float32{7, 8},
-// inShape: []int{1, 1, 2},
-// seqs: []int{0, 0},
-// pos: []int32{1, 2},
-// expected: []float32{7, 4, 3, 4, 6, 8},
-// expectedShape: []int{1, 1, 6},
-// expectedMask: []float32{
-// 0, 0, x, x, x, x,
-// 0, 0, x, x, x, 0,
-// },
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func TestCopy(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// tests := []testCase{
-// {
-// name: "FirstBatch",
-// in: []float32{1, 2, 3, 4},
-// inShape: []int{1, 1, 4},
-// seqs: []int{0, 0, 0, 0},
-// pos: []int32{0, 1, 2, 3},
-// expected: []float32{1, 2, 3, 4},
-// expectedShape: []int{1, 1, 4},
-// expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), 0, 0, 0, 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-
-// cache.CopyPrefix(0, 1, 2)
-
-// tests = []testCase{
-// {
-// name: "Copy",
-// in: []float32{5, 6},
-// inShape: []int{1, 1, 2},
-// seqs: []int{1, 1},
-// pos: []int32{3, 4},
-// expected: []float32{1, 2, 3, 4, 5, 6},
-// expectedShape: []int{1, 1, 6},
-// expectedMask: []float32{0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
-// },
-// }
-
-// testCache(t, backend, cache, tests)
-// })
-// }
-
-// func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) {
-// for _, test := range tests {
-// t.Run(test.name, func(t *testing.T) {
-// context := backend.NewContext()
-// defer context.Close()
-
-// err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
-// if err != nil {
-// panic(err)
-// }
-
-// cache.SetLayer(0)
-// tensor := context.FromFloats(test.in, test.inShape...)
-// cache.Put(context, tensor, tensor)
-
-// out, _, mask := cache.Get(context)
-
-// context.Forward(out, mask).Compute(out, mask)
-
-// if !slices.Equal(out.Floats(), test.expected) {
-// t.Errorf("TestCache: have %v; want %v", out.Floats(), test.expected)
-// }
-
-// if !slices.Equal(out.Shape(), test.expectedShape) {
-// t.Errorf("TestCache: has shape %v; want %v", out.Shape(), test.expectedShape)
-// }
-
-// if !slices.Equal(mask.Floats(), test.expectedMask) {
-// t.Errorf("TestCache: have mask: have %v want %v", mask.Floats(), test.expectedMask)
-// }
-// })
-// }
-// }
-
-// func TestCanResume(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// windowSize := int32(4)
-// cache := NewSWACache(windowSize, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// context := backend.NewContext()
-// defer context.Close()
-
-// err := cache.StartForward(context, input.Batch{
-// Positions: []int32{0, 1, 2, 3, 4},
-// Sequences: []int{0, 0, 0, 0, 0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
-// cache.Put(context, tensor, tensor)
-
-// // with window size 4, nothing has slid out of the window yet
-// if !cache.CanResume(0, 0) {
-// t.Errorf("CanResume(0, 0) = false, want true (within window)")
-// }
-// if !cache.CanResume(0, 1) {
-// t.Errorf("CanResume(0, 1) = false, want true (within window)")
-// }
-// if !cache.CanResume(0, 2) {
-// t.Errorf("CanResume(0, 2) = false, want true (within window)")
-// }
-// if !cache.CanResume(0, 3) {
-// t.Errorf("CanResume(0, 3) = false, want true (latest position)")
-// }
-// if !cache.CanResume(0, 4) {
-// t.Errorf("CanResume(0, 4) = false, want true (latest position)")
-// }
-
-// // shift window by adding position 5
-// err = cache.StartForward(context, input.Batch{
-// Positions: []int32{5},
-// Sequences: []int{0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor = context.FromFloats([]float32{6}, 1, 1, 1)
-// cache.Put(context, tensor, tensor)
-
-// // only the latest position has overlapping windows
-// if cache.CanResume(0, 0) {
-// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 1) {
-// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 2) {
-// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 3) {
-// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 4) {
-// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
-// }
-// if !cache.CanResume(0, 5) {
-// t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
-// }
-// })
-// }
-
-// func TestCanResumeSWAMem(t *testing.T) {
-// runPermutedVariants(t, func(t *testing.T, backend *testBackend) {
-// windowSize := int32(4)
-// memSize := int32(5)
-// cache := NewSWAMemCache(windowSize, memSize, nil)
-// defer cache.Close()
-
-// cache.Init(backend, ml.DTypeF16, 1, 16, 16)
-
-// context := backend.NewContext()
-// defer context.Close()
-
-// err := cache.StartForward(context, input.Batch{
-// Positions: []int32{0, 1, 2, 3, 4, 5, 6},
-// Sequences: []int{0, 0, 0, 0, 0, 0, 0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor := context.FromFloats([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
-// cache.Put(context, tensor, tensor)
-
-// // shift window by adding position 7
-// err = cache.StartForward(context, input.Batch{
-// Positions: []int32{7},
-// Sequences: []int{0},
-// }, false)
-// if err != nil {
-// t.Fatalf("StartForward failed: %v", err)
-// }
-
-// cache.SetLayer(0)
-// tensor = context.FromFloats([]float32{8}, 1, 1, 1)
-// cache.Put(context, tensor, tensor)
-
-// // only the latest position has overlapping windows
-// if cache.CanResume(0, 0) {
-// t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 1) {
-// t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 2) {
-// t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 3) {
-// t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 4) {
-// t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
-// }
-// if cache.CanResume(0, 5) {
-// t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
-// }
-// if !cache.CanResume(0, 6) {
-// t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
-// }
-// if !cache.CanResume(0, 7) {
-// t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
-// }
-// })
-// }
-
-// type testBackend struct {
-// ml.Backend
-// permutedV bool
-// }
-
-// func (b *testBackend) NewContext() ml.Context {
-// return &testContext{}
-// }
-
-// func (b *testBackend) NewContextSize(int) ml.Context {
-// return &testContext{}
-// }
-
-// func (b *testBackend) CacheConfig() ml.CacheConfig {
-// return ml.CacheConfig{PermutedV: b.permutedV}
-// }
-
-// type testContext struct {
-// ml.Context
-// }
-
-// func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
-// total := 0
-
-// if len(shape) > 0 {
-// total = 1
-// for _, s := range shape {
-// total *= s
-// }
-// }
-
-// return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
-// }
-
-// func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
-// return c.Empty(dtype, shape...)
-// }
-
-// func (c *testContext) FromFloats(s []float32, shape ...int) ml.Tensor {
-// t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
-
-// copy(t.data, s)
-
-// return t
-// }
-
-// func (c *testContext) FromInts(s []int32, shape ...int) ml.Tensor {
-// f := make([]float32, len(s))
-// for i := range f {
-// f[i] = float32(s[i])
-// }
-
-// out := c.FromFloats(f, shape...)
-// out.(*testTensor).dtype = ml.DTypeI32
-
-// return out
-// }
-
-// func (c *testContext) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
-// s := make([]float32, 0, int((stop-start)/step))
-// for i := start; i < stop; i += step {
-// s = append(s, i)
-// }
-
-// out := c.FromFloats(s, len(s))
-// out.(*testTensor).dtype = dtype
-// return out
-// }
-
-// func (c *testContext) Input() ml.Context { return c }
-// func (c *testContext) Layer(int) ml.Context { return c }
-
-// func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
-
-// func (c *testContext) Compute(...ml.Tensor) {}
-
-// func (c *testContext) Reserve() {}
-
-// func (c *testContext) MaxGraphNodes() int {
-// return 10
-// }
-
-// func (c *testContext) Close() {}
-
-// type testTensor struct {
-// ml.Tensor
-
-// dtype ml.DType
-// elementSize int
-// data []float32
-// shape []int
-// }
-
-// func (t *testTensor) Dim(n int) int {
-// return t.shape[n]
-// }
-
-// func (t *testTensor) Stride(n int) int {
-// stride := t.elementSize
-// for i := range n {
-// stride *= t.shape[i]
-// }
-
-// return stride
-// }
-
-// func (t *testTensor) Shape() []int {
-// return t.shape
-// }
-
-// func (t *testTensor) DType() ml.DType {
-// return t.dtype
-// }
-
-// func (t *testTensor) Floats() []float32 {
-// out := make([]float32, len(t.data))
-// copy(out, t.data)
-// return out
-// }
-
-// func (t *testTensor) Neg(ctx ml.Context) ml.Tensor {
-// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
-// for i := range out.data {
-// out.data[i] = -t.data[i]
-// }
-// return out
-// }
-
-// func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
-// out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
-
-// for i := range out.data {
-// out.data[i] = t.data[i] + t2.(*testTensor).data[i]
-// }
-
-// return out
-// }
-
-// func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
-// return &testTensor{
-// dtype: t.dtype,
-// elementSize: t.elementSize,
-// data: t.data,
-// shape: shape,
-// }
-// }
-
-// func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
-// offset /= t.elementSize
-
-// var s []int
-
-// switch len(shape) {
-// case 1:
-// s = []int{shape[0]}
-// case 3:
-// s = []int{shape[0], shape[2]}
-// case 5:
-// s = []int{shape[0], shape[2], shape[4]}
-// default:
-// panic("unsupported number of dimensions")
-// }
-
-// context := &testContext{}
-
-// view := context.Empty(t.dtype, s...).(*testTensor)
-// view.data = t.data[offset : offset+len(view.data)]
-
-// return view
-// }
-
-// func (t *testTensor) Permute(ctx ml.Context, order ...int) ml.Tensor {
-// if len(t.shape) > 4 || len(order) > 4 {
-// panic("permute only supports up to 4 dimensions")
-// }
-
-// if len(order) != len(t.shape) && len(order) != 4 {
-// panic("invalid number of dimensions for permute")
-// }
-
-// // ggml_permute expects 4 axes, so fill in any missing dimensions.
-// orderFull := append(make([]int, 0, 4), order...)
-// for len(orderFull) < 4 {
-// orderFull = append(orderFull, len(orderFull))
-// }
-
-// seen := [4]bool{}
-
-// shape4 := [4]int{1, 1, 1, 1}
-// for i := 0; i < len(t.shape) && i < 4; i++ {
-// shape4[i] = t.shape[i]
-// }
-
-// newShape4 := [4]int{1, 1, 1, 1}
-// for axis := range 4 {
-// dst := orderFull[axis]
-// if dst < 0 || dst >= 4 {
-// panic("invalid axis for permute")
-// }
-// if seen[dst] {
-// panic("duplicate axis for permute")
-// }
-// seen[dst] = true
-// newShape4[dst] = shape4[axis]
-// }
-
-// total := len(t.data)
-// newData := make([]float32, total)
-
-// if total > 0 {
-// oldDims := shape4
-// newDims := newShape4
-
-// oldStride := [4]int{1, 1, 1, 1}
-// newStride := [4]int{1, 1, 1, 1}
-// for i := 1; i < 4; i++ {
-// oldStride[i] = oldStride[i-1] * oldDims[i-1]
-// newStride[i] = newStride[i-1] * newDims[i-1]
-// }
-
-// var coords [4]int
-// var newCoords [4]int
-
-// for idx := range total {
-// remainder := idx
-// for axis := range 4 {
-// dim := oldDims[axis]
-// if dim == 0 {
-// coords[axis] = 0
-// continue
-// }
-// coords[axis] = remainder % dim
-// remainder /= dim
-// }
-
-// for axis := range 4 {
-// newCoords[orderFull[axis]] = coords[axis]
-// }
-
-// newIndex := 0
-// for axis := range 4 {
-// if newDims[axis] == 0 {
-// continue
-// }
-// newIndex += newCoords[axis] * newStride[axis]
-// }
-
-// newData[newIndex] = t.data[idx]
-// }
-// }
-
-// numDims := 4
-// for numDims > 1 && newShape4[numDims-1] <= 1 {
-// numDims--
-// }
-
-// newShape := make([]int, numDims)
-// copy(newShape, newShape4[:numDims])
-
-// return &testTensor{
-// dtype: t.dtype,
-// elementSize: t.elementSize,
-// data: newData,
-// shape: newShape,
-// }
-// }
-
-// func (t *testTensor) SetRows(ctx ml.Context, src ml.Tensor, idxs ml.Tensor) ml.Tensor {
-// dst := t
-// srcTensor := src.(*testTensor)
-// idxTensor := idxs.(*testTensor)
-
-// shapeTo4D := func(shape []int) [4]int {
-// out := [4]int{1, 1, 1, 1}
-// for i := 0; i < len(shape) && i < 4; i++ {
-// out[i] = shape[i]
-// }
-// return out
-// }
-
-// computeStrides := func(shape [4]int) [4]int {
-// out := [4]int{1, 1, 1, 1}
-// for i := 1; i < 4; i++ {
-// out[i] = out[i-1] * shape[i-1]
-// }
-// return out
-// }
-
-// dstShape4D := shapeTo4D(dst.shape)
-// srcShape4D := shapeTo4D(srcTensor.shape)
-// idxShape4D := shapeTo4D(idxTensor.shape)
-
-// if dstShape4D[0] != srcShape4D[0] || dstShape4D[2] != srcShape4D[2] || dstShape4D[3] != srcShape4D[3] {
-// panic("SetRows requires matching tensor shapes")
-// }
-
-// if srcShape4D[1] != idxShape4D[0] {
-// panic("SetRows rows/index mismatch")
-// }
-
-// if srcShape4D[2]%idxShape4D[1] != 0 || srcShape4D[3]%idxShape4D[2] != 0 {
-// panic("SetRows cannot broadcast indices")
-// }
-
-// if idxShape4D[3] != 1 {
-// panic("SetRows expects 1D or 2D index tensors")
-// }
-
-// dstStride := computeStrides(dstShape4D)
-// srcStride := computeStrides(srcShape4D)
-// idxStride := computeStrides(idxShape4D)
-
-// numColumns := srcShape4D[0]
-// numRows := srcShape4D[1]
-
-// for dim3Index := range dstShape4D[3] {
-// for dim2Index := range dstShape4D[2] {
-// idxDim2 := 0
-// idxDim3 := 0
-// if idxShape4D[1] > 0 {
-// idxDim2 = dim2Index % idxShape4D[1]
-// }
-// if idxShape4D[2] > 0 {
-// idxDim3 = dim3Index % idxShape4D[2]
-// }
-
-// idxBase := idxDim3*idxStride[2] + idxDim2*idxStride[1]
-// srcBase := dim3Index*srcStride[3] + dim2Index*srcStride[2]
-// dstBase := dim3Index*dstStride[3] + dim2Index*dstStride[2]
-
-// for row := range numRows {
-// idx := int(idxTensor.data[idxBase+row*idxStride[0]])
-// if idx < 0 || idx >= dstShape4D[1] {
-// panic("SetRows index out of range")
-// }
-
-// srcOffset := srcBase + row*srcStride[1]
-// dstOffset := dstBase + idx*dstStride[1]
-
-// copy(dst.data[dstOffset:dstOffset+numColumns], srcTensor.data[srcOffset:srcOffset+numColumns])
-// }
-// }
-// }
-
-// return dst
-// }
-
-// func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
-// copy(t2.(*testTensor).data, t.data)
-// return nil
-// }
diff --git a/x/kvcache/encoder.go b/x/kvcache/encoder.go
deleted file mode 100644
index 19a3839ce73..00000000000
--- a/x/kvcache/encoder.go
+++ /dev/null
@@ -1,156 +0,0 @@
-package kvcache
-
-// import (
-// "fmt"
-
-// "github.com/ollama/ollama/ml"
-// "github.com/ollama/ollama/model/input"
-// )
-
-// // Encoder cache stores K and V tensors that are position independent
-// //
-// // The tensors can be of any shape and will be returned as they were stored
-// // The mask is currently always nil
-// //
-// // Not currently safe for multiple sequences
-// type EncoderCache struct {
-// // config controls mostly backend-specific optimizations
-// config *ml.CacheConfig
-
-// // ** current forward pass **
-
-// // the active layer for Get and Put
-// curLayer int
-
-// // if something is stored during this pass, this
-// // will be the position (but there is no guarantee
-// // anything will be stored)
-// curPos int32
-
-// // curReserve indicates that this forward pass is only for
-// // memory reservation and we should not update our metadata
-// // based on it.
-// curReserve bool
-
-// // ** cache metadata **
-
-// // was something stored in the cache?
-// encoderCached bool
-
-// // position of the cached data
-// encoderPos int32
-
-// // ** cache data storage **
-// backend ml.Backend
-// ctxs map[int]ml.Context
-// keys, values map[int]ml.Tensor
-// }
-
-// func NewEncoderCache() *EncoderCache {
-// return &EncoderCache{
-// ctxs: make(map[int]ml.Context),
-// keys: make(map[int]ml.Tensor),
-// values: make(map[int]ml.Tensor),
-// }
-// }
-
-// func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
-// if c.config == nil {
-// var config ml.CacheConfig
-// if cc, ok := backend.(ml.BackendCacheConfig); ok {
-// config = cc.CacheConfig()
-// }
-// c.config = &config
-// }
-
-// if maxSequences > 1 {
-// panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
-// }
-
-// if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
-// panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
-// }
-
-// c.backend = backend
-// }
-
-// func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
-// if c.config != nil {
-// panic("config cannot be changed after being previously set, either by the model or backend")
-// }
-
-// c.config = &config
-// }
-
-// func (c *EncoderCache) Close() {
-// for _, ctx := range c.ctxs {
-// ctx.Close()
-// }
-// }
-
-// func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
-// // We work with the most recent image
-// if len(batch.Multimodal) > 0 {
-// c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
-// }
-
-// c.curReserve = reserve
-
-// return nil
-// }
-
-// func (c *EncoderCache) SetLayer(layer int) {
-// c.curLayer = layer
-// }
-
-// func (c *EncoderCache) EncoderCached() bool {
-// return c.encoderCached
-// }
-
-// func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
-// return c.keys[c.curLayer], c.values[c.curLayer], nil
-// }
-
-// func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
-// if !c.curReserve {
-// c.encoderPos = c.curPos
-// c.encoderCached = true
-// }
-
-// if c.config.PermutedV {
-// value = value.Transpose(ctx, 1, 2, 0, 3)
-// }
-
-// if _, ok := c.ctxs[c.curLayer]; !ok {
-// c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
-// }
-
-// if _, ok := c.keys[c.curLayer]; !ok {
-// c.keys[c.curLayer] = c.ctxs[c.curLayer].Empty(key.DType(), key.Shape()...)
-// }
-
-// if _, ok := c.values[c.curLayer]; !ok {
-// c.values[c.curLayer] = c.ctxs[c.curLayer].Empty(value.DType(), value.Shape()...)
-// }
-
-// ctx.Forward(
-// key.Copy(ctx, c.keys[c.curLayer]),
-// value.Copy(ctx, c.values[c.curLayer]),
-// )
-// }
-
-// func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
-// panic("encoder cache does not support multiple sequences")
-// }
-
-// func (c *EncoderCache) CanResume(seq int, pos int32) bool {
-// return true
-// }
-
-// func (c *EncoderCache) Remove(seq int, beginIndex, endIndex int32) error {
-// if c.encoderPos >= beginIndex && c.encoderPos < endIndex {
-// c.encoderCached = false
-// }
-
-// return nil
-// }
diff --git a/x/kvcache/mlx.go b/x/kvcache/mlx.go
deleted file mode 100644
index fa38651043a..00000000000
--- a/x/kvcache/mlx.go
+++ /dev/null
@@ -1,144 +0,0 @@
-//go:build mlx
-
-package kvcache
-
-import (
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/model/input"
-)
-
-// Causal cache stores K and V tensors according to their position in the
-// sequence. Returns the history and a mask for attending to past tokens
-type MLXCausal struct {
- DType ml.DType
-
- // locations for data storage for this batch
- curLocPut ml.Tensor
-
- // locations for data storage for this batch
- curLocGet ml.Tensor
-
- // the active layer for Get and Put
- curLayer int
-
- capacity int
-
- offset int
-
- backend ml.Backend
- ctxs map[int]ml.Context
- keys, values map[int]ml.Tensor
-
- // TODO is this needed per layer, or will it always be consistent?
- kHeadDims, vHeadDims, numKVHeads map[int]int
-}
-
-func NewMLXCausalCache() *MLXCausal {
- return &MLXCausal{
- ctxs: make(map[int]ml.Context),
- keys: make(map[int]ml.Tensor),
- values: make(map[int]ml.Tensor),
- kHeadDims: make(map[int]int),
- vHeadDims: make(map[int]int),
- numKVHeads: make(map[int]int),
- }
-}
-
-func (c *MLXCausal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
- c.DType = dtype
- c.capacity = capacity
- c.backend = backend
-}
-
-func (c *MLXCausal) SetConfig(config ml.CacheConfig) {}
-
-func (c *MLXCausal) SetLayer(layer int) {
- c.curLayer = layer
-}
-
-func (c *MLXCausal) Close() {
- // slog.Info("XXX MLXCausal.Close called", "number of contexts", len(c.ctxs))
- for _, ctx := range c.ctxs {
- ctx.Close()
- }
-}
-
-func (c *MLXCausal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
- locsPut := make([]int32, len(batch.Positions))
- for i := c.offset; i < len(batch.Positions); i++ {
- locsPut[i-c.offset] = int32(i)
- }
- c.offset += len(batch.Positions)
- locsGet := make([]int32, c.offset)
- for i := range c.offset {
- locsGet[i] = int32(i)
- }
- c.curLocGet = ctx.Input().FromInts(locsGet, len(locsGet))
- c.curLocPut = ctx.Input().FromInts(locsPut, len(locsPut))
- // slog.Info("XXX MLXCausal.StartForward", "offset", c.offset, "put", locsPut, "get", locsGet)
-
- return nil
-}
-func (c *MLXCausal) Put(ctx ml.Context, key, value ml.Tensor) {
- kHeadDim := key.Dim(3)
- vHeadDim := value.Dim(3)
- numKVHeads := key.Dim(1)
- batchSize := key.Dim(2)
- kCellSize := kHeadDim * numKVHeads
- vCellSize := vHeadDim * numKVHeads
- // slog.Info("XXX Causal.Put", "kHeadDim", kHeadDim, "vHeadDim", vHeadDim, "numKVHeads", numKVHeads, "batchSize", batchSize, "kCellSize", kCellSize, "vCellSize", vCellSize)
-
- if _, ok := c.ctxs[c.curLayer]; !ok {
- // slog.Info("XXX Causal.Put creating new context", "c.curLayer", c.curLayer)
- c.ctxs[c.curLayer] = c.backend.NewContext().Layer(c.curLayer)
- }
-
- if _, ok := c.keys[c.curLayer]; !ok {
- // slog.Info("XXX MLXCausal.Put allocating keys and values", "c.curLayer", c.curLayer, "shape", []int{c.capacity, kCellSize})
- c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, kCellSize)
- c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, c.capacity, vCellSize)
- c.kHeadDims[c.curLayer] = kHeadDim
- c.vHeadDims[c.curLayer] = vHeadDim
- c.numKVHeads[c.curLayer] = numKVHeads
- }
- key = key.Reshape(ctx, batchSize, 1, kCellSize)
-
- // slog.Info("XXX MLXCausal.Put ", "c.keys[c.curLayer]", c.keys[c.curLayer])
- // slog.Info("XXX MLXCausal.Put ", "c.curLocPut", c.curLocPut)
- // slog.Info("XXX MLXCausal.Put ", "key", key)
- ctx.Forward(c.keys[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, key, []int{0}))
- value = value.Reshape(ctx, batchSize, 1, vCellSize)
- ctx.Forward(c.values[c.curLayer].Scatter(ctx, []ml.Tensor{c.curLocPut}, value, []int{0}))
-
-}
-
-func (c *MLXCausal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
- key := c.keys[c.curLayer]
- value := c.values[c.curLayer]
-
- kHeadDim := c.kHeadDims[c.curLayer]
- vHeadDim := c.vHeadDims[c.curLayer]
- numKVHeads := c.numKVHeads[c.curLayer]
- // rowSize := numKVHeads * c.curBatchSize
- // cachedSize := c.curMask.Dim(1)
- cachedSize := c.curLocGet.Dim(0)
- // kCellSize := kHeadDim * numKVHeads
- // vCellSize := vHeadDim * numKVHeads
- // slog.Info("XXX MLXCausal.Get", "shape", []int{1, numKVHeads, cachedSize, kHeadDim})
-
- key = key.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, kHeadDim)
- value = value.TakeAxes(ctx, c.curLocGet, 0).Reshape(ctx, 1, numKVHeads, cachedSize, vHeadDim)
- return key, value, nil
-}
-
-func (c *MLXCausal) CopyPrefix(srcSeq, dstSeq int, len int32) {
- panic("not implemented")
-}
-
-func (c *MLXCausal) CanResume(seq int, pos int32) bool {
- panic("not implemented")
-}
-
-func (c *MLXCausal) Remove(seq int, beginIndex, endIndex int32) error {
- panic("not implemented")
-}
diff --git a/x/kvcache/wrapper.go b/x/kvcache/wrapper.go
deleted file mode 100644
index 69e07dc9620..00000000000
--- a/x/kvcache/wrapper.go
+++ /dev/null
@@ -1,110 +0,0 @@
-package kvcache
-
-// import (
-// "math"
-
-// "github.com/ollama/ollama/ml"
-// "github.com/ollama/ollama/model/input"
-// )
-
-// // Wrapper cache is a container for multiple types of caches,
-// // such as for the encoding and decoding portions of a model.
-// type WrapperCache struct {
-// // caches we are wrapping
-// caches []Cache
-
-// // cache to be used for this layer
-// curType int
-// }
-
-// func NewWrapperCache(caches ...Cache) *WrapperCache {
-// return &WrapperCache{
-// caches: caches,
-// }
-// }
-
-// func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
-// for _, cache := range c.caches {
-// cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
-// }
-// }
-
-// func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
-// for _, cache := range c.caches {
-// cache.SetConfig(config)
-// }
-// }
-
-// func (c *WrapperCache) Close() {
-// for _, cache := range c.caches {
-// cache.Close()
-// }
-// }
-
-// func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
-// for i, cache := range c.caches {
-// err := cache.StartForward(ctx, batch, reserve)
-// if err != nil {
-// // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
-// for j := i - 1; j >= 0; j-- {
-// for k := range batch.Positions {
-// _ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
-// }
-// }
-// return err
-// }
-// }
-
-// c.curType = 0
-// return nil
-// }
-
-// func (c *WrapperCache) SetLayer(layer int) {
-// for _, cache := range c.caches {
-// cache.SetLayer(layer)
-// }
-// }
-
-// func (c *WrapperCache) SetLayerType(layerType int) {
-// c.curType = layerType
-// }
-
-// func (c *WrapperCache) UnderlyingCache() Cache {
-// return c.caches[c.curType]
-// }
-
-// func (c *WrapperCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
-// return c.caches[c.curType].Get(ctx)
-// }
-
-// func (c *WrapperCache) Put(ctx ml.Context, key, value ml.Tensor) {
-// c.caches[c.curType].Put(ctx, key, value)
-// }
-
-// func (c *WrapperCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
-// for _, cache := range c.caches {
-// cache.CopyPrefix(srcSeq, dstSeq, len)
-// }
-// }
-
-// func (c *WrapperCache) CanResume(seq int, pos int32) bool {
-// for _, cache := range c.caches {
-// if !cache.CanResume(seq, pos) {
-// return false
-// }
-// }
-
-// return true
-// }
-
-// func (c *WrapperCache) Remove(seq int, beginIndex, endIndex int32) error {
-// // If the one of these fails, the caller is supposed to retry with endIndex set to math.MaxInt32, which should not fail
-// for _, cache := range c.caches {
-// err := cache.Remove(seq, beginIndex, endIndex)
-// if err != nil {
-// return err
-// }
-// }
-
-// return nil
-// }
diff --git a/x/ml/backend.go b/x/ml/backend.go
deleted file mode 100644
index 31ff3541e76..00000000000
--- a/x/ml/backend.go
+++ /dev/null
@@ -1,433 +0,0 @@
-package ml
-
-import (
- "fmt"
- "log/slog"
- "os"
-
- "github.com/ollama/ollama/fs"
-)
-
-type Backend interface {
- // Close frees all memory associated with this backend
- // Close()
-
- // Load(ctx context.Context, progress func(float32)) error
-
- // BackendMemory returns the memory allocations that were made for this model
- // BackendMemory() BackendMemory
-
- Config() fs.Config
- Get(name string) Tensor
- NewContext() Context
- // NewContextSize(size int) Context
-
- // Enumerate the devices available for inference via this backend
- // BackendDevices() []DeviceInfo
-}
-
-// BackendCacheConfig should be implemented by backends that need special output
-// from the cache to meet specific requirements. It is frequently implemented in
-// conjunction with ScaledDotProductAttention.
-type BackendCacheConfig interface {
- CacheConfig() CacheConfig
-}
-
-// CacheConfig controls optimizations (mostly backend-specific) that may transform
-// the output the cache to work better with specific kernels.
-type CacheConfig struct {
- // CachePadding specifies the multiple for the number of tokens of cache history
- // that will be returned from cache Get for k, v and mask. The capacity of the
- // cache itself will also be increased to a multiple of this size if needed.
- CachePadding int
-
- // PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
- // and return the permuted version via Get. This uses the cache copy operation
- // to avoid a Contiguous call on the permuted tensor.
- PermutedV bool
-
- // MaskDType specifies the data type for generating the mask. If unset it will
- // default to DTypeF32.
- MaskDType DType
-
- // MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
- // Any position that does not correspond to an actual token will be filled with -Inf.
- MaskBatchPadding int
-}
-
-// BackendParams controls how the backend loads and executes models
-type BackendParams struct {
- // AllocMemory causes the backend to allocate memory for the model. If
- // false, this is only being used for discovering the required amount of
- // memory and cannot load the model for running.
- AllocMemory bool
-
- // NumThreads sets the number of threads to use if running on the CPU
- NumThreads int
-
- // GPULayers is the set of layers to offload to GPUs
- GPULayers GPULayersList
-
- // FlashAttention indicates that we should use a fused flash attention kernel
- FlashAttention bool
-}
-
-var backends = make(map[string]func(string, BackendParams) (Backend, error))
-
-func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
- if _, ok := backends[name]; ok {
- panic("backend: backend already registered")
- }
-
- backends[name] = f
-}
-
-func NewBackend(modelPath string, params BackendParams) (Backend, error) {
- be := os.Getenv("OLLAMA_BACKEND")
- if be == "" {
- be = "mlx"
- slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
- }
- slog.Info("Loading new engine", "backend", be)
- if backend, ok := backends[be]; ok {
- return backend(modelPath, params)
- }
-
- return nil, fmt.Errorf("unsupported backend")
-}
-
-type Context interface {
- Empty(dtype DType, shape ...int) Tensor
- Zeros(dtype DType, shape ...int) Tensor
- // FromBytes(dtype DType, s []byte, shape ...int) Tensor
- FromFloats(s []float32, shape ...int) Tensor
- FromInts(s []int32, shape ...int) Tensor
- RandomNormal(shape []int, dtype DType, loc, scale float32, key Tensor) Tensor
-
- // Arange creates a 1D tensor with values within an interval (start, stop] increased by step.
- Arange(start, stop, step float32, dtype DType) Tensor
-
- Forward(...Tensor) Context
-
- // SetBatchSize provides a hint on the batch size to optimize processing
- // Uses heuristics if not set
- // SetBatchSize(int)
-
- Compute(...Tensor)
- // ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
-
- // Reserve is analogous to Compute but rather than executing a
- // graph, simply preallocates memory. Typically called with a
- // worst case graph to ensure all resources are available for
- // for future inference.
- // Reserve()
-
- // MaxGraphNodes() int
- Close()
-
- // Input returns a context appropriate for creating tensors that are
- // inputs to the model (which includes things like output locations)
- Input() Context
-
- // Layer returns a context appropriate for creating intermediate tensors
- Layer(int) Context
-
- // Load a tensor from "filename" safetensors file, and compare with the input tensor
- // Returns error if the shape is inconsistent, or similarity measures are below 99%
- CompareWith(filename string, tensors map[string]Tensor, abortOnError bool) error
-}
-
-type RoPEOptions struct {
- Base *float32
- Freqs Tensor
-}
-
-func WithRoPEBase(base float32) func(*RoPEOptions) {
- return func(opts *RoPEOptions) {
- opts.Base = &base
- }
-}
-
-func WithRoPEFreqs(freqs Tensor) func(*RoPEOptions) {
- return func(opts *RoPEOptions) {
- opts.Freqs = freqs
- }
-}
-
-type Tensor interface {
- ToString() string
- RoPE(ctx Context, dims int, traditional bool, scale float32, offset int, options ...func(*RoPEOptions)) Tensor
- ScaledDotProductAttention(ctx Context, keys, values Tensor, scale float64, maskMode string, mask Tensor, sinks Tensor) Tensor
- TakeAxes(ctx Context, indicies Tensor, axes int) Tensor
- // TakeAxes(ctx Context, axes int, indicies ...int) Tensor
-
- Dim(n int) int
- Stride(n int) int
-
- Shape() []int
- DType() DType
- // Cast(ctx Context, dtype DType) Tensor
-
- // Bytes() []byte
- Floats() []float32
- Ints() []int32
-
- // FromBytes([]byte)
- // FromFloats([]float32)
- // FromInts([]int32)
-
- Add(ctx Context, t2 Tensor) Tensor
- Sub(ctx Context, t2 Tensor) Tensor
- // Mul(ctx Context, t2 Tensor) Tensor
- // Div(ctx Context, t2 Tensor) Tensor
-
- Max(ctx Context, axes []int, keepDims bool) Tensor
- Min(ctx Context, axes []int, keepDims bool) Tensor
-
- Matmul(ctx Context, a2 Tensor) Tensor
- // Mulmat(ctx Context, t2 Tensor) Tensor
- // MulmatFullPrec(ctx Context, t2 Tensor) Tensor
- // MulmatID(ctx Context, t2, ids Tensor) Tensor
- // AddID(ctx Context, t2, ids Tensor) Tensor
-
- Softmax(ctx Context) Tensor
- L2Norm(ctx Context, eps float32) Tensor
- LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
- RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
- Scale(ctx Context, s float64) Tensor
- // SumRows(ctx Context) Tensor
-
- AvgPool2D(ctx Context, k, s int, p float32) Tensor
- Conv2D(ctx Context, weight Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) Tensor
- Conv3D(ctx Context, weight Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) Tensor
-
- // IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
-
- // Sin(ctx Context) Tensor
- // Cos(ctx Context) Tensor
- // Tanh(ctx Context) Tensor
- GELU(ctx Context, up ...Tensor) Tensor
- // QuickGELU(ctx Context, up ...Tensor) Tensor
- // SILU(ctx Context, up ...Tensor) Tensor
- // RELU(ctx Context, up ...Tensor) Tensor
- // Sigmoid(ctx Context) Tensor
-
- // AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
- // SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
-
- Reshape(ctx Context, shape ...int) Tensor
- AsStrided(ctx Context, shape, strides []int, offset int) Tensor
- Transpose(ctx Context, shape ...int) Tensor
- Contiguous(ctx Context, allowColMajor bool) Tensor
-
- // Pad(ctx Context, shape ...int) Tensor
-
- // Stack(ctx Context, dim int, s ...Tensor) Tensor
-
- // Repeat repeats the tensor n times along dimension dim
- // Repeat(ctx Context, dim, n int) Tensor
- // Concat(ctx Context, t2 Tensor, dim int) Tensor
- // Rows(ctx Context, t2 Tensor) Tensor
-
- // TODO these probably aren't actually needed - false starts on trying to wire up cache
- // SliceUpdate(ctx Context, update Tensor, start, stop, strides []int) Tensor
- // SliceUpdateDynamic(ctx Context, update, start Tensor, axes []int) Tensor
- // PutAlongAxis(ctx Context, indicies, values Tensor, axis int) Tensor
-
- Scatter(ctx Context, indicies []Tensor, updates Tensor, axes []int) Tensor
-
- Copy(ctx Context, t2 Tensor) Tensor
- // Duplicate(ctx Context) Tensor
-
- // Slice(ctx Context, dim, low, high, step int) Tensor
- // Chunk(ctx Context, dim int, size int) []Tensor
- // ChunkSections(ctx Context, dim int, sections ...int) []Tensor
-
- // TopK(ctx Context, k int) Tensor
- // Argsort(ctx Context) Tensor
- // Mean(ctx Context) Tensor
- // Variance(ctx Context) Tensor
- // Stddev(ctx Context) Tensor
- // Sqr(ctx Context) Tensor
- // Sqrt(ctx Context) Tensor
-
- // Interpolate(ctx Context, dims [4]int, samplingMode SamplingMode) Tensor
-}
-
-// ScaledDotProductAttention implements a fused attention
-// operation equivalent to following code on a tensor named
-// query:
-//
-// query = query.Permute(ctx, 0, 2, 1, 3)
-// key = key.Permute(ctx, 0, 2, 1, 3)
-// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
-//
-// kq := key.MulmatFullPrec(ctx, query)
-//
-// kq = kq.Scale(ctx, scale)
-//
-// if mask != nil {
-// kq = kq.Add(ctx, mask)
-// }
-//
-// kq = kq.Softmax(ctx)
-//
-// kqv := value.Mulmat(ctx, kq)
-// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
-// type ScaledDotProductAttention interface {
-// ScaledDotProductAttention(ctx Context, key, value, mask, sinks Tensor, vmla Tensor, scale float64) Tensor
-// }
-
-// type number interface {
-// ~int | ~int8 | ~int16 | ~int32 | ~int64 |
-// ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
-// ~float32 | ~float64 |
-// ~complex64 | ~complex128
-// }
-
-// func mul[T number](s ...T) T {
-// p := T(1)
-// for _, v := range s {
-// p *= v
-// }
-
-// return p
-// }
-
-// type DumpOptions func(*dumpOptions)
-
-// // DumpWithPrecision sets the number of decimal places to print. Applies to float32 and float64.
-// func DumpWithPrecision(n int) DumpOptions {
-// return func(opts *dumpOptions) {
-// opts.Precision = n
-// }
-// }
-
-// // DumpWithThreshold sets the threshold for printing the entire tensor. If the number of elements
-// // is less than or equal to this value, the entire tensor will be printed. Otherwise, only the
-// // beginning and end of each dimension will be printed.
-// func DumpWithThreshold(n int) DumpOptions {
-// return func(opts *dumpOptions) {
-// opts.Threshold = n
-// }
-// }
-
-// // DumpWithEdgeItems sets the number of elements to print at the beginning and end of each dimension.
-// func DumpWithEdgeItems(n int) DumpOptions {
-// return func(opts *dumpOptions) {
-// opts.EdgeItems = n
-// }
-// }
-
-// type dumpOptions struct {
-// Precision, Threshold, EdgeItems int
-// }
-
-// func Dump(ctx Context, t Tensor, optsFuncs ...DumpOptions) string {
-// opts := dumpOptions{Precision: 4, Threshold: 1000, EdgeItems: 3}
-// for _, optsFunc := range optsFuncs {
-// optsFunc(&opts)
-// }
-
-// if mul(t.Shape()...) <= opts.Threshold {
-// opts.EdgeItems = math.MaxInt
-// }
-
-// switch t.DType() {
-// case DTypeFloat32:
-// return dump[[]float32](ctx, t, opts.EdgeItems, func(f float32) string {
-// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
-// })
-// case DTypeFloat16: // TODO other types...
-// f32 := ctx.Input().Empty(DTypeFloat32, t.Shape()...)
-// f32 = t.Copy(ctx, f32)
-// return dump[[]float32](ctx, f32, opts.EdgeItems, func(f float32) string {
-// return strconv.FormatFloat(float64(f), 'f', opts.Precision, 32)
-// })
-// case DTypeInt32:
-// return dump[[]int32](ctx, t, opts.EdgeItems, func(i int32) string {
-// return strconv.FormatInt(int64(i), 10)
-// })
-// default:
-// return ""
-// }
-// }
-
-// func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
-// if t.Bytes() == nil {
-// ctx.Compute(t)
-// }
-
-// s := make(S, mul(t.Shape()...))
-// if err := binary.Read(bytes.NewBuffer(t.Bytes()), binary.LittleEndian, &s); err != nil {
-// panic(err)
-// }
-
-// shape := t.Shape()
-// slices.Reverse(shape)
-
-// var sb strings.Builder
-// var f func([]int, int)
-// f = func(dims []int, stride int) {
-// prefix := strings.Repeat(" ", len(shape)-len(dims)+1)
-// sb.WriteString("[")
-// defer func() { sb.WriteString("]") }()
-// for i := 0; i < dims[0]; i++ {
-// if i >= items && i < dims[0]-items {
-// sb.WriteString("..., ")
-// // skip to next printable element
-// skip := dims[0] - 2*items
-// if len(dims) > 1 {
-// stride += mul(append(dims[1:], skip)...)
-// fmt.Fprint(&sb, strings.Repeat("\n", len(dims)-1), prefix)
-// }
-// i += skip - 1
-// } else if len(dims) > 1 {
-// f(dims[1:], stride)
-// stride += mul(dims[1:]...)
-// if i < dims[0]-1 {
-// fmt.Fprint(&sb, ",", strings.Repeat("\n", len(dims)-1), prefix)
-// }
-// } else {
-// text := fn(s[stride+i])
-// if len(text) > 0 && text[0] != '-' {
-// sb.WriteString(" ")
-// }
-
-// sb.WriteString(text)
-// if i < dims[0]-1 {
-// sb.WriteString(", ")
-// }
-// }
-// }
-// }
-// f(shape, 0)
-
-// return sb.String()
-// }
-
-type DType int
-
-const (
- DTypeBool DType = iota
- DTypeUint8
- DTypeUint16
- DTypeUint32
- DTypeUint64
- DTypeInt8
- DTypeInt16
- DTypeInt32
- DTypeInt64
- DTypeFloat16
- DTypeFloat32
- DTypeFloat64
- DTypeBfloat16
- DTypeComplex64
-)
-
-type SamplingMode int
-
-const (
- SamplingModeNearest SamplingMode = iota
- SamplingModeBilinear
-)
diff --git a/x/ml/backend/backend.go b/x/ml/backend/backend.go
deleted file mode 100644
index b9dd4a13bfe..00000000000
--- a/x/ml/backend/backend.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package backend
-
-// _ "github.com/ollama/ollama/x/ml/backend/mlx"
diff --git a/x/ml/backend/mlx/mlx.go b/x/ml/backend/mlx/mlx.go
deleted file mode 100644
index 1b647685e27..00000000000
--- a/x/ml/backend/mlx/mlx.go
+++ /dev/null
@@ -1,1278 +0,0 @@
-//go:build mlx
-
-package mlx
-
-/*
-#cgo CPPFLAGS: -I${SRCDIR}/../../../../build/_deps/mlx-c-src
-#cgo LDFLAGS: -L${SRCDIR}/../../../../build/lib/ollama/ -lmlxc -lmlx
-#cgo LDFLAGS: -framework Accelerate
-#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../../build/lib/ollama/
-#include
-#include "mlx/c/mlx.h"
-static inline size_t stride(const mlx_array a, int i) {return mlx_array_strides(a)[i];}
-
-extern void goStackTrace();
-static void error_handler(const char *msg, void* data) {
- fprintf(stderr, "MLX error: %s\n", msg);
- goStackTrace();
- exit(-1); // TODO adjust so this can become a return code on the current thread instead of exit
-}
-static void set_error_handler() {mlx_set_error_handler(&error_handler, NULL, NULL);}
-static void* mlx_array_data_float16_asvoid(const mlx_array a) {return (void*)mlx_array_data_float16(a);}
-typedef const char cchar_t;
-*/
-import "C"
-
-import (
- "encoding/json"
- "fmt"
- "log/slog"
- "math"
- "os"
- "path/filepath"
- "reflect"
- "runtime"
- "runtime/debug"
- "sync"
- "unsafe"
-
- "github.com/ollama/ollama/convert"
- "github.com/ollama/ollama/fs"
- "github.com/ollama/ollama/x/ml"
- "github.com/x448/float16"
-)
-
-func init() {
- ml.RegisterBackend("mlx", New)
- C.set_error_handler()
-}
-
-//export goStackTrace
-func goStackTrace() {
- debug.PrintStack()
-}
-
-type SafetensorsIndexMetadata struct {
- TotalSize uint64 `json:"total_size"`
-}
-type SafetensorsIndex struct {
- Metadata SafetensorsIndexMetadata `json:"metadata"`
- WeightMap map[string]string `json:"weight_map"`
-}
-
-type Backend struct {
- meta fs.Config
- tensors map[string]*Array
-}
-
-func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
- // TODO assumes modelPath is actually a directory for now...
- kv, tokenizer, err := convert.LoadModelMetadata(os.DirFS(modelPath))
- if err != nil {
- return nil, fmt.Errorf("unable to load model: %w", err)
- }
-
- b := &Backend{
- meta: kv.KV(tokenizer),
- }
-
- err = b.LoadSafeTensors(modelPath)
- if err != nil {
- return nil, fmt.Errorf("safetensors load failed: %w", err)
- }
- return b, nil
-}
-
-func (b *Backend) LoadSafeTensors(dir string) error {
- if _, err := os.Stat(dir); err != nil {
- return fmt.Errorf("failed to stat dir: %w", err)
- }
- // other variations to try?
- stFilename := filepath.Join(dir, "model.safetensors.index.json")
- if _, err := os.Stat(stFilename); err != nil {
- return fmt.Errorf("failed to stat %s: %w", stFilename, err)
- }
-
- fp, err := os.Open(stFilename)
- if err != nil {
- return fmt.Errorf("failed to open safetensor index: %s: %w", stFilename, err)
- }
- decoder := json.NewDecoder(fp)
- var index SafetensorsIndex
- if err := decoder.Decode(&index); err != nil {
- return fmt.Errorf("decode error: %s: %w", stFilename, err)
- }
- slog.Info("XXX parsed metadata", "size", index.Metadata.TotalSize, "weights", len(index.WeightMap))
- filenames := map[string]struct{}{}
- for _, filename := range index.WeightMap {
- filenames[filename] = struct{}{}
- }
- stream := C.mlx_default_cpu_stream_new()
-
- b.tensors = map[string]*Array{}
-
- for filename := range filenames {
- filepath := filepath.Join(dir, filename)
- if _, err := os.Stat(filepath); err != nil {
- return fmt.Errorf("failed to stat %s: %w", filepath, err)
- }
- slog.Info("Loading tensors from", "filename", filename)
- cFilename := C.CString(filepath)
- defer C.free(unsafe.Pointer(cFilename))
- data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it?
- metadata := C.mlx_map_string_to_string_new()
- defer C.mlx_map_string_to_array_free(data)
- defer C.mlx_map_string_to_string_free(metadata)
-
- if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 {
- // TODO with the current error handling, this will never happen
- return fmt.Errorf("load failed")
- }
-
- it := C.mlx_map_string_to_array_iterator_new(data)
- // defer C.mlx_array_free(shaped)
- // TODO confusing how memory management works with this...
- for {
- var key *C.cchar_t
- var value C.mlx_array
- if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
- break
- }
- k := C.GoString((*C.char)(key))
- b.tensors[k] = &Array{
- name: k,
- a: value,
- }
- // slog.Info("XXX read", "tensor", b.tensors[k], "type", b.tensors[k].TypeString())
- }
- }
-
- return nil
-}
-
-func (b *Backend) Get(name string) ml.Tensor {
- var t ml.Tensor
- var ok bool
- if t, ok = b.tensors[name]; !ok {
- // slog.Warn("unable to locate", "tensor", name)
- return nil
- }
- // slog.Info("Fetching", "tensor", name, "type", b.tensors[name].TypeString())
- return t
-}
-
-func (b *Backend) NewContext() ml.Context {
- // slog.Info("MLX.NewContext")
- return &Context{
- stream: C.mlx_default_gpu_stream_new(),
- }
-}
-
-func (b *Backend) Config() fs.Config {
- return b.meta
-}
-
-type Context struct {
- stream C.mlx_stream
-
- mu sync.Mutex
- arrays []C.mlx_array // TODO should we do some bookkeeping to ensure none of these Arrays are still lingering?
-}
-
-func (c *Context) Close() {
- // C.mlx_synchronize(c.stream) // ???
- C.mlx_stream_free(c.stream)
-
- c.mu.Lock()
- defer c.mu.Unlock()
- for _, a := range c.arrays {
- slog.Info("XXX freeing", "array", a)
- C.mlx_array_free(a)
- }
-}
-
-func (c *Context) Compute(tensors ...ml.Tensor) {
- // TODO - for the zero tensor case this feels like it might not be correct...
- needSync := true
- sync := func() {
- if needSync {
- C.mlx_synchronize(c.stream)
- needSync = false
- }
- }
-
- vec := C.mlx_vector_array_new()
- defer C.mlx_vector_array_free(vec)
- for _, t := range tensors {
- C.mlx_vector_array_append_value(vec, t.(*Array).a)
- t.(*Array).sync = sync
- }
- C.mlx_async_eval(vec)
-}
-
-func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
- vec := C.mlx_vector_array_new()
- defer C.mlx_vector_array_free(vec)
- needSync := true
- sync := func() {
- if needSync {
- C.mlx_synchronize(c.stream)
- needSync = false
- }
- }
-
- for _, t := range tensors {
- t.(*Array).sync = sync
- C.mlx_vector_array_append_value(vec, t.(*Array).a)
- }
- C.mlx_async_eval(vec)
- return c
-}
-
-func (c *Context) Input() ml.Context {
- return c
-}
-
-// func (c *Context) Output() ml.Context {
-// return c
-// }
-
-func (c *Context) Layer(_ int) ml.Context {
- return c
-}
-
-func (c *Context) RandomNormal(shape []int, dtype ml.DType, loc, scale float32, key ml.Tensor) ml.Tensor {
- var r C.mlx_array
- var k C.mlx_array
- if key != nil {
- k = key.(*Array).a
- }
- sh := make([]C.int, len(shape))
- for i := range shape {
- sh[i] = C.int(shape[i])
- }
- C.mlx_random_normal(
- &r,
- &sh[0],
- C.size_t(len(shape)),
- C.mlx_dtype(dtype),
- C.float(loc),
- C.float(scale),
- k,
- c.stream,
- )
- return newArray(c, r)
-}
-
-func (c *Context) CompareWith(filepath string, tensors map[string]ml.Tensor, abortOnError bool) (err error) {
- minCosine := float32(0.96) // TODO too low...
- fileTensors := map[string]*Array{}
- defer func() {
- if err != nil {
- for k, v := range tensors {
- fmt.Fprintln(os.Stderr, "input tensor "+k+"\n"+v.ToString())
- if fv, ok := fileTensors[k]; ok {
- fmt.Fprintln(os.Stderr, " file tensor "+k+"\n"+fv.ToString())
- } else {
- fmt.Fprintln(os.Stderr, " file tensor "+k+" missing!\n")
- }
- }
- }
- if abortOnError {
- if err != nil {
- panic(fmt.Sprintf("%s", err))
- }
- }
- }()
- if _, err = os.Stat(filepath); err != nil {
- filepath += ".safetensors"
- if _, err = os.Stat(filepath); err != nil {
- err = fmt.Errorf("failed to stat %s: %w", filepath, err)
- return
- }
- err = nil
- }
- // slog.Info("Loading tensors from", "filename", filepath)
- cFilename := C.CString(filepath)
- defer C.free(unsafe.Pointer(cFilename))
- data := C.mlx_map_string_to_array_new() // TODO is this needed or just var it?
- metadata := C.mlx_map_string_to_string_new()
- defer C.mlx_map_string_to_array_free(data)
- defer C.mlx_map_string_to_string_free(metadata)
-
- stream := C.mlx_default_cpu_stream_new()
-
- if C.mlx_load_safetensors(&data, &metadata, cFilename, stream) != 0 {
- // TODO with the current error handling, this will never happen
- err = fmt.Errorf("load failed")
- return
- }
-
- it := C.mlx_map_string_to_array_iterator_new(data)
- allTensors := []ml.Tensor{}
- for _, t := range tensors {
- allTensors = append(allTensors, t)
- }
-
- for {
- var key *C.cchar_t
- var value C.mlx_array
- defer C.mlx_array_free(value)
- if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
- break
- }
- k := C.GoString((*C.char)(key))
- var r C.mlx_array
- defer C.mlx_array_free(r)
- C.mlx_astype(
- &r,
- value,
- C.MLX_FLOAT32,
- stream,
- )
-
- fileTensors[k] = &Array{
- name: k,
- a: r,
- }
- // slog.Info("XXX read", "tensor", t, "type", t.TypeString())
- allTensors = append(allTensors, fileTensors[k])
- }
- c.Forward(allTensors...)
- for k, t := range tensors {
- a, ok := fileTensors[k]
- if !ok {
- err = fmt.Errorf("tensor named %s not found in file", k)
- return
- }
- if !reflect.DeepEqual(a.Shape(), t.Shape()) {
- err = fmt.Errorf("mismatched shapes: file: %v vs. input %v", a.Shape(), t.Shape())
- return
- }
- // slog.Info("XXX shapes match", "shape", t.Shape())
- // TODO handle int types...
- tDType := t.DType()
- if tDType != ml.DTypeFloat16 && tDType != ml.DTypeFloat32 {
- var r C.mlx_array
- defer C.mlx_array_free(r)
- C.mlx_astype(
- &r,
- t.(*Array).a,
- C.MLX_FLOAT32,
- stream,
- )
- t = &Array{
- a: r,
- }
- c.Forward(t)
- }
-
- af := a.Floats()
- tf := t.Floats()
- cos := cosineSimilarity(af, tf)
- diff := a.Sub(c, t)
- min := diff.Min(c, nil, true)
- max := diff.Max(c, nil, true)
- c.Forward(min, max)
- minf := min.Floats()
- maxf := max.Floats()
- if cos < minCosine {
- err = fmt.Errorf("%s shapes match, but not similar enough: %v min_difference=%v max_difference=%v", k, cos, minf, maxf)
- return
- }
-
- slog.Info("XXX tensors are similar", k, cos, "shape", t.Shape(), "min_difference", minf, "max_difference", maxf)
- }
- err = nil
-
- return
-}
-
-func dotProduct[V float32 | float64](v1, v2 []V) V {
- var result V = 0
- if len(v1) != len(v2) {
- return result
- }
-
- for i := 0; i < len(v1); i++ {
- result += v1[i] * v2[i]
- }
- return result
-}
-
-func magnitude[V float32 | float64](v []V) V {
- var result V = 0
- for _, val := range v {
- result += val * val
- }
- return V(math.Sqrt(float64(result)))
-}
-
-func cosineSimilarity[V float32 | float64](v1, v2 []V) V {
- mag1 := magnitude(v1)
- mag2 := magnitude(v2)
-
- if mag1 == 0 || mag2 == 0 {
- return 0
- }
-
- return dotProduct(v1, v2) / (magnitude(v1) * magnitude(v2))
-}
-
-func euclideanDistance[V float32 | float64](v1, v2 []V) V {
- if len(v1) != len(v2) {
- return V(math.Inf(1))
- }
-
- var sum V = 0
- for i := 0; i < len(v1); i++ {
- diff := v1[i] - v2[i]
- sum += diff * diff
- }
-
- return V(math.Sqrt(float64(sum)))
-}
-
-func manhattanDistance[V float32 | float64](v1, v2 []V) V {
- if len(v1) != len(v2) {
- return V(math.Inf(1))
- }
-
- var sum V = 0
- for i := 0; i < len(v1); i++ {
- sum += V(math.Abs(float64(v1[i] - v2[i])))
- }
-
- return sum
-}
-
-type Array struct {
- name string
- a C.mlx_array
- c *Context
-
- sync func()
-}
-
-func newArray(ctx *Context, a C.mlx_array) *Array {
- // TODO measure impact and if this slows things down, make it conditional on some debugging flag at load time
- var name string
- _, f, l, ok := runtime.Caller(2)
- if ok {
- name = fmt.Sprintf("%s:%d", f, l)
- }
-
- t := &Array{
- name: name,
- a: a,
- c: ctx,
- }
- // DEBUG memory allocation problems...
- // slog.Info("XXX Allocated", "array", t, "a", a)
- ctx.mu.Lock()
- defer ctx.mu.Unlock()
- ctx.arrays = append(ctx.arrays, a)
- return t
-}
-
-// FromFloats implements ml.Context.
-func (c *Context) FromFloats(s []float32, shape ...int) ml.Tensor {
- u16s := make([]float16.Float16, len(s))
- for i := range u16s {
- u16s[i] = float16.Fromfloat32(s[i])
- }
- cshape := make([]C.int, len(shape))
- for i, dim := range shape {
- cshape[i] = C.int(dim)
- }
- return newArray(c,
- C.mlx_array_new_data(
- unsafe.Pointer(&u16s[0]),
- &cshape[0],
- C.int(len(cshape)),
- C.MLX_FLOAT16,
- ),
- )
-}
-
-func (a *Array) Floats() []float32 {
- if a.sync != nil {
- a.sync()
- }
- l := (int)(C.mlx_array_size(a.a))
-
- switch C.mlx_array_dtype(a.a) {
- case C.MLX_BFLOAT16:
- panic("bfloat16 not yet implemented")
- case C.MLX_FLOAT16:
- data := C.mlx_array_data_float16_asvoid(a.a)
- if data == nil {
- panic("nil data, wasn't eval'd")
- }
- u16s := unsafe.Slice((*uint16)(data), l)
- f32s := make([]float32, len(u16s))
- for i := range u16s {
- f32s[i] = float16.Frombits(u16s[i]).Float32()
- }
- return f32s
- case C.MLX_FLOAT32:
- data := C.mlx_array_data_float32(a.a)
- if data == nil {
- panic("nil data, wasn't eval'd")
- }
- f32s := unsafe.Slice((*float32)(data), l)
- return f32s
- default:
- panic(fmt.Sprintf("unsupported dtype for Floats: %d", C.mlx_array_dtype(a.a)))
- }
-}
-
-// FromInts implements ml.Context.
-func (c *Context) FromInts(s []int32, shape ...int) ml.Tensor {
- cshape := make([]C.int, len(shape))
- for i, dim := range shape {
- cshape[i] = C.int(dim)
- }
- return newArray(c,
- C.mlx_array_new_data(
- unsafe.Pointer(&s[0]),
- &cshape[0],
- C.int(len(cshape)),
- C.MLX_INT32,
- ),
- )
-}
-
-func (a *Array) Ints() []int32 {
- if a.sync != nil {
- a.sync()
- }
- l := (int)(C.mlx_array_size(a.a))
-
- switch C.mlx_array_dtype(a.a) {
- case C.MLX_INT32:
- data := C.mlx_array_data_int32(a.a)
- if data == nil {
- panic("nil data, wasn't eval'd")
- }
- i32s := unsafe.Slice((*int32)(data), l)
- return i32s
-
- // TODO other types via conversion?
- default:
- panic(fmt.Sprintf("unsupported dtype for Ints: %d", C.mlx_array_dtype(a.a)))
- }
-}
-
-func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
- sh := make([]C.int, len(shape))
- for i, s := range shape {
- sh[i] = (C.int)(s)
- }
-
- var r C.mlx_array
- C.mlx_zeros(
- &r,
- &sh[0],
- (C.size_t)(len(sh)),
- C.mlx_dtype(dtype),
- c.stream,
- )
- return newArray(c, r)
-}
-
-func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
- // TODO more efficient impl?
- return c.Zeros(dtype, shape...)
-}
-
-func (a *Array) DType() ml.DType {
- return (ml.DType)(C.mlx_array_dtype(a.a))
-}
-
-func (a *Array) Dim(n int) int {
- return int(C.mlx_array_dim(a.a, C.int(n)))
-}
-
-func (a *Array) Stride(n int) int {
- return (int)(C.stride(a.a, (C.int)(n)))
-}
-
-func (c *Context) Arange(start, stop, step float32, dtype ml.DType) ml.Tensor {
- var r C.mlx_array
- C.mlx_arange(
- &r,
- C.double(start),
- C.double(stop),
- C.double(step),
- (C.mlx_dtype)(dtype),
- c.stream,
- )
-
- return newArray(c, r)
-}
-
-// Scale implements ml.Tensor.
-func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor {
- scale := C.mlx_array_new_float(C.float(s))
- var r C.mlx_array
- C.mlx_multiply(
- &r,
- a.a,
- scale,
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Softmax(ctx ml.Context) ml.Tensor {
- var r C.mlx_array
- C.mlx_softmax(
- &r,
- a.a,
- false, // TODO - precise?
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) SliceUpdate(ctx ml.Context, update ml.Tensor, start, stop, strides []int) ml.Tensor {
- cStart := make([]C.int, len(start))
- for i := range start {
- cStart[i] = C.int(start[i])
- }
- cStop := make([]C.int, len(stop))
- for i := range stop {
- cStop[i] = C.int(stop[i])
- }
- cStrides := make([]C.int, len(strides))
- for i := range strides {
- cStrides[i] = C.int(strides[i])
- }
- var r C.mlx_array
- C.mlx_slice_update(
- &r,
- a.a,
- update.(*Array).a,
- (*C.int)(unsafe.Pointer(&cStart[0])),
- C.size_t(len(cStart)),
- (*C.int)(unsafe.Pointer(&cStop[0])),
- C.size_t(len(cStop)),
- (*C.int)(unsafe.Pointer(&cStrides[0])),
- C.size_t(len(cStrides)),
- ctx.(*Context).stream,
- )
- // Release the old array and replace with the new one to ensure the same underlying buffer is used
- a.c.mu.Lock()
- defer a.c.mu.Unlock()
- for i := range a.c.arrays {
- if a.c.arrays[i] == a.a {
- C.mlx_array_free(a.a)
- a.a = r
- a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
- return a
- }
- }
- panic("unable to locate array in context")
-}
-
-func (a *Array) SliceUpdateDynamic(ctx ml.Context, update, start ml.Tensor, axes []int) ml.Tensor {
- cAxes := make([]C.int, len(axes))
- for i := range axes {
- cAxes[i] = C.int(axes[i])
- }
-
- var r C.mlx_array
- C.mlx_slice_update_dynamic(
- &r,
- a.a,
- update.(*Array).a,
- start.(*Array).a,
- (*C.int)(unsafe.Pointer(&cAxes[0])),
- C.size_t(len(cAxes)),
- ctx.(*Context).stream,
- )
- // Release the old array and replace with the new one to ensure the same underlying buffer is used
- a.c.mu.Lock()
- defer a.c.mu.Unlock()
- for i := range a.c.arrays {
- if a.c.arrays[i] == a.a {
- C.mlx_array_free(a.a)
- a.a = r
- a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
- return a
- }
- }
- panic("unable to locate array in context")
-
-}
-
-func (a *Array) PutAlongAxis(ctx ml.Context, indicies, values ml.Tensor, axis int) ml.Tensor {
- var r C.mlx_array
- C.mlx_put_along_axis(
- &r,
- a.a,
- indicies.(*Array).a,
- values.(*Array).a,
- C.int(axis),
- ctx.(*Context).stream,
- )
- // Release the old array and replace with the new one to ensure the same underlying buffer is used
- a.c.mu.Lock()
- defer a.c.mu.Unlock()
- for i := range a.c.arrays {
- if a.c.arrays[i] == a.a {
- C.mlx_array_free(a.a)
- a.a = r
- a.c.arrays = append(a.c.arrays[:i], a.c.arrays[i+1:]...)
- return a
- }
- }
- panic("unable to locate array in context")
-}
-
-func (a *Array) Scatter(ctx ml.Context, indicies []ml.Tensor, updates ml.Tensor, axes []int) ml.Tensor {
-
- cAxes := make([]C.int, len(axes))
- for i := range axes {
- cAxes[i] = C.int(axes[i])
- }
- var cAxes0 *C.int
- if len(cAxes) > 0 {
- cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
- }
- indiciesVec := C.mlx_vector_array_new()
- defer C.mlx_vector_array_free(indiciesVec)
- for _, ind := range indicies {
- C.mlx_vector_array_append_value(indiciesVec, ind.(*Array).a)
- }
-
- var r C.mlx_array
- C.mlx_scatter(
- &r,
- a.a,
- indiciesVec,
- updates.(*Array).a,
- cAxes0,
- C.size_t(len(cAxes)),
- ctx.(*Context).stream,
- )
- // Release the old array and replace with the new one to ensure the same underlying buffer is used
- a.c.mu.Lock()
- defer a.c.mu.Unlock()
- for i := range a.c.arrays {
- if a.c.arrays[i] == a.a {
- C.mlx_array_free(a.a)
- a.a = r
- a.c.arrays[i] = r
- return a
- }
- }
- panic("unable to locate array in context")
-
-}
-
-func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
- C.mlx_copy(
- &a2.(*Array).a,
- a.a,
- ctx.(*Context).stream,
- )
- // TODO - view?
- return newArray(ctx.(*Context), a2.(*Array).a)
-}
-
-func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
- var r C.mlx_array
- C.mlx_add(
- &r,
- a.a,
- a2.(*Array).a,
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Sub(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
- var r C.mlx_array
- C.mlx_subtract(
- &r,
- a.a,
- a2.(*Array).a,
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Max(ctx ml.Context, axes []int, keepDims bool) ml.Tensor {
- var r C.mlx_array
- cAxes := make([]C.int, len(axes))
- for i := range axes {
- cAxes[i] = C.int(axes[i])
- }
- var cAxes0 *C.int
- if len(cAxes) > 0 {
- cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
- C.mlx_max_axes(
- &r,
- a.a,
- cAxes0,
- C.size_t(len(cAxes)),
- C._Bool(keepDims),
- ctx.(*Context).stream,
- )
- } else {
- C.mlx_max(
- &r,
- a.a,
- C._Bool(keepDims),
- ctx.(*Context).stream,
- )
-
- }
-
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Min(ctx ml.Context, axes []int, keepDims bool) ml.Tensor {
- var r C.mlx_array
- cAxes := make([]C.int, len(axes))
- for i := range axes {
- cAxes[i] = C.int(axes[i])
- }
- var cAxes0 *C.int
- if len(cAxes) > 0 {
- cAxes0 = (*C.int)(unsafe.Pointer(&cAxes[0]))
- C.mlx_min_axes(
- &r,
- a.a,
- cAxes0,
- C.size_t(len(cAxes)),
- C._Bool(keepDims),
- ctx.(*Context).stream,
- )
- } else {
- C.mlx_min(
- &r,
- a.a,
- C._Bool(keepDims),
- ctx.(*Context).stream,
- )
- }
-
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Matmul(ctx ml.Context, a2 ml.Tensor) ml.Tensor {
- var r C.mlx_array
- C.mlx_matmul(
- &r,
- a.a,
- a2.(*Array).a,
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
- // slog.Info("MLX.RMSNorm", "a", a, "w", w)
- var r C.mlx_array
- C.mlx_fast_rms_norm(
- &r,
- a.a,
- w.(*Array).a,
- C.float(eps),
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
- var r C.mlx_array
- C.mlx_fast_layer_norm(
- &r,
- a.a,
- w.(*Array).a,
- b.(*Array).a,
- C.float(eps),
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
- // TODO implement
- panic("NOT YET IMPLEMENTED")
-}
-
-func (t Array) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
- panic("NOT YET IMPLEMENTED")
-}
-
-// RoPE implements Rotary Positional Encoding
-//
-// dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged.
-// traditional (bool) – If set to True choose the traditional implementation which rotates consecutive dimensions.
-// scale (float) – The scale used to scale the positions.
-// offset (int) – The position offset to start at. TODO MLX-C does not yet expose Offset as an Array
-// WithBase (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None.
-// WithFreqs (array, optional) – Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None.
-func (a *Array) RoPE(ctx ml.Context, dims int, traditional bool, scale float32, offset int, options ...func(*ml.RoPEOptions)) ml.Tensor {
- opts := ml.RoPEOptions{}
-
- // Apply any provided options
- for _, option := range options {
- option(&opts)
- }
- var r C.mlx_array
- var base C.mlx_optional_float
- var freqs C.mlx_array
-
- if opts.Base != nil {
- base.value = C.float(*opts.Base)
- base.has_value = true
- }
- if opts.Freqs != nil {
- freqs = opts.Freqs.(*Array).a
- }
- C.mlx_fast_rope(
- &r,
- a.a,
- C.int(dims),
- C._Bool(traditional),
- base,
- C.float(scale),
- C.int(offset),
- freqs,
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-// A fast implementation of multi-head attention: O = softmax(Q @ K.T, dim=-1) @ V.
-//
-// Supports:
-// - Multi-Head Attention
-// - Grouped Query Attention
-// - Multi-Query Attention
-//
-// Note:
-// - The softmax operation is performed in float32 regardless of the input precision.
-// - For Grouped Query Attention and Multi-Query Attention, the k and v inputs should not be pre-tiled to match q.
-//
-// In the following the dimensions are given by:
-// - B: The batch size.
-// - N_q: The number of query heads.
-// - N_kv: The number of key and value heads.
-// - T_q: The number of queries per example.
-// - T_kv: The number of keys and values per example.
-// - D: The per-head dimension.
-//
-// Parameters:
-// - [subject array] queries (array) – Queries with shape [B, N_q, T_q, D].
-// - keys (array) – with shape [B, N_kv, T_kv, D].
-// - values (array) – with shape [B, N_kv, T_kv, D].
-// - scale (float) – Scale for queries (typically 1.0 / sqrt(q.shape(-1)).
-// - mask (str or array, optional) – The mask to apply to the query-key scores.
-// The mask can be an array or a string indicating the mask type. The only supported string type is "causal".
-// If the mask is an array it can be a boolean or additive mask. The mask can have at most 4 dimensions and
-// must be broadcast-compatible with the shape [B, N, T_q, T_kv]. If an additive mask is given its type must
-// promote to the promoted type of q, k, and v.
-// - sinks (array, optional) – An optional array of attention sinks. Default: None.
-
-func (queries *Array) ScaledDotProductAttention(ctx ml.Context, keys, values ml.Tensor, scale float64, maskMode string, mask ml.Tensor, sinks ml.Tensor) ml.Tensor {
- var r C.mlx_array
- var s C.mlx_array
- if sinks != nil {
- s = sinks.(*Array).a
- }
- maskModeC := C.CString(maskMode)
- defer C.free(unsafe.Pointer(maskModeC))
- var maskArr C.mlx_array
- if mask != nil {
- maskArr = mask.(*Array).a
- }
-
- C.mlx_fast_scaled_dot_product_attention(
- &r,
- queries.a,
- keys.(*Array).a,
- values.(*Array).a,
- C.float(scale),
- maskModeC,
- maskArr,
- s,
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) TakeAxes(ctx ml.Context, indicies ml.Tensor, axes int) ml.Tensor {
- var r C.mlx_array
-
- C.mlx_take_axis(&r, a.a, indicies.(*Array).a, C.int(axes), ctx.(*Context).stream)
- return newArray(ctx.(*Context), r)
-
-}
-
-// TODO not sure if we'll want this variation taking raw ints instead of a tensor...
-// func (a *Array) TakeAxes(ctx ml.Context, axes int, indicies ...int) ml.Tensor {
-// var i C.mlx_array
-// var r C.mlx_array
-
-// if indicies != nil {
-// shape := []C.int{C.int(len(indicies))}
-// cindicies := make([]int32, len(indicies))
-// for i, v := range indicies {
-// cindicies[i] = int32(v)
-// }
-// i = C.mlx_array_new_data(
-// unsafe.Pointer(&cindicies[0]),
-// &shape[0],
-// C.int(len(shape)),
-// C.MLX_INT32,
-// )
-// }
-// C.mlx_take_axis(&r, a.a, i, C.int(axes), ctx.(*Context).stream)
-// return newArray(ctx.(*Context), r)
-
-// }
-
-func (a *Array) GELU(ctx ml.Context, up ...ml.Tensor) ml.Tensor {
- // TODO precise vs fast, and compile
- // x * mx.sigmoid(1.702 * x)
- u16s := []float16.Float16{float16.Fromfloat32(1.702)}
- cshape := []C.int{1}
- f := C.mlx_array_new_data(unsafe.Pointer(&u16s[0]), &cshape[0], 1, C.MLX_FLOAT16)
- defer C.mlx_array_free(f)
- var r1, r2, r3 C.mlx_array
- C.mlx_multiply(&r1, a.a, f, ctx.(*Context).stream)
- defer C.mlx_array_free(r1)
- C.mlx_sigmoid(&r2, r1, ctx.(*Context).stream)
- defer C.mlx_array_free(r2)
- C.mlx_multiply(&r3, a.a, r2, ctx.(*Context).stream)
-
- if len(up) > 0 {
- var r4 C.mlx_array
- defer C.mlx_array_free(r3)
- C.mlx_multiply(&r4, r3, up[0].(*Array).a, ctx.(*Context).stream)
- return newArray(ctx.(*Context), r4)
- }
-
- return newArray(ctx.(*Context), r3)
-}
-
-// Create a view into the array with the given shape and strides.
-//
-// The resulting array will always be as if the provided array was row
-// contiguous regardless of the provided arrays storage order and current
-// strides.
-//
-// Note that this function should be used with caution as it changes the shape
-// and strides of the array directly. This can lead to the resulting array
-// pointing to invalid memory locations which can result into crashes.
-//
-// Parameters:
-// - shape (list(int), optional) – The shape of the resulting array. If None it defaults to a.shape().
-// - strides (list(int), optional) – The strides of the resulting array. If None it defaults to the
-// reverse exclusive cumulative product of a.shape().
-// - offset (int) – Skip that many elements from the beginning of the input array.
-func (a *Array) AsStrided(ctx ml.Context, shape, strides []int, offset int) ml.Tensor {
- var r C.mlx_array
- sh := make([]C.int, len(shape))
- st := make([]C.int64_t, len(strides))
- var sh0 *C.int
- var st0 *C.int64_t
- for i, s := range shape {
- sh[i] = C.int(s)
- }
- for i, s := range strides {
- st[i] = C.int64_t(s)
- }
- if len(sh) > 0 {
- sh0 = (*C.int)(unsafe.Pointer(&sh[0]))
- }
- if len(st) > 0 {
- st0 = (*C.int64_t)(unsafe.Pointer(&st[0]))
- }
-
- C.mlx_as_strided(
- &r,
- a.a,
- sh0,
- C.size_t(len(sh)),
- st0,
- C.size_t(len(st)),
- C.size_t(offset),
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-
-}
-
-func (a *Array) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
- cshape := make([]C.int, len(shape))
- for i, dim := range shape {
- cshape[i] = C.int(dim)
- }
- var r C.mlx_array
- C.mlx_reshape(&r, a.a, &cshape[0], C.size_t(len(cshape)), ctx.(*Context).stream)
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Transpose(ctx ml.Context, shape ...int) ml.Tensor {
- ndim := min(C.mlx_array_ndim(a.a), C.size_t(len(shape)))
- var r C.mlx_array
- sh := make([]C.int, ndim)
- for i := range ndim {
- sh[i] = (C.int)(shape[i])
- if int(sh[i]) >= int(ndim) {
- slog.Error("Permute error", "tensor", a, "shape", shape)
- panic("invalid pemute call")
- }
- }
- if len(sh) > 0 {
- C.mlx_transpose_axes(
- &r,
- a.a,
- &sh[0],
- ndim,
- ctx.(*Context).stream,
- )
- } else {
- C.mlx_transpose(
- &r,
- a.a,
- ctx.(*Context).stream,
- )
- }
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) Contiguous(ctx ml.Context, allowColMajor bool) ml.Tensor {
- var r C.mlx_array
- C.mlx_contiguous(
- &r,
- a.a,
- (C._Bool)(allowColMajor),
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-// Conv2D implements ml.Tensor.
-// GGML API
-// input: [N, IC, IH, IW]
-// weight: [OC,IC, KH, KW]
-// result: [N, OC, OH, OW]
-//
-// MLX:
-// input: (N, KH, KW, C_in)
-// weight: (C_out, IH, IW, C_in)
-// result: XXX
-
-func (input *Array) Conv2D(ctx ml.Context, weight ml.Tensor, stride0, stride1, padding0, padding1, dilation0, dilation1, groups int) ml.Tensor {
- var r C.mlx_array
- C.mlx_conv2d(
- &r,
- input.a,
- weight.(*Array).a,
- C.int(stride0),
- C.int(stride1),
- C.int(padding0),
- C.int(padding1),
- C.int(dilation0),
- C.int(dilation1),
- C.int(groups),
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (input *Array) Conv3D(ctx ml.Context, weight ml.Tensor, stride0, stride1, stride2, padding0, padding1, padding2, dilation0, dilation1, dilation2, groups int) ml.Tensor {
- var r C.mlx_array
- C.mlx_conv3d(
- &r,
- input.a,
- weight.(*Array).a,
- C.int(stride0),
- C.int(stride1),
- C.int(stride2),
- C.int(padding0),
- C.int(padding1),
- C.int(padding2),
- C.int(dilation0),
- C.int(dilation1),
- C.int(dilation2),
- C.int(groups),
- ctx.(*Context).stream,
- )
- return newArray(ctx.(*Context), r)
-}
-
-func (a *Array) ToString() string {
- str := C.mlx_string_new()
- C.mlx_array_tostring(&str, a.a)
- s := C.mlx_string_data(str)
- defer C.mlx_string_free(str)
- return C.GoString(s)
-}
-
-func (a *Array) LogValue() slog.Value {
-
- dims := int(C.mlx_array_ndim(a.a))
- strides := make([]int, dims)
- for i := range strides {
- strides[i] = int(C.stride(a.a, (C.int)(i)))
- }
-
- return slog.GroupValue(
- slog.String("name", a.name),
- slog.String("type", a.TypeString()),
- slog.Any("shape", a.Shape()),
- slog.Any("strides", strides),
- // slog.String("values", C.GoString(s)),
- )
-}
-
-func (a *Array) Shape() []int {
- shape := make([]int, C.mlx_array_ndim(a.a))
- for i := range shape {
- shape[i] = int(C.mlx_array_dim(a.a, C.int(i)))
- }
-
- return shape
-}
-
-func (a *Array) TypeString() string {
- switch C.mlx_array_dtype(a.a) {
- case C.MLX_BOOL:
- return "bool"
- case C.MLX_UINT8:
- return "uint8"
- case C.MLX_UINT16:
- return "uint16"
- case C.MLX_UINT32:
- return "uint32"
- case C.MLX_UINT64:
- return "uint64"
- case C.MLX_INT8:
- return "int8"
- case C.MLX_INT16:
- return "int16"
- case C.MLX_INT32:
- return "int32"
- case C.MLX_INT64:
- return "int64"
- case C.MLX_FLOAT16:
- return "float16"
- case C.MLX_FLOAT32:
- return "float32"
- case C.MLX_BFLOAT16:
- return "bfloat16"
- case C.MLX_COMPLEX64:
- return "complex64"
- default:
- return "unknown"
- }
-}
diff --git a/x/ml/backend/mlx/mlx_dynamic.c b/x/ml/backend/mlx/mlx_dynamic.c
deleted file mode 100644
index 0038355aedb..00000000000
--- a/x/ml/backend/mlx/mlx_dynamic.c
+++ /dev/null
@@ -1,92 +0,0 @@
-// mlx_dynamic.c - Dynamic loading wrapper for MLX-C library
-// This file provides runtime dynamic loading of libmlxc instead of link-time binding
-
-#include "mlx_dynamic.h"
-#include
-#include
-#include
-
-#ifdef _WIN32
-#include
-typedef HMODULE lib_handle_t;
-#define LOAD_LIB(path) LoadLibraryA(path)
-#define GET_SYMBOL(handle, name) GetProcAddress(handle, name)
-#define CLOSE_LIB(handle) FreeLibrary(handle)
-#define LIB_ERROR() "LoadLibrary failed"
-static const char* LIB_NAMES[] = {"libmlxc.dll", NULL};
-#else
-#include
-typedef void* lib_handle_t;
-#define LOAD_LIB(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
-#define GET_SYMBOL(handle, name) dlsym(handle, name)
-#define CLOSE_LIB(handle) dlclose(handle)
-#define LIB_ERROR() dlerror()
-#ifdef __APPLE__
-static const char* LIB_NAMES[] = {
- "libmlxc.dylib",
- "@loader_path/../build/lib/ollama/libmlxc.dylib",
- "@executable_path/../build/lib/ollama/libmlxc.dylib",
- "build/lib/ollama/libmlxc.dylib",
- "../build/lib/ollama/libmlxc.dylib",
- NULL
-};
-#else
-static const char* LIB_NAMES[] = {
- "libmlxc.so",
- "$ORIGIN/../build/lib/ollama/libmlxc.so",
- "build/lib/ollama/libmlxc.so",
- "../build/lib/ollama/libmlxc.so",
- NULL
-};
-#endif
-#endif
-
-static lib_handle_t mlx_handle = NULL;
-static int mlx_initialized = 0;
-static char mlx_error_buffer[512] = {0};
-
-// Initialize MLX dynamic library
-// Returns 0 on success, -1 on failure
-// On failure, call mlx_dynamic_error() to get error message
-int mlx_dynamic_init(void) {
- if (mlx_initialized) {
- return 0; // Already initialized
- }
-
- // Try each possible library path
- for (int i = 0; LIB_NAMES[i] != NULL; i++) {
- mlx_handle = LOAD_LIB(LIB_NAMES[i]);
- if (mlx_handle != NULL) {
- mlx_initialized = 1;
- snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
- "MLX: Successfully loaded %s", LIB_NAMES[i]);
- return 0;
- }
- }
-
- // Failed to load library
- const char* err = LIB_ERROR();
- snprintf(mlx_error_buffer, sizeof(mlx_error_buffer),
- "MLX: Failed to load libmlxc library. %s",
- err ? err : "Unknown error");
- return -1;
-}
-
-// Get the last error message
-const char* mlx_dynamic_error(void) {
- return mlx_error_buffer;
-}
-
-// Check if MLX is initialized
-int mlx_dynamic_is_initialized(void) {
- return mlx_initialized;
-}
-
-// Cleanup (optional, called at program exit)
-void mlx_dynamic_cleanup(void) {
- if (mlx_handle != NULL) {
- CLOSE_LIB(mlx_handle);
- mlx_handle = NULL;
- mlx_initialized = 0;
- }
-}
diff --git a/x/ml/backend/mlx/mlx_dynamic.h b/x/ml/backend/mlx/mlx_dynamic.h
deleted file mode 100644
index 2ae162a9ae0..00000000000
--- a/x/ml/backend/mlx/mlx_dynamic.h
+++ /dev/null
@@ -1,26 +0,0 @@
-// mlx_dynamic.h - Dynamic loading interface for MLX-C library
-#ifndef MLX_DYNAMIC_H
-#define MLX_DYNAMIC_H
-
-#ifdef __cplusplus
-extern "C" {
-#endif
-
-// Initialize the MLX dynamic library
-// Returns 0 on success, -1 on failure
-int mlx_dynamic_init(void);
-
-// Get the last error message from dynamic loading
-const char* mlx_dynamic_error(void);
-
-// Check if MLX is initialized
-int mlx_dynamic_is_initialized(void);
-
-// Cleanup resources (optional, for clean shutdown)
-void mlx_dynamic_cleanup(void);
-
-#ifdef __cplusplus
-}
-#endif
-
-#endif // MLX_DYNAMIC_H
diff --git a/x/ml/backend/mlx/mlx_test.go b/x/ml/backend/mlx/mlx_test.go
deleted file mode 100644
index 7699c1524dc..00000000000
--- a/x/ml/backend/mlx/mlx_test.go
+++ /dev/null
@@ -1,314 +0,0 @@
-//go:build mlx
-
-package mlx
-
-import (
- "log/slog"
- "os"
- "reflect"
- "strings"
- "testing"
-
- "github.com/ollama/ollama/api"
- "github.com/ollama/ollama/runner/common"
- "github.com/ollama/ollama/sample"
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/model"
- "github.com/ollama/ollama/x/model/input"
- _ "github.com/ollama/ollama/x/model/models/gemma3"
-)
-
-func init() {
- logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
- slog.SetDefault(logger)
-}
-
-func TestLoadModel(t *testing.T) {
- dir := "/Users/daniel/Models/gemma-3-4b-it/"
- b := &Backend{}
- err := b.LoadSafeTensors(dir)
- if err != nil {
- t.Fatalf("load failed: %s", err)
- }
-}
-
-func TestFromInts(t *testing.T) {
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- data := []int32{1, 2, 3, 4, 5, 6}
- a := c.FromInts(data, 2, 3)
- slog.Info("", "array", a)
- t.Log(a.ToString())
- if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
- t.Fatalf("incorrect shape: %v", a.Shape())
- }
-}
-
-func TestFromFloats(t *testing.T) {
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- data := []float32{1, 2, 3, 4, 5, 6}
- a := c.FromFloats(data, 2, 3)
- slog.Info("", "array", a)
- t.Log(a.ToString())
- if !reflect.DeepEqual(a.Shape(), []int{2, 3}) {
- t.Fatalf("incorrect shape: %v", a.Shape())
- }
- res := a.Floats()
- if !reflect.DeepEqual(res, data) {
- t.Fatalf("incorrect results: %v", res)
- }
-}
-
-func TestAdd(t *testing.T) {
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- t1 := c.Arange(0, 24, 1, ml.DTypeFloat16)
- t2 := c.Arange(0, 24, 1, ml.DTypeFloat16)
- exp := c.Arange(0, 48, 2, ml.DTypeFloat16)
- t3 := t1.Add(c, t2)
- c.Compute(t3, exp)
- t3f := t3.Floats()
- if !reflect.DeepEqual(t3f, exp.Floats()) {
- t.Fatalf("incorrect result: %v", t3f)
- }
-}
-
-func TestReshapeTranspose(t *testing.T) {
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- t1 := c.Arange(0, 24, 1, ml.DTypeFloat16).Reshape(c, 2, 3, 4).Transpose(c, 0, 2, 1).Contiguous(c, false)
- c.Compute(t1)
- t1f := t1.Floats()
- exp := []float32{
- 0, 4, 8,
- 1, 5, 9,
- 2, 6, 10,
- 3, 7, 11,
- 12, 16, 20,
- 13, 17, 21,
- 14, 18, 22,
- 15, 19, 23,
- }
- if !reflect.DeepEqual(t1f, exp) {
- t.Fatalf("incorrect results: %v", t1f)
- }
-}
-
-func prod(vals ...int) int {
- r := 1
- for _, v := range vals {
- r *= v
- }
- return r
-}
-func TestMatmul(t *testing.T) {
- // TODO create scenarios...
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- s1 := []int{1, 3, 2, 4}
- t1 := c.Arange(0, float32(prod(s1...)), 1, ml.DTypeFloat16).Reshape(c, s1...)
- s2 := []int{4, 2}
- t2 := c.Arange(0, float32(prod(s2...)), 1, ml.DTypeFloat16).Reshape(c, s2...)
- t3 := t1.Matmul(c, t2)
- exp := []float32{
- 28, 34,
- 76, 98,
-
- 124, 162,
- 172, 226,
-
- 220, 290,
- 268, 354,
- }
- c.Compute(t3)
- t3f := t3.Floats()
- if !reflect.DeepEqual(t3f, exp) {
- t.Fatalf("incorrect result: %v", t3f)
- }
-}
-
-func TestRows(t *testing.T) {
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- t1 := c.Arange(0, 12, 1, ml.DTypeFloat32).Reshape(c, 1, 4, 3)
- outputs := c.Zeros(ml.DTypeInt32, 1)
- t2 := t1.TakeAxes(c, outputs, 1)
- c.Forward(t1, t2).Compute(t1, t2)
- t.Log(t1.ToString())
- t.Log(t2.ToString())
- f := t2.Floats()
- t.Logf("Result: %v", f)
-}
-
-func TestCaching(t *testing.T) {
- // Validate the caching algorithm
- b := &Backend{}
- c := b.NewContext()
- defer c.Close()
- batchSize := 3
- headDim := 4
- numKVHeads := 2
- // Make cache twice the size of one test batch
- cells := batchSize * 2
- cellSize := numKVHeads * headDim
- shape := []int{1, numKVHeads, batchSize, headDim}
- stop := float32(1)
- for _, x := range shape {
- stop *= float32(x)
- }
- // Create the cache
- cache := c.Zeros(ml.DTypeFloat16, cells, cellSize)
- t.Logf("Empty Cache shape%v\n"+cache.ToString(), []int{cells, cellSize})
-
- // Input tensor
- t1 := c.Arange(0, stop, 1, ml.DTypeFloat16).Reshape(c, shape...)
- t.Logf("Initial Data shape%v\n"+t1.ToString(), shape)
-
- // Reshape to copy into the cache
- /*
- From MLX python/src/indexing.cpp mlx_scatter_args_array
- // The update shape must broadcast with indices.shape + [1] + src.shape[1:]
- auto up_shape = indices.shape();
- up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
- up = broadcast_to(up, up_shape);
- up_shape.insert(up_shape.begin() + indices.ndim(), 1);
- up = reshape(up, up_shape);
- */
- numRows := 3
- up := t1.Reshape(c, numRows, 1, cellSize) // The shape has to look like this for scatter to work properly
- t.Logf("Data reshaped for cache input shape%v\n"+up.ToString(), []int{batchSize, numKVHeads * headDim})
-
- // Simulate cells 1,3,5 are available
- indicies := []ml.Tensor{c.FromInts([]int32{1, 3, 5}, numRows)}
- t.Logf("Indicies shape%v\n"+indicies[0].ToString(), []int{numRows})
- axis := []int{0} // The 1,3,5 of the indicies are in reference to axis 0 in the cache shape
- cache.Scatter(c, indicies, up, axis)
-
- c.Forward(cache)
- // Cache should contain the data now
- t.Log("Cache after put\n" + cache.ToString())
-
- // Retrieve cache content and verify it matches
- out := cache.TakeAxes(c, indicies[0], 0).Reshape(c, shape...)
- t.Logf("Output shape%v\n"+out.ToString(), out.Shape())
-
- t1f := t1.Floats()
- outf := out.Floats()
- if !reflect.DeepEqual(t1f, outf) {
- t.Fatalf("mismatched in->out\n%v\n ->\n%v", t1f, outf)
- }
-}
-
-func TestGemma3(t *testing.T) {
- // Why is the sky blue
- inputs := []int32{2, 105, 2364, 107, 36425, 563, 506, 7217, 3730, 106, 107, 105, 4368}
- limit := 50
-
- // TODO generalize this
- dir := "/Users/daniel/Models/gemma-3-4b-it/"
-
- m, err := model.New(dir, ml.BackendParams{})
- if err != nil {
- t.Fatalf("unable to load model: %s", err)
- }
- b := m.Backend()
- ctx := b.NewContext()
- defer ctx.Close()
-
- batch := input.Batch{
- Inputs: ctx.FromInts(inputs[:], 1, len(inputs)),
- Positions: make([]int32, len(inputs)),
- Sequences: make([]int, len(inputs)),
- Outputs: ctx.FromInts([]int32{int32(len(inputs) - 1)}, 1),
- Offset: 0,
- }
- for i := range len(inputs) {
- batch.Positions[i] = int32(i)
- }
- offset := len(inputs)
-
- cache := m.Config().Cache
- if cache != nil {
- numSlots := 1
- batchSize := 512
- numCtx := 4096
-
- // Note: this is inconsistent with mlx-py, but trying to be consistent with the GGML cache impl to get things working
- // cache.SetConfig(ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 64})
- cache.SetConfig(ml.CacheConfig{CachePadding: 0, MaskDType: ml.DTypeBfloat16, MaskBatchPadding: 0})
-
- cache.Init(b, ml.DTypeBfloat16, numSlots, int(numCtx), batchSize)
- err := cache.StartForward(ctx, batch, false)
- if err != nil {
- t.Fatalf("failed cache.StartForward: %s", err)
- }
- }
- opts := api.DefaultOptions()
- var grammar *sample.GrammarSampler
- sampler := sample.NewSampler(
- opts.Temperature,
- opts.TopK,
- opts.TopP,
- opts.MinP,
- opts.Seed,
- grammar,
- )
-
- t.Log("Starting Forward pass loop")
- pendingResponses := []string{}
- for {
- out, err := m.Forward(ctx, batch)
- if err != nil {
- t.Fatalf("failed forward pass: %s", err)
- }
- ctx.Forward(out)
- outputs := out.Floats()
- t.Logf("finished forward pass! length:%d", len(outputs))
- // sample a token
- logits := outputs
- token, err := sampler.Sample(logits)
- if err != nil {
- t.Fatalf("unable to sample token: %s", err)
- }
- t.Logf("Sampled token: %v", token)
- if m.(model.TextProcessor).Is(token, model.SpecialEOS) {
- t.Log("hit EOS")
- break
- }
- piece, err := m.(model.TextProcessor).Decode([]int32{token})
- if err != nil {
- t.Fatalf("unable to decode token: %s", err)
- }
-
- pendingResponses = append(pendingResponses, piece)
- sequence := strings.Join(pendingResponses, "")
- if ok, stop := common.FindStop(sequence, opts.Stop); ok {
- t.Logf("hit stop token: %v", stop)
- break
- }
- t.Logf("RESULTS: %s", sequence)
- batch = input.Batch{
- Inputs: ctx.FromInts([]int32{token}, 1, 1),
- Positions: make([]int32, 1),
- Sequences: make([]int, 1),
- Outputs: ctx.FromInts([]int32{0}, 1),
- Offset: offset,
- }
- offset++
- batch.Positions[0] = 0
- err = cache.StartForward(ctx, batch, false)
- if err != nil {
- t.Fatalf("failed cache.StartForward: %s", err)
- }
- if offset > limit {
- break
- }
- }
-}
diff --git a/x/ml/backend/mlx/quant.go b/x/ml/backend/mlx/quant.go
deleted file mode 100644
index 724f4325389..00000000000
--- a/x/ml/backend/mlx/quant.go
+++ /dev/null
@@ -1,335 +0,0 @@
-//go:build mlx
-
-package mlx
-
-/*
-#include
-#include
-
-#include "mlx/c/array.h"
-#include "mlx/c/ops.h"
-
-// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
-
-void unpack_32_4(uint8_t* data, int8_t* dst) {
- memset(dst, 0, 16);
- for (int j = 0; j < 16; ++j) {
- uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
- if (j % 2 != 0) {
- x <<= 4;
- }
- dst[j / 2] += x;
- }
- // Last 16 weights are in the higher bits
- for (int j = 0; j < 16; ++j) {
- uint8_t x = (data[j + 2] >> 4);
- if (j % 2 != 0) {
- x <<= 4;
- }
- dst[8 + j / 2] += x;
- }
-}
-
-// Extracts (weight, scales, biases) from Q4_0 tensors.
-// Data layout is: |16 bit scale|32 x 4bit weights|.
-void extract_q4_0_data(
- uint8_t* data,
- mlx_array* weights_arr,
- mlx_array* scales_arr,
- mlx_array* biases_arr) {
- const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
- uint8_t* weights = mlx_array_data_uint8(*weights_arr);
- float16_t* scales = mlx_array_data_float16(*scales_arr);
- float16_t* biases = mlx_array_data_float16(*biases_arr);
- for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
- scales[i] = *((float16_t*)data);
- biases[i] = -8 * scales[i];
- unpack_32_4(data, weights);
- weights += 16;
- data += bytes_per_block;
- }
-}
-
-// Extracts (weight, scales, biases) from Q4_1 tensors.
-// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
-void extract_q4_1_data(
- uint8_t* data,
- mlx_array* weights_arr,
- mlx_array* scales_arr,
- mlx_array* biases_arr) {
- const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
- uint8_t* weights = mlx_array_data_uint8(*weights_arr);
- float16_t* scales = mlx_array_data_float16(*scales_arr);
- float16_t* biases = mlx_array_data_float16(*biases_arr);
- for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
- scales[i] = *((float16_t*)data);
- biases[i] = *((float16_t*)(data) + 1);
- unpack_32_4(data, weights);
- weights += 16;
- data += bytes_per_block;
- }
-}
-
-// Extracts (weight, scales, biases) from Q8_0 tensors.
-// Data layout is: |16 bit scale|32 x 8bit weights|.
-void extract_q8_0_data(
- uint8_t* data,
- mlx_array* weights_arr,
- mlx_array* scales_arr,
- mlx_array* biases_arr) {
- const uint64_t weights_per_block = 32;
- const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
- uint8_t* weights = mlx_array_data_uint8(*weights_arr);
- float16_t* scales = mlx_array_data_float16(*scales_arr);
- float16_t* biases = mlx_array_data_float16(*biases_arr);
- for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
- uint8_t* block_data = data + i * bytes_per_block;
- scales[i] = *((float16_t*)block_data);
- biases[i] = -128 * scales[i];
- for (int64_t j = 0; j < weights_per_block; ++j) {
- uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
- // Original data is in int8_t, so we add a bias of -128 and invert the
- // first bit.
- x ^= 1 << 7;
- weights[i * weights_per_block + j] = x;
- }
- }
-}
-
-// Drived from ggml-quants.c
-
-#define QK_K 256
-
-// 6-bit quantization
-// weight is represented as x = a * q
-// 16 blocks of 16 elements each
-// Effectively 6.5625 bits per weight
-typedef struct {
- uint8_t ql[QK_K/2]; // quants, lower 4 bits
- uint8_t qh[QK_K/4]; // quants, upper 2 bits
- int8_t scales[QK_K/16]; // scales, quantized with 8 bits
- uint16_t d; // super-block scale
-} block_q6_K;
-
-void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
- const int64_t nb = k / QK_K;
- block_q6_K *x = (block_q6_K *)vx;
- float16_t* y = (float16_t *)vy;
-
- for (int i = 0; i < nb; i++) {
- float16_t d = 0.0;
- memcpy(&d, &x[i].d, sizeof(d));
-
- const uint8_t * restrict ql = x[i].ql;
- const uint8_t * restrict qh = x[i].qh;
- const int8_t * restrict sc = x[i].scales;
-
- for (int n = 0; n < QK_K; n += 128) {
- for (int l = 0; l < 32; ++l) {
- int is = l/16;
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- y[l + 0] = d * sc[is + 0] * q1;
- y[l + 32] = d * sc[is + 2] * q2;
- y[l + 64] = d * sc[is + 4] * q3;
- y[l + 96] = d * sc[is + 6] * q4;
- }
- y += 128;
- ql += 64;
- qh += 32;
- sc += 8;
- }
- }
-}
-
-#define K_SCALE_SIZE 12
-#define GGML_COMMON_AGGR_U
-#define GGML_COMMON_AGGR_S
-
-// 4-bit quantization
-// 8 blocks of 32 elements each
-// weight is represented as x = a * q + b
-// Effectively 4.5 bits per weight
-typedef struct {
- union {
- struct {
- uint16_t d; // super-block scale for quantized scales
- uint16_t dmin; // super-block scale for quantized mins
- } GGML_COMMON_AGGR_S;
- uint16_t dm;
- } GGML_COMMON_AGGR_U;
- uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
- uint8_t qs[QK_K/2]; // 4--bit quants
-} block_q4_K;
-
-static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
- if (j < 4) {
- *d = q[j] & 63; *m = q[j + 4] & 63;
- } else {
- *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
- *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
- }
-}
-
-void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
- block_q4_K *x = (block_q4_K *)vx;
- float16_t* y = (float16_t *)vy;
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
- const uint8_t * q = x[i].qs;
- float16_t d = 0.0;
- memcpy(&d, &x[i].d, sizeof(d));
- float16_t min = 0.0;
- memcpy(&min, &x[i].dmin, sizeof(d));
-
- int is = 0;
- uint8_t sc, m;
- for (int j = 0; j < QK_K; j += 64) {
- get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
- const float16_t d1 = d * sc; const float16_t m1 = min * m;
- get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
- const float16_t d2 = d * sc; const float16_t m2 = min * m;
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
- q += 32; is += 2;
- }
- }
-}
-
-
-
-*/
-import "C"
-
-import (
- "fmt"
- "unsafe"
-
- "github.com/x448/float16"
-)
-
-func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
- shape := append([]C.int{}, final_shape...)
- var weights_per_byte C.int
- if dtype == 2 || dtype == 3 {
- weights_per_byte = 2
- } else if dtype == 8 {
- weights_per_byte = 1
- } else {
- return r, fmt.Errorf("unsupported tensor type %d", dtype)
- }
-
- weights_per_block := C.int(32)
- if shape[len(shape)-1]%weights_per_block != 0 {
- return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
- }
-
- weights_shape := append([]C.int{}, shape...)
- weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
- w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
- for i := range weights_shape {
- w_nbytes *= weights_shape[i]
- }
- w_data := make([]byte, w_nbytes)
- cbytes := C.CBytes(w_data)
- defer C.free(cbytes)
- weights := C.mlx_array_new_data(
- cbytes,
- &weights_shape[0],
- C.int(len(weights_shape)),
- C.MLX_UINT32,
- )
-
- // For scales and bias
- shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
- sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
- for i := range shape {
- sb_nbytes *= shape[i]
- }
-
- s_data := make([]byte, sb_nbytes)
- cbytes = C.CBytes(s_data)
- defer C.free(cbytes)
- scales := C.mlx_array_new_data(
- cbytes,
- &shape[0],
- C.int(len(shape)),
- C.MLX_FLOAT16,
- )
- b_data := make([]byte, sb_nbytes)
- cbytes = C.CBytes(b_data)
- defer C.free(cbytes)
- biases := C.mlx_array_new_data(
- cbytes,
- &shape[0],
- C.int(len(shape)),
- C.MLX_FLOAT16,
- )
- var bits C.int
- switch dtype {
- case 2:
- C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
- bits = 4
- case 3:
- C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
- bits = 4
- case 8:
- C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
- bits = 8
- }
- groupSize := C.mlx_optional_int{value: 32, has_value: true}
- bitsOpt := C.mlx_optional_int{value: bits, has_value: true}
- var dtypeOpt C.mlx_optional_dtype // has_value defaults to false
- C.mlx_dequantize(
- &r,
- weights,
- scales,
- biases,
- groupSize,
- bitsOpt,
- nil, // TODO mode
- dtypeOpt,
- stream,
- )
- C.mlx_array_free(weights)
- C.mlx_array_free(scales)
- C.mlx_array_free(biases)
-
- return r, nil
-}
-
-func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
- size := 1
- for _, d := range shape {
- size *= int(d)
- }
- fdata := make([]float16.Float16, size)
- switch dtype {
- case 14:
- C.dequant_row_q6_K(
- data,
- unsafe.Pointer(&fdata[0]),
- C.int(size),
- )
-
- case 12:
- C.dequant_row_q4_K(
- data,
- unsafe.Pointer(&fdata[0]),
- C.int(size),
- )
- default:
- return r, fmt.Errorf("unsupported K quant")
- }
-
- r = C.mlx_array_new_data(
- unsafe.Pointer(&fdata[0]),
- &shape[0],
- C.int(len(shape)),
- C.MLX_FLOAT16,
- )
- return r, nil
-}
diff --git a/x/ml/device.go b/x/ml/device.go
deleted file mode 100644
index f892b512d33..00000000000
--- a/x/ml/device.go
+++ /dev/null
@@ -1,643 +0,0 @@
-package ml
-
-import (
- "context"
- "encoding/binary"
- "encoding/json"
- "fmt"
- "hash/maphash"
- "io"
- "log/slog"
- "math"
- "net/http"
- "runtime"
- "slices"
- "sort"
- "strconv"
- "strings"
- "time"
-
- "github.com/ollama/ollama/format"
- "github.com/ollama/ollama/logutil"
-)
-
-// GPULayers is a set of layers to be allocated on a single GPU
-type GPULayers struct {
- DeviceID
-
- // Layers is a set of layer indicies to load
- Layers []int
-}
-
-// FirstLayer returns the smallest layer index scheduled on this GPU, or MaxInt when empty.
-func (g GPULayers) FirstLayer() int {
- if len(g.Layers) == 0 {
- return math.MaxInt
- }
-
- first := g.Layers[0]
- for i := 1; i < len(g.Layers); i++ {
- if g.Layers[i] < first {
- first = g.Layers[i]
- }
- }
-
- return first
-}
-
-func (g GPULayers) String() string {
- if len(g.Layers) == 0 {
- return ""
- }
-
- slices.Sort(g.Layers)
-
- contiguous := true
- base := g.Layers[0]
- for i := range g.Layers {
- if g.Layers[i] != base+i {
- contiguous = false
- break
- }
- }
-
- if contiguous {
- return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
- } else {
- return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
- }
-}
-
-// GPULayersList is a set of layer allocations across multiple GPUs
-type GPULayersList []GPULayers
-
-func (l GPULayersList) Len() int { return len(l) }
-func (l GPULayersList) Swap(i, j int) { l[i], l[j] = l[j], l[i] }
-
-// Sort by the ordering of the layers offloaded
-func (l GPULayersList) Less(i, j int) bool {
- li := l[i].FirstLayer()
- lj := l[j].FirstLayer()
-
- return li < lj
-}
-
-func (l GPULayersList) String() string {
- if l.Sum() > 0 {
- return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
- } else {
- return fmt.Sprintf("%v", []GPULayers(l))
- }
-}
-
-// Sum is the total number of layers assigned across all GPUs
-func (l GPULayersList) Sum() int {
- var sum int
-
- for _, g := range l {
- sum += len(g.Layers)
- }
-
- return sum
-}
-
-var h maphash.Hash
-
-// Hash is an identifier of this layer assignment
-func (l GPULayersList) Hash() uint64 {
- h.Reset()
- for _, g := range l {
- if len(g.Layers) > 0 {
- h.WriteString(g.ID + g.Library)
- for _, l := range g.Layers {
- binary.Write(&h, binary.NativeEndian, int64(l))
- }
- }
- }
-
- return h.Sum64()
-}
-
-// ErrNoMem is returned when panicing due to insufficient memory. It includes
-// the attempted memory allocation.
-type ErrNoMem struct {
- BackendMemory
-}
-
-func (e ErrNoMem) Error() string {
- return fmt.Sprintf("insufficient memory - required allocations: %+v", e.BackendMemory)
-}
-
-// Minimal unique device identification
-type DeviceID struct {
- // ID is an identifier for the device for matching with system
- // management libraries. The ID is only unique for other devices
- // using the same Library.
- // This ID represents a "post filtered" view of the enumerated devices
- // if the ID is numeric
- ID string `json:"id"`
-
- // Library identifies which library is used for the device (e.g. CUDA, ROCm, etc.)
- Library string `json:"backend,omitempty"`
-}
-
-// DeviceMemory provides a breakdown of the memory needed
-// per device, such as a CPU or GPU.
-type DeviceMemory struct {
- DeviceID
-
- // Name is the name of the device as labeled by the backend. It
- // may not be persistent across instances of the runner.
- Name string
-
- // Weights is the per-layer memory needed for the model weights.
- Weights []uint64
-
- // Cache is the per-layer memory needed for the KV cache.
- Cache []uint64
-
- // Graph is the size of the compute graph. It is not per-layer.
- Graph uint64
-}
-
-func sumMemory(mem []uint64) uint64 {
- var sum uint64
-
- for _, m := range mem {
- sum += m
- }
-
- return sum
-}
-
-// Size returns the total size of the memory required by this device
-func (m DeviceMemory) Size() uint64 {
- return sumMemory(m.Weights) + sumMemory(m.Cache) + m.Graph
-}
-
-func memoryPresent(mem []uint64) bool {
- return slices.ContainsFunc(mem, func(m uint64) bool { return m != 0 })
-}
-
-func (m DeviceMemory) LogValue() slog.Value {
- var attrs []slog.Attr
- if memoryPresent(m.Weights) {
- attrs = append(attrs, slog.Any("Weights", m.Weights))
- }
-
- if memoryPresent(m.Cache) {
- attrs = append(attrs, slog.Any("Cache", m.Cache))
- }
-
- if m.Graph != 0 {
- attrs = append(attrs, slog.Any("Graph", m.Graph))
- }
-
- if len(attrs) > 0 && m.ID != "" {
- attrs = append([]slog.Attr{slog.String("ID", m.ID)}, attrs...)
- }
-
- return slog.GroupValue(attrs...)
-}
-
-// BackendMemory provides the amount of memory required to load the model
-// per device based on the BackendParams. In some cases, not all required
-// allocations will be known at this point. However, the size of the most recent
-// allocation is guaranteed to be provided so that if it failed, the caller can
-// accommodate that to make forward progress.
-type BackendMemory struct {
- // InputWeights are always located on the CPU and cannot be moved
- InputWeights uint64
-
- // CPU model components are located in system memory. This does not
- // include unified memory allocated through the GPU.
- CPU DeviceMemory
-
- // GPU model components are located on one or more GPUs.
- GPUs []DeviceMemory
-}
-
-func (m BackendMemory) LogValue() slog.Value {
- var attrs []slog.Attr
- if m.InputWeights != 0 {
- attrs = append(attrs, slog.Any("InputWeights", m.InputWeights))
- }
-
- attrs = append(attrs, slog.Any(m.CPU.Name, m.CPU))
- for _, g := range m.GPUs {
- attrs = append(attrs, slog.Any(g.Name, g))
- }
-
- return slog.GroupValue(attrs...)
-}
-
-// Log prints a high level summary of the memory
-func (m BackendMemory) Log(level slog.Level) {
- var total uint64
-
- for _, gpu := range m.GPUs {
- if sum := sumMemory(gpu.Weights); sum > 0 {
- slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
- total += sum
- }
- }
- if sum := m.InputWeights + sumMemory(m.CPU.Weights); sum > 0 {
- slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
- total += sum
- }
-
- for _, gpu := range m.GPUs {
- if sum := sumMemory(gpu.Cache); sum > 0 {
- slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
- total += sum
- }
- }
- if sum := sumMemory(m.CPU.Cache); sum > 0 {
- slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
- total += sum
- }
-
- for _, gpu := range m.GPUs {
- if sum := gpu.Graph; sum > 0 {
- slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
- total += sum
- }
- }
- if sum := m.CPU.Graph; sum > 0 {
- slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
- total += sum
- }
-
- if total > 0 {
- slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
- }
-}
-
-type DeviceInfo struct {
- DeviceID
-
- // Name is the name of the device as labeled by the backend. It
- // may not be persistent across instances of the runner.
- Name string `json:"name"`
-
- // Description is the longer user-friendly identification of the device
- Description string `json:"description"`
-
- // FilterID is populated with the unfiltered device ID if a numeric ID is used
- // so the device can be included.
- FilterID string `json:"filter_id,omitempty"`
-
- // Integrated is set true for integrated GPUs, false for Discrete GPUs
- Integrated bool `json:"integration,omitempty"`
-
- // PCIID is the bus, device and domain ID of the device for deduplication
- // when discovered by multiple backends
- PCIID string `json:"pci_id,omitempty"`
-
- // TotalMemory is the total amount of memory the device can use for loading models
- TotalMemory uint64 `json:"total_memory"`
-
- // FreeMemory is the amount of memory currently available on the device for loading models
- FreeMemory uint64 `json:"free_memory,omitempty"`
-
- // ComputeMajor is the major version of capabilities of the device
- // if unsupported by the backend, -1 will be returned
- ComputeMajor int
-
- // ComputeMinor is the minor version of capabilities of the device
- // if unsupported by the backend, -1 will be returned
- ComputeMinor int
-
- // Driver Information
- DriverMajor int `json:"driver_major,omitempty"`
- DriverMinor int `json:"driver_minor,omitempty"`
-
- // Where backends were loaded from
- LibraryPath []string
-}
-
-type SystemInfo struct {
- // ThreadCount is the optimal number of threads to use for inference
- ThreadCount int `json:"threads,omitempty"`
-
- // TotalMemory is the total amount of system memory
- TotalMemory uint64 `json:"total_memory,omitempty"`
-
- // FreeMemory is the amount of memory currently available on the system for loading models
- FreeMemory uint64 `json:"free_memory,omitempty"`
-
- // FreeSwap is the amount of system swap space reported as available
- FreeSwap uint64 `json:"free_swap,omitempty"`
-}
-
-func (d DeviceInfo) Compute() string {
- // AMD gfx is encoded into the major minor in hex form
- if strings.EqualFold(d.Library, "ROCm") {
- return fmt.Sprintf("gfx%x%02x", d.ComputeMajor, d.ComputeMinor)
- }
- return strconv.Itoa(d.ComputeMajor) + "." + strconv.Itoa(d.ComputeMinor)
-}
-
-func (d DeviceInfo) Driver() string {
- return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
-}
-
-// MinimumMemory reports the amount of memory that should be set aside
-// on the device for overhead (e.g. VRAM consumed by context structures independent
-// of model allocations)
-func (d DeviceInfo) MinimumMemory() uint64 {
- if d.Library == "Metal" {
- return 512 * format.MebiByte
- }
- return 457 * format.MebiByte
-}
-
-// Sort by Free Space.
-// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
-type ByFreeMemory []DeviceInfo
-
-func (a ByFreeMemory) Len() int { return len(a) }
-func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
-func (a ByFreeMemory) Less(i, j int) bool {
- if a[i].Integrated && !a[j].Integrated {
- return true
- } else if !a[i].Integrated && a[j].Integrated {
- return false
- }
- return a[i].FreeMemory < a[j].FreeMemory
-}
-
-// ByPerformance groups devices by similar speed
-func ByPerformance(l []DeviceInfo) [][]DeviceInfo {
- resp := [][]DeviceInfo{}
- scores := []bool{}
- for _, info := range l {
- found := false
- requested := info.Integrated
- for i, score := range scores {
- if score == requested {
- resp[i] = append(resp[i], info)
- found = true
- break
- }
- }
- if !found {
- scores = append(scores, requested)
- resp = append(resp, []DeviceInfo{info})
- }
- }
- return resp
-}
-
-func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
- resp := [][]DeviceInfo{}
- libs := []string{}
- for _, info := range l {
- found := false
- requested := info.Library
- for i, lib := range libs {
- if lib == requested {
- resp[i] = append(resp[i], info)
- found = true
- break
- }
- }
- if !found {
- libs = append(libs, requested)
- resp = append(resp, []DeviceInfo{info})
- }
- }
- return resp
-}
-
-func LibraryPaths(l []DeviceInfo) []string {
- gpuLibs := []string{LibOllamaPath}
- for _, gpu := range l {
- for _, dir := range gpu.LibraryPath {
- needed := true
- for _, existing := range gpuLibs {
- if dir == existing {
- needed = false
- break
- }
- }
- if needed {
- gpuLibs = append(gpuLibs, dir)
- }
- }
- }
- return gpuLibs
-}
-
-type DeviceComparison int
-
-const (
- UniqueDevice DeviceComparison = iota
- SameBackendDevice // The device is the same, and the library/backend is the same
- DuplicateDevice // The same physical device but different library/backend (overlapping device)
-)
-
-func (a DeviceInfo) Compare(b DeviceInfo) DeviceComparison {
- if a.PCIID != b.PCIID {
- return UniqueDevice
- }
- // If PCIID is empty, we have to use ID + library for uniqueness
- if a.PCIID == "" && a.DeviceID != b.DeviceID {
- return UniqueDevice
- }
- if a.Library == b.Library {
- return SameBackendDevice
- }
- return DuplicateDevice
-}
-
-// For a SameBackendDevice, return true if b is better than a
-// e.g. newer GPU library version
-func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
- aLib := a.LibraryPath[len(a.LibraryPath)-1]
- bLib := b.LibraryPath[len(b.LibraryPath)-1]
- if aLib == bLib {
- return false
- }
- aLibSplit := strings.SplitN(aLib, "_", 2)
- bLibSplit := strings.SplitN(bLib, "_", 2)
- if len(aLibSplit) < 2 || len(bLibSplit) < 2 {
- return false
- }
- if aLibSplit[0] != bLibSplit[0] {
- slog.Debug("unexpected libraries", "a", aLib, "b", bLib)
- return false
- }
- if aLibSplit[1] == bLibSplit[1] {
- return false
- }
- cmp := []string{aLibSplit[1], bLibSplit[1]}
- sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
- return cmp[0] == bLibSplit[1]
-}
-
-// For each GPU, check if it does NOT support flash attention
-func FlashAttentionSupported(l []DeviceInfo) bool {
- for _, gpu := range l {
- supportsFA := gpu.Library == "cpu" ||
- gpu.Name == "Metal" || gpu.Library == "Metal" ||
- (gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
- gpu.Library == "ROCm" ||
- gpu.Library == "Vulkan"
-
- if !supportsFA {
- return false
- }
- }
- return true
-}
-
-// Given the list of GPUs this instantiation is targeted for,
-// figure out the visible devices environment variables
-// Set mustFilter true to enable filtering of CUDA devices
-func GetVisibleDevicesEnv(l []DeviceInfo, mustFilter bool) map[string]string {
- if len(l) == 0 {
- return nil
- }
- env := map[string]string{}
- for _, d := range l {
- d.updateVisibleDevicesEnv(env, mustFilter)
- }
- return env
-}
-
-// NeedsInitValidation returns true if the device in question has the potential
-// to crash at inference time and requires deeper validation before we include
-// it in the supported devices list.
-func (d DeviceInfo) NeedsInitValidation() bool {
- // ROCm: rocblas will crash on unsupported devices.
- // CUDA: verify CC is supported by the version of the library
- return d.Library == "ROCm" || d.Library == "CUDA"
-}
-
-// Set the init validation environment variable
-func (d DeviceInfo) AddInitValidation(env map[string]string) {
- env["GGML_CUDA_INIT"] = "1" // force deep initialization to trigger crash on unsupported GPUs
-}
-
-// PreferredLibrary returns true if this library is preferred over the other input
-// library
-// Used to filter out Vulkan in favor of CUDA or ROCm
-func (d DeviceInfo) PreferredLibrary(other DeviceInfo) bool {
- // TODO in the future if we find Vulkan is better than ROCm on some devices
- // that implementation can live here.
-
- if d.Library == "CUDA" || d.Library == "ROCm" {
- return true
- }
- return false
-}
-
-func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string, mustFilter bool) {
- var envVar string
- switch d.Library {
- case "ROCm":
- // ROCm must be filtered as it can crash the runner on unsupported devices
- envVar = "ROCR_VISIBLE_DEVICES"
- if runtime.GOOS != "linux" {
- envVar = "HIP_VISIBLE_DEVICES"
- }
- case "CUDA":
- if !mustFilter {
- // By default we try to avoid filtering CUDA devices because ROCm also
- // looks at the CUDA env var, and gets confused in mixed vendor environments.
- return
- }
- envVar = "CUDA_VISIBLE_DEVICES"
- default:
- // Vulkan is not filtered via env var, but via scheduling decisions
- return
- }
- v, existing := env[envVar]
- if existing {
- v = v + ","
- }
- if d.FilterID != "" {
- v = v + d.FilterID
- } else {
- v = v + d.ID
- }
- env[envVar] = v
-}
-
-type BaseRunner interface {
- // GetPort returns the localhost port number the runner is running on
- GetPort() int
-
- // HasExited indicates if the runner is no longer running. This can be used during
- // bootstrap to detect if a given filtered device is incompatible and triggered an assert
- HasExited() bool
-}
-
-type RunnerDiscovery interface {
- BaseRunner
-
- // GetDeviceInfos will perform a query of the underlying device libraries
- // for device identification and free VRAM information
- // During bootstrap scenarios, this routine may take seconds to complete
- GetDeviceInfos(ctx context.Context) []DeviceInfo
-}
-
-type FilteredRunnerDiscovery interface {
- RunnerDiscovery
-
- // GetActiveDeviceIDs returns the filtered set of devices actively in
- // use by this runner for running models. If the runner is a bootstrap runner, no devices
- // will be active yet so no device IDs are returned.
- // This routine will not query the underlying device and will return immediately
- GetActiveDeviceIDs() []DeviceID
-}
-
-func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
- var moreDevices []DeviceInfo
- port := runner.GetPort()
- tick := time.Tick(10 * time.Millisecond)
- for {
- select {
- case <-ctx.Done():
- return nil, fmt.Errorf("failed to finish discovery before timeout")
- case <-tick:
- r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
- if err != nil {
- return nil, fmt.Errorf("failed to create request: %w", err)
- }
- r.Header.Set("Content-Type", "application/json")
-
- resp, err := http.DefaultClient.Do(r)
- if err != nil {
- // slog.Warn("failed to send request", "error", err)
- if runner.HasExited() {
- return nil, fmt.Errorf("runner crashed")
- }
- continue
- }
- defer resp.Body.Close()
-
- if resp.StatusCode == http.StatusNotFound {
- // old runner, fall back to bootstrapping model
- return nil, fmt.Errorf("llamarunner free vram reporting not supported")
- }
-
- body, err := io.ReadAll(resp.Body)
- if err != nil {
- slog.Warn("failed to read response", "error", err)
- continue
- }
- if resp.StatusCode != 200 {
- logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
- return nil, fmt.Errorf("runner error: %s", string(body))
- }
-
- if err := json.Unmarshal(body, &moreDevices); err != nil {
- slog.Warn("unmarshal encode response", "error", err)
- continue
- }
- return moreDevices, nil
- }
- }
-}
diff --git a/x/ml/nn/attention.go b/x/ml/nn/attention.go
deleted file mode 100644
index c4a16a3028b..00000000000
--- a/x/ml/nn/attention.go
+++ /dev/null
@@ -1,103 +0,0 @@
-package nn
-
-import (
- "fmt"
-
- "github.com/ollama/ollama/x/kvcache"
- "github.com/ollama/ollama/x/ml"
-)
-
-// Attention implements scaled dot-product attention for transformer models:
-// Attention(Q, K, V) = softmax(QK^T/√d_k)V
-//
-// Parameters:
-// - ctx: Context for tensor operations
-// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
-// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
-// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
-// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
-// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
-//
-// Returns:
-//
-// Attention output with shape [d_v, heads, seq_len_q]
-func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
- return AttentionWithVMLA(ctx, query, key, value, nil, nil, scale, cache)
-}
-
-func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
- return AttentionWithVMLA(ctx, query, key, value, sinks, nil, scale, cache)
-}
-
-func AttentionWithVMLA(ctx ml.Context, query, key, value, sinks ml.Tensor, vmla ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
- ctx.Forward(query)
-
- if key != nil && value != nil {
- if query.Dim(0) != key.Dim(0) {
- panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
- }
-
- if key.Dim(1) != value.Dim(1) {
- panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
- }
-
- if key.Dim(2) != value.Dim(2) {
- panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
- }
-
- ctx.Forward(key, value)
- if cache != nil {
- cache.Put(ctx, key, value)
- }
- } else if cache == nil {
- panic("key & value tensors must be provided if cache is nil")
- }
-
- // ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query, "k": key, "v": value}, true)
- // panic("after cache get") //
- // 2025/12/10 16:02:33 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
- // 2025/12/10 16:02:33 INFO XXX tensors are similar k=0.9999891519546509 shape="[1 4 13 256]" min_difference=[-0.21365738] max_difference=[0.19916534]
- // 2025/12/10 16:02:33 INFO XXX tensors are similar v=0.9999960660934448 shape="[1 4 13 256]" min_difference=[-0.32923126] max_difference=[0.32646942]
-
- // var mask ml.Tensor
- if cache != nil {
- key, value, _ = cache.Get(ctx)
- }
- // ctx.CompareWith("/tmp/test", map[string]ml.Tensor{"q": query.Contiguous(ctx, false), "k": key.Contiguous(ctx, false), "v": value.Contiguous(ctx, false)}, true)
- // panic("after cache get") //
- // 2025/12/10 15:34:03 INFO XXX tensors are similar q=0.9999869465827942 shape="[1 8 13 256]" min_difference=[-0.07926178] max_difference=[0.07012844]
- // 2025/12/10 15:34:03 INFO XXX tensors are similar k=0.9999881982803345 shape="[1 4 13 256]" min_difference=[-0.25] max_difference=[0.25]
- // 2025/12/10 15:34:03 INFO XXX tensors are similar v=0.9999913573265076 shape="[1 4 13 256]" min_difference=[-0.5] max_difference=[0.5]
-
- // Only use the fast SDPA implementation if we have a cache, since that's what
- // will do any expected backend-specific transformations for us
-
- if cache != nil {
- // TODO what to do with vmla?
- // return query.Transpose(ctx, 0, 2, 1, 3).ScaledDotProductAttention(ctx, key.Transpose(ctx, 0, 2, 1, 3), value.Transpose(ctx, 0, 2, 1, 3), scale, "array", mask, sinks)
- return query.ScaledDotProductAttention(ctx, key, value, scale, "causal", nil, sinks)
-
- // TODO these two produce identical output, but not similar enough - 92.9% - should be 99.999%
- } else {
- panic("else case not supported")
- // TODO transpose shapes are wrong
- // key = key.Transpose(ctx, 0, 2, 1, 3)
- // value = value.Transpose(ctx, 1, 2, 0, 3).Contiguous(ctx, false)
-
- // kq := query.Matmul(ctx, key)
-
- // kq = kq.Scale(ctx, scale)
- // if mask != nil {
- // kq = kq.Add(ctx, mask)
- // }
- // kq = kq.Softmax(ctx)
-
- // kqv := kq.Matmul(ctx, value)
-
- // if vmla != nil {
- // kqv = kqv.Matmul(ctx, vmla)
- // }
-
- // return kqv.Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
- }
-}
diff --git a/x/ml/nn/convolution.go b/x/ml/nn/convolution.go
deleted file mode 100644
index 7c4b5a52003..00000000000
--- a/x/ml/nn/convolution.go
+++ /dev/null
@@ -1,30 +0,0 @@
-package nn
-
-import "github.com/ollama/ollama/x/ml"
-
-type Conv2D struct {
- Weight ml.Tensor `gguf:"weight"`
- Bias ml.Tensor `gguf:"bias"`
-}
-
-func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
- t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1, 1)
- if m.Bias != nil {
- // Bias shape is (out_channels,) while t shape is (width, height, out_channels, batch)
- t = t.Add(ctx, m.Bias.Reshape(ctx, 1, 1, -1))
- }
- return t
-}
-
-type Conv3D struct {
- Weight ml.Tensor `gguf:"weight"`
- Bias ml.Tensor `gguf:"bias"`
-}
-
-func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, s2, p0, p1, p2, d0, d1, d2, g int) ml.Tensor {
- t = m.Weight.Conv3D(ctx, t, s0, s1, s2, p0, p1, p2, d0, d1, d2, g)
- if m.Bias != nil {
- t = t.Add(ctx, m.Bias)
- }
- return t
-}
diff --git a/x/ml/nn/embedding.go b/x/ml/nn/embedding.go
deleted file mode 100644
index b00aa2ff1cd..00000000000
--- a/x/ml/nn/embedding.go
+++ /dev/null
@@ -1,11 +0,0 @@
-package nn
-
-import "github.com/ollama/ollama/x/ml"
-
-type Embedding struct {
- Weight ml.Tensor `gguf:"weight"`
-}
-
-func (m *Embedding) Forward(ctx ml.Context, hiddenState ml.Tensor) ml.Tensor {
- return m.Weight.TakeAxes(ctx, hiddenState, 0)
-}
diff --git a/x/ml/nn/linear.go b/x/ml/nn/linear.go
deleted file mode 100644
index 6d108e0950a..00000000000
--- a/x/ml/nn/linear.go
+++ /dev/null
@@ -1,32 +0,0 @@
-package nn
-
-import "github.com/ollama/ollama/x/ml"
-
-type Linear struct {
- Weight ml.Tensor `gguf:"weight"`
- Bias ml.Tensor `gguf:"bias"`
-}
-
-func (m *Linear) Forward(ctx ml.Context, t ml.Tensor) ml.Tensor {
- t = t.Matmul(ctx, m.Weight.Transpose(ctx))
- if m.Bias != nil {
- t = t.Add(ctx, m.Bias)
- }
-
- return t
-}
-
-type LinearBatch struct {
- Weight ml.Tensor `gguf:"weight"`
- Bias ml.Tensor `gguf:"bias"`
-}
-
-func (m *LinearBatch) Forward(ctx ml.Context, t, indices ml.Tensor) ml.Tensor {
- panic("not yet ported")
- // t = m.Weight.MulmatID(ctx, t, indices)
- // if m.Bias != nil {
- // t = t.AddID(ctx, m.Bias, indices)
- // }
-
- // return t
-}
diff --git a/x/ml/nn/normalization.go b/x/ml/nn/normalization.go
deleted file mode 100644
index 621245ab46a..00000000000
--- a/x/ml/nn/normalization.go
+++ /dev/null
@@ -1,29 +0,0 @@
-package nn
-
-import (
- "github.com/ollama/ollama/x/ml"
-)
-
-type LayerNorm struct {
- Weight ml.Tensor `gguf:"weight"`
- Bias ml.Tensor `gguf:"bias"`
-}
-
-func (m *LayerNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
- return t.LayerNorm(ctx, m.Weight, m.Bias, eps)
-}
-
-type RMSNorm struct {
- Weight ml.Tensor `gguf:"weight"`
-}
-
-func (m *RMSNorm) Forward(ctx ml.Context, t ml.Tensor, eps float32) ml.Tensor {
- // slog.Info("RMSNorm", "eps", eps)
- // fmt.Fprintln(os.Stderr, t.ToString())
- // fmt.Fprintln(os.Stderr, m.Weight.ToString())
-
- // TODO this is probably model specific, not generalized...
- w := m.Weight.Add(ctx, ctx.FromFloats([]float32{1.0}, 1))
-
- return t.RMSNorm(ctx, w, eps)
-}
diff --git a/x/ml/nn/pooling/pooling.go b/x/ml/nn/pooling/pooling.go
deleted file mode 100644
index 2dae6dc4381..00000000000
--- a/x/ml/nn/pooling/pooling.go
+++ /dev/null
@@ -1,41 +0,0 @@
-package pooling
-
-import (
- "github.com/ollama/ollama/x/ml"
-)
-
-type Type uint32
-
-const (
- TypeNone Type = iota
- TypeMean
- TypeCLS
- TypeLast
-)
-
-func (t Type) String() string {
- switch t {
- case TypeMean:
- return "Mean"
- case TypeCLS:
- return "CLS"
- case TypeLast:
- return "Last"
- default:
- return "Unknown"
- }
-}
-
-func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
- switch t {
- // case TypeMean:
- // hiddenStates = hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false).Mean(ctx)
- // return hiddenStates.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
- // case TypeCLS:
- // return hiddenStates.Slice(ctx, 1, 0, 1, 1)
- // case TypeLast:
- // return hiddenStates.Slice(ctx, 1, hiddenStates.Dim(1)-1, hiddenStates.Dim(1), 1)
- default:
- panic("unknown pooling type")
- }
-}
diff --git a/x/ml/nn/rope/rope.go b/x/ml/nn/rope/rope.go
deleted file mode 100644
index e868aa61498..00000000000
--- a/x/ml/nn/rope/rope.go
+++ /dev/null
@@ -1,72 +0,0 @@
-package rope
-
-import "github.com/ollama/ollama/x/ml"
-
-// Options contains optional parameters for RoPE function
-type Options struct {
- Type int
- Factors ml.Tensor
-
- // YaRN options
- YaRN struct {
- OriginalContextLength int
- ExtrapolationFactor,
- AttentionFactor,
- BetaFast,
- BetaSlow float32
- }
-
- // MRoPE options
- MRoPE struct {
- Sections []int
- }
-}
-
-// WithTypeNeoX sets RoPE type to NeoX
-func WithTypeNeoX() func(*Options) {
- return func(opts *Options) {
- opts.Type = 2
- }
-}
-
-// WithFactors sets custom rope factors
-func WithFactors(factors ml.Tensor) func(*Options) {
- return func(opts *Options) {
- if factors != nil {
- opts.Factors = factors
- }
- }
-}
-
-// WithOriginalContextLength sets a custom context length
-func WithOriginalContextLength(n int) func(*Options) {
- return func(opts *Options) {
- opts.YaRN.OriginalContextLength = n
- }
-}
-
-func WithExtrapolationFactor(extrapolationFactor float32) func(*Options) {
- return func(opts *Options) {
- opts.YaRN.ExtrapolationFactor = extrapolationFactor
- }
-}
-
-func WithAttentionFactor(attentionFactor float32) func(*Options) {
- return func(opts *Options) {
- opts.YaRN.AttentionFactor = attentionFactor
- }
-}
-
-func WithMRoPE(sections []int) func(*Options) {
- return func(opts *Options) {
- opts.Type |= 1 << 3
- opts.MRoPE.Sections = sections
- }
-}
-
-func WithInterleaveMRoPE(sections []int) func(*Options) {
- return func(opts *Options) {
- opts.Type |= 1<<3 | 1<<5
- opts.MRoPE.Sections = sections
- }
-}
diff --git a/x/ml/path.go b/x/ml/path.go
deleted file mode 100644
index ac93af403f2..00000000000
--- a/x/ml/path.go
+++ /dev/null
@@ -1,56 +0,0 @@
-package ml
-
-import (
- "os"
- "path/filepath"
- "runtime"
-)
-
-// LibPath is a path to lookup dynamic libraries
-// in development it's usually 'build/lib/ollama'
-// in distribution builds it's 'lib/ollama' on Windows
-// '../lib/ollama' on Linux and the executable's directory on macOS
-// note: distribution builds, additional GPU-specific libraries are
-// found in subdirectories of the returned path, such as
-// 'cuda_v12', 'rocm', etc.
-var LibOllamaPath string = func() string {
- exe, err := os.Executable()
- if err != nil {
- return ""
- }
-
- if eval, err := filepath.EvalSymlinks(exe); err == nil {
- exe = eval
- }
-
- var libPath string
- switch runtime.GOOS {
- case "windows":
- libPath = filepath.Join(filepath.Dir(exe), "lib", "ollama")
- case "linux":
- libPath = filepath.Join(filepath.Dir(exe), "..", "lib", "ollama")
- case "darwin":
- libPath = filepath.Dir(exe)
- }
-
- cwd, err := os.Getwd()
- if err != nil {
- return ""
- }
-
- paths := []string{
- libPath,
-
- // build paths for development
- filepath.Join(filepath.Dir(exe), "build", "lib", "ollama"),
- filepath.Join(cwd, "build", "lib", "ollama"),
- }
-
- for _, p := range paths {
- if _, err := os.Stat(p); err == nil {
- return p
- }
- }
-
- return filepath.Dir(exe)
-}()
diff --git a/x/mlxrunner/cache.go b/x/mlxrunner/cache.go
new file mode 100644
index 00000000000..49ddd04b6cc
--- /dev/null
+++ b/x/mlxrunner/cache.go
@@ -0,0 +1,96 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ "log/slog"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+)
+
+type CacheEntry struct {
+ Caches []cache.Cache
+ Count int
+ Entries map[int32]*CacheEntry
+}
+
+func (s Runner) FindNearestCache(tokens []int32) ([]cache.Cache, []int32) {
+ current := &CacheEntry{Entries: s.CacheEntries}
+ index, cacheIndex := 0, -1
+ for _, token := range tokens {
+ if _, ok := current.Entries[token]; !ok {
+ break
+ }
+
+ current = current.Entries[token]
+ if len(current.Caches) > 0 {
+ cacheIndex = index
+ }
+
+ index += 1
+ }
+
+ if cacheIndex == len(tokens)-1 {
+ slog.Info("Cache hit", "type", "exact", "total", len(tokens), "cached", len(tokens), "left", len(tokens))
+ return current.Caches, []int32{}
+ } else if cacheIndex > 1 {
+ slog.Info("Cache hit", "type", "partial", "total", len(tokens), "cached", cacheIndex+1, "left", len(tokens[cacheIndex+1:]))
+ return current.Caches, tokens[cacheIndex+1:]
+ } else if index > 0 && cacheIndex < 0 {
+ type stackItem struct {
+ entry *CacheEntry
+ tokens []int32
+ }
+
+ var best, item stackItem
+ stack := []stackItem{{entry: current, tokens: []int32{}}}
+ for len(stack) > 0 {
+ item, stack = stack[len(stack)-1], stack[:len(stack)-1]
+ if len(item.entry.Caches) > 0 {
+ if len(best.tokens) == 0 || len(item.tokens) < len(best.tokens) {
+ best = item
+ }
+ } else {
+ for token, entry := range item.entry.Entries {
+ stack = append(stack, stackItem{
+ entry: entry,
+ tokens: append(item.tokens, token),
+ })
+ }
+ }
+ }
+
+ prefix := min(len(tokens)-1, index)
+ caches := make([]cache.Cache, len(best.entry.Caches))
+ trim := len(best.tokens)+1
+ for i := range caches {
+ caches[i] = best.entry.Caches[i].Clone()
+ caches[i].Trim(trim)
+ }
+
+ slog.Info("Cache hit", "type", "prefix", "total", len(tokens), "cached", prefix, "left", len(tokens[prefix:]), "trimmed", trim)
+ return caches, tokens[prefix:]
+ }
+
+ slog.Info("Cache miss", "left", len(tokens))
+ return nil, tokens
+}
+
+func (s *Runner) InsertCache(tokens []int32, caches []cache.Cache) {
+ current := &CacheEntry{Entries: s.CacheEntries}
+ for _, token := range tokens {
+ if _, ok := current.Entries[token]; !ok {
+ current.Entries[token] = &CacheEntry{
+ Entries: make(map[int32]*CacheEntry),
+ }
+ }
+
+ current = current.Entries[token]
+ }
+
+ if len(current.Caches) > 0 {
+ current.Count += 1
+ } else {
+ current.Caches = caches
+ }
+}
diff --git a/x/mlxrunner/cache/cache.go b/x/mlxrunner/cache/cache.go
new file mode 100644
index 00000000000..05cffbf5ee6
--- /dev/null
+++ b/x/mlxrunner/cache/cache.go
@@ -0,0 +1,198 @@
+//go:build mlx
+
+package cache
+
+import (
+ "log/slog"
+
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+)
+
+type Cache interface {
+ Update(keys, values *mlx.Array) (newKeys, newValues *mlx.Array)
+ State() (keys, values *mlx.Array)
+ Trim(int) int
+ Clone() Cache
+ Offset() int
+ Len() int
+}
+
+type KVCache struct {
+ keys, values *mlx.Array
+ offset int
+ step int
+}
+
+func NewKVCache() *KVCache {
+ return &KVCache{step: 256}
+}
+
+func (c *KVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
+ B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
+
+ prev := c.offset
+
+ // Grow buffer if needed
+ if c.keys == nil || (prev+L) > c.keys.Dim(2) {
+ steps := (c.step + L - 1) / c.step
+ newKeys := mlx.Zeros(keys.DType(), B, H, steps*c.step, Dk)
+ newValues := mlx.Zeros(values.DType(), B, H, steps*c.step, Dv)
+
+ if c.keys != nil {
+ if prev%c.step != 0 {
+ c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
+ c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, prev), mlx.Slice()))
+ }
+ c.keys.Set(c.keys.Concatenate(2, newKeys))
+ c.values.Set(c.values.Concatenate(2, newValues))
+ } else {
+ c.keys, c.values = newKeys, newValues
+ }
+ }
+
+ c.offset += L
+ c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
+ c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(prev, c.offset), mlx.Slice()))
+
+ return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
+ c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
+}
+
+func (c *KVCache) State() (*mlx.Array, *mlx.Array) {
+ if c.offset == c.keys.Dim(2) {
+ return c.keys, c.values
+ }
+ return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
+ c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
+}
+
+func (c *KVCache) Trim(n int) int {
+ n = min(c.offset, n)
+ c.offset -= n
+ return n
+}
+
+func (c *KVCache) Clone() Cache {
+ return &KVCache{
+ keys: c.keys.Clone(),
+ values: c.values.Clone(),
+ offset: c.offset,
+ step: c.step,
+ }
+}
+
+func (c *KVCache) Offset() int { return c.offset }
+func (c *KVCache) Len() int { return c.offset }
+
+// RotatingKVCache implements sliding window attention with bounded memory
+type RotatingKVCache struct {
+ maxSize int
+ idx int
+
+ *KVCache
+}
+
+func NewRotatingKVCache(maxSize int) *RotatingKVCache {
+ return &RotatingKVCache{maxSize: maxSize, KVCache: NewKVCache()}
+}
+
+func (c *RotatingKVCache) Update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
+ if keys.Dim(2) > 1 {
+ return c.concat(keys, values)
+ }
+ return c.update(keys, values)
+}
+
+func (c *RotatingKVCache) concat(keys, values *mlx.Array) (newK *mlx.Array, newV *mlx.Array) {
+ slog.Debug("(*RotatingKVCache).concat", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
+ if c.keys == nil {
+ c.keys, c.values = keys, values
+ } else {
+ if c.idx < c.keys.Dim(2) {
+ c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
+ c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.idx), mlx.Slice()))
+ }
+
+ // Trim to max_size to maintain sliding window
+ if trim := c.idx - c.maxSize + 1; trim > 0 {
+ c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
+ c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
+ }
+
+ c.keys.Set(c.keys.Concatenate(2, keys))
+ c.values.Set(c.values.Concatenate(2, values))
+ c.idx = c.keys.Dim(2)
+ }
+
+ c.offset += keys.Dim(2)
+ c.idx = c.keys.Dim(2)
+ return c.keys, c.values
+}
+
+func (c *RotatingKVCache) update(keys, values *mlx.Array) (*mlx.Array, *mlx.Array) {
+ slog.Debug("(*RotatingKVCache).update", "keys_dim", keys.Dims(), "values_dim", values.Dims(), "offset", c.offset, "idx", c.idx, "max_size", c.maxSize)
+ B, H, L, Dk, Dv := keys.Dim(0), keys.Dim(1), keys.Dim(2), keys.Dim(3), values.Dim(3)
+
+ prev := c.offset
+
+ // Grow buffer if not yet at max
+ if c.keys == nil || (prev >= c.keys.Dim(2) && c.keys.Dim(2) < c.maxSize) {
+ newSize := min(c.step, c.maxSize-prev)
+ newKeys := mlx.Zeros(keys.DType(), B, H, newSize, Dk)
+ newValues := mlx.Zeros(values.DType(), B, H, newSize, Dv)
+ if c.keys != nil {
+ c.keys.Set(c.keys.Concatenate(2, newKeys))
+ c.values.Set(c.values.Concatenate(2, newValues))
+ } else {
+ c.keys, c.values = newKeys, newValues
+ }
+ c.idx = prev
+ }
+
+ // Trim to max_size to maintain sliding window
+ if trim := c.keys.Dim(2) - c.maxSize; trim > 0 {
+ c.keys.Set(c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.keys.Dim(2)), mlx.Slice()))
+ c.values.Set(c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(trim, c.values.Dim(2)), mlx.Slice()))
+ c.idx = c.maxSize
+ }
+
+ // Rotate when hitting max
+ if c.idx >= c.maxSize {
+ c.idx = 0
+ }
+
+ c.keys.Set(c.keys.SliceUpdate(keys, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
+ c.values.Set(c.values.SliceUpdate(values, mlx.Slice(), mlx.Slice(), mlx.Slice(c.idx, c.idx+L), mlx.Slice()))
+
+ c.offset += L
+ c.idx += L
+
+ validLen := min(c.offset, c.maxSize)
+ return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice()),
+ c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, validLen), mlx.Slice())
+}
+
+func (c *RotatingKVCache) State() (*mlx.Array, *mlx.Array) {
+ if c.offset < c.keys.Dim(2) {
+ return c.keys.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice()),
+ c.values.Slice(mlx.Slice(), mlx.Slice(), mlx.Slice(0, c.offset), mlx.Slice())
+ }
+ return c.keys, c.values
+}
+
+func (c *RotatingKVCache) Trim(n int) int {
+ n = min(c.offset, n)
+ c.offset -= n
+ c.idx -= n
+ return n
+}
+
+func (c *RotatingKVCache) Clone() Cache {
+ return &RotatingKVCache{
+ maxSize: c.maxSize,
+ idx: c.idx,
+ KVCache: c.KVCache.Clone().(*KVCache),
+ }
+}
+
+func (c *RotatingKVCache) Len() int { return min(c.offset, c.maxSize) }
diff --git a/x/mlxrunner/client.go b/x/mlxrunner/client.go
new file mode 100644
index 00000000000..19e987736fb
--- /dev/null
+++ b/x/mlxrunner/client.go
@@ -0,0 +1,414 @@
+package mlxrunner
+
+import (
+ "bufio"
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "log/slog"
+ "math"
+ "math/rand"
+ "net"
+ "net/http"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "runtime"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/ollama/ollama/llm"
+ "github.com/ollama/ollama/ml"
+ "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/x/imagegen/manifest"
+)
+
+// Client wraps an MLX runner subprocess to implement llm.LlamaServer for LLM models.
+type Client struct {
+ port int
+ modelName string
+ vramSize uint64
+ done chan error
+ client *http.Client
+ lastErr string
+ lastErrLock sync.Mutex
+ mu sync.Mutex
+ cmd *exec.Cmd
+}
+
+// NewClient spawns a new MLX runner subprocess for LLM models and waits until it's ready.
+func NewClient(modelName string) (*Client, error) {
+ if err := imagegen.CheckPlatformSupport(); err != nil {
+ return nil, err
+ }
+
+ // Find a free port
+ port := 0
+ if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
+ if l, err := net.ListenTCP("tcp", a); err == nil {
+ port = l.Addr().(*net.TCPAddr).Port
+ l.Close()
+ }
+ }
+ if port == 0 {
+ port = rand.Intn(65535-49152) + 49152
+ }
+
+ // Get the current executable path
+ exe, err := os.Executable()
+ if err != nil {
+ return nil, fmt.Errorf("unable to lookup executable path: %w", err)
+ }
+ if eval, err := filepath.EvalSymlinks(exe); err == nil {
+ exe = eval
+ }
+
+ // Spawn subprocess: ollama runner --mlx-engine --model --port
+ cmd := exec.Command(exe, "runner", "--mlx-engine", "--model", modelName, "--port", strconv.Itoa(port))
+ cmd.Env = os.Environ()
+
+ // On Linux, set LD_LIBRARY_PATH to include MLX library directories
+ if runtime.GOOS == "linux" {
+ libraryPaths := []string{ml.LibOllamaPath}
+ if mlxDirs, err := filepath.Glob(filepath.Join(ml.LibOllamaPath, "mlx_*")); err == nil {
+ libraryPaths = append(libraryPaths, mlxDirs...)
+ }
+
+ if existingPath, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
+ libraryPaths = append(libraryPaths, filepath.SplitList(existingPath)...)
+ }
+
+ pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
+
+ found := false
+ for i := range cmd.Env {
+ if strings.HasPrefix(cmd.Env[i], "LD_LIBRARY_PATH=") {
+ cmd.Env[i] = "LD_LIBRARY_PATH=" + pathEnvVal
+ found = true
+ break
+ }
+ }
+ if !found {
+ cmd.Env = append(cmd.Env, "LD_LIBRARY_PATH="+pathEnvVal)
+ }
+ slog.Debug("mlx subprocess library path", "LD_LIBRARY_PATH", pathEnvVal)
+ }
+
+ // Estimate VRAM based on tensor size from manifest
+ var vramSize uint64
+ if modelManifest, err := manifest.LoadManifest(modelName); err == nil {
+ vramSize = uint64(modelManifest.TotalTensorSize())
+ } else {
+ vramSize = 8 * 1024 * 1024 * 1024
+ }
+
+ c := &Client{
+ port: port,
+ modelName: modelName,
+ vramSize: vramSize,
+ done: make(chan error, 1),
+ client: &http.Client{Timeout: 10 * time.Minute},
+ cmd: cmd,
+ }
+
+ // Forward subprocess stdout/stderr to server logs
+ stdout, _ := cmd.StdoutPipe()
+ stderr, _ := cmd.StderrPipe()
+ go func() {
+ scanner := bufio.NewScanner(stdout)
+ for scanner.Scan() {
+ slog.Info("mlx-runner", "msg", scanner.Text())
+ }
+ }()
+ go func() {
+ scanner := bufio.NewScanner(stderr)
+ for scanner.Scan() {
+ line := scanner.Text()
+ slog.Warn("mlx-runner", "msg", line)
+ c.lastErrLock.Lock()
+ c.lastErr = line
+ c.lastErrLock.Unlock()
+ }
+ }()
+
+ slog.Info("starting mlx runner subprocess", "exe", exe, "model", modelName, "port", port)
+ if err := cmd.Start(); err != nil {
+ return nil, fmt.Errorf("failed to start mlx runner: %w", err)
+ }
+
+ // Reap subprocess when it exits
+ go func() {
+ err := cmd.Wait()
+ c.done <- err
+ }()
+
+ // Wait for subprocess to be ready
+ if err := c.waitUntilRunning(); err != nil {
+ c.Close()
+ return nil, err
+ }
+
+ return c, nil
+}
+
+func (c *Client) getLastErr() string {
+ c.lastErrLock.Lock()
+ defer c.lastErrLock.Unlock()
+ return c.lastErr
+}
+
+func (c *Client) waitUntilRunning() error {
+ ctx := context.Background()
+ timeout := time.After(2 * time.Minute)
+ ticker := time.NewTicker(100 * time.Millisecond)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case err := <-c.done:
+ errMsg := c.getLastErr()
+ if errMsg != "" {
+ return fmt.Errorf("mlx runner failed: %s (exit: %v)", errMsg, err)
+ }
+ return fmt.Errorf("mlx runner exited unexpectedly: %w", err)
+ case <-timeout:
+ errMsg := c.getLastErr()
+ if errMsg != "" {
+ return fmt.Errorf("timeout waiting for mlx runner: %s", errMsg)
+ }
+ return errors.New("timeout waiting for mlx runner to start")
+ case <-ticker.C:
+ if err := c.Ping(ctx); err == nil {
+ slog.Info("mlx runner is ready", "port", c.port)
+ return nil
+ }
+ }
+ }
+}
+
+// completionRequest is a properly-tagged version of llm.CompletionRequest for JSON serialization.
+type completionRequest struct {
+ Prompt string `json:"prompt"`
+ Options *completionOpts `json:"options,omitempty"`
+}
+
+type completionOpts struct {
+ Temperature float32 `json:"temperature,omitempty"`
+ TopP float32 `json:"top_p,omitempty"`
+ MinP float32 `json:"min_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ NumPredict int `json:"num_predict,omitempty"`
+}
+
+// Close terminates the subprocess.
+func (c *Client) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if c.cmd != nil && c.cmd.Process != nil {
+ slog.Info("stopping mlx runner subprocess", "pid", c.cmd.Process.Pid)
+ c.cmd.Process.Signal(os.Interrupt)
+
+ select {
+ case <-c.done:
+ case <-time.After(5 * time.Second):
+ c.cmd.Process.Kill()
+ }
+ c.cmd = nil
+ }
+ return nil
+}
+
+// Completion implements llm.LlamaServer.
+func (c *Client) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
+ creq := completionRequest{
+ Prompt: req.Prompt,
+ }
+ if req.Options != nil {
+ creq.Options = &completionOpts{
+ Temperature: req.Options.Temperature,
+ TopP: req.Options.TopP,
+ MinP: req.Options.MinP,
+ TopK: req.Options.TopK,
+ NumPredict: req.Options.NumPredict,
+ }
+ }
+
+ body, err := json.Marshal(creq)
+ if err != nil {
+ return err
+ }
+
+ httpURL := fmt.Sprintf("http://127.0.0.1:%d/completion", c.port)
+ httpReq, err := http.NewRequestWithContext(ctx, "POST", httpURL, strings.NewReader(string(body)))
+ if err != nil {
+ return err
+ }
+ httpReq.Header.Set("Content-Type", "application/json")
+
+ resp, err := c.client.Do(httpReq)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode != http.StatusOK {
+ respBody, _ := io.ReadAll(resp.Body)
+ return fmt.Errorf("%s", strings.TrimSpace(string(respBody)))
+ }
+
+ scanner := bufio.NewScanner(resp.Body)
+ for scanner.Scan() {
+ var raw struct {
+ Content string `json:"content,omitempty"`
+ Done bool `json:"done"`
+ DoneReason int `json:"done_reason,omitempty"`
+ PromptEvalCount int `json:"prompt_eval_count,omitempty"`
+ PromptEvalDuration int `json:"prompt_eval_duration,omitempty"`
+ EvalCount int `json:"eval_count,omitempty"`
+ EvalDuration int `json:"eval_duration,omitempty"`
+ }
+ if err := json.Unmarshal(scanner.Bytes(), &raw); err != nil {
+ slog.Debug("mlx response parse error", "error", err, "line", string(scanner.Bytes()))
+ continue
+ }
+
+ cresp := llm.CompletionResponse{
+ Content: raw.Content,
+ Done: raw.Done,
+ DoneReason: llm.DoneReason(raw.DoneReason),
+ PromptEvalCount: raw.PromptEvalCount,
+ PromptEvalDuration: time.Duration(raw.PromptEvalDuration),
+ EvalCount: raw.EvalCount,
+ EvalDuration: time.Duration(raw.EvalDuration),
+ }
+
+ fn(cresp)
+ if cresp.Done {
+ return nil
+ }
+ }
+
+ return scanner.Err()
+}
+
+func (c *Client) ContextLength() int {
+ return math.MaxInt
+}
+
+// Detokenize implements llm.LlamaServer.
+func (c *Client) Detokenize(ctx context.Context, tokens []int) (string, error) {
+ return "", errors.New("not supported")
+}
+
+// Embedding implements llm.LlamaServer.
+func (c *Client) Embedding(ctx context.Context, input string) ([]float32, int, error) {
+ return nil, 0, errors.New("not supported")
+}
+
+// GetDeviceInfos implements llm.LlamaServer.
+func (c *Client) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
+ return nil
+}
+
+// GetPort implements llm.LlamaServer.
+func (c *Client) GetPort() int {
+ return c.port
+}
+
+// HasExited implements llm.LlamaServer.
+func (c *Client) HasExited() bool {
+ select {
+ case <-c.done:
+ return true
+ default:
+ return false
+ }
+}
+
+// Load implements llm.LlamaServer.
+func (c *Client) Load(ctx context.Context, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) ([]ml.DeviceID, error) {
+ return nil, nil
+}
+
+// ModelPath implements llm.LlamaServer.
+func (c *Client) ModelPath() string {
+ return c.modelName
+}
+
+// Pid implements llm.LlamaServer.
+func (c *Client) Pid() int {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.cmd != nil && c.cmd.Process != nil {
+ return c.cmd.Process.Pid
+ }
+ return -1
+}
+
+// Ping implements llm.LlamaServer.
+func (c *Client) Ping(ctx context.Context) error {
+ reqURL := fmt.Sprintf("http://127.0.0.1:%d/health", c.port)
+ req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil)
+ if err != nil {
+ return err
+ }
+ resp, err := c.client.Do(req)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ if resp.StatusCode != http.StatusOK {
+ return fmt.Errorf("health check failed: %d", resp.StatusCode)
+ }
+ return nil
+}
+
+// Tokenize implements llm.LlamaServer.
+func (c *Client) Tokenize(ctx context.Context, content string) ([]int, error) {
+ reqURL := fmt.Sprintf("http://127.0.0.1:%d/v1/tokenize", c.port)
+ req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(content))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "text/plain")
+
+ resp, err := c.client.Do(req)
+ if err != nil {
+ return nil, err
+ }
+ defer resp.Body.Close()
+
+ var tokens []int
+ if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
+ return nil, err
+ }
+
+ return tokens, nil
+}
+
+// TotalSize implements llm.LlamaServer.
+func (c *Client) TotalSize() uint64 {
+ return c.vramSize
+}
+
+// VRAMByGPU implements llm.LlamaServer.
+func (c *Client) VRAMByGPU(id ml.DeviceID) uint64 {
+ return c.vramSize
+}
+
+// VRAMSize implements llm.LlamaServer.
+func (c *Client) VRAMSize() uint64 {
+ return c.vramSize
+}
+
+// WaitUntilRunning implements llm.LlamaServer.
+func (c *Client) WaitUntilRunning(ctx context.Context) error {
+ return nil
+}
+
+var _ llm.LlamaServer = (*Client)(nil)
diff --git a/x/mlxrunner/imports.go b/x/mlxrunner/imports.go
new file mode 100644
index 00000000000..5e0cfe86d2f
--- /dev/null
+++ b/x/mlxrunner/imports.go
@@ -0,0 +1,10 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ _ "github.com/ollama/ollama/x/models/gemma3"
+ _ "github.com/ollama/ollama/x/models/glm4_moe_lite"
+ _ "github.com/ollama/ollama/x/models/llama"
+ _ "github.com/ollama/ollama/x/models/qwen3"
+)
diff --git a/x/mlxrunner/mlx/.gitignore b/x/mlxrunner/mlx/.gitignore
new file mode 100644
index 00000000000..b3ccd18fcee
--- /dev/null
+++ b/x/mlxrunner/mlx/.gitignore
@@ -0,0 +1,3 @@
+_deps
+build
+dist
diff --git a/x/mlxrunner/mlx/CMakeLists.txt b/x/mlxrunner/mlx/CMakeLists.txt
new file mode 100644
index 00000000000..c41ce46f79b
--- /dev/null
+++ b/x/mlxrunner/mlx/CMakeLists.txt
@@ -0,0 +1,26 @@
+cmake_minimum_required(VERSION 3.5)
+
+project(mlx)
+
+if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
+ set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_SOURCE_DIR}/dist" CACHE PATH "" FORCE)
+endif()
+
+set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
+set(MLX_BUILD_SAFETENSORS ON CACHE BOOL "" FORCE)
+set(MLX_C_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
+set(BUILD_SHARED_LIBS ON CACHE BOOL "" FORCE)
+
+set(CMAKE_INSTALL_RPATH "@loader_path")
+
+include(FetchContent)
+
+set(MLX_C_GIT_TAG "v0.4.1" CACHE STRING "")
+
+FetchContent_Declare(
+ mlx-c
+ GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
+ GIT_TAG ${MLX_C_GIT_TAG}
+)
+
+FetchContent_MakeAvailable(mlx-c)
diff --git a/x/mlxrunner/mlx/act.go b/x/mlxrunner/mlx/act.go
new file mode 100644
index 00000000000..3134a127a48
--- /dev/null
+++ b/x/mlxrunner/mlx/act.go
@@ -0,0 +1,23 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+import "math"
+
+func GELUApprox(t *Array) *Array {
+ return t.Multiply(
+ FromValue[float32](0.5),
+ ).Multiply(
+ t.Add(
+ t.Power(FromValue[float32](3.0)).Multiply(FromValue[float32](0.044715)),
+ ).Multiply(
+ FromValue(float32(math.Sqrt(2 / math.Pi))),
+ ).Tanh().Add(FromValue[float32](1.0)),
+ ).AsType(t.DType())
+}
+
+func SILU(t *Array) *Array {
+ return t.Multiply(t.Sigmoid()).AsType(t.DType())
+}
diff --git a/x/mlxrunner/mlx/array.go b/x/mlxrunner/mlx/array.go
new file mode 100644
index 00000000000..43254d230bc
--- /dev/null
+++ b/x/mlxrunner/mlx/array.go
@@ -0,0 +1,274 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "encoding/binary"
+ "log/slog"
+ "reflect"
+ "strings"
+ "time"
+ "unsafe"
+
+ "github.com/ollama/ollama/logutil"
+)
+
+type tensorDesc struct {
+ name string
+ inputs []*Array
+ numRefs int
+}
+
+func (d tensorDesc) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.String("name", d.name),
+ slog.Int("inputs", len(d.inputs)),
+ slog.Int("num_refs", d.numRefs),
+ )
+}
+
+type Array struct {
+ ctx C.mlx_array
+ desc tensorDesc
+}
+
+// constructor utilities
+
+func New(name string, inputs ...*Array) *Array {
+ t := &Array{
+ desc: tensorDesc{
+ name: name,
+ inputs: inputs,
+ },
+ }
+
+ for _, input := range inputs {
+ input.desc.numRefs++
+ }
+ logutil.Trace("New", "t", t)
+ return t
+}
+
+type scalarTypes interface {
+ ~bool | ~int | ~float32 | ~float64 | ~complex64
+}
+
+func FromValue[T scalarTypes](t T) *Array {
+ tt := New("")
+ switch v := any(t).(type) {
+ case bool:
+ tt.ctx = C.mlx_array_new_bool(C.bool(v))
+ case int:
+ tt.ctx = C.mlx_array_new_int(C.int(v))
+ case float32:
+ tt.ctx = C.mlx_array_new_float32(C.float(v))
+ case float64:
+ tt.ctx = C.mlx_array_new_float64(C.double(v))
+ case complex64:
+ tt.ctx = C.mlx_array_new_complex(C.float(real(v)), C.float(imag(v)))
+ default:
+ panic("unsupported type")
+ }
+ return tt
+}
+
+type arrayTypes interface {
+ ~bool | ~uint8 | ~uint16 | ~uint32 | ~uint64 |
+ ~int8 | ~int16 | ~int32 | ~int64 |
+ ~float32 | ~float64 |
+ ~complex64
+}
+
+func FromValues[S ~[]E, E arrayTypes](s S, shape ...int) *Array {
+ if len(shape) == 0 {
+ panic("shape must be provided for non-scalar tensors")
+ }
+
+ cShape := make([]C.int, len(shape))
+ for i := range shape {
+ cShape[i] = C.int(shape[i])
+ }
+
+ var dtype DType
+ switch reflect.TypeOf(s).Elem().Kind() {
+ case reflect.Bool:
+ dtype = DTypeBool
+ case reflect.Uint8:
+ dtype = DTypeUint8
+ case reflect.Uint16:
+ dtype = DTypeUint16
+ case reflect.Uint32:
+ dtype = DTypeUint32
+ case reflect.Uint64:
+ dtype = DTypeUint64
+ case reflect.Int8:
+ dtype = DTypeInt8
+ case reflect.Int16:
+ dtype = DTypeInt16
+ case reflect.Int32:
+ dtype = DTypeInt32
+ case reflect.Int64:
+ dtype = DTypeInt64
+ case reflect.Float32:
+ dtype = DTypeFloat32
+ case reflect.Float64:
+ dtype = DTypeFloat64
+ case reflect.Complex64:
+ dtype = DTypeComplex64
+ default:
+ panic("unsupported type")
+ }
+
+ bts := make([]byte, binary.Size(s))
+ if _, err := binary.Encode(bts, binary.LittleEndian, s); err != nil {
+ panic(err)
+ }
+
+ tt := New("")
+ tt.ctx = C.mlx_array_new_data(unsafe.Pointer(&bts[0]), unsafe.SliceData(cShape), C.int(len(cShape)), C.mlx_dtype(dtype))
+ return tt
+}
+
+func (t *Array) Set(other *Array) {
+ Free(t.desc.inputs...)
+ other.desc.numRefs++
+ t.desc.inputs = []*Array{other}
+ C.mlx_array_set(&t.ctx, other.ctx)
+}
+
+func (t *Array) Clone() *Array {
+ tt := New(t.desc.name, t.desc.inputs...)
+ C.mlx_array_set(&tt.ctx, t.ctx)
+ return tt
+}
+
+// misc. utilities
+
+func (t *Array) Valid() bool {
+ return t.ctx.ctx != nil
+}
+
+func (t *Array) String() string {
+ str := C.mlx_string_new()
+ defer C.mlx_string_free(str)
+ C.mlx_array_tostring(&str, t.ctx)
+ return strings.TrimSpace(C.GoString(C.mlx_string_data(str)))
+}
+
+func (t *Array) LogValue() slog.Value {
+ attrs := []slog.Attr{slog.Any("", t.desc)}
+ if t.Valid() {
+ attrs = append(attrs,
+ slog.Any("dtype", t.DType()),
+ slog.Any("shape", t.Dims()),
+ slog.Int("num_bytes", t.NumBytes()),
+ )
+ }
+ return slog.GroupValue(attrs...)
+}
+
+// shape utilities
+
+func (t Array) Size() int {
+ return int(C.mlx_array_size(t.ctx))
+}
+
+func (t Array) NumBytes() int {
+ return int(C.mlx_array_nbytes(t.ctx))
+}
+
+func (t Array) NumDims() int {
+ return int(C.mlx_array_ndim(t.ctx))
+}
+
+func (t Array) Dims() []int {
+ dims := make([]int, t.NumDims())
+ for i := range dims {
+ dims[i] = t.Dim(i)
+ }
+
+ return dims
+}
+
+func (t Array) Dim(dim int) int {
+ return int(C.mlx_array_dim(t.ctx, C.int(dim)))
+}
+
+func (t Array) DType() DType {
+ return DType(C.mlx_array_dtype(t.ctx))
+}
+
+// data utilities
+
+func (t Array) Int() int {
+ var item C.int64_t
+ C.mlx_array_item_int64(&item, t.ctx)
+ return int(item)
+}
+
+func (t Array) Float() float64 {
+ var item C.double
+ C.mlx_array_item_float64(&item, t.ctx)
+ return float64(item)
+}
+
+func (t Array) Ints() []int {
+ ints := make([]int, t.Size())
+ for i, f := range unsafe.Slice(C.mlx_array_data_int32(t.ctx), len(ints)) {
+ ints[i] = int(f)
+ }
+ return ints
+}
+
+func (t Array) Floats() []float32 {
+ floats := make([]float32, t.Size())
+ for i, f := range unsafe.Slice(C.mlx_array_data_float32(t.ctx), len(floats)) {
+ floats[i] = float32(f)
+ }
+ return floats
+}
+
+func (t Array) Save(name string) error {
+ cName := C.CString(name)
+ defer C.free(unsafe.Pointer(cName))
+ C.mlx_save(cName, t.ctx)
+ return nil
+}
+
+func Free(s ...*Array) (n int) {
+ now := time.Now()
+ defer func() {
+ if n > 0 {
+ logutil.Trace("Freed tensors", "num_bytes", PrettyBytes(n), "took", time.Since(now))
+ }
+ }()
+
+ free := make([]*Array, 0, 8192)
+ fn := func(t *Array) {
+ if t.Valid() {
+ t.desc.numRefs--
+ if t.desc.numRefs <= 0 {
+ free = append(free, t.desc.inputs...)
+ logutil.Trace("Free", "t", t)
+ n += t.NumBytes()
+ C.mlx_array_free(t.ctx)
+ t.ctx.ctx = nil
+ }
+ }
+ }
+
+ for _, t := range s {
+ fn(t)
+ }
+
+ for len(free) > 0 {
+ tail := free[len(free)-1]
+ free = free[:len(free)-1]
+ fn(tail)
+ }
+
+ return n
+}
diff --git a/x/mlxrunner/mlx/array_test.go b/x/mlxrunner/mlx/array_test.go
new file mode 100644
index 00000000000..aab5db7ba87
--- /dev/null
+++ b/x/mlxrunner/mlx/array_test.go
@@ -0,0 +1,45 @@
+//go:build mlx
+
+package mlx
+
+import "testing"
+
+func TestFromValue(t *testing.T) {
+ for got, want := range map[*Array]DType{
+ FromValue(true): DTypeBool,
+ FromValue(false): DTypeBool,
+ FromValue(int(7)): DTypeInt32,
+ FromValue(float32(3.14)): DTypeFloat32,
+ FromValue(float64(2.71)): DTypeFloat64,
+ FromValue(complex64(1 + 2i)): DTypeComplex64,
+ } {
+ t.Run(want.String(), func(t *testing.T) {
+ if got.DType() != want {
+ t.Errorf("want %v, got %v", want, got)
+ }
+ })
+ }
+}
+
+func TestFromValues(t *testing.T) {
+ for got, want := range map[*Array]DType{
+ FromValues([]bool{true, false, true}, 3): DTypeBool,
+ FromValues([]uint8{1, 2, 3}, 3): DTypeUint8,
+ FromValues([]uint16{1, 2, 3}, 3): DTypeUint16,
+ FromValues([]uint32{1, 2, 3}, 3): DTypeUint32,
+ FromValues([]uint64{1, 2, 3}, 3): DTypeUint64,
+ FromValues([]int8{-1, -2, -3}, 3): DTypeInt8,
+ FromValues([]int16{-1, -2, -3}, 3): DTypeInt16,
+ FromValues([]int32{-1, -2, -3}, 3): DTypeInt32,
+ FromValues([]int64{-1, -2, -3}, 3): DTypeInt64,
+ FromValues([]float32{3.14, 2.71, 1.61}, 3): DTypeFloat32,
+ FromValues([]float64{3.14, 2.71, 1.61}, 3): DTypeFloat64,
+ FromValues([]complex64{1 + 2i, 3 + 4i, 5 + 6i}, 3): DTypeComplex64,
+ } {
+ t.Run(want.String(), func(t *testing.T) {
+ if got.DType() != want {
+ t.Errorf("want %v, got %v", want, got)
+ }
+ })
+ }
+}
diff --git a/x/mlxrunner/mlx/dtype.go b/x/mlxrunner/mlx/dtype.go
new file mode 100644
index 00000000000..95237c7924d
--- /dev/null
+++ b/x/mlxrunner/mlx/dtype.go
@@ -0,0 +1,96 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+type DType int
+
+func (t DType) String() string {
+ switch t {
+ case DTypeBool:
+ return "BOOL"
+ case DTypeUint8:
+ return "U8"
+ case DTypeUint16:
+ return "U16"
+ case DTypeUint32:
+ return "U32"
+ case DTypeUint64:
+ return "U64"
+ case DTypeInt8:
+ return "I8"
+ case DTypeInt16:
+ return "I16"
+ case DTypeInt32:
+ return "I32"
+ case DTypeInt64:
+ return "I64"
+ case DTypeFloat16:
+ return "F16"
+ case DTypeFloat32:
+ return "F32"
+ case DTypeFloat64:
+ return "F64"
+ case DTypeBFloat16:
+ return "BF16"
+ case DTypeComplex64:
+ return "C64"
+ default:
+ return "Unknown"
+ }
+}
+
+func (t *DType) UnmarshalJSON(b []byte) error {
+ switch string(b) {
+ case `"BOOL"`:
+ *t = DTypeBool
+ case `"U8"`:
+ *t = DTypeUint8
+ case `"U16"`:
+ *t = DTypeUint16
+ case `"U32"`:
+ *t = DTypeUint32
+ case `"U64"`:
+ *t = DTypeUint64
+ case `"I8"`:
+ *t = DTypeInt8
+ case `"I16"`:
+ *t = DTypeInt16
+ case `"I32"`:
+ *t = DTypeInt32
+ case `"I64"`:
+ *t = DTypeInt64
+ case `"F16"`:
+ *t = DTypeFloat16
+ case `"F64"`:
+ *t = DTypeFloat64
+ case `"F32"`:
+ *t = DTypeFloat32
+ case `"BF16"`:
+ *t = DTypeBFloat16
+ case `"C64"`:
+ *t = DTypeComplex64
+ default:
+ return nil
+ }
+ return nil
+}
+
+const (
+ DTypeBool DType = C.MLX_BOOL
+ DTypeUint8 DType = C.MLX_UINT8
+ DTypeUint16 DType = C.MLX_UINT16
+ DTypeUint32 DType = C.MLX_UINT32
+ DTypeUint64 DType = C.MLX_UINT64
+ DTypeInt8 DType = C.MLX_INT8
+ DTypeInt16 DType = C.MLX_INT16
+ DTypeInt32 DType = C.MLX_INT32
+ DTypeInt64 DType = C.MLX_INT64
+ DTypeFloat16 DType = C.MLX_FLOAT16
+ DTypeFloat32 DType = C.MLX_FLOAT32
+ DTypeFloat64 DType = C.MLX_FLOAT64
+ DTypeBFloat16 DType = C.MLX_BFLOAT16
+ DTypeComplex64 DType = C.MLX_COMPLEX64
+)
diff --git a/x/mlxrunner/mlx/dynamic.c b/x/mlxrunner/mlx/dynamic.c
new file mode 100644
index 00000000000..d3c4e6e6ce3
--- /dev/null
+++ b/x/mlxrunner/mlx/dynamic.c
@@ -0,0 +1,34 @@
+#include "dynamic.h"
+
+#include
+
+#ifdef _WIN32
+#include
+#define DLOPEN(path) LoadLibraryA(path)
+#define DLCLOSE(handle) FreeLibrary((HMODULE)(handle))
+#else
+#ifdef __APPLE__
+#include
+#include
+#endif
+#include
+#define DLOPEN(path) dlopen(path, RTLD_LAZY | RTLD_GLOBAL)
+#define DLCLOSE(handle) dlclose(handle)
+#endif
+
+static int mlx_dynamic_open(mlx_dynamic_handle* handle, const char* path) {
+ handle->ctx = (void*) DLOPEN(path);
+ CHECK(handle->ctx != NULL);
+ return 0;
+}
+
+int mlx_dynamic_load(mlx_dynamic_handle* handle, const char *path) {
+ return mlx_dynamic_open(handle, path);
+}
+
+void mlx_dynamic_unload(mlx_dynamic_handle* handle) {
+ if (handle->ctx) {
+ DLCLOSE(handle->ctx);
+ handle->ctx = NULL;
+ }
+}
diff --git a/x/mlxrunner/mlx/dynamic.go b/x/mlxrunner/mlx/dynamic.go
new file mode 100644
index 00000000000..a1286da5924
--- /dev/null
+++ b/x/mlxrunner/mlx/dynamic.go
@@ -0,0 +1,126 @@
+//go:build mlx
+
+package mlx
+
+// #include "dynamic.h"
+// #include "generated.h"
+// #include
+import "C"
+
+import (
+ "fmt"
+ "io/fs"
+ "log/slog"
+ "os"
+ "path/filepath"
+ "runtime"
+ "unsafe"
+)
+
+var initError error
+
+// CheckInit returns any error that occurred during MLX dynamic library initialization.
+func CheckInit() error {
+ return initError
+}
+
+// tryLoadFromDir searches a directory for libmlxc.* and tries to load it.
+// Returns true if the library was successfully loaded.
+func tryLoadFromDir(dir string) bool {
+ matches, err := fs.Glob(os.DirFS(dir), "libmlxc.*")
+ if err != nil || len(matches) == 0 {
+ return false
+ }
+
+ for _, match := range matches {
+ path := filepath.Join(dir, match)
+
+ cPath := C.CString(path)
+ defer C.free(unsafe.Pointer(cPath))
+
+ var handle C.mlx_dynamic_handle
+ if C.mlx_dynamic_load(&handle, cPath) != 0 {
+ slog.Error("Failed to load MLX dynamic library", "path", path)
+ continue
+ }
+
+ if C.mlx_dynamic_load_symbols(handle) != 0 {
+ slog.Error("Failed to load MLX dynamic library symbols", "path", path)
+ C.mlx_dynamic_unload(&handle)
+ continue
+ }
+
+ return true
+ }
+ return false
+}
+
+// tryLoadByName attempts to load the library using just its name,
+// allowing the system to use rpath, LD_LIBRARY_PATH, or standard search paths.
+// Returns true if the library was successfully loaded.
+func tryLoadByName() bool {
+ libraryName := "libmlxc.dylib"
+ if runtime.GOOS == "linux" {
+ libraryName = "libmlxc.so"
+ }
+
+ cPath := C.CString(libraryName)
+ defer C.free(unsafe.Pointer(cPath))
+
+ var handle C.mlx_dynamic_handle
+ if C.mlx_dynamic_load(&handle, cPath) != 0 {
+ return false
+ }
+ if C.mlx_dynamic_load_symbols(handle) != 0 {
+ C.mlx_dynamic_unload(&handle)
+ return false
+ }
+
+ return true
+}
+
+func init() {
+ switch runtime.GOOS {
+ case "darwin":
+
+ case "windows":
+ default:
+ return
+ }
+
+ // Try OLLAMA_LIBRARY_PATH first
+ if paths, ok := os.LookupEnv("OLLAMA_LIBRARY_PATH"); ok {
+ for _, dir := range filepath.SplitList(paths) {
+ if tryLoadFromDir(dir) {
+ return
+ }
+ }
+ }
+
+ // Try loading via rpath/standard library search
+ if tryLoadByName() {
+ return
+ }
+
+ // Build search paths: executable directory, then build directories
+ var searchDirs []string
+ if exe, err := os.Executable(); err == nil {
+ if eval, err := filepath.EvalSymlinks(exe); err == nil {
+ exe = eval
+ }
+ searchDirs = append(searchDirs, filepath.Dir(exe))
+ }
+
+ if cwd, err := os.Getwd(); err == nil {
+ searchDirs = append(searchDirs, filepath.Join(cwd, "build", "lib", "ollama"))
+ }
+
+ for _, dir := range searchDirs {
+ if tryLoadFromDir(dir) {
+ return
+ }
+ }
+
+ initError = fmt.Errorf("failed to load MLX dynamic library (searched: %v)", searchDirs)
+ slog.Warn("MLX dynamic library not available", "error", initError)
+}
diff --git a/x/mlxrunner/mlx/dynamic.h b/x/mlxrunner/mlx/dynamic.h
new file mode 100644
index 00000000000..f93d8fab790
--- /dev/null
+++ b/x/mlxrunner/mlx/dynamic.h
@@ -0,0 +1,41 @@
+#ifndef MLX_DYNAMIC_H
+#define MLX_DYNAMIC_H
+
+#ifdef _WIN32
+#include
+#define DLSYM(handle, symbol) GetProcAddress((HMODULE)(handle), symbol)
+#else
+#include
+#define DLSYM(handle, symbol) dlsym(handle.ctx, symbol)
+#endif
+
+#include
+
+// Provide fallback typedefs for float16_t and bfloat16_t on non-ARM64
+// platforms where arm_fp16.h and arm_bf16.h are not available. These are
+// only used as function pointer signature placeholders since MLX requires
+// Apple Silicon at runtime.
+#if !defined(__aarch64__) && !defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC)
+typedef uint16_t float16_t;
+#endif
+
+#if !defined(__aarch64__) && !defined(__ARM_FEATURE_BF16)
+typedef uint16_t bfloat16_t;
+#endif
+
+#define ERROR(fmt, ...) fprintf(stderr, "%s %s - ERROR - %s:%d - " fmt "\n", __DATE__, __TIME__, __FILE__, __LINE__, ##__VA_ARGS__); return 1
+#define CHECK(x) if (!(x)) { ERROR("CHECK failed: " #x); }
+#define CHECK_LOAD(handle, x) x##_ = DLSYM(handle, #x); CHECK(x##_)
+
+typedef struct {
+ void* ctx;
+} mlx_dynamic_handle;
+
+int mlx_dynamic_load(
+ mlx_dynamic_handle* handle,
+ const char *path);
+
+void mlx_dynamic_unload(
+ mlx_dynamic_handle* handle);
+
+#endif // MLX_DYNAMIC_H
diff --git a/x/mlxrunner/mlx/fast.go b/x/mlxrunner/mlx/fast.go
new file mode 100644
index 00000000000..250d42dc8ce
--- /dev/null
+++ b/x/mlxrunner/mlx/fast.go
@@ -0,0 +1,74 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "unsafe"
+)
+
+func ScaledDotProductAttention(query, key, value, mask *Array, scale float32) *Array {
+ if mask == nil {
+ mask = New("")
+ }
+
+ sinks := New("")
+
+ mode := "causal"
+ cMode := C.CString(mode)
+ defer C.free(unsafe.Pointer(cMode))
+
+ out := New("FAST_SDPA", query, key, value, mask, sinks)
+ C.mlx_fast_scaled_dot_product_attention(&out.ctx, query.ctx, key.ctx, value.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
+ return out
+}
+
+type LayerNorm struct {
+ Weight Array `weight:"weight"`
+ Bias Array `weight:"bias"`
+}
+
+func (r *LayerNorm) Forward(x *Array, eps float32) *Array {
+ out := New("FAST_LAYERNORM", x)
+ C.mlx_fast_layer_norm(&out.ctx, x.ctx, r.Weight.ctx, r.Bias.ctx, C.float(eps), DefaultStream().ctx)
+ return out
+}
+
+type RMSNorm struct {
+ Weight Array `weight:"weight"`
+}
+
+func (r RMSNorm) Forward(x *Array, eps float32) *Array {
+ out := New("FAST_RMSNORM", x)
+ C.mlx_fast_rms_norm(&out.ctx, x.ctx, r.Weight.ctx, C.float(eps), DefaultStream().ctx)
+ return out
+}
+
+type RoPE struct {
+ Dims int
+ Traditional bool
+ Base float32 `json:"rope_theta"`
+ Scale float32
+}
+
+func (r RoPE) Forward(t *Array, offset int) *Array {
+ freqs := New("")
+ out := New("FAST_ROPE", t, freqs)
+ C.mlx_fast_rope(
+ &out.ctx,
+ t.ctx,
+ C.int(r.Dims),
+ C._Bool(r.Traditional),
+ C.mlx_optional_float{
+ value: C.float(r.Base),
+ has_value: C._Bool(func() bool { return r.Base != 0 }()),
+ },
+ C.float(r.Scale),
+ C.int(offset),
+ freqs.ctx,
+ DefaultStream().ctx,
+ )
+ return out
+}
diff --git a/x/mlxrunner/mlx/generated.c b/x/mlxrunner/mlx/generated.c
new file mode 100644
index 00000000000..af99b631e73
--- /dev/null
+++ b/x/mlxrunner/mlx/generated.c
@@ -0,0 +1,2724 @@
+// This code is auto-generated; DO NOT EDIT.
+
+#include "generated.h"
+
+#include
+#include
+#include
+
+size_t (*mlx_dtype_size_)(mlx_dtype dtype) = NULL;
+int (*mlx_array_tostring_)(mlx_string* str, const mlx_array arr) = NULL;
+mlx_array (*mlx_array_new_)(void) = NULL;
+int (*mlx_array_free_)(mlx_array arr) = NULL;
+mlx_array (*mlx_array_new_bool_)(bool val) = NULL;
+mlx_array (*mlx_array_new_int_)(int val) = NULL;
+mlx_array (*mlx_array_new_float32_)(float val) = NULL;
+mlx_array (*mlx_array_new_float_)(float val) = NULL;
+mlx_array (*mlx_array_new_float64_)(double val) = NULL;
+mlx_array (*mlx_array_new_double_)(double val) = NULL;
+mlx_array (*mlx_array_new_complex_)(float real_val, float imag_val) = NULL;
+mlx_array (*mlx_array_new_data_)(
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype) = NULL;
+int (*mlx_array_set_)(mlx_array* arr, const mlx_array src) = NULL;
+int (*mlx_array_set_bool_)(mlx_array* arr, bool val) = NULL;
+int (*mlx_array_set_int_)(mlx_array* arr, int val) = NULL;
+int (*mlx_array_set_float32_)(mlx_array* arr, float val) = NULL;
+int (*mlx_array_set_float_)(mlx_array* arr, float val) = NULL;
+int (*mlx_array_set_float64_)(mlx_array* arr, double val) = NULL;
+int (*mlx_array_set_double_)(mlx_array* arr, double val) = NULL;
+int (*mlx_array_set_complex_)(mlx_array* arr, float real_val, float imag_val) = NULL;
+int (*mlx_array_set_data_)(
+ mlx_array* arr,
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype) = NULL;
+size_t (*mlx_array_itemsize_)(const mlx_array arr) = NULL;
+size_t (*mlx_array_size_)(const mlx_array arr) = NULL;
+size_t (*mlx_array_nbytes_)(const mlx_array arr) = NULL;
+size_t (*mlx_array_ndim_)(const mlx_array arr) = NULL;
+const int * (*mlx_array_shape_)(const mlx_array arr) = NULL;
+const size_t * (*mlx_array_strides_)(const mlx_array arr) = NULL;
+int (*mlx_array_dim_)(const mlx_array arr, int dim) = NULL;
+mlx_dtype (*mlx_array_dtype_)(const mlx_array arr) = NULL;
+int (*mlx_array_eval_)(mlx_array arr) = NULL;
+int (*mlx_array_item_bool_)(bool* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_uint8_)(uint8_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_uint16_)(uint16_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_uint32_)(uint32_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_uint64_)(uint64_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_int8_)(int8_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_int16_)(int16_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_float32_)(float* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_float64_)(double* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr) = NULL;
+int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr) = NULL;
+const bool * (*mlx_array_data_bool_)(const mlx_array arr) = NULL;
+const uint8_t * (*mlx_array_data_uint8_)(const mlx_array arr) = NULL;
+const uint16_t * (*mlx_array_data_uint16_)(const mlx_array arr) = NULL;
+const uint32_t * (*mlx_array_data_uint32_)(const mlx_array arr) = NULL;
+const uint64_t * (*mlx_array_data_uint64_)(const mlx_array arr) = NULL;
+const int8_t * (*mlx_array_data_int8_)(const mlx_array arr) = NULL;
+const int16_t * (*mlx_array_data_int16_)(const mlx_array arr) = NULL;
+const int32_t * (*mlx_array_data_int32_)(const mlx_array arr) = NULL;
+const int64_t * (*mlx_array_data_int64_)(const mlx_array arr) = NULL;
+const float * (*mlx_array_data_float32_)(const mlx_array arr) = NULL;
+const double * (*mlx_array_data_float64_)(const mlx_array arr) = NULL;
+const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr) = NULL;
+const float16_t * (*mlx_array_data_float16_)(const mlx_array arr) = NULL;
+const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr) = NULL;
+int (*_mlx_array_is_available_)(bool* res, const mlx_array arr) = NULL;
+int (*_mlx_array_wait_)(const mlx_array arr) = NULL;
+int (*_mlx_array_is_contiguous_)(bool* res, const mlx_array arr) = NULL;
+int (*_mlx_array_is_row_contiguous_)(bool* res, const mlx_array arr) = NULL;
+int (*_mlx_array_is_col_contiguous_)(bool* res, const mlx_array arr) = NULL;
+mlx_closure (*mlx_closure_new_)(void) = NULL;
+int (*mlx_closure_free_)(mlx_closure cls) = NULL;
+mlx_closure (*mlx_closure_new_func_)(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array)) = NULL;
+mlx_closure (*mlx_closure_new_func_payload_)(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
+ void* payload,
+ void (*dtor)(void*)) = NULL;
+int (*mlx_closure_set_)(mlx_closure* cls, const mlx_closure src) = NULL;
+int (*mlx_closure_apply_)(
+ mlx_vector_array* res,
+ mlx_closure cls,
+ const mlx_vector_array input) = NULL;
+mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array)) = NULL;
+mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void) = NULL;
+int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls) = NULL;
+mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array)) = NULL;
+mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) = NULL;
+int (*mlx_closure_kwargs_set_)(
+ mlx_closure_kwargs* cls,
+ const mlx_closure_kwargs src) = NULL;
+int (*mlx_closure_kwargs_apply_)(
+ mlx_vector_array* res,
+ mlx_closure_kwargs cls,
+ const mlx_vector_array input_0,
+ const mlx_map_string_to_array input_1) = NULL;
+mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_)(void) = NULL;
+int (*mlx_closure_value_and_grad_free_)(mlx_closure_value_and_grad cls) = NULL;
+mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_)(
+ int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) = NULL;
+mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_array*,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) = NULL;
+int (*mlx_closure_value_and_grad_set_)(
+ mlx_closure_value_and_grad* cls,
+ const mlx_closure_value_and_grad src) = NULL;
+int (*mlx_closure_value_and_grad_apply_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ mlx_closure_value_and_grad cls,
+ const mlx_vector_array input) = NULL;
+mlx_closure_custom (*mlx_closure_custom_new_)(void) = NULL;
+int (*mlx_closure_custom_free_)(mlx_closure_custom cls) = NULL;
+mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array)) = NULL;
+mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) = NULL;
+int (*mlx_closure_custom_set_)(
+ mlx_closure_custom* cls,
+ const mlx_closure_custom src) = NULL;
+int (*mlx_closure_custom_apply_)(
+ mlx_vector_array* res,
+ mlx_closure_custom cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const mlx_vector_array input_2) = NULL;
+mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void) = NULL;
+int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls) = NULL;
+mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num)) = NULL;
+mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) = NULL;
+int (*mlx_closure_custom_jvp_set_)(
+ mlx_closure_custom_jvp* cls,
+ const mlx_closure_custom_jvp src) = NULL;
+int (*mlx_closure_custom_jvp_apply_)(
+ mlx_vector_array* res,
+ mlx_closure_custom_jvp cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const int* input_2,
+ size_t input_2_num) = NULL;
+mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void) = NULL;
+int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls) = NULL;
+mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num)) = NULL;
+mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) = NULL;
+int (*mlx_closure_custom_vmap_set_)(
+ mlx_closure_custom_vmap* cls,
+ const mlx_closure_custom_vmap src) = NULL;
+int (*mlx_closure_custom_vmap_apply_)(
+ mlx_vector_array* res_0,
+ mlx_vector_int* res_1,
+ mlx_closure_custom_vmap cls,
+ const mlx_vector_array input_0,
+ const int* input_1,
+ size_t input_1_num) = NULL;
+int (*mlx_compile_)(mlx_closure* res, const mlx_closure fun, bool shapeless) = NULL;
+int (*mlx_detail_compile_)(
+ mlx_closure* res,
+ const mlx_closure fun,
+ uintptr_t fun_id,
+ bool shapeless,
+ const uint64_t* constants,
+ size_t constants_num) = NULL;
+int (*mlx_detail_compile_clear_cache_)(void) = NULL;
+int (*mlx_detail_compile_erase_)(uintptr_t fun_id) = NULL;
+int (*mlx_disable_compile_)(void) = NULL;
+int (*mlx_enable_compile_)(void) = NULL;
+int (*mlx_set_compile_mode_)(mlx_compile_mode mode) = NULL;
+mlx_device (*mlx_device_new_)(void) = NULL;
+mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index) = NULL;
+int (*mlx_device_free_)(mlx_device dev) = NULL;
+int (*mlx_device_set_)(mlx_device* dev, const mlx_device src) = NULL;
+int (*mlx_device_tostring_)(mlx_string* str, mlx_device dev) = NULL;
+bool (*mlx_device_equal_)(mlx_device lhs, mlx_device rhs) = NULL;
+int (*mlx_device_get_index_)(int* index, mlx_device dev) = NULL;
+int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev) = NULL;
+int (*mlx_get_default_device_)(mlx_device* dev) = NULL;
+int (*mlx_set_default_device_)(mlx_device dev) = NULL;
+int (*mlx_distributed_group_rank_)(mlx_distributed_group group) = NULL;
+int (*mlx_distributed_group_size_)(mlx_distributed_group group) = NULL;
+mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key) = NULL;
+bool (*mlx_distributed_is_available_)(void) = NULL;
+mlx_distributed_group (*mlx_distributed_init_)(bool strict) = NULL;
+int (*mlx_distributed_all_gather_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream S) = NULL;
+int (*mlx_distributed_all_max_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_distributed_all_min_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_distributed_all_sum_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_distributed_recv_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_distributed_recv_like_)(
+ mlx_array* res,
+ const mlx_array x,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_distributed_send_)(
+ mlx_array* res,
+ const mlx_array x,
+ int dst,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_distributed_sum_scatter_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) = NULL;
+void (*mlx_set_error_handler_)(
+ mlx_error_handler_func handler,
+ void* data,
+ void (*dtor)(void*)) = NULL;
+void (*_mlx_error_)(const char* file, const int line, const char* fmt, ...) = NULL;
+int (*mlx_export_function_)(
+ const char* file,
+ const mlx_closure fun,
+ const mlx_vector_array args,
+ bool shapeless) = NULL;
+int (*mlx_export_function_kwargs_)(
+ const char* file,
+ const mlx_closure_kwargs fun,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs,
+ bool shapeless) = NULL;
+mlx_function_exporter (*mlx_function_exporter_new_)(
+ const char* file,
+ const mlx_closure fun,
+ bool shapeless) = NULL;
+int (*mlx_function_exporter_free_)(mlx_function_exporter xfunc) = NULL;
+int (*mlx_function_exporter_apply_)(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args) = NULL;
+int (*mlx_function_exporter_apply_kwargs_)(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs) = NULL;
+mlx_imported_function (*mlx_imported_function_new_)(const char* file) = NULL;
+int (*mlx_imported_function_free_)(mlx_imported_function xfunc) = NULL;
+int (*mlx_imported_function_apply_)(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args) = NULL;
+int (*mlx_imported_function_apply_kwargs_)(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs) = NULL;
+mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_)(void) = NULL;
+void (*mlx_fast_cuda_kernel_config_free_)(mlx_fast_cuda_kernel_config cls) = NULL;
+int (*mlx_fast_cuda_kernel_config_add_output_arg_)(
+ mlx_fast_cuda_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype) = NULL;
+int (*mlx_fast_cuda_kernel_config_set_grid_)(
+ mlx_fast_cuda_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3) = NULL;
+int (*mlx_fast_cuda_kernel_config_set_thread_group_)(
+ mlx_fast_cuda_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3) = NULL;
+int (*mlx_fast_cuda_kernel_config_set_init_value_)(
+ mlx_fast_cuda_kernel_config cls,
+ float value) = NULL;
+int (*mlx_fast_cuda_kernel_config_set_verbose_)(
+ mlx_fast_cuda_kernel_config cls,
+ bool verbose) = NULL;
+int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_)(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype) = NULL;
+int (*mlx_fast_cuda_kernel_config_add_template_arg_int_)(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ int value) = NULL;
+int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_)(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ bool value) = NULL;
+mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_)(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ int shared_memory) = NULL;
+void (*mlx_fast_cuda_kernel_free_)(mlx_fast_cuda_kernel cls) = NULL;
+int (*mlx_fast_cuda_kernel_apply_)(
+ mlx_vector_array* outputs,
+ mlx_fast_cuda_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_cuda_kernel_config config,
+ const mlx_stream stream) = NULL;
+int (*mlx_fast_layer_norm_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ const mlx_array bias /* may be null */,
+ float eps,
+ const mlx_stream s) = NULL;
+mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_)(void) = NULL;
+void (*mlx_fast_metal_kernel_config_free_)(mlx_fast_metal_kernel_config cls) = NULL;
+int (*mlx_fast_metal_kernel_config_add_output_arg_)(
+ mlx_fast_metal_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype) = NULL;
+int (*mlx_fast_metal_kernel_config_set_grid_)(
+ mlx_fast_metal_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3) = NULL;
+int (*mlx_fast_metal_kernel_config_set_thread_group_)(
+ mlx_fast_metal_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3) = NULL;
+int (*mlx_fast_metal_kernel_config_set_init_value_)(
+ mlx_fast_metal_kernel_config cls,
+ float value) = NULL;
+int (*mlx_fast_metal_kernel_config_set_verbose_)(
+ mlx_fast_metal_kernel_config cls,
+ bool verbose) = NULL;
+int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_)(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype) = NULL;
+int (*mlx_fast_metal_kernel_config_add_template_arg_int_)(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ int value) = NULL;
+int (*mlx_fast_metal_kernel_config_add_template_arg_bool_)(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ bool value) = NULL;
+mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_)(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ bool atomic_outputs) = NULL;
+void (*mlx_fast_metal_kernel_free_)(mlx_fast_metal_kernel cls) = NULL;
+int (*mlx_fast_metal_kernel_apply_)(
+ mlx_vector_array* outputs,
+ mlx_fast_metal_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_metal_kernel_config config,
+ const mlx_stream stream) = NULL;
+int (*mlx_fast_rms_norm_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ float eps,
+ const mlx_stream s) = NULL;
+int (*mlx_fast_rope_)(
+ mlx_array* res,
+ const mlx_array x,
+ int dims,
+ bool traditional,
+ mlx_optional_float base,
+ float scale,
+ int offset,
+ const mlx_array freqs /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_fast_scaled_dot_product_attention_)(
+ mlx_array* res,
+ const mlx_array queries,
+ const mlx_array keys,
+ const mlx_array values,
+ float scale,
+ const char* mask_mode,
+ const mlx_array mask_arr /* may be null */,
+ const mlx_array sinks /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_fft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_fft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_fftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_fftshift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_ifft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_ifft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_ifftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_ifftshift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_irfft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_irfft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_irfftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_rfft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_rfft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_fft_rfftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable) = NULL;
+int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io) = NULL;
+int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io) = NULL;
+int (*mlx_io_reader_free_)(mlx_io_reader io) = NULL;
+mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable) = NULL;
+int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io) = NULL;
+int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io) = NULL;
+int (*mlx_io_writer_free_)(mlx_io_writer io) = NULL;
+int (*mlx_load_reader_)(
+ mlx_array* res,
+ mlx_io_reader in_stream,
+ const mlx_stream s) = NULL;
+int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s) = NULL;
+int (*mlx_load_safetensors_reader_)(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ mlx_io_reader in_stream,
+ const mlx_stream s) = NULL;
+int (*mlx_load_safetensors_)(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ const char* file,
+ const mlx_stream s) = NULL;
+int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a) = NULL;
+int (*mlx_save_)(const char* file, const mlx_array a) = NULL;
+int (*mlx_save_safetensors_writer_)(
+ mlx_io_writer in_stream,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata) = NULL;
+int (*mlx_save_safetensors_)(
+ const char* file,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata) = NULL;
+int (*mlx_linalg_cholesky_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_cholesky_inv_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_cross_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_eig_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_eigh_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_eigvals_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_linalg_eigvalsh_)(
+ mlx_array* res,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_inv_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_linalg_lu_)(mlx_vector_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_linalg_lu_factor_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_norm_)(
+ mlx_array* res,
+ const mlx_array a,
+ double ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_norm_matrix_)(
+ mlx_array* res,
+ const mlx_array a,
+ const char* ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_norm_l2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_pinv_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_linalg_qr_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_solve_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_solve_triangular_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool upper,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_svd_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ bool compute_uv,
+ const mlx_stream s) = NULL;
+int (*mlx_linalg_tri_inv_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s) = NULL;
+mlx_map_string_to_array (*mlx_map_string_to_array_new_)(void) = NULL;
+int (*mlx_map_string_to_array_set_)(
+ mlx_map_string_to_array* map,
+ const mlx_map_string_to_array src) = NULL;
+int (*mlx_map_string_to_array_free_)(mlx_map_string_to_array map) = NULL;
+int (*mlx_map_string_to_array_insert_)(
+ mlx_map_string_to_array map,
+ const char* key,
+ const mlx_array value) = NULL;
+int (*mlx_map_string_to_array_get_)(
+ mlx_array* value,
+ const mlx_map_string_to_array map,
+ const char* key) = NULL;
+mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_)(
+ mlx_map_string_to_array map) = NULL;
+int (*mlx_map_string_to_array_iterator_free_)(mlx_map_string_to_array_iterator it) = NULL;
+int (*mlx_map_string_to_array_iterator_next_)(
+ const char** key,
+ mlx_array* value,
+ mlx_map_string_to_array_iterator it) = NULL;
+mlx_map_string_to_string (*mlx_map_string_to_string_new_)(void) = NULL;
+int (*mlx_map_string_to_string_set_)(
+ mlx_map_string_to_string* map,
+ const mlx_map_string_to_string src) = NULL;
+int (*mlx_map_string_to_string_free_)(mlx_map_string_to_string map) = NULL;
+int (*mlx_map_string_to_string_insert_)(
+ mlx_map_string_to_string map,
+ const char* key,
+ const char* value) = NULL;
+int (*mlx_map_string_to_string_get_)(
+ const char** value,
+ const mlx_map_string_to_string map,
+ const char* key) = NULL;
+mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_)(
+ mlx_map_string_to_string map) = NULL;
+int (*mlx_map_string_to_string_iterator_free_)(
+ mlx_map_string_to_string_iterator it) = NULL;
+int (*mlx_map_string_to_string_iterator_next_)(
+ const char** key,
+ const char** value,
+ mlx_map_string_to_string_iterator it) = NULL;
+int (*mlx_clear_cache_)(void) = NULL;
+int (*mlx_get_active_memory_)(size_t* res) = NULL;
+int (*mlx_get_cache_memory_)(size_t* res) = NULL;
+int (*mlx_get_memory_limit_)(size_t* res) = NULL;
+int (*mlx_get_peak_memory_)(size_t* res) = NULL;
+int (*mlx_reset_peak_memory_)(void) = NULL;
+int (*mlx_set_cache_limit_)(size_t* res, size_t limit) = NULL;
+int (*mlx_set_memory_limit_)(size_t* res, size_t limit) = NULL;
+int (*mlx_set_wired_limit_)(size_t* res, size_t limit) = NULL;
+mlx_metal_device_info_t (*mlx_metal_device_info_)(void) = NULL;
+int (*mlx_metal_is_available_)(bool* res) = NULL;
+int (*mlx_metal_start_capture_)(const char* path) = NULL;
+int (*mlx_metal_stop_capture_)(void) = NULL;
+int (*mlx_abs_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_add_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_addmm_)(
+ mlx_array* res,
+ const mlx_array c,
+ const mlx_array a,
+ const mlx_array b,
+ float alpha,
+ float beta,
+ const mlx_stream s) = NULL;
+int (*mlx_all_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_all_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_all_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_allclose_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s) = NULL;
+int (*mlx_any_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_any_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_any_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_arange_)(
+ mlx_array* res,
+ double start,
+ double stop,
+ double step,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_arccos_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_arccosh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_arcsin_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_arcsinh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_arctan_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_arctan2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_arctanh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_argmax_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_argmax_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_argmin_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_argmin_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_argpartition_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_argpartition_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s) = NULL;
+int (*mlx_argsort_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_argsort_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_array_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool equal_nan,
+ const mlx_stream s) = NULL;
+int (*mlx_as_strided_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const int64_t* strides,
+ size_t strides_num,
+ size_t offset,
+ const mlx_stream s) = NULL;
+int (*mlx_astype_)(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_bitwise_and_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_bitwise_invert_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_bitwise_or_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_bitwise_xor_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_block_masked_mm_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int block_size,
+ const mlx_array mask_out /* may be null */,
+ const mlx_array mask_lhs /* may be null */,
+ const mlx_array mask_rhs /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_broadcast_arrays_)(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_stream s) = NULL;
+int (*mlx_broadcast_to_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s) = NULL;
+int (*mlx_ceil_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_clip_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array a_min /* may be null */,
+ const mlx_array a_max /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_concatenate_axis_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_concatenate_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s) = NULL;
+int (*mlx_conjugate_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_contiguous_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool allow_col_major,
+ const mlx_stream s) = NULL;
+int (*mlx_conv1d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int groups,
+ const mlx_stream s) = NULL;
+int (*mlx_conv2d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int groups,
+ const mlx_stream s) = NULL;
+int (*mlx_conv3d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int groups,
+ const mlx_stream s) = NULL;
+int (*mlx_conv_general_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ const int* stride,
+ size_t stride_num,
+ const int* padding_lo,
+ size_t padding_lo_num,
+ const int* padding_hi,
+ size_t padding_hi_num,
+ const int* kernel_dilation,
+ size_t kernel_dilation_num,
+ const int* input_dilation,
+ size_t input_dilation_num,
+ int groups,
+ bool flip,
+ const mlx_stream s) = NULL;
+int (*mlx_conv_transpose1d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int output_padding,
+ int groups,
+ const mlx_stream s) = NULL;
+int (*mlx_conv_transpose2d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int output_padding_0,
+ int output_padding_1,
+ int groups,
+ const mlx_stream s) = NULL;
+int (*mlx_conv_transpose3d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int output_padding_0,
+ int output_padding_1,
+ int output_padding_2,
+ int groups,
+ const mlx_stream s) = NULL;
+int (*mlx_copy_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_cos_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_cosh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_cummax_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) = NULL;
+int (*mlx_cummin_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) = NULL;
+int (*mlx_cumprod_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) = NULL;
+int (*mlx_cumsum_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) = NULL;
+int (*mlx_degrees_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_depends_)(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array dependencies) = NULL;
+int (*mlx_dequantize_)(
+ mlx_array* res,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ mlx_optional_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
+int (*mlx_diagonal_)(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ const mlx_stream s) = NULL;
+int (*mlx_divide_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_divmod_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_einsum_)(
+ mlx_array* res,
+ const char* subscripts,
+ const mlx_vector_array operands,
+ const mlx_stream s) = NULL;
+int (*mlx_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_erf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_erfinv_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_exp_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_expand_dims_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_expand_dims_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_expm1_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_eye_)(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_flatten_)(
+ mlx_array* res,
+ const mlx_array a,
+ int start_axis,
+ int end_axis,
+ const mlx_stream s) = NULL;
+int (*mlx_floor_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_floor_divide_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_from_fp8_)(
+ mlx_array* res,
+ const mlx_array x,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_full_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_full_like_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_gather_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_sizes,
+ size_t slice_sizes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_gather_mm_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool sorted_indices,
+ const mlx_stream s) = NULL;
+int (*mlx_gather_qmm_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ bool sorted_indices,
+ const mlx_stream s) = NULL;
+int (*mlx_greater_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_greater_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_hadamard_transform_)(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_optional_float scale,
+ const mlx_stream s) = NULL;
+int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) = NULL;
+int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_inner_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_isclose_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s) = NULL;
+int (*mlx_isfinite_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_isinf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_isnan_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_isneginf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_isposinf_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_kron_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_left_shift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_less_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_less_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_linspace_)(
+ mlx_array* res,
+ double start,
+ double stop,
+ int num,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_log_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_log10_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_log1p_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_log2_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_logaddexp_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_logcumsumexp_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) = NULL;
+int (*mlx_logical_and_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_logical_not_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_logical_or_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_logsumexp_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_logsumexp_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_logsumexp_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_masked_scatter_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array mask,
+ const mlx_array src,
+ const mlx_stream s) = NULL;
+int (*mlx_matmul_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_max_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_max_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_max_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_maximum_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_mean_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_mean_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_mean_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_median_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_meshgrid_)(
+ mlx_vector_array* res,
+ const mlx_vector_array arrays,
+ bool sparse,
+ const char* indexing,
+ const mlx_stream s) = NULL;
+int (*mlx_min_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_min_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_min_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_minimum_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_moveaxis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int source,
+ int destination,
+ const mlx_stream s) = NULL;
+int (*mlx_multiply_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_nan_to_num_)(
+ mlx_array* res,
+ const mlx_array a,
+ float nan,
+ mlx_optional_float posinf,
+ mlx_optional_float neginf,
+ const mlx_stream s) = NULL;
+int (*mlx_negative_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_not_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_number_of_elements_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool inverted,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_ones_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_ones_like_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_outer_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_pad_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const int* low_pad_size,
+ size_t low_pad_size_num,
+ const int* high_pad_size,
+ size_t high_pad_size_num,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s) = NULL;
+int (*mlx_pad_symmetric_)(
+ mlx_array* res,
+ const mlx_array a,
+ int pad_width,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s) = NULL;
+int (*mlx_partition_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_partition_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s) = NULL;
+int (*mlx_power_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_prod_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_prod_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_prod_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_put_along_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_quantize_)(
+ mlx_vector_array* res,
+ const mlx_array w,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s) = NULL;
+int (*mlx_quantized_matmul_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s) = NULL;
+int (*mlx_radians_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_real_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_reciprocal_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_remainder_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_repeat_axis_)(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_repeat_)(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ const mlx_stream s) = NULL;
+int (*mlx_reshape_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s) = NULL;
+int (*mlx_right_shift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_roll_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_roll_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_roll_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const mlx_stream s) = NULL;
+int (*mlx_round_)(
+ mlx_array* res,
+ const mlx_array a,
+ int decimals,
+ const mlx_stream s) = NULL;
+int (*mlx_rsqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_scatter_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_scatter_add_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_scatter_add_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_scatter_max_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_scatter_min_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_scatter_prod_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_segmented_mm_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array segments,
+ const mlx_stream s) = NULL;
+int (*mlx_sigmoid_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_sign_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_sin_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_sinh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_slice_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s) = NULL;
+int (*mlx_slice_dynamic_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_size,
+ size_t slice_size_num,
+ const mlx_stream s) = NULL;
+int (*mlx_slice_update_)(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s) = NULL;
+int (*mlx_slice_update_dynamic_)(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_softmax_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool precise,
+ const mlx_stream s) = NULL;
+int (*mlx_softmax_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool precise,
+ const mlx_stream s) = NULL;
+int (*mlx_softmax_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool precise,
+ const mlx_stream s) = NULL;
+int (*mlx_sort_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_sort_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_split_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ int num_splits,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_split_sections_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const int* indices,
+ size_t indices_num,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_sqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_square_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_squeeze_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_squeeze_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_squeeze_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_stack_axis_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_stack_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s) = NULL;
+int (*mlx_std_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) = NULL;
+int (*mlx_std_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) = NULL;
+int (*mlx_std_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) = NULL;
+int (*mlx_stop_gradient_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_subtract_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) = NULL;
+int (*mlx_sum_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_sum_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_sum_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) = NULL;
+int (*mlx_swapaxes_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis1,
+ int axis2,
+ const mlx_stream s) = NULL;
+int (*mlx_take_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_take_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_stream s) = NULL;
+int (*mlx_take_along_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_tan_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_tanh_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_tensordot_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const int* axes_a,
+ size_t axes_a_num,
+ const int* axes_b,
+ size_t axes_b_num,
+ const mlx_stream s) = NULL;
+int (*mlx_tensordot_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_tile_)(
+ mlx_array* res,
+ const mlx_array arr,
+ const int* reps,
+ size_t reps_num,
+ const mlx_stream s) = NULL;
+int (*mlx_to_fp8_)(mlx_array* res, const mlx_array x, const mlx_stream s) = NULL;
+int (*mlx_topk_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int k,
+ int axis,
+ const mlx_stream s) = NULL;
+int (*mlx_topk_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s) = NULL;
+int (*mlx_trace_)(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_transpose_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) = NULL;
+int (*mlx_transpose_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_tri_)(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype type,
+ const mlx_stream s) = NULL;
+int (*mlx_tril_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL;
+int (*mlx_triu_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s) = NULL;
+int (*mlx_unflatten_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s) = NULL;
+int (*mlx_var_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) = NULL;
+int (*mlx_var_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) = NULL;
+int (*mlx_var_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) = NULL;
+int (*mlx_view_)(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_where_)(
+ mlx_array* res,
+ const mlx_array condition,
+ const mlx_array x,
+ const mlx_array y,
+ const mlx_stream s) = NULL;
+int (*mlx_zeros_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s) = NULL;
+int (*mlx_zeros_like_)(mlx_array* res, const mlx_array a, const mlx_stream s) = NULL;
+int (*mlx_random_bernoulli_)(
+ mlx_array* res,
+ const mlx_array p,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_bits_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ int width,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_categorical_shape_)(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_categorical_num_samples_)(
+ mlx_array* res,
+ const mlx_array logits_,
+ int axis,
+ int num_samples,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_categorical_)(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_gumbel_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_key_)(mlx_array* res, uint64_t seed) = NULL;
+int (*mlx_random_laplace_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_multivariate_normal_)(
+ mlx_array* res,
+ const mlx_array mean,
+ const mlx_array cov,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_normal_broadcast_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array loc /* may be null */,
+ const mlx_array scale /* may be null */,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_normal_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_permutation_)(
+ mlx_array* res,
+ const mlx_array x,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_permutation_arange_)(
+ mlx_array* res,
+ int x,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_randint_)(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_seed_)(uint64_t seed) = NULL;
+int (*mlx_random_split_num_)(
+ mlx_array* res,
+ const mlx_array key,
+ int num,
+ const mlx_stream s) = NULL;
+int (*mlx_random_split_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array key,
+ const mlx_stream s) = NULL;
+int (*mlx_random_truncated_normal_)(
+ mlx_array* res,
+ const mlx_array lower,
+ const mlx_array upper,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+int (*mlx_random_uniform_)(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) = NULL;
+mlx_stream (*mlx_stream_new_)(void) = NULL;
+mlx_stream (*mlx_stream_new_device_)(mlx_device dev) = NULL;
+int (*mlx_stream_set_)(mlx_stream* stream, const mlx_stream src) = NULL;
+int (*mlx_stream_free_)(mlx_stream stream) = NULL;
+int (*mlx_stream_tostring_)(mlx_string* str, mlx_stream stream) = NULL;
+bool (*mlx_stream_equal_)(mlx_stream lhs, mlx_stream rhs) = NULL;
+int (*mlx_stream_get_device_)(mlx_device* dev, mlx_stream stream) = NULL;
+int (*mlx_stream_get_index_)(int* index, mlx_stream stream) = NULL;
+int (*mlx_synchronize_)(mlx_stream stream) = NULL;
+int (*mlx_get_default_stream_)(mlx_stream* stream, mlx_device dev) = NULL;
+int (*mlx_set_default_stream_)(mlx_stream stream) = NULL;
+mlx_stream (*mlx_default_cpu_stream_new_)(void) = NULL;
+mlx_stream (*mlx_default_gpu_stream_new_)(void) = NULL;
+mlx_string (*mlx_string_new_)(void) = NULL;
+mlx_string (*mlx_string_new_data_)(const char* str) = NULL;
+int (*mlx_string_set_)(mlx_string* str, const mlx_string src) = NULL;
+const char * (*mlx_string_data_)(mlx_string str) = NULL;
+int (*mlx_string_free_)(mlx_string str) = NULL;
+int (*mlx_detail_vmap_replace_)(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array s_inputs,
+ const mlx_vector_array s_outputs,
+ const int* in_axes,
+ size_t in_axes_num,
+ const int* out_axes,
+ size_t out_axes_num) = NULL;
+int (*mlx_detail_vmap_trace_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array inputs,
+ const int* in_axes,
+ size_t in_axes_num) = NULL;
+int (*mlx_async_eval_)(const mlx_vector_array outputs) = NULL;
+int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun) = NULL;
+int (*mlx_custom_function_)(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp /* may be null */,
+ const mlx_closure_custom_jvp fun_jvp /* may be null */,
+ const mlx_closure_custom_vmap fun_vmap /* may be null */) = NULL;
+int (*mlx_custom_vjp_)(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp) = NULL;
+int (*mlx_eval_)(const mlx_vector_array outputs) = NULL;
+int (*mlx_jvp_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array tangents) = NULL;
+int (*mlx_value_and_grad_)(
+ mlx_closure_value_and_grad* res,
+ const mlx_closure fun,
+ const int* argnums,
+ size_t argnums_num) = NULL;
+int (*mlx_vjp_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array cotangents) = NULL;
+mlx_vector_array (*mlx_vector_array_new_)(void) = NULL;
+int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src) = NULL;
+int (*mlx_vector_array_free_)(mlx_vector_array vec) = NULL;
+mlx_vector_array (*mlx_vector_array_new_data_)(const mlx_array* data, size_t size) = NULL;
+mlx_vector_array (*mlx_vector_array_new_value_)(const mlx_array val) = NULL;
+int (*mlx_vector_array_set_data_)(
+ mlx_vector_array* vec,
+ const mlx_array* data,
+ size_t size) = NULL;
+int (*mlx_vector_array_set_value_)(mlx_vector_array* vec, const mlx_array val) = NULL;
+int (*mlx_vector_array_append_data_)(
+ mlx_vector_array vec,
+ const mlx_array* data,
+ size_t size) = NULL;
+int (*mlx_vector_array_append_value_)(mlx_vector_array vec, const mlx_array val) = NULL;
+size_t (*mlx_vector_array_size_)(mlx_vector_array vec) = NULL;
+int (*mlx_vector_array_get_)(
+ mlx_array* res,
+ const mlx_vector_array vec,
+ size_t idx) = NULL;
+mlx_vector_vector_array (*mlx_vector_vector_array_new_)(void) = NULL;
+int (*mlx_vector_vector_array_set_)(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_vector_array src) = NULL;
+int (*mlx_vector_vector_array_free_)(mlx_vector_vector_array vec) = NULL;
+mlx_vector_vector_array (*mlx_vector_vector_array_new_data_)(
+ const mlx_vector_array* data,
+ size_t size) = NULL;
+mlx_vector_vector_array (*mlx_vector_vector_array_new_value_)(
+ const mlx_vector_array val) = NULL;
+int (*mlx_vector_vector_array_set_data_)(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array* data,
+ size_t size) = NULL;
+int (*mlx_vector_vector_array_set_value_)(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array val) = NULL;
+int (*mlx_vector_vector_array_append_data_)(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array* data,
+ size_t size) = NULL;
+int (*mlx_vector_vector_array_append_value_)(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array val) = NULL;
+size_t (*mlx_vector_vector_array_size_)(mlx_vector_vector_array vec) = NULL;
+int (*mlx_vector_vector_array_get_)(
+ mlx_vector_array* res,
+ const mlx_vector_vector_array vec,
+ size_t idx) = NULL;
+mlx_vector_int (*mlx_vector_int_new_)(void) = NULL;
+int (*mlx_vector_int_set_)(mlx_vector_int* vec, const mlx_vector_int src) = NULL;
+int (*mlx_vector_int_free_)(mlx_vector_int vec) = NULL;
+mlx_vector_int (*mlx_vector_int_new_data_)(int* data, size_t size) = NULL;
+mlx_vector_int (*mlx_vector_int_new_value_)(int val) = NULL;
+int (*mlx_vector_int_set_data_)(mlx_vector_int* vec, int* data, size_t size) = NULL;
+int (*mlx_vector_int_set_value_)(mlx_vector_int* vec, int val) = NULL;
+int (*mlx_vector_int_append_data_)(mlx_vector_int vec, int* data, size_t size) = NULL;
+int (*mlx_vector_int_append_value_)(mlx_vector_int vec, int val) = NULL;
+size_t (*mlx_vector_int_size_)(mlx_vector_int vec) = NULL;
+int (*mlx_vector_int_get_)(int* res, const mlx_vector_int vec, size_t idx) = NULL;
+mlx_vector_string (*mlx_vector_string_new_)(void) = NULL;
+int (*mlx_vector_string_set_)(mlx_vector_string* vec, const mlx_vector_string src) = NULL;
+int (*mlx_vector_string_free_)(mlx_vector_string vec) = NULL;
+mlx_vector_string (*mlx_vector_string_new_data_)(const char** data, size_t size) = NULL;
+mlx_vector_string (*mlx_vector_string_new_value_)(const char* val) = NULL;
+int (*mlx_vector_string_set_data_)(
+ mlx_vector_string* vec,
+ const char** data,
+ size_t size) = NULL;
+int (*mlx_vector_string_set_value_)(mlx_vector_string* vec, const char* val) = NULL;
+int (*mlx_vector_string_append_data_)(
+ mlx_vector_string vec,
+ const char** data,
+ size_t size) = NULL;
+int (*mlx_vector_string_append_value_)(mlx_vector_string vec, const char* val) = NULL;
+size_t (*mlx_vector_string_size_)(mlx_vector_string vec) = NULL;
+int (*mlx_vector_string_get_)(char** res, const mlx_vector_string vec, size_t idx) = NULL;
+int (*mlx_version_)(mlx_string* str_) = NULL;
+
+int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
+ CHECK_LOAD(handle, mlx_dtype_size);
+ CHECK_LOAD(handle, mlx_array_tostring);
+ CHECK_LOAD(handle, mlx_array_new);
+ CHECK_LOAD(handle, mlx_array_free);
+ CHECK_LOAD(handle, mlx_array_new_bool);
+ CHECK_LOAD(handle, mlx_array_new_int);
+ CHECK_LOAD(handle, mlx_array_new_float32);
+ CHECK_LOAD(handle, mlx_array_new_float);
+ CHECK_LOAD(handle, mlx_array_new_float64);
+ CHECK_LOAD(handle, mlx_array_new_double);
+ CHECK_LOAD(handle, mlx_array_new_complex);
+ CHECK_LOAD(handle, mlx_array_new_data);
+ CHECK_LOAD(handle, mlx_array_set);
+ CHECK_LOAD(handle, mlx_array_set_bool);
+ CHECK_LOAD(handle, mlx_array_set_int);
+ CHECK_LOAD(handle, mlx_array_set_float32);
+ CHECK_LOAD(handle, mlx_array_set_float);
+ CHECK_LOAD(handle, mlx_array_set_float64);
+ CHECK_LOAD(handle, mlx_array_set_double);
+ CHECK_LOAD(handle, mlx_array_set_complex);
+ CHECK_LOAD(handle, mlx_array_set_data);
+ CHECK_LOAD(handle, mlx_array_itemsize);
+ CHECK_LOAD(handle, mlx_array_size);
+ CHECK_LOAD(handle, mlx_array_nbytes);
+ CHECK_LOAD(handle, mlx_array_ndim);
+ CHECK_LOAD(handle, mlx_array_shape);
+ CHECK_LOAD(handle, mlx_array_strides);
+ CHECK_LOAD(handle, mlx_array_dim);
+ CHECK_LOAD(handle, mlx_array_dtype);
+ CHECK_LOAD(handle, mlx_array_eval);
+ CHECK_LOAD(handle, mlx_array_item_bool);
+ CHECK_LOAD(handle, mlx_array_item_uint8);
+ CHECK_LOAD(handle, mlx_array_item_uint16);
+ CHECK_LOAD(handle, mlx_array_item_uint32);
+ CHECK_LOAD(handle, mlx_array_item_uint64);
+ CHECK_LOAD(handle, mlx_array_item_int8);
+ CHECK_LOAD(handle, mlx_array_item_int16);
+ CHECK_LOAD(handle, mlx_array_item_int32);
+ CHECK_LOAD(handle, mlx_array_item_int64);
+ CHECK_LOAD(handle, mlx_array_item_float32);
+ CHECK_LOAD(handle, mlx_array_item_float64);
+ CHECK_LOAD(handle, mlx_array_item_complex64);
+ CHECK_LOAD(handle, mlx_array_item_float16);
+ CHECK_LOAD(handle, mlx_array_item_bfloat16);
+ CHECK_LOAD(handle, mlx_array_data_bool);
+ CHECK_LOAD(handle, mlx_array_data_uint8);
+ CHECK_LOAD(handle, mlx_array_data_uint16);
+ CHECK_LOAD(handle, mlx_array_data_uint32);
+ CHECK_LOAD(handle, mlx_array_data_uint64);
+ CHECK_LOAD(handle, mlx_array_data_int8);
+ CHECK_LOAD(handle, mlx_array_data_int16);
+ CHECK_LOAD(handle, mlx_array_data_int32);
+ CHECK_LOAD(handle, mlx_array_data_int64);
+ CHECK_LOAD(handle, mlx_array_data_float32);
+ CHECK_LOAD(handle, mlx_array_data_float64);
+ CHECK_LOAD(handle, mlx_array_data_complex64);
+ CHECK_LOAD(handle, mlx_array_data_float16);
+ CHECK_LOAD(handle, mlx_array_data_bfloat16);
+ CHECK_LOAD(handle, _mlx_array_is_available);
+ CHECK_LOAD(handle, _mlx_array_wait);
+ CHECK_LOAD(handle, _mlx_array_is_contiguous);
+ CHECK_LOAD(handle, _mlx_array_is_row_contiguous);
+ CHECK_LOAD(handle, _mlx_array_is_col_contiguous);
+ CHECK_LOAD(handle, mlx_closure_new);
+ CHECK_LOAD(handle, mlx_closure_free);
+ CHECK_LOAD(handle, mlx_closure_new_func);
+ CHECK_LOAD(handle, mlx_closure_new_func_payload);
+ CHECK_LOAD(handle, mlx_closure_set);
+ CHECK_LOAD(handle, mlx_closure_apply);
+ CHECK_LOAD(handle, mlx_closure_new_unary);
+ CHECK_LOAD(handle, mlx_closure_kwargs_new);
+ CHECK_LOAD(handle, mlx_closure_kwargs_free);
+ CHECK_LOAD(handle, mlx_closure_kwargs_new_func);
+ CHECK_LOAD(handle, mlx_closure_kwargs_new_func_payload);
+ CHECK_LOAD(handle, mlx_closure_kwargs_set);
+ CHECK_LOAD(handle, mlx_closure_kwargs_apply);
+ CHECK_LOAD(handle, mlx_closure_value_and_grad_new);
+ CHECK_LOAD(handle, mlx_closure_value_and_grad_free);
+ CHECK_LOAD(handle, mlx_closure_value_and_grad_new_func);
+ CHECK_LOAD(handle, mlx_closure_value_and_grad_new_func_payload);
+ CHECK_LOAD(handle, mlx_closure_value_and_grad_set);
+ CHECK_LOAD(handle, mlx_closure_value_and_grad_apply);
+ CHECK_LOAD(handle, mlx_closure_custom_new);
+ CHECK_LOAD(handle, mlx_closure_custom_free);
+ CHECK_LOAD(handle, mlx_closure_custom_new_func);
+ CHECK_LOAD(handle, mlx_closure_custom_new_func_payload);
+ CHECK_LOAD(handle, mlx_closure_custom_set);
+ CHECK_LOAD(handle, mlx_closure_custom_apply);
+ CHECK_LOAD(handle, mlx_closure_custom_jvp_new);
+ CHECK_LOAD(handle, mlx_closure_custom_jvp_free);
+ CHECK_LOAD(handle, mlx_closure_custom_jvp_new_func);
+ CHECK_LOAD(handle, mlx_closure_custom_jvp_new_func_payload);
+ CHECK_LOAD(handle, mlx_closure_custom_jvp_set);
+ CHECK_LOAD(handle, mlx_closure_custom_jvp_apply);
+ CHECK_LOAD(handle, mlx_closure_custom_vmap_new);
+ CHECK_LOAD(handle, mlx_closure_custom_vmap_free);
+ CHECK_LOAD(handle, mlx_closure_custom_vmap_new_func);
+ CHECK_LOAD(handle, mlx_closure_custom_vmap_new_func_payload);
+ CHECK_LOAD(handle, mlx_closure_custom_vmap_set);
+ CHECK_LOAD(handle, mlx_closure_custom_vmap_apply);
+ CHECK_LOAD(handle, mlx_compile);
+ CHECK_LOAD(handle, mlx_detail_compile);
+ CHECK_LOAD(handle, mlx_detail_compile_clear_cache);
+ CHECK_LOAD(handle, mlx_detail_compile_erase);
+ CHECK_LOAD(handle, mlx_disable_compile);
+ CHECK_LOAD(handle, mlx_enable_compile);
+ CHECK_LOAD(handle, mlx_set_compile_mode);
+ CHECK_LOAD(handle, mlx_device_new);
+ CHECK_LOAD(handle, mlx_device_new_type);
+ CHECK_LOAD(handle, mlx_device_free);
+ CHECK_LOAD(handle, mlx_device_set);
+ CHECK_LOAD(handle, mlx_device_tostring);
+ CHECK_LOAD(handle, mlx_device_equal);
+ CHECK_LOAD(handle, mlx_device_get_index);
+ CHECK_LOAD(handle, mlx_device_get_type);
+ CHECK_LOAD(handle, mlx_get_default_device);
+ CHECK_LOAD(handle, mlx_set_default_device);
+ CHECK_LOAD(handle, mlx_distributed_group_rank);
+ CHECK_LOAD(handle, mlx_distributed_group_size);
+ CHECK_LOAD(handle, mlx_distributed_group_split);
+ CHECK_LOAD(handle, mlx_distributed_is_available);
+ CHECK_LOAD(handle, mlx_distributed_init);
+ CHECK_LOAD(handle, mlx_distributed_all_gather);
+ CHECK_LOAD(handle, mlx_distributed_all_max);
+ CHECK_LOAD(handle, mlx_distributed_all_min);
+ CHECK_LOAD(handle, mlx_distributed_all_sum);
+ CHECK_LOAD(handle, mlx_distributed_recv);
+ CHECK_LOAD(handle, mlx_distributed_recv_like);
+ CHECK_LOAD(handle, mlx_distributed_send);
+ CHECK_LOAD(handle, mlx_distributed_sum_scatter);
+ CHECK_LOAD(handle, mlx_set_error_handler);
+ CHECK_LOAD(handle, _mlx_error);
+ CHECK_LOAD(handle, mlx_export_function);
+ CHECK_LOAD(handle, mlx_export_function_kwargs);
+ CHECK_LOAD(handle, mlx_function_exporter_new);
+ CHECK_LOAD(handle, mlx_function_exporter_free);
+ CHECK_LOAD(handle, mlx_function_exporter_apply);
+ CHECK_LOAD(handle, mlx_function_exporter_apply_kwargs);
+ CHECK_LOAD(handle, mlx_imported_function_new);
+ CHECK_LOAD(handle, mlx_imported_function_free);
+ CHECK_LOAD(handle, mlx_imported_function_apply);
+ CHECK_LOAD(handle, mlx_imported_function_apply_kwargs);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_new);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_free);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_output_arg);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_grid);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_thread_group);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_init_value);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_set_verbose);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_template_arg_dtype);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_template_arg_int);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_config_add_template_arg_bool);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_new);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_free);
+ CHECK_LOAD(handle, mlx_fast_cuda_kernel_apply);
+ CHECK_LOAD(handle, mlx_fast_layer_norm);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_new);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_free);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_output_arg);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_grid);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_thread_group);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_init_value);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_set_verbose);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_template_arg_dtype);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_template_arg_int);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_config_add_template_arg_bool);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_new);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_free);
+ CHECK_LOAD(handle, mlx_fast_metal_kernel_apply);
+ CHECK_LOAD(handle, mlx_fast_rms_norm);
+ CHECK_LOAD(handle, mlx_fast_rope);
+ CHECK_LOAD(handle, mlx_fast_scaled_dot_product_attention);
+ CHECK_LOAD(handle, mlx_fft_fft);
+ CHECK_LOAD(handle, mlx_fft_fft2);
+ CHECK_LOAD(handle, mlx_fft_fftn);
+ CHECK_LOAD(handle, mlx_fft_fftshift);
+ CHECK_LOAD(handle, mlx_fft_ifft);
+ CHECK_LOAD(handle, mlx_fft_ifft2);
+ CHECK_LOAD(handle, mlx_fft_ifftn);
+ CHECK_LOAD(handle, mlx_fft_ifftshift);
+ CHECK_LOAD(handle, mlx_fft_irfft);
+ CHECK_LOAD(handle, mlx_fft_irfft2);
+ CHECK_LOAD(handle, mlx_fft_irfftn);
+ CHECK_LOAD(handle, mlx_fft_rfft);
+ CHECK_LOAD(handle, mlx_fft_rfft2);
+ CHECK_LOAD(handle, mlx_fft_rfftn);
+ CHECK_LOAD(handle, mlx_io_reader_new);
+ CHECK_LOAD(handle, mlx_io_reader_descriptor);
+ CHECK_LOAD(handle, mlx_io_reader_tostring);
+ CHECK_LOAD(handle, mlx_io_reader_free);
+ CHECK_LOAD(handle, mlx_io_writer_new);
+ CHECK_LOAD(handle, mlx_io_writer_descriptor);
+ CHECK_LOAD(handle, mlx_io_writer_tostring);
+ CHECK_LOAD(handle, mlx_io_writer_free);
+ CHECK_LOAD(handle, mlx_load_reader);
+ CHECK_LOAD(handle, mlx_load);
+ CHECK_LOAD(handle, mlx_load_safetensors_reader);
+ CHECK_LOAD(handle, mlx_load_safetensors);
+ CHECK_LOAD(handle, mlx_save_writer);
+ CHECK_LOAD(handle, mlx_save);
+ CHECK_LOAD(handle, mlx_save_safetensors_writer);
+ CHECK_LOAD(handle, mlx_save_safetensors);
+ CHECK_LOAD(handle, mlx_linalg_cholesky);
+ CHECK_LOAD(handle, mlx_linalg_cholesky_inv);
+ CHECK_LOAD(handle, mlx_linalg_cross);
+ CHECK_LOAD(handle, mlx_linalg_eig);
+ CHECK_LOAD(handle, mlx_linalg_eigh);
+ CHECK_LOAD(handle, mlx_linalg_eigvals);
+ CHECK_LOAD(handle, mlx_linalg_eigvalsh);
+ CHECK_LOAD(handle, mlx_linalg_inv);
+ CHECK_LOAD(handle, mlx_linalg_lu);
+ CHECK_LOAD(handle, mlx_linalg_lu_factor);
+ CHECK_LOAD(handle, mlx_linalg_norm);
+ CHECK_LOAD(handle, mlx_linalg_norm_matrix);
+ CHECK_LOAD(handle, mlx_linalg_norm_l2);
+ CHECK_LOAD(handle, mlx_linalg_pinv);
+ CHECK_LOAD(handle, mlx_linalg_qr);
+ CHECK_LOAD(handle, mlx_linalg_solve);
+ CHECK_LOAD(handle, mlx_linalg_solve_triangular);
+ CHECK_LOAD(handle, mlx_linalg_svd);
+ CHECK_LOAD(handle, mlx_linalg_tri_inv);
+ CHECK_LOAD(handle, mlx_map_string_to_array_new);
+ CHECK_LOAD(handle, mlx_map_string_to_array_set);
+ CHECK_LOAD(handle, mlx_map_string_to_array_free);
+ CHECK_LOAD(handle, mlx_map_string_to_array_insert);
+ CHECK_LOAD(handle, mlx_map_string_to_array_get);
+ CHECK_LOAD(handle, mlx_map_string_to_array_iterator_new);
+ CHECK_LOAD(handle, mlx_map_string_to_array_iterator_free);
+ CHECK_LOAD(handle, mlx_map_string_to_array_iterator_next);
+ CHECK_LOAD(handle, mlx_map_string_to_string_new);
+ CHECK_LOAD(handle, mlx_map_string_to_string_set);
+ CHECK_LOAD(handle, mlx_map_string_to_string_free);
+ CHECK_LOAD(handle, mlx_map_string_to_string_insert);
+ CHECK_LOAD(handle, mlx_map_string_to_string_get);
+ CHECK_LOAD(handle, mlx_map_string_to_string_iterator_new);
+ CHECK_LOAD(handle, mlx_map_string_to_string_iterator_free);
+ CHECK_LOAD(handle, mlx_map_string_to_string_iterator_next);
+ CHECK_LOAD(handle, mlx_clear_cache);
+ CHECK_LOAD(handle, mlx_get_active_memory);
+ CHECK_LOAD(handle, mlx_get_cache_memory);
+ CHECK_LOAD(handle, mlx_get_memory_limit);
+ CHECK_LOAD(handle, mlx_get_peak_memory);
+ CHECK_LOAD(handle, mlx_reset_peak_memory);
+ CHECK_LOAD(handle, mlx_set_cache_limit);
+ CHECK_LOAD(handle, mlx_set_memory_limit);
+ CHECK_LOAD(handle, mlx_set_wired_limit);
+ CHECK_LOAD(handle, mlx_metal_device_info);
+ CHECK_LOAD(handle, mlx_metal_is_available);
+ CHECK_LOAD(handle, mlx_metal_start_capture);
+ CHECK_LOAD(handle, mlx_metal_stop_capture);
+ CHECK_LOAD(handle, mlx_abs);
+ CHECK_LOAD(handle, mlx_add);
+ CHECK_LOAD(handle, mlx_addmm);
+ CHECK_LOAD(handle, mlx_all_axes);
+ CHECK_LOAD(handle, mlx_all_axis);
+ CHECK_LOAD(handle, mlx_all);
+ CHECK_LOAD(handle, mlx_allclose);
+ CHECK_LOAD(handle, mlx_any_axes);
+ CHECK_LOAD(handle, mlx_any_axis);
+ CHECK_LOAD(handle, mlx_any);
+ CHECK_LOAD(handle, mlx_arange);
+ CHECK_LOAD(handle, mlx_arccos);
+ CHECK_LOAD(handle, mlx_arccosh);
+ CHECK_LOAD(handle, mlx_arcsin);
+ CHECK_LOAD(handle, mlx_arcsinh);
+ CHECK_LOAD(handle, mlx_arctan);
+ CHECK_LOAD(handle, mlx_arctan2);
+ CHECK_LOAD(handle, mlx_arctanh);
+ CHECK_LOAD(handle, mlx_argmax_axis);
+ CHECK_LOAD(handle, mlx_argmax);
+ CHECK_LOAD(handle, mlx_argmin_axis);
+ CHECK_LOAD(handle, mlx_argmin);
+ CHECK_LOAD(handle, mlx_argpartition_axis);
+ CHECK_LOAD(handle, mlx_argpartition);
+ CHECK_LOAD(handle, mlx_argsort_axis);
+ CHECK_LOAD(handle, mlx_argsort);
+ CHECK_LOAD(handle, mlx_array_equal);
+ CHECK_LOAD(handle, mlx_as_strided);
+ CHECK_LOAD(handle, mlx_astype);
+ CHECK_LOAD(handle, mlx_atleast_1d);
+ CHECK_LOAD(handle, mlx_atleast_2d);
+ CHECK_LOAD(handle, mlx_atleast_3d);
+ CHECK_LOAD(handle, mlx_bitwise_and);
+ CHECK_LOAD(handle, mlx_bitwise_invert);
+ CHECK_LOAD(handle, mlx_bitwise_or);
+ CHECK_LOAD(handle, mlx_bitwise_xor);
+ CHECK_LOAD(handle, mlx_block_masked_mm);
+ CHECK_LOAD(handle, mlx_broadcast_arrays);
+ CHECK_LOAD(handle, mlx_broadcast_to);
+ CHECK_LOAD(handle, mlx_ceil);
+ CHECK_LOAD(handle, mlx_clip);
+ CHECK_LOAD(handle, mlx_concatenate_axis);
+ CHECK_LOAD(handle, mlx_concatenate);
+ CHECK_LOAD(handle, mlx_conjugate);
+ CHECK_LOAD(handle, mlx_contiguous);
+ CHECK_LOAD(handle, mlx_conv1d);
+ CHECK_LOAD(handle, mlx_conv2d);
+ CHECK_LOAD(handle, mlx_conv3d);
+ CHECK_LOAD(handle, mlx_conv_general);
+ CHECK_LOAD(handle, mlx_conv_transpose1d);
+ CHECK_LOAD(handle, mlx_conv_transpose2d);
+ CHECK_LOAD(handle, mlx_conv_transpose3d);
+ CHECK_LOAD(handle, mlx_copy);
+ CHECK_LOAD(handle, mlx_cos);
+ CHECK_LOAD(handle, mlx_cosh);
+ CHECK_LOAD(handle, mlx_cummax);
+ CHECK_LOAD(handle, mlx_cummin);
+ CHECK_LOAD(handle, mlx_cumprod);
+ CHECK_LOAD(handle, mlx_cumsum);
+ CHECK_LOAD(handle, mlx_degrees);
+ CHECK_LOAD(handle, mlx_depends);
+ CHECK_LOAD(handle, mlx_dequantize);
+ CHECK_LOAD(handle, mlx_diag);
+ CHECK_LOAD(handle, mlx_diagonal);
+ CHECK_LOAD(handle, mlx_divide);
+ CHECK_LOAD(handle, mlx_divmod);
+ CHECK_LOAD(handle, mlx_einsum);
+ CHECK_LOAD(handle, mlx_equal);
+ CHECK_LOAD(handle, mlx_erf);
+ CHECK_LOAD(handle, mlx_erfinv);
+ CHECK_LOAD(handle, mlx_exp);
+ CHECK_LOAD(handle, mlx_expand_dims_axes);
+ CHECK_LOAD(handle, mlx_expand_dims);
+ CHECK_LOAD(handle, mlx_expm1);
+ CHECK_LOAD(handle, mlx_eye);
+ CHECK_LOAD(handle, mlx_flatten);
+ CHECK_LOAD(handle, mlx_floor);
+ CHECK_LOAD(handle, mlx_floor_divide);
+ CHECK_LOAD(handle, mlx_from_fp8);
+ CHECK_LOAD(handle, mlx_full);
+ CHECK_LOAD(handle, mlx_full_like);
+ CHECK_LOAD(handle, mlx_gather);
+ CHECK_LOAD(handle, mlx_gather_mm);
+ CHECK_LOAD(handle, mlx_gather_qmm);
+ CHECK_LOAD(handle, mlx_greater);
+ CHECK_LOAD(handle, mlx_greater_equal);
+ CHECK_LOAD(handle, mlx_hadamard_transform);
+ CHECK_LOAD(handle, mlx_identity);
+ CHECK_LOAD(handle, mlx_imag);
+ CHECK_LOAD(handle, mlx_inner);
+ CHECK_LOAD(handle, mlx_isclose);
+ CHECK_LOAD(handle, mlx_isfinite);
+ CHECK_LOAD(handle, mlx_isinf);
+ CHECK_LOAD(handle, mlx_isnan);
+ CHECK_LOAD(handle, mlx_isneginf);
+ CHECK_LOAD(handle, mlx_isposinf);
+ CHECK_LOAD(handle, mlx_kron);
+ CHECK_LOAD(handle, mlx_left_shift);
+ CHECK_LOAD(handle, mlx_less);
+ CHECK_LOAD(handle, mlx_less_equal);
+ CHECK_LOAD(handle, mlx_linspace);
+ CHECK_LOAD(handle, mlx_log);
+ CHECK_LOAD(handle, mlx_log10);
+ CHECK_LOAD(handle, mlx_log1p);
+ CHECK_LOAD(handle, mlx_log2);
+ CHECK_LOAD(handle, mlx_logaddexp);
+ CHECK_LOAD(handle, mlx_logcumsumexp);
+ CHECK_LOAD(handle, mlx_logical_and);
+ CHECK_LOAD(handle, mlx_logical_not);
+ CHECK_LOAD(handle, mlx_logical_or);
+ CHECK_LOAD(handle, mlx_logsumexp_axes);
+ CHECK_LOAD(handle, mlx_logsumexp_axis);
+ CHECK_LOAD(handle, mlx_logsumexp);
+ CHECK_LOAD(handle, mlx_masked_scatter);
+ CHECK_LOAD(handle, mlx_matmul);
+ CHECK_LOAD(handle, mlx_max_axes);
+ CHECK_LOAD(handle, mlx_max_axis);
+ CHECK_LOAD(handle, mlx_max);
+ CHECK_LOAD(handle, mlx_maximum);
+ CHECK_LOAD(handle, mlx_mean_axes);
+ CHECK_LOAD(handle, mlx_mean_axis);
+ CHECK_LOAD(handle, mlx_mean);
+ CHECK_LOAD(handle, mlx_median);
+ CHECK_LOAD(handle, mlx_meshgrid);
+ CHECK_LOAD(handle, mlx_min_axes);
+ CHECK_LOAD(handle, mlx_min_axis);
+ CHECK_LOAD(handle, mlx_min);
+ CHECK_LOAD(handle, mlx_minimum);
+ CHECK_LOAD(handle, mlx_moveaxis);
+ CHECK_LOAD(handle, mlx_multiply);
+ CHECK_LOAD(handle, mlx_nan_to_num);
+ CHECK_LOAD(handle, mlx_negative);
+ CHECK_LOAD(handle, mlx_not_equal);
+ CHECK_LOAD(handle, mlx_number_of_elements);
+ CHECK_LOAD(handle, mlx_ones);
+ CHECK_LOAD(handle, mlx_ones_like);
+ CHECK_LOAD(handle, mlx_outer);
+ CHECK_LOAD(handle, mlx_pad);
+ CHECK_LOAD(handle, mlx_pad_symmetric);
+ CHECK_LOAD(handle, mlx_partition_axis);
+ CHECK_LOAD(handle, mlx_partition);
+ CHECK_LOAD(handle, mlx_power);
+ CHECK_LOAD(handle, mlx_prod_axes);
+ CHECK_LOAD(handle, mlx_prod_axis);
+ CHECK_LOAD(handle, mlx_prod);
+ CHECK_LOAD(handle, mlx_put_along_axis);
+ CHECK_LOAD(handle, mlx_quantize);
+ CHECK_LOAD(handle, mlx_quantized_matmul);
+ CHECK_LOAD(handle, mlx_radians);
+ CHECK_LOAD(handle, mlx_real);
+ CHECK_LOAD(handle, mlx_reciprocal);
+ CHECK_LOAD(handle, mlx_remainder);
+ CHECK_LOAD(handle, mlx_repeat_axis);
+ CHECK_LOAD(handle, mlx_repeat);
+ CHECK_LOAD(handle, mlx_reshape);
+ CHECK_LOAD(handle, mlx_right_shift);
+ CHECK_LOAD(handle, mlx_roll_axis);
+ CHECK_LOAD(handle, mlx_roll_axes);
+ CHECK_LOAD(handle, mlx_roll);
+ CHECK_LOAD(handle, mlx_round);
+ CHECK_LOAD(handle, mlx_rsqrt);
+ CHECK_LOAD(handle, mlx_scatter);
+ CHECK_LOAD(handle, mlx_scatter_add);
+ CHECK_LOAD(handle, mlx_scatter_add_axis);
+ CHECK_LOAD(handle, mlx_scatter_max);
+ CHECK_LOAD(handle, mlx_scatter_min);
+ CHECK_LOAD(handle, mlx_scatter_prod);
+ CHECK_LOAD(handle, mlx_segmented_mm);
+ CHECK_LOAD(handle, mlx_sigmoid);
+ CHECK_LOAD(handle, mlx_sign);
+ CHECK_LOAD(handle, mlx_sin);
+ CHECK_LOAD(handle, mlx_sinh);
+ CHECK_LOAD(handle, mlx_slice);
+ CHECK_LOAD(handle, mlx_slice_dynamic);
+ CHECK_LOAD(handle, mlx_slice_update);
+ CHECK_LOAD(handle, mlx_slice_update_dynamic);
+ CHECK_LOAD(handle, mlx_softmax_axes);
+ CHECK_LOAD(handle, mlx_softmax_axis);
+ CHECK_LOAD(handle, mlx_softmax);
+ CHECK_LOAD(handle, mlx_sort_axis);
+ CHECK_LOAD(handle, mlx_sort);
+ CHECK_LOAD(handle, mlx_split);
+ CHECK_LOAD(handle, mlx_split_sections);
+ CHECK_LOAD(handle, mlx_sqrt);
+ CHECK_LOAD(handle, mlx_square);
+ CHECK_LOAD(handle, mlx_squeeze_axes);
+ CHECK_LOAD(handle, mlx_squeeze_axis);
+ CHECK_LOAD(handle, mlx_squeeze);
+ CHECK_LOAD(handle, mlx_stack_axis);
+ CHECK_LOAD(handle, mlx_stack);
+ CHECK_LOAD(handle, mlx_std_axes);
+ CHECK_LOAD(handle, mlx_std_axis);
+ CHECK_LOAD(handle, mlx_std);
+ CHECK_LOAD(handle, mlx_stop_gradient);
+ CHECK_LOAD(handle, mlx_subtract);
+ CHECK_LOAD(handle, mlx_sum_axes);
+ CHECK_LOAD(handle, mlx_sum_axis);
+ CHECK_LOAD(handle, mlx_sum);
+ CHECK_LOAD(handle, mlx_swapaxes);
+ CHECK_LOAD(handle, mlx_take_axis);
+ CHECK_LOAD(handle, mlx_take);
+ CHECK_LOAD(handle, mlx_take_along_axis);
+ CHECK_LOAD(handle, mlx_tan);
+ CHECK_LOAD(handle, mlx_tanh);
+ CHECK_LOAD(handle, mlx_tensordot);
+ CHECK_LOAD(handle, mlx_tensordot_axis);
+ CHECK_LOAD(handle, mlx_tile);
+ CHECK_LOAD(handle, mlx_to_fp8);
+ CHECK_LOAD(handle, mlx_topk_axis);
+ CHECK_LOAD(handle, mlx_topk);
+ CHECK_LOAD(handle, mlx_trace);
+ CHECK_LOAD(handle, mlx_transpose_axes);
+ CHECK_LOAD(handle, mlx_transpose);
+ CHECK_LOAD(handle, mlx_tri);
+ CHECK_LOAD(handle, mlx_tril);
+ CHECK_LOAD(handle, mlx_triu);
+ CHECK_LOAD(handle, mlx_unflatten);
+ CHECK_LOAD(handle, mlx_var_axes);
+ CHECK_LOAD(handle, mlx_var_axis);
+ CHECK_LOAD(handle, mlx_var);
+ CHECK_LOAD(handle, mlx_view);
+ CHECK_LOAD(handle, mlx_where);
+ CHECK_LOAD(handle, mlx_zeros);
+ CHECK_LOAD(handle, mlx_zeros_like);
+ CHECK_LOAD(handle, mlx_random_bernoulli);
+ CHECK_LOAD(handle, mlx_random_bits);
+ CHECK_LOAD(handle, mlx_random_categorical_shape);
+ CHECK_LOAD(handle, mlx_random_categorical_num_samples);
+ CHECK_LOAD(handle, mlx_random_categorical);
+ CHECK_LOAD(handle, mlx_random_gumbel);
+ CHECK_LOAD(handle, mlx_random_key);
+ CHECK_LOAD(handle, mlx_random_laplace);
+ CHECK_LOAD(handle, mlx_random_multivariate_normal);
+ CHECK_LOAD(handle, mlx_random_normal_broadcast);
+ CHECK_LOAD(handle, mlx_random_normal);
+ CHECK_LOAD(handle, mlx_random_permutation);
+ CHECK_LOAD(handle, mlx_random_permutation_arange);
+ CHECK_LOAD(handle, mlx_random_randint);
+ CHECK_LOAD(handle, mlx_random_seed);
+ CHECK_LOAD(handle, mlx_random_split_num);
+ CHECK_LOAD(handle, mlx_random_split);
+ CHECK_LOAD(handle, mlx_random_truncated_normal);
+ CHECK_LOAD(handle, mlx_random_uniform);
+ CHECK_LOAD(handle, mlx_stream_new);
+ CHECK_LOAD(handle, mlx_stream_new_device);
+ CHECK_LOAD(handle, mlx_stream_set);
+ CHECK_LOAD(handle, mlx_stream_free);
+ CHECK_LOAD(handle, mlx_stream_tostring);
+ CHECK_LOAD(handle, mlx_stream_equal);
+ CHECK_LOAD(handle, mlx_stream_get_device);
+ CHECK_LOAD(handle, mlx_stream_get_index);
+ CHECK_LOAD(handle, mlx_synchronize);
+ CHECK_LOAD(handle, mlx_get_default_stream);
+ CHECK_LOAD(handle, mlx_set_default_stream);
+ CHECK_LOAD(handle, mlx_default_cpu_stream_new);
+ CHECK_LOAD(handle, mlx_default_gpu_stream_new);
+ CHECK_LOAD(handle, mlx_string_new);
+ CHECK_LOAD(handle, mlx_string_new_data);
+ CHECK_LOAD(handle, mlx_string_set);
+ CHECK_LOAD(handle, mlx_string_data);
+ CHECK_LOAD(handle, mlx_string_free);
+ CHECK_LOAD(handle, mlx_detail_vmap_replace);
+ CHECK_LOAD(handle, mlx_detail_vmap_trace);
+ CHECK_LOAD(handle, mlx_async_eval);
+ CHECK_LOAD(handle, mlx_checkpoint);
+ CHECK_LOAD(handle, mlx_custom_function);
+ CHECK_LOAD(handle, mlx_custom_vjp);
+ CHECK_LOAD(handle, mlx_eval);
+ CHECK_LOAD(handle, mlx_jvp);
+ CHECK_LOAD(handle, mlx_value_and_grad);
+ CHECK_LOAD(handle, mlx_vjp);
+ CHECK_LOAD(handle, mlx_vector_array_new);
+ CHECK_LOAD(handle, mlx_vector_array_set);
+ CHECK_LOAD(handle, mlx_vector_array_free);
+ CHECK_LOAD(handle, mlx_vector_array_new_data);
+ CHECK_LOAD(handle, mlx_vector_array_new_value);
+ CHECK_LOAD(handle, mlx_vector_array_set_data);
+ CHECK_LOAD(handle, mlx_vector_array_set_value);
+ CHECK_LOAD(handle, mlx_vector_array_append_data);
+ CHECK_LOAD(handle, mlx_vector_array_append_value);
+ CHECK_LOAD(handle, mlx_vector_array_size);
+ CHECK_LOAD(handle, mlx_vector_array_get);
+ CHECK_LOAD(handle, mlx_vector_vector_array_new);
+ CHECK_LOAD(handle, mlx_vector_vector_array_set);
+ CHECK_LOAD(handle, mlx_vector_vector_array_free);
+ CHECK_LOAD(handle, mlx_vector_vector_array_new_data);
+ CHECK_LOAD(handle, mlx_vector_vector_array_new_value);
+ CHECK_LOAD(handle, mlx_vector_vector_array_set_data);
+ CHECK_LOAD(handle, mlx_vector_vector_array_set_value);
+ CHECK_LOAD(handle, mlx_vector_vector_array_append_data);
+ CHECK_LOAD(handle, mlx_vector_vector_array_append_value);
+ CHECK_LOAD(handle, mlx_vector_vector_array_size);
+ CHECK_LOAD(handle, mlx_vector_vector_array_get);
+ CHECK_LOAD(handle, mlx_vector_int_new);
+ CHECK_LOAD(handle, mlx_vector_int_set);
+ CHECK_LOAD(handle, mlx_vector_int_free);
+ CHECK_LOAD(handle, mlx_vector_int_new_data);
+ CHECK_LOAD(handle, mlx_vector_int_new_value);
+ CHECK_LOAD(handle, mlx_vector_int_set_data);
+ CHECK_LOAD(handle, mlx_vector_int_set_value);
+ CHECK_LOAD(handle, mlx_vector_int_append_data);
+ CHECK_LOAD(handle, mlx_vector_int_append_value);
+ CHECK_LOAD(handle, mlx_vector_int_size);
+ CHECK_LOAD(handle, mlx_vector_int_get);
+ CHECK_LOAD(handle, mlx_vector_string_new);
+ CHECK_LOAD(handle, mlx_vector_string_set);
+ CHECK_LOAD(handle, mlx_vector_string_free);
+ CHECK_LOAD(handle, mlx_vector_string_new_data);
+ CHECK_LOAD(handle, mlx_vector_string_new_value);
+ CHECK_LOAD(handle, mlx_vector_string_set_data);
+ CHECK_LOAD(handle, mlx_vector_string_set_value);
+ CHECK_LOAD(handle, mlx_vector_string_append_data);
+ CHECK_LOAD(handle, mlx_vector_string_append_value);
+ CHECK_LOAD(handle, mlx_vector_string_size);
+ CHECK_LOAD(handle, mlx_vector_string_get);
+ CHECK_LOAD(handle, mlx_version);
+ return 0;
+}
diff --git a/x/mlxrunner/mlx/generated.h b/x/mlxrunner/mlx/generated.h
new file mode 100644
index 00000000000..c88946d9f73
--- /dev/null
+++ b/x/mlxrunner/mlx/generated.h
@@ -0,0 +1,7135 @@
+// This code is auto-generated; DO NOT EDIT.
+
+#ifndef MLX_GENERATED_H
+#define MLX_GENERATED_H
+
+#include "dynamic.h"
+
+#define mlx_dtype_size mlx_dtype_size_mlx_gen_orig_
+#define mlx_array_tostring mlx_array_tostring_mlx_gen_orig_
+#define mlx_array_new mlx_array_new_mlx_gen_orig_
+#define mlx_array_free mlx_array_free_mlx_gen_orig_
+#define mlx_array_new_bool mlx_array_new_bool_mlx_gen_orig_
+#define mlx_array_new_int mlx_array_new_int_mlx_gen_orig_
+#define mlx_array_new_float32 mlx_array_new_float32_mlx_gen_orig_
+#define mlx_array_new_float mlx_array_new_float_mlx_gen_orig_
+#define mlx_array_new_float64 mlx_array_new_float64_mlx_gen_orig_
+#define mlx_array_new_double mlx_array_new_double_mlx_gen_orig_
+#define mlx_array_new_complex mlx_array_new_complex_mlx_gen_orig_
+#define mlx_array_new_data mlx_array_new_data_mlx_gen_orig_
+#define mlx_array_set mlx_array_set_mlx_gen_orig_
+#define mlx_array_set_bool mlx_array_set_bool_mlx_gen_orig_
+#define mlx_array_set_int mlx_array_set_int_mlx_gen_orig_
+#define mlx_array_set_float32 mlx_array_set_float32_mlx_gen_orig_
+#define mlx_array_set_float mlx_array_set_float_mlx_gen_orig_
+#define mlx_array_set_float64 mlx_array_set_float64_mlx_gen_orig_
+#define mlx_array_set_double mlx_array_set_double_mlx_gen_orig_
+#define mlx_array_set_complex mlx_array_set_complex_mlx_gen_orig_
+#define mlx_array_set_data mlx_array_set_data_mlx_gen_orig_
+#define mlx_array_itemsize mlx_array_itemsize_mlx_gen_orig_
+#define mlx_array_size mlx_array_size_mlx_gen_orig_
+#define mlx_array_nbytes mlx_array_nbytes_mlx_gen_orig_
+#define mlx_array_ndim mlx_array_ndim_mlx_gen_orig_
+#define mlx_array_shape mlx_array_shape_mlx_gen_orig_
+#define mlx_array_strides mlx_array_strides_mlx_gen_orig_
+#define mlx_array_dim mlx_array_dim_mlx_gen_orig_
+#define mlx_array_dtype mlx_array_dtype_mlx_gen_orig_
+#define mlx_array_eval mlx_array_eval_mlx_gen_orig_
+#define mlx_array_item_bool mlx_array_item_bool_mlx_gen_orig_
+#define mlx_array_item_uint8 mlx_array_item_uint8_mlx_gen_orig_
+#define mlx_array_item_uint16 mlx_array_item_uint16_mlx_gen_orig_
+#define mlx_array_item_uint32 mlx_array_item_uint32_mlx_gen_orig_
+#define mlx_array_item_uint64 mlx_array_item_uint64_mlx_gen_orig_
+#define mlx_array_item_int8 mlx_array_item_int8_mlx_gen_orig_
+#define mlx_array_item_int16 mlx_array_item_int16_mlx_gen_orig_
+#define mlx_array_item_int32 mlx_array_item_int32_mlx_gen_orig_
+#define mlx_array_item_int64 mlx_array_item_int64_mlx_gen_orig_
+#define mlx_array_item_float32 mlx_array_item_float32_mlx_gen_orig_
+#define mlx_array_item_float64 mlx_array_item_float64_mlx_gen_orig_
+#define mlx_array_item_complex64 mlx_array_item_complex64_mlx_gen_orig_
+#define mlx_array_item_float16 mlx_array_item_float16_mlx_gen_orig_
+#define mlx_array_item_bfloat16 mlx_array_item_bfloat16_mlx_gen_orig_
+#define mlx_array_data_bool mlx_array_data_bool_mlx_gen_orig_
+#define mlx_array_data_uint8 mlx_array_data_uint8_mlx_gen_orig_
+#define mlx_array_data_uint16 mlx_array_data_uint16_mlx_gen_orig_
+#define mlx_array_data_uint32 mlx_array_data_uint32_mlx_gen_orig_
+#define mlx_array_data_uint64 mlx_array_data_uint64_mlx_gen_orig_
+#define mlx_array_data_int8 mlx_array_data_int8_mlx_gen_orig_
+#define mlx_array_data_int16 mlx_array_data_int16_mlx_gen_orig_
+#define mlx_array_data_int32 mlx_array_data_int32_mlx_gen_orig_
+#define mlx_array_data_int64 mlx_array_data_int64_mlx_gen_orig_
+#define mlx_array_data_float32 mlx_array_data_float32_mlx_gen_orig_
+#define mlx_array_data_float64 mlx_array_data_float64_mlx_gen_orig_
+#define mlx_array_data_complex64 mlx_array_data_complex64_mlx_gen_orig_
+#define mlx_array_data_float16 mlx_array_data_float16_mlx_gen_orig_
+#define mlx_array_data_bfloat16 mlx_array_data_bfloat16_mlx_gen_orig_
+#define _mlx_array_is_available _mlx_array_is_available_mlx_gen_orig_
+#define _mlx_array_wait _mlx_array_wait_mlx_gen_orig_
+#define _mlx_array_is_contiguous _mlx_array_is_contiguous_mlx_gen_orig_
+#define _mlx_array_is_row_contiguous _mlx_array_is_row_contiguous_mlx_gen_orig_
+#define _mlx_array_is_col_contiguous _mlx_array_is_col_contiguous_mlx_gen_orig_
+#define mlx_closure_new mlx_closure_new_mlx_gen_orig_
+#define mlx_closure_free mlx_closure_free_mlx_gen_orig_
+#define mlx_closure_new_func mlx_closure_new_func_mlx_gen_orig_
+#define mlx_closure_new_func_payload mlx_closure_new_func_payload_mlx_gen_orig_
+#define mlx_closure_set mlx_closure_set_mlx_gen_orig_
+#define mlx_closure_apply mlx_closure_apply_mlx_gen_orig_
+#define mlx_closure_new_unary mlx_closure_new_unary_mlx_gen_orig_
+#define mlx_closure_kwargs_new mlx_closure_kwargs_new_mlx_gen_orig_
+#define mlx_closure_kwargs_free mlx_closure_kwargs_free_mlx_gen_orig_
+#define mlx_closure_kwargs_new_func mlx_closure_kwargs_new_func_mlx_gen_orig_
+#define mlx_closure_kwargs_new_func_payload mlx_closure_kwargs_new_func_payload_mlx_gen_orig_
+#define mlx_closure_kwargs_set mlx_closure_kwargs_set_mlx_gen_orig_
+#define mlx_closure_kwargs_apply mlx_closure_kwargs_apply_mlx_gen_orig_
+#define mlx_closure_value_and_grad_new mlx_closure_value_and_grad_new_mlx_gen_orig_
+#define mlx_closure_value_and_grad_free mlx_closure_value_and_grad_free_mlx_gen_orig_
+#define mlx_closure_value_and_grad_new_func mlx_closure_value_and_grad_new_func_mlx_gen_orig_
+#define mlx_closure_value_and_grad_new_func_payload mlx_closure_value_and_grad_new_func_payload_mlx_gen_orig_
+#define mlx_closure_value_and_grad_set mlx_closure_value_and_grad_set_mlx_gen_orig_
+#define mlx_closure_value_and_grad_apply mlx_closure_value_and_grad_apply_mlx_gen_orig_
+#define mlx_closure_custom_new mlx_closure_custom_new_mlx_gen_orig_
+#define mlx_closure_custom_free mlx_closure_custom_free_mlx_gen_orig_
+#define mlx_closure_custom_new_func mlx_closure_custom_new_func_mlx_gen_orig_
+#define mlx_closure_custom_new_func_payload mlx_closure_custom_new_func_payload_mlx_gen_orig_
+#define mlx_closure_custom_set mlx_closure_custom_set_mlx_gen_orig_
+#define mlx_closure_custom_apply mlx_closure_custom_apply_mlx_gen_orig_
+#define mlx_closure_custom_jvp_new mlx_closure_custom_jvp_new_mlx_gen_orig_
+#define mlx_closure_custom_jvp_free mlx_closure_custom_jvp_free_mlx_gen_orig_
+#define mlx_closure_custom_jvp_new_func mlx_closure_custom_jvp_new_func_mlx_gen_orig_
+#define mlx_closure_custom_jvp_new_func_payload mlx_closure_custom_jvp_new_func_payload_mlx_gen_orig_
+#define mlx_closure_custom_jvp_set mlx_closure_custom_jvp_set_mlx_gen_orig_
+#define mlx_closure_custom_jvp_apply mlx_closure_custom_jvp_apply_mlx_gen_orig_
+#define mlx_closure_custom_vmap_new mlx_closure_custom_vmap_new_mlx_gen_orig_
+#define mlx_closure_custom_vmap_free mlx_closure_custom_vmap_free_mlx_gen_orig_
+#define mlx_closure_custom_vmap_new_func mlx_closure_custom_vmap_new_func_mlx_gen_orig_
+#define mlx_closure_custom_vmap_new_func_payload mlx_closure_custom_vmap_new_func_payload_mlx_gen_orig_
+#define mlx_closure_custom_vmap_set mlx_closure_custom_vmap_set_mlx_gen_orig_
+#define mlx_closure_custom_vmap_apply mlx_closure_custom_vmap_apply_mlx_gen_orig_
+#define mlx_compile mlx_compile_mlx_gen_orig_
+#define mlx_detail_compile mlx_detail_compile_mlx_gen_orig_
+#define mlx_detail_compile_clear_cache mlx_detail_compile_clear_cache_mlx_gen_orig_
+#define mlx_detail_compile_erase mlx_detail_compile_erase_mlx_gen_orig_
+#define mlx_disable_compile mlx_disable_compile_mlx_gen_orig_
+#define mlx_enable_compile mlx_enable_compile_mlx_gen_orig_
+#define mlx_set_compile_mode mlx_set_compile_mode_mlx_gen_orig_
+#define mlx_device_new mlx_device_new_mlx_gen_orig_
+#define mlx_device_new_type mlx_device_new_type_mlx_gen_orig_
+#define mlx_device_free mlx_device_free_mlx_gen_orig_
+#define mlx_device_set mlx_device_set_mlx_gen_orig_
+#define mlx_device_tostring mlx_device_tostring_mlx_gen_orig_
+#define mlx_device_equal mlx_device_equal_mlx_gen_orig_
+#define mlx_device_get_index mlx_device_get_index_mlx_gen_orig_
+#define mlx_device_get_type mlx_device_get_type_mlx_gen_orig_
+#define mlx_get_default_device mlx_get_default_device_mlx_gen_orig_
+#define mlx_set_default_device mlx_set_default_device_mlx_gen_orig_
+#define mlx_distributed_group_rank mlx_distributed_group_rank_mlx_gen_orig_
+#define mlx_distributed_group_size mlx_distributed_group_size_mlx_gen_orig_
+#define mlx_distributed_group_split mlx_distributed_group_split_mlx_gen_orig_
+#define mlx_distributed_is_available mlx_distributed_is_available_mlx_gen_orig_
+#define mlx_distributed_init mlx_distributed_init_mlx_gen_orig_
+#define mlx_distributed_all_gather mlx_distributed_all_gather_mlx_gen_orig_
+#define mlx_distributed_all_max mlx_distributed_all_max_mlx_gen_orig_
+#define mlx_distributed_all_min mlx_distributed_all_min_mlx_gen_orig_
+#define mlx_distributed_all_sum mlx_distributed_all_sum_mlx_gen_orig_
+#define mlx_distributed_recv mlx_distributed_recv_mlx_gen_orig_
+#define mlx_distributed_recv_like mlx_distributed_recv_like_mlx_gen_orig_
+#define mlx_distributed_send mlx_distributed_send_mlx_gen_orig_
+#define mlx_distributed_sum_scatter mlx_distributed_sum_scatter_mlx_gen_orig_
+#define mlx_set_error_handler mlx_set_error_handler_mlx_gen_orig_
+#define _mlx_error _mlx_error_mlx_gen_orig_
+#define mlx_export_function mlx_export_function_mlx_gen_orig_
+#define mlx_export_function_kwargs mlx_export_function_kwargs_mlx_gen_orig_
+#define mlx_function_exporter_new mlx_function_exporter_new_mlx_gen_orig_
+#define mlx_function_exporter_free mlx_function_exporter_free_mlx_gen_orig_
+#define mlx_function_exporter_apply mlx_function_exporter_apply_mlx_gen_orig_
+#define mlx_function_exporter_apply_kwargs mlx_function_exporter_apply_kwargs_mlx_gen_orig_
+#define mlx_imported_function_new mlx_imported_function_new_mlx_gen_orig_
+#define mlx_imported_function_free mlx_imported_function_free_mlx_gen_orig_
+#define mlx_imported_function_apply mlx_imported_function_apply_mlx_gen_orig_
+#define mlx_imported_function_apply_kwargs mlx_imported_function_apply_kwargs_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_new mlx_fast_cuda_kernel_config_new_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_free mlx_fast_cuda_kernel_config_free_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_add_output_arg mlx_fast_cuda_kernel_config_add_output_arg_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_set_grid mlx_fast_cuda_kernel_config_set_grid_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_set_thread_group mlx_fast_cuda_kernel_config_set_thread_group_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_set_init_value mlx_fast_cuda_kernel_config_set_init_value_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_set_verbose mlx_fast_cuda_kernel_config_set_verbose_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_add_template_arg_dtype mlx_fast_cuda_kernel_config_add_template_arg_dtype_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_add_template_arg_int mlx_fast_cuda_kernel_config_add_template_arg_int_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_config_add_template_arg_bool mlx_fast_cuda_kernel_config_add_template_arg_bool_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_new mlx_fast_cuda_kernel_new_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_free mlx_fast_cuda_kernel_free_mlx_gen_orig_
+#define mlx_fast_cuda_kernel_apply mlx_fast_cuda_kernel_apply_mlx_gen_orig_
+#define mlx_fast_layer_norm mlx_fast_layer_norm_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_new mlx_fast_metal_kernel_config_new_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_free mlx_fast_metal_kernel_config_free_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_add_output_arg mlx_fast_metal_kernel_config_add_output_arg_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_set_grid mlx_fast_metal_kernel_config_set_grid_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_set_thread_group mlx_fast_metal_kernel_config_set_thread_group_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_set_init_value mlx_fast_metal_kernel_config_set_init_value_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_set_verbose mlx_fast_metal_kernel_config_set_verbose_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_add_template_arg_dtype mlx_fast_metal_kernel_config_add_template_arg_dtype_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_add_template_arg_int mlx_fast_metal_kernel_config_add_template_arg_int_mlx_gen_orig_
+#define mlx_fast_metal_kernel_config_add_template_arg_bool mlx_fast_metal_kernel_config_add_template_arg_bool_mlx_gen_orig_
+#define mlx_fast_metal_kernel_new mlx_fast_metal_kernel_new_mlx_gen_orig_
+#define mlx_fast_metal_kernel_free mlx_fast_metal_kernel_free_mlx_gen_orig_
+#define mlx_fast_metal_kernel_apply mlx_fast_metal_kernel_apply_mlx_gen_orig_
+#define mlx_fast_rms_norm mlx_fast_rms_norm_mlx_gen_orig_
+#define mlx_fast_rope mlx_fast_rope_mlx_gen_orig_
+#define mlx_fast_scaled_dot_product_attention mlx_fast_scaled_dot_product_attention_mlx_gen_orig_
+#define mlx_fft_fft mlx_fft_fft_mlx_gen_orig_
+#define mlx_fft_fft2 mlx_fft_fft2_mlx_gen_orig_
+#define mlx_fft_fftn mlx_fft_fftn_mlx_gen_orig_
+#define mlx_fft_fftshift mlx_fft_fftshift_mlx_gen_orig_
+#define mlx_fft_ifft mlx_fft_ifft_mlx_gen_orig_
+#define mlx_fft_ifft2 mlx_fft_ifft2_mlx_gen_orig_
+#define mlx_fft_ifftn mlx_fft_ifftn_mlx_gen_orig_
+#define mlx_fft_ifftshift mlx_fft_ifftshift_mlx_gen_orig_
+#define mlx_fft_irfft mlx_fft_irfft_mlx_gen_orig_
+#define mlx_fft_irfft2 mlx_fft_irfft2_mlx_gen_orig_
+#define mlx_fft_irfftn mlx_fft_irfftn_mlx_gen_orig_
+#define mlx_fft_rfft mlx_fft_rfft_mlx_gen_orig_
+#define mlx_fft_rfft2 mlx_fft_rfft2_mlx_gen_orig_
+#define mlx_fft_rfftn mlx_fft_rfftn_mlx_gen_orig_
+#define mlx_io_reader_new mlx_io_reader_new_mlx_gen_orig_
+#define mlx_io_reader_descriptor mlx_io_reader_descriptor_mlx_gen_orig_
+#define mlx_io_reader_tostring mlx_io_reader_tostring_mlx_gen_orig_
+#define mlx_io_reader_free mlx_io_reader_free_mlx_gen_orig_
+#define mlx_io_writer_new mlx_io_writer_new_mlx_gen_orig_
+#define mlx_io_writer_descriptor mlx_io_writer_descriptor_mlx_gen_orig_
+#define mlx_io_writer_tostring mlx_io_writer_tostring_mlx_gen_orig_
+#define mlx_io_writer_free mlx_io_writer_free_mlx_gen_orig_
+#define mlx_load_reader mlx_load_reader_mlx_gen_orig_
+#define mlx_load mlx_load_mlx_gen_orig_
+#define mlx_load_safetensors_reader mlx_load_safetensors_reader_mlx_gen_orig_
+#define mlx_load_safetensors mlx_load_safetensors_mlx_gen_orig_
+#define mlx_save_writer mlx_save_writer_mlx_gen_orig_
+#define mlx_save mlx_save_mlx_gen_orig_
+#define mlx_save_safetensors_writer mlx_save_safetensors_writer_mlx_gen_orig_
+#define mlx_save_safetensors mlx_save_safetensors_mlx_gen_orig_
+#define mlx_linalg_cholesky mlx_linalg_cholesky_mlx_gen_orig_
+#define mlx_linalg_cholesky_inv mlx_linalg_cholesky_inv_mlx_gen_orig_
+#define mlx_linalg_cross mlx_linalg_cross_mlx_gen_orig_
+#define mlx_linalg_eig mlx_linalg_eig_mlx_gen_orig_
+#define mlx_linalg_eigh mlx_linalg_eigh_mlx_gen_orig_
+#define mlx_linalg_eigvals mlx_linalg_eigvals_mlx_gen_orig_
+#define mlx_linalg_eigvalsh mlx_linalg_eigvalsh_mlx_gen_orig_
+#define mlx_linalg_inv mlx_linalg_inv_mlx_gen_orig_
+#define mlx_linalg_lu mlx_linalg_lu_mlx_gen_orig_
+#define mlx_linalg_lu_factor mlx_linalg_lu_factor_mlx_gen_orig_
+#define mlx_linalg_norm mlx_linalg_norm_mlx_gen_orig_
+#define mlx_linalg_norm_matrix mlx_linalg_norm_matrix_mlx_gen_orig_
+#define mlx_linalg_norm_l2 mlx_linalg_norm_l2_mlx_gen_orig_
+#define mlx_linalg_pinv mlx_linalg_pinv_mlx_gen_orig_
+#define mlx_linalg_qr mlx_linalg_qr_mlx_gen_orig_
+#define mlx_linalg_solve mlx_linalg_solve_mlx_gen_orig_
+#define mlx_linalg_solve_triangular mlx_linalg_solve_triangular_mlx_gen_orig_
+#define mlx_linalg_svd mlx_linalg_svd_mlx_gen_orig_
+#define mlx_linalg_tri_inv mlx_linalg_tri_inv_mlx_gen_orig_
+#define mlx_map_string_to_array_new mlx_map_string_to_array_new_mlx_gen_orig_
+#define mlx_map_string_to_array_set mlx_map_string_to_array_set_mlx_gen_orig_
+#define mlx_map_string_to_array_free mlx_map_string_to_array_free_mlx_gen_orig_
+#define mlx_map_string_to_array_insert mlx_map_string_to_array_insert_mlx_gen_orig_
+#define mlx_map_string_to_array_get mlx_map_string_to_array_get_mlx_gen_orig_
+#define mlx_map_string_to_array_iterator_new mlx_map_string_to_array_iterator_new_mlx_gen_orig_
+#define mlx_map_string_to_array_iterator_free mlx_map_string_to_array_iterator_free_mlx_gen_orig_
+#define mlx_map_string_to_array_iterator_next mlx_map_string_to_array_iterator_next_mlx_gen_orig_
+#define mlx_map_string_to_string_new mlx_map_string_to_string_new_mlx_gen_orig_
+#define mlx_map_string_to_string_set mlx_map_string_to_string_set_mlx_gen_orig_
+#define mlx_map_string_to_string_free mlx_map_string_to_string_free_mlx_gen_orig_
+#define mlx_map_string_to_string_insert mlx_map_string_to_string_insert_mlx_gen_orig_
+#define mlx_map_string_to_string_get mlx_map_string_to_string_get_mlx_gen_orig_
+#define mlx_map_string_to_string_iterator_new mlx_map_string_to_string_iterator_new_mlx_gen_orig_
+#define mlx_map_string_to_string_iterator_free mlx_map_string_to_string_iterator_free_mlx_gen_orig_
+#define mlx_map_string_to_string_iterator_next mlx_map_string_to_string_iterator_next_mlx_gen_orig_
+#define mlx_clear_cache mlx_clear_cache_mlx_gen_orig_
+#define mlx_get_active_memory mlx_get_active_memory_mlx_gen_orig_
+#define mlx_get_cache_memory mlx_get_cache_memory_mlx_gen_orig_
+#define mlx_get_memory_limit mlx_get_memory_limit_mlx_gen_orig_
+#define mlx_get_peak_memory mlx_get_peak_memory_mlx_gen_orig_
+#define mlx_reset_peak_memory mlx_reset_peak_memory_mlx_gen_orig_
+#define mlx_set_cache_limit mlx_set_cache_limit_mlx_gen_orig_
+#define mlx_set_memory_limit mlx_set_memory_limit_mlx_gen_orig_
+#define mlx_set_wired_limit mlx_set_wired_limit_mlx_gen_orig_
+#define mlx_metal_device_info mlx_metal_device_info_mlx_gen_orig_
+#define mlx_metal_is_available mlx_metal_is_available_mlx_gen_orig_
+#define mlx_metal_start_capture mlx_metal_start_capture_mlx_gen_orig_
+#define mlx_metal_stop_capture mlx_metal_stop_capture_mlx_gen_orig_
+#define mlx_abs mlx_abs_mlx_gen_orig_
+#define mlx_add mlx_add_mlx_gen_orig_
+#define mlx_addmm mlx_addmm_mlx_gen_orig_
+#define mlx_all_axes mlx_all_axes_mlx_gen_orig_
+#define mlx_all_axis mlx_all_axis_mlx_gen_orig_
+#define mlx_all mlx_all_mlx_gen_orig_
+#define mlx_allclose mlx_allclose_mlx_gen_orig_
+#define mlx_any_axes mlx_any_axes_mlx_gen_orig_
+#define mlx_any_axis mlx_any_axis_mlx_gen_orig_
+#define mlx_any mlx_any_mlx_gen_orig_
+#define mlx_arange mlx_arange_mlx_gen_orig_
+#define mlx_arccos mlx_arccos_mlx_gen_orig_
+#define mlx_arccosh mlx_arccosh_mlx_gen_orig_
+#define mlx_arcsin mlx_arcsin_mlx_gen_orig_
+#define mlx_arcsinh mlx_arcsinh_mlx_gen_orig_
+#define mlx_arctan mlx_arctan_mlx_gen_orig_
+#define mlx_arctan2 mlx_arctan2_mlx_gen_orig_
+#define mlx_arctanh mlx_arctanh_mlx_gen_orig_
+#define mlx_argmax_axis mlx_argmax_axis_mlx_gen_orig_
+#define mlx_argmax mlx_argmax_mlx_gen_orig_
+#define mlx_argmin_axis mlx_argmin_axis_mlx_gen_orig_
+#define mlx_argmin mlx_argmin_mlx_gen_orig_
+#define mlx_argpartition_axis mlx_argpartition_axis_mlx_gen_orig_
+#define mlx_argpartition mlx_argpartition_mlx_gen_orig_
+#define mlx_argsort_axis mlx_argsort_axis_mlx_gen_orig_
+#define mlx_argsort mlx_argsort_mlx_gen_orig_
+#define mlx_array_equal mlx_array_equal_mlx_gen_orig_
+#define mlx_as_strided mlx_as_strided_mlx_gen_orig_
+#define mlx_astype mlx_astype_mlx_gen_orig_
+#define mlx_atleast_1d mlx_atleast_1d_mlx_gen_orig_
+#define mlx_atleast_2d mlx_atleast_2d_mlx_gen_orig_
+#define mlx_atleast_3d mlx_atleast_3d_mlx_gen_orig_
+#define mlx_bitwise_and mlx_bitwise_and_mlx_gen_orig_
+#define mlx_bitwise_invert mlx_bitwise_invert_mlx_gen_orig_
+#define mlx_bitwise_or mlx_bitwise_or_mlx_gen_orig_
+#define mlx_bitwise_xor mlx_bitwise_xor_mlx_gen_orig_
+#define mlx_block_masked_mm mlx_block_masked_mm_mlx_gen_orig_
+#define mlx_broadcast_arrays mlx_broadcast_arrays_mlx_gen_orig_
+#define mlx_broadcast_to mlx_broadcast_to_mlx_gen_orig_
+#define mlx_ceil mlx_ceil_mlx_gen_orig_
+#define mlx_clip mlx_clip_mlx_gen_orig_
+#define mlx_concatenate_axis mlx_concatenate_axis_mlx_gen_orig_
+#define mlx_concatenate mlx_concatenate_mlx_gen_orig_
+#define mlx_conjugate mlx_conjugate_mlx_gen_orig_
+#define mlx_contiguous mlx_contiguous_mlx_gen_orig_
+#define mlx_conv1d mlx_conv1d_mlx_gen_orig_
+#define mlx_conv2d mlx_conv2d_mlx_gen_orig_
+#define mlx_conv3d mlx_conv3d_mlx_gen_orig_
+#define mlx_conv_general mlx_conv_general_mlx_gen_orig_
+#define mlx_conv_transpose1d mlx_conv_transpose1d_mlx_gen_orig_
+#define mlx_conv_transpose2d mlx_conv_transpose2d_mlx_gen_orig_
+#define mlx_conv_transpose3d mlx_conv_transpose3d_mlx_gen_orig_
+#define mlx_copy mlx_copy_mlx_gen_orig_
+#define mlx_cos mlx_cos_mlx_gen_orig_
+#define mlx_cosh mlx_cosh_mlx_gen_orig_
+#define mlx_cummax mlx_cummax_mlx_gen_orig_
+#define mlx_cummin mlx_cummin_mlx_gen_orig_
+#define mlx_cumprod mlx_cumprod_mlx_gen_orig_
+#define mlx_cumsum mlx_cumsum_mlx_gen_orig_
+#define mlx_degrees mlx_degrees_mlx_gen_orig_
+#define mlx_depends mlx_depends_mlx_gen_orig_
+#define mlx_dequantize mlx_dequantize_mlx_gen_orig_
+#define mlx_diag mlx_diag_mlx_gen_orig_
+#define mlx_diagonal mlx_diagonal_mlx_gen_orig_
+#define mlx_divide mlx_divide_mlx_gen_orig_
+#define mlx_divmod mlx_divmod_mlx_gen_orig_
+#define mlx_einsum mlx_einsum_mlx_gen_orig_
+#define mlx_equal mlx_equal_mlx_gen_orig_
+#define mlx_erf mlx_erf_mlx_gen_orig_
+#define mlx_erfinv mlx_erfinv_mlx_gen_orig_
+#define mlx_exp mlx_exp_mlx_gen_orig_
+#define mlx_expand_dims_axes mlx_expand_dims_axes_mlx_gen_orig_
+#define mlx_expand_dims mlx_expand_dims_mlx_gen_orig_
+#define mlx_expm1 mlx_expm1_mlx_gen_orig_
+#define mlx_eye mlx_eye_mlx_gen_orig_
+#define mlx_flatten mlx_flatten_mlx_gen_orig_
+#define mlx_floor mlx_floor_mlx_gen_orig_
+#define mlx_floor_divide mlx_floor_divide_mlx_gen_orig_
+#define mlx_from_fp8 mlx_from_fp8_mlx_gen_orig_
+#define mlx_full mlx_full_mlx_gen_orig_
+#define mlx_full_like mlx_full_like_mlx_gen_orig_
+#define mlx_gather mlx_gather_mlx_gen_orig_
+#define mlx_gather_mm mlx_gather_mm_mlx_gen_orig_
+#define mlx_gather_qmm mlx_gather_qmm_mlx_gen_orig_
+#define mlx_greater mlx_greater_mlx_gen_orig_
+#define mlx_greater_equal mlx_greater_equal_mlx_gen_orig_
+#define mlx_hadamard_transform mlx_hadamard_transform_mlx_gen_orig_
+#define mlx_identity mlx_identity_mlx_gen_orig_
+#define mlx_imag mlx_imag_mlx_gen_orig_
+#define mlx_inner mlx_inner_mlx_gen_orig_
+#define mlx_isclose mlx_isclose_mlx_gen_orig_
+#define mlx_isfinite mlx_isfinite_mlx_gen_orig_
+#define mlx_isinf mlx_isinf_mlx_gen_orig_
+#define mlx_isnan mlx_isnan_mlx_gen_orig_
+#define mlx_isneginf mlx_isneginf_mlx_gen_orig_
+#define mlx_isposinf mlx_isposinf_mlx_gen_orig_
+#define mlx_kron mlx_kron_mlx_gen_orig_
+#define mlx_left_shift mlx_left_shift_mlx_gen_orig_
+#define mlx_less mlx_less_mlx_gen_orig_
+#define mlx_less_equal mlx_less_equal_mlx_gen_orig_
+#define mlx_linspace mlx_linspace_mlx_gen_orig_
+#define mlx_log mlx_log_mlx_gen_orig_
+#define mlx_log10 mlx_log10_mlx_gen_orig_
+#define mlx_log1p mlx_log1p_mlx_gen_orig_
+#define mlx_log2 mlx_log2_mlx_gen_orig_
+#define mlx_logaddexp mlx_logaddexp_mlx_gen_orig_
+#define mlx_logcumsumexp mlx_logcumsumexp_mlx_gen_orig_
+#define mlx_logical_and mlx_logical_and_mlx_gen_orig_
+#define mlx_logical_not mlx_logical_not_mlx_gen_orig_
+#define mlx_logical_or mlx_logical_or_mlx_gen_orig_
+#define mlx_logsumexp_axes mlx_logsumexp_axes_mlx_gen_orig_
+#define mlx_logsumexp_axis mlx_logsumexp_axis_mlx_gen_orig_
+#define mlx_logsumexp mlx_logsumexp_mlx_gen_orig_
+#define mlx_masked_scatter mlx_masked_scatter_mlx_gen_orig_
+#define mlx_matmul mlx_matmul_mlx_gen_orig_
+#define mlx_max_axes mlx_max_axes_mlx_gen_orig_
+#define mlx_max_axis mlx_max_axis_mlx_gen_orig_
+#define mlx_max mlx_max_mlx_gen_orig_
+#define mlx_maximum mlx_maximum_mlx_gen_orig_
+#define mlx_mean_axes mlx_mean_axes_mlx_gen_orig_
+#define mlx_mean_axis mlx_mean_axis_mlx_gen_orig_
+#define mlx_mean mlx_mean_mlx_gen_orig_
+#define mlx_median mlx_median_mlx_gen_orig_
+#define mlx_meshgrid mlx_meshgrid_mlx_gen_orig_
+#define mlx_min_axes mlx_min_axes_mlx_gen_orig_
+#define mlx_min_axis mlx_min_axis_mlx_gen_orig_
+#define mlx_min mlx_min_mlx_gen_orig_
+#define mlx_minimum mlx_minimum_mlx_gen_orig_
+#define mlx_moveaxis mlx_moveaxis_mlx_gen_orig_
+#define mlx_multiply mlx_multiply_mlx_gen_orig_
+#define mlx_nan_to_num mlx_nan_to_num_mlx_gen_orig_
+#define mlx_negative mlx_negative_mlx_gen_orig_
+#define mlx_not_equal mlx_not_equal_mlx_gen_orig_
+#define mlx_number_of_elements mlx_number_of_elements_mlx_gen_orig_
+#define mlx_ones mlx_ones_mlx_gen_orig_
+#define mlx_ones_like mlx_ones_like_mlx_gen_orig_
+#define mlx_outer mlx_outer_mlx_gen_orig_
+#define mlx_pad mlx_pad_mlx_gen_orig_
+#define mlx_pad_symmetric mlx_pad_symmetric_mlx_gen_orig_
+#define mlx_partition_axis mlx_partition_axis_mlx_gen_orig_
+#define mlx_partition mlx_partition_mlx_gen_orig_
+#define mlx_power mlx_power_mlx_gen_orig_
+#define mlx_prod_axes mlx_prod_axes_mlx_gen_orig_
+#define mlx_prod_axis mlx_prod_axis_mlx_gen_orig_
+#define mlx_prod mlx_prod_mlx_gen_orig_
+#define mlx_put_along_axis mlx_put_along_axis_mlx_gen_orig_
+#define mlx_quantize mlx_quantize_mlx_gen_orig_
+#define mlx_quantized_matmul mlx_quantized_matmul_mlx_gen_orig_
+#define mlx_radians mlx_radians_mlx_gen_orig_
+#define mlx_real mlx_real_mlx_gen_orig_
+#define mlx_reciprocal mlx_reciprocal_mlx_gen_orig_
+#define mlx_remainder mlx_remainder_mlx_gen_orig_
+#define mlx_repeat_axis mlx_repeat_axis_mlx_gen_orig_
+#define mlx_repeat mlx_repeat_mlx_gen_orig_
+#define mlx_reshape mlx_reshape_mlx_gen_orig_
+#define mlx_right_shift mlx_right_shift_mlx_gen_orig_
+#define mlx_roll_axis mlx_roll_axis_mlx_gen_orig_
+#define mlx_roll_axes mlx_roll_axes_mlx_gen_orig_
+#define mlx_roll mlx_roll_mlx_gen_orig_
+#define mlx_round mlx_round_mlx_gen_orig_
+#define mlx_rsqrt mlx_rsqrt_mlx_gen_orig_
+#define mlx_scatter mlx_scatter_mlx_gen_orig_
+#define mlx_scatter_add mlx_scatter_add_mlx_gen_orig_
+#define mlx_scatter_add_axis mlx_scatter_add_axis_mlx_gen_orig_
+#define mlx_scatter_max mlx_scatter_max_mlx_gen_orig_
+#define mlx_scatter_min mlx_scatter_min_mlx_gen_orig_
+#define mlx_scatter_prod mlx_scatter_prod_mlx_gen_orig_
+#define mlx_segmented_mm mlx_segmented_mm_mlx_gen_orig_
+#define mlx_sigmoid mlx_sigmoid_mlx_gen_orig_
+#define mlx_sign mlx_sign_mlx_gen_orig_
+#define mlx_sin mlx_sin_mlx_gen_orig_
+#define mlx_sinh mlx_sinh_mlx_gen_orig_
+#define mlx_slice mlx_slice_mlx_gen_orig_
+#define mlx_slice_dynamic mlx_slice_dynamic_mlx_gen_orig_
+#define mlx_slice_update mlx_slice_update_mlx_gen_orig_
+#define mlx_slice_update_dynamic mlx_slice_update_dynamic_mlx_gen_orig_
+#define mlx_softmax_axes mlx_softmax_axes_mlx_gen_orig_
+#define mlx_softmax_axis mlx_softmax_axis_mlx_gen_orig_
+#define mlx_softmax mlx_softmax_mlx_gen_orig_
+#define mlx_sort_axis mlx_sort_axis_mlx_gen_orig_
+#define mlx_sort mlx_sort_mlx_gen_orig_
+#define mlx_split mlx_split_mlx_gen_orig_
+#define mlx_split_sections mlx_split_sections_mlx_gen_orig_
+#define mlx_sqrt mlx_sqrt_mlx_gen_orig_
+#define mlx_square mlx_square_mlx_gen_orig_
+#define mlx_squeeze_axes mlx_squeeze_axes_mlx_gen_orig_
+#define mlx_squeeze_axis mlx_squeeze_axis_mlx_gen_orig_
+#define mlx_squeeze mlx_squeeze_mlx_gen_orig_
+#define mlx_stack_axis mlx_stack_axis_mlx_gen_orig_
+#define mlx_stack mlx_stack_mlx_gen_orig_
+#define mlx_std_axes mlx_std_axes_mlx_gen_orig_
+#define mlx_std_axis mlx_std_axis_mlx_gen_orig_
+#define mlx_std mlx_std_mlx_gen_orig_
+#define mlx_stop_gradient mlx_stop_gradient_mlx_gen_orig_
+#define mlx_subtract mlx_subtract_mlx_gen_orig_
+#define mlx_sum_axes mlx_sum_axes_mlx_gen_orig_
+#define mlx_sum_axis mlx_sum_axis_mlx_gen_orig_
+#define mlx_sum mlx_sum_mlx_gen_orig_
+#define mlx_swapaxes mlx_swapaxes_mlx_gen_orig_
+#define mlx_take_axis mlx_take_axis_mlx_gen_orig_
+#define mlx_take mlx_take_mlx_gen_orig_
+#define mlx_take_along_axis mlx_take_along_axis_mlx_gen_orig_
+#define mlx_tan mlx_tan_mlx_gen_orig_
+#define mlx_tanh mlx_tanh_mlx_gen_orig_
+#define mlx_tensordot mlx_tensordot_mlx_gen_orig_
+#define mlx_tensordot_axis mlx_tensordot_axis_mlx_gen_orig_
+#define mlx_tile mlx_tile_mlx_gen_orig_
+#define mlx_to_fp8 mlx_to_fp8_mlx_gen_orig_
+#define mlx_topk_axis mlx_topk_axis_mlx_gen_orig_
+#define mlx_topk mlx_topk_mlx_gen_orig_
+#define mlx_trace mlx_trace_mlx_gen_orig_
+#define mlx_transpose_axes mlx_transpose_axes_mlx_gen_orig_
+#define mlx_transpose mlx_transpose_mlx_gen_orig_
+#define mlx_tri mlx_tri_mlx_gen_orig_
+#define mlx_tril mlx_tril_mlx_gen_orig_
+#define mlx_triu mlx_triu_mlx_gen_orig_
+#define mlx_unflatten mlx_unflatten_mlx_gen_orig_
+#define mlx_var_axes mlx_var_axes_mlx_gen_orig_
+#define mlx_var_axis mlx_var_axis_mlx_gen_orig_
+#define mlx_var mlx_var_mlx_gen_orig_
+#define mlx_view mlx_view_mlx_gen_orig_
+#define mlx_where mlx_where_mlx_gen_orig_
+#define mlx_zeros mlx_zeros_mlx_gen_orig_
+#define mlx_zeros_like mlx_zeros_like_mlx_gen_orig_
+#define mlx_random_bernoulli mlx_random_bernoulli_mlx_gen_orig_
+#define mlx_random_bits mlx_random_bits_mlx_gen_orig_
+#define mlx_random_categorical_shape mlx_random_categorical_shape_mlx_gen_orig_
+#define mlx_random_categorical_num_samples mlx_random_categorical_num_samples_mlx_gen_orig_
+#define mlx_random_categorical mlx_random_categorical_mlx_gen_orig_
+#define mlx_random_gumbel mlx_random_gumbel_mlx_gen_orig_
+#define mlx_random_key mlx_random_key_mlx_gen_orig_
+#define mlx_random_laplace mlx_random_laplace_mlx_gen_orig_
+#define mlx_random_multivariate_normal mlx_random_multivariate_normal_mlx_gen_orig_
+#define mlx_random_normal_broadcast mlx_random_normal_broadcast_mlx_gen_orig_
+#define mlx_random_normal mlx_random_normal_mlx_gen_orig_
+#define mlx_random_permutation mlx_random_permutation_mlx_gen_orig_
+#define mlx_random_permutation_arange mlx_random_permutation_arange_mlx_gen_orig_
+#define mlx_random_randint mlx_random_randint_mlx_gen_orig_
+#define mlx_random_seed mlx_random_seed_mlx_gen_orig_
+#define mlx_random_split_num mlx_random_split_num_mlx_gen_orig_
+#define mlx_random_split mlx_random_split_mlx_gen_orig_
+#define mlx_random_truncated_normal mlx_random_truncated_normal_mlx_gen_orig_
+#define mlx_random_uniform mlx_random_uniform_mlx_gen_orig_
+#define mlx_stream_new mlx_stream_new_mlx_gen_orig_
+#define mlx_stream_new_device mlx_stream_new_device_mlx_gen_orig_
+#define mlx_stream_set mlx_stream_set_mlx_gen_orig_
+#define mlx_stream_free mlx_stream_free_mlx_gen_orig_
+#define mlx_stream_tostring mlx_stream_tostring_mlx_gen_orig_
+#define mlx_stream_equal mlx_stream_equal_mlx_gen_orig_
+#define mlx_stream_get_device mlx_stream_get_device_mlx_gen_orig_
+#define mlx_stream_get_index mlx_stream_get_index_mlx_gen_orig_
+#define mlx_synchronize mlx_synchronize_mlx_gen_orig_
+#define mlx_get_default_stream mlx_get_default_stream_mlx_gen_orig_
+#define mlx_set_default_stream mlx_set_default_stream_mlx_gen_orig_
+#define mlx_default_cpu_stream_new mlx_default_cpu_stream_new_mlx_gen_orig_
+#define mlx_default_gpu_stream_new mlx_default_gpu_stream_new_mlx_gen_orig_
+#define mlx_string_new mlx_string_new_mlx_gen_orig_
+#define mlx_string_new_data mlx_string_new_data_mlx_gen_orig_
+#define mlx_string_set mlx_string_set_mlx_gen_orig_
+#define mlx_string_data mlx_string_data_mlx_gen_orig_
+#define mlx_string_free mlx_string_free_mlx_gen_orig_
+#define mlx_detail_vmap_replace mlx_detail_vmap_replace_mlx_gen_orig_
+#define mlx_detail_vmap_trace mlx_detail_vmap_trace_mlx_gen_orig_
+#define mlx_async_eval mlx_async_eval_mlx_gen_orig_
+#define mlx_checkpoint mlx_checkpoint_mlx_gen_orig_
+#define mlx_custom_function mlx_custom_function_mlx_gen_orig_
+#define mlx_custom_vjp mlx_custom_vjp_mlx_gen_orig_
+#define mlx_eval mlx_eval_mlx_gen_orig_
+#define mlx_jvp mlx_jvp_mlx_gen_orig_
+#define mlx_value_and_grad mlx_value_and_grad_mlx_gen_orig_
+#define mlx_vjp mlx_vjp_mlx_gen_orig_
+#define mlx_vector_array_new mlx_vector_array_new_mlx_gen_orig_
+#define mlx_vector_array_set mlx_vector_array_set_mlx_gen_orig_
+#define mlx_vector_array_free mlx_vector_array_free_mlx_gen_orig_
+#define mlx_vector_array_new_data mlx_vector_array_new_data_mlx_gen_orig_
+#define mlx_vector_array_new_value mlx_vector_array_new_value_mlx_gen_orig_
+#define mlx_vector_array_set_data mlx_vector_array_set_data_mlx_gen_orig_
+#define mlx_vector_array_set_value mlx_vector_array_set_value_mlx_gen_orig_
+#define mlx_vector_array_append_data mlx_vector_array_append_data_mlx_gen_orig_
+#define mlx_vector_array_append_value mlx_vector_array_append_value_mlx_gen_orig_
+#define mlx_vector_array_size mlx_vector_array_size_mlx_gen_orig_
+#define mlx_vector_array_get mlx_vector_array_get_mlx_gen_orig_
+#define mlx_vector_vector_array_new mlx_vector_vector_array_new_mlx_gen_orig_
+#define mlx_vector_vector_array_set mlx_vector_vector_array_set_mlx_gen_orig_
+#define mlx_vector_vector_array_free mlx_vector_vector_array_free_mlx_gen_orig_
+#define mlx_vector_vector_array_new_data mlx_vector_vector_array_new_data_mlx_gen_orig_
+#define mlx_vector_vector_array_new_value mlx_vector_vector_array_new_value_mlx_gen_orig_
+#define mlx_vector_vector_array_set_data mlx_vector_vector_array_set_data_mlx_gen_orig_
+#define mlx_vector_vector_array_set_value mlx_vector_vector_array_set_value_mlx_gen_orig_
+#define mlx_vector_vector_array_append_data mlx_vector_vector_array_append_data_mlx_gen_orig_
+#define mlx_vector_vector_array_append_value mlx_vector_vector_array_append_value_mlx_gen_orig_
+#define mlx_vector_vector_array_size mlx_vector_vector_array_size_mlx_gen_orig_
+#define mlx_vector_vector_array_get mlx_vector_vector_array_get_mlx_gen_orig_
+#define mlx_vector_int_new mlx_vector_int_new_mlx_gen_orig_
+#define mlx_vector_int_set mlx_vector_int_set_mlx_gen_orig_
+#define mlx_vector_int_free mlx_vector_int_free_mlx_gen_orig_
+#define mlx_vector_int_new_data mlx_vector_int_new_data_mlx_gen_orig_
+#define mlx_vector_int_new_value mlx_vector_int_new_value_mlx_gen_orig_
+#define mlx_vector_int_set_data mlx_vector_int_set_data_mlx_gen_orig_
+#define mlx_vector_int_set_value mlx_vector_int_set_value_mlx_gen_orig_
+#define mlx_vector_int_append_data mlx_vector_int_append_data_mlx_gen_orig_
+#define mlx_vector_int_append_value mlx_vector_int_append_value_mlx_gen_orig_
+#define mlx_vector_int_size mlx_vector_int_size_mlx_gen_orig_
+#define mlx_vector_int_get mlx_vector_int_get_mlx_gen_orig_
+#define mlx_vector_string_new mlx_vector_string_new_mlx_gen_orig_
+#define mlx_vector_string_set mlx_vector_string_set_mlx_gen_orig_
+#define mlx_vector_string_free mlx_vector_string_free_mlx_gen_orig_
+#define mlx_vector_string_new_data mlx_vector_string_new_data_mlx_gen_orig_
+#define mlx_vector_string_new_value mlx_vector_string_new_value_mlx_gen_orig_
+#define mlx_vector_string_set_data mlx_vector_string_set_data_mlx_gen_orig_
+#define mlx_vector_string_set_value mlx_vector_string_set_value_mlx_gen_orig_
+#define mlx_vector_string_append_data mlx_vector_string_append_data_mlx_gen_orig_
+#define mlx_vector_string_append_value mlx_vector_string_append_value_mlx_gen_orig_
+#define mlx_vector_string_size mlx_vector_string_size_mlx_gen_orig_
+#define mlx_vector_string_get mlx_vector_string_get_mlx_gen_orig_
+#define mlx_version mlx_version_mlx_gen_orig_
+
+#include "mlx/c/mlx.h"
+
+#undef mlx_dtype_size
+#undef mlx_array_tostring
+#undef mlx_array_new
+#undef mlx_array_free
+#undef mlx_array_new_bool
+#undef mlx_array_new_int
+#undef mlx_array_new_float32
+#undef mlx_array_new_float
+#undef mlx_array_new_float64
+#undef mlx_array_new_double
+#undef mlx_array_new_complex
+#undef mlx_array_new_data
+#undef mlx_array_set
+#undef mlx_array_set_bool
+#undef mlx_array_set_int
+#undef mlx_array_set_float32
+#undef mlx_array_set_float
+#undef mlx_array_set_float64
+#undef mlx_array_set_double
+#undef mlx_array_set_complex
+#undef mlx_array_set_data
+#undef mlx_array_itemsize
+#undef mlx_array_size
+#undef mlx_array_nbytes
+#undef mlx_array_ndim
+#undef mlx_array_shape
+#undef mlx_array_strides
+#undef mlx_array_dim
+#undef mlx_array_dtype
+#undef mlx_array_eval
+#undef mlx_array_item_bool
+#undef mlx_array_item_uint8
+#undef mlx_array_item_uint16
+#undef mlx_array_item_uint32
+#undef mlx_array_item_uint64
+#undef mlx_array_item_int8
+#undef mlx_array_item_int16
+#undef mlx_array_item_int32
+#undef mlx_array_item_int64
+#undef mlx_array_item_float32
+#undef mlx_array_item_float64
+#undef mlx_array_item_complex64
+#undef mlx_array_item_float16
+#undef mlx_array_item_bfloat16
+#undef mlx_array_data_bool
+#undef mlx_array_data_uint8
+#undef mlx_array_data_uint16
+#undef mlx_array_data_uint32
+#undef mlx_array_data_uint64
+#undef mlx_array_data_int8
+#undef mlx_array_data_int16
+#undef mlx_array_data_int32
+#undef mlx_array_data_int64
+#undef mlx_array_data_float32
+#undef mlx_array_data_float64
+#undef mlx_array_data_complex64
+#undef mlx_array_data_float16
+#undef mlx_array_data_bfloat16
+#undef _mlx_array_is_available
+#undef _mlx_array_wait
+#undef _mlx_array_is_contiguous
+#undef _mlx_array_is_row_contiguous
+#undef _mlx_array_is_col_contiguous
+#undef mlx_closure_new
+#undef mlx_closure_free
+#undef mlx_closure_new_func
+#undef mlx_closure_new_func_payload
+#undef mlx_closure_set
+#undef mlx_closure_apply
+#undef mlx_closure_new_unary
+#undef mlx_closure_kwargs_new
+#undef mlx_closure_kwargs_free
+#undef mlx_closure_kwargs_new_func
+#undef mlx_closure_kwargs_new_func_payload
+#undef mlx_closure_kwargs_set
+#undef mlx_closure_kwargs_apply
+#undef mlx_closure_value_and_grad_new
+#undef mlx_closure_value_and_grad_free
+#undef mlx_closure_value_and_grad_new_func
+#undef mlx_closure_value_and_grad_new_func_payload
+#undef mlx_closure_value_and_grad_set
+#undef mlx_closure_value_and_grad_apply
+#undef mlx_closure_custom_new
+#undef mlx_closure_custom_free
+#undef mlx_closure_custom_new_func
+#undef mlx_closure_custom_new_func_payload
+#undef mlx_closure_custom_set
+#undef mlx_closure_custom_apply
+#undef mlx_closure_custom_jvp_new
+#undef mlx_closure_custom_jvp_free
+#undef mlx_closure_custom_jvp_new_func
+#undef mlx_closure_custom_jvp_new_func_payload
+#undef mlx_closure_custom_jvp_set
+#undef mlx_closure_custom_jvp_apply
+#undef mlx_closure_custom_vmap_new
+#undef mlx_closure_custom_vmap_free
+#undef mlx_closure_custom_vmap_new_func
+#undef mlx_closure_custom_vmap_new_func_payload
+#undef mlx_closure_custom_vmap_set
+#undef mlx_closure_custom_vmap_apply
+#undef mlx_compile
+#undef mlx_detail_compile
+#undef mlx_detail_compile_clear_cache
+#undef mlx_detail_compile_erase
+#undef mlx_disable_compile
+#undef mlx_enable_compile
+#undef mlx_set_compile_mode
+#undef mlx_device_new
+#undef mlx_device_new_type
+#undef mlx_device_free
+#undef mlx_device_set
+#undef mlx_device_tostring
+#undef mlx_device_equal
+#undef mlx_device_get_index
+#undef mlx_device_get_type
+#undef mlx_get_default_device
+#undef mlx_set_default_device
+#undef mlx_distributed_group_rank
+#undef mlx_distributed_group_size
+#undef mlx_distributed_group_split
+#undef mlx_distributed_is_available
+#undef mlx_distributed_init
+#undef mlx_distributed_all_gather
+#undef mlx_distributed_all_max
+#undef mlx_distributed_all_min
+#undef mlx_distributed_all_sum
+#undef mlx_distributed_recv
+#undef mlx_distributed_recv_like
+#undef mlx_distributed_send
+#undef mlx_distributed_sum_scatter
+#undef mlx_set_error_handler
+#undef _mlx_error
+#undef mlx_export_function
+#undef mlx_export_function_kwargs
+#undef mlx_function_exporter_new
+#undef mlx_function_exporter_free
+#undef mlx_function_exporter_apply
+#undef mlx_function_exporter_apply_kwargs
+#undef mlx_imported_function_new
+#undef mlx_imported_function_free
+#undef mlx_imported_function_apply
+#undef mlx_imported_function_apply_kwargs
+#undef mlx_fast_cuda_kernel_config_new
+#undef mlx_fast_cuda_kernel_config_free
+#undef mlx_fast_cuda_kernel_config_add_output_arg
+#undef mlx_fast_cuda_kernel_config_set_grid
+#undef mlx_fast_cuda_kernel_config_set_thread_group
+#undef mlx_fast_cuda_kernel_config_set_init_value
+#undef mlx_fast_cuda_kernel_config_set_verbose
+#undef mlx_fast_cuda_kernel_config_add_template_arg_dtype
+#undef mlx_fast_cuda_kernel_config_add_template_arg_int
+#undef mlx_fast_cuda_kernel_config_add_template_arg_bool
+#undef mlx_fast_cuda_kernel_new
+#undef mlx_fast_cuda_kernel_free
+#undef mlx_fast_cuda_kernel_apply
+#undef mlx_fast_layer_norm
+#undef mlx_fast_metal_kernel_config_new
+#undef mlx_fast_metal_kernel_config_free
+#undef mlx_fast_metal_kernel_config_add_output_arg
+#undef mlx_fast_metal_kernel_config_set_grid
+#undef mlx_fast_metal_kernel_config_set_thread_group
+#undef mlx_fast_metal_kernel_config_set_init_value
+#undef mlx_fast_metal_kernel_config_set_verbose
+#undef mlx_fast_metal_kernel_config_add_template_arg_dtype
+#undef mlx_fast_metal_kernel_config_add_template_arg_int
+#undef mlx_fast_metal_kernel_config_add_template_arg_bool
+#undef mlx_fast_metal_kernel_new
+#undef mlx_fast_metal_kernel_free
+#undef mlx_fast_metal_kernel_apply
+#undef mlx_fast_rms_norm
+#undef mlx_fast_rope
+#undef mlx_fast_scaled_dot_product_attention
+#undef mlx_fft_fft
+#undef mlx_fft_fft2
+#undef mlx_fft_fftn
+#undef mlx_fft_fftshift
+#undef mlx_fft_ifft
+#undef mlx_fft_ifft2
+#undef mlx_fft_ifftn
+#undef mlx_fft_ifftshift
+#undef mlx_fft_irfft
+#undef mlx_fft_irfft2
+#undef mlx_fft_irfftn
+#undef mlx_fft_rfft
+#undef mlx_fft_rfft2
+#undef mlx_fft_rfftn
+#undef mlx_io_reader_new
+#undef mlx_io_reader_descriptor
+#undef mlx_io_reader_tostring
+#undef mlx_io_reader_free
+#undef mlx_io_writer_new
+#undef mlx_io_writer_descriptor
+#undef mlx_io_writer_tostring
+#undef mlx_io_writer_free
+#undef mlx_load_reader
+#undef mlx_load
+#undef mlx_load_safetensors_reader
+#undef mlx_load_safetensors
+#undef mlx_save_writer
+#undef mlx_save
+#undef mlx_save_safetensors_writer
+#undef mlx_save_safetensors
+#undef mlx_linalg_cholesky
+#undef mlx_linalg_cholesky_inv
+#undef mlx_linalg_cross
+#undef mlx_linalg_eig
+#undef mlx_linalg_eigh
+#undef mlx_linalg_eigvals
+#undef mlx_linalg_eigvalsh
+#undef mlx_linalg_inv
+#undef mlx_linalg_lu
+#undef mlx_linalg_lu_factor
+#undef mlx_linalg_norm
+#undef mlx_linalg_norm_matrix
+#undef mlx_linalg_norm_l2
+#undef mlx_linalg_pinv
+#undef mlx_linalg_qr
+#undef mlx_linalg_solve
+#undef mlx_linalg_solve_triangular
+#undef mlx_linalg_svd
+#undef mlx_linalg_tri_inv
+#undef mlx_map_string_to_array_new
+#undef mlx_map_string_to_array_set
+#undef mlx_map_string_to_array_free
+#undef mlx_map_string_to_array_insert
+#undef mlx_map_string_to_array_get
+#undef mlx_map_string_to_array_iterator_new
+#undef mlx_map_string_to_array_iterator_free
+#undef mlx_map_string_to_array_iterator_next
+#undef mlx_map_string_to_string_new
+#undef mlx_map_string_to_string_set
+#undef mlx_map_string_to_string_free
+#undef mlx_map_string_to_string_insert
+#undef mlx_map_string_to_string_get
+#undef mlx_map_string_to_string_iterator_new
+#undef mlx_map_string_to_string_iterator_free
+#undef mlx_map_string_to_string_iterator_next
+#undef mlx_clear_cache
+#undef mlx_get_active_memory
+#undef mlx_get_cache_memory
+#undef mlx_get_memory_limit
+#undef mlx_get_peak_memory
+#undef mlx_reset_peak_memory
+#undef mlx_set_cache_limit
+#undef mlx_set_memory_limit
+#undef mlx_set_wired_limit
+#undef mlx_metal_device_info
+#undef mlx_metal_is_available
+#undef mlx_metal_start_capture
+#undef mlx_metal_stop_capture
+#undef mlx_abs
+#undef mlx_add
+#undef mlx_addmm
+#undef mlx_all_axes
+#undef mlx_all_axis
+#undef mlx_all
+#undef mlx_allclose
+#undef mlx_any_axes
+#undef mlx_any_axis
+#undef mlx_any
+#undef mlx_arange
+#undef mlx_arccos
+#undef mlx_arccosh
+#undef mlx_arcsin
+#undef mlx_arcsinh
+#undef mlx_arctan
+#undef mlx_arctan2
+#undef mlx_arctanh
+#undef mlx_argmax_axis
+#undef mlx_argmax
+#undef mlx_argmin_axis
+#undef mlx_argmin
+#undef mlx_argpartition_axis
+#undef mlx_argpartition
+#undef mlx_argsort_axis
+#undef mlx_argsort
+#undef mlx_array_equal
+#undef mlx_as_strided
+#undef mlx_astype
+#undef mlx_atleast_1d
+#undef mlx_atleast_2d
+#undef mlx_atleast_3d
+#undef mlx_bitwise_and
+#undef mlx_bitwise_invert
+#undef mlx_bitwise_or
+#undef mlx_bitwise_xor
+#undef mlx_block_masked_mm
+#undef mlx_broadcast_arrays
+#undef mlx_broadcast_to
+#undef mlx_ceil
+#undef mlx_clip
+#undef mlx_concatenate_axis
+#undef mlx_concatenate
+#undef mlx_conjugate
+#undef mlx_contiguous
+#undef mlx_conv1d
+#undef mlx_conv2d
+#undef mlx_conv3d
+#undef mlx_conv_general
+#undef mlx_conv_transpose1d
+#undef mlx_conv_transpose2d
+#undef mlx_conv_transpose3d
+#undef mlx_copy
+#undef mlx_cos
+#undef mlx_cosh
+#undef mlx_cummax
+#undef mlx_cummin
+#undef mlx_cumprod
+#undef mlx_cumsum
+#undef mlx_degrees
+#undef mlx_depends
+#undef mlx_dequantize
+#undef mlx_diag
+#undef mlx_diagonal
+#undef mlx_divide
+#undef mlx_divmod
+#undef mlx_einsum
+#undef mlx_equal
+#undef mlx_erf
+#undef mlx_erfinv
+#undef mlx_exp
+#undef mlx_expand_dims_axes
+#undef mlx_expand_dims
+#undef mlx_expm1
+#undef mlx_eye
+#undef mlx_flatten
+#undef mlx_floor
+#undef mlx_floor_divide
+#undef mlx_from_fp8
+#undef mlx_full
+#undef mlx_full_like
+#undef mlx_gather
+#undef mlx_gather_mm
+#undef mlx_gather_qmm
+#undef mlx_greater
+#undef mlx_greater_equal
+#undef mlx_hadamard_transform
+#undef mlx_identity
+#undef mlx_imag
+#undef mlx_inner
+#undef mlx_isclose
+#undef mlx_isfinite
+#undef mlx_isinf
+#undef mlx_isnan
+#undef mlx_isneginf
+#undef mlx_isposinf
+#undef mlx_kron
+#undef mlx_left_shift
+#undef mlx_less
+#undef mlx_less_equal
+#undef mlx_linspace
+#undef mlx_log
+#undef mlx_log10
+#undef mlx_log1p
+#undef mlx_log2
+#undef mlx_logaddexp
+#undef mlx_logcumsumexp
+#undef mlx_logical_and
+#undef mlx_logical_not
+#undef mlx_logical_or
+#undef mlx_logsumexp_axes
+#undef mlx_logsumexp_axis
+#undef mlx_logsumexp
+#undef mlx_masked_scatter
+#undef mlx_matmul
+#undef mlx_max_axes
+#undef mlx_max_axis
+#undef mlx_max
+#undef mlx_maximum
+#undef mlx_mean_axes
+#undef mlx_mean_axis
+#undef mlx_mean
+#undef mlx_median
+#undef mlx_meshgrid
+#undef mlx_min_axes
+#undef mlx_min_axis
+#undef mlx_min
+#undef mlx_minimum
+#undef mlx_moveaxis
+#undef mlx_multiply
+#undef mlx_nan_to_num
+#undef mlx_negative
+#undef mlx_not_equal
+#undef mlx_number_of_elements
+#undef mlx_ones
+#undef mlx_ones_like
+#undef mlx_outer
+#undef mlx_pad
+#undef mlx_pad_symmetric
+#undef mlx_partition_axis
+#undef mlx_partition
+#undef mlx_power
+#undef mlx_prod_axes
+#undef mlx_prod_axis
+#undef mlx_prod
+#undef mlx_put_along_axis
+#undef mlx_quantize
+#undef mlx_quantized_matmul
+#undef mlx_radians
+#undef mlx_real
+#undef mlx_reciprocal
+#undef mlx_remainder
+#undef mlx_repeat_axis
+#undef mlx_repeat
+#undef mlx_reshape
+#undef mlx_right_shift
+#undef mlx_roll_axis
+#undef mlx_roll_axes
+#undef mlx_roll
+#undef mlx_round
+#undef mlx_rsqrt
+#undef mlx_scatter
+#undef mlx_scatter_add
+#undef mlx_scatter_add_axis
+#undef mlx_scatter_max
+#undef mlx_scatter_min
+#undef mlx_scatter_prod
+#undef mlx_segmented_mm
+#undef mlx_sigmoid
+#undef mlx_sign
+#undef mlx_sin
+#undef mlx_sinh
+#undef mlx_slice
+#undef mlx_slice_dynamic
+#undef mlx_slice_update
+#undef mlx_slice_update_dynamic
+#undef mlx_softmax_axes
+#undef mlx_softmax_axis
+#undef mlx_softmax
+#undef mlx_sort_axis
+#undef mlx_sort
+#undef mlx_split
+#undef mlx_split_sections
+#undef mlx_sqrt
+#undef mlx_square
+#undef mlx_squeeze_axes
+#undef mlx_squeeze_axis
+#undef mlx_squeeze
+#undef mlx_stack_axis
+#undef mlx_stack
+#undef mlx_std_axes
+#undef mlx_std_axis
+#undef mlx_std
+#undef mlx_stop_gradient
+#undef mlx_subtract
+#undef mlx_sum_axes
+#undef mlx_sum_axis
+#undef mlx_sum
+#undef mlx_swapaxes
+#undef mlx_take_axis
+#undef mlx_take
+#undef mlx_take_along_axis
+#undef mlx_tan
+#undef mlx_tanh
+#undef mlx_tensordot
+#undef mlx_tensordot_axis
+#undef mlx_tile
+#undef mlx_to_fp8
+#undef mlx_topk_axis
+#undef mlx_topk
+#undef mlx_trace
+#undef mlx_transpose_axes
+#undef mlx_transpose
+#undef mlx_tri
+#undef mlx_tril
+#undef mlx_triu
+#undef mlx_unflatten
+#undef mlx_var_axes
+#undef mlx_var_axis
+#undef mlx_var
+#undef mlx_view
+#undef mlx_where
+#undef mlx_zeros
+#undef mlx_zeros_like
+#undef mlx_random_bernoulli
+#undef mlx_random_bits
+#undef mlx_random_categorical_shape
+#undef mlx_random_categorical_num_samples
+#undef mlx_random_categorical
+#undef mlx_random_gumbel
+#undef mlx_random_key
+#undef mlx_random_laplace
+#undef mlx_random_multivariate_normal
+#undef mlx_random_normal_broadcast
+#undef mlx_random_normal
+#undef mlx_random_permutation
+#undef mlx_random_permutation_arange
+#undef mlx_random_randint
+#undef mlx_random_seed
+#undef mlx_random_split_num
+#undef mlx_random_split
+#undef mlx_random_truncated_normal
+#undef mlx_random_uniform
+#undef mlx_stream_new
+#undef mlx_stream_new_device
+#undef mlx_stream_set
+#undef mlx_stream_free
+#undef mlx_stream_tostring
+#undef mlx_stream_equal
+#undef mlx_stream_get_device
+#undef mlx_stream_get_index
+#undef mlx_synchronize
+#undef mlx_get_default_stream
+#undef mlx_set_default_stream
+#undef mlx_default_cpu_stream_new
+#undef mlx_default_gpu_stream_new
+#undef mlx_string_new
+#undef mlx_string_new_data
+#undef mlx_string_set
+#undef mlx_string_data
+#undef mlx_string_free
+#undef mlx_detail_vmap_replace
+#undef mlx_detail_vmap_trace
+#undef mlx_async_eval
+#undef mlx_checkpoint
+#undef mlx_custom_function
+#undef mlx_custom_vjp
+#undef mlx_eval
+#undef mlx_jvp
+#undef mlx_value_and_grad
+#undef mlx_vjp
+#undef mlx_vector_array_new
+#undef mlx_vector_array_set
+#undef mlx_vector_array_free
+#undef mlx_vector_array_new_data
+#undef mlx_vector_array_new_value
+#undef mlx_vector_array_set_data
+#undef mlx_vector_array_set_value
+#undef mlx_vector_array_append_data
+#undef mlx_vector_array_append_value
+#undef mlx_vector_array_size
+#undef mlx_vector_array_get
+#undef mlx_vector_vector_array_new
+#undef mlx_vector_vector_array_set
+#undef mlx_vector_vector_array_free
+#undef mlx_vector_vector_array_new_data
+#undef mlx_vector_vector_array_new_value
+#undef mlx_vector_vector_array_set_data
+#undef mlx_vector_vector_array_set_value
+#undef mlx_vector_vector_array_append_data
+#undef mlx_vector_vector_array_append_value
+#undef mlx_vector_vector_array_size
+#undef mlx_vector_vector_array_get
+#undef mlx_vector_int_new
+#undef mlx_vector_int_set
+#undef mlx_vector_int_free
+#undef mlx_vector_int_new_data
+#undef mlx_vector_int_new_value
+#undef mlx_vector_int_set_data
+#undef mlx_vector_int_set_value
+#undef mlx_vector_int_append_data
+#undef mlx_vector_int_append_value
+#undef mlx_vector_int_size
+#undef mlx_vector_int_get
+#undef mlx_vector_string_new
+#undef mlx_vector_string_set
+#undef mlx_vector_string_free
+#undef mlx_vector_string_new_data
+#undef mlx_vector_string_new_value
+#undef mlx_vector_string_set_data
+#undef mlx_vector_string_set_value
+#undef mlx_vector_string_append_data
+#undef mlx_vector_string_append_value
+#undef mlx_vector_string_size
+#undef mlx_vector_string_get
+#undef mlx_version
+
+extern size_t (*mlx_dtype_size_)(mlx_dtype dtype);
+extern int (*mlx_array_tostring_)(mlx_string* str, const mlx_array arr);
+extern mlx_array (*mlx_array_new_)(void);
+extern int (*mlx_array_free_)(mlx_array arr);
+extern mlx_array (*mlx_array_new_bool_)(bool val);
+extern mlx_array (*mlx_array_new_int_)(int val);
+extern mlx_array (*mlx_array_new_float32_)(float val);
+extern mlx_array (*mlx_array_new_float_)(float val);
+extern mlx_array (*mlx_array_new_float64_)(double val);
+extern mlx_array (*mlx_array_new_double_)(double val);
+extern mlx_array (*mlx_array_new_complex_)(float real_val, float imag_val);
+extern mlx_array (*mlx_array_new_data_)(
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype);
+extern int (*mlx_array_set_)(mlx_array* arr, const mlx_array src);
+extern int (*mlx_array_set_bool_)(mlx_array* arr, bool val);
+extern int (*mlx_array_set_int_)(mlx_array* arr, int val);
+extern int (*mlx_array_set_float32_)(mlx_array* arr, float val);
+extern int (*mlx_array_set_float_)(mlx_array* arr, float val);
+extern int (*mlx_array_set_float64_)(mlx_array* arr, double val);
+extern int (*mlx_array_set_double_)(mlx_array* arr, double val);
+extern int (*mlx_array_set_complex_)(mlx_array* arr, float real_val, float imag_val);
+extern int (*mlx_array_set_data_)(
+ mlx_array* arr,
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype);
+extern size_t (*mlx_array_itemsize_)(const mlx_array arr);
+extern size_t (*mlx_array_size_)(const mlx_array arr);
+extern size_t (*mlx_array_nbytes_)(const mlx_array arr);
+extern size_t (*mlx_array_ndim_)(const mlx_array arr);
+extern const int * (*mlx_array_shape_)(const mlx_array arr);
+extern const size_t * (*mlx_array_strides_)(const mlx_array arr);
+extern int (*mlx_array_dim_)(const mlx_array arr, int dim);
+extern mlx_dtype (*mlx_array_dtype_)(const mlx_array arr);
+extern int (*mlx_array_eval_)(mlx_array arr);
+extern int (*mlx_array_item_bool_)(bool* res, const mlx_array arr);
+extern int (*mlx_array_item_uint8_)(uint8_t* res, const mlx_array arr);
+extern int (*mlx_array_item_uint16_)(uint16_t* res, const mlx_array arr);
+extern int (*mlx_array_item_uint32_)(uint32_t* res, const mlx_array arr);
+extern int (*mlx_array_item_uint64_)(uint64_t* res, const mlx_array arr);
+extern int (*mlx_array_item_int8_)(int8_t* res, const mlx_array arr);
+extern int (*mlx_array_item_int16_)(int16_t* res, const mlx_array arr);
+extern int (*mlx_array_item_int32_)(int32_t* res, const mlx_array arr);
+extern int (*mlx_array_item_int64_)(int64_t* res, const mlx_array arr);
+extern int (*mlx_array_item_float32_)(float* res, const mlx_array arr);
+extern int (*mlx_array_item_float64_)(double* res, const mlx_array arr);
+extern int (*mlx_array_item_complex64_)(float _Complex* res, const mlx_array arr);
+extern int (*mlx_array_item_float16_)(float16_t* res, const mlx_array arr);
+extern int (*mlx_array_item_bfloat16_)(bfloat16_t* res, const mlx_array arr);
+extern const bool * (*mlx_array_data_bool_)(const mlx_array arr);
+extern const uint8_t * (*mlx_array_data_uint8_)(const mlx_array arr);
+extern const uint16_t * (*mlx_array_data_uint16_)(const mlx_array arr);
+extern const uint32_t * (*mlx_array_data_uint32_)(const mlx_array arr);
+extern const uint64_t * (*mlx_array_data_uint64_)(const mlx_array arr);
+extern const int8_t * (*mlx_array_data_int8_)(const mlx_array arr);
+extern const int16_t * (*mlx_array_data_int16_)(const mlx_array arr);
+extern const int32_t * (*mlx_array_data_int32_)(const mlx_array arr);
+extern const int64_t * (*mlx_array_data_int64_)(const mlx_array arr);
+extern const float * (*mlx_array_data_float32_)(const mlx_array arr);
+extern const double * (*mlx_array_data_float64_)(const mlx_array arr);
+extern const float _Complex * (*mlx_array_data_complex64_)(const mlx_array arr);
+extern const float16_t * (*mlx_array_data_float16_)(const mlx_array arr);
+extern const bfloat16_t * (*mlx_array_data_bfloat16_)(const mlx_array arr);
+extern int (*_mlx_array_is_available_)(bool* res, const mlx_array arr);
+extern int (*_mlx_array_wait_)(const mlx_array arr);
+extern int (*_mlx_array_is_contiguous_)(bool* res, const mlx_array arr);
+extern int (*_mlx_array_is_row_contiguous_)(bool* res, const mlx_array arr);
+extern int (*_mlx_array_is_col_contiguous_)(bool* res, const mlx_array arr);
+extern mlx_closure (*mlx_closure_new_)(void);
+extern int (*mlx_closure_free_)(mlx_closure cls);
+extern mlx_closure (*mlx_closure_new_func_)(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array));
+extern mlx_closure (*mlx_closure_new_func_payload_)(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
+ void* payload,
+ void (*dtor)(void*));
+extern int (*mlx_closure_set_)(mlx_closure* cls, const mlx_closure src);
+extern int (*mlx_closure_apply_)(
+ mlx_vector_array* res,
+ mlx_closure cls,
+ const mlx_vector_array input);
+extern mlx_closure (*mlx_closure_new_unary_)(int (*fun)(mlx_array*, const mlx_array));
+extern mlx_closure_kwargs (*mlx_closure_kwargs_new_)(void);
+extern int (*mlx_closure_kwargs_free_)(mlx_closure_kwargs cls);
+extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array));
+extern mlx_closure_kwargs (*mlx_closure_kwargs_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+extern int (*mlx_closure_kwargs_set_)(
+ mlx_closure_kwargs* cls,
+ const mlx_closure_kwargs src);
+extern int (*mlx_closure_kwargs_apply_)(
+ mlx_vector_array* res,
+ mlx_closure_kwargs cls,
+ const mlx_vector_array input_0,
+ const mlx_map_string_to_array input_1);
+extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_)(void);
+extern int (*mlx_closure_value_and_grad_free_)(mlx_closure_value_and_grad cls);
+extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_)(
+ int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array));
+extern mlx_closure_value_and_grad (*mlx_closure_value_and_grad_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_array*,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+extern int (*mlx_closure_value_and_grad_set_)(
+ mlx_closure_value_and_grad* cls,
+ const mlx_closure_value_and_grad src);
+extern int (*mlx_closure_value_and_grad_apply_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ mlx_closure_value_and_grad cls,
+ const mlx_vector_array input);
+extern mlx_closure_custom (*mlx_closure_custom_new_)(void);
+extern int (*mlx_closure_custom_free_)(mlx_closure_custom cls);
+extern mlx_closure_custom (*mlx_closure_custom_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array));
+extern mlx_closure_custom (*mlx_closure_custom_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+extern int (*mlx_closure_custom_set_)(
+ mlx_closure_custom* cls,
+ const mlx_closure_custom src);
+extern int (*mlx_closure_custom_apply_)(
+ mlx_vector_array* res,
+ mlx_closure_custom cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const mlx_vector_array input_2);
+extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_)(void);
+extern int (*mlx_closure_custom_jvp_free_)(mlx_closure_custom_jvp cls);
+extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num));
+extern mlx_closure_custom_jvp (*mlx_closure_custom_jvp_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+extern int (*mlx_closure_custom_jvp_set_)(
+ mlx_closure_custom_jvp* cls,
+ const mlx_closure_custom_jvp src);
+extern int (*mlx_closure_custom_jvp_apply_)(
+ mlx_vector_array* res,
+ mlx_closure_custom_jvp cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const int* input_2,
+ size_t input_2_num);
+extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_)(void);
+extern int (*mlx_closure_custom_vmap_free_)(mlx_closure_custom_vmap cls);
+extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_)(int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num));
+extern mlx_closure_custom_vmap (*mlx_closure_custom_vmap_new_func_payload_)(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*));
+extern int (*mlx_closure_custom_vmap_set_)(
+ mlx_closure_custom_vmap* cls,
+ const mlx_closure_custom_vmap src);
+extern int (*mlx_closure_custom_vmap_apply_)(
+ mlx_vector_array* res_0,
+ mlx_vector_int* res_1,
+ mlx_closure_custom_vmap cls,
+ const mlx_vector_array input_0,
+ const int* input_1,
+ size_t input_1_num);
+extern int (*mlx_compile_)(mlx_closure* res, const mlx_closure fun, bool shapeless);
+extern int (*mlx_detail_compile_)(
+ mlx_closure* res,
+ const mlx_closure fun,
+ uintptr_t fun_id,
+ bool shapeless,
+ const uint64_t* constants,
+ size_t constants_num);
+extern int (*mlx_detail_compile_clear_cache_)(void);
+extern int (*mlx_detail_compile_erase_)(uintptr_t fun_id);
+extern int (*mlx_disable_compile_)(void);
+extern int (*mlx_enable_compile_)(void);
+extern int (*mlx_set_compile_mode_)(mlx_compile_mode mode);
+extern mlx_device (*mlx_device_new_)(void);
+extern mlx_device (*mlx_device_new_type_)(mlx_device_type type, int index);
+extern int (*mlx_device_free_)(mlx_device dev);
+extern int (*mlx_device_set_)(mlx_device* dev, const mlx_device src);
+extern int (*mlx_device_tostring_)(mlx_string* str, mlx_device dev);
+extern bool (*mlx_device_equal_)(mlx_device lhs, mlx_device rhs);
+extern int (*mlx_device_get_index_)(int* index, mlx_device dev);
+extern int (*mlx_device_get_type_)(mlx_device_type* type, mlx_device dev);
+extern int (*mlx_get_default_device_)(mlx_device* dev);
+extern int (*mlx_set_default_device_)(mlx_device dev);
+extern int (*mlx_distributed_group_rank_)(mlx_distributed_group group);
+extern int (*mlx_distributed_group_size_)(mlx_distributed_group group);
+extern mlx_distributed_group (*mlx_distributed_group_split_)(mlx_distributed_group group, int color, int key);
+extern bool (*mlx_distributed_is_available_)(void);
+extern mlx_distributed_group (*mlx_distributed_init_)(bool strict);
+extern int (*mlx_distributed_all_gather_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream S);
+extern int (*mlx_distributed_all_max_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_distributed_all_min_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_distributed_all_sum_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_distributed_recv_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_distributed_recv_like_)(
+ mlx_array* res,
+ const mlx_array x,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_distributed_send_)(
+ mlx_array* res,
+ const mlx_array x,
+ int dst,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_distributed_sum_scatter_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s);
+extern void (*mlx_set_error_handler_)(
+ mlx_error_handler_func handler,
+ void* data,
+ void (*dtor)(void*));
+extern void (*_mlx_error_)(const char* file, const int line, const char* fmt, ...);
+extern int (*mlx_export_function_)(
+ const char* file,
+ const mlx_closure fun,
+ const mlx_vector_array args,
+ bool shapeless);
+extern int (*mlx_export_function_kwargs_)(
+ const char* file,
+ const mlx_closure_kwargs fun,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs,
+ bool shapeless);
+extern mlx_function_exporter (*mlx_function_exporter_new_)(
+ const char* file,
+ const mlx_closure fun,
+ bool shapeless);
+extern int (*mlx_function_exporter_free_)(mlx_function_exporter xfunc);
+extern int (*mlx_function_exporter_apply_)(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args);
+extern int (*mlx_function_exporter_apply_kwargs_)(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs);
+extern mlx_imported_function (*mlx_imported_function_new_)(const char* file);
+extern int (*mlx_imported_function_free_)(mlx_imported_function xfunc);
+extern int (*mlx_imported_function_apply_)(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args);
+extern int (*mlx_imported_function_apply_kwargs_)(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs);
+extern mlx_fast_cuda_kernel_config (*mlx_fast_cuda_kernel_config_new_)(void);
+extern void (*mlx_fast_cuda_kernel_config_free_)(mlx_fast_cuda_kernel_config cls);
+extern int (*mlx_fast_cuda_kernel_config_add_output_arg_)(
+ mlx_fast_cuda_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype);
+extern int (*mlx_fast_cuda_kernel_config_set_grid_)(
+ mlx_fast_cuda_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3);
+extern int (*mlx_fast_cuda_kernel_config_set_thread_group_)(
+ mlx_fast_cuda_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3);
+extern int (*mlx_fast_cuda_kernel_config_set_init_value_)(
+ mlx_fast_cuda_kernel_config cls,
+ float value);
+extern int (*mlx_fast_cuda_kernel_config_set_verbose_)(
+ mlx_fast_cuda_kernel_config cls,
+ bool verbose);
+extern int (*mlx_fast_cuda_kernel_config_add_template_arg_dtype_)(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype);
+extern int (*mlx_fast_cuda_kernel_config_add_template_arg_int_)(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ int value);
+extern int (*mlx_fast_cuda_kernel_config_add_template_arg_bool_)(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ bool value);
+extern mlx_fast_cuda_kernel (*mlx_fast_cuda_kernel_new_)(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ int shared_memory);
+extern void (*mlx_fast_cuda_kernel_free_)(mlx_fast_cuda_kernel cls);
+extern int (*mlx_fast_cuda_kernel_apply_)(
+ mlx_vector_array* outputs,
+ mlx_fast_cuda_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_cuda_kernel_config config,
+ const mlx_stream stream);
+extern int (*mlx_fast_layer_norm_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ const mlx_array bias /* may be null */,
+ float eps,
+ const mlx_stream s);
+extern mlx_fast_metal_kernel_config (*mlx_fast_metal_kernel_config_new_)(void);
+extern void (*mlx_fast_metal_kernel_config_free_)(mlx_fast_metal_kernel_config cls);
+extern int (*mlx_fast_metal_kernel_config_add_output_arg_)(
+ mlx_fast_metal_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype);
+extern int (*mlx_fast_metal_kernel_config_set_grid_)(
+ mlx_fast_metal_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3);
+extern int (*mlx_fast_metal_kernel_config_set_thread_group_)(
+ mlx_fast_metal_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3);
+extern int (*mlx_fast_metal_kernel_config_set_init_value_)(
+ mlx_fast_metal_kernel_config cls,
+ float value);
+extern int (*mlx_fast_metal_kernel_config_set_verbose_)(
+ mlx_fast_metal_kernel_config cls,
+ bool verbose);
+extern int (*mlx_fast_metal_kernel_config_add_template_arg_dtype_)(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype);
+extern int (*mlx_fast_metal_kernel_config_add_template_arg_int_)(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ int value);
+extern int (*mlx_fast_metal_kernel_config_add_template_arg_bool_)(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ bool value);
+extern mlx_fast_metal_kernel (*mlx_fast_metal_kernel_new_)(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ bool atomic_outputs);
+extern void (*mlx_fast_metal_kernel_free_)(mlx_fast_metal_kernel cls);
+extern int (*mlx_fast_metal_kernel_apply_)(
+ mlx_vector_array* outputs,
+ mlx_fast_metal_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_metal_kernel_config config,
+ const mlx_stream stream);
+extern int (*mlx_fast_rms_norm_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ float eps,
+ const mlx_stream s);
+extern int (*mlx_fast_rope_)(
+ mlx_array* res,
+ const mlx_array x,
+ int dims,
+ bool traditional,
+ mlx_optional_float base,
+ float scale,
+ int offset,
+ const mlx_array freqs /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_fast_scaled_dot_product_attention_)(
+ mlx_array* res,
+ const mlx_array queries,
+ const mlx_array keys,
+ const mlx_array values,
+ float scale,
+ const char* mask_mode,
+ const mlx_array mask_arr /* may be null */,
+ const mlx_array sinks /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_fft_fft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_fft_fft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_fftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_fftshift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_ifft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_fft_ifft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_ifftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_ifftshift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_irfft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_fft_irfft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_irfftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_rfft_)(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_fft_rfft2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_fft_rfftn_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern mlx_io_reader (*mlx_io_reader_new_)(void* desc, mlx_io_vtable vtable);
+extern int (*mlx_io_reader_descriptor_)(void** desc_, mlx_io_reader io);
+extern int (*mlx_io_reader_tostring_)(mlx_string* str_, mlx_io_reader io);
+extern int (*mlx_io_reader_free_)(mlx_io_reader io);
+extern mlx_io_writer (*mlx_io_writer_new_)(void* desc, mlx_io_vtable vtable);
+extern int (*mlx_io_writer_descriptor_)(void** desc_, mlx_io_writer io);
+extern int (*mlx_io_writer_tostring_)(mlx_string* str_, mlx_io_writer io);
+extern int (*mlx_io_writer_free_)(mlx_io_writer io);
+extern int (*mlx_load_reader_)(
+ mlx_array* res,
+ mlx_io_reader in_stream,
+ const mlx_stream s);
+extern int (*mlx_load_)(mlx_array* res, const char* file, const mlx_stream s);
+extern int (*mlx_load_safetensors_reader_)(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ mlx_io_reader in_stream,
+ const mlx_stream s);
+extern int (*mlx_load_safetensors_)(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ const char* file,
+ const mlx_stream s);
+extern int (*mlx_save_writer_)(mlx_io_writer out_stream, const mlx_array a);
+extern int (*mlx_save_)(const char* file, const mlx_array a);
+extern int (*mlx_save_safetensors_writer_)(
+ mlx_io_writer in_stream,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata);
+extern int (*mlx_save_safetensors_)(
+ const char* file,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata);
+extern int (*mlx_linalg_cholesky_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s);
+extern int (*mlx_linalg_cholesky_inv_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s);
+extern int (*mlx_linalg_cross_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_linalg_eig_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s);
+extern int (*mlx_linalg_eigh_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s);
+extern int (*mlx_linalg_eigvals_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_linalg_eigvalsh_)(
+ mlx_array* res,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s);
+extern int (*mlx_linalg_inv_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_linalg_lu_)(mlx_vector_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_linalg_lu_factor_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s);
+extern int (*mlx_linalg_norm_)(
+ mlx_array* res,
+ const mlx_array a,
+ double ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_linalg_norm_matrix_)(
+ mlx_array* res,
+ const mlx_array a,
+ const char* ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_linalg_norm_l2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_linalg_pinv_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_linalg_qr_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s);
+extern int (*mlx_linalg_solve_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_linalg_solve_triangular_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool upper,
+ const mlx_stream s);
+extern int (*mlx_linalg_svd_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ bool compute_uv,
+ const mlx_stream s);
+extern int (*mlx_linalg_tri_inv_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s);
+extern mlx_map_string_to_array (*mlx_map_string_to_array_new_)(void);
+extern int (*mlx_map_string_to_array_set_)(
+ mlx_map_string_to_array* map,
+ const mlx_map_string_to_array src);
+extern int (*mlx_map_string_to_array_free_)(mlx_map_string_to_array map);
+extern int (*mlx_map_string_to_array_insert_)(
+ mlx_map_string_to_array map,
+ const char* key,
+ const mlx_array value);
+extern int (*mlx_map_string_to_array_get_)(
+ mlx_array* value,
+ const mlx_map_string_to_array map,
+ const char* key);
+extern mlx_map_string_to_array_iterator (*mlx_map_string_to_array_iterator_new_)(
+ mlx_map_string_to_array map);
+extern int (*mlx_map_string_to_array_iterator_free_)(mlx_map_string_to_array_iterator it);
+extern int (*mlx_map_string_to_array_iterator_next_)(
+ const char** key,
+ mlx_array* value,
+ mlx_map_string_to_array_iterator it);
+extern mlx_map_string_to_string (*mlx_map_string_to_string_new_)(void);
+extern int (*mlx_map_string_to_string_set_)(
+ mlx_map_string_to_string* map,
+ const mlx_map_string_to_string src);
+extern int (*mlx_map_string_to_string_free_)(mlx_map_string_to_string map);
+extern int (*mlx_map_string_to_string_insert_)(
+ mlx_map_string_to_string map,
+ const char* key,
+ const char* value);
+extern int (*mlx_map_string_to_string_get_)(
+ const char** value,
+ const mlx_map_string_to_string map,
+ const char* key);
+extern mlx_map_string_to_string_iterator (*mlx_map_string_to_string_iterator_new_)(
+ mlx_map_string_to_string map);
+extern int (*mlx_map_string_to_string_iterator_free_)(
+ mlx_map_string_to_string_iterator it);
+extern int (*mlx_map_string_to_string_iterator_next_)(
+ const char** key,
+ const char** value,
+ mlx_map_string_to_string_iterator it);
+extern int (*mlx_clear_cache_)(void);
+extern int (*mlx_get_active_memory_)(size_t* res);
+extern int (*mlx_get_cache_memory_)(size_t* res);
+extern int (*mlx_get_memory_limit_)(size_t* res);
+extern int (*mlx_get_peak_memory_)(size_t* res);
+extern int (*mlx_reset_peak_memory_)(void);
+extern int (*mlx_set_cache_limit_)(size_t* res, size_t limit);
+extern int (*mlx_set_memory_limit_)(size_t* res, size_t limit);
+extern int (*mlx_set_wired_limit_)(size_t* res, size_t limit);
+extern mlx_metal_device_info_t (*mlx_metal_device_info_)(void);
+extern int (*mlx_metal_is_available_)(bool* res);
+extern int (*mlx_metal_start_capture_)(const char* path);
+extern int (*mlx_metal_stop_capture_)(void);
+extern int (*mlx_abs_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_add_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_addmm_)(
+ mlx_array* res,
+ const mlx_array c,
+ const mlx_array a,
+ const mlx_array b,
+ float alpha,
+ float beta,
+ const mlx_stream s);
+extern int (*mlx_all_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_all_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_all_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_allclose_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s);
+extern int (*mlx_any_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_any_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_any_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_arange_)(
+ mlx_array* res,
+ double start,
+ double stop,
+ double step,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_arccos_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_arccosh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_arcsin_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_arcsinh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_arctan_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_arctan2_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_arctanh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_argmax_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_argmax_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_argmin_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_argmin_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_argpartition_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_argpartition_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s);
+extern int (*mlx_argsort_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_argsort_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_array_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool equal_nan,
+ const mlx_stream s);
+extern int (*mlx_as_strided_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const int64_t* strides,
+ size_t strides_num,
+ size_t offset,
+ const mlx_stream s);
+extern int (*mlx_astype_)(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_atleast_1d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_atleast_2d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_atleast_3d_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_bitwise_and_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_bitwise_invert_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_bitwise_or_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_bitwise_xor_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_block_masked_mm_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int block_size,
+ const mlx_array mask_out /* may be null */,
+ const mlx_array mask_lhs /* may be null */,
+ const mlx_array mask_rhs /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_broadcast_arrays_)(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_stream s);
+extern int (*mlx_broadcast_to_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s);
+extern int (*mlx_ceil_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_clip_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array a_min /* may be null */,
+ const mlx_array a_max /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_concatenate_axis_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_concatenate_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s);
+extern int (*mlx_conjugate_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_contiguous_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool allow_col_major,
+ const mlx_stream s);
+extern int (*mlx_conv1d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int groups,
+ const mlx_stream s);
+extern int (*mlx_conv2d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int groups,
+ const mlx_stream s);
+extern int (*mlx_conv3d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int groups,
+ const mlx_stream s);
+extern int (*mlx_conv_general_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ const int* stride,
+ size_t stride_num,
+ const int* padding_lo,
+ size_t padding_lo_num,
+ const int* padding_hi,
+ size_t padding_hi_num,
+ const int* kernel_dilation,
+ size_t kernel_dilation_num,
+ const int* input_dilation,
+ size_t input_dilation_num,
+ int groups,
+ bool flip,
+ const mlx_stream s);
+extern int (*mlx_conv_transpose1d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int output_padding,
+ int groups,
+ const mlx_stream s);
+extern int (*mlx_conv_transpose2d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int output_padding_0,
+ int output_padding_1,
+ int groups,
+ const mlx_stream s);
+extern int (*mlx_conv_transpose3d_)(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int output_padding_0,
+ int output_padding_1,
+ int output_padding_2,
+ int groups,
+ const mlx_stream s);
+extern int (*mlx_copy_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_cos_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_cosh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_cummax_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+extern int (*mlx_cummin_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+extern int (*mlx_cumprod_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+extern int (*mlx_cumsum_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+extern int (*mlx_degrees_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_depends_)(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array dependencies);
+extern int (*mlx_dequantize_)(
+ mlx_array* res,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ mlx_optional_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_diag_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
+extern int (*mlx_diagonal_)(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ const mlx_stream s);
+extern int (*mlx_divide_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_divmod_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_einsum_)(
+ mlx_array* res,
+ const char* subscripts,
+ const mlx_vector_array operands,
+ const mlx_stream s);
+extern int (*mlx_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_erf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_erfinv_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_exp_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_expand_dims_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_expand_dims_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_expm1_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_eye_)(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_flatten_)(
+ mlx_array* res,
+ const mlx_array a,
+ int start_axis,
+ int end_axis,
+ const mlx_stream s);
+extern int (*mlx_floor_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_floor_divide_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_from_fp8_)(
+ mlx_array* res,
+ const mlx_array x,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_full_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_full_like_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_gather_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_sizes,
+ size_t slice_sizes_num,
+ const mlx_stream s);
+extern int (*mlx_gather_mm_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool sorted_indices,
+ const mlx_stream s);
+extern int (*mlx_gather_qmm_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ bool sorted_indices,
+ const mlx_stream s);
+extern int (*mlx_greater_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_greater_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_hadamard_transform_)(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_optional_float scale,
+ const mlx_stream s);
+extern int (*mlx_identity_)(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s);
+extern int (*mlx_imag_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_inner_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_isclose_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s);
+extern int (*mlx_isfinite_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_isinf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_isnan_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_isneginf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_isposinf_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_kron_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_left_shift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_less_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_less_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_linspace_)(
+ mlx_array* res,
+ double start,
+ double stop,
+ int num,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_log_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_log10_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_log1p_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_log2_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_logaddexp_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_logcumsumexp_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s);
+extern int (*mlx_logical_and_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_logical_not_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_logical_or_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_logsumexp_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_logsumexp_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_logsumexp_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_masked_scatter_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array mask,
+ const mlx_array src,
+ const mlx_stream s);
+extern int (*mlx_matmul_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_max_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_max_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_max_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_maximum_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_mean_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_mean_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_mean_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_median_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_meshgrid_)(
+ mlx_vector_array* res,
+ const mlx_vector_array arrays,
+ bool sparse,
+ const char* indexing,
+ const mlx_stream s);
+extern int (*mlx_min_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_min_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_min_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_minimum_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_moveaxis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int source,
+ int destination,
+ const mlx_stream s);
+extern int (*mlx_multiply_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_nan_to_num_)(
+ mlx_array* res,
+ const mlx_array a,
+ float nan,
+ mlx_optional_float posinf,
+ mlx_optional_float neginf,
+ const mlx_stream s);
+extern int (*mlx_negative_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_not_equal_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_number_of_elements_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool inverted,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_ones_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_ones_like_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_outer_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_pad_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const int* low_pad_size,
+ size_t low_pad_size_num,
+ const int* high_pad_size,
+ size_t high_pad_size_num,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s);
+extern int (*mlx_pad_symmetric_)(
+ mlx_array* res,
+ const mlx_array a,
+ int pad_width,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s);
+extern int (*mlx_partition_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_partition_)(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s);
+extern int (*mlx_power_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_prod_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_prod_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_prod_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_put_along_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_quantize_)(
+ mlx_vector_array* res,
+ const mlx_array w,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s);
+extern int (*mlx_quantized_matmul_)(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s);
+extern int (*mlx_radians_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_real_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_reciprocal_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_remainder_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_repeat_axis_)(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_repeat_)(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ const mlx_stream s);
+extern int (*mlx_reshape_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s);
+extern int (*mlx_right_shift_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_roll_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_roll_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_roll_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const mlx_stream s);
+extern int (*mlx_round_)(
+ mlx_array* res,
+ const mlx_array a,
+ int decimals,
+ const mlx_stream s);
+extern int (*mlx_rsqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_scatter_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_scatter_add_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_scatter_add_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_scatter_max_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_scatter_min_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_scatter_prod_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_segmented_mm_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array segments,
+ const mlx_stream s);
+extern int (*mlx_sigmoid_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_sign_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_sin_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_sinh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_slice_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s);
+extern int (*mlx_slice_dynamic_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_size,
+ size_t slice_size_num,
+ const mlx_stream s);
+extern int (*mlx_slice_update_)(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s);
+extern int (*mlx_slice_update_dynamic_)(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_softmax_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool precise,
+ const mlx_stream s);
+extern int (*mlx_softmax_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool precise,
+ const mlx_stream s);
+extern int (*mlx_softmax_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool precise,
+ const mlx_stream s);
+extern int (*mlx_sort_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_sort_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_split_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ int num_splits,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_split_sections_)(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const int* indices,
+ size_t indices_num,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_sqrt_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_square_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_squeeze_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_squeeze_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_squeeze_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_stack_axis_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_stack_)(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s);
+extern int (*mlx_std_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+extern int (*mlx_std_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+extern int (*mlx_std_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+extern int (*mlx_stop_gradient_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_subtract_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s);
+extern int (*mlx_sum_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_sum_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_sum_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s);
+extern int (*mlx_swapaxes_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis1,
+ int axis2,
+ const mlx_stream s);
+extern int (*mlx_take_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_take_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_stream s);
+extern int (*mlx_take_along_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_tan_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_tanh_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_tensordot_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const int* axes_a,
+ size_t axes_a_num,
+ const int* axes_b,
+ size_t axes_b_num,
+ const mlx_stream s);
+extern int (*mlx_tensordot_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_tile_)(
+ mlx_array* res,
+ const mlx_array arr,
+ const int* reps,
+ size_t reps_num,
+ const mlx_stream s);
+extern int (*mlx_to_fp8_)(mlx_array* res, const mlx_array x, const mlx_stream s);
+extern int (*mlx_topk_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int k,
+ int axis,
+ const mlx_stream s);
+extern int (*mlx_topk_)(mlx_array* res, const mlx_array a, int k, const mlx_stream s);
+extern int (*mlx_trace_)(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_transpose_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s);
+extern int (*mlx_transpose_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_tri_)(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype type,
+ const mlx_stream s);
+extern int (*mlx_tril_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
+extern int (*mlx_triu_)(mlx_array* res, const mlx_array x, int k, const mlx_stream s);
+extern int (*mlx_unflatten_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s);
+extern int (*mlx_var_axes_)(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+extern int (*mlx_var_axis_)(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+extern int (*mlx_var_)(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s);
+extern int (*mlx_view_)(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_where_)(
+ mlx_array* res,
+ const mlx_array condition,
+ const mlx_array x,
+ const mlx_array y,
+ const mlx_stream s);
+extern int (*mlx_zeros_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s);
+extern int (*mlx_zeros_like_)(mlx_array* res, const mlx_array a, const mlx_stream s);
+extern int (*mlx_random_bernoulli_)(
+ mlx_array* res,
+ const mlx_array p,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_bits_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ int width,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_categorical_shape_)(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_categorical_num_samples_)(
+ mlx_array* res,
+ const mlx_array logits_,
+ int axis,
+ int num_samples,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_categorical_)(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_gumbel_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_key_)(mlx_array* res, uint64_t seed);
+extern int (*mlx_random_laplace_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_multivariate_normal_)(
+ mlx_array* res,
+ const mlx_array mean,
+ const mlx_array cov,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_normal_broadcast_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array loc /* may be null */,
+ const mlx_array scale /* may be null */,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_normal_)(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_permutation_)(
+ mlx_array* res,
+ const mlx_array x,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_permutation_arange_)(
+ mlx_array* res,
+ int x,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_randint_)(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_seed_)(uint64_t seed);
+extern int (*mlx_random_split_num_)(
+ mlx_array* res,
+ const mlx_array key,
+ int num,
+ const mlx_stream s);
+extern int (*mlx_random_split_)(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array key,
+ const mlx_stream s);
+extern int (*mlx_random_truncated_normal_)(
+ mlx_array* res,
+ const mlx_array lower,
+ const mlx_array upper,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern int (*mlx_random_uniform_)(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s);
+extern mlx_stream (*mlx_stream_new_)(void);
+extern mlx_stream (*mlx_stream_new_device_)(mlx_device dev);
+extern int (*mlx_stream_set_)(mlx_stream* stream, const mlx_stream src);
+extern int (*mlx_stream_free_)(mlx_stream stream);
+extern int (*mlx_stream_tostring_)(mlx_string* str, mlx_stream stream);
+extern bool (*mlx_stream_equal_)(mlx_stream lhs, mlx_stream rhs);
+extern int (*mlx_stream_get_device_)(mlx_device* dev, mlx_stream stream);
+extern int (*mlx_stream_get_index_)(int* index, mlx_stream stream);
+extern int (*mlx_synchronize_)(mlx_stream stream);
+extern int (*mlx_get_default_stream_)(mlx_stream* stream, mlx_device dev);
+extern int (*mlx_set_default_stream_)(mlx_stream stream);
+extern mlx_stream (*mlx_default_cpu_stream_new_)(void);
+extern mlx_stream (*mlx_default_gpu_stream_new_)(void);
+extern mlx_string (*mlx_string_new_)(void);
+extern mlx_string (*mlx_string_new_data_)(const char* str);
+extern int (*mlx_string_set_)(mlx_string* str, const mlx_string src);
+extern const char * (*mlx_string_data_)(mlx_string str);
+extern int (*mlx_string_free_)(mlx_string str);
+extern int (*mlx_detail_vmap_replace_)(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array s_inputs,
+ const mlx_vector_array s_outputs,
+ const int* in_axes,
+ size_t in_axes_num,
+ const int* out_axes,
+ size_t out_axes_num);
+extern int (*mlx_detail_vmap_trace_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array inputs,
+ const int* in_axes,
+ size_t in_axes_num);
+extern int (*mlx_async_eval_)(const mlx_vector_array outputs);
+extern int (*mlx_checkpoint_)(mlx_closure* res, const mlx_closure fun);
+extern int (*mlx_custom_function_)(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp /* may be null */,
+ const mlx_closure_custom_jvp fun_jvp /* may be null */,
+ const mlx_closure_custom_vmap fun_vmap /* may be null */);
+extern int (*mlx_custom_vjp_)(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp);
+extern int (*mlx_eval_)(const mlx_vector_array outputs);
+extern int (*mlx_jvp_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array tangents);
+extern int (*mlx_value_and_grad_)(
+ mlx_closure_value_and_grad* res,
+ const mlx_closure fun,
+ const int* argnums,
+ size_t argnums_num);
+extern int (*mlx_vjp_)(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array cotangents);
+extern mlx_vector_array (*mlx_vector_array_new_)(void);
+extern int (*mlx_vector_array_set_)(mlx_vector_array* vec, const mlx_vector_array src);
+extern int (*mlx_vector_array_free_)(mlx_vector_array vec);
+extern mlx_vector_array (*mlx_vector_array_new_data_)(const mlx_array* data, size_t size);
+extern mlx_vector_array (*mlx_vector_array_new_value_)(const mlx_array val);
+extern int (*mlx_vector_array_set_data_)(
+ mlx_vector_array* vec,
+ const mlx_array* data,
+ size_t size);
+extern int (*mlx_vector_array_set_value_)(mlx_vector_array* vec, const mlx_array val);
+extern int (*mlx_vector_array_append_data_)(
+ mlx_vector_array vec,
+ const mlx_array* data,
+ size_t size);
+extern int (*mlx_vector_array_append_value_)(mlx_vector_array vec, const mlx_array val);
+extern size_t (*mlx_vector_array_size_)(mlx_vector_array vec);
+extern int (*mlx_vector_array_get_)(
+ mlx_array* res,
+ const mlx_vector_array vec,
+ size_t idx);
+extern mlx_vector_vector_array (*mlx_vector_vector_array_new_)(void);
+extern int (*mlx_vector_vector_array_set_)(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_vector_array src);
+extern int (*mlx_vector_vector_array_free_)(mlx_vector_vector_array vec);
+extern mlx_vector_vector_array (*mlx_vector_vector_array_new_data_)(
+ const mlx_vector_array* data,
+ size_t size);
+extern mlx_vector_vector_array (*mlx_vector_vector_array_new_value_)(
+ const mlx_vector_array val);
+extern int (*mlx_vector_vector_array_set_data_)(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array* data,
+ size_t size);
+extern int (*mlx_vector_vector_array_set_value_)(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array val);
+extern int (*mlx_vector_vector_array_append_data_)(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array* data,
+ size_t size);
+extern int (*mlx_vector_vector_array_append_value_)(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array val);
+extern size_t (*mlx_vector_vector_array_size_)(mlx_vector_vector_array vec);
+extern int (*mlx_vector_vector_array_get_)(
+ mlx_vector_array* res,
+ const mlx_vector_vector_array vec,
+ size_t idx);
+extern mlx_vector_int (*mlx_vector_int_new_)(void);
+extern int (*mlx_vector_int_set_)(mlx_vector_int* vec, const mlx_vector_int src);
+extern int (*mlx_vector_int_free_)(mlx_vector_int vec);
+extern mlx_vector_int (*mlx_vector_int_new_data_)(int* data, size_t size);
+extern mlx_vector_int (*mlx_vector_int_new_value_)(int val);
+extern int (*mlx_vector_int_set_data_)(mlx_vector_int* vec, int* data, size_t size);
+extern int (*mlx_vector_int_set_value_)(mlx_vector_int* vec, int val);
+extern int (*mlx_vector_int_append_data_)(mlx_vector_int vec, int* data, size_t size);
+extern int (*mlx_vector_int_append_value_)(mlx_vector_int vec, int val);
+extern size_t (*mlx_vector_int_size_)(mlx_vector_int vec);
+extern int (*mlx_vector_int_get_)(int* res, const mlx_vector_int vec, size_t idx);
+extern mlx_vector_string (*mlx_vector_string_new_)(void);
+extern int (*mlx_vector_string_set_)(mlx_vector_string* vec, const mlx_vector_string src);
+extern int (*mlx_vector_string_free_)(mlx_vector_string vec);
+extern mlx_vector_string (*mlx_vector_string_new_data_)(const char** data, size_t size);
+extern mlx_vector_string (*mlx_vector_string_new_value_)(const char* val);
+extern int (*mlx_vector_string_set_data_)(
+ mlx_vector_string* vec,
+ const char** data,
+ size_t size);
+extern int (*mlx_vector_string_set_value_)(mlx_vector_string* vec, const char* val);
+extern int (*mlx_vector_string_append_data_)(
+ mlx_vector_string vec,
+ const char** data,
+ size_t size);
+extern int (*mlx_vector_string_append_value_)(mlx_vector_string vec, const char* val);
+extern size_t (*mlx_vector_string_size_)(mlx_vector_string vec);
+extern int (*mlx_vector_string_get_)(char** res, const mlx_vector_string vec, size_t idx);
+extern int (*mlx_version_)(mlx_string* str_);
+
+int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
+
+static inline size_t mlx_dtype_size(mlx_dtype dtype) {
+ return mlx_dtype_size_(dtype);
+}
+
+static inline int mlx_array_tostring(mlx_string* str, const mlx_array arr) {
+ return mlx_array_tostring_(str, arr);
+}
+
+static inline mlx_array mlx_array_new(void) {
+ return mlx_array_new_();
+}
+
+static inline int mlx_array_free(mlx_array arr) {
+ return mlx_array_free_(arr);
+}
+
+static inline mlx_array mlx_array_new_bool(bool val) {
+ return mlx_array_new_bool_(val);
+}
+
+static inline mlx_array mlx_array_new_int(int val) {
+ return mlx_array_new_int_(val);
+}
+
+static inline mlx_array mlx_array_new_float32(float val) {
+ return mlx_array_new_float32_(val);
+}
+
+static inline mlx_array mlx_array_new_float(float val) {
+ return mlx_array_new_float_(val);
+}
+
+static inline mlx_array mlx_array_new_float64(double val) {
+ return mlx_array_new_float64_(val);
+}
+
+static inline mlx_array mlx_array_new_double(double val) {
+ return mlx_array_new_double_(val);
+}
+
+static inline mlx_array mlx_array_new_complex(float real_val, float imag_val) {
+ return mlx_array_new_complex_(real_val, imag_val);
+}
+
+static inline mlx_array mlx_array_new_data(
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype) {
+ return mlx_array_new_data_(data, shape, dim, dtype);
+}
+
+static inline int mlx_array_set(mlx_array* arr, const mlx_array src) {
+ return mlx_array_set_(arr, src);
+}
+
+static inline int mlx_array_set_bool(mlx_array* arr, bool val) {
+ return mlx_array_set_bool_(arr, val);
+}
+
+static inline int mlx_array_set_int(mlx_array* arr, int val) {
+ return mlx_array_set_int_(arr, val);
+}
+
+static inline int mlx_array_set_float32(mlx_array* arr, float val) {
+ return mlx_array_set_float32_(arr, val);
+}
+
+static inline int mlx_array_set_float(mlx_array* arr, float val) {
+ return mlx_array_set_float_(arr, val);
+}
+
+static inline int mlx_array_set_float64(mlx_array* arr, double val) {
+ return mlx_array_set_float64_(arr, val);
+}
+
+static inline int mlx_array_set_double(mlx_array* arr, double val) {
+ return mlx_array_set_double_(arr, val);
+}
+
+static inline int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val) {
+ return mlx_array_set_complex_(arr, real_val, imag_val);
+}
+
+static inline int mlx_array_set_data(
+ mlx_array* arr,
+ const void* data,
+ const int* shape,
+ int dim,
+ mlx_dtype dtype) {
+ return mlx_array_set_data_(arr, data, shape, dim, dtype);
+}
+
+static inline size_t mlx_array_itemsize(const mlx_array arr) {
+ return mlx_array_itemsize_(arr);
+}
+
+static inline size_t mlx_array_size(const mlx_array arr) {
+ return mlx_array_size_(arr);
+}
+
+static inline size_t mlx_array_nbytes(const mlx_array arr) {
+ return mlx_array_nbytes_(arr);
+}
+
+static inline size_t mlx_array_ndim(const mlx_array arr) {
+ return mlx_array_ndim_(arr);
+}
+
+static inline const int * mlx_array_shape(const mlx_array arr) {
+ return mlx_array_shape_(arr);
+}
+
+static inline const size_t * mlx_array_strides(const mlx_array arr) {
+ return mlx_array_strides_(arr);
+}
+
+static inline int mlx_array_dim(const mlx_array arr, int dim) {
+ return mlx_array_dim_(arr, dim);
+}
+
+static inline mlx_dtype mlx_array_dtype(const mlx_array arr) {
+ return mlx_array_dtype_(arr);
+}
+
+static inline int mlx_array_eval(mlx_array arr) {
+ return mlx_array_eval_(arr);
+}
+
+static inline int mlx_array_item_bool(bool* res, const mlx_array arr) {
+ return mlx_array_item_bool_(res, arr);
+}
+
+static inline int mlx_array_item_uint8(uint8_t* res, const mlx_array arr) {
+ return mlx_array_item_uint8_(res, arr);
+}
+
+static inline int mlx_array_item_uint16(uint16_t* res, const mlx_array arr) {
+ return mlx_array_item_uint16_(res, arr);
+}
+
+static inline int mlx_array_item_uint32(uint32_t* res, const mlx_array arr) {
+ return mlx_array_item_uint32_(res, arr);
+}
+
+static inline int mlx_array_item_uint64(uint64_t* res, const mlx_array arr) {
+ return mlx_array_item_uint64_(res, arr);
+}
+
+static inline int mlx_array_item_int8(int8_t* res, const mlx_array arr) {
+ return mlx_array_item_int8_(res, arr);
+}
+
+static inline int mlx_array_item_int16(int16_t* res, const mlx_array arr) {
+ return mlx_array_item_int16_(res, arr);
+}
+
+static inline int mlx_array_item_int32(int32_t* res, const mlx_array arr) {
+ return mlx_array_item_int32_(res, arr);
+}
+
+static inline int mlx_array_item_int64(int64_t* res, const mlx_array arr) {
+ return mlx_array_item_int64_(res, arr);
+}
+
+static inline int mlx_array_item_float32(float* res, const mlx_array arr) {
+ return mlx_array_item_float32_(res, arr);
+}
+
+static inline int mlx_array_item_float64(double* res, const mlx_array arr) {
+ return mlx_array_item_float64_(res, arr);
+}
+
+static inline int mlx_array_item_complex64(float _Complex* res, const mlx_array arr) {
+ return mlx_array_item_complex64_(res, arr);
+}
+
+static inline int mlx_array_item_float16(float16_t* res, const mlx_array arr) {
+ return mlx_array_item_float16_(res, arr);
+}
+
+static inline int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr) {
+ return mlx_array_item_bfloat16_(res, arr);
+}
+
+static inline const bool * mlx_array_data_bool(const mlx_array arr) {
+ return mlx_array_data_bool_(arr);
+}
+
+static inline const uint8_t * mlx_array_data_uint8(const mlx_array arr) {
+ return mlx_array_data_uint8_(arr);
+}
+
+static inline const uint16_t * mlx_array_data_uint16(const mlx_array arr) {
+ return mlx_array_data_uint16_(arr);
+}
+
+static inline const uint32_t * mlx_array_data_uint32(const mlx_array arr) {
+ return mlx_array_data_uint32_(arr);
+}
+
+static inline const uint64_t * mlx_array_data_uint64(const mlx_array arr) {
+ return mlx_array_data_uint64_(arr);
+}
+
+static inline const int8_t * mlx_array_data_int8(const mlx_array arr) {
+ return mlx_array_data_int8_(arr);
+}
+
+static inline const int16_t * mlx_array_data_int16(const mlx_array arr) {
+ return mlx_array_data_int16_(arr);
+}
+
+static inline const int32_t * mlx_array_data_int32(const mlx_array arr) {
+ return mlx_array_data_int32_(arr);
+}
+
+static inline const int64_t * mlx_array_data_int64(const mlx_array arr) {
+ return mlx_array_data_int64_(arr);
+}
+
+static inline const float * mlx_array_data_float32(const mlx_array arr) {
+ return mlx_array_data_float32_(arr);
+}
+
+static inline const double * mlx_array_data_float64(const mlx_array arr) {
+ return mlx_array_data_float64_(arr);
+}
+
+static inline const float _Complex * mlx_array_data_complex64(const mlx_array arr) {
+ return mlx_array_data_complex64_(arr);
+}
+
+static inline const float16_t * mlx_array_data_float16(const mlx_array arr) {
+ return mlx_array_data_float16_(arr);
+}
+
+static inline const bfloat16_t * mlx_array_data_bfloat16(const mlx_array arr) {
+ return mlx_array_data_bfloat16_(arr);
+}
+
+static inline int _mlx_array_is_available(bool* res, const mlx_array arr) {
+ return _mlx_array_is_available_(res, arr);
+}
+
+static inline int _mlx_array_wait(const mlx_array arr) {
+ return _mlx_array_wait_(arr);
+}
+
+static inline int _mlx_array_is_contiguous(bool* res, const mlx_array arr) {
+ return _mlx_array_is_contiguous_(res, arr);
+}
+
+static inline int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr) {
+ return _mlx_array_is_row_contiguous_(res, arr);
+}
+
+static inline int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr) {
+ return _mlx_array_is_col_contiguous_(res, arr);
+}
+
+static inline mlx_closure mlx_closure_new(void) {
+ return mlx_closure_new_();
+}
+
+static inline int mlx_closure_free(mlx_closure cls) {
+ return mlx_closure_free_(cls);
+}
+
+static inline mlx_closure mlx_closure_new_func(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array)) {
+ return mlx_closure_new_func_(fun);
+}
+
+static inline mlx_closure mlx_closure_new_func_payload(
+ int (*fun)(mlx_vector_array*, const mlx_vector_array, void*),
+ void* payload,
+ void (*dtor)(void*)) {
+ return mlx_closure_new_func_payload_(fun, payload, dtor);
+}
+
+static inline int mlx_closure_set(mlx_closure* cls, const mlx_closure src) {
+ return mlx_closure_set_(cls, src);
+}
+
+static inline int mlx_closure_apply(
+ mlx_vector_array* res,
+ mlx_closure cls,
+ const mlx_vector_array input) {
+ return mlx_closure_apply_(res, cls, input);
+}
+
+static inline mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)) {
+ return mlx_closure_new_unary_(fun);
+}
+
+static inline mlx_closure_kwargs mlx_closure_kwargs_new(void) {
+ return mlx_closure_kwargs_new_();
+}
+
+static inline int mlx_closure_kwargs_free(mlx_closure_kwargs cls) {
+ return mlx_closure_kwargs_free_(cls);
+}
+
+static inline mlx_closure_kwargs mlx_closure_kwargs_new_func(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array)) {
+ return mlx_closure_kwargs_new_func_(fun);
+}
+
+static inline mlx_closure_kwargs mlx_closure_kwargs_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_map_string_to_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) {
+ return mlx_closure_kwargs_new_func_payload_(fun, payload, dtor);
+}
+
+static inline int mlx_closure_kwargs_set(
+ mlx_closure_kwargs* cls,
+ const mlx_closure_kwargs src) {
+ return mlx_closure_kwargs_set_(cls, src);
+}
+
+static inline int mlx_closure_kwargs_apply(
+ mlx_vector_array* res,
+ mlx_closure_kwargs cls,
+ const mlx_vector_array input_0,
+ const mlx_map_string_to_array input_1) {
+ return mlx_closure_kwargs_apply_(res, cls, input_0, input_1);
+}
+
+static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void) {
+ return mlx_closure_value_and_grad_new_();
+}
+
+static inline int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls) {
+ return mlx_closure_value_and_grad_free_(cls);
+}
+
+static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func(
+ int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)) {
+ return mlx_closure_value_and_grad_new_func_(fun);
+}
+
+static inline mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_array*,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) {
+ return mlx_closure_value_and_grad_new_func_payload_(fun, payload, dtor);
+}
+
+static inline int mlx_closure_value_and_grad_set(
+ mlx_closure_value_and_grad* cls,
+ const mlx_closure_value_and_grad src) {
+ return mlx_closure_value_and_grad_set_(cls, src);
+}
+
+static inline int mlx_closure_value_and_grad_apply(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ mlx_closure_value_and_grad cls,
+ const mlx_vector_array input) {
+ return mlx_closure_value_and_grad_apply_(res_0, res_1, cls, input);
+}
+
+static inline mlx_closure_custom mlx_closure_custom_new(void) {
+ return mlx_closure_custom_new_();
+}
+
+static inline int mlx_closure_custom_free(mlx_closure_custom cls) {
+ return mlx_closure_custom_free_(cls);
+}
+
+static inline mlx_closure_custom mlx_closure_custom_new_func(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array)) {
+ return mlx_closure_custom_new_func_(fun);
+}
+
+static inline mlx_closure_custom mlx_closure_custom_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) {
+ return mlx_closure_custom_new_func_payload_(fun, payload, dtor);
+}
+
+static inline int mlx_closure_custom_set(
+ mlx_closure_custom* cls,
+ const mlx_closure_custom src) {
+ return mlx_closure_custom_set_(cls, src);
+}
+
+static inline int mlx_closure_custom_apply(
+ mlx_vector_array* res,
+ mlx_closure_custom cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const mlx_vector_array input_2) {
+ return mlx_closure_custom_apply_(res, cls, input_0, input_1, input_2);
+}
+
+static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void) {
+ return mlx_closure_custom_jvp_new_();
+}
+
+static inline int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls) {
+ return mlx_closure_custom_jvp_free_(cls);
+}
+
+static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func(int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num)) {
+ return mlx_closure_custom_jvp_new_func_(fun);
+}
+
+static inline mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ const mlx_vector_array,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) {
+ return mlx_closure_custom_jvp_new_func_payload_(fun, payload, dtor);
+}
+
+static inline int mlx_closure_custom_jvp_set(
+ mlx_closure_custom_jvp* cls,
+ const mlx_closure_custom_jvp src) {
+ return mlx_closure_custom_jvp_set_(cls, src);
+}
+
+static inline int mlx_closure_custom_jvp_apply(
+ mlx_vector_array* res,
+ mlx_closure_custom_jvp cls,
+ const mlx_vector_array input_0,
+ const mlx_vector_array input_1,
+ const int* input_2,
+ size_t input_2_num) {
+ return mlx_closure_custom_jvp_apply_(res, cls, input_0, input_1, input_2, input_2_num);
+}
+
+static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void) {
+ return mlx_closure_custom_vmap_new_();
+}
+
+static inline int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls) {
+ return mlx_closure_custom_vmap_free_(cls);
+}
+
+static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func(int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num)) {
+ return mlx_closure_custom_vmap_new_func_(fun);
+}
+
+static inline mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload(
+ int (*fun)(
+ mlx_vector_array*,
+ mlx_vector_int*,
+ const mlx_vector_array,
+ const int*,
+ size_t _num,
+ void*),
+ void* payload,
+ void (*dtor)(void*)) {
+ return mlx_closure_custom_vmap_new_func_payload_(fun, payload, dtor);
+}
+
+static inline int mlx_closure_custom_vmap_set(
+ mlx_closure_custom_vmap* cls,
+ const mlx_closure_custom_vmap src) {
+ return mlx_closure_custom_vmap_set_(cls, src);
+}
+
+static inline int mlx_closure_custom_vmap_apply(
+ mlx_vector_array* res_0,
+ mlx_vector_int* res_1,
+ mlx_closure_custom_vmap cls,
+ const mlx_vector_array input_0,
+ const int* input_1,
+ size_t input_1_num) {
+ return mlx_closure_custom_vmap_apply_(res_0, res_1, cls, input_0, input_1, input_1_num);
+}
+
+static inline int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless) {
+ return mlx_compile_(res, fun, shapeless);
+}
+
+static inline int mlx_detail_compile(
+ mlx_closure* res,
+ const mlx_closure fun,
+ uintptr_t fun_id,
+ bool shapeless,
+ const uint64_t* constants,
+ size_t constants_num) {
+ return mlx_detail_compile_(res, fun, fun_id, shapeless, constants, constants_num);
+}
+
+static inline int mlx_detail_compile_clear_cache(void) {
+ return mlx_detail_compile_clear_cache_();
+}
+
+static inline int mlx_detail_compile_erase(uintptr_t fun_id) {
+ return mlx_detail_compile_erase_(fun_id);
+}
+
+static inline int mlx_disable_compile(void) {
+ return mlx_disable_compile_();
+}
+
+static inline int mlx_enable_compile(void) {
+ return mlx_enable_compile_();
+}
+
+static inline int mlx_set_compile_mode(mlx_compile_mode mode) {
+ return mlx_set_compile_mode_(mode);
+}
+
+static inline mlx_device mlx_device_new(void) {
+ return mlx_device_new_();
+}
+
+static inline mlx_device mlx_device_new_type(mlx_device_type type, int index) {
+ return mlx_device_new_type_(type, index);
+}
+
+static inline int mlx_device_free(mlx_device dev) {
+ return mlx_device_free_(dev);
+}
+
+static inline int mlx_device_set(mlx_device* dev, const mlx_device src) {
+ return mlx_device_set_(dev, src);
+}
+
+static inline int mlx_device_tostring(mlx_string* str, mlx_device dev) {
+ return mlx_device_tostring_(str, dev);
+}
+
+static inline bool mlx_device_equal(mlx_device lhs, mlx_device rhs) {
+ return mlx_device_equal_(lhs, rhs);
+}
+
+static inline int mlx_device_get_index(int* index, mlx_device dev) {
+ return mlx_device_get_index_(index, dev);
+}
+
+static inline int mlx_device_get_type(mlx_device_type* type, mlx_device dev) {
+ return mlx_device_get_type_(type, dev);
+}
+
+static inline int mlx_get_default_device(mlx_device* dev) {
+ return mlx_get_default_device_(dev);
+}
+
+static inline int mlx_set_default_device(mlx_device dev) {
+ return mlx_set_default_device_(dev);
+}
+
+static inline int mlx_distributed_group_rank(mlx_distributed_group group) {
+ return mlx_distributed_group_rank_(group);
+}
+
+static inline int mlx_distributed_group_size(mlx_distributed_group group) {
+ return mlx_distributed_group_size_(group);
+}
+
+static inline mlx_distributed_group mlx_distributed_group_split(mlx_distributed_group group, int color, int key) {
+ return mlx_distributed_group_split_(group, color, key);
+}
+
+static inline bool mlx_distributed_is_available(void) {
+ return mlx_distributed_is_available_();
+}
+
+static inline mlx_distributed_group mlx_distributed_init(bool strict) {
+ return mlx_distributed_init_(strict);
+}
+
+static inline int mlx_distributed_all_gather(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream S) {
+ return mlx_distributed_all_gather_(res, x, group, S);
+}
+
+static inline int mlx_distributed_all_max(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_all_max_(res, x, group, s);
+}
+
+static inline int mlx_distributed_all_min(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_all_min_(res, x, group, s);
+}
+
+static inline int mlx_distributed_all_sum(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_all_sum_(res, x, group, s);
+}
+
+static inline int mlx_distributed_recv(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_recv_(res, shape, shape_num, dtype, src, group, s);
+}
+
+static inline int mlx_distributed_recv_like(
+ mlx_array* res,
+ const mlx_array x,
+ int src,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_recv_like_(res, x, src, group, s);
+}
+
+static inline int mlx_distributed_send(
+ mlx_array* res,
+ const mlx_array x,
+ int dst,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_send_(res, x, dst, group, s);
+}
+
+static inline int mlx_distributed_sum_scatter(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_distributed_group group /* may be null */,
+ const mlx_stream s) {
+ return mlx_distributed_sum_scatter_(res, x, group, s);
+}
+
+static inline void mlx_set_error_handler(
+ mlx_error_handler_func handler,
+ void* data,
+ void (*dtor)(void*)) {
+ mlx_set_error_handler_(handler, data, dtor);
+}
+
+#define _mlx_error(file, line, fmt, ...) _mlx_error_(file, line, fmt, __VA_ARGS__)
+
+static inline int mlx_export_function(
+ const char* file,
+ const mlx_closure fun,
+ const mlx_vector_array args,
+ bool shapeless) {
+ return mlx_export_function_(file, fun, args, shapeless);
+}
+
+static inline int mlx_export_function_kwargs(
+ const char* file,
+ const mlx_closure_kwargs fun,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs,
+ bool shapeless) {
+ return mlx_export_function_kwargs_(file, fun, args, kwargs, shapeless);
+}
+
+static inline mlx_function_exporter mlx_function_exporter_new(
+ const char* file,
+ const mlx_closure fun,
+ bool shapeless) {
+ return mlx_function_exporter_new_(file, fun, shapeless);
+}
+
+static inline int mlx_function_exporter_free(mlx_function_exporter xfunc) {
+ return mlx_function_exporter_free_(xfunc);
+}
+
+static inline int mlx_function_exporter_apply(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args) {
+ return mlx_function_exporter_apply_(xfunc, args);
+}
+
+static inline int mlx_function_exporter_apply_kwargs(
+ const mlx_function_exporter xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs) {
+ return mlx_function_exporter_apply_kwargs_(xfunc, args, kwargs);
+}
+
+static inline mlx_imported_function mlx_imported_function_new(const char* file) {
+ return mlx_imported_function_new_(file);
+}
+
+static inline int mlx_imported_function_free(mlx_imported_function xfunc) {
+ return mlx_imported_function_free_(xfunc);
+}
+
+static inline int mlx_imported_function_apply(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args) {
+ return mlx_imported_function_apply_(res, xfunc, args);
+}
+
+static inline int mlx_imported_function_apply_kwargs(
+ mlx_vector_array* res,
+ const mlx_imported_function xfunc,
+ const mlx_vector_array args,
+ const mlx_map_string_to_array kwargs) {
+ return mlx_imported_function_apply_kwargs_(res, xfunc, args, kwargs);
+}
+
+static inline mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void) {
+ return mlx_fast_cuda_kernel_config_new_();
+}
+
+static inline void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls) {
+ mlx_fast_cuda_kernel_config_free_(cls);
+}
+
+static inline int mlx_fast_cuda_kernel_config_add_output_arg(
+ mlx_fast_cuda_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype) {
+ return mlx_fast_cuda_kernel_config_add_output_arg_(cls, shape, size, dtype);
+}
+
+static inline int mlx_fast_cuda_kernel_config_set_grid(
+ mlx_fast_cuda_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3) {
+ return mlx_fast_cuda_kernel_config_set_grid_(cls, grid1, grid2, grid3);
+}
+
+static inline int mlx_fast_cuda_kernel_config_set_thread_group(
+ mlx_fast_cuda_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3) {
+ return mlx_fast_cuda_kernel_config_set_thread_group_(cls, thread1, thread2, thread3);
+}
+
+static inline int mlx_fast_cuda_kernel_config_set_init_value(
+ mlx_fast_cuda_kernel_config cls,
+ float value) {
+ return mlx_fast_cuda_kernel_config_set_init_value_(cls, value);
+}
+
+static inline int mlx_fast_cuda_kernel_config_set_verbose(
+ mlx_fast_cuda_kernel_config cls,
+ bool verbose) {
+ return mlx_fast_cuda_kernel_config_set_verbose_(cls, verbose);
+}
+
+static inline int mlx_fast_cuda_kernel_config_add_template_arg_dtype(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype) {
+ return mlx_fast_cuda_kernel_config_add_template_arg_dtype_(cls, name, dtype);
+}
+
+static inline int mlx_fast_cuda_kernel_config_add_template_arg_int(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ int value) {
+ return mlx_fast_cuda_kernel_config_add_template_arg_int_(cls, name, value);
+}
+
+static inline int mlx_fast_cuda_kernel_config_add_template_arg_bool(
+ mlx_fast_cuda_kernel_config cls,
+ const char* name,
+ bool value) {
+ return mlx_fast_cuda_kernel_config_add_template_arg_bool_(cls, name, value);
+}
+
+static inline mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ int shared_memory) {
+ return mlx_fast_cuda_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, shared_memory);
+}
+
+static inline void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls) {
+ mlx_fast_cuda_kernel_free_(cls);
+}
+
+static inline int mlx_fast_cuda_kernel_apply(
+ mlx_vector_array* outputs,
+ mlx_fast_cuda_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_cuda_kernel_config config,
+ const mlx_stream stream) {
+ return mlx_fast_cuda_kernel_apply_(outputs, cls, inputs, config, stream);
+}
+
+static inline int mlx_fast_layer_norm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ const mlx_array bias /* may be null */,
+ float eps,
+ const mlx_stream s) {
+ return mlx_fast_layer_norm_(res, x, weight, bias, eps, s);
+}
+
+static inline mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void) {
+ return mlx_fast_metal_kernel_config_new_();
+}
+
+static inline void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls) {
+ mlx_fast_metal_kernel_config_free_(cls);
+}
+
+static inline int mlx_fast_metal_kernel_config_add_output_arg(
+ mlx_fast_metal_kernel_config cls,
+ const int* shape,
+ size_t size,
+ mlx_dtype dtype) {
+ return mlx_fast_metal_kernel_config_add_output_arg_(cls, shape, size, dtype);
+}
+
+static inline int mlx_fast_metal_kernel_config_set_grid(
+ mlx_fast_metal_kernel_config cls,
+ int grid1,
+ int grid2,
+ int grid3) {
+ return mlx_fast_metal_kernel_config_set_grid_(cls, grid1, grid2, grid3);
+}
+
+static inline int mlx_fast_metal_kernel_config_set_thread_group(
+ mlx_fast_metal_kernel_config cls,
+ int thread1,
+ int thread2,
+ int thread3) {
+ return mlx_fast_metal_kernel_config_set_thread_group_(cls, thread1, thread2, thread3);
+}
+
+static inline int mlx_fast_metal_kernel_config_set_init_value(
+ mlx_fast_metal_kernel_config cls,
+ float value) {
+ return mlx_fast_metal_kernel_config_set_init_value_(cls, value);
+}
+
+static inline int mlx_fast_metal_kernel_config_set_verbose(
+ mlx_fast_metal_kernel_config cls,
+ bool verbose) {
+ return mlx_fast_metal_kernel_config_set_verbose_(cls, verbose);
+}
+
+static inline int mlx_fast_metal_kernel_config_add_template_arg_dtype(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ mlx_dtype dtype) {
+ return mlx_fast_metal_kernel_config_add_template_arg_dtype_(cls, name, dtype);
+}
+
+static inline int mlx_fast_metal_kernel_config_add_template_arg_int(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ int value) {
+ return mlx_fast_metal_kernel_config_add_template_arg_int_(cls, name, value);
+}
+
+static inline int mlx_fast_metal_kernel_config_add_template_arg_bool(
+ mlx_fast_metal_kernel_config cls,
+ const char* name,
+ bool value) {
+ return mlx_fast_metal_kernel_config_add_template_arg_bool_(cls, name, value);
+}
+
+static inline mlx_fast_metal_kernel mlx_fast_metal_kernel_new(
+ const char* name,
+ const mlx_vector_string input_names,
+ const mlx_vector_string output_names,
+ const char* source,
+ const char* header,
+ bool ensure_row_contiguous,
+ bool atomic_outputs) {
+ return mlx_fast_metal_kernel_new_(name, input_names, output_names, source, header, ensure_row_contiguous, atomic_outputs);
+}
+
+static inline void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls) {
+ mlx_fast_metal_kernel_free_(cls);
+}
+
+static inline int mlx_fast_metal_kernel_apply(
+ mlx_vector_array* outputs,
+ mlx_fast_metal_kernel cls,
+ const mlx_vector_array inputs,
+ const mlx_fast_metal_kernel_config config,
+ const mlx_stream stream) {
+ return mlx_fast_metal_kernel_apply_(outputs, cls, inputs, config, stream);
+}
+
+static inline int mlx_fast_rms_norm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array weight /* may be null */,
+ float eps,
+ const mlx_stream s) {
+ return mlx_fast_rms_norm_(res, x, weight, eps, s);
+}
+
+static inline int mlx_fast_rope(
+ mlx_array* res,
+ const mlx_array x,
+ int dims,
+ bool traditional,
+ mlx_optional_float base,
+ float scale,
+ int offset,
+ const mlx_array freqs /* may be null */,
+ const mlx_stream s) {
+ return mlx_fast_rope_(res, x, dims, traditional, base, scale, offset, freqs, s);
+}
+
+static inline int mlx_fast_scaled_dot_product_attention(
+ mlx_array* res,
+ const mlx_array queries,
+ const mlx_array keys,
+ const mlx_array values,
+ float scale,
+ const char* mask_mode,
+ const mlx_array mask_arr /* may be null */,
+ const mlx_array sinks /* may be null */,
+ const mlx_stream s) {
+ return mlx_fast_scaled_dot_product_attention_(res, queries, keys, values, scale, mask_mode, mask_arr, sinks, s);
+}
+
+static inline int mlx_fft_fft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) {
+ return mlx_fft_fft_(res, a, n, axis, s);
+}
+
+static inline int mlx_fft_fft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_fft2_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_fftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_fftn_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_fftshift(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_fftshift_(res, a, axes, axes_num, s);
+}
+
+static inline int mlx_fft_ifft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) {
+ return mlx_fft_ifft_(res, a, n, axis, s);
+}
+
+static inline int mlx_fft_ifft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_ifft2_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_ifftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_ifftn_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_ifftshift(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_ifftshift_(res, a, axes, axes_num, s);
+}
+
+static inline int mlx_fft_irfft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) {
+ return mlx_fft_irfft_(res, a, n, axis, s);
+}
+
+static inline int mlx_fft_irfft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_irfft2_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_irfftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_irfftn_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_rfft(
+ mlx_array* res,
+ const mlx_array a,
+ int n,
+ int axis,
+ const mlx_stream s) {
+ return mlx_fft_rfft_(res, a, n, axis, s);
+}
+
+static inline int mlx_fft_rfft2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_rfft2_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline int mlx_fft_rfftn(
+ mlx_array* res,
+ const mlx_array a,
+ const int* n,
+ size_t n_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_fft_rfftn_(res, a, n, n_num, axes, axes_num, s);
+}
+
+static inline mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable) {
+ return mlx_io_reader_new_(desc, vtable);
+}
+
+static inline int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io) {
+ return mlx_io_reader_descriptor_(desc_, io);
+}
+
+static inline int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io) {
+ return mlx_io_reader_tostring_(str_, io);
+}
+
+static inline int mlx_io_reader_free(mlx_io_reader io) {
+ return mlx_io_reader_free_(io);
+}
+
+static inline mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable) {
+ return mlx_io_writer_new_(desc, vtable);
+}
+
+static inline int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io) {
+ return mlx_io_writer_descriptor_(desc_, io);
+}
+
+static inline int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io) {
+ return mlx_io_writer_tostring_(str_, io);
+}
+
+static inline int mlx_io_writer_free(mlx_io_writer io) {
+ return mlx_io_writer_free_(io);
+}
+
+static inline int mlx_load_reader(
+ mlx_array* res,
+ mlx_io_reader in_stream,
+ const mlx_stream s) {
+ return mlx_load_reader_(res, in_stream, s);
+}
+
+static inline int mlx_load(mlx_array* res, const char* file, const mlx_stream s) {
+ return mlx_load_(res, file, s);
+}
+
+static inline int mlx_load_safetensors_reader(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ mlx_io_reader in_stream,
+ const mlx_stream s) {
+ return mlx_load_safetensors_reader_(res_0, res_1, in_stream, s);
+}
+
+static inline int mlx_load_safetensors(
+ mlx_map_string_to_array* res_0,
+ mlx_map_string_to_string* res_1,
+ const char* file,
+ const mlx_stream s) {
+ return mlx_load_safetensors_(res_0, res_1, file, s);
+}
+
+static inline int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a) {
+ return mlx_save_writer_(out_stream, a);
+}
+
+static inline int mlx_save(const char* file, const mlx_array a) {
+ return mlx_save_(file, a);
+}
+
+static inline int mlx_save_safetensors_writer(
+ mlx_io_writer in_stream,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata) {
+ return mlx_save_safetensors_writer_(in_stream, param, metadata);
+}
+
+static inline int mlx_save_safetensors(
+ const char* file,
+ const mlx_map_string_to_array param,
+ const mlx_map_string_to_string metadata) {
+ return mlx_save_safetensors_(file, param, metadata);
+}
+
+static inline int mlx_linalg_cholesky(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s) {
+ return mlx_linalg_cholesky_(res, a, upper, s);
+}
+
+static inline int mlx_linalg_cholesky_inv(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s) {
+ return mlx_linalg_cholesky_inv_(res, a, upper, s);
+}
+
+static inline int mlx_linalg_cross(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s) {
+ return mlx_linalg_cross_(res, a, b, axis, s);
+}
+
+static inline int mlx_linalg_eig(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s) {
+ return mlx_linalg_eig_(res_0, res_1, a, s);
+}
+
+static inline int mlx_linalg_eigh(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s) {
+ return mlx_linalg_eigh_(res_0, res_1, a, UPLO, s);
+}
+
+static inline int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_linalg_eigvals_(res, a, s);
+}
+
+static inline int mlx_linalg_eigvalsh(
+ mlx_array* res,
+ const mlx_array a,
+ const char* UPLO,
+ const mlx_stream s) {
+ return mlx_linalg_eigvalsh_(res, a, UPLO, s);
+}
+
+static inline int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_linalg_inv_(res, a, s);
+}
+
+static inline int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_linalg_lu_(res, a, s);
+}
+
+static inline int mlx_linalg_lu_factor(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s) {
+ return mlx_linalg_lu_factor_(res_0, res_1, a, s);
+}
+
+static inline int mlx_linalg_norm(
+ mlx_array* res,
+ const mlx_array a,
+ double ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_linalg_norm_(res, a, ord, axis, axis_num, keepdims, s);
+}
+
+static inline int mlx_linalg_norm_matrix(
+ mlx_array* res,
+ const mlx_array a,
+ const char* ord,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_linalg_norm_matrix_(res, a, ord, axis, axis_num, keepdims, s);
+}
+
+static inline int mlx_linalg_norm_l2(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axis /* may be null */,
+ size_t axis_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_linalg_norm_l2_(res, a, axis, axis_num, keepdims, s);
+}
+
+static inline int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_linalg_pinv_(res, a, s);
+}
+
+static inline int mlx_linalg_qr(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array a,
+ const mlx_stream s) {
+ return mlx_linalg_qr_(res_0, res_1, a, s);
+}
+
+static inline int mlx_linalg_solve(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_linalg_solve_(res, a, b, s);
+}
+
+static inline int mlx_linalg_solve_triangular(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool upper,
+ const mlx_stream s) {
+ return mlx_linalg_solve_triangular_(res, a, b, upper, s);
+}
+
+static inline int mlx_linalg_svd(
+ mlx_vector_array* res,
+ const mlx_array a,
+ bool compute_uv,
+ const mlx_stream s) {
+ return mlx_linalg_svd_(res, a, compute_uv, s);
+}
+
+static inline int mlx_linalg_tri_inv(
+ mlx_array* res,
+ const mlx_array a,
+ bool upper,
+ const mlx_stream s) {
+ return mlx_linalg_tri_inv_(res, a, upper, s);
+}
+
+static inline mlx_map_string_to_array mlx_map_string_to_array_new(void) {
+ return mlx_map_string_to_array_new_();
+}
+
+static inline int mlx_map_string_to_array_set(
+ mlx_map_string_to_array* map,
+ const mlx_map_string_to_array src) {
+ return mlx_map_string_to_array_set_(map, src);
+}
+
+static inline int mlx_map_string_to_array_free(mlx_map_string_to_array map) {
+ return mlx_map_string_to_array_free_(map);
+}
+
+static inline int mlx_map_string_to_array_insert(
+ mlx_map_string_to_array map,
+ const char* key,
+ const mlx_array value) {
+ return mlx_map_string_to_array_insert_(map, key, value);
+}
+
+static inline int mlx_map_string_to_array_get(
+ mlx_array* value,
+ const mlx_map_string_to_array map,
+ const char* key) {
+ return mlx_map_string_to_array_get_(value, map, key);
+}
+
+static inline mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new(
+ mlx_map_string_to_array map) {
+ return mlx_map_string_to_array_iterator_new_(map);
+}
+
+static inline int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it) {
+ return mlx_map_string_to_array_iterator_free_(it);
+}
+
+static inline int mlx_map_string_to_array_iterator_next(
+ const char** key,
+ mlx_array* value,
+ mlx_map_string_to_array_iterator it) {
+ return mlx_map_string_to_array_iterator_next_(key, value, it);
+}
+
+static inline mlx_map_string_to_string mlx_map_string_to_string_new(void) {
+ return mlx_map_string_to_string_new_();
+}
+
+static inline int mlx_map_string_to_string_set(
+ mlx_map_string_to_string* map,
+ const mlx_map_string_to_string src) {
+ return mlx_map_string_to_string_set_(map, src);
+}
+
+static inline int mlx_map_string_to_string_free(mlx_map_string_to_string map) {
+ return mlx_map_string_to_string_free_(map);
+}
+
+static inline int mlx_map_string_to_string_insert(
+ mlx_map_string_to_string map,
+ const char* key,
+ const char* value) {
+ return mlx_map_string_to_string_insert_(map, key, value);
+}
+
+static inline int mlx_map_string_to_string_get(
+ const char** value,
+ const mlx_map_string_to_string map,
+ const char* key) {
+ return mlx_map_string_to_string_get_(value, map, key);
+}
+
+static inline mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new(
+ mlx_map_string_to_string map) {
+ return mlx_map_string_to_string_iterator_new_(map);
+}
+
+static inline int mlx_map_string_to_string_iterator_free(
+ mlx_map_string_to_string_iterator it) {
+ return mlx_map_string_to_string_iterator_free_(it);
+}
+
+static inline int mlx_map_string_to_string_iterator_next(
+ const char** key,
+ const char** value,
+ mlx_map_string_to_string_iterator it) {
+ return mlx_map_string_to_string_iterator_next_(key, value, it);
+}
+
+static inline int mlx_clear_cache(void) {
+ return mlx_clear_cache_();
+}
+
+static inline int mlx_get_active_memory(size_t* res) {
+ return mlx_get_active_memory_(res);
+}
+
+static inline int mlx_get_cache_memory(size_t* res) {
+ return mlx_get_cache_memory_(res);
+}
+
+static inline int mlx_get_memory_limit(size_t* res) {
+ return mlx_get_memory_limit_(res);
+}
+
+static inline int mlx_get_peak_memory(size_t* res) {
+ return mlx_get_peak_memory_(res);
+}
+
+static inline int mlx_reset_peak_memory(void) {
+ return mlx_reset_peak_memory_();
+}
+
+static inline int mlx_set_cache_limit(size_t* res, size_t limit) {
+ return mlx_set_cache_limit_(res, limit);
+}
+
+static inline int mlx_set_memory_limit(size_t* res, size_t limit) {
+ return mlx_set_memory_limit_(res, limit);
+}
+
+static inline int mlx_set_wired_limit(size_t* res, size_t limit) {
+ return mlx_set_wired_limit_(res, limit);
+}
+
+static inline mlx_metal_device_info_t mlx_metal_device_info(void) {
+ return mlx_metal_device_info_();
+}
+
+static inline int mlx_metal_is_available(bool* res) {
+ return mlx_metal_is_available_(res);
+}
+
+static inline int mlx_metal_start_capture(const char* path) {
+ return mlx_metal_start_capture_(path);
+}
+
+static inline int mlx_metal_stop_capture(void) {
+ return mlx_metal_stop_capture_();
+}
+
+static inline int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_abs_(res, a, s);
+}
+
+static inline int mlx_add(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_add_(res, a, b, s);
+}
+
+static inline int mlx_addmm(
+ mlx_array* res,
+ const mlx_array c,
+ const mlx_array a,
+ const mlx_array b,
+ float alpha,
+ float beta,
+ const mlx_stream s) {
+ return mlx_addmm_(res, c, a, b, alpha, beta, s);
+}
+
+static inline int mlx_all_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_all_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_all_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_all_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_all(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_all_(res, a, keepdims, s);
+}
+
+static inline int mlx_allclose(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s) {
+ return mlx_allclose_(res, a, b, rtol, atol, equal_nan, s);
+}
+
+static inline int mlx_any_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_any_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_any_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_any_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_any(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_any_(res, a, keepdims, s);
+}
+
+static inline int mlx_arange(
+ mlx_array* res,
+ double start,
+ double stop,
+ double step,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_arange_(res, start, stop, step, dtype, s);
+}
+
+static inline int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_arccos_(res, a, s);
+}
+
+static inline int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_arccosh_(res, a, s);
+}
+
+static inline int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_arcsin_(res, a, s);
+}
+
+static inline int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_arcsinh_(res, a, s);
+}
+
+static inline int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_arctan_(res, a, s);
+}
+
+static inline int mlx_arctan2(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_arctan2_(res, a, b, s);
+}
+
+static inline int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_arctanh_(res, a, s);
+}
+
+static inline int mlx_argmax_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_argmax_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_argmax(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_argmax_(res, a, keepdims, s);
+}
+
+static inline int mlx_argmin_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_argmin_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_argmin(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_argmin_(res, a, keepdims, s);
+}
+
+static inline int mlx_argpartition_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s) {
+ return mlx_argpartition_axis_(res, a, kth, axis, s);
+}
+
+static inline int mlx_argpartition(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s) {
+ return mlx_argpartition_(res, a, kth, s);
+}
+
+static inline int mlx_argsort_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) {
+ return mlx_argsort_axis_(res, a, axis, s);
+}
+
+static inline int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_argsort_(res, a, s);
+}
+
+static inline int mlx_array_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ bool equal_nan,
+ const mlx_stream s) {
+ return mlx_array_equal_(res, a, b, equal_nan, s);
+}
+
+static inline int mlx_as_strided(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const int64_t* strides,
+ size_t strides_num,
+ size_t offset,
+ const mlx_stream s) {
+ return mlx_as_strided_(res, a, shape, shape_num, strides, strides_num, offset, s);
+}
+
+static inline int mlx_astype(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_astype_(res, a, dtype, s);
+}
+
+static inline int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_atleast_1d_(res, a, s);
+}
+
+static inline int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_atleast_2d_(res, a, s);
+}
+
+static inline int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_atleast_3d_(res, a, s);
+}
+
+static inline int mlx_bitwise_and(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_bitwise_and_(res, a, b, s);
+}
+
+static inline int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_bitwise_invert_(res, a, s);
+}
+
+static inline int mlx_bitwise_or(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_bitwise_or_(res, a, b, s);
+}
+
+static inline int mlx_bitwise_xor(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_bitwise_xor_(res, a, b, s);
+}
+
+static inline int mlx_block_masked_mm(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int block_size,
+ const mlx_array mask_out /* may be null */,
+ const mlx_array mask_lhs /* may be null */,
+ const mlx_array mask_rhs /* may be null */,
+ const mlx_stream s) {
+ return mlx_block_masked_mm_(res, a, b, block_size, mask_out, mask_lhs, mask_rhs, s);
+}
+
+static inline int mlx_broadcast_arrays(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_stream s) {
+ return mlx_broadcast_arrays_(res, inputs, s);
+}
+
+static inline int mlx_broadcast_to(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s) {
+ return mlx_broadcast_to_(res, a, shape, shape_num, s);
+}
+
+static inline int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_ceil_(res, a, s);
+}
+
+static inline int mlx_clip(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array a_min /* may be null */,
+ const mlx_array a_max /* may be null */,
+ const mlx_stream s) {
+ return mlx_clip_(res, a, a_min, a_max, s);
+}
+
+static inline int mlx_concatenate_axis(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s) {
+ return mlx_concatenate_axis_(res, arrays, axis, s);
+}
+
+static inline int mlx_concatenate(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s) {
+ return mlx_concatenate_(res, arrays, s);
+}
+
+static inline int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_conjugate_(res, a, s);
+}
+
+static inline int mlx_contiguous(
+ mlx_array* res,
+ const mlx_array a,
+ bool allow_col_major,
+ const mlx_stream s) {
+ return mlx_contiguous_(res, a, allow_col_major, s);
+}
+
+static inline int mlx_conv1d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int groups,
+ const mlx_stream s) {
+ return mlx_conv1d_(res, input, weight, stride, padding, dilation, groups, s);
+}
+
+static inline int mlx_conv2d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int groups,
+ const mlx_stream s) {
+ return mlx_conv2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, groups, s);
+}
+
+static inline int mlx_conv3d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int groups,
+ const mlx_stream s) {
+ return mlx_conv3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, groups, s);
+}
+
+static inline int mlx_conv_general(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ const int* stride,
+ size_t stride_num,
+ const int* padding_lo,
+ size_t padding_lo_num,
+ const int* padding_hi,
+ size_t padding_hi_num,
+ const int* kernel_dilation,
+ size_t kernel_dilation_num,
+ const int* input_dilation,
+ size_t input_dilation_num,
+ int groups,
+ bool flip,
+ const mlx_stream s) {
+ return mlx_conv_general_(res, input, weight, stride, stride_num, padding_lo, padding_lo_num, padding_hi, padding_hi_num, kernel_dilation, kernel_dilation_num, input_dilation, input_dilation_num, groups, flip, s);
+}
+
+static inline int mlx_conv_transpose1d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride,
+ int padding,
+ int dilation,
+ int output_padding,
+ int groups,
+ const mlx_stream s) {
+ return mlx_conv_transpose1d_(res, input, weight, stride, padding, dilation, output_padding, groups, s);
+}
+
+static inline int mlx_conv_transpose2d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int padding_0,
+ int padding_1,
+ int dilation_0,
+ int dilation_1,
+ int output_padding_0,
+ int output_padding_1,
+ int groups,
+ const mlx_stream s) {
+ return mlx_conv_transpose2d_(res, input, weight, stride_0, stride_1, padding_0, padding_1, dilation_0, dilation_1, output_padding_0, output_padding_1, groups, s);
+}
+
+static inline int mlx_conv_transpose3d(
+ mlx_array* res,
+ const mlx_array input,
+ const mlx_array weight,
+ int stride_0,
+ int stride_1,
+ int stride_2,
+ int padding_0,
+ int padding_1,
+ int padding_2,
+ int dilation_0,
+ int dilation_1,
+ int dilation_2,
+ int output_padding_0,
+ int output_padding_1,
+ int output_padding_2,
+ int groups,
+ const mlx_stream s) {
+ return mlx_conv_transpose3d_(res, input, weight, stride_0, stride_1, stride_2, padding_0, padding_1, padding_2, dilation_0, dilation_1, dilation_2, output_padding_0, output_padding_1, output_padding_2, groups, s);
+}
+
+static inline int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_copy_(res, a, s);
+}
+
+static inline int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_cos_(res, a, s);
+}
+
+static inline int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_cosh_(res, a, s);
+}
+
+static inline int mlx_cummax(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) {
+ return mlx_cummax_(res, a, axis, reverse, inclusive, s);
+}
+
+static inline int mlx_cummin(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) {
+ return mlx_cummin_(res, a, axis, reverse, inclusive, s);
+}
+
+static inline int mlx_cumprod(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) {
+ return mlx_cumprod_(res, a, axis, reverse, inclusive, s);
+}
+
+static inline int mlx_cumsum(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) {
+ return mlx_cumsum_(res, a, axis, reverse, inclusive, s);
+}
+
+static inline int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_degrees_(res, a, s);
+}
+
+static inline int mlx_depends(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array dependencies) {
+ return mlx_depends_(res, inputs, dependencies);
+}
+
+static inline int mlx_dequantize(
+ mlx_array* res,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ mlx_optional_dtype dtype,
+ const mlx_stream s) {
+ return mlx_dequantize_(res, w, scales, biases, group_size, bits, mode, dtype, s);
+}
+
+static inline int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
+ return mlx_diag_(res, a, k, s);
+}
+
+static inline int mlx_diagonal(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ const mlx_stream s) {
+ return mlx_diagonal_(res, a, offset, axis1, axis2, s);
+}
+
+static inline int mlx_divide(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_divide_(res, a, b, s);
+}
+
+static inline int mlx_divmod(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_divmod_(res, a, b, s);
+}
+
+static inline int mlx_einsum(
+ mlx_array* res,
+ const char* subscripts,
+ const mlx_vector_array operands,
+ const mlx_stream s) {
+ return mlx_einsum_(res, subscripts, operands, s);
+}
+
+static inline int mlx_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_equal_(res, a, b, s);
+}
+
+static inline int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_erf_(res, a, s);
+}
+
+static inline int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_erfinv_(res, a, s);
+}
+
+static inline int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_exp_(res, a, s);
+}
+
+static inline int mlx_expand_dims_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_expand_dims_axes_(res, a, axes, axes_num, s);
+}
+
+static inline int mlx_expand_dims(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) {
+ return mlx_expand_dims_(res, a, axis, s);
+}
+
+static inline int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_expm1_(res, a, s);
+}
+
+static inline int mlx_eye(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_eye_(res, n, m, k, dtype, s);
+}
+
+static inline int mlx_flatten(
+ mlx_array* res,
+ const mlx_array a,
+ int start_axis,
+ int end_axis,
+ const mlx_stream s) {
+ return mlx_flatten_(res, a, start_axis, end_axis, s);
+}
+
+static inline int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_floor_(res, a, s);
+}
+
+static inline int mlx_floor_divide(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_floor_divide_(res, a, b, s);
+}
+
+static inline int mlx_from_fp8(
+ mlx_array* res,
+ const mlx_array x,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_from_fp8_(res, x, dtype, s);
+}
+
+static inline int mlx_full(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_full_(res, shape, shape_num, vals, dtype, s);
+}
+
+static inline int mlx_full_like(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array vals,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_full_like_(res, a, vals, dtype, s);
+}
+
+static inline int mlx_gather(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_sizes,
+ size_t slice_sizes_num,
+ const mlx_stream s) {
+ return mlx_gather_(res, a, indices, axes, axes_num, slice_sizes, slice_sizes_num, s);
+}
+
+static inline int mlx_gather_mm(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool sorted_indices,
+ const mlx_stream s) {
+ return mlx_gather_mm_(res, a, b, lhs_indices, rhs_indices, sorted_indices, s);
+}
+
+static inline int mlx_gather_qmm(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ const mlx_array lhs_indices /* may be null */,
+ const mlx_array rhs_indices /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ bool sorted_indices,
+ const mlx_stream s) {
+ return mlx_gather_qmm_(res, x, w, scales, biases, lhs_indices, rhs_indices, transpose, group_size, bits, mode, sorted_indices, s);
+}
+
+static inline int mlx_greater(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_greater_(res, a, b, s);
+}
+
+static inline int mlx_greater_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_greater_equal_(res, a, b, s);
+}
+
+static inline int mlx_hadamard_transform(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_optional_float scale,
+ const mlx_stream s) {
+ return mlx_hadamard_transform_(res, a, scale, s);
+}
+
+static inline int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s) {
+ return mlx_identity_(res, n, dtype, s);
+}
+
+static inline int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_imag_(res, a, s);
+}
+
+static inline int mlx_inner(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_inner_(res, a, b, s);
+}
+
+static inline int mlx_isclose(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ double rtol,
+ double atol,
+ bool equal_nan,
+ const mlx_stream s) {
+ return mlx_isclose_(res, a, b, rtol, atol, equal_nan, s);
+}
+
+static inline int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_isfinite_(res, a, s);
+}
+
+static inline int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_isinf_(res, a, s);
+}
+
+static inline int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_isnan_(res, a, s);
+}
+
+static inline int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_isneginf_(res, a, s);
+}
+
+static inline int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_isposinf_(res, a, s);
+}
+
+static inline int mlx_kron(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_kron_(res, a, b, s);
+}
+
+static inline int mlx_left_shift(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_left_shift_(res, a, b, s);
+}
+
+static inline int mlx_less(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_less_(res, a, b, s);
+}
+
+static inline int mlx_less_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_less_equal_(res, a, b, s);
+}
+
+static inline int mlx_linspace(
+ mlx_array* res,
+ double start,
+ double stop,
+ int num,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_linspace_(res, start, stop, num, dtype, s);
+}
+
+static inline int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_log_(res, a, s);
+}
+
+static inline int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_log10_(res, a, s);
+}
+
+static inline int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_log1p_(res, a, s);
+}
+
+static inline int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_log2_(res, a, s);
+}
+
+static inline int mlx_logaddexp(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_logaddexp_(res, a, b, s);
+}
+
+static inline int mlx_logcumsumexp(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool reverse,
+ bool inclusive,
+ const mlx_stream s) {
+ return mlx_logcumsumexp_(res, a, axis, reverse, inclusive, s);
+}
+
+static inline int mlx_logical_and(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_logical_and_(res, a, b, s);
+}
+
+static inline int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_logical_not_(res, a, s);
+}
+
+static inline int mlx_logical_or(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_logical_or_(res, a, b, s);
+}
+
+static inline int mlx_logsumexp_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_logsumexp_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_logsumexp_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_logsumexp_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_logsumexp(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_logsumexp_(res, a, keepdims, s);
+}
+
+static inline int mlx_masked_scatter(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array mask,
+ const mlx_array src,
+ const mlx_stream s) {
+ return mlx_masked_scatter_(res, a, mask, src, s);
+}
+
+static inline int mlx_matmul(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_matmul_(res, a, b, s);
+}
+
+static inline int mlx_max_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_max_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_max_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_max_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_max(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_max_(res, a, keepdims, s);
+}
+
+static inline int mlx_maximum(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_maximum_(res, a, b, s);
+}
+
+static inline int mlx_mean_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_mean_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_mean_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_mean_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_mean(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_mean_(res, a, keepdims, s);
+}
+
+static inline int mlx_median(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_median_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_meshgrid(
+ mlx_vector_array* res,
+ const mlx_vector_array arrays,
+ bool sparse,
+ const char* indexing,
+ const mlx_stream s) {
+ return mlx_meshgrid_(res, arrays, sparse, indexing, s);
+}
+
+static inline int mlx_min_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_min_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_min_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_min_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_min(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_min_(res, a, keepdims, s);
+}
+
+static inline int mlx_minimum(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_minimum_(res, a, b, s);
+}
+
+static inline int mlx_moveaxis(
+ mlx_array* res,
+ const mlx_array a,
+ int source,
+ int destination,
+ const mlx_stream s) {
+ return mlx_moveaxis_(res, a, source, destination, s);
+}
+
+static inline int mlx_multiply(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_multiply_(res, a, b, s);
+}
+
+static inline int mlx_nan_to_num(
+ mlx_array* res,
+ const mlx_array a,
+ float nan,
+ mlx_optional_float posinf,
+ mlx_optional_float neginf,
+ const mlx_stream s) {
+ return mlx_nan_to_num_(res, a, nan, posinf, neginf, s);
+}
+
+static inline int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_negative_(res, a, s);
+}
+
+static inline int mlx_not_equal(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_not_equal_(res, a, b, s);
+}
+
+static inline int mlx_number_of_elements(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool inverted,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_number_of_elements_(res, a, axes, axes_num, inverted, dtype, s);
+}
+
+static inline int mlx_ones(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_ones_(res, shape, shape_num, dtype, s);
+}
+
+static inline int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_ones_like_(res, a, s);
+}
+
+static inline int mlx_outer(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_outer_(res, a, b, s);
+}
+
+static inline int mlx_pad(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const int* low_pad_size,
+ size_t low_pad_size_num,
+ const int* high_pad_size,
+ size_t high_pad_size_num,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s) {
+ return mlx_pad_(res, a, axes, axes_num, low_pad_size, low_pad_size_num, high_pad_size, high_pad_size_num, pad_value, mode, s);
+}
+
+static inline int mlx_pad_symmetric(
+ mlx_array* res,
+ const mlx_array a,
+ int pad_width,
+ const mlx_array pad_value,
+ const char* mode,
+ const mlx_stream s) {
+ return mlx_pad_symmetric_(res, a, pad_width, pad_value, mode, s);
+}
+
+static inline int mlx_partition_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ int axis,
+ const mlx_stream s) {
+ return mlx_partition_axis_(res, a, kth, axis, s);
+}
+
+static inline int mlx_partition(
+ mlx_array* res,
+ const mlx_array a,
+ int kth,
+ const mlx_stream s) {
+ return mlx_partition_(res, a, kth, s);
+}
+
+static inline int mlx_power(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_power_(res, a, b, s);
+}
+
+static inline int mlx_prod_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_prod_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_prod_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_prod_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_prod(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_prod_(res, a, keepdims, s);
+}
+
+static inline int mlx_put_along_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s) {
+ return mlx_put_along_axis_(res, a, indices, values, axis, s);
+}
+
+static inline int mlx_quantize(
+ mlx_vector_array* res,
+ const mlx_array w,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s) {
+ return mlx_quantize_(res, w, group_size, bits, mode, s);
+}
+
+static inline int mlx_quantized_matmul(
+ mlx_array* res,
+ const mlx_array x,
+ const mlx_array w,
+ const mlx_array scales,
+ const mlx_array biases /* may be null */,
+ bool transpose,
+ mlx_optional_int group_size,
+ mlx_optional_int bits,
+ const char* mode,
+ const mlx_stream s) {
+ return mlx_quantized_matmul_(res, x, w, scales, biases, transpose, group_size, bits, mode, s);
+}
+
+static inline int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_radians_(res, a, s);
+}
+
+static inline int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_real_(res, a, s);
+}
+
+static inline int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_reciprocal_(res, a, s);
+}
+
+static inline int mlx_remainder(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_remainder_(res, a, b, s);
+}
+
+static inline int mlx_repeat_axis(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ int axis,
+ const mlx_stream s) {
+ return mlx_repeat_axis_(res, arr, repeats, axis, s);
+}
+
+static inline int mlx_repeat(
+ mlx_array* res,
+ const mlx_array arr,
+ int repeats,
+ const mlx_stream s) {
+ return mlx_repeat_(res, arr, repeats, s);
+}
+
+static inline int mlx_reshape(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s) {
+ return mlx_reshape_(res, a, shape, shape_num, s);
+}
+
+static inline int mlx_right_shift(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_right_shift_(res, a, b, s);
+}
+
+static inline int mlx_roll_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ int axis,
+ const mlx_stream s) {
+ return mlx_roll_axis_(res, a, shift, shift_num, axis, s);
+}
+
+static inline int mlx_roll_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_roll_axes_(res, a, shift, shift_num, axes, axes_num, s);
+}
+
+static inline int mlx_roll(
+ mlx_array* res,
+ const mlx_array a,
+ const int* shift,
+ size_t shift_num,
+ const mlx_stream s) {
+ return mlx_roll_(res, a, shift, shift_num, s);
+}
+
+static inline int mlx_round(
+ mlx_array* res,
+ const mlx_array a,
+ int decimals,
+ const mlx_stream s) {
+ return mlx_round_(res, a, decimals, s);
+}
+
+static inline int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_rsqrt_(res, a, s);
+}
+
+static inline int mlx_scatter(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_scatter_(res, a, indices, updates, axes, axes_num, s);
+}
+
+static inline int mlx_scatter_add(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_scatter_add_(res, a, indices, updates, axes, axes_num, s);
+}
+
+static inline int mlx_scatter_add_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_array values,
+ int axis,
+ const mlx_stream s) {
+ return mlx_scatter_add_axis_(res, a, indices, values, axis, s);
+}
+
+static inline int mlx_scatter_max(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_scatter_max_(res, a, indices, updates, axes, axes_num, s);
+}
+
+static inline int mlx_scatter_min(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_scatter_min_(res, a, indices, updates, axes, axes_num, s);
+}
+
+static inline int mlx_scatter_prod(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_vector_array indices,
+ const mlx_array updates,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_scatter_prod_(res, a, indices, updates, axes, axes_num, s);
+}
+
+static inline int mlx_segmented_mm(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_array segments,
+ const mlx_stream s) {
+ return mlx_segmented_mm_(res, a, b, segments, s);
+}
+
+static inline int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_sigmoid_(res, a, s);
+}
+
+static inline int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_sign_(res, a, s);
+}
+
+static inline int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_sin_(res, a, s);
+}
+
+static inline int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_sinh_(res, a, s);
+}
+
+static inline int mlx_slice(
+ mlx_array* res,
+ const mlx_array a,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s) {
+ return mlx_slice_(res, a, start, start_num, stop, stop_num, strides, strides_num, s);
+}
+
+static inline int mlx_slice_dynamic(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const int* slice_size,
+ size_t slice_size_num,
+ const mlx_stream s) {
+ return mlx_slice_dynamic_(res, a, start, axes, axes_num, slice_size, slice_size_num, s);
+}
+
+static inline int mlx_slice_update(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const int* start,
+ size_t start_num,
+ const int* stop,
+ size_t stop_num,
+ const int* strides,
+ size_t strides_num,
+ const mlx_stream s) {
+ return mlx_slice_update_(res, src, update, start, start_num, stop, stop_num, strides, strides_num, s);
+}
+
+static inline int mlx_slice_update_dynamic(
+ mlx_array* res,
+ const mlx_array src,
+ const mlx_array update,
+ const mlx_array start,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_slice_update_dynamic_(res, src, update, start, axes, axes_num, s);
+}
+
+static inline int mlx_softmax_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool precise,
+ const mlx_stream s) {
+ return mlx_softmax_axes_(res, a, axes, axes_num, precise, s);
+}
+
+static inline int mlx_softmax_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool precise,
+ const mlx_stream s) {
+ return mlx_softmax_axis_(res, a, axis, precise, s);
+}
+
+static inline int mlx_softmax(
+ mlx_array* res,
+ const mlx_array a,
+ bool precise,
+ const mlx_stream s) {
+ return mlx_softmax_(res, a, precise, s);
+}
+
+static inline int mlx_sort_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) {
+ return mlx_sort_axis_(res, a, axis, s);
+}
+
+static inline int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_sort_(res, a, s);
+}
+
+static inline int mlx_split(
+ mlx_vector_array* res,
+ const mlx_array a,
+ int num_splits,
+ int axis,
+ const mlx_stream s) {
+ return mlx_split_(res, a, num_splits, axis, s);
+}
+
+static inline int mlx_split_sections(
+ mlx_vector_array* res,
+ const mlx_array a,
+ const int* indices,
+ size_t indices_num,
+ int axis,
+ const mlx_stream s) {
+ return mlx_split_sections_(res, a, indices, indices_num, axis, s);
+}
+
+static inline int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_sqrt_(res, a, s);
+}
+
+static inline int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_square_(res, a, s);
+}
+
+static inline int mlx_squeeze_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_squeeze_axes_(res, a, axes, axes_num, s);
+}
+
+static inline int mlx_squeeze_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const mlx_stream s) {
+ return mlx_squeeze_axis_(res, a, axis, s);
+}
+
+static inline int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_squeeze_(res, a, s);
+}
+
+static inline int mlx_stack_axis(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ int axis,
+ const mlx_stream s) {
+ return mlx_stack_axis_(res, arrays, axis, s);
+}
+
+static inline int mlx_stack(
+ mlx_array* res,
+ const mlx_vector_array arrays,
+ const mlx_stream s) {
+ return mlx_stack_(res, arrays, s);
+}
+
+static inline int mlx_std_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) {
+ return mlx_std_axes_(res, a, axes, axes_num, keepdims, ddof, s);
+}
+
+static inline int mlx_std_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) {
+ return mlx_std_axis_(res, a, axis, keepdims, ddof, s);
+}
+
+static inline int mlx_std(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) {
+ return mlx_std_(res, a, keepdims, ddof, s);
+}
+
+static inline int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_stop_gradient_(res, a, s);
+}
+
+static inline int mlx_subtract(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const mlx_stream s) {
+ return mlx_subtract_(res, a, b, s);
+}
+
+static inline int mlx_sum_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_sum_axes_(res, a, axes, axes_num, keepdims, s);
+}
+
+static inline int mlx_sum_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_sum_axis_(res, a, axis, keepdims, s);
+}
+
+static inline int mlx_sum(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ const mlx_stream s) {
+ return mlx_sum_(res, a, keepdims, s);
+}
+
+static inline int mlx_swapaxes(
+ mlx_array* res,
+ const mlx_array a,
+ int axis1,
+ int axis2,
+ const mlx_stream s) {
+ return mlx_swapaxes_(res, a, axis1, axis2, s);
+}
+
+static inline int mlx_take_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s) {
+ return mlx_take_axis_(res, a, indices, axis, s);
+}
+
+static inline int mlx_take(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ const mlx_stream s) {
+ return mlx_take_(res, a, indices, s);
+}
+
+static inline int mlx_take_along_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array indices,
+ int axis,
+ const mlx_stream s) {
+ return mlx_take_along_axis_(res, a, indices, axis, s);
+}
+
+static inline int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_tan_(res, a, s);
+}
+
+static inline int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_tanh_(res, a, s);
+}
+
+static inline int mlx_tensordot(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ const int* axes_a,
+ size_t axes_a_num,
+ const int* axes_b,
+ size_t axes_b_num,
+ const mlx_stream s) {
+ return mlx_tensordot_(res, a, b, axes_a, axes_a_num, axes_b, axes_b_num, s);
+}
+
+static inline int mlx_tensordot_axis(
+ mlx_array* res,
+ const mlx_array a,
+ const mlx_array b,
+ int axis,
+ const mlx_stream s) {
+ return mlx_tensordot_axis_(res, a, b, axis, s);
+}
+
+static inline int mlx_tile(
+ mlx_array* res,
+ const mlx_array arr,
+ const int* reps,
+ size_t reps_num,
+ const mlx_stream s) {
+ return mlx_tile_(res, arr, reps, reps_num, s);
+}
+
+static inline int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s) {
+ return mlx_to_fp8_(res, x, s);
+}
+
+static inline int mlx_topk_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int k,
+ int axis,
+ const mlx_stream s) {
+ return mlx_topk_axis_(res, a, k, axis, s);
+}
+
+static inline int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s) {
+ return mlx_topk_(res, a, k, s);
+}
+
+static inline int mlx_trace(
+ mlx_array* res,
+ const mlx_array a,
+ int offset,
+ int axis1,
+ int axis2,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_trace_(res, a, offset, axis1, axis2, dtype, s);
+}
+
+static inline int mlx_transpose_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ const mlx_stream s) {
+ return mlx_transpose_axes_(res, a, axes, axes_num, s);
+}
+
+static inline int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_transpose_(res, a, s);
+}
+
+static inline int mlx_tri(
+ mlx_array* res,
+ int n,
+ int m,
+ int k,
+ mlx_dtype type,
+ const mlx_stream s) {
+ return mlx_tri_(res, n, m, k, type, s);
+}
+
+static inline int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
+ return mlx_tril_(res, x, k, s);
+}
+
+static inline int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s) {
+ return mlx_triu_(res, x, k, s);
+}
+
+static inline int mlx_unflatten(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_stream s) {
+ return mlx_unflatten_(res, a, axis, shape, shape_num, s);
+}
+
+static inline int mlx_var_axes(
+ mlx_array* res,
+ const mlx_array a,
+ const int* axes,
+ size_t axes_num,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) {
+ return mlx_var_axes_(res, a, axes, axes_num, keepdims, ddof, s);
+}
+
+static inline int mlx_var_axis(
+ mlx_array* res,
+ const mlx_array a,
+ int axis,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) {
+ return mlx_var_axis_(res, a, axis, keepdims, ddof, s);
+}
+
+static inline int mlx_var(
+ mlx_array* res,
+ const mlx_array a,
+ bool keepdims,
+ int ddof,
+ const mlx_stream s) {
+ return mlx_var_(res, a, keepdims, ddof, s);
+}
+
+static inline int mlx_view(
+ mlx_array* res,
+ const mlx_array a,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_view_(res, a, dtype, s);
+}
+
+static inline int mlx_where(
+ mlx_array* res,
+ const mlx_array condition,
+ const mlx_array x,
+ const mlx_array y,
+ const mlx_stream s) {
+ return mlx_where_(res, condition, x, y, s);
+}
+
+static inline int mlx_zeros(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_stream s) {
+ return mlx_zeros_(res, shape, shape_num, dtype, s);
+}
+
+static inline int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s) {
+ return mlx_zeros_like_(res, a, s);
+}
+
+static inline int mlx_random_bernoulli(
+ mlx_array* res,
+ const mlx_array p,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_bernoulli_(res, p, shape, shape_num, key, s);
+}
+
+static inline int mlx_random_bits(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ int width,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_bits_(res, shape, shape_num, width, key, s);
+}
+
+static inline int mlx_random_categorical_shape(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const int* shape,
+ size_t shape_num,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_categorical_shape_(res, logits, axis, shape, shape_num, key, s);
+}
+
+static inline int mlx_random_categorical_num_samples(
+ mlx_array* res,
+ const mlx_array logits_,
+ int axis,
+ int num_samples,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_categorical_num_samples_(res, logits_, axis, num_samples, key, s);
+}
+
+static inline int mlx_random_categorical(
+ mlx_array* res,
+ const mlx_array logits,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_categorical_(res, logits, axis, key, s);
+}
+
+static inline int mlx_random_gumbel(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_gumbel_(res, shape, shape_num, dtype, key, s);
+}
+
+static inline int mlx_random_key(mlx_array* res, uint64_t seed) {
+ return mlx_random_key_(res, seed);
+}
+
+static inline int mlx_random_laplace(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_laplace_(res, shape, shape_num, dtype, loc, scale, key, s);
+}
+
+static inline int mlx_random_multivariate_normal(
+ mlx_array* res,
+ const mlx_array mean,
+ const mlx_array cov,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_multivariate_normal_(res, mean, cov, shape, shape_num, dtype, key, s);
+}
+
+static inline int mlx_random_normal_broadcast(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array loc /* may be null */,
+ const mlx_array scale /* may be null */,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_normal_broadcast_(res, shape, shape_num, dtype, loc, scale, key, s);
+}
+
+static inline int mlx_random_normal(
+ mlx_array* res,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ float loc,
+ float scale,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_normal_(res, shape, shape_num, dtype, loc, scale, key, s);
+}
+
+static inline int mlx_random_permutation(
+ mlx_array* res,
+ const mlx_array x,
+ int axis,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_permutation_(res, x, axis, key, s);
+}
+
+static inline int mlx_random_permutation_arange(
+ mlx_array* res,
+ int x,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_permutation_arange_(res, x, key, s);
+}
+
+static inline int mlx_random_randint(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_randint_(res, low, high, shape, shape_num, dtype, key, s);
+}
+
+static inline int mlx_random_seed(uint64_t seed) {
+ return mlx_random_seed_(seed);
+}
+
+static inline int mlx_random_split_num(
+ mlx_array* res,
+ const mlx_array key,
+ int num,
+ const mlx_stream s) {
+ return mlx_random_split_num_(res, key, num, s);
+}
+
+static inline int mlx_random_split(
+ mlx_array* res_0,
+ mlx_array* res_1,
+ const mlx_array key,
+ const mlx_stream s) {
+ return mlx_random_split_(res_0, res_1, key, s);
+}
+
+static inline int mlx_random_truncated_normal(
+ mlx_array* res,
+ const mlx_array lower,
+ const mlx_array upper,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_truncated_normal_(res, lower, upper, shape, shape_num, dtype, key, s);
+}
+
+static inline int mlx_random_uniform(
+ mlx_array* res,
+ const mlx_array low,
+ const mlx_array high,
+ const int* shape,
+ size_t shape_num,
+ mlx_dtype dtype,
+ const mlx_array key /* may be null */,
+ const mlx_stream s) {
+ return mlx_random_uniform_(res, low, high, shape, shape_num, dtype, key, s);
+}
+
+static inline mlx_stream mlx_stream_new(void) {
+ return mlx_stream_new_();
+}
+
+static inline mlx_stream mlx_stream_new_device(mlx_device dev) {
+ return mlx_stream_new_device_(dev);
+}
+
+static inline int mlx_stream_set(mlx_stream* stream, const mlx_stream src) {
+ return mlx_stream_set_(stream, src);
+}
+
+static inline int mlx_stream_free(mlx_stream stream) {
+ return mlx_stream_free_(stream);
+}
+
+static inline int mlx_stream_tostring(mlx_string* str, mlx_stream stream) {
+ return mlx_stream_tostring_(str, stream);
+}
+
+static inline bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs) {
+ return mlx_stream_equal_(lhs, rhs);
+}
+
+static inline int mlx_stream_get_device(mlx_device* dev, mlx_stream stream) {
+ return mlx_stream_get_device_(dev, stream);
+}
+
+static inline int mlx_stream_get_index(int* index, mlx_stream stream) {
+ return mlx_stream_get_index_(index, stream);
+}
+
+static inline int mlx_synchronize(mlx_stream stream) {
+ return mlx_synchronize_(stream);
+}
+
+static inline int mlx_get_default_stream(mlx_stream* stream, mlx_device dev) {
+ return mlx_get_default_stream_(stream, dev);
+}
+
+static inline int mlx_set_default_stream(mlx_stream stream) {
+ return mlx_set_default_stream_(stream);
+}
+
+static inline mlx_stream mlx_default_cpu_stream_new(void) {
+ return mlx_default_cpu_stream_new_();
+}
+
+static inline mlx_stream mlx_default_gpu_stream_new(void) {
+ return mlx_default_gpu_stream_new_();
+}
+
+static inline mlx_string mlx_string_new(void) {
+ return mlx_string_new_();
+}
+
+static inline mlx_string mlx_string_new_data(const char* str) {
+ return mlx_string_new_data_(str);
+}
+
+static inline int mlx_string_set(mlx_string* str, const mlx_string src) {
+ return mlx_string_set_(str, src);
+}
+
+static inline const char * mlx_string_data(mlx_string str) {
+ return mlx_string_data_(str);
+}
+
+static inline int mlx_string_free(mlx_string str) {
+ return mlx_string_free_(str);
+}
+
+static inline int mlx_detail_vmap_replace(
+ mlx_vector_array* res,
+ const mlx_vector_array inputs,
+ const mlx_vector_array s_inputs,
+ const mlx_vector_array s_outputs,
+ const int* in_axes,
+ size_t in_axes_num,
+ const int* out_axes,
+ size_t out_axes_num) {
+ return mlx_detail_vmap_replace_(res, inputs, s_inputs, s_outputs, in_axes, in_axes_num, out_axes, out_axes_num);
+}
+
+static inline int mlx_detail_vmap_trace(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array inputs,
+ const int* in_axes,
+ size_t in_axes_num) {
+ return mlx_detail_vmap_trace_(res_0, res_1, fun, inputs, in_axes, in_axes_num);
+}
+
+static inline int mlx_async_eval(const mlx_vector_array outputs) {
+ return mlx_async_eval_(outputs);
+}
+
+static inline int mlx_checkpoint(mlx_closure* res, const mlx_closure fun) {
+ return mlx_checkpoint_(res, fun);
+}
+
+static inline int mlx_custom_function(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp /* may be null */,
+ const mlx_closure_custom_jvp fun_jvp /* may be null */,
+ const mlx_closure_custom_vmap fun_vmap /* may be null */) {
+ return mlx_custom_function_(res, fun, fun_vjp, fun_jvp, fun_vmap);
+}
+
+static inline int mlx_custom_vjp(
+ mlx_closure* res,
+ const mlx_closure fun,
+ const mlx_closure_custom fun_vjp) {
+ return mlx_custom_vjp_(res, fun, fun_vjp);
+}
+
+static inline int mlx_eval(const mlx_vector_array outputs) {
+ return mlx_eval_(outputs);
+}
+
+static inline int mlx_jvp(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array tangents) {
+ return mlx_jvp_(res_0, res_1, fun, primals, tangents);
+}
+
+static inline int mlx_value_and_grad(
+ mlx_closure_value_and_grad* res,
+ const mlx_closure fun,
+ const int* argnums,
+ size_t argnums_num) {
+ return mlx_value_and_grad_(res, fun, argnums, argnums_num);
+}
+
+static inline int mlx_vjp(
+ mlx_vector_array* res_0,
+ mlx_vector_array* res_1,
+ const mlx_closure fun,
+ const mlx_vector_array primals,
+ const mlx_vector_array cotangents) {
+ return mlx_vjp_(res_0, res_1, fun, primals, cotangents);
+}
+
+static inline mlx_vector_array mlx_vector_array_new(void) {
+ return mlx_vector_array_new_();
+}
+
+static inline int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src) {
+ return mlx_vector_array_set_(vec, src);
+}
+
+static inline int mlx_vector_array_free(mlx_vector_array vec) {
+ return mlx_vector_array_free_(vec);
+}
+
+static inline mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size) {
+ return mlx_vector_array_new_data_(data, size);
+}
+
+static inline mlx_vector_array mlx_vector_array_new_value(const mlx_array val) {
+ return mlx_vector_array_new_value_(val);
+}
+
+static inline int mlx_vector_array_set_data(
+ mlx_vector_array* vec,
+ const mlx_array* data,
+ size_t size) {
+ return mlx_vector_array_set_data_(vec, data, size);
+}
+
+static inline int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val) {
+ return mlx_vector_array_set_value_(vec, val);
+}
+
+static inline int mlx_vector_array_append_data(
+ mlx_vector_array vec,
+ const mlx_array* data,
+ size_t size) {
+ return mlx_vector_array_append_data_(vec, data, size);
+}
+
+static inline int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val) {
+ return mlx_vector_array_append_value_(vec, val);
+}
+
+static inline size_t mlx_vector_array_size(mlx_vector_array vec) {
+ return mlx_vector_array_size_(vec);
+}
+
+static inline int mlx_vector_array_get(
+ mlx_array* res,
+ const mlx_vector_array vec,
+ size_t idx) {
+ return mlx_vector_array_get_(res, vec, idx);
+}
+
+static inline mlx_vector_vector_array mlx_vector_vector_array_new(void) {
+ return mlx_vector_vector_array_new_();
+}
+
+static inline int mlx_vector_vector_array_set(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_vector_array src) {
+ return mlx_vector_vector_array_set_(vec, src);
+}
+
+static inline int mlx_vector_vector_array_free(mlx_vector_vector_array vec) {
+ return mlx_vector_vector_array_free_(vec);
+}
+
+static inline mlx_vector_vector_array mlx_vector_vector_array_new_data(
+ const mlx_vector_array* data,
+ size_t size) {
+ return mlx_vector_vector_array_new_data_(data, size);
+}
+
+static inline mlx_vector_vector_array mlx_vector_vector_array_new_value(
+ const mlx_vector_array val) {
+ return mlx_vector_vector_array_new_value_(val);
+}
+
+static inline int mlx_vector_vector_array_set_data(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array* data,
+ size_t size) {
+ return mlx_vector_vector_array_set_data_(vec, data, size);
+}
+
+static inline int mlx_vector_vector_array_set_value(
+ mlx_vector_vector_array* vec,
+ const mlx_vector_array val) {
+ return mlx_vector_vector_array_set_value_(vec, val);
+}
+
+static inline int mlx_vector_vector_array_append_data(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array* data,
+ size_t size) {
+ return mlx_vector_vector_array_append_data_(vec, data, size);
+}
+
+static inline int mlx_vector_vector_array_append_value(
+ mlx_vector_vector_array vec,
+ const mlx_vector_array val) {
+ return mlx_vector_vector_array_append_value_(vec, val);
+}
+
+static inline size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec) {
+ return mlx_vector_vector_array_size_(vec);
+}
+
+static inline int mlx_vector_vector_array_get(
+ mlx_vector_array* res,
+ const mlx_vector_vector_array vec,
+ size_t idx) {
+ return mlx_vector_vector_array_get_(res, vec, idx);
+}
+
+static inline mlx_vector_int mlx_vector_int_new(void) {
+ return mlx_vector_int_new_();
+}
+
+static inline int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src) {
+ return mlx_vector_int_set_(vec, src);
+}
+
+static inline int mlx_vector_int_free(mlx_vector_int vec) {
+ return mlx_vector_int_free_(vec);
+}
+
+static inline mlx_vector_int mlx_vector_int_new_data(int* data, size_t size) {
+ return mlx_vector_int_new_data_(data, size);
+}
+
+static inline mlx_vector_int mlx_vector_int_new_value(int val) {
+ return mlx_vector_int_new_value_(val);
+}
+
+static inline int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size) {
+ return mlx_vector_int_set_data_(vec, data, size);
+}
+
+static inline int mlx_vector_int_set_value(mlx_vector_int* vec, int val) {
+ return mlx_vector_int_set_value_(vec, val);
+}
+
+static inline int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size) {
+ return mlx_vector_int_append_data_(vec, data, size);
+}
+
+static inline int mlx_vector_int_append_value(mlx_vector_int vec, int val) {
+ return mlx_vector_int_append_value_(vec, val);
+}
+
+static inline size_t mlx_vector_int_size(mlx_vector_int vec) {
+ return mlx_vector_int_size_(vec);
+}
+
+static inline int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx) {
+ return mlx_vector_int_get_(res, vec, idx);
+}
+
+static inline mlx_vector_string mlx_vector_string_new(void) {
+ return mlx_vector_string_new_();
+}
+
+static inline int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src) {
+ return mlx_vector_string_set_(vec, src);
+}
+
+static inline int mlx_vector_string_free(mlx_vector_string vec) {
+ return mlx_vector_string_free_(vec);
+}
+
+static inline mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size) {
+ return mlx_vector_string_new_data_(data, size);
+}
+
+static inline mlx_vector_string mlx_vector_string_new_value(const char* val) {
+ return mlx_vector_string_new_value_(val);
+}
+
+static inline int mlx_vector_string_set_data(
+ mlx_vector_string* vec,
+ const char** data,
+ size_t size) {
+ return mlx_vector_string_set_data_(vec, data, size);
+}
+
+static inline int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val) {
+ return mlx_vector_string_set_value_(vec, val);
+}
+
+static inline int mlx_vector_string_append_data(
+ mlx_vector_string vec,
+ const char** data,
+ size_t size) {
+ return mlx_vector_string_append_data_(vec, data, size);
+}
+
+static inline int mlx_vector_string_append_value(mlx_vector_string vec, const char* val) {
+ return mlx_vector_string_append_value_(vec, val);
+}
+
+static inline size_t mlx_vector_string_size(mlx_vector_string vec) {
+ return mlx_vector_string_size_(vec);
+}
+
+static inline int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx) {
+ return mlx_vector_string_get_(res, vec, idx);
+}
+
+static inline int mlx_version(mlx_string* str_) {
+ return mlx_version_(str_);
+}
+
+#endif // MLX_GENERATED_H
\ No newline at end of file
diff --git a/x/mlxrunner/mlx/generator/generated.c.gotmpl b/x/mlxrunner/mlx/generator/generated.c.gotmpl
new file mode 100644
index 00000000000..c31b34a769f
--- /dev/null
+++ b/x/mlxrunner/mlx/generator/generated.c.gotmpl
@@ -0,0 +1,17 @@
+// This code is auto-generated; DO NOT EDIT.
+
+#include "generated.h"
+
+#include
+#include
+#include
+{{ range .Functions }}
+{{ .Type }} (*{{ .Name }}_){{ .Parameters }} = NULL;
+{{- end }}
+
+int mlx_dynamic_load_symbols(mlx_dynamic_handle handle) {
+{{- range .Functions }}
+ CHECK_LOAD(handle, {{ .Name }});
+{{- end }}
+ return 0;
+}
diff --git a/x/mlxrunner/mlx/generator/generated.h.gotmpl b/x/mlxrunner/mlx/generator/generated.h.gotmpl
new file mode 100644
index 00000000000..8f043573bfa
--- /dev/null
+++ b/x/mlxrunner/mlx/generator/generated.h.gotmpl
@@ -0,0 +1,22 @@
+// This code is auto-generated; DO NOT EDIT.
+
+#ifndef MLX_GENERATED_H
+#define MLX_GENERATED_H
+
+#include "dynamic.h"
+#include "mlx/c/mlx.h"
+{{ range .Functions }}
+#undef {{ .Name }}
+{{- end }}
+{{ range .Functions }}
+extern {{ .Type }} (*{{ .Name }}_){{ .Parameters }};
+{{- end }}
+
+int mlx_dynamic_load_symbols(mlx_dynamic_handle handle);
+{{ range .Functions }}
+static inline {{ .Type }} {{ .Name }}{{ .Parameters }} {{ "{" }}
+ return {{ .Name }}_({{ .Args }});
+{{ "}" }}
+{{- end }}
+
+#endif // MLX_GENERATED_H
diff --git a/x/mlxrunner/mlx/generator/main.go b/x/mlxrunner/mlx/generator/main.go
new file mode 100644
index 00000000000..a98046a2fab
--- /dev/null
+++ b/x/mlxrunner/mlx/generator/main.go
@@ -0,0 +1,135 @@
+package main
+
+import (
+ "embed"
+ "flag"
+ "fmt"
+ "os"
+ "path/filepath"
+ "slices"
+ "strings"
+ "text/template"
+
+ tree_sitter "github.com/tree-sitter/go-tree-sitter"
+ tree_sitter_cpp "github.com/tree-sitter/tree-sitter-cpp/bindings/go"
+)
+
+//go:embed *.gotmpl
+var fsys embed.FS
+
+type Function struct {
+ Type,
+ Name,
+ Parameters,
+ Args string
+}
+
+func ParseFunction(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) Function {
+ var fn Function
+ fn.Name = node.ChildByFieldName("declarator").Utf8Text(source)
+ if params := node.ChildByFieldName("parameters"); params != nil {
+ fn.Parameters = params.Utf8Text(source)
+ fn.Args = ParseParameters(params, tc, source)
+ }
+
+ var types []string
+ for node.Parent() != nil && node.Parent().Kind() != "declaration" {
+ if node.Parent().Kind() == "pointer_declarator" {
+ types = append(types, "*")
+ }
+ node = node.Parent()
+ }
+
+ for sibling := node.PrevSibling(); sibling != nil; sibling = sibling.PrevSibling() {
+ types = append(types, sibling.Utf8Text(source))
+ }
+
+ slices.Reverse(types)
+ fn.Type = strings.Join(types, " ")
+ return fn
+}
+
+func ParseParameters(node *tree_sitter.Node, tc *tree_sitter.TreeCursor, source []byte) string {
+ var s []string
+ for _, child := range node.Children(tc) {
+ if child.IsNamed() {
+ child := child.ChildByFieldName("declarator")
+ for child != nil && child.Kind() != "identifier" {
+ if child.Kind() == "parenthesized_declarator" {
+ child = child.Child(1)
+ } else {
+ child = child.ChildByFieldName("declarator")
+ }
+ }
+
+ if child != nil {
+ s = append(s, child.Utf8Text(source))
+ }
+ }
+ }
+ return strings.Join(s, ", ")
+}
+
+func main() {
+ var output string
+ flag.StringVar(&output, "output", ".", "Output directory for generated files")
+ flag.Parse()
+
+ parser := tree_sitter.NewParser()
+ defer parser.Close()
+
+ language := tree_sitter.NewLanguage(tree_sitter_cpp.Language())
+ parser.SetLanguage(language)
+
+ query, _ := tree_sitter.NewQuery(language, `(function_declarator declarator: (identifier)) @func`)
+ defer query.Close()
+
+ qc := tree_sitter.NewQueryCursor()
+ defer qc.Close()
+
+ var funs []Function
+ for _, arg := range flag.Args() {
+ bts, err := os.ReadFile(arg)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error reading file %s: %v\n", arg, err)
+ continue
+ }
+
+ tree := parser.Parse(bts, nil)
+ defer tree.Close()
+
+ tc := tree.Walk()
+ defer tc.Close()
+
+ matches := qc.Matches(query, tree.RootNode(), bts)
+ for match := matches.Next(); match != nil; match = matches.Next() {
+ for _, capture := range match.Captures {
+ funs = append(funs, ParseFunction(&capture.Node, tc, bts))
+ }
+ }
+ }
+
+ tmpl, err := template.New("").ParseFS(fsys, "*.gotmpl")
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error parsing template: %v\n", err)
+ return
+ }
+
+ for _, tmpl := range tmpl.Templates() {
+ name := filepath.Join(output, strings.TrimSuffix(tmpl.Name(), ".gotmpl"))
+
+ fmt.Println("Generating", name)
+ f, err := os.Create(name)
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Error creating file %s: %v\n", name, err)
+ continue
+ }
+ defer f.Close()
+
+ if err := tmpl.Execute(f, map[string]any{
+ "Functions": funs,
+ }); err != nil {
+ fmt.Fprintf(os.Stderr, "Error executing template %s: %v\n", tmpl.Name(), err)
+ }
+ }
+}
diff --git a/x/mlxrunner/mlx/io.go b/x/mlxrunner/mlx/io.go
new file mode 100644
index 00000000000..304cfcd2c4a
--- /dev/null
+++ b/x/mlxrunner/mlx/io.go
@@ -0,0 +1,45 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "iter"
+ "unsafe"
+)
+
+func Load(path string) iter.Seq2[string, *Array] {
+ return func(yield func(string, *Array) bool) {
+ string2array := C.mlx_map_string_to_array_new()
+ defer C.mlx_map_string_to_array_free(string2array)
+
+ string2string := C.mlx_map_string_to_string_new()
+ defer C.mlx_map_string_to_string_free(string2string)
+
+ cPath := C.CString(path)
+ defer C.free(unsafe.Pointer(cPath))
+
+ cpu := C.mlx_default_cpu_stream_new()
+ defer C.mlx_stream_free(cpu)
+
+ C.mlx_load_safetensors(&string2array, &string2string, cPath, cpu)
+
+ it := C.mlx_map_string_to_array_iterator_new(string2array)
+ defer C.mlx_map_string_to_array_iterator_free(it)
+
+ for {
+ var key *C.char
+ value := C.mlx_array_new()
+ if C.mlx_map_string_to_array_iterator_next(&key, &value, it) != 0 {
+ break
+ }
+
+ name := C.GoString(key)
+ if !yield(name, &Array{ctx: value, desc: tensorDesc{name: name, numRefs: 1000}}) {
+ break
+ }
+ }
+ }
+}
diff --git a/x/mlxrunner/mlx/memory.go b/x/mlxrunner/mlx/memory.go
new file mode 100644
index 00000000000..e9a174b1ef0
--- /dev/null
+++ b/x/mlxrunner/mlx/memory.go
@@ -0,0 +1,87 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "fmt"
+ "log/slog"
+ "strconv"
+)
+
+func (b Byte) String() string {
+ return strconv.FormatInt(int64(b), 10) + " B"
+}
+
+func (b KibiByte) String() string {
+ return strconv.FormatFloat(float64(b)/(1<<10), 'f', 2, 64) + " KiB"
+}
+
+func (b MebiByte) String() string {
+ return strconv.FormatFloat(float64(b)/(1<<(2*10)), 'f', 2, 64) + " MiB"
+}
+
+func (b GibiByte) String() string {
+ return strconv.FormatFloat(float64(b)/(1<<(3*10)), 'f', 2, 64) + " GiB"
+}
+
+func (b TebiByte) String() string {
+ return strconv.FormatFloat(float64(b)/(1<<(4*10)), 'f', 2, 64) + " TiB"
+}
+
+func PrettyBytes(n int) fmt.Stringer {
+ switch {
+ case n < 1<<10:
+ return Byte(n)
+ case n < 1<<(2*10):
+ return KibiByte(n)
+ case n < 1<<(3*10):
+ return MebiByte(n)
+ case n < 1<<(4*10):
+ return GibiByte(n)
+ default:
+ return TebiByte(n)
+ }
+}
+
+func ActiveMemory() int {
+ var active C.size_t
+ C.mlx_get_active_memory(&active)
+ return int(active)
+}
+
+func CacheMemory() int {
+ var cache C.size_t
+ C.mlx_get_cache_memory(&cache)
+ return int(cache)
+}
+
+func PeakMemory() int {
+ var peak C.size_t
+ C.mlx_get_peak_memory(&peak)
+ return int(peak)
+}
+
+type Memory struct{}
+
+func (Memory) LogValue() slog.Value {
+ return slog.GroupValue(
+ slog.Any("active", PrettyBytes(ActiveMemory())),
+ slog.Any("cache", PrettyBytes(CacheMemory())),
+ slog.Any("peak", PrettyBytes(PeakMemory())),
+ )
+}
+
+type (
+ Byte int
+ KibiByte int
+ MebiByte int
+ GibiByte int
+ TebiByte int
+)
+
+func ClearCache() {
+ C.mlx_clear_cache()
+}
diff --git a/x/mlxrunner/mlx/mlx.go b/x/mlxrunner/mlx/mlx.go
new file mode 100644
index 00000000000..0bf43830c8f
--- /dev/null
+++ b/x/mlxrunner/mlx/mlx.go
@@ -0,0 +1,40 @@
+//go:build mlx
+
+package mlx
+
+//go:generate cmake -S . -B build -DCMAKE_INSTALL_PREFIX=dist -DCMAKE_BUILD_TYPE=Release
+//go:generate cmake --build build --parallel
+//go:generate cmake --install build
+//go:generate sh -c "go run generator/main.go -output=. ./dist/include/mlx/c/*.h"
+
+// #cgo CXXFLAGS: -std=c++17
+// #cgo CPPFLAGS: -I${SRCDIR}/dist/include
+// #cgo LDFLAGS: -L${SRCDIR}/dist/lib -lstdc++
+// #cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework Accelerate
+// #include "generated.h"
+import "C"
+
+func doEval(outputs []*Array, async bool) {
+ vector := C.mlx_vector_array_new()
+ defer C.mlx_vector_array_free(vector)
+
+ for _, output := range outputs {
+ if output.Valid() {
+ C.mlx_vector_array_append_value(vector, output.ctx)
+ }
+ }
+
+ if async {
+ C.mlx_async_eval(vector)
+ } else {
+ C.mlx_eval(vector)
+ }
+}
+
+func AsyncEval(outputs ...*Array) {
+ doEval(outputs, true)
+}
+
+func Eval(outputs ...*Array) {
+ doEval(outputs, false)
+}
diff --git a/x/mlxrunner/mlx/nn.go b/x/mlxrunner/mlx/nn.go
new file mode 100644
index 00000000000..3d5691368d9
--- /dev/null
+++ b/x/mlxrunner/mlx/nn.go
@@ -0,0 +1,38 @@
+//go:build mlx
+
+package mlx
+
+type Linear struct {
+ Weight Array `weight:"weight"`
+ Bias Array `weight:"bias"`
+}
+
+// Forward computes the linear transformation: x @ Weight.T + Bias
+func (m Linear) Forward(x *Array) *Array {
+ w := m.Weight.Transpose(1, 0)
+ if m.Bias.Valid() {
+ return m.Bias.Addmm(x, w, 1.0, 1.0)
+ }
+
+ return x.Matmul(w)
+}
+
+func (m Linear) Gather(x, lhs, rhs *Array, sorted bool) *Array {
+ w := m.Weight.Transpose(0, 2, 1)
+ // TODO: bias
+ return x.GatherMM(w, lhs, rhs, sorted)
+}
+
+type Embedding struct {
+ Weight Array `weight:"weight"`
+}
+
+func (e *Embedding) Forward(indices *Array) *Array {
+ return e.Weight.TakeAxis(indices, 0)
+}
+
+func (e *Embedding) AsLinear() Linear {
+ return Linear{
+ Weight: e.Weight,
+ }
+}
diff --git a/x/mlxrunner/mlx/ops.go b/x/mlxrunner/mlx/ops.go
new file mode 100644
index 00000000000..01a7f4835fc
--- /dev/null
+++ b/x/mlxrunner/mlx/ops.go
@@ -0,0 +1,256 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "unsafe"
+)
+
+func (t *Array) Abs() *Array {
+ out := New("ABS", t)
+ C.mlx_abs(&out.ctx, t.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Add(other *Array) *Array {
+ out := New("ADD", t, other)
+ C.mlx_add(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Addmm(a, b *Array, alpha, beta float32) *Array {
+ out := New("ADDMM", t, a, b)
+ C.mlx_addmm(&out.ctx, t.ctx, a.ctx, b.ctx, C.float(alpha), C.float(beta), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Argmax(axis int, keepDims bool) *Array {
+ out := New("ARGMAX", t)
+ C.mlx_argmax_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) ArgpartitionAxis(kth int, axis int) *Array {
+ out := New("ARGPARTITION", t)
+ C.mlx_argpartition_axis(&out.ctx, t.ctx, C.int(kth), C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) ArgsortAxis(axis int) *Array {
+ out := New("ARGSORT_AXIS", t)
+ C.mlx_argsort_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) AsType(dtype DType) *Array {
+ out := New("AS_TYPE", t)
+ C.mlx_astype(&out.ctx, t.ctx, C.mlx_dtype(dtype), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) AsStrided(shape []int, strides []int, offset int) *Array {
+ cShape := make([]C.int, len(shape))
+ for i, s := range shape {
+ cShape[i] = C.int(s)
+ }
+
+ cStrides := make([]C.int64_t, len(strides))
+ for i, s := range strides {
+ cStrides[i] = C.int64_t(s)
+ }
+
+ out := New("AS_STRIDED", t)
+ C.mlx_as_strided(
+ &out.ctx, t.ctx,
+ unsafe.SliceData(cShape), C.size_t(len(shape)),
+ unsafe.SliceData(cStrides), C.size_t(len(strides)),
+ C.size_t(offset),
+ DefaultStream().ctx,
+ )
+ return out
+}
+
+func (t *Array) Concatenate(axis int, others ...*Array) *Array {
+ vector := C.mlx_vector_array_new()
+ defer C.mlx_vector_array_free(vector)
+
+ s := append([]*Array{t}, others...)
+ for _, other := range s {
+ C.mlx_vector_array_append_value(vector, other.ctx)
+ }
+
+ out := New("CONCATENATE", s...)
+ C.mlx_concatenate_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Divide(other *Array) *Array {
+ out := New("DIVIDE", t, other)
+ C.mlx_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) ExpandDims(axis int) *Array {
+ out := New("EXPAND_DIMS", t)
+ C.mlx_expand_dims(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Flatten(startAxis, endAxis int) *Array {
+ out := New("FLATTEN", t)
+ C.mlx_flatten(&out.ctx, t.ctx, C.int(startAxis), C.int(endAxis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) FloorDivide(other *Array) *Array {
+ out := New("FLOOR_DIVIDE", t, other)
+ C.mlx_floor_divide(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) GatherMM(other, lhs, rhs *Array, sorted bool) *Array {
+ if lhs == nil {
+ lhs = New("")
+ }
+ if rhs == nil {
+ rhs = New("")
+ }
+ out := New("GATHER_MM", t, other, lhs, rhs)
+ C.mlx_gather_mm(&out.ctx, t.ctx, other.ctx, lhs.ctx, rhs.ctx, C.bool(sorted), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Logsumexp(keepDims bool) *Array {
+ out := New("LOGSUMEXP", t)
+ C.mlx_logsumexp(&out.ctx, t.ctx, C.bool(keepDims), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Matmul(other *Array) *Array {
+ out := New("MATMUL", t, other)
+ C.mlx_matmul(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Multiply(other *Array) *Array {
+ out := New("MULTIPLY", t, other)
+ C.mlx_multiply(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Negative() *Array {
+ out := New("NEGATIVE", t)
+ C.mlx_negative(&out.ctx, t.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Power(exponent *Array) *Array {
+ out := New("POWER", t, exponent)
+ C.mlx_power(&out.ctx, t.ctx, exponent.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) PutAlongAxis(indices, values *Array, axis int) *Array {
+ out := New("PUT_ALONG_AXIS", t, indices, values)
+ C.mlx_put_along_axis(&out.ctx, t.ctx, indices.ctx, values.ctx, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Reshape(axes ...int) *Array {
+ cAxes := make([]C.int, len(axes))
+ for i := range axes {
+ cAxes[i] = C.int(axes[i])
+ }
+
+ out := New("RESHAPE", t)
+ C.mlx_reshape(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Sigmoid() *Array {
+ out := New("SIGMOID", t)
+ C.mlx_sigmoid(&out.ctx, t.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Sqrt() *Array {
+ out := New("SQRT", t)
+ C.mlx_sqrt(&out.ctx, t.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Squeeze(axis int) *Array {
+ out := New("SQUEEZE", t)
+ C.mlx_squeeze_axis(&out.ctx, t.ctx, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) StackAxis(axis int, others ...*Array) *Array {
+ vectorData := make([]C.mlx_array, len(others)+1)
+ vectorData[0] = t.ctx
+ for i := range others {
+ vectorData[i+1] = others[i].ctx
+ }
+
+ vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
+ defer C.mlx_vector_array_free(vector)
+
+ out := New("STACK_AXIS", append(others, t)...)
+ C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Subtract(other *Array) *Array {
+ out := New("SUBTRACT", t, other)
+ C.mlx_subtract(&out.ctx, t.ctx, other.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) SumAxis(axis int, keepDims bool) *Array {
+ out := New("SUM_AXIS", t)
+ C.mlx_sum_axis(&out.ctx, t.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) TakeAxis(indices *Array, axis int) *Array {
+ out := New("TAKE_AXIS", t, indices)
+ C.mlx_take_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) TakeAlongAxis(indices *Array, axis int) *Array {
+ out := New("TAKE_ALONG_AXIS", t, indices)
+ C.mlx_take_along_axis(&out.ctx, t.ctx, indices.ctx, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Tanh() *Array {
+ out := New("TANH", t)
+ C.mlx_tanh(&out.ctx, t.ctx, DefaultStream().ctx)
+ return out
+}
+
+func (t *Array) Transpose(axes ...int) *Array {
+ cAxes := make([]C.int, len(axes))
+ for i, axis := range axes {
+ cAxes[i] = C.int(axis)
+ }
+
+ out := New("TRANSPOSE", t)
+ C.mlx_transpose_axes(&out.ctx, t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), DefaultStream().ctx)
+ return out
+}
+
+func Zeros(dtype DType, shape ...int) *Array {
+ cAxes := make([]C.int, len(shape))
+ for i := range shape {
+ cAxes[i] = C.int(shape[i])
+ }
+
+ t := New("ZEROS")
+ C.mlx_zeros(&t.ctx, unsafe.SliceData(cAxes), C.size_t(len(cAxes)), C.mlx_dtype(dtype), DefaultStream().ctx)
+ return t
+}
diff --git a/x/mlxrunner/mlx/ops_extra.go b/x/mlxrunner/mlx/ops_extra.go
new file mode 100644
index 00000000000..f2882e9892e
--- /dev/null
+++ b/x/mlxrunner/mlx/ops_extra.go
@@ -0,0 +1,450 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "reflect"
+ "unsafe"
+)
+
+// Quantization operations
+
+func Quantize(w *Array, groupSize, bits int, mode string) (weights, scales, biases *Array) {
+ cMode := C.CString(mode)
+ defer C.free(unsafe.Pointer(cMode))
+ optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
+ optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
+ res := C.mlx_vector_array_new()
+ defer C.mlx_vector_array_free(res)
+ C.mlx_quantize(&res, w.ctx, optGroupSize, optBits, cMode, DefaultStream().ctx)
+
+ vecSize := int(C.mlx_vector_array_size(res))
+ w0 := New("QUANTIZE_W")
+ C.mlx_vector_array_get(&w0.ctx, res, 0)
+ w1 := New("QUANTIZE_S")
+ C.mlx_vector_array_get(&w1.ctx, res, 1)
+ if vecSize >= 3 {
+ w2 := New("QUANTIZE_B")
+ C.mlx_vector_array_get(&w2.ctx, res, 2)
+ return w0, w1, w2
+ }
+ return w0, w1, nil
+}
+
+func Dequantize(w, scales, biases *Array, groupSize, bits int, mode string) *Array {
+ cMode := C.CString(mode)
+ defer C.free(unsafe.Pointer(cMode))
+ optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
+ optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
+ optDtype := C.mlx_optional_dtype{has_value: false}
+
+ inputs := []*Array{w, scales}
+ var b C.mlx_array
+ if biases != nil {
+ b = biases.ctx
+ inputs = append(inputs, biases)
+ }
+
+ out := New("DEQUANTIZE", inputs...)
+ C.mlx_dequantize(&out.ctx, w.ctx, scales.ctx, b, optGroupSize, optBits, cMode, optDtype, DefaultStream().ctx)
+ return out
+}
+
+func QuantizedMatmul(x, w, scales, biases *Array, transpose bool, groupSize, bits int, mode string) *Array {
+ cMode := C.CString(mode)
+ defer C.free(unsafe.Pointer(cMode))
+ optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
+ optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
+
+ inputs := []*Array{x, w, scales}
+ var b C.mlx_array
+ if biases != nil {
+ b = biases.ctx
+ inputs = append(inputs, biases)
+ }
+
+ out := New("QUANTIZED_MATMUL", inputs...)
+ C.mlx_quantized_matmul(&out.ctx, x.ctx, w.ctx, scales.ctx, b, C.bool(transpose), optGroupSize, optBits, cMode, DefaultStream().ctx)
+ return out
+}
+
+func GatherQMM(x, w, scales *Array, biases, lhsIndices, rhsIndices *Array, transpose bool, groupSize, bits int, mode string, sortedIndices bool) *Array {
+ cMode := C.CString(mode)
+ defer C.free(unsafe.Pointer(cMode))
+ optGroupSize := C.mlx_optional_int{value: C.int(groupSize), has_value: true}
+ optBits := C.mlx_optional_int{value: C.int(bits), has_value: true}
+
+ inputs := []*Array{x, w, scales}
+ var b, lhs, rhs C.mlx_array
+ if biases != nil {
+ b = biases.ctx
+ inputs = append(inputs, biases)
+ }
+ if lhsIndices != nil {
+ lhs = lhsIndices.ctx
+ inputs = append(inputs, lhsIndices)
+ }
+ if rhsIndices != nil {
+ rhs = rhsIndices.ctx
+ inputs = append(inputs, rhsIndices)
+ }
+
+ out := New("GATHER_QMM", inputs...)
+ C.mlx_gather_qmm(&out.ctx, x.ctx, w.ctx, scales.ctx, b, lhs, rhs, C.bool(transpose), optGroupSize, optBits, cMode, C.bool(sortedIndices), DefaultStream().ctx)
+ return out
+}
+
+// Missing tensor ops
+
+func Tile(a *Array, reps []int32) *Array {
+ cReps := make([]C.int, len(reps))
+ for i, r := range reps {
+ cReps[i] = C.int(r)
+ }
+ out := New("TILE", a)
+ C.mlx_tile(&out.ctx, a.ctx, unsafe.SliceData(cReps), C.size_t(len(reps)), DefaultStream().ctx)
+ return out
+}
+
+func Tri(n, m int32, k int) *Array {
+ out := New("TRI")
+ C.mlx_tri(&out.ctx, C.int(n), C.int(m), C.int(k), C.mlx_dtype(DTypeFloat32), DefaultStream().ctx)
+ return out
+}
+
+func Where(condition, a, b *Array) *Array {
+ out := New("WHERE", condition, a, b)
+ C.mlx_where(&out.ctx, condition.ctx, a.ctx, b.ctx, DefaultStream().ctx)
+ return out
+}
+
+// Convenience wrappers (function-style for the model code)
+
+func Stack(arrays []*Array, axis int) *Array {
+ vectorData := make([]C.mlx_array, len(arrays))
+ for i := range arrays {
+ vectorData[i] = arrays[i].ctx
+ }
+ vector := C.mlx_vector_array_new_data(unsafe.SliceData(vectorData), C.size_t(len(vectorData)))
+ defer C.mlx_vector_array_free(vector)
+
+ out := New("STACK", arrays...)
+ C.mlx_stack_axis(&out.ctx, vector, C.int(axis), DefaultStream().ctx)
+ return out
+}
+
+func Neg(a *Array) *Array {
+ return a.Negative()
+}
+
+func Sum(a *Array, axis int, keepDims bool) *Array {
+ return a.SumAxis(axis, keepDims)
+}
+
+func Argsort(a *Array, axis int) *Array {
+ return a.ArgsortAxis(axis)
+}
+
+func Take(a *Array, indices *Array, axis int) *Array {
+ return a.TakeAxis(indices, axis)
+}
+
+func RSqrt(a *Array) *Array {
+ out := New("RSQRT", a)
+ C.mlx_rsqrt(&out.ctx, a.ctx, DefaultStream().ctx)
+ return out
+}
+
+func Mean(a *Array, axis int, keepDims bool) *Array {
+ out := New("MEAN_AXIS", a)
+ C.mlx_mean_axis(&out.ctx, a.ctx, C.int(axis), C.bool(keepDims), DefaultStream().ctx)
+ return out
+}
+
+func Argpartition(a *Array, kth int, axis int) *Array {
+ return a.ArgpartitionAxis(kth, axis)
+}
+
+func TakeAlongAxis(a, indices *Array, axis int) *Array {
+ return a.TakeAlongAxis(indices, axis)
+}
+
+// Function-style wrappers matching imagegen API
+
+func Add(a, b *Array) *Array {
+ return a.Add(b)
+}
+
+func Sub(a, b *Array) *Array {
+ return a.Subtract(b)
+}
+
+func Mul(a, b *Array) *Array {
+ return a.Multiply(b)
+}
+
+func Div(a, b *Array) *Array {
+ return a.Divide(b)
+}
+
+func Matmul(a, b *Array) *Array {
+ return a.Matmul(b)
+}
+
+func Reshape(a *Array, shape ...int32) *Array {
+ axes := make([]int, len(shape))
+ for i, s := range shape {
+ axes[i] = int(s)
+ }
+ return a.Reshape(axes...)
+}
+
+func Transpose(a *Array, axes ...int) *Array {
+ return a.Transpose(axes...)
+}
+
+func ExpandDims(a *Array, axis int) *Array {
+ return a.ExpandDims(axis)
+}
+
+func Squeeze(a *Array, axis int) *Array {
+ return a.Squeeze(axis)
+}
+
+func Flatten(a *Array) *Array {
+ return a.Flatten(0, -1)
+}
+
+func Concatenate(arrays []*Array, axis int) *Array {
+ if len(arrays) == 0 {
+ return nil
+ }
+ return arrays[0].Concatenate(axis, arrays[1:]...)
+}
+
+func SliceStartStop(a *Array, start, stop []int32) *Array {
+ n := len(start)
+ cStart := make([]C.int, n)
+ cStop := make([]C.int, n)
+ cStrides := make([]C.int, n)
+ for i := 0; i < n; i++ {
+ cStart[i] = C.int(start[i])
+ cStop[i] = C.int(stop[i])
+ cStrides[i] = 1
+ }
+ out := New("SLICE", a)
+ C.mlx_slice(&out.ctx, a.ctx, unsafe.SliceData(cStart), C.size_t(n), unsafe.SliceData(cStop), C.size_t(n), unsafe.SliceData(cStrides), C.size_t(n), DefaultStream().ctx)
+ return out
+}
+
+func GatherMM(a, b *Array, lhsIndices, rhsIndices *Array, sortedIndices bool) *Array {
+ if lhsIndices == nil {
+ lhsIndices = New("")
+ }
+ if rhsIndices == nil {
+ rhsIndices = New("")
+ }
+ return a.GatherMM(b, lhsIndices, rhsIndices, sortedIndices)
+}
+
+func SiLU(a *Array) *Array {
+ sig := a.Sigmoid()
+ return a.Multiply(sig)
+}
+
+func RoPEWithBase(x *Array, dims int, traditional bool, base, scale float32, offset int) *Array {
+ freqs := New("")
+ out := New("FAST_ROPE", x, freqs)
+ C.mlx_fast_rope(
+ &out.ctx,
+ x.ctx,
+ C.int(dims),
+ C.bool(traditional),
+ C.mlx_optional_float{
+ value: C.float(base),
+ has_value: C.bool(func() bool { return base != 0 }()),
+ },
+ C.float(scale),
+ C.int(offset),
+ freqs.ctx,
+ DefaultStream().ctx,
+ )
+ return out
+}
+
+func Sigmoid(a *Array) *Array {
+ return a.Sigmoid()
+}
+
+func ScaledDotProductAttentionCausal(q, k, v *Array, scale float32, causalMask bool) *Array {
+ mask := New("")
+ sinks := New("")
+ mode := ""
+ if causalMask {
+ mode = "causal"
+ }
+ cMode := C.CString(mode)
+ defer C.free(unsafe.Pointer(cMode))
+
+ out := New("FAST_SDPA", q, k, v, mask, sinks)
+ C.mlx_fast_scaled_dot_product_attention(&out.ctx, q.ctx, k.ctx, v.ctx, C.float(scale), cMode, mask.ctx, sinks.ctx, DefaultStream().ctx)
+ return out
+}
+
+func RMSNormFn(x, weight *Array, eps float32) *Array {
+ out := New("FAST_RMSNORM", x)
+ C.mlx_fast_rms_norm(&out.ctx, x.ctx, weight.ctx, C.float(eps), DefaultStream().ctx)
+ return out
+}
+
+func AddMM(c, a, b *Array, alpha, beta float32) *Array {
+ return c.Addmm(a, b, alpha, beta)
+}
+
+// Scalar helpers
+
+// scalarWithDtype creates a scalar array matching the dtype of a.
+// Matching dtype is important for graph fusion and avoiding implicit casts.
+func scalarWithDtype(s float32, a *Array) C.mlx_array {
+ f32 := C.mlx_array_new_float(C.float(s))
+ dtype := a.DType()
+ if dtype == DTypeFloat32 {
+ return f32
+ }
+ casted := C.mlx_array_new()
+ C.mlx_astype(&casted, f32, C.mlx_dtype(dtype), DefaultStream().ctx)
+ C.mlx_array_free(f32)
+ return casted
+}
+
+func AddScalar(a *Array, s float32) *Array {
+ scalar := scalarWithDtype(s, a)
+ out := New("ADD_SCALAR", a)
+ C.mlx_add(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
+ C.mlx_array_free(scalar)
+ return out
+}
+
+func MulScalar(a *Array, s float32) *Array {
+ scalar := scalarWithDtype(s, a)
+ out := New("MUL_SCALAR", a)
+ C.mlx_multiply(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
+ C.mlx_array_free(scalar)
+ return out
+}
+
+func DivScalar(a *Array, s float32) *Array {
+ scalar := scalarWithDtype(s, a)
+ out := New("DIV_SCALAR", a)
+ C.mlx_divide(&out.ctx, a.ctx, scalar, DefaultStream().ctx)
+ C.mlx_array_free(scalar)
+ return out
+}
+
+func FloorDivideScalar(a *Array, s int32) *Array {
+ scalar := FromValue(int(s))
+ return a.FloorDivide(scalar)
+}
+
+// Array constructors
+
+func NewArrayInt32(data []int32, shape []int32) *Array {
+ cShape := make([]C.int, len(shape))
+ for i, s := range shape {
+ cShape[i] = C.int(s)
+ }
+ out := New("NEW_ARRAY_INT32")
+ out.ctx = C.mlx_array_new_data(unsafe.Pointer(&data[0]), unsafe.SliceData(cShape), C.int(len(shape)), C.mlx_dtype(DTypeInt32))
+ return out
+}
+
+func NewScalarArray(value float32) *Array {
+ out := New("SCALAR")
+ out.ctx = C.mlx_array_new_float32(C.float(value))
+ return out
+}
+
+func ZerosF32(shape []int32) *Array {
+ return Zeros(DTypeFloat32, func() []int {
+ ints := make([]int, len(shape))
+ for i, s := range shape {
+ ints[i] = int(s)
+ }
+ return ints
+ }()...)
+}
+
+// Utility
+
+func Collect(v any) []*Array {
+ var arrays []*Array
+ seen := make(map[uintptr]bool)
+ collect(reflect.ValueOf(v), &arrays, seen)
+ return arrays
+}
+
+func collect(v reflect.Value, arrays *[]*Array, seen map[uintptr]bool) {
+ if !v.IsValid() {
+ return
+ }
+
+ if v.Kind() == reflect.Ptr {
+ if v.IsNil() {
+ return
+ }
+ ptr := v.Pointer()
+ if seen[ptr] {
+ return
+ }
+ seen[ptr] = true
+
+ if arr, ok := v.Interface().(*Array); ok {
+ if arr != nil && arr.Valid() {
+ *arrays = append(*arrays, arr)
+ }
+ return
+ }
+ collect(v.Elem(), arrays, seen)
+ return
+ }
+
+ switch v.Kind() {
+ case reflect.Struct:
+ // Check if this struct IS an Array (not a pointer to one)
+ if arr, ok := v.Addr().Interface().(*Array); ok {
+ if arr != nil && arr.Valid() {
+ *arrays = append(*arrays, arr)
+ }
+ return
+ }
+ for i := 0; i < v.NumField(); i++ {
+ field := v.Field(i)
+ if field.CanInterface() {
+ collect(field, arrays, seen)
+ }
+ }
+ case reflect.Slice:
+ for i := 0; i < v.Len(); i++ {
+ collect(v.Index(i), arrays, seen)
+ }
+ case reflect.Map:
+ for _, key := range v.MapKeys() {
+ collect(v.MapIndex(key), arrays, seen)
+ }
+ case reflect.Interface:
+ if !v.IsNil() {
+ collect(v.Elem(), arrays, seen)
+ }
+ }
+}
+
+func EnableCompile() {
+ C.mlx_enable_compile()
+}
+
+func DisableCompile() {
+ C.mlx_disable_compile()
+}
diff --git a/x/mlxrunner/mlx/random.go b/x/mlxrunner/mlx/random.go
new file mode 100644
index 00000000000..805308b4a4e
--- /dev/null
+++ b/x/mlxrunner/mlx/random.go
@@ -0,0 +1,13 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+func (t *Array) Categorical(axis int) *Array {
+ key := New("")
+ out := New("", t, key)
+ C.mlx_random_categorical(&out.ctx, t.ctx, C.int(axis), key.ctx, DefaultStream().ctx)
+ return out
+}
diff --git a/x/mlxrunner/mlx/slice.go b/x/mlxrunner/mlx/slice.go
new file mode 100644
index 00000000000..7ab7e203185
--- /dev/null
+++ b/x/mlxrunner/mlx/slice.go
@@ -0,0 +1,86 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "cmp"
+ "unsafe"
+)
+
+type slice struct {
+ args []int
+}
+
+func Slice(args ...int) slice {
+ return slice{args: args}
+}
+
+func makeSlices(dims []int, slices ...slice) (starts, stops, strides []C.int) {
+ if len(slices) != len(dims) {
+ panic("number of slice arguments must match number of tensor dimensions")
+ }
+
+ args := [3][]C.int{
+ make([]C.int, len(slices)),
+ make([]C.int, len(slices)),
+ make([]C.int, len(slices)),
+ }
+
+ for i, s := range slices {
+ switch len(s.args) {
+ case 0:
+ // slice[:]
+ args[0][i] = C.int(0)
+ args[1][i] = C.int(dims[i])
+ args[2][i] = C.int(1)
+ case 1:
+ // slice[i]
+ args[0][i] = C.int(s.args[0])
+ args[1][i] = C.int(s.args[0] + 1)
+ args[2][i] = C.int(1)
+ case 2:
+ // slice[i:j]
+ args[0][i] = C.int(s.args[0])
+ args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
+ args[2][i] = C.int(1)
+ case 3:
+ // slice[i:j:k]
+ args[0][i] = C.int(s.args[0])
+ args[1][i] = cmp.Or(C.int(s.args[1]), C.int(dims[i]))
+ args[2][i] = C.int(s.args[2])
+ default:
+ panic("invalid slice arguments")
+ }
+ }
+
+ return args[0], args[1], args[2]
+}
+
+func (t *Array) Slice(slices ...slice) *Array {
+ starts, stops, strides := makeSlices(t.Dims(), slices...)
+ out := New("SLICE", t)
+ C.mlx_slice(
+ &out.ctx, t.ctx,
+ unsafe.SliceData(starts), C.size_t(len(starts)),
+ unsafe.SliceData(stops), C.size_t(len(stops)),
+ unsafe.SliceData(strides), C.size_t(len(strides)),
+ DefaultStream().ctx,
+ )
+ return out
+}
+
+func (t *Array) SliceUpdate(other *Array, slices ...slice) *Array {
+ starts, stops, strides := makeSlices(t.Dims(), slices...)
+ out := New("SLICE_UPDATE", t, other)
+ C.mlx_slice_update(
+ &out.ctx, t.ctx, other.ctx,
+ unsafe.SliceData(starts), C.size_t(len(starts)),
+ unsafe.SliceData(stops), C.size_t(len(stops)),
+ unsafe.SliceData(strides), C.size_t(len(strides)),
+ DefaultStream().ctx,
+ )
+ return out
+}
diff --git a/x/mlxrunner/mlx/stream.go b/x/mlxrunner/mlx/stream.go
new file mode 100644
index 00000000000..83a3eeffdce
--- /dev/null
+++ b/x/mlxrunner/mlx/stream.go
@@ -0,0 +1,45 @@
+//go:build mlx
+
+package mlx
+
+// #include "generated.h"
+import "C"
+
+import (
+ "log/slog"
+ "sync"
+)
+
+type Device struct {
+ ctx C.mlx_device
+}
+
+func (d Device) LogValue() slog.Value {
+ str := C.mlx_string_new()
+ defer C.mlx_string_free(str)
+ C.mlx_device_tostring(&str, d.ctx)
+ return slog.StringValue(C.GoString(C.mlx_string_data(str)))
+}
+
+var DefaultDevice = sync.OnceValue(func() Device {
+ d := C.mlx_device_new()
+ C.mlx_get_default_device(&d)
+ return Device{d}
+})
+
+type Stream struct {
+ ctx C.mlx_stream
+}
+
+func (s Stream) LogValue() slog.Value {
+ str := C.mlx_string_new()
+ defer C.mlx_string_free(str)
+ C.mlx_stream_tostring(&str, s.ctx)
+ return slog.StringValue(C.GoString(C.mlx_string_data(str)))
+}
+
+var DefaultStream = sync.OnceValue(func() Stream {
+ s := C.mlx_stream_new()
+ C.mlx_get_default_stream(&s, DefaultDevice().ctx)
+ return Stream{s}
+})
diff --git a/x/mlxrunner/model/base/base.go b/x/mlxrunner/model/base/base.go
new file mode 100644
index 00000000000..6d3a25798c8
--- /dev/null
+++ b/x/mlxrunner/model/base/base.go
@@ -0,0 +1,85 @@
+//go:build mlx
+
+package base
+
+import (
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "sync"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/model"
+ "github.com/ollama/ollama/x/tokenizer"
+)
+
+// Model is the interface that model implementations must satisfy.
+type Model interface {
+ Forward(inputs *mlx.Array, cache []cache.Cache) *mlx.Array
+ Unembed(x *mlx.Array) *mlx.Array
+ NumLayers() int
+ Tokenizer() *tokenizer.Tokenizer
+
+ // LoadWeights receives all tensors loaded from the manifest and assigns
+ // them to model fields. Model-specific logic (MLA absorption, expert
+ // stacking, quantized layer creation) happens here.
+ LoadWeights(tensors map[string]*mlx.Array) error
+}
+
+var (
+ mu sync.Mutex
+ registry = make(map[string]func(root *model.Root) (Model, error))
+)
+
+// Register registers a model constructor by architecture name.
+// Called from init() in model packages. Panics on duplicate registration.
+func Register(arch string, fn func(root *model.Root) (Model, error)) {
+ mu.Lock()
+ defer mu.Unlock()
+
+ if _, exists := registry[arch]; exists {
+ panic(fmt.Sprintf("model architecture %q already registered", arch))
+ }
+ registry[arch] = fn
+}
+
+// New reads config.json from the manifest, detects the architecture, looks up
+// the registered constructor, and calls it to create the model (with config
+// parsed and struct created, but weights not yet loaded).
+func New(root *model.Root) (Model, error) {
+ configData, err := root.Manifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("failed to read config.json: %w", err)
+ }
+
+ var archConfig struct {
+ Architectures []string `json:"architectures"`
+ }
+ if err := json.Unmarshal(configData, &archConfig); err != nil {
+ return nil, fmt.Errorf("failed to parse config.json: %w", err)
+ }
+
+ if len(archConfig.Architectures) == 0 {
+ return nil, fmt.Errorf("no architectures found in config.json")
+ }
+
+ arch := archConfig.Architectures[0]
+ slog.Info("Model architecture", "arch", arch)
+
+ mu.Lock()
+ fn, ok := registry[arch]
+ mu.Unlock()
+
+ if !ok {
+ return nil, fmt.Errorf("unsupported architecture: %s", arch)
+ }
+
+ return fn(root)
+}
+
+// Weights returns the model's LoadWeights method, which encapsulates all
+// weight assignment and post-processing (MLA absorption, expert stacking).
+func Weights(m Model) func(map[string]*mlx.Array) error {
+ return m.LoadWeights
+}
diff --git a/x/mlxrunner/model/base/base_stub.go b/x/mlxrunner/model/base/base_stub.go
new file mode 100644
index 00000000000..318d8f91154
--- /dev/null
+++ b/x/mlxrunner/model/base/base_stub.go
@@ -0,0 +1,3 @@
+//go:build !mlx
+
+package base
diff --git a/x/mlxrunner/model/linear.go b/x/mlxrunner/model/linear.go
new file mode 100644
index 00000000000..fffdbdb2942
--- /dev/null
+++ b/x/mlxrunner/model/linear.go
@@ -0,0 +1,92 @@
+//go:build mlx
+
+package model
+
+import (
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/models/nn"
+)
+
+// LinearFactory builds linear layers using shared tensor maps and quant defaults.
+type LinearFactory struct {
+ tensors map[string]*mlx.Array
+ defaultGroupSize int
+ defaultBits int
+ defaultMode string
+ tensorQuant map[string]*TensorQuantInfo
+}
+
+// NewLinearFactory creates a reusable constructor for model linear layers.
+func NewLinearFactory(
+ tensors map[string]*mlx.Array,
+ defaultGroupSize, defaultBits int,
+ defaultMode string,
+ tensorQuant map[string]*TensorQuantInfo,
+) LinearFactory {
+ return LinearFactory{
+ tensors: tensors,
+ defaultGroupSize: defaultGroupSize,
+ defaultBits: defaultBits,
+ defaultMode: defaultMode,
+ tensorQuant: tensorQuant,
+ }
+}
+
+// Make constructs a linear layer at path.
+func (f LinearFactory) Make(path string) nn.LinearLayer {
+ return MakeLinearLayer(
+ f.tensors,
+ path,
+ f.defaultGroupSize,
+ f.defaultBits,
+ f.defaultMode,
+ f.tensorQuant,
+ )
+}
+
+// MakeLinearLayer constructs a linear layer from a tensor map.
+//
+// For quantized tensors (path.weight + path.weight_scale), it resolves per-tensor
+// quant params via TensorQuant metadata (with shape-based affine fallback).
+// For non-quantized tensors, it returns a standard nn.Linear.
+func MakeLinearLayer(
+ tensors map[string]*mlx.Array,
+ path string,
+ defaultGroupSize, defaultBits int,
+ defaultMode string,
+ tensorQuant map[string]*TensorQuantInfo,
+) nn.LinearLayer {
+ w := tensors[path+".weight"]
+ if w == nil {
+ return nil
+ }
+
+ scales := tensors[path+".weight_scale"]
+ if scales != nil {
+ qbiases := tensors[path+".weight_qbias"]
+ bias := tensors[path+".bias"]
+
+ groupSize, bits, mode := ResolveLinearQuantParams(
+ defaultGroupSize,
+ defaultBits,
+ defaultMode,
+ tensorQuant,
+ path+".weight",
+ w,
+ scales,
+ )
+
+ return &nn.QuantizedLinear{
+ Weight: w,
+ Scales: scales,
+ QBiases: qbiases,
+ Bias: bias,
+ GroupSize: groupSize,
+ Bits: bits,
+ Mode: mode,
+ }
+ }
+
+ bias := tensors[path+".bias"]
+ return nn.NewLinear(w, bias)
+}
diff --git a/x/mlxrunner/model/quant.go b/x/mlxrunner/model/quant.go
new file mode 100644
index 00000000000..3a17ab4856b
--- /dev/null
+++ b/x/mlxrunner/model/quant.go
@@ -0,0 +1,130 @@
+//go:build mlx
+
+package model
+
+import (
+ "strings"
+
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+)
+
+// QuantizationParams returns default groupSize, bits, and mode for a quantization type.
+func QuantizationParams(quantization string) (groupSize, bits int, mode string) {
+ switch strings.ToUpper(quantization) {
+ case "NVFP4":
+ return 16, 4, "nvfp4"
+ case "FP4", "Q4", "INT4":
+ return 32, 4, "affine"
+ case "MXFP8":
+ return 32, 8, "mxfp8"
+ case "FP8", "Q8", "INT8", "":
+ return 64, 8, "affine"
+ default:
+ return 32, 8, "affine"
+ }
+}
+
+// TensorQuantParams resolves quant params for a tensor using per-tensor metadata
+// when available, otherwise falling back to the provided model defaults.
+func TensorQuantParams(
+ defaultGroupSize, defaultBits int,
+ defaultMode string,
+ tensorQuant map[string]*TensorQuantInfo,
+ tensorName string,
+) (groupSize, bits int, mode string, fromTensor bool) {
+ if tensorQuant != nil {
+ if tq := tensorQuant[tensorName]; tq != nil {
+ groupSize, bits, mode = QuantizationParams(tq.QuantType)
+ if tq.GroupSize > 0 {
+ groupSize = tq.GroupSize
+ }
+ return groupSize, bits, mode, true
+ }
+ }
+ return defaultGroupSize, defaultBits, defaultMode, false
+}
+
+// ResolveLinearQuantParams resolves quantization params for a quantized linear
+// tensor, preferring per-tensor metadata and falling back to shape-based
+// inference for affine packed tensors.
+func ResolveLinearQuantParams(
+ defaultGroupSize, defaultBits int,
+ defaultMode string,
+ tensorQuant map[string]*TensorQuantInfo,
+ tensorName string,
+ weight, scales *mlx.Array,
+) (groupSize, bits int, mode string) {
+ groupSize, bits, mode, fromTensor := TensorQuantParams(
+ defaultGroupSize,
+ defaultBits,
+ defaultMode,
+ tensorQuant,
+ tensorName,
+ )
+
+ if mode == "affine" {
+ if inferredGroupSize, inferredBits, ok := InferAffineQuantParamsFromShapes(weight, scales, bits); ok {
+ if !fromTensor || groupSize == 0 || bits == 0 {
+ groupSize = inferredGroupSize
+ bits = inferredBits
+ }
+ }
+ }
+
+ return groupSize, bits, mode
+}
+
+// InferAffineQuantParamsFromShapes infers (groupSize,bits) for affine quantized
+// tensors from packed weight and scale shapes.
+func InferAffineQuantParamsFromShapes(weight, scales *mlx.Array, hintBits int) (groupSize, bits int, ok bool) {
+ if weight == nil || scales == nil {
+ return 0, 0, false
+ }
+
+ weightShape := weight.Dims()
+ scaleShape := scales.Dims()
+ if len(weightShape) == 0 || len(scaleShape) == 0 {
+ return 0, 0, false
+ }
+
+ weightCols := weightShape[len(weightShape)-1]
+ scalesCols := scaleShape[len(scaleShape)-1]
+ if weightCols <= 0 || scalesCols <= 0 {
+ return 0, 0, false
+ }
+
+ groupSize4 := weightCols * 8 / scalesCols
+ groupSize8 := weightCols * 4 / scalesCols
+
+ switch {
+ case groupSize4 == 32:
+ return 32, 4, true
+ case groupSize8 == 64:
+ return 64, 8, true
+ case groupSize4 == 64 && groupSize8 == 32:
+ if hintBits == 8 {
+ return 32, 8, true
+ }
+ if hintBits == 4 {
+ return 64, 4, true
+ }
+ }
+
+ if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
+ return groupSize4, 4, true
+ }
+ if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
+ return groupSize8, 8, true
+ }
+
+ return 0, 0, false
+}
+
+func isCommonGroupSize(v int) bool {
+ switch v {
+ case 16, 32, 64, 128:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/x/mlxrunner/model/root.go b/x/mlxrunner/model/root.go
new file mode 100644
index 00000000000..c912f7f4c28
--- /dev/null
+++ b/x/mlxrunner/model/root.go
@@ -0,0 +1,252 @@
+//go:build mlx
+
+package model
+
+import (
+ "encoding/binary"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "sort"
+ "strconv"
+ "strings"
+
+ "github.com/ollama/ollama/x/imagegen/manifest"
+)
+
+// TensorQuantInfo describes per-tensor quantization metadata.
+type TensorQuantInfo struct {
+ QuantType string
+ GroupSize int
+}
+
+// Root wraps a ModelManifest with pre-scanned quantization metadata.
+type Root struct {
+ Manifest *manifest.ModelManifest
+
+ // Backwards-compatible model-level quant metadata (first tensor blob).
+ quantType string
+ groupSize int
+
+ // Per-tensor quantization metadata.
+ tensorQuant map[string]*TensorQuantInfo
+}
+
+// Open loads a manifest for the given model name and scans tensor blobs for
+// quantization metadata.
+func Open(modelName string) (*Root, error) {
+ m, err := manifest.LoadManifest(modelName)
+ if err != nil {
+ return nil, err
+ }
+
+ root := &Root{
+ Manifest: m,
+ tensorQuant: make(map[string]*TensorQuantInfo),
+ }
+
+ for _, layer := range m.GetTensorLayers("") {
+ blobPath := m.BlobPath(layer.Digest)
+
+ infos, blobQuantType, blobGroupSize, err := readBlobTensorQuantInfo(blobPath)
+ if err != nil {
+ continue
+ }
+
+ for name, info := range infos {
+ root.tensorQuant[name] = info
+ }
+
+ if root.quantType == "" && blobQuantType != "" {
+ root.quantType = strings.ToUpper(blobQuantType)
+ root.groupSize = blobGroupSize
+ if root.groupSize == 0 {
+ root.groupSize = defaultGroupSize(root.quantType)
+ }
+ }
+ }
+
+ return root, nil
+}
+
+// Close is a no-op for now (future: release resources).
+func (r *Root) Close() {}
+
+// QuantType returns the quantization type detected from the first tensor blob metadata.
+func (r *Root) QuantType() string { return r.quantType }
+
+// GroupSize returns the quantization group size detected from the first tensor blob metadata.
+func (r *Root) GroupSize() int { return r.groupSize }
+
+// TensorQuant returns per-tensor quantization metadata if available.
+func (r *Root) TensorQuant(name string) *TensorQuantInfo {
+ if r == nil {
+ return nil
+ }
+ return r.tensorQuant[name]
+}
+
+// AllTensorQuant returns a copy of the per-tensor quantization metadata.
+func (r *Root) AllTensorQuant() map[string]*TensorQuantInfo {
+ out := make(map[string]*TensorQuantInfo, len(r.tensorQuant))
+ for k, v := range r.tensorQuant {
+ if v == nil {
+ continue
+ }
+ copy := *v
+ out[k] = ©
+ }
+ return out
+}
+
+func defaultGroupSize(quantType string) int {
+ groupSize, _, _ := QuantizationParams(quantType)
+ return groupSize
+}
+
+func readBlobTensorQuantInfo(path string) (map[string]*TensorQuantInfo, string, int, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, "", 0, err
+ }
+ defer f.Close()
+
+ var headerSize uint64
+ if err := binary.Read(f, binary.LittleEndian, &headerSize); err != nil {
+ return nil, "", 0, err
+ }
+ if headerSize > 100*1024*1024 {
+ return nil, "", 0, fmt.Errorf("header too large: %d", headerSize)
+ }
+
+ data := make([]byte, headerSize)
+ if _, err := io.ReadFull(f, data); err != nil {
+ return nil, "", 0, err
+ }
+
+ var header map[string]json.RawMessage
+ if err := json.Unmarshal(data, &header); err != nil {
+ return nil, "", 0, err
+ }
+
+ globalQuantType, globalGroupSize := parseGlobalQuantMetadata(header)
+ globalQuantType = strings.ToUpper(globalQuantType)
+
+ mainNames := mainTensorNames(header)
+ infos := make(map[string]*TensorQuantInfo)
+ for _, name := range mainNames {
+ if _, ok := header[name+".scale"]; !ok {
+ continue
+ }
+
+ quantType := globalQuantType
+ groupSize := globalGroupSize
+
+ inferredType, inferredGroup := inferQuantTypeFromShapes(header, name, quantType)
+ if quantType == "" {
+ quantType = inferredType
+ }
+ if groupSize == 0 {
+ groupSize = inferredGroup
+ }
+ if quantType == "" {
+ continue
+ }
+ if groupSize == 0 {
+ groupSize = defaultGroupSize(quantType)
+ }
+
+ infos[name] = &TensorQuantInfo{QuantType: quantType, GroupSize: groupSize}
+ }
+
+ return infos, globalQuantType, globalGroupSize, nil
+}
+
+func parseGlobalQuantMetadata(header map[string]json.RawMessage) (quantType string, groupSize int) {
+ metaRaw, ok := header["__metadata__"]
+ if !ok {
+ return "", 0
+ }
+
+ var meta map[string]string
+ if err := json.Unmarshal(metaRaw, &meta); err != nil {
+ return "", 0
+ }
+
+ quantType = meta["quant_type"]
+ if gs := meta["group_size"]; gs != "" {
+ groupSize, _ = strconv.Atoi(gs)
+ }
+ return quantType, groupSize
+}
+
+func mainTensorNames(header map[string]json.RawMessage) []string {
+ names := make([]string, 0, len(header))
+ for name := range header {
+ if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
+ continue
+ }
+ names = append(names, name)
+ }
+ sort.Strings(names)
+ return names
+}
+
+func inferQuantTypeFromShapes(header map[string]json.RawMessage, tensorName string, hintQuantType string) (string, int) {
+ type tensorShape struct {
+ Shape []int64 `json:"shape"`
+ }
+
+ mainRaw, ok := header[tensorName]
+ if !ok {
+ return "", 0
+ }
+ scaleRaw, ok := header[tensorName+".scale"]
+ if !ok {
+ return "", 0
+ }
+
+ var mainInfo tensorShape
+ if err := json.Unmarshal(mainRaw, &mainInfo); err != nil || len(mainInfo.Shape) == 0 {
+ return "", 0
+ }
+
+ var scaleInfo tensorShape
+ if err := json.Unmarshal(scaleRaw, &scaleInfo); err != nil || len(scaleInfo.Shape) == 0 {
+ return "", 0
+ }
+
+ weightCols := int(mainInfo.Shape[len(mainInfo.Shape)-1])
+ scalesCols := int(scaleInfo.Shape[len(scaleInfo.Shape)-1])
+ if weightCols <= 0 || scalesCols <= 0 {
+ return "", 0
+ }
+
+ groupSize4 := weightCols * 8 / scalesCols
+ groupSize8 := weightCols * 4 / scalesCols
+
+ switch {
+ case groupSize4 == 32:
+ return "INT4", 32
+ case groupSize8 == 64:
+ return "INT8", 64
+ case groupSize4 == 64 && groupSize8 == 32:
+ h := strings.ToUpper(hintQuantType)
+ if strings.Contains(h, "8") {
+ return "INT8", 32
+ }
+ if strings.Contains(h, "4") {
+ return "INT4", 64
+ }
+ }
+
+ if isCommonGroupSize(groupSize4) && !isCommonGroupSize(groupSize8) {
+ return "INT4", groupSize4
+ }
+ if isCommonGroupSize(groupSize8) && !isCommonGroupSize(groupSize4) {
+ return "INT8", groupSize8
+ }
+
+ return "", 0
+}
diff --git a/x/mlxrunner/model/root_stub.go b/x/mlxrunner/model/root_stub.go
new file mode 100644
index 00000000000..3fcda9c25dd
--- /dev/null
+++ b/x/mlxrunner/model/root_stub.go
@@ -0,0 +1,3 @@
+//go:build !mlx
+
+package model
diff --git a/x/mlxrunner/pipeline.go b/x/mlxrunner/pipeline.go
new file mode 100644
index 00000000000..274fc9be636
--- /dev/null
+++ b/x/mlxrunner/pipeline.go
@@ -0,0 +1,129 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ "bytes"
+ "errors"
+ "log/slog"
+ "time"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+)
+
+func (r *Runner) TextGenerationPipeline(request Request) error {
+ if r.Model == nil {
+ return errors.New("model not loaded")
+ }
+
+ enableCompile := true
+ if modelCompile, ok := r.Model.(interface{ EnableCompile() bool }); ok {
+ enableCompile = modelCompile.EnableCompile()
+ }
+ if enableCompile {
+ mlx.EnableCompile()
+ } else {
+ mlx.DisableCompile()
+ }
+
+ inputs := r.Tokenizer.Encode(request.Prompt, true)
+
+ caches, tokens := r.FindNearestCache(inputs)
+ if len(caches) == 0 {
+ if cacheFactory, ok := r.Model.(interface{ NewCaches() []cache.Cache }); ok {
+ caches = cacheFactory.NewCaches()
+ } else {
+ caches = make([]cache.Cache, r.Model.NumLayers())
+ for i := range caches {
+ caches[i] = cache.NewKVCache()
+ }
+ }
+ }
+
+ total, processed := len(tokens), 0
+ slog.Info("Prompt processing progress", "processed", processed, "total", total)
+ for total-processed > 1 {
+ n := min(2<<10, total-processed-1)
+ temp := r.Model.Forward(mlx.FromValues(tokens[processed:processed+n], n).ExpandDims(0), caches)
+ defer mlx.Free(temp)
+ mlx.Eval(func() []*mlx.Array {
+ s := make([]*mlx.Array, 2*len(caches))
+ for i, c := range caches {
+ s[2*i], s[2*i+1] = c.State()
+ }
+ return s
+ }()...)
+ processed += n
+ slog.Info("Prompt processing progress", "processed", processed, "total", total)
+ mlx.ClearCache()
+ }
+
+ step := func(token *mlx.Array) (*mlx.Array, *mlx.Array) {
+ fwd := r.Model.Forward(token.ExpandDims(0), caches)
+ logits := r.Model.Unembed(fwd)
+ logits = logits.Slice(mlx.Slice(), mlx.Slice(logits.Dim(1)-1), mlx.Slice()).Squeeze(1)
+
+ logprobs := logits.Subtract(logits.Logsumexp(true))
+ return request.Sample(logprobs), logprobs
+ }
+
+ sample, logprobs := step(mlx.FromValues(tokens[processed:], total-processed))
+ mlx.AsyncEval(sample, logprobs)
+
+ var b bytes.Buffer
+
+ now := time.Now()
+ final := Response{Done: true, PromptTokens: total, CompletionTokens: request.Options.MaxTokens, DoneReason: 1}
+ outputs := make([]int32, 0, request.Options.MaxTokens)
+ for i := range request.Options.MaxTokens {
+ nextSample, nextLogprobs := step(sample)
+ mlx.AsyncEval(nextSample, nextLogprobs)
+
+ if i == 0 {
+ slog.Info("Prompt processing progress", "processed", total, "total", total)
+ mlx.Eval(sample)
+ final.PromptTokensDuration = time.Since(now)
+ now = time.Now()
+ }
+
+ output := int32(sample.Int())
+ outputs = append(outputs, output)
+
+ if r.Tokenizer.IsEOS(output) {
+ final.Token = int(output)
+ final.DoneReason = 0
+ final.CompletionTokens = i
+ break
+ }
+
+ request.Responses <- Response{
+ Text: r.Decode(output, &b),
+ Token: int(output),
+ }
+
+ mlx.Free(sample, logprobs)
+ if i%256 == 0 {
+ mlx.ClearCache()
+ }
+
+ sample, logprobs = nextSample, nextLogprobs
+ }
+
+ mlx.Free(sample, logprobs)
+ final.CompletionTokensDuration = time.Since(now)
+ request.Responses <- final
+ r.InsertCache(append(inputs, outputs...), caches)
+ return nil
+}
+
+func (r Runner) Decode(sample int32, b *bytes.Buffer) string {
+ token := r.Tokenizer.Decode([]int32{sample})
+
+ if _, err := b.WriteString(token); err != nil {
+ slog.Error("Failed to write token to buffer", "error", err)
+ return ""
+ }
+
+ return flushValidUTF8Prefix(b)
+}
diff --git a/x/mlxrunner/runner.go b/x/mlxrunner/runner.go
new file mode 100644
index 00000000000..0b24fdb3dbf
--- /dev/null
+++ b/x/mlxrunner/runner.go
@@ -0,0 +1,174 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ "context"
+ "log/slog"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ "golang.org/x/sync/errgroup"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/model"
+ "github.com/ollama/ollama/x/mlxrunner/model/base"
+ "github.com/ollama/ollama/x/mlxrunner/sample"
+ "github.com/ollama/ollama/x/tokenizer"
+)
+
+type Request struct {
+ TextCompletionsRequest
+ Responses chan Response
+ Pipeline func(Request) error
+
+ sample.Sampler
+ caches []cache.Cache
+}
+
+type TextCompletionsRequest struct {
+ Prompt string `json:"prompt"`
+ Options struct {
+ Temperature float32 `json:"temperature"`
+ TopP float32 `json:"top_p"`
+ MinP float32 `json:"min_p"`
+ TopK int `json:"top_k"`
+ MaxTokens int `json:"max_tokens"`
+
+ // Deprecated: use MaxTokens instead
+ NumPredict int `json:"num_predict"`
+ } `json:"options"`
+}
+
+type Response struct {
+ Text string `json:"content,omitempty"`
+ Token int `json:"token,omitempty"`
+ Logprobs []float32 `json:"logprobs,omitempty"`
+ Done bool `json:"done,omitempty"`
+ DoneReason int `json:"done_reason,omitempty"`
+
+ PromptTokens int `json:"prompt_eval_count,omitempty"`
+ PromptTokensDuration time.Duration `json:"prompt_eval_duration,omitempty"`
+ CompletionTokens int `json:"eval_count,omitempty"`
+ CompletionTokensDuration time.Duration `json:"eval_duration,omitempty"`
+ TotalTokens int `json:"total_tokens,omitempty"`
+}
+
+type Runner struct {
+ Model base.Model
+ Tokenizer *tokenizer.Tokenizer
+ Requests chan Request
+ CacheEntries map[int32]*CacheEntry
+}
+
+func (r *Runner) Load(modelName string) error {
+ root, err := model.Open(modelName)
+ if err != nil {
+ return err
+ }
+ defer root.Close()
+
+ m, err := base.New(root)
+ if err != nil {
+ return err
+ }
+
+ // Load all tensor blobs from manifest
+ tensors, err := loadTensorsFromManifest(root)
+ if err != nil {
+ return err
+ }
+
+ // Assign weights to model (model-specific logic)
+ loadWeights := base.Weights(m)
+ if err := loadWeights(tensors); err != nil {
+ return err
+ }
+
+ r.Model = m
+ r.Tokenizer = m.Tokenizer()
+ return nil
+}
+
+// loadTensorsFromManifest loads all tensor blobs from the manifest into a
+// flat map, deduplicating by digest and remapping safetensors key suffixes.
+//
+// Uses a two-phase approach: first loads all raw tensors, then remaps
+// .bias → _qbias with complete knowledge of which base names have .scale
+// entries. This avoids a race condition where Go map iteration order could
+// cause .bias to be processed before .scale within the same blob.
+func loadTensorsFromManifest(root *model.Root) (map[string]*mlx.Array, error) {
+ // Phase 1: Load all tensors raw from all blobs
+ rawTensors := make(map[string]*mlx.Array)
+ seen := make(map[string]bool)
+ for _, layer := range root.Manifest.GetTensorLayers("") {
+ if seen[layer.Digest] {
+ continue
+ }
+ seen[layer.Digest] = true
+ blobPath := root.Manifest.BlobPath(layer.Digest)
+ for name, arr := range mlx.Load(blobPath) {
+ rawTensors[name] = arr
+ }
+ }
+
+ // Phase 2: Identify all base names that have .scale tensors and remap them
+ scaleBaseNames := make(map[string]bool)
+ allTensors := make(map[string]*mlx.Array, len(rawTensors))
+ for name, arr := range rawTensors {
+ if strings.HasSuffix(name, ".scale") {
+ baseName := strings.TrimSuffix(name, ".scale")
+ allTensors[baseName+"_scale"] = arr
+ scaleBaseNames[baseName] = true
+ }
+ }
+
+ // Phase 3: Process remaining tensors with complete scale knowledge
+ for name, arr := range rawTensors {
+ if strings.HasSuffix(name, ".scale") {
+ continue // already handled
+ }
+ if strings.HasSuffix(name, ".bias") && !strings.HasSuffix(name, ".weight_qbias") {
+ baseName := strings.TrimSuffix(name, ".bias")
+ if scaleBaseNames[baseName] {
+ allTensors[baseName+"_qbias"] = arr
+ } else {
+ allTensors[name] = arr
+ }
+ } else {
+ allTensors[name] = arr
+ }
+ }
+
+ slog.Info("Loaded tensors from manifest", "count", len(allTensors))
+ return allTensors, nil
+}
+
+func (r *Runner) Run(host, port string, mux http.Handler) error {
+ g, ctx := errgroup.WithContext(context.Background())
+
+ g.Go(func() error {
+ for {
+ select {
+ case <-ctx.Done():
+ return nil
+ case request := <-r.Requests:
+ if err := request.Pipeline(request); err != nil {
+ break
+ }
+
+ close(request.Responses)
+ }
+ }
+ })
+
+ g.Go(func() error {
+ slog.Info("Starting HTTP server", "host", host, "port", port)
+ return http.ListenAndServe(net.JoinHostPort(host, port), mux)
+ })
+
+ return g.Wait()
+}
diff --git a/x/mlxrunner/sample/sample.go b/x/mlxrunner/sample/sample.go
new file mode 100644
index 00000000000..b0656973ff0
--- /dev/null
+++ b/x/mlxrunner/sample/sample.go
@@ -0,0 +1,77 @@
+//go:build mlx
+
+package sample
+
+import (
+ "math"
+
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+)
+
+type Sampler interface {
+ Sample(*mlx.Array) *mlx.Array
+}
+
+func New(temp, top_p, min_p float32, top_k int) Sampler {
+ if temp == 0 {
+ return greedy{}
+ }
+
+ var samplers []Sampler
+ if top_p > 0 && top_p < 1 {
+ samplers = append(samplers, TopP(top_p))
+ }
+
+ if min_p != 0 {
+ samplers = append(samplers, MinP(min_p))
+ }
+
+ if top_k > 0 {
+ samplers = append(samplers, TopK(top_k))
+ }
+
+ samplers = append(samplers, Temperature(temp))
+ return chain(samplers)
+}
+
+type greedy struct{}
+
+func (greedy) Sample(logits *mlx.Array) *mlx.Array {
+ return logits.Argmax(-1, false)
+}
+
+type chain []Sampler
+
+func (c chain) Sample(logits *mlx.Array) *mlx.Array {
+ for _, sampler := range c {
+ logits = sampler.Sample(logits)
+ }
+ return logits
+}
+
+type Temperature float32
+
+func (t Temperature) Sample(logits *mlx.Array) *mlx.Array {
+ return mlx.DivScalar(logits, float32(t)).Categorical(-1)
+}
+
+type TopP float32
+
+func (p TopP) Sample(logprobs *mlx.Array) *mlx.Array {
+ // TODO: implement
+ return logprobs
+}
+
+type MinP float32
+
+func (p MinP) Sample(logprobs *mlx.Array) *mlx.Array {
+ // TODO: implement
+ return logprobs
+}
+
+type TopK int
+
+func (k TopK) Sample(logprobs *mlx.Array) *mlx.Array {
+ mask := logprobs.Negative().ArgpartitionAxis(int(k)-1, -1).Slice(mlx.Slice(), mlx.Slice(int(k), 0))
+ return logprobs.PutAlongAxis(mask, mlx.FromValue(float32(math.Inf(-1))), -1)
+}
diff --git a/x/mlxrunner/server.go b/x/mlxrunner/server.go
new file mode 100644
index 00000000000..ef1e0dd1c90
--- /dev/null
+++ b/x/mlxrunner/server.go
@@ -0,0 +1,182 @@
+//go:build mlx
+
+package mlxrunner
+
+import (
+ "bytes"
+ "cmp"
+ "encoding/json"
+ "flag"
+ "fmt"
+ "io"
+ "log/slog"
+ "net/http"
+ "os"
+ "strconv"
+ "time"
+
+ "github.com/ollama/ollama/envconfig"
+ "github.com/ollama/ollama/logutil"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/sample"
+)
+
+func Execute(args []string) error {
+ slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
+
+ if err := mlx.CheckInit(); err != nil {
+ return fmt.Errorf("MLX not available: %w", err)
+ }
+
+ var (
+ modelName string
+ port int
+ )
+
+ flagSet := flag.NewFlagSet("mlxrunner", flag.ExitOnError)
+ flagSet.StringVar(&modelName, "model", "", "Model name")
+ flagSet.IntVar(&port, "port", 0, "Port to listen on")
+ _ = flagSet.Bool("verbose", false, "Enable debug logging")
+ flagSet.Parse(args)
+
+ runner := Runner{
+ Requests: make(chan Request),
+ CacheEntries: make(map[int32]*CacheEntry),
+ }
+
+ if err := runner.Load(modelName); err != nil {
+ return err
+ }
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("GET /v1/status", func(w http.ResponseWriter, r *http.Request) {
+ if err := json.NewEncoder(w).Encode(map[string]any{
+ "status": 0,
+ "progress": 100,
+ }); err != nil {
+ slog.Error("Failed to encode response", "error", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ })
+
+ mux.HandleFunc("/v1/models", func(w http.ResponseWriter, r *http.Request) {
+ switch r.Method {
+ case "POST":
+ fallthrough
+ case "GET":
+ if err := json.NewEncoder(w).Encode(map[string]any{
+ "Success": true,
+ }); err != nil {
+ slog.Error("Failed to encode response", "error", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ case "DELETE":
+ // TODO: cleanup model and cache
+ }
+ })
+
+ mux.HandleFunc("POST /v1/completions", func(w http.ResponseWriter, r *http.Request) {
+ request := Request{Responses: make(chan Response)}
+
+ if err := json.NewDecoder(r.Body).Decode(&request.TextCompletionsRequest); err != nil {
+ slog.Error("Failed to decode request", "error", err)
+ http.Error(w, "Bad Request", http.StatusBadRequest)
+ return
+ }
+
+ request.Options.MaxTokens = cmp.Or(request.Options.MaxTokens, request.Options.NumPredict)
+ if request.Options.MaxTokens < 1 {
+ request.Options.MaxTokens = 16 << 10
+ }
+
+ request.Pipeline = runner.TextGenerationPipeline
+ request.Sampler = sample.New(
+ request.Options.Temperature,
+ request.Options.TopP,
+ request.Options.MinP,
+ request.Options.TopK,
+ )
+
+ runner.Requests <- request
+
+ w.Header().Set("Content-Type", "application/jsonl")
+ w.WriteHeader(http.StatusOK)
+ enc := json.NewEncoder(w)
+ for response := range request.Responses {
+ if err := enc.Encode(response); err != nil {
+ slog.Error("Failed to encode response", "error", err)
+ return
+ }
+
+ if f, ok := w.(http.Flusher); ok {
+ f.Flush()
+ }
+ }
+ })
+
+ mux.HandleFunc("POST /v1/tokenize", func(w http.ResponseWriter, r *http.Request) {
+ var b bytes.Buffer
+ if _, err := io.Copy(&b, r.Body); err != nil {
+ slog.Error("Failed to read request body", "error", err)
+ http.Error(w, "Bad Request", http.StatusBadRequest)
+ return
+ }
+
+ tokens := runner.Tokenizer.Encode(b.String(), true)
+
+ if err := json.NewEncoder(w).Encode(tokens); err != nil {
+ slog.Error("Failed to encode response", "error", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ return
+ }
+ })
+
+ for source, target := range map[string]string{
+ "GET /health": "/v1/status",
+ "POST /load": "/v1/models",
+ "POST /completion": "/v1/completions",
+ } {
+ mux.Handle(source, http.RedirectHandler(target, http.StatusPermanentRedirect))
+ }
+
+ return runner.Run("127.0.0.1", strconv.Itoa(port), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ recorder := &statusRecorder{ResponseWriter: w, code: http.StatusOK}
+ t := time.Now()
+ mux.ServeHTTP(recorder, r)
+
+ var level slog.Level
+ switch {
+ case recorder.code >= 500:
+ level = slog.LevelError
+ case recorder.code >= 400:
+ level = slog.LevelWarn
+ case recorder.code >= 300:
+ return
+ }
+
+ slog.Log(r.Context(), level, "ServeHTTP", "method", r.Method, "path", r.URL.Path, "took", time.Since(t), "status", recorder.Status())
+ }))
+}
+
+type statusRecorder struct {
+ http.ResponseWriter
+ code int
+}
+
+func (w *statusRecorder) WriteHeader(code int) {
+ w.code = code
+ w.ResponseWriter.WriteHeader(code)
+}
+
+func (w *statusRecorder) Status() string {
+ return strconv.Itoa(w.code) + " " + http.StatusText(w.code)
+}
+
+func (w *statusRecorder) Flush() {
+ if f, ok := w.ResponseWriter.(http.Flusher); ok {
+ f.Flush()
+ }
+}
diff --git a/x/mlxrunner/server_stub.go b/x/mlxrunner/server_stub.go
new file mode 100644
index 00000000000..3b0f35500ce
--- /dev/null
+++ b/x/mlxrunner/server_stub.go
@@ -0,0 +1,10 @@
+//go:build !mlx
+
+package mlxrunner
+
+import "errors"
+
+// Execute returns an error when not built with MLX support.
+func Execute(args []string) error {
+ return errors.New("MLX runner not available: build with mlx tag")
+}
diff --git a/x/mlxrunner/utf8_buffer.go b/x/mlxrunner/utf8_buffer.go
new file mode 100644
index 00000000000..5d155b47877
--- /dev/null
+++ b/x/mlxrunner/utf8_buffer.go
@@ -0,0 +1,47 @@
+package mlxrunner
+
+import (
+ "bytes"
+ "unicode/utf8"
+)
+
+// flushValidUTF8Prefix returns and consumes the longest valid UTF-8 prefix
+// currently buffered, leaving any incomplete trailing bytes in place.
+func flushValidUTF8Prefix(b *bytes.Buffer) string {
+ data := b.Bytes()
+ if len(data) == 0 {
+ return ""
+ }
+
+ prefix := validUTF8PrefixLen(data)
+ if prefix == 0 {
+ return ""
+ }
+
+ text := string(data[:prefix])
+ b.Next(prefix)
+ return text
+}
+
+func validUTF8PrefixLen(data []byte) int {
+ i := 0
+ prefix := 0
+ for i < len(data) {
+ r, size := utf8.DecodeRune(data[i:])
+ if r == utf8.RuneError && size == 1 {
+ if !utf8.FullRune(data[i:]) {
+ break
+ }
+
+ // Invalid UTF-8 byte; consume one byte to guarantee forward progress.
+ i++
+ prefix = i
+ continue
+ }
+
+ i += size
+ prefix = i
+ }
+
+ return prefix
+}
diff --git a/x/mlxrunner/utf8_buffer_test.go b/x/mlxrunner/utf8_buffer_test.go
new file mode 100644
index 00000000000..aaaf77b6311
--- /dev/null
+++ b/x/mlxrunner/utf8_buffer_test.go
@@ -0,0 +1,46 @@
+package mlxrunner
+
+import (
+ "bytes"
+ "testing"
+)
+
+func TestFlushValidUTF8Prefix_PreservesIncompleteRune(t *testing.T) {
+ var b bytes.Buffer
+
+ b.Write([]byte{0xE3, 0x81})
+ if got := flushValidUTF8Prefix(&b); got != "" {
+ t.Fatalf("first flush = %q, want empty", got)
+ }
+
+ b.Write([]byte{0x93, 0xE3})
+ if got := flushValidUTF8Prefix(&b); got != "こ" {
+ t.Fatalf("second flush = %q, want %q", got, "こ")
+ }
+
+ if got := b.Bytes(); !bytes.Equal(got, []byte{0xE3}) {
+ t.Fatalf("buffer after second flush = %v, want %v", got, []byte{0xE3})
+ }
+
+ b.Write([]byte{0x82, 0x93})
+ if got := flushValidUTF8Prefix(&b); got != "ん" {
+ t.Fatalf("third flush = %q, want %q", got, "ん")
+ }
+
+ if b.Len() != 0 {
+ t.Fatalf("buffer not empty after third flush: %d", b.Len())
+ }
+}
+
+func TestFlushValidUTF8Prefix_ValidText(t *testing.T) {
+ var b bytes.Buffer
+ b.WriteString("hello 世界")
+
+ if got := flushValidUTF8Prefix(&b); got != "hello 世界" {
+ t.Fatalf("flush = %q, want %q", got, "hello 世界")
+ }
+
+ if b.Len() != 0 {
+ t.Fatalf("buffer not empty after flush: %d", b.Len())
+ }
+}
diff --git a/x/model/bytepairencoding_test.go b/x/model/bytepairencoding_test.go
deleted file mode 100644
index 2a7041284a2..00000000000
--- a/x/model/bytepairencoding_test.go
+++ /dev/null
@@ -1,322 +0,0 @@
-package model
-
-import (
- "bufio"
- "encoding/json"
- "math"
- "os"
- "path/filepath"
- "slices"
- "strconv"
- "strings"
- "testing"
-
- "github.com/google/go-cmp/cmp"
-)
-
-func llama(t testing.TB) BytePairEncoding {
- t.Helper()
-
- f, err := os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "encoder.json"))
- if err != nil {
- t.Fatal(err)
- }
- defer f.Close()
-
- vocab := make(map[string]int32)
- if err := json.NewDecoder(f).Decode(&vocab); err != nil {
- t.Fatal(err)
- }
-
- types := make([]int32, len(vocab))
- tokens := make([]string, len(vocab))
- for token, id := range vocab {
- tokens[id] = token
- types[id] = 1
- }
-
- for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
- if _, ok := vocab[token]; !ok {
- tokens = append(tokens, token) //nolint:makezero
- types = append(types, 3) //nolint:makezero
- vocab[token] = int32(len(vocab))
- }
- }
-
- f, err = os.Open(filepath.Join("..", "..", "model", "testdata", "llama3.2", "vocab.bpe"))
- if err != nil {
- t.Fatal(err)
- }
- defer f.Close()
-
- merges := make([]string, 0, 50000)
-
- scanner := bufio.NewScanner(f)
- for scanner.Scan() {
- if !strings.HasPrefix(scanner.Text(), "#") {
- merges = append(merges, scanner.Text())
- }
- }
-
- return NewBytePairEncoding(
- &Vocabulary{
- Values: tokens,
- Types: types,
- Merges: merges,
- },
- "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
- )
-}
-
-func TestLlama(t *testing.T) {
- tokenizer := llama(t)
-
- t.Run("simple", func(t *testing.T) {
- t.Parallel()
-
- ids, err := tokenizer.Encode("hello world", true)
- if err != nil {
- t.Error(err)
- }
-
- if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
-
- s, err := tokenizer.Decode([]int32{15339, 1917})
- if err != nil {
- t.Fatal(err)
- }
-
- if s != "hello world" {
- t.Errorf("got %q, want hello world", s)
- }
-
- ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
- if err != nil {
- t.Error(err)
- }
-
- if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- })
-
- t.Run("simple repeated", func(t *testing.T) {
- t.Parallel()
-
- cases := map[string][]int32{
- strings.Repeat("0", 1): {15},
- strings.Repeat("0", 2): {410},
- strings.Repeat("0", 3): {931},
- strings.Repeat("0", 4): {931, 15},
- strings.Repeat("0", 5): {931, 410},
- strings.Repeat("0", 6): {931, 931},
- strings.Repeat("0", 7): {931, 931, 15},
- strings.Repeat("0", 8): {931, 931, 410},
- strings.Repeat("0", 9): {931, 931, 931},
- strings.Repeat("0", 10): {931, 931, 931, 15},
- strings.Repeat("0", 11): {931, 931, 931, 410},
- strings.Repeat("0", 12): {931, 931, 931, 931},
- strings.Repeat("0", 13): {931, 931, 931, 931, 15},
- strings.Repeat("0", 14): {931, 931, 931, 931, 410},
- strings.Repeat("0", 15): {931, 931, 931, 931, 931},
- strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
- strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
- }
-
- for s, want := range cases {
- ids, err := tokenizer.Encode(s, true)
- if err != nil {
- t.Error(err)
- }
-
- if diff := cmp.Diff(want, ids); diff != "" {
- t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
- }
- }
- })
-
- t.Run("basic roundtrip", func(t *testing.T) {
- t.Parallel()
-
- cases := []string{
- "hello",
- "hello ",
- "hello ",
- " hello",
- " hello ",
- " hello ",
- "hello world",
- "请考试我的软件!12345",
- }
-
- for _, want := range cases {
- ids, err := tokenizer.Encode(want, true)
- if err != nil {
- t.Error(err)
- }
-
- if got, err := tokenizer.Decode(ids); err != nil {
- t.Fatal(err)
- } else if got != want {
- t.Errorf("got %q, want %q", got, want)
- }
- }
- })
-
- t.Run("special", func(t *testing.T) {
- t.Parallel()
-
- cases := map[string][]int32{
- "<|begin_of_text|>A B!": {128000, 32, 426, 0},
- "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
- "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
- "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
- }
-
- for s, want := range cases {
- ids, err := tokenizer.Encode(s, true)
- if err != nil {
- t.Fatal(err)
- }
-
- if diff := cmp.Diff(want, ids); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- }
- })
-
- t.Run("split", func(t *testing.T) {
- t.Parallel()
-
- cases := map[string][]string{
- "Hello World!": {"Hello", " World", "!"},
- "I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
- "In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
- "Hello!! ...world": {"Hello", "!!", " ...", "world"},
- "Hello World": {"Hello", " ", " World"},
- "Hello\nWorld": {"Hello", "\n", "World"},
- "Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
- }
-
- for s, want := range cases {
- got := slices.Collect(tokenizer.split(s))
- if diff := cmp.Diff(want, got); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- }
- })
-
- t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
- t.Parallel()
-
- for b := 0x00; b <= 0xFF; b++ {
- input := string(rune(b))
- ids, err := tokenizer.Encode(input, false)
- if err != nil {
- t.Errorf("failed to encode rune 0x%02X: %v", b, err)
- continue
- }
-
- decoded, err := tokenizer.Decode(ids)
- if err != nil {
- t.Errorf("failed to decode rune 0x%02X: %v", b, err)
- continue
- }
-
- if b == 0x00 {
- if len(decoded) != 0 {
- t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
- }
- continue
- }
-
- if decoded != input {
- t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
- }
- }
- })
-}
-
-func BenchmarkBytePairEncoding(b *testing.B) {
- tokenizer := llama(b)
- bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
- if err != nil {
- b.Fatal(err)
- }
-
- for i := range 8 {
- n := min(int(math.Pow10(i)), len(bts))
- bts := bts[:n]
- b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
- b.ResetTimer()
- for b.Loop() {
- _, err := tokenizer.Encode(string(bts), true)
- if err != nil {
- b.Fatal(err)
- }
- }
- })
-
- b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
- ids, err := tokenizer.Encode(string(bts), true)
- if err != nil {
- b.Fatal(err)
- }
-
- b.ResetTimer()
- for b.Loop() {
- _, err := tokenizer.Decode(ids)
- if err != nil {
- b.Fatal(err)
- }
- }
- })
-
- b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
- b.ResetTimer()
- for b.Loop() {
- slices.Collect(tokenizer.split(string(bts)))
- }
- })
- }
-}
-
-func TestSplit(t *testing.T) {
- cases := []struct {
- name string
- patterns,
- want []string
- }{
- {
- name: "default",
- want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " 123", " 一二三"},
- },
- {
- name: "unicode",
- patterns: []string{
- "\\p{N}{1,3}",
- `[一-龥-ゟ゠-ヿ]+`,
- "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+",
- },
- want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "123", " ", "一二三"},
- },
- {
- name: "individual digits",
- patterns: []string{
- "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
- },
- want: []string{"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?", " ", "1", "2", "3", " 一二三"},
- },
- }
-
- for _, tt := range cases {
- t.Run(tt.name, func(t *testing.T) {
- tokenizer := NewBytePairEncoding(nil, tt.patterns...)
- if diff := cmp.Diff(tt.want, slices.Collect(tokenizer.split("Hello, WORLD!! How's it going? 123 一二三"))); diff != "" {
- t.Errorf("no match (-theirs +ours):\n%s", diff)
- }
- })
- }
-}
diff --git a/x/model/input/input.go b/x/model/input/input.go
deleted file mode 100644
index 05857e20a95..00000000000
--- a/x/model/input/input.go
+++ /dev/null
@@ -1,76 +0,0 @@
-package input
-
-import "github.com/ollama/ollama/x/ml"
-
-// Multimodal is a multimodal embedding or a component of one.
-// For example, it could be a row of an image that can be processed
-// independently.
-type Multimodal struct {
- // Tensor is the embedding data. Implementations may chose what to
- // store here or it may be nil if not needed. However, any ml.Tensor
- // objects must be stored here and not in Data.
- Tensor ml.Tensor
-
- // Data is implementation-specific opaque data, such as metadata on how
- // to layout Tensor. It may be nil if not needed. It may also store larger
- // objects such as complete images if they are to be processed later.
- Data any
-}
-
-// Input represents one token in the input stream
-type Input struct {
- // Token is a single element of text.
- Token int32
-
- // Multimodal is represents a non-text element such as an
- // image (or part of one if the image can be processed in pieces).
- // It may be used either together with Token or on its own.
- Multimodal []Multimodal
-
- // MultimodalHash is a unique representation of the data
- // stored in Multimodal, used for caching and comparing
- // equality.
- MultimodalHash uint64
-
- // SameBatch forces the following number of tokens to be processed
- // in a single batch, breaking and extending batches as needed.
- // Useful for things like images that must be processed in one
- // shot.
- SameBatch int
-}
-
-// MultimodalIndex is a multimodal element (such as an image)
-// together with an index into the slice of Inputs with the
-// corresponding token. Note that the index is not the same
-// as the position - to find that use the index with the
-// Positions slice.
-type MultimodalIndex struct {
- Index int
- Multimodal []Multimodal
-}
-
-// Batch contains the inputs for a model forward pass
-type Batch struct {
- // Inputs is the input tokens, including placeholders for multimodal inputs.
- Inputs ml.Tensor
-
- // Outputs are the set of indicies into Inputs for which output data should
- // be returned.
- Outputs ml.Tensor
-
- // TODO maybe not the optimal way to handle this
- // Offset of final tensor in the final batch
- Offset int
-
- // Positions is the position for each Input, relative to its sequence. Equal
- // in length to Inputs.
- Positions []int32
-
- // Sequences is the sequence for each Input. Equal in length to Inputs.
- Sequences []int
-
- // Multimodal is a set of multimodal embeddings previously created by
- // EncodeMultimodal, along with an index into Inputs. Unused for text-only
- // models or for batches without multimodal elements.
- Multimodal []MultimodalIndex
-}
diff --git a/x/model/model.go b/x/model/model.go
deleted file mode 100644
index 60c3d1487ed..00000000000
--- a/x/model/model.go
+++ /dev/null
@@ -1,333 +0,0 @@
-package model
-
-import (
- "errors"
- "fmt"
- _ "image/jpeg"
- _ "image/png"
- "log/slog"
- "os"
- "reflect"
- "strconv"
- "strings"
-
- _ "golang.org/x/image/bmp"
- _ "golang.org/x/image/tiff"
- _ "golang.org/x/image/webp"
-
- "github.com/ollama/ollama/fs"
- fsggml "github.com/ollama/ollama/fs/ggml"
- "github.com/ollama/ollama/logutil"
- "github.com/ollama/ollama/x/kvcache"
- "github.com/ollama/ollama/x/ml"
- _ "github.com/ollama/ollama/x/ml/backend"
- "github.com/ollama/ollama/x/ml/nn/pooling"
- "github.com/ollama/ollama/x/model/input"
-)
-
-var (
- ErrNoVisionModel = errors.New("this model is missing data required for image input")
- ErrUnsupportedModel = errors.New("model not supported")
- ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
-)
-
-// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
-type Model interface {
- Forward(ml.Context, input.Batch) (ml.Tensor, error)
-
- Backend() ml.Backend
- Config() config
-}
-
-// MultimodalProcessor must be implemented by multimodal models.
-type MultimodalProcessor interface {
- // EncodeMultimodal processes a single input (such as an image) and
- // generates an output (typically an embedding) that can be used by the model.
- //
- // The return value is one or more tensors, each with optional model-specific
- // opaque metadata. Typically, the tensors might be views into an embedding
- // with each view representing a chunk of data that can be processed independently
- // in different batches.
- //
- // The result may be cached by the runner.
- EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
-
- // PostTokenize is called after tokenization to allow the model to edit the
- // input stream to correctly arrange multimodal elements.
- //
- // The input is a slice of tokens with the results of EncodeMultimodal interleaved
- // in the order that the user provided them. Each element of the slice will be
- // either a single token or single multimodal object.
- //
- // The model must ensure that inputs are stored according to how they will be
- // processed and stored in the cache. For example, Llava-style models should insert
- // placeholder tokens equal to the feature size of the corresponding image with
- // the image itself attached to and split across these tokens. When Forward is called
- // a partial subset of these tokens may be submitted according to the batch size.
- //
- // This function is also responsible for updating MultimodalHash for any Multimodal
- // that is modified to ensure that there is a unique hash value that accurately
- // represents the contents.
- PostTokenize([]*input.Input) ([]*input.Input, error)
-}
-
-// Base implements the common fields and methods for all models
-type Base struct {
- b ml.Backend
- config
-}
-
-type config struct {
- Cache kvcache.Cache
-}
-
-// Backend returns the underlying backend that will run the model
-func (m *Base) Backend() ml.Backend {
- return m.b
-}
-
-func (m *Base) Config() config {
- return m.config
-}
-
-var models = make(map[string]func(fs.Config) (Model, error))
-
-// Register registers a model constructor for the given architecture
-func Register(name string, f func(fs.Config) (Model, error)) {
- if _, ok := models[name]; ok {
- panic("model: model already registered")
- }
-
- models[name] = f
-}
-
-// New initializes a new model instance with the provided configuration based on the metadata in the model file
-func New(modelPath string, params ml.BackendParams) (Model, error) {
- b, err := ml.NewBackend(modelPath, params)
- if err != nil {
- return nil, err
- }
-
- m, err := modelForArch(b.Config())
- if err != nil {
- return nil, err
- }
-
- base := Base{b: b, config: m.Config()}
- v := reflect.ValueOf(m)
- v.Elem().Set(populateFields(base, v.Elem()))
- return m, nil
-}
-
-func NewTextProcessor(s string) (TextProcessor, error) {
- r, err := os.Open(s)
- if err != nil {
- return nil, err
- }
- defer r.Close()
-
- meta, err := fsggml.Decode(r, -1)
- if err != nil {
- return nil, err
- }
-
- m, err := modelForArch(meta.KV())
- if err != nil {
- return nil, err
- }
-
- tp, ok := m.(TextProcessor)
- if !ok {
- return nil, ErrUnsupportedTokenizer
- }
- return tp, nil
-}
-
-func modelForArch(c fs.Config) (Model, error) {
- arch := c.Architecture()
- if pooling.Type(c.Uint("pooling_type")) != pooling.TypeNone {
- arch = arch + "_embed"
- }
-
- f, ok := models[arch]
- if !ok {
- return nil, ErrUnsupportedModel
- }
-
- return f(c)
-}
-
-func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
- t := v.Type()
-
- if t.Kind() == reflect.Struct {
- allNil := true
- for i := range t.NumField() {
- tt := t.Field(i).Type
- vv := v.Field(i)
- if !vv.CanSet() {
- continue
- }
-
- // make a copy
- tagsCopy := tags
- if tag := t.Field(i).Tag.Get("gguf"); tag != "" {
- tagsCopy = append(tagsCopy, parseTag(tag))
- }
-
- if tt == reflect.TypeOf((*Base)(nil)).Elem() {
- vv.Set(reflect.ValueOf(base))
- } else if tt == reflect.TypeOf((*ml.Tensor)(nil)).Elem() {
- var fn func([]Tag, string, string) [][]string
- fn = func(tags []Tag, prefix, suffix string) (fullNames [][]string) {
- if len(tags) > 0 {
- var names []string
- if tags[0].name != "" {
- for _, n := range append([]string{tags[0].name}, tags[0].alternatives...) {
- names = append(names, prefix+n+suffix)
- }
- }
- childNames := fn(tags[1:], tags[0].prefix, tags[0].suffix)
- if len(names) == 0 {
- // current tag has no name, use child names only
- fullNames = append(fullNames, childNames...)
- } else if len(childNames) == 0 {
- // current tag has names but no children, create branches for each name
- for _, name := range names {
- fullNames = append(fullNames, []string{name})
- }
- } else {
- // merge each name with each child
- for _, name := range names {
- for _, childName := range childNames {
- fullNames = append(fullNames, append([]string{name}, childName...))
- }
- }
- }
- }
-
- return fullNames
- }
-
- names := fn(tagsCopy, "", "")
- for _, name := range names {
- if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
- logutil.Trace("found tensor", "", tensor)
- vv.Set(reflect.ValueOf(tensor))
- break
- }
- }
- } else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
- setPointer(base, vv, tagsCopy)
- } else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
- for i := range vv.Len() {
- vvv := vv.Index(i)
- if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
- setPointer(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)}))
- } else {
- vvv.Set(populateFields(base, vvv, append(tagsCopy, Tag{name: strconv.Itoa(i)})...))
- }
- }
- }
-
- if !canNil(tt) || !vv.IsNil() {
- allNil = false
- }
- }
-
- if allNil {
- return reflect.Zero(t)
- }
- }
-
- return v
-}
-
-func setPointer(base Base, v reflect.Value, tags []Tag) {
- vv := v
- if v.Kind() == reflect.Interface {
- if v.IsNil() {
- return
- }
-
- vv = vv.Elem()
- }
-
- vv = reflect.Indirect(vv)
- if v.IsNil() {
- vv = reflect.New(v.Type().Elem()).Elem()
- }
-
- if f := populateFields(base, vv, tags...); f.CanAddr() {
- v.Set(f.Addr())
- }
-}
-
-type Tag struct {
- name,
- // prefix and suffix are applied to child tags
- prefix,
- suffix string
- alternatives []string
-}
-
-func parseTag(s string) (tag Tag) {
- parts := strings.Split(s, ",")
- if len(parts) > 0 {
- tag.name = parts[0]
-
- for _, part := range parts[1:] {
- if value, ok := strings.CutPrefix(part, "alt:"); ok && tag.name == "" {
- // elevate alternative to primary if no primary given
- tag.name = value
- slog.Warn("gguf tag has alt: but no primary name", "tag", s)
- } else if ok {
- tag.alternatives = append(tag.alternatives, value)
- }
- if value, ok := strings.CutPrefix(part, "pre:"); ok {
- tag.prefix = value
- }
- if value, ok := strings.CutPrefix(part, "suf:"); ok {
- tag.suffix = value
- }
- }
- }
-
- return
-}
-
-func canNil(t reflect.Type) bool {
- return t.Kind() == reflect.Chan ||
- t.Kind() == reflect.Func ||
- t.Kind() == reflect.Interface ||
- t.Kind() == reflect.Map ||
- t.Kind() == reflect.Pointer ||
- t.Kind() == reflect.Slice
-}
-
-func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
- if len(batch.Positions) != len(batch.Sequences) {
- return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
- }
-
- if len(batch.Positions) < 1 {
- return nil, errors.New("batch size cannot be less than 1")
- }
-
- cache := m.Config().Cache
- if cache != nil {
- err := cache.StartForward(ctx, batch, false)
- if err != nil {
- return nil, err
- }
- }
-
- t, err := m.Forward(ctx, batch)
- if err != nil {
- return nil, err
- }
-
- ctx.Forward(t)
-
- return t, nil
-}
diff --git a/x/model/models/gemma3/embed.go b/x/model/models/gemma3/embed.go
deleted file mode 100644
index 229cbcb50cc..00000000000
--- a/x/model/models/gemma3/embed.go
+++ /dev/null
@@ -1,58 +0,0 @@
-//go:build mlx
-
-package gemma3
-
-import (
- "github.com/ollama/ollama/fs"
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/ml/nn"
- "github.com/ollama/ollama/x/ml/nn/pooling"
- "github.com/ollama/ollama/x/model"
- "github.com/ollama/ollama/x/model/input"
-)
-
-type embedModel struct {
- model.Base
- model.SentencePiece
-
- *TextModel
- poolingType pooling.Type
-
- Dense [2]*nn.Linear `gguf:"dense"`
-}
-
-func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
- hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
- hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
- for _, dense := range m.Dense {
- hiddenStates = dense.Forward(ctx, hiddenStates)
- }
- hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
- return hiddenStates, nil
-}
-
-func newEmbedModel(c fs.Config) (model.Model, error) {
- m := &embedModel{
- SentencePiece: model.NewSentencePiece(
- &model.Vocabulary{
- Values: c.Strings("tokenizer.ggml.tokens"),
- Scores: c.Floats("tokenizer.ggml.scores"),
- Types: c.Ints("tokenizer.ggml.token_type"),
- AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
- BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
- AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
- EOS: append(
- []int32{
- int32(c.Uint("tokenizer.ggml.eos_token_id")),
- int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
- },
- c.Ints("tokenizer.ggml.eos_token_ids")...,
- ),
- },
- ),
- TextModel: newTextModel(c),
- poolingType: pooling.Type(c.Uint("pooling_type", 0)),
- }
-
- return m, nil
-}
diff --git a/x/model/models/gemma3/model.go b/x/model/models/gemma3/model.go
deleted file mode 100644
index 23f78f20740..00000000000
--- a/x/model/models/gemma3/model.go
+++ /dev/null
@@ -1,157 +0,0 @@
-//go:build mlx
-
-package gemma3
-
-import (
- "bytes"
- "image"
- "math"
- "slices"
-
- "github.com/ollama/ollama/fs"
- "github.com/ollama/ollama/x/kvcache"
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/ml/nn"
- "github.com/ollama/ollama/x/model"
- "github.com/ollama/ollama/x/model/input"
-)
-
-type Model struct {
- model.Base
- model.SentencePiece
-
- *VisionModel `gguf:"vision_tower.vision_model"`
- *TextModel `gguf:"language_model.model"`
-
- *MultiModalProjector `gguf:"multi_modal_projector"`
-
- ImageProcessor
-}
-
-var _ model.MultimodalProcessor = (*Model)(nil)
-
-type MultiModalProjector struct {
- SoftEmbNorm *nn.RMSNorm `gguf:"mm_soft_emb_norm"`
- InputProjection *nn.Linear `gguf:"mm_input_projection_weight"` // TODO .weight vs _weight
-
- tokensPerImage int
-}
-
-func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, imageSize, patchSize int, eps float32) ml.Tensor {
- l := visionOutputs.Dim(0)
-
- visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
- patchesPerImage := imageSize / patchSize
- visionOutputs = visionOutputs.Reshape(ctx, patchesPerImage, patchesPerImage, l)
-
- kernelSize := patchesPerImage / int(math.Sqrt(float64(p.tokensPerImage)))
- visionOutputs = visionOutputs.AvgPool2D(ctx, kernelSize, kernelSize, 0)
- visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0)*visionOutputs.Dim(1), l)
- visionOutputs = visionOutputs.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
- visionOutputs = p.SoftEmbNorm.Forward(ctx, visionOutputs, eps)
-
- // TODO: inputProjection must be transposed since they're incompatible with visionOutputs
- visionOutputs = visionOutputs.Matmul(ctx, p.InputProjection.Weight.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false))
- return visionOutputs
-}
-
-func New(c fs.Config) (model.Model, error) {
- // slog.Info("XXX Config", "c", c)
- m := Model{
- SentencePiece: model.NewSentencePiece(
- &model.Vocabulary{
- Values: c.Strings("tokenizer.ggml.tokens"),
- Scores: c.Floats("tokenizer.ggml.scores"),
- Types: c.Ints("tokenizer.ggml.token_type"),
- AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
- BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
- AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
- EOS: append(
- []int32{
- int32(c.Uint("tokenizer.ggml.eos_token_id")),
- int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
- },
- c.Ints("tokenizer.ggml.eos_token_ids")...,
- ),
- },
- ),
- ImageProcessor: newImageProcessor(c),
- VisionModel: newVisionModel(c),
- TextModel: newTextModel(c),
- MultiModalProjector: &MultiModalProjector{
- tokensPerImage: int(c.Uint("mm_tokens_per_image", 256)),
- },
- }
-
- // slidingWindowLen := int32(c.Uint("attention.sliding_window"))
- // m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(slidingWindowLen, m.Shift), kvcache.NewCausalCache(m.Shift))
-
- // TODO need to implement sliding window...
- m.Cache = kvcache.NewMLXCausalCache()
-
- return &m, nil
-}
-
-func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
- if len(m.VisionModel.Layers) == 0 {
- return nil, model.ErrNoVisionModel
- }
-
- image, _, err := image.Decode(bytes.NewReader(multimodalData))
- if err != nil {
- return nil, err
- }
-
- f32s, err := m.ImageProcessor.ProcessImage(image)
- if err != nil {
- return nil, err
- }
-
- pixelValues := ctx.Input().FromFloats(f32s,
- m.ImageProcessor.imageSize,
- m.ImageProcessor.imageSize,
- m.ImageProcessor.numChannels,
- )
-
- visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
- visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
- return []input.Multimodal{{Tensor: visionOutputs}}, nil
-}
-
-func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
- var result []*input.Input
-
- for _, inp := range inputs {
- if len(inp.Multimodal) == 0 {
- result = append(result, inp)
- } else {
- inputMultimodal := inp.Multimodal[0].Tensor
-
- result = append(result,
- &input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
- &input.Input{Token: 255999}, // """
- &input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
- )
-
- // add image token placeholders
- result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
-
- result = append(result,
- &input.Input{Token: 256000}, //
- &input.Input{Token: 108}, // "\n\n"
- )
- }
- }
-
- return result, nil
-}
-
-func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
- hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
- return m.Output.Forward(ctx, hiddenStates), nil
-}
-
-func init() {
- model.Register("gemma3", New)
- model.Register("gemma3_embed", newEmbedModel)
-}
diff --git a/x/model/models/gemma3/model_text.go b/x/model/models/gemma3/model_text.go
deleted file mode 100644
index d7686542a46..00000000000
--- a/x/model/models/gemma3/model_text.go
+++ /dev/null
@@ -1,211 +0,0 @@
-//go:build mlx
-
-package gemma3
-
-import (
- "math"
-
- "github.com/ollama/ollama/fs"
- "github.com/ollama/ollama/x/kvcache"
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/ml/nn"
- "github.com/ollama/ollama/x/model/input"
-)
-
-type TextConfig struct {
- hiddenSize, numHeads, numKVHeads int
- attnKeyLen int
- eps, ropeScale float32
- ropeLocalBase, ropeGlobalBase float32
- largeModelScaling bool
-}
-
-type TextModel struct {
- TokenEmbedding *nn.Embedding `gguf:"embed_tokens"`
- Layers []TextLayer `gguf:"layers"`
- OutputNorm *nn.RMSNorm `gguf:"norm"`
- Output *nn.Linear `gguf:"embed_tokens"`
-
- *TextConfig
-}
-
-const (
- gemmaGlobalCacheCount = 6
- gemma27BLayerCount = 62
-)
-
-// const (
-// cacheTypeSWA = iota
-// cacheTypeCausal
-// )
-
-func newTextModel(c fs.Config) *TextModel {
- numBlocks := int(c.Uint("block_count"))
-
- m := TextModel{
- Layers: make([]TextLayer, numBlocks),
- TextConfig: &TextConfig{
- hiddenSize: int(c.Uint("embedding_length")), // 2560 -- config.json: text_config.hidden_size
- numHeads: int(c.Uint("attention.head_count")), // 8 -- hard coded in python implementation for the model, 4 in some places, then overridden as 8
- numKVHeads: int(c.Uint("attention.head_count_kv")), // 4 -- same as above
- attnKeyLen: int(c.Uint("attention.key_length", 256)), //256 -- rope settings, hardcoded in model definition python
- eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), // 1e-06 - hardcoded in model definition python
- ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), // 10000 - hardcoded in python
- ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), // 1e+06 - hardcoded in python
- ropeScale: 1, // 1 - default is 1, implied in python code
- // vocabSize: vocabSize, // 262144
- // attnValLen: int(c.Uint("attention.value_length", 256)), //256
- // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
- // (8 instead of 1)
- // ropeScale: c.Float("rope.scaling.factor", 1.0),
- },
- }
- if numBlocks == gemma27BLayerCount {
- m.largeModelScaling = true
- }
-
- return &m
-}
-
-type TextSelfAttention struct {
- Query *nn.Linear `gguf:"q_proj"`
- QueryNorm *nn.RMSNorm `gguf:"q_norm"`
- Key *nn.Linear `gguf:"k_proj"`
- KeyNorm *nn.RMSNorm `gguf:"k_norm"`
- Value *nn.Linear `gguf:"v_proj"`
- Output *nn.Linear `gguf:"o_proj"`
-}
-
-func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
- B := hiddenState.Dim(0)
- L := hiddenState.Dim(1)
- ropeBase := opts.ropeLocalBase
- if (layer+1)%gemmaGlobalCacheCount == 0 {
- ropeBase = opts.ropeGlobalBase
- }
-
- q := sa.Query.Forward(ctx, hiddenState)
- k := sa.Key.Forward(ctx, hiddenState)
- v := sa.Value.Forward(ctx, hiddenState)
- q = q.Reshape(ctx, B, L, opts.numHeads, -1).Transpose(ctx, 0, 2, 1, 3)
- k = k.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3)
- v = v.Reshape(ctx, B, L, opts.numKVHeads, -1).Transpose(ctx, 0, 2, 1, 3).Contiguous(ctx, false)
- q = sa.QueryNorm.Forward(ctx, q, opts.eps)
- k = sa.KeyNorm.Forward(ctx, k, opts.eps)
- traditional := false
- q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
- k = k.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
-
- // TODO - this is wrong somehow so commenting out
- // if opts.largeModelScaling {
- // q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
- // } else {
- // q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.attnKeyLen)))
- // }
-
- scaleFactor := math.Pow(256, -0.5)
-
- kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
- kqv = kqv.Transpose(ctx, 0, 2, 1, 3).Reshape(ctx, B, L, -1)
- return sa.Output.Forward(ctx, kqv)
-}
-
-func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
- // ropeBase := m.TextConfig.ropeLocalBase
- // if (layer+1)%gemmaGlobalCacheCount == 0 {
- // ropeBase = m.TextConfig.ropeGlobalBase
- // }
- // q = q.RoPE(ctx, opts.attnKeyLen, traditional, opts.ropeScale, offset, ml.WithRoPEBase(ropeBase))
- panic("not yet implemented")
- // return key.RoPE(ctx, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
-}
-
-type TextMLP struct {
- Up *nn.Linear `gguf:"up_proj"`
- Down *nn.Linear `gguf:"down_proj"`
- Gate *nn.Linear `gguf:"gate_proj"`
-}
-
-func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
- hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
- return mlp.Down.Forward(ctx, hiddenState)
-}
-
-type TextLayer struct {
- AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
- SelfAttention *TextSelfAttention `gguf:"self_attn"`
- PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
- MLPNorm *nn.RMSNorm `gguf:"pre_feedforward_layernorm"`
- MLP *TextMLP `gguf:"mlp"`
- PostMLPNorm *nn.RMSNorm `gguf:"post_feedforward_layernorm"`
-}
-
-func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, outputs ml.Tensor, offset int, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
- residual := hiddenState
- hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
- hiddenState = l.SelfAttention.Forward(ctx, layer, hiddenState, offset, cache, opts)
- hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
-
- // In the final layer (outputs != nil), optimize by pruning to just the token positions
- // we need logits for.
- if outputs != nil {
- hiddenState = hiddenState.TakeAxes(ctx, outputs, 1)
- residual = residual.TakeAxes(ctx, outputs, 1)
- }
-
- hiddenState = hiddenState.Add(ctx, residual)
- residual = hiddenState
- hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
- hiddenState = l.MLP.Forward(ctx, hiddenState, opts) // TODO this is where it goes bad most likely...
- hiddenState = l.PostMLPNorm.Forward(ctx, hiddenState, opts.eps)
- return hiddenState.Add(ctx, residual)
-}
-
-func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
- hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
- hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
-
- // set image embeddings
- // var except []int
- // for _, image := range batch.Multimodal {
- // visionOutputs := image.Multimodal[0].Tensor
- // ctx.Forward(visionOutputs.Copy(ctx, hiddenState.AsStrided(ctx,
- // []int{visionOutputs.Dim(0) * visionOutputs.Dim(1)},
- // []int{image.Index * hiddenState.Stride(1)}, 0)))
-
- // for i := range visionOutputs.Dim(1) {
- // except = append(except, image.Index+i)
- // }
- // }
-
- for i, layer := range m.Layers {
- // gemma alternates between the sliding window (local) and causal (global)
- // kv cache every 6 layers
- if cache != nil {
- // cacheType := cacheTypeSWA
- // if (i+1)%gemmaGlobalCacheCount == 0 {
- // cacheType = cacheTypeCausal
- // }
- cache.SetLayer(i)
-
- // TODO this needs to come back
- // wc := cache.(*kvcache.WrapperCache)
- // wc.SetLayerType(cacheType)
-
- // if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
- // causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
- // }
- }
-
- var offset int
- var lastLayerOutputs ml.Tensor
- if i == len(m.Layers)-1 {
- offset = batch.Offset
- lastLayerOutputs = batch.Outputs
- }
-
- hiddenState = layer.Forward(ctx, i, hiddenState, lastLayerOutputs, offset, cache, m.TextConfig)
- }
- hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
- return hiddenState
-}
diff --git a/x/model/models/gemma3/model_vision.go b/x/model/models/gemma3/model_vision.go
deleted file mode 100644
index bffb3cb58dd..00000000000
--- a/x/model/models/gemma3/model_vision.go
+++ /dev/null
@@ -1,121 +0,0 @@
-//go:build mlx
-
-package gemma3
-
-import (
- "math"
-
- "github.com/ollama/ollama/fs"
- "github.com/ollama/ollama/x/ml"
- "github.com/ollama/ollama/x/ml/nn"
-)
-
-var batchSize int = 1
-
-type VisionSelfAttention struct {
- Query *nn.Linear `gguf:"self_attn.q_proj"`
- Key *nn.Linear `gguf:"self_attn.k_proj"`
- Value *nn.Linear `gguf:"self_attn.v_proj"`
- Output *nn.Linear `gguf:"self_attn.out_proj"`
-}
-
-func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
- headDim := opts.hiddenSize / opts.numHeads
-
- query := sa.Query.Forward(ctx, hiddenState)
- key := sa.Key.Forward(ctx, hiddenState)
- value := sa.Value.Forward(ctx, hiddenState)
-
- query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
- key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
- value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
-
- attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
- attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
-
- hiddenState = sa.Output.Forward(ctx, attention)
- return hiddenState
-}
-
-type VisionMLP struct {
- FC1 *nn.Linear `gguf:"fc1"`
- FC2 *nn.Linear `gguf:"fc2"`
-}
-
-func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
- hiddenState = mlp.FC1.Forward(ctx, hiddenState).GELU(ctx)
- hiddenState = mlp.FC2.Forward(ctx, hiddenState)
- return hiddenState
-}
-
-type VisionEncoderLayer struct {
- LayerNorm1 *nn.LayerNorm `gguf:"layer_norm1"`
- SelfAttention *VisionSelfAttention
-
- LayerNorm2 *nn.LayerNorm `gguf:"layer_norm2"`
- MLP *VisionMLP `gguf:"mlp"`
-}
-
-func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
- residual := hiddenState
-
- // self attention
- hiddenState = e.LayerNorm1.Forward(ctx, hiddenState, opts.eps)
- hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
- hiddenState = hiddenState.Add(ctx, residual)
- residual = hiddenState
-
- // feed forward
- hiddenState = e.LayerNorm2.Forward(ctx, hiddenState, opts.eps)
- hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
- return hiddenState.Add(ctx, residual)
-}
-
-type VisionModelOptions struct {
- hiddenSize, numHeads int
- imageSize, patchSize int
- eps float32
-}
-
-type VisionModel struct {
- PatchEmbedding *nn.Conv2D `gguf:"embeddings.patch_embedding"`
- PositionEmbedding *nn.Embedding `gguf:"embeddings.position_embedding"`
- PostLayerNorm *nn.LayerNorm `gguf:"post_layernorm"`
-
- Layers []VisionEncoderLayer `gguf:"encoder.layers"`
-
- *VisionModelOptions
-}
-
-func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
- numPatches := (m.imageSize / m.patchSize) * (m.imageSize / m.patchSize)
-
- hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
- hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
- hiddenState = hiddenState.Transpose(ctx, 1, 0, 2, 3).Contiguous(ctx, false)
-
- positionIDs := ctx.Arange(0, float32(numPatches), 1, ml.DTypeInt32)
- hiddenState = hiddenState.Add(ctx, m.PositionEmbedding.Forward(ctx, positionIDs))
-
- for _, layer := range m.Layers {
- hiddenState = layer.Forward(ctx, hiddenState, m.VisionModelOptions)
- }
-
- hiddenState = m.PostLayerNorm.Forward(ctx, hiddenState, m.eps)
- return hiddenState
-}
-
-func newVisionModel(c fs.Config) *VisionModel {
- return &VisionModel{
- Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
- VisionModelOptions: &VisionModelOptions{
- hiddenSize: int(c.Uint("vision.embedding_length")),
- numHeads: int(c.Uint("vision.attention.head_count")),
-
- imageSize: int(c.Uint("vision.image_size")),
- patchSize: int(c.Uint("vision.patch_size")),
-
- eps: c.Float("vision.attention.layer_norm_epsilon"),
- },
- }
-}
diff --git a/x/model/models/gemma3/process_image.go b/x/model/models/gemma3/process_image.go
deleted file mode 100644
index 09d0727d0eb..00000000000
--- a/x/model/models/gemma3/process_image.go
+++ /dev/null
@@ -1,60 +0,0 @@
-//go:build mlx
-
-package gemma3
-
-import (
- "image"
-
- "github.com/ollama/ollama/fs"
- "github.com/ollama/ollama/model/imageproc"
-)
-
-type ImageProcessor struct {
- imageSize, patchSize, numChannels int
-}
-
-func newImageProcessor(c fs.Config) ImageProcessor {
- return ImageProcessor{
- imageSize: int(c.Uint("vision.image_size")),
- patchSize: int(c.Uint("vision.patch_size")),
- numChannels: int(c.Uint("vision.num_channels")),
- }
-}
-
-func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
- var pixelVals, rVals, gVals, bVals []float32
-
- bounds := img.Bounds()
- for y := bounds.Min.Y; y < bounds.Max.Y; y++ {
- for x := bounds.Min.X; x < bounds.Max.X; x++ {
- c := img.At(x, y)
- r, g, b, _ := c.RGBA()
- rVal := float32(r>>8) / 255.0
- gVal := float32(g>>8) / 255.0
- bVal := float32(b>>8) / 255.0
-
- rVal = (rVal - mean[0]) / std[0]
- gVal = (gVal - mean[1]) / std[1]
- bVal = (bVal - mean[2]) / std[2]
-
- rVals = append(rVals, rVal)
- gVals = append(gVals, gVal)
- bVals = append(bVals, bVal)
- }
- }
-
- pixelVals = append(pixelVals, rVals...)
- pixelVals = append(pixelVals, gVals...)
- pixelVals = append(pixelVals, bVals...)
-
- return pixelVals
-}
-
-func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
- outputSize := image.Point{p.imageSize, p.imageSize}
- newImage := imageproc.Composite(img)
- newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
-
- data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
- return data, nil
-}
diff --git a/x/model/models/models.go b/x/model/models/models.go
deleted file mode 100644
index a2542707fba..00000000000
--- a/x/model/models/models.go
+++ /dev/null
@@ -1,3 +0,0 @@
-package models
-
-// _ "github.com/ollama/ollama/x/model/models/gemma3"
diff --git a/x/model/sentencepiece.go b/x/model/sentencepiece.go
deleted file mode 100644
index 2c178ec0c08..00000000000
--- a/x/model/sentencepiece.go
+++ /dev/null
@@ -1,249 +0,0 @@
-package model
-
-import (
- "container/heap"
- "fmt"
- "log/slog"
- "strconv"
- "strings"
-
- "github.com/ollama/ollama/logutil"
-)
-
-const spmWhitespaceSep = "▁"
-
-type SentencePiece struct {
- maxTokenLen int
- vocab *Vocabulary
-}
-
-var _ TextProcessor = (*SentencePiece)(nil)
-
-func (spm SentencePiece) Vocabulary() *Vocabulary {
- return spm.vocab
-}
-
-func NewSentencePiece(vocab *Vocabulary) SentencePiece {
- logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
-
- counter := map[int]int{}
- var maxTokenLen int
- for cnt := range vocab.Types {
- switch vocab.Types[cnt] {
- case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
- maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
- fallthrough
- default:
- counter[int(vocab.Types[cnt])] += 1
- }
- }
-
- logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
- "user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
- "max token len", maxTokenLen)
-
- return SentencePiece{
- maxTokenLen: maxTokenLen,
- vocab: vocab,
- }
-}
-
-func (spm SentencePiece) Is(id int32, special Special) bool {
- return spm.vocab.Is(id, special)
-}
-
-func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
- fragments := []fragment{{value: s}}
- for _, special := range spm.vocab.SpecialVocabulary() {
- id := spm.vocab.Encode(special)
- for i := 0; i < len(fragments); i++ {
- frag := fragments[i]
- if len(frag.ids) > 0 {
- continue
- }
-
- var middle []fragment
- switch i := strings.Index(frag.value, special); {
- case i < 0:
- middle = append(middle, frag)
- case i > 0:
- middle = append(middle, fragment{value: frag.value[:i]})
- fallthrough
- default:
- middle = append(middle, fragment{value: special, ids: []int32{id}})
- if rest := frag.value[i+len(special):]; rest != "" {
- middle = append(middle, fragment{value: rest})
- }
- }
-
- fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
- }
- }
-
- var ids []int32
- for _, frag := range fragments {
- if len(frag.ids) > 0 {
- ids = append(ids, frag.ids...)
- continue
- }
-
- text := strings.ReplaceAll(frag.value, " ", spmWhitespaceSep)
-
- if id := spm.vocab.Encode(text); id >= 0 {
- ids = append(ids, id)
- continue
- }
-
- q := &queue{}
- heap.Init(q)
-
- runes := []rune(text)
- merges := make([]merge, len(runes))
- for r := range runes {
- merges[r] = merge{
- p: r - 1,
- n: r + 1,
- runes: []rune{runes[r]},
- }
- }
-
- pairwise := func(a, b int) *candidate {
- if a < 0 || b >= len(runes) {
- return nil
- }
-
- left, right := string(merges[a].runes), string(merges[b].runes)
- if id := spm.vocab.Encode(left + right); id >= 0 {
- return &candidate{
- a: a,
- b: b,
- score: spm.vocab.Scores[id],
- size: len(left) + len(right),
- }
- }
-
- return nil
- }
-
- for i := range len(runes) - 1 {
- if pair := pairwise(i, i+1); pair != nil {
- heap.Push(q, pair)
- }
- }
-
- for q.Len() > 0 {
- pair := heap.Pop(q).(*candidate)
- left, right := merges[pair.a], merges[pair.b]
-
- if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size {
- continue
- }
-
- merges[pair.a].runes = append(left.runes, right.runes...)
- merges[pair.b].runes = nil
- merges[pair.a].n = right.n
- if right.n < len(merges) {
- merges[right.n].p = pair.a
- }
-
- if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
- heap.Push(q, pair)
- }
-
- if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
- heap.Push(q, pair)
- }
- }
-
- for _, merge := range merges {
- if token := string(merge.runes); token != "" {
- id := spm.vocab.Encode(token)
-
- if id >= 0 {
- ids = append(ids, id)
- continue
- }
-
- // Fallback to byte tokenization
- var result []int32
- for _, b := range []byte(token) {
- byteToken := fmt.Sprintf("<0x%02X>", b)
- unknownID := spm.vocab.Encode(byteToken)
- if unknownID >= 0 {
- result = append(result, unknownID)
- } else {
- slog.Debug("unknown byte token", "byte", b, "token", byteToken)
- }
- }
-
- ids = append(ids, result...)
- }
- }
- }
-
- if addSpecial {
- ids = spm.vocab.addSpecials(ids)
- }
-
- logutil.Trace("encoded", "string", s, "ids", ids)
- return ids, nil
-}
-
-type candidate struct {
- a, b int
- score float32
- size int
-}
-
-type queue []*candidate
-
-func (q queue) Len() int { return len(q) }
-
-func (q queue) Less(i, j int) bool {
- return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a)
-}
-
-func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] }
-
-func (q *queue) Push(x interface{}) {
- item := x.(*candidate)
- *q = append(*q, item)
-}
-
-func (q *queue) Pop() interface{} {
- old := *q
- n := len(old)
- item := old[n-1]
- *q = old[0 : n-1]
- return item
-}
-
-func (spm SentencePiece) Decode(ids []int32) (string, error) {
- var sb strings.Builder
- for _, id := range ids {
- data := spm.vocab.Decode(id)
- data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
-
- // For tokenizers that use byte tokens like "<0xEA>"
- // convert them to the partial unicode character
- // so they are buffered correctly by the runner instead
- // of being sent back to the api as "<0xEA>"
- if len(data) == 6 && strings.HasPrefix(data, "<0x") && strings.HasSuffix(data, ">") {
- byteVal, err := strconv.ParseUint(data[1:5], 0, 8)
- if err != nil {
- return "", fmt.Errorf("failed to parse hex byte: %v", err)
- }
-
- if err := sb.WriteByte(byte(byteVal)); err != nil {
- return "", err
- }
- } else {
- if _, err := sb.WriteString(data); err != nil {
- return "", err
- }
- }
- }
-
- logutil.Trace("decoded", "ids", ids, "string", sb.String())
- return sb.String(), nil
-}
diff --git a/x/model/sentencepiece_test.go b/x/model/sentencepiece_test.go
deleted file mode 100644
index 7ab158af770..00000000000
--- a/x/model/sentencepiece_test.go
+++ /dev/null
@@ -1,172 +0,0 @@
-package model
-
-import (
- "log/slog"
- "os"
- "path/filepath"
- "slices"
- "testing"
-
- "google.golang.org/protobuf/proto"
-
- "github.com/ollama/ollama/convert/sentencepiece"
-)
-
-func loadSentencePieceVocab(t *testing.T) SentencePiece {
- t.Helper()
-
- bts, err := os.ReadFile(filepath.Join("..", "..", "model", "testdata", "gemma2", "tokenizer.model"))
- if err != nil {
- t.Fatal(err)
- }
-
- var spm sentencepiece.ModelProto
- if err := proto.Unmarshal(bts, &spm); err != nil {
- t.Fatal(err)
- }
-
- var v Vocabulary
-
- for _, piece := range spm.GetPieces() {
- v.Values = append(v.Values, piece.GetPiece())
- v.Scores = append(v.Scores, piece.GetScore())
- switch t := piece.GetType(); t {
- case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
- sentencepiece.ModelProto_SentencePiece_CONTROL,
- sentencepiece.ModelProto_SentencePiece_UNUSED,
- sentencepiece.ModelProto_SentencePiece_BYTE:
- v.Types = append(v.Types, int32(t))
- default:
- tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
- // todo parse the special tokens file
- // - this will roundtrip correctly but the and
- // tokens aren't processed
- v.Types = append(v.Types, tt)
- }
- }
-
- return NewSentencePiece(&v)
-}
-
-func TestSentencePieceEncode(t *testing.T) {
- logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
- slog.SetDefault(logger)
-
- tokenizer := loadSentencePieceVocab(t)
-
- t.Run("basic roundtrip", func(t *testing.T) {
- t.Parallel()
-
- cases := []string{
- "hello",
- "hello ",
- "hello ",
- " hello",
- " hello ",
- " hello ",
- "hello world",
- "请考试我的软件!12345",
- "你好",
- "Hello 你好 world!",
- "Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
- "Multilingual: 你好 こんにちは Привет Hola مرحبا",
- "Numbers and symbols: 123456789 +- */",
- "Special tokens: text ",
- "Code snippets: func main() { fmt.Println(\"Hello World\") }",
- "Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
- "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
- "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
- }
-
- for _, want := range cases {
- ids, err := tokenizer.Encode(want, true)
- if err != nil {
- t.Fatal(err)
- }
-
- if got, err := tokenizer.Decode(ids); err != nil {
- t.Fatal(err)
- } else if got != want {
- t.Errorf("got %q, want %q [%#v]", got, want, ids)
- }
- }
- })
-
- t.Run("special tokens", func(t *testing.T) {
- type candidate struct {
- token string
- ids []int32
- }
-
- cases := []candidate{
- {"", []int32{2}},
- {"", []int32{1}},
- }
-
- for _, want := range cases {
- ids, err := tokenizer.Encode(want.token, true)
- if err != nil {
- t.Fatal(err)
- }
- if !slices.Equal(ids, want.ids) {
- t.Errorf("got %#v, want %#v", ids, want.ids)
- }
- }
- })
-}
-
-func TestSentencePieceDecodeByteTokens(t *testing.T) {
- vocab := &Vocabulary{
- Values: []string{
- "normal",
- "<0xEA>",
- "<0x41>",
- "<0xC3>",
- "<0xA3>",
- },
- Types: []int32{
- TOKEN_TYPE_NORMAL,
- TOKEN_TYPE_BYTE,
- TOKEN_TYPE_BYTE,
- TOKEN_TYPE_BYTE,
- TOKEN_TYPE_BYTE,
- },
- Scores: []float32{0, 0, 0, 0, 0},
- }
-
- spm := NewSentencePiece(vocab)
-
- tests := []struct {
- name string
- ids []int32
- expected string
- }{
- {
- name: "single byte token",
- ids: []int32{1},
- expected: "\xea",
- },
- {
- name: "ASCII byte token",
- ids: []int32{2},
- expected: "A",
- },
- {
- name: "multiple byte tokens forming UTF-8 character",
- ids: []int32{3, 4},
- expected: "ã",
- },
- }
-
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- result, err := spm.Decode(tt.ids)
- if err != nil {
- t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
- }
- if result != tt.expected {
- t.Errorf("got %q, want %q", result, tt.expected)
- }
- })
- }
-}
diff --git a/x/model/textprocessor.go b/x/model/textprocessor.go
deleted file mode 100644
index 4a36f235290..00000000000
--- a/x/model/textprocessor.go
+++ /dev/null
@@ -1,17 +0,0 @@
-package model
-
-const (
- TOKEN_TYPE_NORMAL = iota + 1
- TOKEN_TYPE_UNKNOWN
- TOKEN_TYPE_CONTROL
- TOKEN_TYPE_USER_DEFINED
- TOKEN_TYPE_UNUSED
- TOKEN_TYPE_BYTE
-)
-
-type TextProcessor interface {
- Encode(s string, addSpecial bool) ([]int32, error)
- Decode([]int32) (string, error)
- Is(int32, Special) bool
- Vocabulary() *Vocabulary
-}
diff --git a/x/model/vocabulary.go b/x/model/vocabulary.go
deleted file mode 100644
index d977c495781..00000000000
--- a/x/model/vocabulary.go
+++ /dev/null
@@ -1,112 +0,0 @@
-package model
-
-import (
- "log/slog"
- "slices"
- "sync"
-)
-
-type Special int32
-
-const (
- SpecialBOS Special = iota
- SpecialEOS
-)
-
-type Vocabulary struct {
- Values []string
- Types []int32
- Scores []float32
- Merges []string
-
- BOS, EOS []int32
- AddBOS, AddEOS bool
-
- specialOnce sync.Once
- special []string
-
- valuesOnce sync.Once
- values map[string]int32
-
- mergeOnce sync.Once
- merge map[string]int32
-}
-
-func (v *Vocabulary) Is(id int32, special Special) bool {
- switch special {
- case SpecialBOS:
- return slices.Contains(v.BOS, id)
- case SpecialEOS:
- return slices.Contains(v.EOS, id)
- default:
- return false
- }
-}
-
-func (v *Vocabulary) addSpecials(ids []int32) []int32 {
- if v.AddBOS && len(v.BOS) > 0 {
- if len(ids) > 0 && slices.Contains(v.BOS, ids[0]) {
- slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
- }
-
- slog.Debug("adding bos token to prompt", "id", v.BOS[0])
- ids = append([]int32{v.BOS[0]}, ids...)
- }
-
- if v.AddEOS && len(v.EOS) > 0 {
- if len(ids) > 0 && slices.Contains(v.BOS, ids[len(ids)-1]) {
- slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
- }
-
- slog.Debug("adding eos token to prompt", "id", v.EOS[0])
- ids = append(ids, v.EOS[0])
- }
-
- return ids
-}
-
-func (v *Vocabulary) Encode(s string) int32 {
- v.valuesOnce.Do(func() {
- v.values = make(map[string]int32, len(v.Values))
- for i, value := range v.Values {
- v.values[value] = int32(i)
- }
- })
-
- if id, ok := v.values[s]; ok {
- return id
- }
-
- return -1
-}
-
-func (v *Vocabulary) Decode(id int32) string {
- return v.Values[id]
-}
-
-func (v *Vocabulary) SpecialVocabulary() []string {
- v.specialOnce.Do(func() {
- for i := range v.Values {
- if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
- v.special = append(v.special, v.Values[i])
- }
- }
- })
-
- return v.special
-}
-
-func (v *Vocabulary) Merge(left, right string) int {
- v.mergeOnce.Do(func() {
- v.merge = make(map[string]int32, len(v.Merges))
- for i, merge := range v.Merges {
- v.merge[merge] = int32(i)
- }
- })
-
- if id, ok := v.merge[left+" "+right]; ok {
- return int(id)
- }
-
- return -1
-}
diff --git a/x/model/vocabulary_test.go b/x/model/vocabulary_test.go
deleted file mode 100644
index ccfc39e6945..00000000000
--- a/x/model/vocabulary_test.go
+++ /dev/null
@@ -1,107 +0,0 @@
-package model
-
-import (
- "testing"
-
- "github.com/google/go-cmp/cmp"
-)
-
-func TestSpecialVocabulary(t *testing.T) {
- vocab := &Vocabulary{
- Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
- Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
- }
-
- specialVocab := vocab.SpecialVocabulary()
-
- if len(specialVocab) != 4 {
- t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
- }
-}
-
-func TestAddSpecialVocabulary(t *testing.T) {
- cases := []struct {
- name string
- vocab *Vocabulary
- input []int32
- want []int32
- }{
- {
- name: "add bos",
- vocab: &Vocabulary{
- BOS: []int32{0},
- EOS: []int32{1},
- AddBOS: true,
- AddEOS: false,
- },
- input: []int32{2, 3, 4},
- want: []int32{0, 2, 3, 4},
- },
- {
- // TODO(mxyng): this is to match previous behaviour
- name: "add bos when already present",
- vocab: &Vocabulary{
- BOS: []int32{0},
- EOS: []int32{1},
- AddBOS: true,
- AddEOS: false,
- },
- input: []int32{0, 2, 3, 4},
- want: []int32{0, 0, 2, 3, 4},
- },
- {
- name: "add eos",
- vocab: &Vocabulary{
- BOS: []int32{0},
- EOS: []int32{1},
- AddBOS: false,
- AddEOS: true,
- },
- input: []int32{2, 3, 4},
- want: []int32{2, 3, 4, 1},
- },
- {
- // TODO(mxyng): this is to match previous behaviour
- name: "add eos when already present",
- vocab: &Vocabulary{
- BOS: []int32{0},
- EOS: []int32{1},
- AddBOS: false,
- AddEOS: true,
- },
- input: []int32{2, 3, 4, 1},
- want: []int32{2, 3, 4, 1, 1},
- },
- {
- name: "add both",
- vocab: &Vocabulary{
- BOS: []int32{0},
- EOS: []int32{1},
- AddBOS: true,
- AddEOS: true,
- },
- input: []int32{2, 3, 4},
- want: []int32{0, 2, 3, 4, 1},
- },
- {
- name: "add bos to empty inputs",
- vocab: &Vocabulary{
- BOS: []int32{0},
- EOS: []int32{1},
- AddBOS: true,
- AddEOS: false,
- },
- input: []int32{},
- want: []int32{0},
- },
- }
-
- for _, tt := range cases {
- t.Run(tt.name, func(t *testing.T) {
- got := tt.vocab.addSpecials(tt.input)
- if diff := cmp.Diff(tt.want, got); diff != "" {
- t.Errorf("no match (-want +got):\n%s", diff)
- }
- })
- }
-}
diff --git a/x/model/wordpiece.go b/x/model/wordpiece.go
deleted file mode 100644
index e552bce0dd3..00000000000
--- a/x/model/wordpiece.go
+++ /dev/null
@@ -1,171 +0,0 @@
-package model
-
-import (
- "fmt"
- "iter"
- "strings"
- "unicode"
-
- "github.com/ollama/ollama/logutil"
-)
-
-type WordPiece struct {
- vocab *Vocabulary
- lowercase bool
-}
-
-// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
-// this differs from original word piece which uses "##" to indicate subwords.
-const ggmlPrefix = "▁"
-
-var wordPieceReplacer = strings.NewReplacer(
- " .", ".",
- " ?", "?",
- " !", "!",
- " ,", ",",
- " ' ", "'",
- " n't", "n't",
- " 'm", "'m",
- " do not", " don't",
- " 's", "'s",
- " 've", "'ve",
- " 're", "'re",
-)
-
-// Decode implements TextProcessor.
-func (wpm WordPiece) Decode(ids []int32) (string, error) {
- var sb strings.Builder
- for i, id := range ids {
- if id < 0 || int(id) >= len(wpm.vocab.Values) {
- return "", fmt.Errorf("invalid token id: %d", id)
- }
-
- var separator string
- piece := wpm.vocab.Values[id]
- if i > 0 &&
- (strings.HasPrefix(piece, ggmlPrefix) ||
- (strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
- separator = " "
- }
-
- sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
- }
-
- return sb.String(), nil
-}
-
-// words splits a string into words, treating CJK characters as separate words.
-// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
-func (wpm WordPiece) words(s string) iter.Seq[string] {
- return func(yield func(string) bool) {
- runes := make([]rune, 0, len(s)*3)
- for _, r := range s {
- switch {
- case r >= 0x4E00 && r <= 0x9FFF,
- r >= 0x3400 && r <= 0x4DBF,
- r >= 0x20000 && r <= 0x2A6DF,
- r >= 0x2A700 && r <= 0x2B73F,
- r >= 0x2B740 && r <= 0x2B81F,
- r >= 0x2B820 && r <= 0x2CEAF,
- r >= 0xF900 && r <= 0xFAFF,
- r >= 0x2F800 && r <= 0x2FA1F:
- runes = append(runes, ' ', r, ' ')
- default:
- runes = append(runes, r)
- }
- }
-
- for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
- // split on but keep punctuation
- var start int
- for start < len(w) {
- end := strings.IndexFunc(w[start:], unicode.IsPunct)
- if end < 0 {
- end = len(w) - start
- } else if end == 0 {
- end = 1
- }
-
- if !yield(w[start : start+end]) {
- return
- }
-
- start += end
- }
- }
- }
-}
-
-// Encode implements TextProcessor.
-func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
- var ids []int32
-
- // TODO: use [UNK] from config
- unk := wpm.vocab.Encode("[UNK]")
- for word := range wpm.words(s) {
- var start int
- var pieces []int32
- for start < len(word) {
- end := len(word)
-
- var piece int32
- for start < end {
- subword := word[start:end]
- if start == 0 {
- subword = ggmlPrefix + subword
- }
-
- if wpm.lowercase {
- subword = strings.ToLower(subword)
- }
- piece = wpm.vocab.Encode(subword)
- if piece >= 0 {
- break
- }
-
- end--
- }
-
- if piece < 0 {
- // Unknown token
- pieces = pieces[:0]
- break
- }
-
- pieces = append(pieces, piece)
- start = end
- }
-
- if len(pieces) > 0 {
- ids = append(ids, pieces...)
- } else {
- ids = append(ids, unk)
- }
- }
-
- if addSpecial {
- ids = wpm.vocab.addSpecials(ids)
- }
-
- logutil.Trace("encoded", "string", s, "ids", ids)
- return ids, nil
-}
-
-// Is implements TextProcessor.
-func (wpm WordPiece) Is(id int32, special Special) bool {
- return wpm.vocab.Is(id, special)
-}
-
-// Vocabulary implements TextProcessor.
-func (wpm WordPiece) Vocabulary() *Vocabulary {
- return wpm.vocab
-}
-
-var _ TextProcessor = (*WordPiece)(nil)
-
-func NewWordPiece(vocab *Vocabulary, lowercase bool) WordPiece {
- return WordPiece{
- vocab: vocab,
- lowercase: lowercase,
- }
-}
diff --git a/x/models/gemma3/gemma3.go b/x/models/gemma3/gemma3.go
new file mode 100644
index 00000000000..7ba24d29490
--- /dev/null
+++ b/x/models/gemma3/gemma3.go
@@ -0,0 +1,521 @@
+//go:build mlx
+
+// Package gemma3 provides the Gemma 3 text model implementation for MLX.
+package gemma3
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/model"
+ "github.com/ollama/ollama/x/mlxrunner/model/base"
+ "github.com/ollama/ollama/x/models/nn"
+ "github.com/ollama/ollama/x/tokenizer"
+)
+
+func init() {
+ base.Register("Gemma3ForCausalLM", newModel)
+ base.Register("Gemma3ForConditionalGeneration", newModel)
+}
+
+// TextConfig holds configuration for the Gemma 3 text model.
+type TextConfig struct {
+ HiddenSize int32 `json:"hidden_size"`
+ NumHiddenLayers int32 `json:"num_hidden_layers"`
+ IntermediateSize int32 `json:"intermediate_size"`
+ NumAttentionHeads int32 `json:"num_attention_heads"`
+ NumKeyValueHeads int32 `json:"num_key_value_heads"`
+ HeadDim int32 `json:"head_dim"`
+ VocabSize int32 `json:"vocab_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ RopeTheta float32 `json:"rope_theta"`
+ RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
+ MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
+ SlidingWindow int32 `json:"sliding_window"`
+ SlidingWindowPattern int32 `json:"sliding_window_pattern"`
+ LayerTypes []string `json:"layer_types"`
+ TieWordEmbeddings bool `json:"tie_word_embeddings"`
+
+ // Quantization parameters (set during load based on model quantization).
+ QuantGroupSize int `json:"-"`
+ QuantBits int `json:"-"`
+ QuantMode string `json:"-"`
+ TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
+
+ // Computed fields.
+ Scale float32 `json:"-"`
+}
+
+// Attention implements Gemma 3 attention with Q/K normalization.
+type Attention struct {
+ QProj nn.LinearLayer
+ KProj nn.LinearLayer
+ VProj nn.LinearLayer
+ OProj nn.LinearLayer
+
+ QNorm *nn.RMSNorm
+ KNorm *nn.RMSNorm
+
+ // Precomputed (1 + weight) for Gemma-style RMSNorm.
+ QNormScaled *mlx.Array
+ KNormScaled *mlx.Array
+}
+
+// MLP is the feed-forward network with GELU activation.
+type MLP struct {
+ GateProj nn.LinearLayer
+ UpProj nn.LinearLayer
+ DownProj nn.LinearLayer
+}
+
+// DecoderLayer is a single transformer block.
+type DecoderLayer struct {
+ InputNorm *nn.RMSNorm
+ Attention *Attention
+ PostAttnNorm *nn.RMSNorm
+ PreFFNorm *nn.RMSNorm
+ MLP *MLP
+ PostFFNorm *nn.RMSNorm
+
+ // Precomputed (1 + weight) for Gemma-style RMSNorm.
+ InputNormScaled *mlx.Array
+ PostAttnNormScaled *mlx.Array
+ PreFFNormScaled *mlx.Array
+ PostFFNormScaled *mlx.Array
+
+ // Layer metadata.
+ IsSliding bool
+ LayerIdx int32
+}
+
+// Model is the Gemma 3 text-only model.
+type Model struct {
+ EmbedTokens *nn.Embedding
+ Layers []*DecoderLayer
+ Norm *nn.RMSNorm
+ LMHead nn.LinearLayer
+
+ // Precomputed (1 + weight) for Gemma-style RMSNorm.
+ NormScaled *mlx.Array
+
+ tok *tokenizer.Tokenizer
+ *TextConfig
+
+ weightPrefix string
+}
+
+func defaultHeads(numLayers int32) (numHeads, numKVHeads int32) {
+ switch numLayers {
+ case 34:
+ return 8, 4
+ case 48:
+ return 16, 8
+ case 62:
+ return 32, 16
+ default:
+ return 8, 4
+ }
+}
+
+func parseTextConfig(configData []byte) (TextConfig, bool, error) {
+ var cfg TextConfig
+ if err := json.Unmarshal(configData, &cfg); err != nil {
+ return TextConfig{}, false, fmt.Errorf("parse config: %w", err)
+ }
+
+ var wrapped struct {
+ TextConfig *TextConfig `json:"text_config"`
+ }
+ if err := json.Unmarshal(configData, &wrapped); err != nil {
+ return TextConfig{}, false, fmt.Errorf("parse nested text config: %w", err)
+ }
+
+ fromConditional := wrapped.TextConfig != nil
+ if fromConditional {
+ cfg = *wrapped.TextConfig
+
+ if cfg.HeadDim == 0 {
+ cfg.HeadDim = 256
+ }
+ if cfg.NumAttentionHeads == 0 {
+ cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
+ }
+ if cfg.NumKeyValueHeads == 0 {
+ _, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
+ }
+ if cfg.VocabSize == 0 {
+ cfg.VocabSize = 262208
+ }
+ if cfg.SlidingWindowPattern == 0 && len(cfg.LayerTypes) == 0 {
+ cfg.SlidingWindowPattern = 6
+ }
+ if cfg.MaxPositionEmbeddings == 0 {
+ cfg.MaxPositionEmbeddings = 131072
+ }
+ }
+
+ if cfg.HeadDim == 0 {
+ cfg.HeadDim = 256
+ }
+ if cfg.NumAttentionHeads == 0 {
+ cfg.NumAttentionHeads, cfg.NumKeyValueHeads = defaultHeads(cfg.NumHiddenLayers)
+ }
+ if cfg.NumKeyValueHeads == 0 {
+ cfg.NumKeyValueHeads = max(1, cfg.NumAttentionHeads/2)
+ }
+ if cfg.RopeTheta == 0 {
+ cfg.RopeTheta = 1000000
+ }
+ if cfg.RopeLocalBaseFreq == 0 {
+ cfg.RopeLocalBaseFreq = 10000
+ }
+ if cfg.RMSNormEps == 0 {
+ cfg.RMSNormEps = 1e-6
+ }
+ if cfg.VocabSize == 0 {
+ cfg.VocabSize = 262208
+ }
+
+ cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
+
+ return cfg, fromConditional, nil
+}
+
+func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
+ for _, prefix := range []string{"", "language_model."} {
+ if tensors[prefix+"model.embed_tokens.weight"] != nil {
+ return prefix
+ }
+ }
+ return ""
+}
+
+func isLayerSliding(layerIdx int32, cfg *TextConfig) bool {
+ if len(cfg.LayerTypes) > 0 && int(layerIdx) < len(cfg.LayerTypes) {
+ return cfg.LayerTypes[layerIdx] == "sliding_attention"
+ }
+ if cfg.SlidingWindowPattern <= 0 {
+ return false
+ }
+ return (layerIdx+1)%cfg.SlidingWindowPattern != 0
+}
+
+func precomputeGemmaScaledWeights(m *Model) {
+ if m.Norm != nil {
+ m.NormScaled = mlx.AddScalar(m.Norm.Weight, 1.0)
+ }
+
+ var scaled []*mlx.Array
+ if m.NormScaled != nil {
+ scaled = append(scaled, m.NormScaled)
+ }
+
+ for _, layer := range m.Layers {
+ if layer == nil || layer.Attention == nil {
+ continue
+ }
+
+ if layer.InputNorm != nil {
+ layer.InputNormScaled = mlx.AddScalar(layer.InputNorm.Weight, 1.0)
+ scaled = append(scaled, layer.InputNormScaled)
+ }
+ if layer.PostAttnNorm != nil {
+ layer.PostAttnNormScaled = mlx.AddScalar(layer.PostAttnNorm.Weight, 1.0)
+ scaled = append(scaled, layer.PostAttnNormScaled)
+ }
+ if layer.PreFFNorm != nil {
+ layer.PreFFNormScaled = mlx.AddScalar(layer.PreFFNorm.Weight, 1.0)
+ scaled = append(scaled, layer.PreFFNormScaled)
+ }
+ if layer.PostFFNorm != nil {
+ layer.PostFFNormScaled = mlx.AddScalar(layer.PostFFNorm.Weight, 1.0)
+ scaled = append(scaled, layer.PostFFNormScaled)
+ }
+
+ if layer.Attention.QNorm != nil {
+ layer.Attention.QNormScaled = mlx.AddScalar(layer.Attention.QNorm.Weight, 1.0)
+ scaled = append(scaled, layer.Attention.QNormScaled)
+ }
+ if layer.Attention.KNorm != nil {
+ layer.Attention.KNormScaled = mlx.AddScalar(layer.Attention.KNorm.Weight, 1.0)
+ scaled = append(scaled, layer.Attention.KNormScaled)
+ }
+ }
+
+ if len(scaled) > 0 {
+ mlx.Eval(scaled...)
+ }
+}
+
+func newModel(root *model.Root) (base.Model, error) {
+ configData, err := root.Manifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("load config: %w", err)
+ }
+
+ cfg, _, err := parseTextConfig(configData)
+ if err != nil {
+ return nil, err
+ }
+
+ if qt := root.QuantType(); qt != "" {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
+ if gs := root.GroupSize(); gs > 0 {
+ cfg.QuantGroupSize = gs
+ }
+ } else {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
+ }
+ cfg.TensorQuant = root.AllTensorQuant()
+
+ tokData, err := root.Manifest.ReadConfig("tokenizer.json")
+ if err != nil {
+ return nil, fmt.Errorf("load tokenizer config: %w", err)
+ }
+
+ tokConfig := &tokenizer.TokenizerConfig{ConfigJSON: configData}
+ if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
+ tokConfig.GenerationConfigJSON = genConfigData
+ }
+ if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
+ tokConfig.TokenizerConfigJSON = tokConfigData
+ }
+
+ tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
+ if err != nil {
+ return nil, fmt.Errorf("parse tokenizer: %w", err)
+ }
+
+ m := &Model{
+ Layers: make([]*DecoderLayer, cfg.NumHiddenLayers),
+ TextConfig: &cfg,
+ tok: tok,
+ }
+
+ for i := range m.Layers {
+ m.Layers[i] = &DecoderLayer{
+ LayerIdx: int32(i),
+ IsSliding: isLayerSliding(int32(i), m.TextConfig),
+ }
+ }
+
+ return m, nil
+}
+
+// LoadWeights receives all tensors loaded from the manifest and assigns them
+// to model fields.
+func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
+ m.weightPrefix = resolveWeightPrefix(tensors)
+ prefix := m.weightPrefix
+ linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
+
+ embedWeight := tensors[prefix+"model.embed_tokens.weight"]
+ if embedWeight == nil {
+ return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
+ }
+ m.EmbedTokens = nn.NewEmbedding(embedWeight)
+
+ normWeight := tensors[prefix+"model.norm.weight"]
+ if normWeight == nil {
+ return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
+ }
+ m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
+
+ if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
+ m.LMHead = lmHead
+ } else if lmHead := linears.Make("lm_head"); lmHead != nil {
+ m.LMHead = lmHead
+ } else {
+ // Gemma usually ties output projection to embeddings.
+ m.LMHead = nn.NewLinear(embedWeight, nil)
+ }
+
+ for i := int32(0); i < m.NumHiddenLayers; i++ {
+ layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
+
+ layer := &DecoderLayer{
+ LayerIdx: i,
+ IsSliding: isLayerSliding(i, m.TextConfig),
+ Attention: &Attention{},
+ MLP: &MLP{},
+ }
+
+ if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
+ layer.InputNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+ if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
+ layer.PostAttnNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+ if w := tensors[layerPrefix+".pre_feedforward_layernorm.weight"]; w != nil {
+ layer.PreFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+ if w := tensors[layerPrefix+".post_feedforward_layernorm.weight"]; w != nil {
+ layer.PostFFNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+
+ layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
+ layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
+ layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
+ layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
+
+ if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
+ layer.Attention.QNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+ if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
+ layer.Attention.KNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+
+ layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
+ layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
+ layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
+
+ if layer.InputNorm == nil {
+ return fmt.Errorf("layer %d: missing input_layernorm", i)
+ }
+ if layer.PostAttnNorm == nil {
+ return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
+ }
+ if layer.PreFFNorm == nil {
+ return fmt.Errorf("layer %d: missing pre_feedforward_layernorm", i)
+ }
+ if layer.PostFFNorm == nil {
+ return fmt.Errorf("layer %d: missing post_feedforward_layernorm", i)
+ }
+ if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
+ return fmt.Errorf("layer %d: missing attention projections", i)
+ }
+ if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
+ return fmt.Errorf("layer %d: missing attention q/k norms", i)
+ }
+ if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
+ return fmt.Errorf("layer %d: missing mlp projections", i)
+ }
+
+ m.Layers[i] = layer
+ }
+
+ precomputeGemmaScaledWeights(m)
+ if m.NormScaled == nil {
+ return fmt.Errorf("missing precomputed final norm weight")
+ }
+ collected := mlx.Collect(m)
+ mlx.Eval(collected...)
+
+ return nil
+}
+
+func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
+ dims := tokens.Dims()
+ B, L := int32(dims[0]), int32(dims[1])
+
+ h := m.EmbedTokens.Forward(tokens)
+ h = mlx.MulScalar(h, float32(math.Sqrt(float64(m.HiddenSize))))
+
+ for i, layer := range m.Layers {
+ var c cache.Cache
+ if caches != nil && i < len(caches) {
+ c = caches[i]
+ }
+ h = layer.Forward(h, c, B, L, m.TextConfig)
+ }
+
+ return mlx.RMSNormFn(h, m.NormScaled, m.RMSNormEps)
+}
+
+func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
+ return m.LMHead.Forward(x)
+}
+
+func (m *Model) NumLayers() int {
+ return len(m.Layers)
+}
+
+func (m *Model) Tokenizer() *tokenizer.Tokenizer {
+ return m.tok
+}
+
+// NewCaches creates cache objects for all layers.
+func (m *Model) NewCaches() []cache.Cache {
+ caches := make([]cache.Cache, len(m.Layers))
+ for i, layer := range m.Layers {
+ if m.SlidingWindow > 0 && layer.IsSliding {
+ caches[i] = cache.NewRotatingKVCache(int(m.SlidingWindow))
+ } else {
+ caches[i] = cache.NewKVCache()
+ }
+ }
+ return caches
+}
+
+// FormatPrompt applies the Gemma 3 chat template.
+func (m *Model) FormatPrompt(prompt string) string {
+ return fmt.Sprintf("user\n%s\nmodel\n", prompt)
+}
+
+func (l *DecoderLayer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *TextConfig) *mlx.Array {
+ normed := mlx.RMSNormFn(x, l.InputNormScaled, cfg.RMSNormEps)
+
+ attnOut := l.Attention.Forward(normed, c, B, L, l.IsSliding, cfg)
+ attnOut = mlx.RMSNormFn(attnOut, l.PostAttnNormScaled, cfg.RMSNormEps)
+ h := mlx.Add(x, attnOut)
+
+ normed = mlx.RMSNormFn(h, l.PreFFNormScaled, cfg.RMSNormEps)
+
+ mlpOut := l.MLP.Forward(normed)
+ mlpOut = mlx.RMSNormFn(mlpOut, l.PostFFNormScaled, cfg.RMSNormEps)
+
+ return mlx.Add(h, mlpOut)
+}
+
+func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, isSliding bool, cfg *TextConfig) *mlx.Array {
+ q := a.QProj.Forward(x)
+ k := a.KProj.Forward(x)
+ v := a.VProj.Forward(x)
+
+ q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
+ q = mlx.Transpose(q, 0, 2, 1, 3)
+
+ k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
+ k = mlx.Transpose(k, 0, 2, 1, 3)
+
+ v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
+ v = mlx.Transpose(v, 0, 2, 1, 3)
+
+ q = mlx.RMSNormFn(q, a.QNormScaled, cfg.RMSNormEps)
+ k = mlx.RMSNormFn(k, a.KNormScaled, cfg.RMSNormEps)
+
+ ropeTheta := cfg.RopeTheta
+ if isSliding {
+ ropeTheta = cfg.RopeLocalBaseFreq
+ }
+
+ offset := 0
+ if c != nil {
+ offset = c.Offset()
+ }
+ q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
+ k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, ropeTheta, 1.0, offset)
+
+ if c != nil {
+ k, v = c.Update(k, v)
+ }
+
+ repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
+ if repeatFactor > 1 {
+ k = nn.RepeatKV(k, repeatFactor)
+ v = nn.RepeatKV(v, repeatFactor)
+ }
+
+ out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
+ out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
+ return a.OProj.Forward(out)
+}
+
+func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.GELUApprox(m.GateProj.Forward(x))
+ up := m.UpProj.Forward(x)
+ return m.DownProj.Forward(mlx.Mul(gate, up))
+}
diff --git a/x/models/glm4_moe_lite/glm4_moe_lite.go b/x/models/glm4_moe_lite/glm4_moe_lite.go
new file mode 100644
index 00000000000..a1ec559725c
--- /dev/null
+++ b/x/models/glm4_moe_lite/glm4_moe_lite.go
@@ -0,0 +1,777 @@
+//go:build mlx
+
+// Package glm4_moe_lite provides the GLM4-MoE-Lite implementation for MLX.
+// This model uses Multi-head Latent Attention (MLA) and Mixture of Experts (MoE).
+package glm4_moe_lite
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/model"
+ "github.com/ollama/ollama/x/mlxrunner/model/base"
+ "github.com/ollama/ollama/x/models/nn"
+ "github.com/ollama/ollama/x/tokenizer"
+)
+
+func init() {
+ base.Register("Glm4MoeLiteForCausalLM", newModel)
+ base.Register("GLM4MoeLite", newModel)
+}
+
+// RopeScaling holds RoPE scaling configuration
+type RopeScaling struct {
+ Factor float32 `json:"factor"`
+ MscaleAllDim float32 `json:"mscale_all_dim"`
+}
+
+// Config holds GLM4-MoE-Lite model configuration
+type Config struct {
+ HiddenSize int32 `json:"hidden_size"`
+ NumHiddenLayers int32 `json:"num_hidden_layers"`
+ IntermediateSize int32 `json:"intermediate_size"`
+ MoEIntermediateSize int32 `json:"moe_intermediate_size"`
+ NumAttentionHeads int32 `json:"num_attention_heads"`
+ NumKeyValueHeads int32 `json:"num_key_value_heads"`
+ VocabSize int32 `json:"vocab_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ RopeTheta float32 `json:"rope_theta"`
+ MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
+ AttentionBias bool `json:"attention_bias"`
+
+ // MLA (Multi-head Latent Attention) parameters
+ QLoraRank int32 `json:"q_lora_rank"`
+ KVLoraRank int32 `json:"kv_lora_rank"`
+ QKRopeHeadDim int32 `json:"qk_rope_head_dim"`
+ QKNopeHeadDim int32 `json:"qk_nope_head_dim"`
+ VHeadDim int32 `json:"v_head_dim"`
+
+ // MoE parameters
+ NRoutedExperts int32 `json:"n_routed_experts"`
+ NSharedExperts int32 `json:"n_shared_experts"`
+ NumExpertsPerTok int32 `json:"num_experts_per_tok"`
+ RoutedScalingFactor float32 `json:"routed_scaling_factor"`
+ NormTopKProb bool `json:"norm_topk_prob"`
+ FirstKDenseReplace int32 `json:"first_k_dense_replace"`
+ NGroup int32 `json:"n_group"`
+ TopKGroup int32 `json:"topk_group"`
+
+ // RoPE scaling
+ RopeScaling *RopeScaling `json:"rope_scaling"`
+
+ // Quantization parameters (set during load based on model quantization)
+ QuantGroupSize int `json:"-"` // Group size for quantization (default 64)
+ QuantBits int `json:"-"` // Bits per weight (4 or 8)
+ QuantMode string `json:"-"` // Quantization mode ("affine", etc.)
+ TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
+
+ // Computed fields
+ QHeadDim int32 `json:"-"` // qk_nope_head_dim + qk_rope_head_dim
+ Scale float32 `json:"-"` // 1/sqrt(QHeadDim) with mscale adjustment
+}
+
+// MLAAttention implements Multi-head Latent Attention with absorption.
+type MLAAttention struct {
+ QAProj nn.LinearLayer
+ QALayerNorm *nn.RMSNorm
+ QBProj nn.LinearLayer
+
+ KVAProjWithMQA nn.LinearLayer
+ KVALayerNorm *nn.RMSNorm
+
+ EmbedQ *nn.MultiLinear
+ UnembedOut *nn.MultiLinear
+
+ OProj nn.LinearLayer
+}
+
+// Forward computes absorbed MLA attention output.
+func (a *MLAAttention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ q := a.QAProj.Forward(x)
+ q = a.QALayerNorm.Forward(q, cfg.RMSNormEps)
+ q = a.QBProj.Forward(q)
+
+ q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.QHeadDim)
+ q = mlx.Transpose(q, 0, 2, 1, 3)
+
+ qNope := mlx.SliceStartStop(q, []int32{0, 0, 0, 0}, []int32{B, cfg.NumAttentionHeads, L, cfg.QKNopeHeadDim})
+ qPE := mlx.SliceStartStop(q, []int32{0, 0, 0, cfg.QKNopeHeadDim}, []int32{B, cfg.NumAttentionHeads, L, cfg.QHeadDim})
+
+ compressedKV := a.KVAProjWithMQA.Forward(x)
+
+ kvCompressed := mlx.SliceStartStop(compressedKV, []int32{0, 0, 0}, []int32{B, L, cfg.KVLoraRank})
+ kPE := mlx.SliceStartStop(compressedKV, []int32{0, 0, cfg.KVLoraRank}, []int32{B, L, cfg.KVLoraRank + cfg.QKRopeHeadDim})
+
+ kPE = mlx.Reshape(kPE, B, L, 1, cfg.QKRopeHeadDim)
+ kPE = mlx.Transpose(kPE, 0, 2, 1, 3)
+
+ kvLatent := a.KVALayerNorm.Forward(kvCompressed, cfg.RMSNormEps)
+ kvLatent = mlx.ExpandDims(kvLatent, 1)
+
+ offset := 0
+ if c != nil {
+ offset = c.Offset()
+ }
+ qPE = mlx.RoPEWithBase(qPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
+ kPE = mlx.RoPEWithBase(kPE, int(cfg.QKRopeHeadDim), true, cfg.RopeTheta, 1.0, offset)
+
+ qLatent := a.EmbedQ.Forward(qNope)
+
+ keys := mlx.Concatenate([]*mlx.Array{kvLatent, kPE}, 3)
+
+ cachedL := L
+ if c != nil {
+ placeholderValues := mlx.ZerosF32([]int32{B, 1, L, 0})
+ keys, _ = c.Update(keys, placeholderValues)
+ cachedL = int32(keys.Dim(2))
+ }
+
+ values := mlx.SliceStartStop(keys, []int32{0, 0, 0, 0}, []int32{B, 1, cachedL, cfg.KVLoraRank})
+
+ queries := mlx.Concatenate([]*mlx.Array{qLatent, qPE}, 3)
+
+ out := mlx.ScaledDotProductAttentionCausal(queries, keys, values, cfg.Scale, L > 1)
+ out = a.UnembedOut.Forward(out)
+
+ out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.VHeadDim)
+
+ return a.OProj.Forward(out)
+}
+
+// DenseMLP implements the standard SwiGLU MLP for dense layers
+type DenseMLP struct {
+ GateProj nn.LinearLayer
+ UpProj nn.LinearLayer
+ DownProj nn.LinearLayer
+}
+
+// Forward applies the SwiGLU MLP
+func (m *DenseMLP) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.SiLU(m.GateProj.Forward(x))
+ up := m.UpProj.Forward(x)
+ return m.DownProj.Forward(mlx.Mul(gate, up))
+}
+
+// MoEGate implements the expert gating mechanism
+type MoEGate struct {
+ Gate nn.LinearLayer
+ EScoreCorrectionBias *mlx.Array
+}
+
+// Forward computes expert selection indices and scores
+func (g *MoEGate) Forward(x *mlx.Array, cfg *Config) (*mlx.Array, *mlx.Array) {
+ gates := g.Gate.Forward(x)
+
+ scores := mlx.Sigmoid(gates)
+ origScores := scores
+
+ if g.EScoreCorrectionBias != nil {
+ scores = mlx.Add(scores, g.EScoreCorrectionBias)
+ }
+
+ topK := cfg.NumExpertsPerTok
+ negScores := mlx.Neg(scores)
+ inds := mlx.Argpartition(negScores, int(topK)-1, -1)
+
+ dims := inds.Dims()
+ inds = mlx.SliceStartStop(inds, []int32{0, 0, 0}, []int32{int32(dims[0]), int32(dims[1]), topK})
+
+ scores = mlx.TakeAlongAxis(origScores, inds, -1)
+
+ if topK > 1 && cfg.NormTopKProb {
+ sumScores := mlx.Sum(scores, -1, true)
+ scores = mlx.Div(scores, sumScores)
+ }
+
+ scores = mlx.MulScalar(scores, cfg.RoutedScalingFactor)
+
+ return inds, scores
+}
+
+// SwitchMLP implements the MoE expert computation using stacked weights
+type SwitchMLP struct {
+ GateWeight *mlx.Array
+ UpWeight *mlx.Array
+ DownWeight *mlx.Array
+
+ GateWeightQ, GateScales, GateBiases *mlx.Array
+ UpWeightQ, UpScales, UpBiases *mlx.Array
+ DownWeightQ, DownScales, DownBiases *mlx.Array
+
+ GateBits int
+ UpBits int
+ DownBits int
+
+ GateGroupSize int
+ UpGroupSize int
+ DownGroupSize int
+
+ UseQuantized bool
+}
+
+// Forward applies the switched expert MLP
+func (s *SwitchMLP) Forward(x *mlx.Array, indices *mlx.Array, cfg *Config) *mlx.Array {
+ dims := x.Dims()
+ B, L := int32(dims[0]), int32(dims[1])
+ topK := cfg.NumExpertsPerTok
+
+ xExpanded := mlx.ExpandDims(mlx.ExpandDims(x, -2), -2)
+
+ xFlat := mlx.Reshape(xExpanded, B*L, 1, 1, cfg.HiddenSize)
+
+ idxFlat := mlx.Reshape(indices, B*L, topK)
+
+ doSort := B*L >= 64
+ var invOrder *mlx.Array
+ n := B * L * topK
+
+ if doSort {
+ idxAll := mlx.Flatten(idxFlat)
+ order := mlx.Argsort(idxAll, 0)
+ invOrder = mlx.Argsort(order, 0)
+ xFlat = mlx.ExpandDims(mlx.Take(mlx.Squeeze(xFlat, 1), mlx.FloorDivideScalar(order, topK), 0), 1)
+ idxFlat = mlx.Reshape(mlx.Take(idxAll, order, 0), n, 1)
+ }
+
+ var gate, up, hidden, down *mlx.Array
+
+ if s.UseQuantized {
+ gate = mlx.GatherQMM(xFlat, s.GateWeightQ, s.GateScales, s.GateBiases,
+ nil, idxFlat, true, s.GateGroupSize, s.GateBits, cfg.QuantMode, doSort)
+ up = mlx.GatherQMM(xFlat, s.UpWeightQ, s.UpScales, s.UpBiases,
+ nil, idxFlat, true, s.UpGroupSize, s.UpBits, cfg.QuantMode, doSort)
+
+ hidden = mlx.Mul(mlx.SiLU(gate), up)
+
+ down = mlx.GatherQMM(hidden, s.DownWeightQ, s.DownScales, s.DownBiases,
+ nil, idxFlat, true, s.DownGroupSize, s.DownBits, cfg.QuantMode, doSort)
+ } else {
+ gate = mlx.GatherMM(xFlat, mlx.Transpose(s.GateWeight, 0, 2, 1), nil, idxFlat, doSort)
+ up = mlx.GatherMM(xFlat, mlx.Transpose(s.UpWeight, 0, 2, 1), nil, idxFlat, doSort)
+
+ hidden = mlx.Mul(mlx.SiLU(gate), up)
+
+ down = mlx.GatherMM(hidden, mlx.Transpose(s.DownWeight, 0, 2, 1), nil, idxFlat, doSort)
+ }
+
+ if doSort {
+ down = mlx.Reshape(mlx.Take(mlx.Squeeze(mlx.Squeeze(down, 2), 1), invOrder, 0), B*L, topK, cfg.HiddenSize)
+ } else {
+ down = mlx.Squeeze(down, 2)
+ }
+
+ return mlx.Reshape(down, B, L, topK, cfg.HiddenSize)
+}
+
+// SharedExperts implements the shared expert MLP
+type SharedExperts struct {
+ GateProj nn.LinearLayer
+ UpProj nn.LinearLayer
+ DownProj nn.LinearLayer
+}
+
+// Forward applies the shared expert MLP
+func (s *SharedExperts) Forward(x *mlx.Array) *mlx.Array {
+ gate := mlx.SiLU(s.GateProj.Forward(x))
+ up := s.UpProj.Forward(x)
+ return s.DownProj.Forward(mlx.Mul(gate, up))
+}
+
+// MoE implements the full Mixture of Experts layer
+type MoE struct {
+ Gate *MoEGate
+ SwitchMLP *SwitchMLP
+ SharedExperts *SharedExperts
+}
+
+// Forward applies the MoE layer
+func (m *MoE) Forward(x *mlx.Array, cfg *Config) *mlx.Array {
+ dims := x.Dims()
+ B, L := int32(dims[0]), int32(dims[1])
+
+ inds, scores := m.Gate.Forward(x, cfg)
+
+ expertOut := m.SwitchMLP.Forward(x, inds, cfg)
+
+ scoresExpanded := mlx.ExpandDims(scores, -1)
+ y := mlx.Sum(mlx.Mul(expertOut, scoresExpanded), 2, false)
+
+ if m.SharedExperts != nil {
+ y = mlx.Add(y, m.SharedExperts.Forward(x))
+ }
+
+ return mlx.Reshape(y, B, L, cfg.HiddenSize)
+}
+
+// DenseBlock represents a dense transformer block (for first_k_dense_replace layers)
+type DenseBlock struct {
+ Attention *MLAAttention
+ MLP *DenseMLP
+ InputLayerNorm *nn.RMSNorm
+ PostAttentionLayerNorm *nn.RMSNorm
+}
+
+// Forward applies the dense block
+func (b *DenseBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
+ h := mlx.Add(x, r)
+
+ r = b.MLP.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps))
+ return mlx.Add(h, r)
+}
+
+// MoEBlock represents a MoE transformer block
+type MoEBlock struct {
+ Attention *MLAAttention
+ MoE *MoE
+ InputLayerNorm *nn.RMSNorm
+ PostAttentionLayerNorm *nn.RMSNorm
+}
+
+// Forward applies the MoE block
+func (b *MoEBlock) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ r := b.Attention.Forward(b.InputLayerNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg)
+ h := mlx.Add(x, r)
+
+ r = b.MoE.Forward(b.PostAttentionLayerNorm.Forward(h, cfg.RMSNormEps), cfg)
+ return mlx.Add(h, r)
+}
+
+// Block interface for both dense and MoE blocks
+type Block interface {
+ Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array
+}
+
+// Model represents the complete GLM4-MoE-Lite model
+type Model struct {
+ EmbedTokens *nn.Embedding
+ Layers []Block
+ Norm *nn.RMSNorm
+ LMHead nn.LinearLayer
+
+ tok *tokenizer.Tokenizer
+ *Config
+}
+
+// computeScale computes the attention scale.
+func computeScale(cfg *Config) float32 {
+ keyLength := cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
+ scale := float32(1.0 / math.Sqrt(float64(keyLength)))
+ if cfg.RopeScaling != nil && cfg.RopeScaling.MscaleAllDim > 0 && cfg.RopeScaling.Factor > 1 {
+ s := 0.1*cfg.RopeScaling.MscaleAllDim*float32(math.Log(float64(cfg.RopeScaling.Factor))) + 1.0
+ scale *= s * s
+ }
+ return scale
+}
+
+// supportsGatherQMM returns true if the quantization mode has GatherQMM kernel support.
+func supportsGatherQMM(mode string, bits int) bool {
+ return mode == "affine" && (bits == 4 || bits == 8)
+}
+
+// ExpertWeight holds a single expert's weight with optional quantization components.
+type ExpertWeight struct {
+ Weight *mlx.Array
+ Scales *mlx.Array
+ Biases *mlx.Array
+ Bits int
+ GroupSize int
+}
+
+// loadExpertWeight loads an expert weight from the tensor map.
+func loadExpertWeight(tensors map[string]*mlx.Array, path string, useQuantized bool, cfg *Config) *ExpertWeight {
+ w := tensors[path+".weight"]
+ if w == nil {
+ return nil
+ }
+
+ scales := tensors[path+".weight_scale"]
+ if scales != nil {
+ qbiases := tensors[path+".weight_qbias"]
+
+ groupSize, bits, mode := model.ResolveLinearQuantParams(
+ cfg.QuantGroupSize,
+ cfg.QuantBits,
+ cfg.QuantMode,
+ cfg.TensorQuant,
+ path+".weight",
+ w,
+ scales,
+ )
+
+ if useQuantized && supportsGatherQMM(mode, bits) {
+ return &ExpertWeight{Weight: w, Scales: scales, Biases: qbiases, Bits: bits, GroupSize: groupSize}
+ }
+
+ return &ExpertWeight{Weight: mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)}
+ }
+
+ return &ExpertWeight{Weight: w}
+}
+
+// StackedExpertWeights holds stacked weights for all experts.
+type StackedExpertWeights struct {
+ Weight *mlx.Array
+ Scales *mlx.Array
+ Biases *mlx.Array
+ Bits int
+ GroupSize int
+}
+
+// collectAndStackExpertWeights loads and stacks expert weights for one projection type.
+func collectAndStackExpertWeights(
+ tensors map[string]*mlx.Array,
+ prefix string,
+ projName string,
+ numExperts int32,
+ useQuantized bool,
+ cfg *Config,
+) *StackedExpertWeights {
+ var w, s, b []*mlx.Array
+ var bits, groupSize int
+
+ for e := int32(0); e < numExperts; e++ {
+ path := fmt.Sprintf("%s.mlp.experts.%d.%s", prefix, e, projName)
+ ew := loadExpertWeight(tensors, path, useQuantized, cfg)
+ if ew == nil {
+ continue
+ }
+ w = append(w, ew.Weight)
+ if ew.Scales != nil {
+ s = append(s, ew.Scales)
+ }
+ if ew.Biases != nil {
+ b = append(b, ew.Biases)
+ }
+ if e == 0 {
+ bits = ew.Bits
+ groupSize = ew.GroupSize
+ }
+ }
+
+ result := &StackedExpertWeights{Bits: bits, GroupSize: groupSize}
+ if len(w) > 0 {
+ result.Weight = mlx.Stack(w, 0)
+ if len(s) > 0 {
+ result.Scales = mlx.Stack(s, 0)
+ }
+ if len(b) > 0 {
+ result.Biases = mlx.Stack(b, 0)
+ }
+ }
+ return result
+}
+
+// sanitizeExpertWeights stacks individual expert weights into tensors.
+func sanitizeExpertWeights(tensors map[string]*mlx.Array, prefix string, numExperts int32, useQuantized bool, cfg *Config) (gate, up, down *StackedExpertWeights) {
+ gate = collectAndStackExpertWeights(tensors, prefix, "gate_proj", numExperts, useQuantized, cfg)
+ up = collectAndStackExpertWeights(tensors, prefix, "up_proj", numExperts, useQuantized, cfg)
+ down = collectAndStackExpertWeights(tensors, prefix, "down_proj", numExperts, useQuantized, cfg)
+ return gate, up, down
+}
+
+// sanitizeMLAWeights transforms kv_b_proj weights into absorbed MLA format.
+func sanitizeMLAWeights(tensors map[string]*mlx.Array, prefix string, cfg *Config) (*mlx.Array, *mlx.Array) {
+ path := prefix + ".self_attn.kv_b_proj"
+ w := tensors[path+".weight"]
+ if w == nil {
+ return nil, nil
+ }
+
+ // Check if quantized and dequantize
+ if scales := tensors[path+".weight_scale"]; scales != nil {
+ qbiases := tensors[path+".weight_qbias"]
+ groupSize, bits, mode := model.ResolveLinearQuantParams(
+ cfg.QuantGroupSize,
+ cfg.QuantBits,
+ cfg.QuantMode,
+ cfg.TensorQuant,
+ path+".weight",
+ w,
+ scales,
+ )
+ w = mlx.Dequantize(w, scales, qbiases, groupSize, bits, mode)
+ }
+
+ headDim := cfg.QKNopeHeadDim + cfg.VHeadDim
+ w = mlx.Reshape(w, cfg.NumAttentionHeads, headDim, cfg.KVLoraRank)
+
+ wk := mlx.SliceStartStop(w, []int32{0, 0, 0}, []int32{cfg.NumAttentionHeads, cfg.QKNopeHeadDim, cfg.KVLoraRank})
+ wv := mlx.SliceStartStop(w, []int32{0, cfg.QKNopeHeadDim, 0}, []int32{cfg.NumAttentionHeads, headDim, cfg.KVLoraRank})
+
+ embedQ := mlx.Transpose(wk, 0, 2, 1)
+ unembedOut := wv
+
+ return embedQ, unembedOut
+}
+
+// newModel creates a new GLM4-MoE-Lite model from a Root (config + tokenizer,
+// no weights loaded yet). Called by the registry via base.New().
+func newModel(root *model.Root) (base.Model, error) {
+ configData, err := root.Manifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("load config: %w", err)
+ }
+
+ var cfg Config
+ if err := json.Unmarshal(configData, &cfg); err != nil {
+ return nil, fmt.Errorf("parse config: %w", err)
+ }
+
+ cfg.QHeadDim = cfg.QKNopeHeadDim + cfg.QKRopeHeadDim
+ cfg.Scale = computeScale(&cfg)
+
+ // Set up quantization parameters from pre-scanned metadata
+ if qt := root.QuantType(); qt != "" {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
+ if gs := root.GroupSize(); gs > 0 {
+ cfg.QuantGroupSize = gs
+ }
+ } else {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
+ }
+ cfg.TensorQuant = root.AllTensorQuant()
+
+ // Load tokenizer
+ tokData, err := root.Manifest.ReadConfig("tokenizer.json")
+ if err != nil {
+ return nil, fmt.Errorf("load tokenizer config: %w", err)
+ }
+
+ tokConfig := &tokenizer.TokenizerConfig{
+ ConfigJSON: configData,
+ }
+
+ if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
+ tokConfig.GenerationConfigJSON = genConfigData
+ }
+
+ if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
+ tokConfig.TokenizerConfigJSON = tokConfigData
+ }
+
+ tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
+ if err != nil {
+ return nil, fmt.Errorf("parse tokenizer: %w", err)
+ }
+
+ m := &Model{
+ Layers: make([]Block, cfg.NumHiddenLayers),
+ Config: &cfg,
+ tok: tok,
+ }
+
+ return m, nil
+}
+
+// LoadWeights receives all tensors loaded from the manifest and assigns them
+// to model fields. Handles MLA absorption, expert stacking, and quantized
+// layer creation.
+func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
+ cfg := m.Config
+ linears := model.NewLinearFactory(tensors, cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode, cfg.TensorQuant)
+ useQuantized := supportsGatherQMM(cfg.QuantMode, cfg.QuantBits)
+ if !useQuantized && cfg.TensorQuant != nil {
+ for _, tq := range cfg.TensorQuant {
+ if tq == nil {
+ continue
+ }
+ _, bits, mode := model.QuantizationParams(tq.QuantType)
+ if supportsGatherQMM(mode, bits) {
+ useQuantized = true
+ break
+ }
+ }
+ }
+
+ // Load embedding
+ if w := tensors["model.embed_tokens.weight"]; w != nil {
+ m.EmbedTokens = nn.NewEmbedding(w)
+ }
+
+ // Load final norm
+ if w := tensors["model.norm.weight"]; w != nil {
+ m.Norm = nn.NewRMSNorm(w, cfg.RMSNormEps)
+ }
+
+ // Load LM head
+ m.LMHead = linears.Make("lm_head")
+
+ // Load layers
+ for i := int32(0); i < cfg.NumHiddenLayers; i++ {
+ prefix := fmt.Sprintf("model.layers.%d", i)
+
+ // Load attention (same for both block types)
+ attn := &MLAAttention{}
+ attn.QAProj = linears.Make(prefix + ".self_attn.q_a_proj")
+ if w := tensors[prefix+".self_attn.q_a_layernorm.weight"]; w != nil {
+ attn.QALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
+ }
+ attn.QBProj = linears.Make(prefix + ".self_attn.q_b_proj")
+ attn.KVAProjWithMQA = linears.Make(prefix + ".self_attn.kv_a_proj_with_mqa")
+ if w := tensors[prefix+".self_attn.kv_a_layernorm.weight"]; w != nil {
+ attn.KVALayerNorm = nn.NewRMSNorm(w, cfg.RMSNormEps)
+ }
+ attn.OProj = linears.Make(prefix + ".self_attn.o_proj")
+
+ // Sanitize MLA weights for absorbed attention
+ embedQ, unembedOut := sanitizeMLAWeights(tensors, prefix, cfg)
+ attn.EmbedQ = nn.NewMultiLinear(embedQ)
+ attn.UnembedOut = nn.NewMultiLinear(unembedOut)
+
+ inputLN := tensors[prefix+".input_layernorm.weight"]
+ postAttnLN := tensors[prefix+".post_attention_layernorm.weight"]
+
+ if i < cfg.FirstKDenseReplace {
+ // Dense block
+ block := &DenseBlock{Attention: attn}
+ if inputLN != nil {
+ block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
+ }
+ if postAttnLN != nil {
+ block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
+ }
+
+ block.MLP = &DenseMLP{
+ GateProj: linears.Make(prefix + ".mlp.gate_proj"),
+ UpProj: linears.Make(prefix + ".mlp.up_proj"),
+ DownProj: linears.Make(prefix + ".mlp.down_proj"),
+ }
+
+ m.Layers[i] = block
+ } else {
+ // MoE block
+ block := &MoEBlock{Attention: attn}
+ if inputLN != nil {
+ block.InputLayerNorm = nn.NewRMSNorm(inputLN, cfg.RMSNormEps)
+ }
+ if postAttnLN != nil {
+ block.PostAttentionLayerNorm = nn.NewRMSNorm(postAttnLN, cfg.RMSNormEps)
+ }
+
+ // Stack expert weights
+ gate, up, down := sanitizeExpertWeights(tensors, prefix, cfg.NRoutedExperts, useQuantized, cfg)
+
+ switchMLP := &SwitchMLP{UseQuantized: useQuantized}
+ if useQuantized {
+ switchMLP.GateWeightQ = gate.Weight
+ switchMLP.GateScales = gate.Scales
+ switchMLP.GateBiases = gate.Biases
+ switchMLP.GateBits = gate.Bits
+ switchMLP.GateGroupSize = gate.GroupSize
+ switchMLP.UpWeightQ = up.Weight
+ switchMLP.UpScales = up.Scales
+ switchMLP.UpBiases = up.Biases
+ switchMLP.UpBits = up.Bits
+ switchMLP.UpGroupSize = up.GroupSize
+ switchMLP.DownWeightQ = down.Weight
+ switchMLP.DownScales = down.Scales
+ switchMLP.DownBiases = down.Biases
+ switchMLP.DownBits = down.Bits
+ switchMLP.DownGroupSize = down.GroupSize
+ } else {
+ switchMLP.GateWeight = gate.Weight
+ switchMLP.UpWeight = up.Weight
+ switchMLP.DownWeight = down.Weight
+ }
+
+ moeGate := &MoEGate{}
+ moeGate.Gate = linears.Make(prefix + ".mlp.gate")
+ if bias := tensors[prefix+".mlp.gate.e_score_correction_bias"]; bias != nil {
+ moeGate.EScoreCorrectionBias = bias
+ }
+
+ block.MoE = &MoE{
+ Gate: moeGate,
+ SwitchMLP: switchMLP,
+ }
+
+ // Load shared experts if present
+ if cfg.NSharedExperts > 0 {
+ block.MoE.SharedExperts = &SharedExperts{
+ GateProj: linears.Make(prefix + ".mlp.shared_experts.gate_proj"),
+ UpProj: linears.Make(prefix + ".mlp.shared_experts.up_proj"),
+ DownProj: linears.Make(prefix + ".mlp.shared_experts.down_proj"),
+ }
+ }
+
+ m.Layers[i] = block
+ }
+ }
+
+ collected := mlx.Collect(m)
+ mlx.Eval(collected...)
+
+ return nil
+}
+
+// Forward computes the forward pass of the model
+func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
+ dims := tokens.Dims()
+ B, L := int32(dims[0]), int32(dims[1])
+
+ h := m.EmbedTokens.Forward(tokens)
+
+ for i, layer := range m.Layers {
+ var c cache.Cache
+ if caches != nil {
+ c = caches[i]
+ }
+ h = layer.Forward(h, c, B, L, m.Config)
+ }
+
+ h = m.Norm.Forward(h, m.RMSNormEps)
+ return h
+}
+
+// Unembed applies the LM head to get logits.
+func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
+ return m.LMHead.Forward(x)
+}
+
+// NumLayers returns the number of transformer layers
+func (m *Model) NumLayers() int { return len(m.Layers) }
+
+// MaxContextLength returns the maximum context length
+func (m *Model) MaxContextLength() int32 { return m.MaxPositionEmbeddings }
+
+// VocabSize returns the vocabulary size
+func (m *Model) VocabSize() int32 { return m.Config.VocabSize }
+
+// Tokenizer returns the model's tokenizer
+func (m *Model) Tokenizer() *tokenizer.Tokenizer { return m.tok }
+
+// NewCache creates a new KV cache for the model
+func (m *Model) NewCache(maxSeqLen int32) []cache.Cache {
+ caches := make([]cache.Cache, len(m.Layers))
+ for i := range caches {
+ caches[i] = cache.NewKVCache()
+ }
+ return caches
+}
+
+// FormatPrompt applies the GLM-4 chat template with thinking enabled by default.
+func (m *Model) FormatPrompt(prompt string) string {
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+}
+
+// FormatPromptWithThinking applies the GLM-4 chat template with explicit thinking control.
+func (m *Model) FormatPromptWithThinking(prompt string, think bool) string {
+ if think {
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+ }
+ return "[gMASK]<|user|>" + prompt + "<|assistant|>"
+}
+
+// NewRenderer returns a new Renderer for formatting multi-turn conversations.
+func (m *Model) NewRenderer() *Renderer {
+ return &Renderer{}
+}
+
+// NewParser returns a new Parser for extracting thinking and tool calls from output.
+func (m *Model) NewParser() *Parser {
+ return &Parser{}
+}
diff --git a/x/models/glm4_moe_lite/parser.go b/x/models/glm4_moe_lite/parser.go
new file mode 100644
index 00000000000..c81ec5a4043
--- /dev/null
+++ b/x/models/glm4_moe_lite/parser.go
@@ -0,0 +1,479 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "context"
+ "encoding/json"
+ "encoding/xml"
+ "fmt"
+ "log/slog"
+ "strings"
+ "unicode"
+
+ "github.com/ollama/ollama/api"
+ "github.com/ollama/ollama/logutil"
+)
+
+type parserState int
+
+const (
+ parserState_LookingForThinkingOpen parserState = iota
+ parserState_ThinkingStartedEatingWhitespace
+ parserState_CollectingThinking
+ parserState_ThinkingDoneEatingWhitespace
+ parserState_CollectingContent
+ parserState_ToolStartedEatingWhitespace
+ parserState_CollectingToolContent
+)
+
+const (
+ thinkingOpenTag = ""
+ thinkingCloseTag = ""
+ toolOpenTag = ""
+ toolCloseTag = ""
+)
+
+// Parser parses GLM4-MoE-Lite model output to extract thinking and tool calls.
+// GLM-4's prompt ends with when thinking is enabled, so the parser
+// must start in CollectingThinking state (the model outputs thinking content directly).
+type Parser struct {
+ state parserState
+ buffer strings.Builder
+ tools []api.Tool
+}
+
+// HasToolSupport returns true as GLM4 supports tool calling.
+func (p *Parser) HasToolSupport() bool {
+ return true
+}
+
+// HasThinkingSupport returns true as GLM4 supports thinking mode.
+func (p *Parser) HasThinkingSupport() bool {
+ return true
+}
+
+// Init initializes the parser with tools and thinking configuration.
+func (p *Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
+ p.tools = tools
+ // When thinking is enabled (nil or true), the prompt ends with ,
+ // so model output starts directly with thinking content (no opening tag).
+ if thinkValue == nil || thinkValue.Bool() {
+ p.state = parserState_CollectingThinking
+ }
+ return tools
+}
+
+type parserEvent interface {
+ isParserEvent()
+}
+
+type eventContent struct {
+ content string
+}
+
+func (eventContent) isParserEvent() {}
+
+type eventRawToolCall struct {
+ raw string
+}
+
+func (eventRawToolCall) isParserEvent() {}
+
+type eventThinkingContent struct {
+ content string
+}
+
+func (eventThinkingContent) isParserEvent() {}
+
+// Add processes new output text and returns parsed content, thinking, and tool calls.
+func (p *Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
+ p.buffer.WriteString(s)
+ events := p.parseEvents()
+
+ var toolCalls []api.ToolCall
+ var contentSb strings.Builder
+ var thinkingSb strings.Builder
+
+ for _, event := range events {
+ switch event := event.(type) {
+ case eventRawToolCall:
+ toolCall, err := parseToolCall(event, p.tools)
+ if err != nil {
+ slog.Warn("glm-4 tool call parsing failed", "error", err)
+ return "", "", nil, err
+ }
+ toolCalls = append(toolCalls, toolCall)
+ case eventThinkingContent:
+ thinkingSb.WriteString(event.content)
+ case eventContent:
+ contentSb.WriteString(event.content)
+ }
+ }
+
+ return contentSb.String(), thinkingSb.String(), toolCalls, nil
+}
+
+func (p *Parser) parseEvents() []parserEvent {
+ var all []parserEvent
+
+ keepLooping := true
+ for keepLooping {
+ var events []parserEvent
+ events, keepLooping = p.eat()
+ if len(events) > 0 {
+ all = append(all, events...)
+ }
+ }
+
+ if len(all) > 0 {
+ slog.Log(context.TODO(), logutil.LevelTrace, "glm-4 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
+ }
+
+ return all
+}
+
+// eatLeadingWhitespaceAndTransitionTo consumes leading whitespace from the buffer
+// and transitions to the next state. Returns (nil, false) if only whitespace remains
+// in the buffer (needs more input), or (nil, true) if we successfully transitioned.
+func (p *Parser) eatLeadingWhitespaceAndTransitionTo(nextState parserState) ([]parserEvent, bool) {
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ p.buffer.Reset()
+ if trimmed == "" {
+ return nil, false // Still only whitespace, keep waiting for more input
+ }
+ p.state = nextState
+ p.buffer.WriteString(trimmed)
+ return nil, true // Successfully transitioned
+}
+
+// splitAtTag splits the buffer at the given tag, returns the content before (trimmed of trailing whitespace),
+// the content after (optionally trimmed of leading whitespace), and updates the buffer
+func (p *Parser) splitAtTag(tag string, trimAfter bool) (string, string) {
+ split := strings.SplitN(p.buffer.String(), tag, 2)
+ before := split[0]
+ before = strings.TrimRightFunc(before, unicode.IsSpace)
+ after := split[1]
+ if trimAfter {
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ }
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ return before, after
+}
+
+func (p *Parser) eat() ([]parserEvent, bool) {
+ var events []parserEvent
+
+ switch p.state {
+ case parserState_LookingForThinkingOpen:
+ trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
+ if strings.HasPrefix(trimmed, thinkingOpenTag) {
+ // Found opening tag
+ after := strings.TrimPrefix(trimmed, thinkingOpenTag)
+ after = strings.TrimLeftFunc(after, unicode.IsSpace)
+ p.buffer.Reset()
+ p.buffer.WriteString(after)
+ if after == "" {
+ p.state = parserState_ThinkingStartedEatingWhitespace
+ } else {
+ p.state = parserState_CollectingThinking
+ }
+ return events, true
+ } else if strings.HasPrefix(thinkingOpenTag, trimmed) {
+ // Partial opening tag seen, keep accumulating
+ return events, false
+ } else if trimmed == "" {
+ // Only whitespace, keep accumulating
+ return events, false
+ } else {
+ // No thinking tag found, skip to content collection
+ p.state = parserState_CollectingContent
+ // Don't trim - we want to keep the original content
+ return events, true
+ }
+
+ case parserState_ThinkingStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingThinking)
+
+ case parserState_CollectingThinking:
+ acc := p.buffer.String()
+ if strings.Contains(acc, thinkingCloseTag) {
+ thinking, remaining := p.splitAtTag(thinkingCloseTag, true)
+ if len(thinking) > 0 {
+ events = append(events, eventThinkingContent{content: thinking})
+ }
+ if remaining == "" {
+ p.state = parserState_ThinkingDoneEatingWhitespace
+ } else {
+ p.state = parserState_CollectingContent
+ }
+ return events, true
+ } else if overlapLen := overlap(acc, thinkingCloseTag); overlapLen > 0 {
+ // Partial closing tag - withhold it along with any trailing whitespace before it
+ beforePartialTag := acc[:len(acc)-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ } else {
+ // Pure thinking content - withhold trailing whitespace (might precede closing tag)
+ whitespaceLen := trailingWhitespaceLen(acc)
+ ambiguousStart := len(acc) - whitespaceLen
+
+ unambiguous := acc[:ambiguousStart]
+ ambiguous := acc[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventThinkingContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case parserState_ThinkingDoneEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingContent)
+
+ case parserState_CollectingContent:
+ if strings.Contains(p.buffer.String(), toolOpenTag) {
+ before, after := p.splitAtTag(toolOpenTag, true)
+ if len(before) > 0 {
+ events = append(events, eventContent{content: before})
+ }
+ if after == "" {
+ p.state = parserState_ToolStartedEatingWhitespace
+ } else {
+ p.state = parserState_CollectingToolContent
+ }
+ return events, true
+ } else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
+ beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
+ trailingWsLen := trailingWhitespaceLen(beforePartialTag)
+ ambiguousStart := len(beforePartialTag) - trailingWsLen
+
+ unambiguous := p.buffer.String()[:ambiguousStart]
+ ambiguous := p.buffer.String()[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventContent{content: unambiguous})
+ }
+ return events, false
+ } else {
+ whitespaceLen := trailingWhitespaceLen(p.buffer.String())
+ ambiguousStart := len(p.buffer.String()) - whitespaceLen
+
+ unambiguous := p.buffer.String()[:ambiguousStart]
+ ambiguous := p.buffer.String()[ambiguousStart:]
+ p.buffer.Reset()
+ p.buffer.WriteString(ambiguous)
+ if len(unambiguous) > 0 {
+ events = append(events, eventContent{content: unambiguous})
+ }
+ return events, false
+ }
+
+ case parserState_ToolStartedEatingWhitespace:
+ return p.eatLeadingWhitespaceAndTransitionTo(parserState_CollectingToolContent)
+
+ case parserState_CollectingToolContent:
+ acc := p.buffer.String()
+ if strings.Contains(acc, toolCloseTag) {
+ toolContent, _ := p.splitAtTag(toolCloseTag, true)
+ if len(toolContent) == 0 {
+ slog.Warn("glm4 tool call closing tag found but no content before it")
+ }
+ events = append(events, eventRawToolCall{raw: toolContent})
+ p.state = parserState_CollectingContent
+ return events, true
+ } else {
+ // Keep accumulating - tool calls are not streamed
+ // We just wait for the closing tag
+ return events, false
+ }
+
+ default:
+ panic("unreachable")
+ }
+}
+
+// overlap returns the length of the overlap between the end of s and the start of tag.
+func overlap(s, tag string) int {
+ for i := 1; i <= len(tag) && i <= len(s); i++ {
+ if strings.HasSuffix(s, tag[:i]) {
+ return i
+ }
+ }
+ return 0
+}
+
+// trailingWhitespaceLen returns the length of trailing whitespace in s.
+func trailingWhitespaceLen(s string) int {
+ trimmed := strings.TrimRightFunc(s, unicode.IsSpace)
+ return len(s) - len(trimmed)
+}
+
+// ToolCallXML represents the structure of a GLM-4 tool call for XML parsing
+type ToolCallXML struct {
+ XMLName xml.Name `xml:"tool_call"`
+ Content string `xml:",chardata"` // Function name (text nodes between tags)
+ Keys []string `xml:"arg_key"` // All arg_key elements in document order
+ Values []string `xml:"arg_value"` // All arg_value elements in document order
+}
+
+// escapeContent escapes XML entities in text content while preserving arg_key/arg_value tags
+func escapeContent(s string) string {
+ var result strings.Builder
+ inTag := false
+
+ for i := range len(s) {
+ ch := s[i]
+
+ if ch == '<' {
+ // Check if this is a known tag
+ if strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") ||
+ strings.HasPrefix(s[i:], "") {
+ inTag = true
+ }
+ }
+
+ if inTag {
+ result.WriteByte(ch)
+ if ch == '>' {
+ inTag = false
+ }
+ } else {
+ // Escape special characters in text content
+ switch ch {
+ case '&':
+ result.WriteString("&")
+ case '<':
+ result.WriteString("<")
+ case '>':
+ result.WriteString(">")
+ default:
+ result.WriteByte(ch)
+ }
+ }
+ }
+
+ return result.String()
+}
+
+func parseToolCall(raw eventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
+ // Escape any unescaped entities in text content
+ escaped := escapeContent(raw.raw)
+
+ // Wrap the content in a root element to make it valid XML
+ xmlString := "" + escaped + ""
+
+ // Parse XML into struct
+ var parsed ToolCallXML
+ if err := xml.Unmarshal([]byte(xmlString), &parsed); err != nil {
+ return api.ToolCall{}, fmt.Errorf("failed to parse XML: %w", err)
+ }
+
+ // Extract and trim function name
+ functionName := strings.TrimSpace(parsed.Content)
+ if functionName == "" {
+ return api.ToolCall{}, fmt.Errorf("empty function name")
+ }
+
+ // Verify keys and values are paired correctly
+ if len(parsed.Keys) != len(parsed.Values) {
+ return api.ToolCall{}, fmt.Errorf("mismatched arg_key and arg_value counts: %d keys, %d values", len(parsed.Keys), len(parsed.Values))
+ }
+
+ // Find the matching tool to get parameter types
+ var matchedTool *api.Tool
+ for i := range tools {
+ if tools[i].Function.Name == functionName {
+ matchedTool = &tools[i]
+ break
+ }
+ }
+
+ // Build arguments map by pairing keys and values
+ toolCall := api.ToolCall{
+ Function: api.ToolCallFunction{
+ Name: functionName,
+ Arguments: api.NewToolCallFunctionArguments(),
+ },
+ }
+
+ for i := range parsed.Keys {
+ key := strings.TrimSpace(parsed.Keys[i])
+ value := parsed.Values[i] // Don't trim here - parseValue handles it
+
+ // Look up parameter type
+ var paramType api.PropertyType
+ if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
+ if prop, ok := matchedTool.Function.Parameters.Properties.Get(key); ok {
+ // Handle anyOf by collecting all types from the union
+ if len(prop.AnyOf) > 0 {
+ for _, anyOfProp := range prop.AnyOf {
+ paramType = append(paramType, anyOfProp.Type...)
+ }
+ } else {
+ paramType = prop.Type
+ }
+ }
+ }
+
+ // Parse value with type coercion
+ toolCall.Function.Arguments.Set(key, parseValue(value, paramType))
+ }
+
+ return toolCall, nil
+}
+
+// parseValue parses a string value and coerces it to the appropriate type based on paramType.
+func parseValue(value string, paramType api.PropertyType) any {
+ value = strings.TrimSpace(value)
+
+ // If no type specified, return as string
+ if len(paramType) == 0 {
+ return value
+ }
+
+ // Try to parse based on specified types
+ for _, t := range paramType {
+ switch t {
+ case "boolean":
+ if value == "true" {
+ return true
+ }
+ if value == "false" {
+ return false
+ }
+ case "integer":
+ var i int64
+ if _, err := fmt.Sscanf(value, "%d", &i); err == nil {
+ return i
+ }
+ case "number":
+ var f float64
+ if _, err := fmt.Sscanf(value, "%f", &f); err == nil {
+ return f
+ }
+ case "array", "object":
+ // Try to parse as JSON
+ var result any
+ if err := json.Unmarshal([]byte(value), &result); err == nil {
+ return result
+ }
+ }
+ }
+
+ // Default to string
+ return value
+}
diff --git a/x/models/glm4_moe_lite/parser_test.go b/x/models/glm4_moe_lite/parser_test.go
new file mode 100644
index 00000000000..0ce3827098b
--- /dev/null
+++ b/x/models/glm4_moe_lite/parser_test.go
@@ -0,0 +1,192 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestParserThinking(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ thinkEnabled bool
+ wantContent string
+ wantThinking string
+ wantToolCalls int
+ }{
+ {
+ name: "thinking enabled - simple thinking then content",
+ input: "Let me think about this...Here is my answer.",
+ thinkEnabled: true,
+ wantThinking: "Let me think about this...",
+ wantContent: "Here is my answer.",
+ },
+ {
+ name: "thinking enabled - only thinking",
+ input: "I need to consider multiple factors...",
+ thinkEnabled: true,
+ wantThinking: "I need to consider multiple factors...",
+ wantContent: "",
+ },
+ {
+ name: "thinking disabled - direct content",
+ input: "Here is my direct answer.",
+ thinkEnabled: false,
+ wantThinking: "",
+ wantContent: "Here is my direct answer.",
+ },
+ {
+ name: "thinking with tool call",
+ input: "Let me search for that...I'll use a tool.searchquerytest",
+ thinkEnabled: true,
+ wantThinking: "Let me search for that...",
+ wantContent: "I'll use a tool.",
+ wantToolCalls: 1,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ p := &Parser{}
+
+ var thinkValue *api.ThinkValue
+ if tt.thinkEnabled {
+ thinkValue = &api.ThinkValue{Value: true}
+ } else {
+ thinkValue = &api.ThinkValue{Value: false}
+ }
+
+ // Define tools for tool call tests
+ props := api.NewToolPropertiesMap()
+ props.Set("query", api.ToolProperty{Type: api.PropertyType{"string"}})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "search",
+ Parameters: api.ToolFunctionParameters{
+ Properties: props,
+ },
+ },
+ },
+ }
+
+ p.Init(tools, nil, thinkValue)
+
+ content, thinking, calls, err := p.Add(tt.input, true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if thinking != tt.wantThinking {
+ t.Errorf("thinking = %q, want %q", thinking, tt.wantThinking)
+ }
+ if content != tt.wantContent {
+ t.Errorf("content = %q, want %q", content, tt.wantContent)
+ }
+ if len(calls) != tt.wantToolCalls {
+ t.Errorf("len(calls) = %d, want %d", len(calls), tt.wantToolCalls)
+ }
+ })
+ }
+}
+
+func TestParserToolCall(t *testing.T) {
+ p := &Parser{}
+
+ props := api.NewToolPropertiesMap()
+ props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}})
+ props.Set("unit", api.ToolProperty{Type: api.PropertyType{"string"}})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Parameters: api.ToolFunctionParameters{
+ Properties: props,
+ },
+ },
+ },
+ }
+
+ // Initialize with thinking disabled
+ tv := &api.ThinkValue{Value: false}
+ p.Init(tools, nil, tv)
+
+ input := "get_weatherlocationSan Franciscounitcelsius"
+
+ _, _, calls, err := p.Add(input, true)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if len(calls) != 1 {
+ t.Fatalf("expected 1 tool call, got %d", len(calls))
+ }
+
+ call := calls[0]
+ if call.Function.Name != "get_weather" {
+ t.Errorf("function name = %q, want %q", call.Function.Name, "get_weather")
+ }
+
+ location, ok := call.Function.Arguments.Get("location")
+ if !ok || location != "San Francisco" {
+ t.Errorf("location = %v, want %q", location, "San Francisco")
+ }
+
+ unit, ok := call.Function.Arguments.Get("unit")
+ if !ok || unit != "celsius" {
+ t.Errorf("unit = %v, want %q", unit, "celsius")
+ }
+}
+
+func TestOverlap(t *testing.T) {
+ tests := []struct {
+ s string
+ tag string
+ want int
+ }{
+ {"hello<", "", 1},
+ {"hello", "", 2},
+ {"hello", 3},
+ {"hello", 4},
+ {"hello", 5},
+ {"hello", 6},
+ {"hello", 7},
+ {"hello", "", 8}, // Complete tag at end returns full length
+ {"hello", "", 0},
+ {"", "", 0},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.s+"_"+tt.tag, func(t *testing.T) {
+ got := overlap(tt.s, tt.tag)
+ if got != tt.want {
+ t.Errorf("overlap(%q, %q) = %d, want %d", tt.s, tt.tag, got, tt.want)
+ }
+ })
+ }
+}
+
+func TestTrailingWhitespaceLen(t *testing.T) {
+ tests := []struct {
+ s string
+ want int
+ }{
+ {"hello ", 3},
+ {"hello\n\t ", 3},
+ {"hello", 0},
+ {"", 0},
+ {" ", 3},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.s, func(t *testing.T) {
+ got := trailingWhitespaceLen(tt.s)
+ if got != tt.want {
+ t.Errorf("trailingWhitespaceLen(%q) = %d, want %d", tt.s, got, tt.want)
+ }
+ })
+ }
+}
diff --git a/x/models/glm4_moe_lite/render.go b/x/models/glm4_moe_lite/render.go
new file mode 100644
index 00000000000..4998604bf39
--- /dev/null
+++ b/x/models/glm4_moe_lite/render.go
@@ -0,0 +1,175 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/ollama/ollama/api"
+)
+
+// Renderer renders messages for GLM4-MoE-Lite models.
+//
+// GLM-4 Thinking Modes (ref: https://docs.z.ai/guides/capabilities/thinking-mode):
+//
+// 1. INTERLEAVED THINKING
+// The model thinks between tool calls and after receiving tool results.
+// This enables complex step-by-step reasoning: interpreting each tool output
+// before deciding what to do next. Thinking blocks are preserved and returned
+// with tool results to maintain reasoning continuity.
+//
+// 2. PRESERVED THINKING
+// The model retains reasoning content from previous assistant turns in context.
+// This preserves reasoning continuity across multi-turn conversations. The
+// upstream API has a "clear_thinking" parameter to control this:
+// - clear_thinking=true: clears reasoning from previous turns (outputs )
+// - clear_thinking=false: preserves ... blocks from previous turns
+//
+// 3. TURN-LEVEL THINKING
+// Controls whether the model should reason on each turn. The upstream API
+// uses "enable_thinking" parameter:
+// - enable_thinking=true: outputs to start reasoning
+// - enable_thinking=false: outputs to skip reasoning
+//
+// OLLAMA DEFAULTS:
+// - Thinking is ENABLED by default (thinkValue=nil or true outputs )
+// - Thinking is PRESERVED by default (reasoning content from previous turns is always
+// included in ... blocks, equivalent to clear_thinking=false)
+// - Users can disable thinking per-turn via thinkValue=false
+type Renderer struct{}
+
+// Render renders messages into the GLM4 chat format.
+func (r *Renderer) Render(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
+ var sb strings.Builder
+
+ sb.WriteString("[gMASK]")
+
+ if len(tools) > 0 {
+ sb.WriteString("<|system|>\n")
+ sb.WriteString("# Tools\n\n")
+ sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
+ sb.WriteString("You are provided with function signatures within XML tags:\n")
+ sb.WriteString("\n")
+ for _, tool := range tools {
+ d, _ := json.Marshal(tool)
+ sb.WriteString(formatToolJSON(d))
+ sb.WriteString("\n")
+ }
+ sb.WriteString("\n\n")
+ sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
+ sb.WriteString("{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...")
+ }
+
+ think := true
+ if thinkValue != nil && !thinkValue.Bool() {
+ think = false
+ }
+
+ for i, message := range messages {
+ switch message.Role {
+ case "user":
+ sb.WriteString("<|user|>")
+ sb.WriteString(message.Content)
+ case "assistant":
+ sb.WriteString("<|assistant|>")
+ if message.Thinking != "" {
+ sb.WriteString("" + message.Thinking + "")
+ } else {
+ sb.WriteString("")
+ }
+ if message.Content != "" {
+ sb.WriteString(message.Content)
+ }
+ if len(message.ToolCalls) > 0 {
+ for _, toolCall := range message.ToolCalls {
+ sb.WriteString("" + toolCall.Function.Name)
+ sb.WriteString(renderToolArguments(toolCall.Function.Arguments))
+ sb.WriteString("")
+ }
+ }
+ case "tool":
+ if i == 0 || messages[i-1].Role != "tool" {
+ sb.WriteString("<|observation|>")
+ }
+ sb.WriteString("")
+ sb.WriteString(message.Content)
+ sb.WriteString("")
+ case "system":
+ sb.WriteString("<|system|>")
+ sb.WriteString(message.Content)
+ }
+ }
+
+ sb.WriteString("<|assistant|>")
+ if think {
+ sb.WriteString("")
+ } else {
+ sb.WriteString("")
+ }
+
+ return sb.String(), nil
+}
+
+// renderToolArguments converts tool call arguments to GLM4 XML format.
+func renderToolArguments(args api.ToolCallFunctionArguments) string {
+ var sb strings.Builder
+ for key, value := range args.All() {
+ sb.WriteString("" + key + "")
+ var valueStr string
+ if str, ok := value.(string); ok {
+ valueStr = str
+ } else {
+ jsonBytes, err := json.Marshal(value)
+ if err != nil {
+ valueStr = fmt.Sprintf("%v", value)
+ } else {
+ valueStr = string(jsonBytes)
+ }
+ }
+
+ sb.WriteString("" + valueStr + "")
+ }
+
+ return sb.String()
+}
+
+// formatToolJSON formats JSON for GLM4 tool definitions by adding spaces after : and ,
+func formatToolJSON(raw []byte) string {
+ var sb strings.Builder
+ sb.Grow(len(raw) + len(raw)/10)
+
+ inString := false
+ escaped := false
+ for i := range raw {
+ ch := raw[i]
+ sb.WriteByte(ch)
+
+ if inString {
+ if escaped {
+ escaped = false
+ continue
+ }
+ if ch == '\\' {
+ escaped = true
+ continue
+ }
+ if ch == '"' {
+ inString = false
+ }
+ continue
+ }
+
+ if ch == '"' {
+ inString = true
+ continue
+ }
+
+ if ch == ':' || ch == ',' {
+ sb.WriteByte(' ')
+ }
+ }
+
+ return sb.String()
+}
diff --git a/x/models/glm4_moe_lite/render_test.go b/x/models/glm4_moe_lite/render_test.go
new file mode 100644
index 00000000000..f0d576bec85
--- /dev/null
+++ b/x/models/glm4_moe_lite/render_test.go
@@ -0,0 +1,205 @@
+//go:build mlx
+
+package glm4_moe_lite
+
+import (
+ "strings"
+ "testing"
+
+ "github.com/ollama/ollama/api"
+)
+
+func TestRendererSimple(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "Hello"},
+ }
+
+ // Thinking enabled (default)
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := "[gMASK]<|user|>Hello<|assistant|>"
+ if result != expected {
+ t.Errorf("result = %q, want %q", result, expected)
+ }
+}
+
+func TestRendererThinkingDisabled(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "Hello"},
+ }
+
+ tv := &api.ThinkValue{Value: false}
+
+ result, err := r.Render(messages, nil, tv)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ expected := "[gMASK]<|user|>Hello<|assistant|>"
+ if result != expected {
+ t.Errorf("result = %q, want %q", result, expected)
+ }
+}
+
+func TestRendererMultiTurn(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "What is 2+2?"},
+ {Role: "assistant", Content: "4", Thinking: "Let me calculate: 2+2=4"},
+ {Role: "user", Content: "And 3+3?"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Check key parts
+ if !strings.Contains(result, "[gMASK]") {
+ t.Error("missing [gMASK] prefix")
+ }
+ if !strings.Contains(result, "<|user|>What is 2+2?") {
+ t.Error("missing first user message")
+ }
+ if !strings.Contains(result, "<|assistant|>Let me calculate: 2+2=44") {
+ t.Error("missing assistant message with thinking")
+ }
+ if !strings.Contains(result, "<|user|>And 3+3?") {
+ t.Error("missing second user message")
+ }
+ if !strings.HasSuffix(result, "<|assistant|>") {
+ t.Errorf("should end with <|assistant|>, got suffix: %q", result[len(result)-30:])
+ }
+}
+
+func TestRendererWithSystem(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "system", Content: "You are a helpful assistant."},
+ {Role: "user", Content: "Hello"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if !strings.Contains(result, "<|system|>You are a helpful assistant.") {
+ t.Error("missing system message")
+ }
+}
+
+func TestRendererWithTools(t *testing.T) {
+ r := &Renderer{}
+
+ messages := []api.Message{
+ {Role: "user", Content: "What's the weather?"},
+ }
+
+ props := api.NewToolPropertiesMap()
+ props.Set("location", api.ToolProperty{Type: api.PropertyType{"string"}, Description: "The city"})
+ tools := []api.Tool{
+ {
+ Function: api.ToolFunction{
+ Name: "get_weather",
+ Description: "Get the weather for a location",
+ Parameters: api.ToolFunctionParameters{
+ Type: "object",
+ Properties: props,
+ Required: []string{"location"},
+ },
+ },
+ },
+ }
+
+ result, err := r.Render(messages, tools, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ // Check for tool system prompt
+ if !strings.Contains(result, "<|system|>") {
+ t.Error("missing system tag for tools")
+ }
+ if !strings.Contains(result, "# Tools") {
+ t.Error("missing tools header")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing tools tag")
+ }
+ if !strings.Contains(result, "get_weather") {
+ t.Error("missing tool name")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing closing tools tag")
+ }
+}
+
+func TestRendererWithToolCalls(t *testing.T) {
+ r := &Renderer{}
+
+ args := api.NewToolCallFunctionArguments()
+ args.Set("location", "San Francisco")
+
+ messages := []api.Message{
+ {Role: "user", Content: "What's the weather in SF?"},
+ {
+ Role: "assistant",
+ ToolCalls: []api.ToolCall{
+ {
+ Function: api.ToolCallFunction{
+ Name: "get_weather",
+ Arguments: args,
+ },
+ },
+ },
+ },
+ {Role: "tool", Content: "Sunny, 72F"},
+ }
+
+ result, err := r.Render(messages, nil, nil)
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+
+ if !strings.Contains(result, "get_weather") {
+ t.Error("missing tool call")
+ }
+ if !strings.Contains(result, "location") {
+ t.Error("missing arg_key")
+ }
+ if !strings.Contains(result, "San Francisco") {
+ t.Error("missing arg_value")
+ }
+ if !strings.Contains(result, "") {
+ t.Error("missing tool call closing tag")
+ }
+ if !strings.Contains(result, "<|observation|>") {
+ t.Error("missing observation tag")
+ }
+ if !strings.Contains(result, "Sunny, 72F") {
+ t.Error("missing tool response")
+ }
+}
+
+func TestFormatToolJSON(t *testing.T) {
+ input := []byte(`{"name":"test","value":123}`)
+ result := formatToolJSON(input)
+
+ // Should add spaces after : and ,
+ if !strings.Contains(result, ": ") {
+ t.Error("should add space after colon")
+ }
+ if !strings.Contains(result, ", ") {
+ t.Error("should add space after comma")
+ }
+}
diff --git a/x/models/llama/llama.go b/x/models/llama/llama.go
new file mode 100644
index 00000000000..61e51b35c94
--- /dev/null
+++ b/x/models/llama/llama.go
@@ -0,0 +1,323 @@
+//go:build mlx
+
+// Package llama provides a Llama-style decoder-only transformer for MLX.
+package llama
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/model"
+ "github.com/ollama/ollama/x/mlxrunner/model/base"
+ "github.com/ollama/ollama/x/models/nn"
+ "github.com/ollama/ollama/x/tokenizer"
+)
+
+func init() {
+ base.Register("LlamaForCausalLM", newModel)
+}
+
+// Config holds Llama model configuration.
+type Config struct {
+ HiddenSize int32 `json:"hidden_size"`
+ NumHiddenLayers int32 `json:"num_hidden_layers"`
+ IntermediateSize int32 `json:"intermediate_size"`
+ NumAttentionHeads int32 `json:"num_attention_heads"`
+ NumKeyValueHeads int32 `json:"num_key_value_heads"`
+ VocabSize int32 `json:"vocab_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ RopeTheta float32 `json:"rope_theta"`
+ MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
+ TieWordEmbeddings bool `json:"tie_word_embeddings"`
+
+ // Quantization parameters (set during load based on model quantization).
+ QuantGroupSize int `json:"-"`
+ QuantBits int `json:"-"`
+ QuantMode string `json:"-"`
+ TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
+
+ // Computed fields.
+ HeadDim int32 `json:"-"`
+ Scale float32 `json:"-"`
+}
+
+// Model is a Llama text model.
+type Model struct {
+ EmbedTokens *nn.Embedding
+ Layers []*Layer
+ Norm *nn.RMSNorm
+ LMHead nn.LinearLayer
+
+ tok *tokenizer.Tokenizer
+ *Config
+
+ weightPrefix string
+}
+
+type Layer struct {
+ Attention *Attention
+ MLP *MLP
+ AttentionNorm *nn.RMSNorm
+ MLPNorm *nn.RMSNorm
+}
+
+type Attention struct {
+ QProj nn.LinearLayer
+ KProj nn.LinearLayer
+ VProj nn.LinearLayer
+ OProj nn.LinearLayer
+}
+
+type MLP struct {
+ GateProj nn.LinearLayer
+ UpProj nn.LinearLayer
+ DownProj nn.LinearLayer
+}
+
+func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
+ for _, prefix := range []string{"", "language_model."} {
+ if tensors[prefix+"model.embed_tokens.weight"] != nil {
+ return prefix
+ }
+ }
+ return ""
+}
+
+func newModel(root *model.Root) (base.Model, error) {
+ configData, err := root.Manifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("load config: %w", err)
+ }
+
+ var cfg Config
+ if err := json.Unmarshal(configData, &cfg); err != nil {
+ return nil, fmt.Errorf("parse config: %w", err)
+ }
+
+ if cfg.HiddenSize <= 0 {
+ return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
+ }
+ if cfg.NumAttentionHeads <= 0 {
+ return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
+ }
+ if cfg.NumKeyValueHeads <= 0 {
+ cfg.NumKeyValueHeads = cfg.NumAttentionHeads
+ }
+ if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
+ return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
+ }
+ if cfg.HeadDim == 0 {
+ cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
+ }
+ if cfg.HeadDim <= 0 {
+ return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
+ }
+ if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
+ return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
+ }
+ if cfg.RopeTheta == 0 {
+ cfg.RopeTheta = 10000
+ }
+ if cfg.RMSNormEps == 0 {
+ cfg.RMSNormEps = 1e-5
+ }
+ cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
+
+ if qt := root.QuantType(); qt != "" {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
+ if gs := root.GroupSize(); gs > 0 {
+ cfg.QuantGroupSize = gs
+ }
+ } else {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
+ }
+ cfg.TensorQuant = root.AllTensorQuant()
+
+ tokData, err := root.Manifest.ReadConfig("tokenizer.json")
+ if err != nil {
+ return nil, fmt.Errorf("load tokenizer config: %w", err)
+ }
+
+ tokConfig := &tokenizer.TokenizerConfig{
+ ConfigJSON: configData,
+ }
+ if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
+ tokConfig.GenerationConfigJSON = genConfigData
+ }
+ if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
+ tokConfig.TokenizerConfigJSON = tokConfigData
+ }
+
+ tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
+ if err != nil {
+ return nil, fmt.Errorf("parse tokenizer: %w", err)
+ }
+
+ m := &Model{
+ Layers: make([]*Layer, cfg.NumHiddenLayers),
+ Config: &cfg,
+ tok: tok,
+ }
+
+ return m, nil
+}
+
+// LoadWeights receives all tensors loaded from the manifest and assigns them
+// to model fields.
+func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
+ m.weightPrefix = resolveWeightPrefix(tensors)
+ prefix := m.weightPrefix
+ linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
+
+ embedWeight := tensors[prefix+"model.embed_tokens.weight"]
+ if embedWeight == nil {
+ return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
+ }
+ m.EmbedTokens = nn.NewEmbedding(embedWeight)
+
+ normWeight := tensors[prefix+"model.norm.weight"]
+ if normWeight == nil {
+ return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
+ }
+ m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
+
+ if m.TieWordEmbeddings {
+ m.LMHead = nn.NewLinear(embedWeight, nil)
+ } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
+ m.LMHead = lmHead
+ } else if lmHead := linears.Make("lm_head"); lmHead != nil {
+ m.LMHead = lmHead
+ } else {
+ // Fallback used by many Llama checkpoints where output is tied.
+ m.LMHead = nn.NewLinear(embedWeight, nil)
+ }
+
+ for i := int32(0); i < m.NumHiddenLayers; i++ {
+ layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
+
+ layer := &Layer{
+ Attention: &Attention{},
+ MLP: &MLP{},
+ }
+
+ if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
+ layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+ if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
+ layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+
+ layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
+ layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
+ layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
+ layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
+
+ layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
+ layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
+ layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
+
+ if layer.AttentionNorm == nil {
+ return fmt.Errorf("layer %d: missing input_layernorm", i)
+ }
+ if layer.MLPNorm == nil {
+ return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
+ }
+ if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
+ return fmt.Errorf("layer %d: missing attention projections", i)
+ }
+ if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
+ return fmt.Errorf("layer %d: missing mlp projections", i)
+ }
+
+ m.Layers[i] = layer
+ }
+
+ collected := mlx.Collect(m)
+ mlx.Eval(collected...)
+
+ return nil
+}
+
+func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
+ dims := tokens.Dims()
+ B, L := int32(dims[0]), int32(dims[1])
+
+ h := m.EmbedTokens.Forward(tokens)
+ for i, layer := range m.Layers {
+ var c cache.Cache
+ if caches != nil && i < len(caches) {
+ c = caches[i]
+ }
+ h = layer.Forward(h, c, B, L, m.Config)
+ }
+
+ return m.Norm.Forward(h, m.RMSNormEps)
+}
+
+func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
+ return m.LMHead.Forward(x)
+}
+
+func (m *Model) NumLayers() int {
+ return len(m.Layers)
+}
+
+func (m *Model) Tokenizer() *tokenizer.Tokenizer {
+ return m.tok
+}
+
+func (m *Model) NewCaches() []cache.Cache {
+ caches := make([]cache.Cache, len(m.Layers))
+ for i := range caches {
+ caches[i] = cache.NewKVCache()
+ }
+ return caches
+}
+
+func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
+ return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
+}
+
+func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ q := a.QProj.Forward(x)
+ k := a.KProj.Forward(x)
+ v := a.VProj.Forward(x)
+
+ q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
+ q = mlx.Transpose(q, 0, 2, 1, 3)
+
+ k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
+ k = mlx.Transpose(k, 0, 2, 1, 3)
+
+ v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
+ v = mlx.Transpose(v, 0, 2, 1, 3)
+
+ offset := 0
+ if c != nil {
+ offset = c.Offset()
+ }
+ q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
+ k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
+
+ if c != nil {
+ k, v = c.Update(k, v)
+ }
+
+ repeatFactor := cfg.NumAttentionHeads / cfg.NumKeyValueHeads
+ if repeatFactor > 1 {
+ k = nn.RepeatKV(k, repeatFactor)
+ v = nn.RepeatKV(v, repeatFactor)
+ }
+
+ out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
+ out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
+ return a.OProj.Forward(out)
+}
+
+func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
+ return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
+}
diff --git a/x/models/nn/nn.go b/x/models/nn/nn.go
new file mode 100644
index 00000000000..3f57d483a12
--- /dev/null
+++ b/x/models/nn/nn.go
@@ -0,0 +1,188 @@
+//go:build mlx
+
+package nn
+
+import "github.com/ollama/ollama/x/mlxrunner/mlx"
+
+// Layer is the interface for neural network layers with a Forward method.
+type Layer interface {
+ Forward(x *mlx.Array) *mlx.Array
+}
+
+// LinearLayer is an interface for linear layers (both regular and quantized).
+type LinearLayer interface {
+ Forward(x *mlx.Array) *mlx.Array
+ OutputDim() int32
+}
+
+// Linear applies an affine transformation: y = x @ W.T + b
+type Linear struct {
+ Weight *mlx.Array
+ Bias *mlx.Array
+}
+
+func NewLinear(weight *mlx.Array, bias *mlx.Array) *Linear {
+ return &Linear{Weight: weight, Bias: bias}
+}
+
+func (l *Linear) Forward(x *mlx.Array) *mlx.Array {
+ w := l.Weight.Transpose(1, 0)
+ if l.Bias != nil && l.Bias.Valid() {
+ return l.Bias.Addmm(x, w, 1.0, 1.0)
+ }
+ return x.Matmul(w)
+}
+
+func (l *Linear) OutputDim() int32 {
+ return int32(l.Weight.Dim(0))
+}
+
+// QuantizedLinear applies an affine transformation using quantized weights.
+type QuantizedLinear struct {
+ Weight *mlx.Array // Quantized weight data
+ Scales *mlx.Array // Scale factors for dequantization
+ QBiases *mlx.Array // Quantization biases (nil for nvfp4)
+ Bias *mlx.Array // Layer bias [output_dims] or nil
+ GroupSize int
+ Bits int
+ Mode string
+}
+
+func NewQuantizedLinear(weight *mlx.Array, bias *mlx.Array, groupSize, bits int, mode string) *QuantizedLinear {
+ qw, scales, qbiases := mlx.Quantize(weight, groupSize, bits, mode)
+ if qbiases != nil {
+ mlx.Eval(qw, scales, qbiases)
+ } else {
+ mlx.Eval(qw, scales)
+ }
+ return &QuantizedLinear{
+ Weight: qw,
+ Scales: scales,
+ QBiases: qbiases,
+ Bias: bias,
+ GroupSize: groupSize,
+ Bits: bits,
+ Mode: mode,
+ }
+}
+
+func (ql *QuantizedLinear) Forward(x *mlx.Array) *mlx.Array {
+ out := mlx.QuantizedMatmul(x, ql.Weight, ql.Scales, ql.QBiases, true, ql.GroupSize, ql.Bits, ql.Mode)
+ if ql.Bias != nil && ql.Bias.Valid() {
+ out = out.Add(ql.Bias)
+ }
+ return out
+}
+
+func (ql *QuantizedLinear) OutputDim() int32 {
+ return int32(ql.Weight.Dim(0))
+}
+
+// RMSNorm represents an RMS normalization layer.
+type RMSNorm struct {
+ Weight *mlx.Array
+ Eps float32
+}
+
+func NewRMSNorm(weight *mlx.Array, eps float32) *RMSNorm {
+ return &RMSNorm{Weight: weight, Eps: eps}
+}
+
+func (rn *RMSNorm) Forward(x *mlx.Array, eps float32) *mlx.Array {
+ if eps == 0 {
+ eps = rn.Eps
+ }
+ return mlx.RMSNormFn(x, rn.Weight, eps)
+}
+
+// Embedding represents an embedding layer.
+type Embedding struct {
+ Weight *mlx.Array
+}
+
+func NewEmbedding(weight *mlx.Array) *Embedding {
+ return &Embedding{Weight: weight}
+}
+
+func (e *Embedding) Forward(indices *mlx.Array) *mlx.Array {
+ return e.Weight.TakeAxis(indices, 0)
+}
+
+// LayerNorm represents a standard layer normalization layer (with bias).
+type LayerNorm struct {
+ Weight *mlx.Array
+ Bias *mlx.Array
+ Eps float32
+}
+
+func (ln *LayerNorm) Forward(x *mlx.Array) *mlx.Array {
+ eps := ln.Eps
+ if eps == 0 {
+ eps = 1e-5
+ }
+ mean := mlx.Mean(x, -1, true)
+ centered := x.Subtract(mean)
+ variance := mlx.Mean(centered.Multiply(centered), -1, true)
+ normalized := centered.Multiply(mlx.RSqrt(mlx.AddScalar(variance, eps)))
+ out := normalized.Multiply(ln.Weight)
+ if ln.Bias != nil && ln.Bias.Valid() {
+ out = out.Add(ln.Bias)
+ }
+ return out
+}
+
+// MultiLinearLayer is an interface for per-head linear layers.
+type MultiLinearLayer interface {
+ Forward(x *mlx.Array) *mlx.Array
+}
+
+// MultiLinear performs per-head linear projections.
+// Weight shape: [num_heads, output_dims, input_dims]
+type MultiLinear struct {
+ Weight *mlx.Array
+}
+
+func NewMultiLinear(weight *mlx.Array) *MultiLinear {
+ return &MultiLinear{Weight: weight}
+}
+
+func (ml *MultiLinear) Forward(x *mlx.Array) *mlx.Array {
+ wT := ml.Weight.Transpose(0, 2, 1)
+ return x.Matmul(wT)
+}
+
+// RepeatKV repeats K/V tensors for grouped query attention.
+func RepeatKV(x *mlx.Array, repeatFactor int32) *mlx.Array {
+ if repeatFactor == 1 {
+ return x
+ }
+ shape := x.Dims()
+ x = x.ExpandDims(2)
+ reps := []int32{1, 1, repeatFactor, 1, 1}
+ x = mlx.Tile(x, reps)
+ return mlx.Reshape(x, int32(shape[0]), int32(shape[1])*repeatFactor, int32(shape[2]), int32(shape[3]))
+}
+
+// ApplyCausalMask applies causal (lower triangular) mask to attention scores.
+func ApplyCausalMask(scores *mlx.Array) *mlx.Array {
+ shape := scores.Dims()
+ seqLen := int32(shape[2])
+ mask := mlx.Tri(seqLen, seqLen, 0)
+ negInf := mlx.NewScalarArray(float32(-1e9))
+ mask = mask.ExpandDims(0).ExpandDims(0)
+ return mlx.Where(mask, scores, negInf)
+}
+
+// ApplyCausalMaskWithOffset applies causal mask for cached attention.
+func ApplyCausalMaskWithOffset(scores *mlx.Array, offset int32) *mlx.Array {
+ if offset == 0 {
+ return ApplyCausalMask(scores)
+ }
+ shape := scores.Dims()
+ queryLen := int32(shape[2])
+ keyLen := int32(shape[3])
+ mask := mlx.Tri(queryLen, keyLen, int(offset))
+ negInf := mlx.NewScalarArray(float32(-1e9))
+ mask = mask.ExpandDims(0).ExpandDims(0)
+ return mlx.Where(mask, scores, negInf)
+}
diff --git a/x/models/qwen3/qwen3.go b/x/models/qwen3/qwen3.go
new file mode 100644
index 00000000000..76170881a86
--- /dev/null
+++ b/x/models/qwen3/qwen3.go
@@ -0,0 +1,338 @@
+//go:build mlx
+
+// Package qwen3 provides the Qwen3 text model implementation for MLX.
+package qwen3
+
+import (
+ "encoding/json"
+ "fmt"
+ "math"
+
+ "github.com/ollama/ollama/x/mlxrunner/cache"
+ "github.com/ollama/ollama/x/mlxrunner/mlx"
+ "github.com/ollama/ollama/x/mlxrunner/model"
+ "github.com/ollama/ollama/x/mlxrunner/model/base"
+ "github.com/ollama/ollama/x/models/nn"
+ "github.com/ollama/ollama/x/tokenizer"
+)
+
+func init() {
+ base.Register("Qwen3ForCausalLM", newModel)
+}
+
+// Config holds Qwen3 model configuration.
+type Config struct {
+ HiddenSize int32 `json:"hidden_size"`
+ NumHiddenLayers int32 `json:"num_hidden_layers"`
+ IntermediateSize int32 `json:"intermediate_size"`
+ NumAttentionHeads int32 `json:"num_attention_heads"`
+ NumKeyValueHeads int32 `json:"num_key_value_heads"`
+ VocabSize int32 `json:"vocab_size"`
+ RMSNormEps float32 `json:"rms_norm_eps"`
+ RopeTheta float32 `json:"rope_theta"`
+ HeadDim int32 `json:"head_dim"`
+ MaxPositionEmbeddings int32 `json:"max_position_embeddings"`
+ TieWordEmbeddings bool `json:"tie_word_embeddings"`
+
+ // Quantization parameters (set during load based on model quantization).
+ QuantGroupSize int `json:"-"`
+ QuantBits int `json:"-"`
+ QuantMode string `json:"-"`
+ TensorQuant map[string]*model.TensorQuantInfo `json:"-"`
+
+ // Computed fields.
+ Scale float32 `json:"-"`
+ QKNormEps float32 `json:"-"`
+}
+
+// Model is the Qwen3 text-only model.
+type Model struct {
+ EmbedTokens *nn.Embedding
+ Layers []*Layer
+ Norm *nn.RMSNorm
+ LMHead nn.LinearLayer
+
+ tok *tokenizer.Tokenizer
+ *Config
+
+ weightPrefix string
+}
+
+// Layer is a single Qwen3 decoder block.
+type Layer struct {
+ Attention *Attention
+ MLP *MLP
+ AttentionNorm *nn.RMSNorm
+ MLPNorm *nn.RMSNorm
+}
+
+// Attention implements Qwen3 attention with Q/K norms.
+type Attention struct {
+ QProj nn.LinearLayer
+ KProj nn.LinearLayer
+ VProj nn.LinearLayer
+ OProj nn.LinearLayer
+ QNorm *nn.RMSNorm
+ KNorm *nn.RMSNorm
+}
+
+// MLP is the feed-forward network with SwiGLU activation.
+type MLP struct {
+ GateProj nn.LinearLayer
+ UpProj nn.LinearLayer
+ DownProj nn.LinearLayer
+}
+
+func resolveWeightPrefix(tensors map[string]*mlx.Array) string {
+ for _, prefix := range []string{"", "language_model."} {
+ if tensors[prefix+"model.embed_tokens.weight"] != nil {
+ return prefix
+ }
+ }
+ return ""
+}
+
+func newModel(root *model.Root) (base.Model, error) {
+ configData, err := root.Manifest.ReadConfig("config.json")
+ if err != nil {
+ return nil, fmt.Errorf("load config: %w", err)
+ }
+
+ var cfg Config
+ if err := json.Unmarshal(configData, &cfg); err != nil {
+ return nil, fmt.Errorf("parse config: %w", err)
+ }
+
+ if cfg.HiddenSize <= 0 {
+ return nil, fmt.Errorf("invalid hidden_size: %d", cfg.HiddenSize)
+ }
+ if cfg.NumAttentionHeads <= 0 {
+ return nil, fmt.Errorf("invalid num_attention_heads: %d", cfg.NumAttentionHeads)
+ }
+ if cfg.NumKeyValueHeads <= 0 {
+ cfg.NumKeyValueHeads = cfg.NumAttentionHeads
+ }
+ if cfg.HeadDim == 0 {
+ if cfg.HiddenSize%cfg.NumAttentionHeads != 0 {
+ return nil, fmt.Errorf("hidden_size (%d) must be divisible by num_attention_heads (%d)", cfg.HiddenSize, cfg.NumAttentionHeads)
+ }
+ cfg.HeadDim = cfg.HiddenSize / cfg.NumAttentionHeads
+ }
+ if cfg.HeadDim <= 0 {
+ return nil, fmt.Errorf("invalid head_dim: %d", cfg.HeadDim)
+ }
+ if cfg.NumAttentionHeads%cfg.NumKeyValueHeads != 0 {
+ return nil, fmt.Errorf("num_attention_heads (%d) must be divisible by num_key_value_heads (%d)", cfg.NumAttentionHeads, cfg.NumKeyValueHeads)
+ }
+ if cfg.RMSNormEps == 0 {
+ cfg.RMSNormEps = 1e-6
+ }
+ if cfg.RopeTheta == 0 {
+ cfg.RopeTheta = 1000000
+ }
+ cfg.Scale = float32(1.0 / math.Sqrt(float64(cfg.HeadDim)))
+ cfg.QKNormEps = 1e-6
+
+ if qt := root.QuantType(); qt != "" {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams(qt)
+ if gs := root.GroupSize(); gs > 0 {
+ cfg.QuantGroupSize = gs
+ }
+ } else {
+ cfg.QuantGroupSize, cfg.QuantBits, cfg.QuantMode = model.QuantizationParams("")
+ }
+ cfg.TensorQuant = root.AllTensorQuant()
+
+ tokData, err := root.Manifest.ReadConfig("tokenizer.json")
+ if err != nil {
+ return nil, fmt.Errorf("load tokenizer config: %w", err)
+ }
+
+ tokConfig := &tokenizer.TokenizerConfig{
+ ConfigJSON: configData,
+ }
+ if genConfigData, err := root.Manifest.ReadConfig("generation_config.json"); err == nil {
+ tokConfig.GenerationConfigJSON = genConfigData
+ }
+ if tokConfigData, err := root.Manifest.ReadConfig("tokenizer_config.json"); err == nil {
+ tokConfig.TokenizerConfigJSON = tokConfigData
+ }
+
+ tok, err := tokenizer.LoadFromBytesWithConfig(tokData, tokConfig)
+ if err != nil {
+ return nil, fmt.Errorf("parse tokenizer: %w", err)
+ }
+
+ m := &Model{
+ Layers: make([]*Layer, cfg.NumHiddenLayers),
+ Config: &cfg,
+ tok: tok,
+ }
+
+ return m, nil
+}
+
+// LoadWeights receives all tensors loaded from the manifest and assigns them
+// to model fields.
+func (m *Model) LoadWeights(tensors map[string]*mlx.Array) error {
+ m.weightPrefix = resolveWeightPrefix(tensors)
+ prefix := m.weightPrefix
+ linears := model.NewLinearFactory(tensors, m.QuantGroupSize, m.QuantBits, m.QuantMode, m.TensorQuant)
+
+ embedWeight := tensors[prefix+"model.embed_tokens.weight"]
+ if embedWeight == nil {
+ return fmt.Errorf("missing embedding weight: %smodel.embed_tokens.weight", prefix)
+ }
+ m.EmbedTokens = nn.NewEmbedding(embedWeight)
+
+ normWeight := tensors[prefix+"model.norm.weight"]
+ if normWeight == nil {
+ return fmt.Errorf("missing final norm weight: %smodel.norm.weight", prefix)
+ }
+ m.Norm = nn.NewRMSNorm(normWeight, m.RMSNormEps)
+
+ if m.TieWordEmbeddings {
+ m.LMHead = nn.NewLinear(embedWeight, nil)
+ } else if lmHead := linears.Make(prefix + "lm_head"); lmHead != nil {
+ m.LMHead = lmHead
+ } else if lmHead := linears.Make("lm_head"); lmHead != nil {
+ m.LMHead = lmHead
+ } else {
+ // Qwen3 checkpoints commonly tie output projection to embeddings.
+ m.LMHead = nn.NewLinear(embedWeight, nil)
+ }
+
+ for i := int32(0); i < m.NumHiddenLayers; i++ {
+ layerPrefix := fmt.Sprintf("%smodel.layers.%d", prefix, i)
+
+ layer := &Layer{
+ Attention: &Attention{},
+ MLP: &MLP{},
+ }
+
+ if w := tensors[layerPrefix+".input_layernorm.weight"]; w != nil {
+ layer.AttentionNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+ if w := tensors[layerPrefix+".post_attention_layernorm.weight"]; w != nil {
+ layer.MLPNorm = nn.NewRMSNorm(w, m.RMSNormEps)
+ }
+
+ layer.Attention.QProj = linears.Make(layerPrefix + ".self_attn.q_proj")
+ layer.Attention.KProj = linears.Make(layerPrefix + ".self_attn.k_proj")
+ layer.Attention.VProj = linears.Make(layerPrefix + ".self_attn.v_proj")
+ layer.Attention.OProj = linears.Make(layerPrefix + ".self_attn.o_proj")
+
+ if w := tensors[layerPrefix+".self_attn.q_norm.weight"]; w != nil {
+ layer.Attention.QNorm = nn.NewRMSNorm(w, m.QKNormEps)
+ }
+ if w := tensors[layerPrefix+".self_attn.k_norm.weight"]; w != nil {
+ layer.Attention.KNorm = nn.NewRMSNorm(w, m.QKNormEps)
+ }
+
+ layer.MLP.GateProj = linears.Make(layerPrefix + ".mlp.gate_proj")
+ layer.MLP.UpProj = linears.Make(layerPrefix + ".mlp.up_proj")
+ layer.MLP.DownProj = linears.Make(layerPrefix + ".mlp.down_proj")
+
+ if layer.AttentionNorm == nil {
+ return fmt.Errorf("layer %d: missing input_layernorm", i)
+ }
+ if layer.MLPNorm == nil {
+ return fmt.Errorf("layer %d: missing post_attention_layernorm", i)
+ }
+ if layer.Attention.QProj == nil || layer.Attention.KProj == nil || layer.Attention.VProj == nil || layer.Attention.OProj == nil {
+ return fmt.Errorf("layer %d: missing attention projections", i)
+ }
+ if layer.Attention.QNorm == nil || layer.Attention.KNorm == nil {
+ return fmt.Errorf("layer %d: missing attention q/k norms", i)
+ }
+ if layer.MLP.GateProj == nil || layer.MLP.UpProj == nil || layer.MLP.DownProj == nil {
+ return fmt.Errorf("layer %d: missing mlp projections", i)
+ }
+
+ m.Layers[i] = layer
+ }
+
+ collected := mlx.Collect(m)
+ mlx.Eval(collected...)
+
+ return nil
+}
+
+func (m *Model) Forward(tokens *mlx.Array, caches []cache.Cache) *mlx.Array {
+ dims := tokens.Dims()
+ B, L := int32(dims[0]), int32(dims[1])
+
+ h := m.EmbedTokens.Forward(tokens)
+ for i, layer := range m.Layers {
+ var c cache.Cache
+ if caches != nil && i < len(caches) {
+ c = caches[i]
+ }
+ h = layer.Forward(h, c, B, L, m.Config)
+ }
+
+ return m.Norm.Forward(h, m.RMSNormEps)
+}
+
+func (m *Model) Unembed(x *mlx.Array) *mlx.Array {
+ return m.LMHead.Forward(x)
+}
+
+func (m *Model) NumLayers() int {
+ return len(m.Layers)
+}
+
+func (m *Model) Tokenizer() *tokenizer.Tokenizer {
+ return m.tok
+}
+
+func (m *Model) NewCaches() []cache.Cache {
+ caches := make([]cache.Cache, len(m.Layers))
+ for i := range caches {
+ caches[i] = cache.NewKVCache()
+ }
+ return caches
+}
+
+func (l *Layer) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ h := mlx.Add(x, l.Attention.Forward(l.AttentionNorm.Forward(x, cfg.RMSNormEps), c, B, L, cfg))
+ return mlx.Add(h, l.MLP.Forward(l.MLPNorm.Forward(h, cfg.RMSNormEps)))
+}
+
+func (a *Attention) Forward(x *mlx.Array, c cache.Cache, B, L int32, cfg *Config) *mlx.Array {
+ q := a.QProj.Forward(x)
+ k := a.KProj.Forward(x)
+ v := a.VProj.Forward(x)
+
+ q = mlx.Reshape(q, B, L, cfg.NumAttentionHeads, cfg.HeadDim)
+ k = mlx.Reshape(k, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
+ v = mlx.Reshape(v, B, L, cfg.NumKeyValueHeads, cfg.HeadDim)
+
+ q = a.QNorm.Forward(q, cfg.QKNormEps)
+ k = a.KNorm.Forward(k, cfg.QKNormEps)
+
+ q = mlx.Transpose(q, 0, 2, 1, 3)
+ k = mlx.Transpose(k, 0, 2, 1, 3)
+ v = mlx.Transpose(v, 0, 2, 1, 3)
+
+ offset := 0
+ if c != nil {
+ offset = c.Offset()
+ }
+ q = mlx.RoPEWithBase(q, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
+ k = mlx.RoPEWithBase(k, int(cfg.HeadDim), false, cfg.RopeTheta, 1.0, offset)
+
+ if c != nil {
+ k, v = c.Update(k, v)
+ }
+
+ // MLX SDPA supports grouped-query attention directly (Q heads can be a
+ // multiple of K/V heads), so avoid materializing repeated K/V tensors.
+ out := mlx.ScaledDotProductAttentionCausal(q, k, v, cfg.Scale, L > 1)
+ out = mlx.Reshape(mlx.Transpose(out, 0, 2, 1, 3), B, L, cfg.NumAttentionHeads*cfg.HeadDim)
+ return a.OProj.Forward(out)
+}
+
+func (m *MLP) Forward(x *mlx.Array) *mlx.Array {
+ return m.DownProj.Forward(mlx.Mul(mlx.SiLU(m.GateProj.Forward(x)), m.UpProj.Forward(x)))
+}
diff --git a/x/server/show.go b/x/server/show.go
index 8cadb2c6207..f5f45f8c71a 100644
--- a/x/server/show.go
+++ b/x/server/show.go
@@ -5,11 +5,14 @@ import (
"encoding/json"
"fmt"
"io"
+ "math"
"os"
+ "sort"
"strings"
"github.com/ollama/ollama/api"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/manifest"
+ "github.com/ollama/ollama/types/model"
)
// modelConfig represents the HuggingFace config.json structure
@@ -35,28 +38,36 @@ type modelConfig struct {
// GetSafetensorsLLMInfo extracts model information from safetensors LLM models.
// It reads the config.json layer and returns a map compatible with GGML's KV format.
-func GetSafetensorsLLMInfo(modelName string) (map[string]any, error) {
- manifest, err := imagegen.LoadManifest(modelName)
+func GetSafetensorsLLMInfo(name model.Name) (map[string]any, error) {
+ mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
var config modelConfig
- if err := manifest.ReadConfigJSON("config.json", &config); err != nil {
+ if err := mf.ReadConfigJSON("config.json", &config); err != nil {
return nil, fmt.Errorf("failed to read config.json: %w", err)
}
// Calculate total tensor bytes from manifest layers
var totalBytes int64
var tensorCount int64
- for _, layer := range manifest.Manifest.Layers {
- if layer.MediaType == "application/vnd.ollama.image.tensor" {
+ for _, layer := range mf.Layers {
+ if layer.MediaType == manifest.MediaTypeImageTensor {
totalBytes += layer.Size
tensorCount++
}
}
- return buildModelInfo(config, totalBytes, tensorCount), nil
+ info := buildModelInfo(config, totalBytes, tensorCount)
+
+ // For quantized models, byte-based estimation can significantly undercount
+ // parameters. Prefer exact counting from tensor shapes in safetensors headers.
+ if paramCount, err := getParameterCountFromManifest(mf); err == nil && paramCount > 0 {
+ info["general.parameter_count"] = paramCount
+ }
+
+ return info, nil
}
// buildModelInfo constructs the model info map from config and tensor stats.
@@ -104,9 +115,9 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
bytesPerParam = 1
}
- // Subtract safetensors header overhead (88 bytes per tensor file)
- // Each tensor is stored as a minimal safetensors file
- totalBytes := totalTensorBytes - tensorCount*88
+ // Subtract safetensors header overhead per tensor blob.
+ // Headers include __metadata__ with the tensor name, so overhead is ~150 bytes on average.
+ totalBytes := totalTensorBytes - tensorCount*150
paramCount := totalBytes / bytesPerParam
@@ -149,75 +160,181 @@ func buildModelInfo(config modelConfig, totalTensorBytes, tensorCount int64) map
return info
}
+// getParameterCountFromManifest counts model parameters from tensor shapes.
+// This accounts for quantized tensors by using unpacked shapes from
+// getTensorInfoFromManifest.
+func getParameterCountFromManifest(mf *manifest.Manifest) (int64, error) {
+ tensors, err := getTensorInfoFromManifest(mf)
+ if err != nil {
+ return 0, err
+ }
+
+ var total int64
+ for _, tensor := range tensors {
+ if len(tensor.Shape) == 0 {
+ continue
+ }
+
+ elements := int64(1)
+ for _, dim := range tensor.Shape {
+ if dim == 0 {
+ elements = 0
+ break
+ }
+
+ if dim > uint64(math.MaxInt64) {
+ return 0, fmt.Errorf("tensor %s dimension too large: %d", tensor.Name, dim)
+ }
+
+ d := int64(dim)
+ if elements > math.MaxInt64/d {
+ return 0, fmt.Errorf("tensor %s element count overflow", tensor.Name)
+ }
+ elements *= d
+ }
+
+ if elements == 0 {
+ continue
+ }
+ if total > math.MaxInt64-elements {
+ return 0, fmt.Errorf("total parameter count overflow")
+ }
+ total += elements
+ }
+
+ return total, nil
+}
+
// GetSafetensorsTensorInfo extracts tensor information from safetensors model layers.
// Each tensor is stored as a minimal safetensors file with an 88-byte header containing metadata.
-func GetSafetensorsTensorInfo(modelName string) ([]api.Tensor, error) {
- manifest, err := imagegen.LoadManifest(modelName)
+func GetSafetensorsTensorInfo(name model.Name) ([]api.Tensor, error) {
+ mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return nil, fmt.Errorf("failed to load manifest: %w", err)
}
- return getTensorInfoFromManifest(manifest)
+ return getTensorInfoFromManifest(mf)
}
// getTensorInfoFromManifest extracts tensor info from a manifest.
// This is separated for testability.
-func getTensorInfoFromManifest(manifest *imagegen.ModelManifest) ([]api.Tensor, error) {
+// For quantized tensors, reads quant_type from blob __metadata__.
+// For packed blobs (multiple tensors per blob), enumerates all tensors in the blob.
+func getTensorInfoFromManifest(mf *manifest.Manifest) ([]api.Tensor, error) {
var tensors []api.Tensor
- for _, layer := range manifest.Manifest.Layers {
- if layer.MediaType != "application/vnd.ollama.image.tensor" {
+ for _, layer := range mf.Layers {
+ if layer.MediaType != manifest.MediaTypeImageTensor {
continue
}
- // Read the safetensors header from the blob
- blobPath := manifest.BlobPath(layer.Digest)
- info, err := readSafetensorsHeader(blobPath)
+ // Read all tensor entries from the safetensors header
+ blobPath, err := manifest.BlobsPath(layer.Digest)
if err != nil {
- // Skip tensors we can't read
continue
}
- // Convert shape from int to uint64
- shape := make([]uint64, len(info.Shape))
- for i, s := range info.Shape {
- shape[i] = uint64(s)
+ f, err := os.Open(blobPath)
+ if err != nil {
+ continue
}
- tensors = append(tensors, api.Tensor{
- Name: layer.Name,
- Type: info.Dtype,
- Shape: shape,
- })
+ allInfos, err := parseSafetensorsAllHeaders(f)
+ f.Close()
+ if err != nil {
+ continue
+ }
+
+ // Determine if this is a packed blob (multiple main tensors)
+ isPacked := len(allInfos) > 1
+
+ for _, info := range allInfos {
+ tensorName := layer.Name
+ if isPacked {
+ // For packed blobs, use the tensor name from the header
+ tensorName = info.Name
+ }
+
+ if info.QuantType != "" {
+ quantType := strings.ToUpper(info.QuantType)
+
+ shape := make([]uint64, len(info.Shape))
+ for i, s := range info.Shape {
+ shape[i] = uint64(s)
+ }
+
+ var packFactor int64
+ switch strings.ToLower(info.QuantType) {
+ case "int4", "nvfp4":
+ packFactor = 8
+ case "int8", "mxfp8":
+ packFactor = 4
+ }
+ if packFactor > 0 && len(shape) >= 2 {
+ shape[len(shape)-1] = uint64(info.Shape[len(info.Shape)-1] * packFactor)
+ }
+
+ tensors = append(tensors, api.Tensor{
+ Name: tensorName,
+ Type: quantType,
+ Shape: shape,
+ })
+ } else {
+ shape := make([]uint64, len(info.Shape))
+ for i, s := range info.Shape {
+ shape[i] = uint64(s)
+ }
+
+ tensors = append(tensors, api.Tensor{
+ Name: tensorName,
+ Type: info.Dtype,
+ Shape: shape,
+ })
+ }
+ }
}
+ sort.Slice(tensors, func(i, j int) bool {
+ return tensors[i].Name < tensors[j].Name
+ })
+
return tensors, nil
}
// GetSafetensorsDtype returns the quantization type for a safetensors model.
-// If the model is quantized (has _scale tensors), returns the quantization type (e.g., "FP8").
-// Otherwise returns the torch_dtype from config.json.
-func GetSafetensorsDtype(modelName string) (string, error) {
- manifest, err := imagegen.LoadManifest(modelName)
+// Reads quant_type from the first tensor blob's __metadata__.
+// Falls back to torch_dtype from config.json if no quant metadata.
+func GetSafetensorsDtype(name model.Name) (string, error) {
+ mf, err := manifest.ParseNamedManifest(name)
if err != nil {
return "", fmt.Errorf("failed to load manifest: %w", err)
}
- // Check if model is quantized by looking for _scale tensors
- for _, layer := range manifest.Manifest.Layers {
- if layer.MediaType == "application/vnd.ollama.image.tensor" {
- if strings.HasSuffix(layer.Name, "_scale") {
- // Model is quantized - return FP8 (affine quantization)
- return "FP8", nil
- }
+ // Check first tensor blob for quant_type metadata
+ for _, layer := range mf.Layers {
+ if layer.MediaType != manifest.MediaTypeImageTensor {
+ continue
+ }
+ blobPath, err := manifest.BlobsPath(layer.Digest)
+ if err != nil {
+ continue
}
+ info, err := readSafetensorsHeader(blobPath)
+ if err != nil {
+ continue
+ }
+ if info.QuantType != "" {
+ return strings.ToUpper(info.QuantType), nil
+ }
+ // Only check the first tensor blob
+ break
}
// Not quantized - return torch_dtype from config.json
var cfg struct {
TorchDtype string `json:"torch_dtype"`
}
- if err := manifest.ReadConfigJSON("config.json", &cfg); err != nil {
+ if err := mf.ReadConfigJSON("config.json", &cfg); err != nil {
return "", fmt.Errorf("failed to read config.json: %w", err)
}
@@ -226,8 +343,11 @@ func GetSafetensorsDtype(modelName string) (string, error) {
// safetensorsTensorInfo holds metadata about a tensor from a safetensors header
type safetensorsTensorInfo struct {
- Dtype string `json:"dtype"`
- Shape []int64 `json:"shape"`
+ Name string // tensor name from the header key
+ Dtype string `json:"dtype"`
+ Shape []int64 `json:"shape"`
+ QuantType string // from __metadata__.quant_type (e.g., "int4", "int8", "nvfp4", "mxfp8")
+ GroupSize string // from __metadata__.group_size (e.g., "32", "64")
}
// readSafetensorsHeader reads the JSON header from a safetensors file to get tensor metadata.
@@ -244,6 +364,7 @@ func readSafetensorsHeader(path string) (*safetensorsTensorInfo, error) {
// parseSafetensorsHeader parses a safetensors header from a reader.
// This is separated for testability.
+// Parses __metadata__ for quant_type and group_size if present.
func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
// Read header size (8 bytes, little endian)
var headerSize uint64
@@ -268,7 +389,31 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
return nil, fmt.Errorf("failed to parse header: %w", err)
}
- // Find the first (and should be only) tensor entry
+ // Parse metadata if present
+ var quantType, groupSize string
+ if metaRaw, ok := header["__metadata__"]; ok {
+ var meta map[string]string
+ if json.Unmarshal(metaRaw, &meta) == nil {
+ quantType = meta["quant_type"]
+ groupSize = meta["group_size"]
+ }
+ }
+
+ // Find the main tensor entry (not __metadata__, .scale, or .bias)
+ for name, raw := range header {
+ if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
+ continue
+ }
+ var info safetensorsTensorInfo
+ if err := json.Unmarshal(raw, &info); err != nil {
+ return nil, fmt.Errorf("failed to parse tensor info: %w", err)
+ }
+ info.QuantType = quantType
+ info.GroupSize = groupSize
+ return &info, nil
+ }
+
+ // Fall back to first non-metadata tensor entry
for name, raw := range header {
if name == "__metadata__" {
continue
@@ -277,8 +422,134 @@ func parseSafetensorsHeader(r io.Reader) (*safetensorsTensorInfo, error) {
if err := json.Unmarshal(raw, &info); err != nil {
return nil, fmt.Errorf("failed to parse tensor info: %w", err)
}
+ info.QuantType = quantType
+ info.GroupSize = groupSize
return &info, nil
}
return nil, fmt.Errorf("no tensor found in header")
}
+
+// parseSafetensorsAllHeaders parses all tensor entries from a safetensors header.
+// Returns one safetensorsTensorInfo per main tensor (skipping __metadata__, .scale, .bias).
+// For packed blobs this returns multiple entries; for single-tensor blobs, one entry.
+// Each tensor's quant type is inferred from its shape and the presence of .scale/.bias entries
+// when no global __metadata__ quant_type is present.
+func parseSafetensorsAllHeaders(r io.Reader) ([]safetensorsTensorInfo, error) {
+ var headerSize uint64
+ if err := binary.Read(r, binary.LittleEndian, &headerSize); err != nil {
+ return nil, fmt.Errorf("failed to read header size: %w", err)
+ }
+
+ if headerSize > 100*1024*1024 { // 100MB limit for packed blob headers
+ return nil, fmt.Errorf("header size too large: %d", headerSize)
+ }
+
+ headerBytes := make([]byte, headerSize)
+ if _, err := io.ReadFull(r, headerBytes); err != nil {
+ return nil, fmt.Errorf("failed to read header: %w", err)
+ }
+
+ var header map[string]json.RawMessage
+ if err := json.Unmarshal(headerBytes, &header); err != nil {
+ return nil, fmt.Errorf("failed to parse header: %w", err)
+ }
+
+ // Parse global metadata if present
+ var globalQuantType, globalGroupSize string
+ if metaRaw, ok := header["__metadata__"]; ok {
+ var meta map[string]string
+ if json.Unmarshal(metaRaw, &meta) == nil {
+ globalQuantType = meta["quant_type"]
+ globalGroupSize = meta["group_size"]
+ }
+ }
+
+ // Build a set of all keys for checking .scale/.bias presence
+ headerKeys := make(map[string]bool, len(header))
+ for k := range header {
+ headerKeys[k] = true
+ }
+
+ // Collect all main tensor entries (sorted for deterministic output)
+ var mainNames []string
+ for name := range header {
+ if name == "__metadata__" || strings.HasSuffix(name, ".scale") || strings.HasSuffix(name, ".bias") {
+ continue
+ }
+ mainNames = append(mainNames, name)
+ }
+ sort.Strings(mainNames)
+
+ var results []safetensorsTensorInfo
+ for _, name := range mainNames {
+ var info safetensorsTensorInfo
+ if err := json.Unmarshal(header[name], &info); err != nil {
+ return nil, fmt.Errorf("failed to parse tensor info for %s: %w", name, err)
+ }
+ info.Name = name
+
+ if globalQuantType != "" {
+ // Use global metadata
+ info.QuantType = globalQuantType
+ info.GroupSize = globalGroupSize
+ } else if headerKeys[name+".scale"] {
+ // No global metadata, but has .scale - infer quant type from shape
+ info.QuantType = inferQuantType(header, name)
+ }
+
+ results = append(results, info)
+ }
+
+ if len(results) == 0 {
+ return nil, fmt.Errorf("no tensor found in header")
+ }
+
+ return results, nil
+}
+
+// inferQuantType infers the quantization type for a tensor from its shape and scale shape.
+// Returns "int4", "int8", etc. or "" if not quantized.
+func inferQuantType(header map[string]json.RawMessage, name string) string {
+ // Parse the main tensor shape
+ var mainInfo struct {
+ Shape []int64 `json:"shape"`
+ }
+ if json.Unmarshal(header[name], &mainInfo) != nil || len(mainInfo.Shape) < 2 {
+ return ""
+ }
+
+ // Parse scale shape to determine group size
+ scaleRaw, ok := header[name+".scale"]
+ if !ok {
+ return ""
+ }
+ var scaleInfo struct {
+ Shape []int64 `json:"shape"`
+ }
+ if json.Unmarshal(scaleRaw, &scaleInfo) != nil || len(scaleInfo.Shape) < 2 {
+ return ""
+ }
+
+ // Calculate group size: main_cols * pack_factor / scale_cols
+ // Main dtype is U32, so we need to figure out the pack factor
+ // For int4: pack=8, group=32. scale_cols = original_cols / 32 = main_cols * 8 / 32 = main_cols / 4
+ // For int8: pack=4, group=64. scale_cols = original_cols / 64 = main_cols * 4 / 64 = main_cols / 16
+ mainCols := mainInfo.Shape[len(mainInfo.Shape)-1]
+ scaleCols := scaleInfo.Shape[len(scaleInfo.Shape)-1]
+ if scaleCols == 0 {
+ return ""
+ }
+
+ ratio := mainCols / scaleCols // main_packed_cols / scale_cols
+ // int4: ratio = (orig/8) / (orig/32) = 32/8 = 4
+ // int8: ratio = (orig/4) / (orig/64) = 64/4 = 16
+ switch ratio {
+ case 4:
+ return "int4"
+ case 16:
+ return "int8"
+ default:
+ return ""
+ }
+}
diff --git a/x/server/show_test.go b/x/server/show_test.go
index c510b0d5414..52299378732 100644
--- a/x/server/show_test.go
+++ b/x/server/show_test.go
@@ -8,7 +8,7 @@ import (
"path/filepath"
"testing"
- "github.com/ollama/ollama/x/imagegen"
+ "github.com/ollama/ollama/manifest"
)
func TestBuildModelInfo(t *testing.T) {
@@ -36,7 +36,7 @@ func TestBuildModelInfo(t *testing.T) {
VocabSize: 262144,
TorchDtype: "bfloat16",
},
- totalTensorBytes: 8_600_000_088, // ~4.3B params * 2 bytes + 88 bytes header
+ totalTensorBytes: 8_600_000_150, // ~4.3B params * 2 bytes + 150 bytes header
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
@@ -57,7 +57,7 @@ func TestBuildModelInfo(t *testing.T) {
VocabSize: 32000,
TorchDtype: "float16",
},
- totalTensorBytes: 14_000_000_088, // ~7B params * 2 bytes + 88 bytes header
+ totalTensorBytes: 14_000_000_150, // ~7B params * 2 bytes + 150 bytes header
tensorCount: 1,
wantArch: "llama",
wantContextLen: 4096,
@@ -84,7 +84,7 @@ func TestBuildModelInfo(t *testing.T) {
VocabSize: 262144,
TorchDtype: "bfloat16",
},
- totalTensorBytes: 8_600_000_088,
+ totalTensorBytes: 8_600_000_150,
tensorCount: 1,
wantArch: "gemma3",
wantContextLen: 131072,
@@ -101,7 +101,7 @@ func TestBuildModelInfo(t *testing.T) {
MaxPositionEmbeddings: 2048,
TorchDtype: "float32",
},
- totalTensorBytes: 400_000_088, // 100M params * 4 bytes + 88 bytes header
+ totalTensorBytes: 400_000_150, // 100M params * 4 bytes + 150 bytes header
tensorCount: 1,
wantArch: "test",
wantContextLen: 2048,
@@ -118,7 +118,7 @@ func TestBuildModelInfo(t *testing.T) {
MaxPositionEmbeddings: 1024,
TorchDtype: "bfloat16",
},
- totalTensorBytes: 2_000_880, // 1M params * 2 bytes + 10 tensors * 88 bytes
+ totalTensorBytes: 2_001_500, // 1M params * 2 bytes + 10 tensors * 150 bytes
tensorCount: 10,
wantArch: "test",
wantContextLen: 1024,
@@ -230,42 +230,42 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
{
name: "bfloat16",
dtype: "bfloat16",
- totalBytes: 2_000_088, // 1M * 2 + 88
+ totalBytes: 2_000_150, // 1M * 2 + 150
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float16",
dtype: "float16",
- totalBytes: 2_000_088,
+ totalBytes: 2_000_150,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "float32",
dtype: "float32",
- totalBytes: 4_000_088, // 1M * 4 + 88
+ totalBytes: 4_000_150, // 1M * 4 + 150
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "int8",
dtype: "int8",
- totalBytes: 1_000_088, // 1M * 1 + 88
+ totalBytes: 1_000_150, // 1M * 1 + 150
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "unknown dtype defaults to 2 bytes",
dtype: "unknown",
- totalBytes: 2_000_088,
+ totalBytes: 2_000_150,
tensorCount: 1,
wantParamCount: 1_000_000,
},
{
name: "empty dtype defaults to 2 bytes",
dtype: "",
- totalBytes: 2_000_088,
+ totalBytes: 2_000_150,
tensorCount: 1,
wantParamCount: 1_000_000,
},
@@ -288,11 +288,13 @@ func TestBuildModelInfo_BytesPerParam(t *testing.T) {
func TestParseSafetensorsHeader(t *testing.T) {
tests := []struct {
- name string
- header map[string]any
- wantDtype string
- wantShape []int64
- wantErr bool
+ name string
+ header map[string]any
+ wantDtype string
+ wantShape []int64
+ wantQuantType string
+ wantGroupSize string
+ wantErr bool
}{
{
name: "simple tensor",
@@ -307,7 +309,70 @@ func TestParseSafetensorsHeader(t *testing.T) {
wantShape: []int64{2560, 262144},
},
{
- name: "with metadata",
+ name: "tensor keyed by name",
+ header: map[string]any{
+ "model.layers.0.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 2560},
+ "data_offsets": []int64{0, 13107200},
+ },
+ },
+ wantDtype: "BF16",
+ wantShape: []int64{2560, 2560},
+ },
+ {
+ name: "with int4 quant metadata",
+ header: map[string]any{
+ "__metadata__": map[string]any{
+ "quant_type": "int4",
+ "group_size": "32",
+ },
+ "model.layers.0.mlp.up_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{2560, 320},
+ "data_offsets": []int64{0, 3276800},
+ },
+ "model.layers.0.mlp.up_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 80},
+ "data_offsets": []int64{3276800, 3686400},
+ },
+ "model.layers.0.mlp.up_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 80},
+ "data_offsets": []int64{3686400, 4096000},
+ },
+ },
+ wantDtype: "U32",
+ wantShape: []int64{2560, 320},
+ wantQuantType: "int4",
+ wantGroupSize: "32",
+ },
+ {
+ name: "int8 quant metadata",
+ header: map[string]any{
+ "__metadata__": map[string]any{
+ "quant_type": "int8",
+ "group_size": "64",
+ },
+ "model.layers.0.mlp.down_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{2560, 640},
+ "data_offsets": []int64{0, 6553600},
+ },
+ "model.layers.0.mlp.down_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 40},
+ "data_offsets": []int64{6553600, 6963200},
+ },
+ },
+ wantDtype: "U32",
+ wantShape: []int64{2560, 640},
+ wantQuantType: "int8",
+ wantGroupSize: "64",
+ },
+ {
+ name: "with old-style format metadata",
header: map[string]any{
"__metadata__": map[string]any{
"format": "pt",
@@ -371,6 +436,13 @@ func TestParseSafetensorsHeader(t *testing.T) {
}
}
}
+
+ if info.QuantType != tt.wantQuantType {
+ t.Errorf("QuantType = %v, want %v", info.QuantType, tt.wantQuantType)
+ }
+ if info.GroupSize != tt.wantGroupSize {
+ t.Errorf("GroupSize = %v, want %v", info.GroupSize, tt.wantGroupSize)
+ }
})
}
}
@@ -451,10 +523,16 @@ func TestParseSafetensorsHeader_Errors(t *testing.T) {
}
func TestGetTensorInfoFromManifest(t *testing.T) {
- // Create a temp directory for blobs
+ // Create a temp directory for blobs and set OLLAMA_MODELS
tempDir := t.TempDir()
+ t.Setenv("OLLAMA_MODELS", tempDir)
- // Create test tensor blobs
+ blobDir := filepath.Join(tempDir, "blobs")
+ if err := os.MkdirAll(blobDir, 0o755); err != nil {
+ t.Fatalf("failed to create blobs dir: %v", err)
+ }
+
+ // Create test tensor blobs with __metadata__
tensors := []struct {
name string
digest string
@@ -463,28 +541,27 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}{
{
name: "model.embed_tokens.weight",
- digest: "sha256:abc123",
+ digest: "sha256:abc123abc123abc123abc123abc123abc123abc123abc123abc123abc123abc0",
dtype: "BF16",
shape: []int64{262144, 2560},
},
{
name: "model.layers.0.self_attn.q_proj.weight",
- digest: "sha256:def456",
+ digest: "sha256:def456def456def456def456def456def456def456def456def456def456def0",
dtype: "BF16",
shape: []int64{2560, 2560},
},
{
name: "model.norm.weight",
- digest: "sha256:ghi789",
+ digest: "sha256:789789789789789789789789789789789789789789789789789789789789abc0",
dtype: "F32",
shape: []int64{2560},
},
}
- // Create blob files
- var layers []imagegen.ManifestLayer
+ // Create blob files with tensor keyed by name
+ var layers []manifest.Layer
for _, tensor := range tensors {
- // Create safetensors blob
header := map[string]any{
tensor.name: map[string]any{
"dtype": tensor.dtype,
@@ -498,15 +575,17 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
buf.Write(headerJSON)
- // Write blob file
- blobName := "sha256-" + tensor.digest[7:]
- blobPath := filepath.Join(tempDir, blobName)
+ // Write blob file using the digest format expected by GetBlobsPath
+ blobPath, err := manifest.BlobsPath(tensor.digest)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
t.Fatalf("failed to write blob: %v", err)
}
- layers = append(layers, imagegen.ManifestLayer{
- MediaType: "application/vnd.ollama.image.tensor",
+ layers = append(layers, manifest.Layer{
+ MediaType: manifest.MediaTypeImageTensor,
Digest: tensor.digest,
Size: int64(buf.Len() + 1000), // header + fake data
Name: tensor.name,
@@ -514,21 +593,20 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}
// Add a non-tensor layer (should be skipped)
- layers = append(layers, imagegen.ManifestLayer{
+ layers = append(layers, manifest.Layer{
MediaType: "application/vnd.ollama.image.json",
- Digest: "sha256:config",
+ Digest: "sha256:0000000000000000000000000000000000000000000000000000000000000000",
Size: 100,
Name: "config.json",
})
- manifest := &imagegen.ModelManifest{
- Manifest: &imagegen.Manifest{
- Layers: layers,
- },
- BlobDir: tempDir,
+ mf := &manifest.Manifest{
+ SchemaVersion: 2,
+ MediaType: "application/vnd.docker.distribution.manifest.v2+json",
+ Layers: layers,
}
- result, err := getTensorInfoFromManifest(manifest)
+ result, err := getTensorInfoFromManifest(mf)
if err != nil {
t.Fatalf("getTensorInfoFromManifest() error = %v", err)
}
@@ -554,6 +632,572 @@ func TestGetTensorInfoFromManifest(t *testing.T) {
}
}
+func TestGetTensorInfoFromManifest_Quantized(t *testing.T) {
+ // Create a temp directory for blobs and set OLLAMA_MODELS
+ tempDir := t.TempDir()
+ t.Setenv("OLLAMA_MODELS", tempDir)
+
+ blobDir := filepath.Join(tempDir, "blobs")
+ if err := os.MkdirAll(blobDir, 0o755); err != nil {
+ t.Fatalf("failed to create blobs dir: %v", err)
+ }
+
+ // Create a combined quantized blob with __metadata__
+ header := map[string]any{
+ "__metadata__": map[string]string{
+ "quant_type": "int4",
+ "group_size": "32",
+ },
+ "model.layers.0.mlp.up_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{2560, 320}, // packed: 2560 / 8 = 320
+ "data_offsets": []int64{0, 3276800},
+ },
+ "model.layers.0.mlp.up_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 80}, // 2560 / 32 = 80
+ "data_offsets": []int64{3276800, 3686400},
+ },
+ "model.layers.0.mlp.up_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 80},
+ "data_offsets": []int64{3686400, 4096000},
+ },
+ }
+ headerJSON, _ := json.Marshal(header)
+
+ var buf bytes.Buffer
+ binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
+ buf.Write(headerJSON)
+
+ digest := "sha256:aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb11aabb"
+ blobPath, err := manifest.BlobsPath(digest)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
+ if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
+ t.Fatalf("failed to write blob: %v", err)
+ }
+
+ mf := &manifest.Manifest{
+ SchemaVersion: 2,
+ MediaType: "application/vnd.docker.distribution.manifest.v2+json",
+ Layers: []manifest.Layer{
+ {
+ MediaType: manifest.MediaTypeImageTensor,
+ Digest: digest,
+ Size: int64(buf.Len() + 4096000),
+ Name: "model.layers.0.mlp.up_proj.weight",
+ },
+ },
+ }
+
+ result, err := getTensorInfoFromManifest(mf)
+ if err != nil {
+ t.Fatalf("getTensorInfoFromManifest() error = %v", err)
+ }
+
+ if len(result) != 1 {
+ t.Fatalf("got %d tensors, want 1", len(result))
+ }
+
+ tensor := result[0]
+ if tensor.Name != "model.layers.0.mlp.up_proj.weight" {
+ t.Errorf("Name = %v, want model.layers.0.mlp.up_proj.weight", tensor.Name)
+ }
+ if tensor.Type != "INT4" {
+ t.Errorf("Type = %v, want INT4", tensor.Type)
+ }
+ // Shape should be unpacked: 320 * 8 = 2560
+ if len(tensor.Shape) != 2 || tensor.Shape[0] != 2560 || tensor.Shape[1] != 2560 {
+ t.Errorf("Shape = %v, want [2560, 2560]", tensor.Shape)
+ }
+}
+
+func TestGetParameterCountFromManifest(t *testing.T) {
+ // Create a temp directory for blobs and set OLLAMA_MODELS
+ tempDir := t.TempDir()
+ t.Setenv("OLLAMA_MODELS", tempDir)
+
+ blobDir := filepath.Join(tempDir, "blobs")
+ if err := os.MkdirAll(blobDir, 0o755); err != nil {
+ t.Fatalf("failed to create blobs dir: %v", err)
+ }
+
+ // Unquantized tensor: [4,5] = 20 params
+ header1 := map[string]any{
+ "model.embed_tokens.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{4, 5},
+ "data_offsets": []int64{0, 40},
+ },
+ }
+ header1JSON, _ := json.Marshal(header1)
+ var buf1 bytes.Buffer
+ binary.Write(&buf1, binary.LittleEndian, uint64(len(header1JSON)))
+ buf1.Write(header1JSON)
+
+ digest1 := "sha256:1111111111111111111111111111111111111111111111111111111111111111"
+ blobPath1, err := manifest.BlobsPath(digest1)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
+ if err := os.WriteFile(blobPath1, buf1.Bytes(), 0o644); err != nil {
+ t.Fatalf("failed to write blob1: %v", err)
+ }
+
+ // Quantized int4 tensor with packed shape [10,2] -> unpacked [10,16] = 160 params
+ header2 := map[string]any{
+ "__metadata__": map[string]string{
+ "quant_type": "int4",
+ "group_size": "32",
+ },
+ "model.layers.0.mlp.up_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{10, 2},
+ "data_offsets": []int64{0, 80},
+ },
+ "model.layers.0.mlp.up_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10, 1},
+ "data_offsets": []int64{80, 100},
+ },
+ "model.layers.0.mlp.up_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10, 1},
+ "data_offsets": []int64{100, 120},
+ },
+ }
+ header2JSON, _ := json.Marshal(header2)
+ var buf2 bytes.Buffer
+ binary.Write(&buf2, binary.LittleEndian, uint64(len(header2JSON)))
+ buf2.Write(header2JSON)
+
+ digest2 := "sha256:2222222222222222222222222222222222222222222222222222222222222222"
+ blobPath2, err := manifest.BlobsPath(digest2)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
+ if err := os.WriteFile(blobPath2, buf2.Bytes(), 0o644); err != nil {
+ t.Fatalf("failed to write blob2: %v", err)
+ }
+
+ mf := &manifest.Manifest{
+ SchemaVersion: 2,
+ MediaType: "application/vnd.docker.distribution.manifest.v2+json",
+ Layers: []manifest.Layer{
+ {
+ MediaType: manifest.MediaTypeImageTensor,
+ Digest: digest1,
+ Size: int64(buf1.Len() + 40),
+ Name: "model.embed_tokens.weight",
+ },
+ {
+ MediaType: manifest.MediaTypeImageTensor,
+ Digest: digest2,
+ Size: int64(buf2.Len() + 120),
+ Name: "model.layers.0.mlp.up_proj.weight",
+ },
+ },
+ }
+
+ paramCount, err := getParameterCountFromManifest(mf)
+ if err != nil {
+ t.Fatalf("getParameterCountFromManifest() error = %v", err)
+ }
+
+ const want int64 = 180 // 20 + 160
+ if paramCount != want {
+ t.Errorf("parameter_count = %d, want %d", paramCount, want)
+ }
+}
+
+func TestGetParameterCountFromManifest_MixedQuantizedPacked(t *testing.T) {
+ // Create a temp directory for blobs and set OLLAMA_MODELS
+ tempDir := t.TempDir()
+ t.Setenv("OLLAMA_MODELS", tempDir)
+
+ blobDir := filepath.Join(tempDir, "blobs")
+ if err := os.MkdirAll(blobDir, 0o755); err != nil {
+ t.Fatalf("failed to create blobs dir: %v", err)
+ }
+
+ // Packed mixed-precision blob (no global metadata):
+ // - gate_proj: int4 packed [5,8] + scale [5,2] => unpacked [5,64] = 320 params
+ // - down_proj: int8 packed [5,16] + scale [5,1] => unpacked [5,64] = 320 params
+ header := map[string]any{
+ "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{5, 8},
+ "data_offsets": []int64{0, 160},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{5, 2},
+ "data_offsets": []int64{160, 180},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{5, 2},
+ "data_offsets": []int64{180, 200},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{5, 16},
+ "data_offsets": []int64{200, 520},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{5, 1},
+ "data_offsets": []int64{520, 530},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{5, 1},
+ "data_offsets": []int64{530, 540},
+ },
+ }
+ headerJSON, _ := json.Marshal(header)
+ var buf bytes.Buffer
+ binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
+ buf.Write(headerJSON)
+
+ digest := "sha256:3333333333333333333333333333333333333333333333333333333333333333"
+ blobPath, err := manifest.BlobsPath(digest)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
+ if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
+ t.Fatalf("failed to write blob: %v", err)
+ }
+
+ mf := &manifest.Manifest{
+ SchemaVersion: 2,
+ MediaType: "application/vnd.docker.distribution.manifest.v2+json",
+ Layers: []manifest.Layer{
+ {
+ MediaType: manifest.MediaTypeImageTensor,
+ Digest: digest,
+ Size: int64(buf.Len() + 540),
+ Name: "model.layers.0.mlp.experts",
+ },
+ },
+ }
+
+ paramCount, err := getParameterCountFromManifest(mf)
+ if err != nil {
+ t.Fatalf("getParameterCountFromManifest() error = %v", err)
+ }
+
+ const want int64 = 640 // 320 + 320
+ if paramCount != want {
+ t.Errorf("parameter_count = %d, want %d", paramCount, want)
+ }
+}
+
+func TestParseSafetensorsAllHeaders(t *testing.T) {
+ tests := []struct {
+ name string
+ header map[string]any
+ wantCount int
+ wantNames []string
+ wantDtypes []string
+ wantQuants []string
+ wantErr bool
+ }{
+ {
+ name: "single tensor blob",
+ header: map[string]any{
+ "model.layers.0.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 2560},
+ "data_offsets": []int64{0, 13107200},
+ },
+ },
+ wantCount: 1,
+ wantNames: []string{"model.layers.0.weight"},
+ wantDtypes: []string{"BF16"},
+ wantQuants: []string{""},
+ },
+ {
+ name: "packed unquantized blob",
+ header: map[string]any{
+ "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 10240},
+ "data_offsets": []int64{0, 52428800},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 2560},
+ "data_offsets": []int64{52428800, 104857600},
+ },
+ "model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 2560},
+ "data_offsets": []int64{104857600, 157286400},
+ },
+ },
+ wantCount: 3,
+ wantNames: []string{
+ "model.layers.0.mlp.experts.0.down_proj.weight",
+ "model.layers.0.mlp.experts.0.gate_proj.weight",
+ "model.layers.0.mlp.experts.0.up_proj.weight",
+ },
+ wantDtypes: []string{"BF16", "BF16", "BF16"},
+ wantQuants: []string{"", "", ""},
+ },
+ {
+ name: "packed quantized blob with global metadata",
+ header: map[string]any{
+ "__metadata__": map[string]any{
+ "quant_type": "int4",
+ "group_size": "32",
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{10240, 320},
+ "data_offsets": []int64{0, 13107200},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{13107200, 14745600},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{14745600, 16384000},
+ },
+ "model.layers.0.mlp.experts.0.up_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{10240, 320},
+ "data_offsets": []int64{16384000, 29491200},
+ },
+ "model.layers.0.mlp.experts.0.up_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{29491200, 31129600},
+ },
+ "model.layers.0.mlp.experts.0.up_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{31129600, 32768000},
+ },
+ },
+ wantCount: 2,
+ wantNames: []string{
+ "model.layers.0.mlp.experts.0.gate_proj.weight",
+ "model.layers.0.mlp.experts.0.up_proj.weight",
+ },
+ wantDtypes: []string{"U32", "U32"},
+ wantQuants: []string{"int4", "int4"},
+ },
+ {
+ name: "packed mixed-precision blob (no global metadata)",
+ header: map[string]any{
+ "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{10240, 320},
+ "data_offsets": []int64{0, 13107200},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{13107200, 14745600},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{14745600, 16384000},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{2560, 2560},
+ "data_offsets": []int64{16384000, 42598400},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 160},
+ "data_offsets": []int64{42598400, 43417600},
+ },
+ },
+ wantCount: 2,
+ wantNames: []string{
+ "model.layers.0.mlp.experts.0.down_proj.weight",
+ "model.layers.0.mlp.experts.0.gate_proj.weight",
+ },
+ wantDtypes: []string{"U32", "U32"},
+ wantQuants: []string{"int8", "int4"},
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ headerJSON, err := json.Marshal(tt.header)
+ if err != nil {
+ t.Fatalf("failed to marshal header: %v", err)
+ }
+
+ var buf bytes.Buffer
+ if err := binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON))); err != nil {
+ t.Fatalf("failed to write header size: %v", err)
+ }
+ buf.Write(headerJSON)
+
+ results, err := parseSafetensorsAllHeaders(&buf)
+ if (err != nil) != tt.wantErr {
+ t.Errorf("parseSafetensorsAllHeaders() error = %v, wantErr %v", err, tt.wantErr)
+ return
+ }
+ if tt.wantErr {
+ return
+ }
+
+ if len(results) != tt.wantCount {
+ t.Fatalf("got %d tensors, want %d", len(results), tt.wantCount)
+ }
+
+ for i, info := range results {
+ if info.Name != tt.wantNames[i] {
+ t.Errorf("tensor[%d].Name = %v, want %v", i, info.Name, tt.wantNames[i])
+ }
+ if info.Dtype != tt.wantDtypes[i] {
+ t.Errorf("tensor[%d].Dtype = %v, want %v", i, info.Dtype, tt.wantDtypes[i])
+ }
+ if info.QuantType != tt.wantQuants[i] {
+ t.Errorf("tensor[%d].QuantType = %v, want %v", i, info.QuantType, tt.wantQuants[i])
+ }
+ }
+ })
+ }
+}
+
+func TestGetTensorInfoFromManifest_Packed(t *testing.T) {
+ // Create a temp directory for blobs and set OLLAMA_MODELS
+ tempDir := t.TempDir()
+ t.Setenv("OLLAMA_MODELS", tempDir)
+
+ blobDir := filepath.Join(tempDir, "blobs")
+ if err := os.MkdirAll(blobDir, 0o755); err != nil {
+ t.Fatalf("failed to create blobs dir: %v", err)
+ }
+
+ // Create a packed blob with multiple expert tensors (mixed quantization)
+ header := map[string]any{
+ "model.layers.0.mlp.experts.0.gate_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{10240, 320},
+ "data_offsets": []int64{0, 13107200},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{13107200, 14745600},
+ },
+ "model.layers.0.mlp.experts.0.gate_proj.weight.bias": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{10240, 80},
+ "data_offsets": []int64{14745600, 16384000},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight": map[string]any{
+ "dtype": "U32",
+ "shape": []int64{2560, 2560},
+ "data_offsets": []int64{16384000, 42598400},
+ },
+ "model.layers.0.mlp.experts.0.down_proj.weight.scale": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{2560, 160},
+ "data_offsets": []int64{42598400, 43417600},
+ },
+ }
+ headerJSON, _ := json.Marshal(header)
+
+ var buf bytes.Buffer
+ binary.Write(&buf, binary.LittleEndian, uint64(len(headerJSON)))
+ buf.Write(headerJSON)
+
+ packedDigest := "sha256:aaaa000000000000000000000000000000000000000000000000000000000001"
+ blobPath, err := manifest.BlobsPath(packedDigest)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
+ if err := os.WriteFile(blobPath, buf.Bytes(), 0o644); err != nil {
+ t.Fatalf("failed to write packed blob: %v", err)
+ }
+
+ // Also create a regular (single-tensor) blob
+ singleHeader := map[string]any{
+ "model.embed_tokens.weight": map[string]any{
+ "dtype": "BF16",
+ "shape": []int64{262144, 2560},
+ "data_offsets": []int64{0, 1342177280},
+ },
+ }
+ singleHeaderJSON, _ := json.Marshal(singleHeader)
+ var singleBuf bytes.Buffer
+ binary.Write(&singleBuf, binary.LittleEndian, uint64(len(singleHeaderJSON)))
+ singleBuf.Write(singleHeaderJSON)
+
+ singleDigest := "sha256:bbbb000000000000000000000000000000000000000000000000000000000002"
+ singleBlobPath, err := manifest.BlobsPath(singleDigest)
+ if err != nil {
+ t.Fatalf("failed to get blob path: %v", err)
+ }
+ if err := os.WriteFile(singleBlobPath, singleBuf.Bytes(), 0o644); err != nil {
+ t.Fatalf("failed to write single blob: %v", err)
+ }
+
+ mf := &manifest.Manifest{
+ SchemaVersion: 2,
+ MediaType: "application/vnd.docker.distribution.manifest.v2+json",
+ Layers: []manifest.Layer{
+ {
+ MediaType: manifest.MediaTypeImageTensor,
+ Digest: singleDigest,
+ Size: int64(singleBuf.Len()),
+ Name: "model.embed_tokens.weight",
+ },
+ {
+ MediaType: manifest.MediaTypeImageTensor,
+ Digest: packedDigest,
+ Size: int64(buf.Len()),
+ Name: "model.layers.0.mlp.experts", // group prefix
+ },
+ },
+ }
+
+ result, err := getTensorInfoFromManifest(mf)
+ if err != nil {
+ t.Fatalf("getTensorInfoFromManifest() error = %v", err)
+ }
+
+ // Should have 3 tensors: 1 single + 2 packed main tensors
+ if len(result) != 3 {
+ t.Fatalf("got %d tensors, want 3. Tensors: %v", len(result), result)
+ }
+
+ // First tensor should be the single blob
+ if result[0].Name != "model.embed_tokens.weight" {
+ t.Errorf("tensor[0].Name = %v, want model.embed_tokens.weight", result[0].Name)
+ }
+ if result[0].Type != "BF16" {
+ t.Errorf("tensor[0].Type = %v, want BF16", result[0].Type)
+ }
+
+ // Packed tensors should have their actual names (sorted)
+ packedNames := make(map[string]bool)
+ for _, r := range result[1:] {
+ packedNames[r.Name] = true
+ }
+ if !packedNames["model.layers.0.mlp.experts.0.down_proj.weight"] {
+ t.Error("missing packed tensor: model.layers.0.mlp.experts.0.down_proj.weight")
+ }
+ if !packedNames["model.layers.0.mlp.experts.0.gate_proj.weight"] {
+ t.Error("missing packed tensor: model.layers.0.mlp.experts.0.gate_proj.weight")
+ }
+}
+
func TestReadSafetensorsHeader(t *testing.T) {
// Create a temp file with a valid safetensors header
tempDir := t.TempDir()
diff --git a/x/tokenizer/tokenizer.go b/x/tokenizer/tokenizer.go
new file mode 100644
index 00000000000..301e51aea3c
--- /dev/null
+++ b/x/tokenizer/tokenizer.go
@@ -0,0 +1,108 @@
+//go:build mlx
+
+// tokenizer.go - BPE and SentencePiece tokenizer for HuggingFace models
+//
+// Based on standard BPE algorithm (Sennrich et al. 2015) with:
+// - GPT-2 byte-level encoding (OpenAI tiktoken)
+// - HuggingFace tokenizer.json pretokenizer patterns
+// - SentencePiece ▁-style space handling
+
+package tokenizer
+
+import "regexp"
+
+// TokenizerType identifies the tokenization algorithm
+type TokenizerType int
+
+const (
+ TokenizerBPE TokenizerType = iota // GPT-2 style byte-level BPE
+ TokenizerSentencePiece // SentencePiece with ▁ for spaces
+)
+
+// Vocabulary holds the tokenizer vocabulary and merges
+type Vocabulary struct {
+ Values []string
+ Reverse map[string]int32
+ Merges map[string]int
+
+ BOS int32
+ EOS []int32 // Multiple EOS tokens supported (e.g., Gemma has and )
+ PAD int32 // Padding token (often <|endoftext|> or )
+ AddBOS bool
+ AddEOS bool
+
+ // Precomputed byte token IDs for <0xNN> fallback (256 entries, -1 if not found)
+ byteTokens [256]int32
+}
+
+// Tokenizer handles BPE and SentencePiece tokenization
+type Tokenizer struct {
+ vocab *Vocabulary
+ pretokenizer *regexp.Regexp
+ specialTokens map[string]int32 // Special tokens for direct lookup
+ sortedSpecialTokens []string // Special tokens sorted by length, longest first
+ typ TokenizerType // Algorithm type
+}
+
+// Precomputed GPT-2 byte-level encoding table
+// Maps byte values to their encoded rune equivalents
+var byteToRune [256]rune
+
+func init() {
+ for b := 0; b < 256; b++ {
+ r := rune(b)
+ switch {
+ case r == 0x00ad:
+ r = 0x0143
+ case r <= 0x0020:
+ r = r + 0x0100
+ case r >= 0x007f && r <= 0x00a0:
+ r = r + 0x00a2
+ }
+ byteToRune[b] = r
+ }
+}
+
+// VocabSize returns the vocabulary size
+func (t *Tokenizer) VocabSize() int {
+ return len(t.vocab.Values)
+}
+
+// BOS returns the beginning of sequence token ID
+func (t *Tokenizer) BOS() int32 {
+ return t.vocab.BOS
+}
+
+// EOS returns the first end of sequence token ID (for backwards compatibility)
+func (t *Tokenizer) EOS() int32 {
+ if len(t.vocab.EOS) > 0 {
+ return t.vocab.EOS[0]
+ }
+ return -1
+}
+
+// EOSTokens returns all end of sequence token IDs
+func (t *Tokenizer) EOSTokens() []int32 {
+ return t.vocab.EOS
+}
+
+// PAD returns the padding token ID, or -1 if not set
+func (t *Tokenizer) PAD() int32 {
+ return t.vocab.PAD
+}
+
+// IsEOS returns true if the token ID is an end of sequence token
+func (t *Tokenizer) IsEOS(id int32) bool {
+ for _, eos := range t.vocab.EOS {
+ if id == eos {
+ return true
+ }
+ }
+ return false
+}
+
+// GetSpecialToken returns the token ID for a special token string
+func (t *Tokenizer) GetSpecialToken(name string) (int32, bool) {
+ id, ok := t.specialTokens[name]
+ return id, ok
+}
diff --git a/x/tokenizer/tokenizer_benchmark_test.go b/x/tokenizer/tokenizer_benchmark_test.go
new file mode 100644
index 00000000000..e65a5978645
--- /dev/null
+++ b/x/tokenizer/tokenizer_benchmark_test.go
@@ -0,0 +1,251 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+var (
+ benchmarkSinkIDs []int32
+ benchmarkSinkStr string
+ benchmarkSinkTok *Tokenizer
+)
+
+const benchmarkWordPieceJSON = `{
+ "model": {
+ "type": "WordPiece",
+ "vocab": {
+ "[UNK]": 0,
+ "hello": 1,
+ "##world": 2,
+ "##ly": 3,
+ "##hello": 4
+ }
+ },
+ "added_tokens": []
+}`
+
+const benchmarkSentencePieceJSON = `{
+ "model": {
+ "type": "BPE",
+ "vocab": {
+ "\u2581": 0,
+ "h": 1,
+ "e": 2,
+ "l": 3,
+ "o": 4,
+ "w": 5,
+ "r": 6,
+ "d": 7,
+ "<0x0A>": 8
+ },
+ "merges": []
+ },
+ "decoder": {
+ "type": "Sequence",
+ "decoders": [
+ {
+ "type": "Replace",
+ "pattern": {
+ "String": "\u2581"
+ }
+ }
+ ]
+ },
+ "added_tokens": []
+}`
+
+func benchmarkMiniLlamaPath(tb testing.TB) string {
+ tb.Helper()
+
+ _, filename, _, ok := runtime.Caller(0)
+ if !ok {
+ tb.Fatal("failed to resolve benchmark file path")
+ }
+
+ return filepath.Join(filepath.Dir(filename), "..", "imagegen", "tokenizer", "testdata", "mini_llama.json")
+}
+
+func benchmarkLoadMiniLlama(tb testing.TB) *Tokenizer {
+ tb.Helper()
+
+ data := benchmarkLoadMiniLlamaBytes(tb)
+ tok, err := LoadFromBytes(data)
+ if err != nil {
+ tb.Fatalf("failed to load mini llama tokenizer: %v", err)
+ }
+ return tok
+}
+
+func benchmarkLoadMiniLlamaBytes(tb testing.TB) []byte {
+ tb.Helper()
+
+ data, err := os.ReadFile(benchmarkMiniLlamaPath(tb))
+ if err != nil {
+ tb.Fatalf("failed to read mini llama tokenizer: %v", err)
+ }
+ return data
+}
+
+func benchmarkLoadFromBytes(tb testing.TB, data []byte) *Tokenizer {
+ tb.Helper()
+
+ tok, err := LoadFromBytes(data)
+ if err != nil {
+ tb.Fatalf("failed to load tokenizer from bytes: %v", err)
+ }
+ return tok
+}
+
+func BenchmarkTokenizerEncodeBPE(b *testing.B) {
+ tok := benchmarkLoadMiniLlama(b)
+
+ inputs := []struct {
+ name string
+ text string
+ }{
+ {name: "short", text: "Hello, world!"},
+ {name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
+ {name: "long_sequential", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 80)},
+ {name: "long_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
+ {name: "huge_parallel", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)},
+ {name: "special_tokens", text: "<|begin_of_text|>system\nYou are concise.<|end_of_text|>"},
+ }
+
+ for _, input := range inputs {
+ b.Run(input.name, func(b *testing.B) {
+ b.ReportAllocs()
+ b.SetBytes(int64(len(input.text)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkSinkIDs = tok.Encode(input.text, false)
+ }
+ })
+ }
+}
+
+func BenchmarkTokenizerDecodeBPE(b *testing.B) {
+ tok := benchmarkLoadMiniLlama(b)
+
+ inputs := []struct {
+ name string
+ text string
+ }{
+ {name: "medium", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 16)},
+ {name: "long", text: strings.Repeat("The quick brown fox jumps over the lazy dog. ", 160)},
+ }
+
+ for _, input := range inputs {
+ ids := tok.Encode(input.text, false)
+ b.Run(input.name, func(b *testing.B) {
+ b.ReportAllocs()
+ b.SetBytes(int64(len(input.text)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkSinkStr = tok.Decode(ids)
+ }
+ })
+ }
+}
+
+func BenchmarkTokenizerLoadFromBytes(b *testing.B) {
+ data := benchmarkLoadMiniLlamaBytes(b)
+
+ config := &TokenizerConfig{
+ TokenizerConfigJSON: []byte(`{
+ "bos_token": {"content": "<|begin_of_text|>"},
+ "eos_token": {"content": "<|end_of_text|>"},
+ "add_bos_token": true
+ }`),
+ GenerationConfigJSON: []byte(`{"bos_token_id": 128000, "eos_token_id": 128001}`),
+ }
+
+ b.Run("without_config", func(b *testing.B) {
+ b.ReportAllocs()
+ b.SetBytes(int64(len(data)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ tok, err := LoadFromBytes(data)
+ if err != nil {
+ b.Fatalf("LoadFromBytes failed: %v", err)
+ }
+ benchmarkSinkTok = tok
+ }
+ })
+
+ b.Run("with_config", func(b *testing.B) {
+ b.ReportAllocs()
+ b.SetBytes(int64(len(data)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ tok, err := LoadFromBytesWithConfig(data, config)
+ if err != nil {
+ b.Fatalf("LoadFromBytesWithConfig failed: %v", err)
+ }
+ benchmarkSinkTok = tok
+ }
+ })
+}
+
+func BenchmarkTokenizerEncodeWordPiece(b *testing.B) {
+ tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
+ text := strings.Repeat("helloworldly", 16)
+
+ b.ReportAllocs()
+ b.SetBytes(int64(len(text)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkSinkIDs = tok.Encode(text, false)
+ }
+}
+
+func BenchmarkTokenizerDecodeWordPiece(b *testing.B) {
+ tok := benchmarkLoadFromBytes(b, []byte(benchmarkWordPieceJSON))
+ text := strings.Repeat("helloworldly", 16)
+ ids := tok.Encode(text, false)
+
+ b.ReportAllocs()
+ b.SetBytes(int64(len(text)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkSinkStr = tok.Decode(ids)
+ }
+}
+
+func BenchmarkTokenizerEncodeSentencePiece(b *testing.B) {
+ tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
+ text := strings.Repeat("hello world\n", 64)
+
+ b.ReportAllocs()
+ b.SetBytes(int64(len(text)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkSinkIDs = tok.Encode(text, false)
+ }
+}
+
+func BenchmarkTokenizerDecodeSentencePiece(b *testing.B) {
+ tok := benchmarkLoadFromBytes(b, []byte(benchmarkSentencePieceJSON))
+ text := strings.Repeat("hello world\n", 64)
+ ids := tok.Encode(text, false)
+
+ b.ReportAllocs()
+ b.SetBytes(int64(len(text)))
+ b.ResetTimer()
+
+ for i := 0; i < b.N; i++ {
+ benchmarkSinkStr = tok.Decode(ids)
+ }
+}
diff --git a/x/tokenizer/tokenizer_bpe.go b/x/tokenizer/tokenizer_bpe.go
new file mode 100644
index 00000000000..1e625c20a7d
--- /dev/null
+++ b/x/tokenizer/tokenizer_bpe.go
@@ -0,0 +1,175 @@
+//go:build mlx
+
+package tokenizer
+
+import "container/heap"
+
+type bpeMergeNode struct {
+ prev int
+ next int
+ token string
+}
+
+type bpePair struct {
+ left int
+ right int
+ rank int
+ value string
+}
+
+type bpePairHeap []*bpePair
+
+func (h bpePairHeap) Len() int { return len(h) }
+
+func (h bpePairHeap) Less(i, j int) bool {
+ return h[i].rank < h[j].rank || (h[i].rank == h[j].rank && h[i].left < h[j].left)
+}
+
+func (h bpePairHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
+
+func (h *bpePairHeap) Push(x any) {
+ *h = append(*h, x.(*bpePair))
+}
+
+func (h *bpePairHeap) Pop() any {
+ old := *h
+ n := len(old)
+ item := old[n-1]
+ *h = old[:n-1]
+ return item
+}
+
+// encodeBPEMerge encodes using BPE merge algorithm.
+// Uses the heap/linked-list pair merge strategy from tokenizer/bytepairencoding.go:
+// merge the lowest-rank valid pair, then only recheck adjacent pairs.
+func (t *Tokenizer) encodeBPEMerge(encoded string, ids []int32) []int32 {
+ runes := []rune(encoded)
+ if len(runes) == 0 {
+ return ids
+ }
+
+ nodes := make([]bpeMergeNode, len(runes))
+ for i := range runes {
+ nodes[i] = bpeMergeNode{
+ prev: i - 1,
+ next: i + 1,
+ token: string(runes[i]),
+ }
+ }
+
+ pairwise := func(left, right int) *bpePair {
+ if left < 0 || right >= len(nodes) {
+ return nil
+ }
+ if nodes[left].token == "" || nodes[right].token == "" {
+ return nil
+ }
+
+ leftToken, rightToken := nodes[left].token, nodes[right].token
+ rank, ok := t.vocab.Merges[leftToken+" "+rightToken]
+ if !ok {
+ return nil
+ }
+
+ value := leftToken + rightToken
+ if _, ok := t.vocab.Reverse[value]; !ok {
+ return nil
+ }
+
+ return &bpePair{
+ left: left,
+ right: right,
+ rank: rank,
+ value: value,
+ }
+ }
+
+ pairs := bpePairHeap{}
+ heap.Init(&pairs)
+ for i := 0; i < len(runes)-1; i++ {
+ if pair := pairwise(i, i+1); pair != nil {
+ heap.Push(&pairs, pair)
+ }
+ }
+
+ for pairs.Len() > 0 {
+ pair := heap.Pop(&pairs).(*bpePair)
+ left, right := nodes[pair.left], nodes[pair.right]
+ if left.token == "" || right.token == "" {
+ continue
+ }
+ if left.next != pair.right || right.prev != pair.left {
+ continue
+ }
+ if left.token+right.token != pair.value {
+ continue
+ }
+
+ nodes[pair.left].token = pair.value
+ nodes[pair.right].token = ""
+ nodes[pair.left].next = right.next
+ if right.next < len(nodes) {
+ nodes[right.next].prev = pair.left
+ }
+
+ if pair := pairwise(nodes[pair.left].prev, pair.left); pair != nil {
+ heap.Push(&pairs, pair)
+ }
+ if pair := pairwise(pair.left, nodes[pair.left].next); pair != nil {
+ heap.Push(&pairs, pair)
+ }
+ }
+
+ for _, node := range nodes {
+ if node.token == "" {
+ continue
+ }
+
+ if id, ok := t.vocab.Reverse[node.token]; ok {
+ ids = append(ids, id)
+ continue
+ }
+
+ ids = t.appendByteFallback(ids, node.token)
+ }
+
+ return ids
+}
+
+func (t *Tokenizer) appendByteFallback(ids []int32, token string) []int32 {
+ if t.typ == TokenizerBPE {
+ for _, r := range token {
+ if b, ok := decodeByteLevelRune(r); ok {
+ if id := t.vocab.byteTokens[b]; id >= 0 {
+ ids = append(ids, id)
+ }
+ }
+ }
+ return ids
+ }
+
+ // SentencePiece fallback uses the UTF-8 bytes for <0xNN> tokens.
+ for _, b := range []byte(token) {
+ if id := t.vocab.byteTokens[b]; id >= 0 {
+ ids = append(ids, id)
+ }
+ }
+ return ids
+}
+
+func decodeByteLevelRune(r rune) (byte, bool) {
+ switch {
+ case r >= 0x00 && r <= 0xFF:
+ return byte(r), true
+ case r == 0x0100:
+ return 0x00, true
+ case r == 0x0143:
+ return 0x00ad, true
+ case r > 0x0100 && r <= 0x0120:
+ return byte(r - 0x0100), true
+ case r > 0x0120 && r <= 0x0142:
+ return byte(r - 0x00a2), true
+ default:
+ return 0, false
+ }
+}
diff --git a/x/tokenizer/tokenizer_correctness_test.go b/x/tokenizer/tokenizer_correctness_test.go
new file mode 100644
index 00000000000..2fe94d27925
--- /dev/null
+++ b/x/tokenizer/tokenizer_correctness_test.go
@@ -0,0 +1,137 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func equalIDs(a, b []int32) bool {
+ if len(a) != len(b) {
+ return false
+ }
+ for i := range a {
+ if a[i] != b[i] {
+ return false
+ }
+ }
+ return true
+}
+
+func TestEncodeRoundtripMiniLlama(t *testing.T) {
+ tok := benchmarkLoadMiniLlama(t)
+
+ inputs := []string{
+ "",
+ "hello",
+ "hello world",
+ " hello world ",
+ "don't we'll they're",
+ "1234567890",
+ "こんにちは世界",
+ "Hello 世界",
+ "func main() {}",
+ "<|begin_of_text|>system\nYou are concise.<|end_of_text|>",
+ strings.Repeat("The quick brown fox jumps over the lazy dog. ", 32),
+ }
+
+ for _, input := range inputs {
+ ids := tok.Encode(input, false)
+ got := tok.Decode(ids)
+ if got != input {
+ t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
+ }
+ }
+}
+
+func TestSplitBySpecialTokensGreedyLongest(t *testing.T) {
+ data := []byte(`{
+ "model": {
+ "type": "BPE",
+ "vocab": {"a": 0, "b": 1},
+ "merges": []
+ },
+ "added_tokens": [
+ {"id": 2, "content": "", "special": true},
+ {"id": 3, "content": "x", "special": true}
+ ]
+ }`)
+
+ tok, err := LoadFromBytes(data)
+ if err != nil {
+ t.Fatalf("failed to load tokenizer: %v", err)
+ }
+
+ input := "axb"
+ want := []string{"a", "x", "b"}
+
+ got := tok.splitBySpecialTokens(input)
+ if len(got) != len(want) {
+ t.Fatalf("split length mismatch: got %v want %v", got, want)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
+ }
+ }
+}
+
+func TestSplitBySpecialTokensFallbackWithoutCache(t *testing.T) {
+ data := []byte(`{
+ "model": {
+ "type": "BPE",
+ "vocab": {"a": 0, "b": 1},
+ "merges": []
+ },
+ "added_tokens": [
+ {"id": 2, "content": "", "special": true},
+ {"id": 3, "content": "x", "special": true}
+ ]
+ }`)
+
+ tok, err := LoadFromBytes(data)
+ if err != nil {
+ t.Fatalf("failed to load tokenizer: %v", err)
+ }
+
+ input := "axb"
+ want := []string{"a", "x", "b"}
+
+ // Simulate construction outside loader path where cache is not set.
+ tok.sortedSpecialTokens = nil
+
+ got := tok.splitBySpecialTokens(input)
+ if len(got) != len(want) {
+ t.Fatalf("split length mismatch: got %v want %v", got, want)
+ }
+ for i := range want {
+ if got[i] != want[i] {
+ t.Fatalf("split mismatch at %d: got %v want %v", i, got, want)
+ }
+ }
+}
+
+func TestEncodeDeterministicAcrossGOMAXPROCS(t *testing.T) {
+ tok := benchmarkLoadMiniLlama(t)
+
+ input := strings.Repeat("The quick brown fox jumps over the lazy dog. ", 640)
+
+ prev := runtime.GOMAXPROCS(0)
+ defer runtime.GOMAXPROCS(prev)
+
+ runtime.GOMAXPROCS(1)
+ seq := tok.Encode(input, false)
+
+ if prev < 2 {
+ runtime.GOMAXPROCS(2)
+ } else {
+ runtime.GOMAXPROCS(prev)
+ }
+ par := tok.Encode(input, false)
+
+ if !equalIDs(seq, par) {
+ t.Fatalf("encode mismatch between sequential and parallel paths: seq=%d par=%d", len(seq), len(par))
+ }
+}
diff --git a/x/tokenizer/tokenizer_decode.go b/x/tokenizer/tokenizer_decode.go
new file mode 100644
index 00000000000..e02d2a88bc7
--- /dev/null
+++ b/x/tokenizer/tokenizer_decode.go
@@ -0,0 +1,56 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "strconv"
+ "strings"
+)
+
+// Decode converts token IDs back to text
+func (t *Tokenizer) Decode(ids []int32) string {
+ var sb strings.Builder
+
+ for _, id := range ids {
+ if int(id) >= len(t.vocab.Values) {
+ continue
+ }
+
+ token := t.vocab.Values[id]
+
+ switch t.typ {
+ case TokenizerSentencePiece:
+ // SentencePiece style: replace ▁ with space, decode byte tokens
+ token = strings.ReplaceAll(token, "▁", " ")
+ // Handle byte fallback tokens like <0x0D>
+ if len(token) == 6 && token[0] == '<' && token[1] == '0' && token[2] == 'x' && token[5] == '>' {
+ if v, err := strconv.ParseUint(token[3:5], 16, 8); err == nil {
+ sb.WriteByte(byte(v))
+ continue
+ }
+ }
+ sb.WriteString(token)
+ default:
+ // GPT-2 BPE style: decode byte-level encoding
+ for _, r := range token {
+ switch {
+ case r == 0x0100:
+ // Mirror GGML tokenizer behavior for NULL byte.
+ // 0x00 is omitted during decode.
+ continue
+ case r == 0x0143:
+ r = 0x00ad
+ case r > 0x0100 && r <= 0x0120:
+ r = r - 0x0100
+ case r > 0x0120 && r <= 0x0142:
+ r = r - 0x00a2
+ }
+
+ // Write as byte, not UTF-8 encoded rune
+ sb.WriteByte(byte(r))
+ }
+ }
+ }
+
+ return sb.String()
+}
diff --git a/x/tokenizer/tokenizer_encode.go b/x/tokenizer/tokenizer_encode.go
new file mode 100644
index 00000000000..1b71ea6d37d
--- /dev/null
+++ b/x/tokenizer/tokenizer_encode.go
@@ -0,0 +1,289 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "runtime"
+ "sort"
+ "strings"
+ "sync"
+ "unicode"
+ "unicode/utf8"
+)
+
+const (
+ encodeParallelMinInputBytes = 4 * 1024
+ encodeParallelMinChunksPerWorker = 8
+)
+
+type tokenMatch struct {
+ start int
+ end int
+}
+
+type encodeChunk struct {
+ text string
+ isSpecial bool
+}
+
+// isNonNewlineWhitespace returns true if s contains only whitespace characters (no newlines)
+func isNonNewlineWhitespace(s string) bool {
+ if s == "" {
+ return false
+ }
+ for _, r := range s {
+ if r == '\n' || r == '\r' {
+ return false
+ }
+ if !unicode.IsSpace(r) {
+ return false
+ }
+ }
+ return true
+}
+
+// splitBySpecialTokens splits text into parts, keeping special tokens as separate elements
+func (t *Tokenizer) splitBySpecialTokens(s string) []string {
+ if len(t.specialTokens) == 0 {
+ return []string{s}
+ }
+
+ tokens := t.sortedSpecialTokens
+ if len(tokens) == 0 {
+ // Fallback for tokenizers constructed outside the loaders.
+ tokens = make([]string, 0, len(t.specialTokens))
+ for tok := range t.specialTokens {
+ tokens = append(tokens, tok)
+ }
+ sort.Slice(tokens, func(i, j int) bool {
+ return len(tokens[i]) > len(tokens[j])
+ })
+ }
+
+ var result []string
+ remaining := s
+
+ for len(remaining) > 0 {
+ found := false
+ for _, tok := range tokens {
+ if strings.HasPrefix(remaining, tok) {
+ result = append(result, tok)
+ remaining = remaining[len(tok):]
+ found = true
+ break
+ }
+ }
+ if !found {
+ // Find next special token position
+ nextPos := len(remaining)
+ for _, tok := range tokens {
+ if idx := strings.Index(remaining, tok); idx != -1 && idx < nextPos {
+ nextPos = idx
+ }
+ }
+ if nextPos > 0 {
+ result = append(result, remaining[:nextPos])
+ }
+ remaining = remaining[nextPos:]
+ }
+ }
+
+ return result
+}
+
+func adjustWhitespaceBoundary(part string, curr, next *tokenMatch) {
+ m := part[curr.start:curr.end]
+ nextText := part[next.start:next.end]
+
+ if !isNonNewlineWhitespace(m) || len(nextText) == 0 {
+ return
+ }
+
+ firstRune, _ := utf8.DecodeRuneInString(nextText)
+ if !unicode.IsLetter(firstRune) {
+ return
+ }
+
+ lastSpaceStart := curr.end
+ for j := curr.end; j > curr.start; {
+ r, size := utf8.DecodeLastRuneInString(part[curr.start:j])
+ if unicode.IsSpace(r) {
+ lastSpaceStart = j - size
+ break
+ }
+ j -= size
+ }
+ if lastSpaceStart > curr.start {
+ curr.end = lastSpaceStart
+ next.start = lastSpaceStart
+ } else {
+ next.start = curr.start
+ curr.end = curr.start
+ }
+}
+
+func (t *Tokenizer) forEachPartChunk(part string, fn func(encodeChunk)) {
+ if _, ok := t.specialTokens[part]; ok {
+ fn(encodeChunk{text: part, isSpecial: true})
+ return
+ }
+
+ if t.pretokenizer == nil {
+ fn(encodeChunk{text: part, isSpecial: false})
+ return
+ }
+
+ re := t.pretokenizer
+ offset := 0
+ loc := re.FindStringIndex(part[offset:])
+ if loc == nil {
+ return
+ }
+
+ curr := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
+ offset += loc[1]
+
+ for {
+ loc = re.FindStringIndex(part[offset:])
+ if loc == nil {
+ if curr.end > curr.start {
+ fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
+ }
+ return
+ }
+
+ next := tokenMatch{start: offset + loc[0], end: offset + loc[1]}
+ offset += loc[1]
+
+ adjustWhitespaceBoundary(part, &curr, &next)
+
+ if curr.end > curr.start {
+ fn(encodeChunk{text: part[curr.start:curr.end], isSpecial: false})
+ }
+ curr = next
+ }
+}
+
+func (t *Tokenizer) appendEncodedChunk(ids []int32, c encodeChunk) []int32 {
+ if c.isSpecial {
+ if id, ok := t.specialTokens[c.text]; ok {
+ return append(ids, id)
+ }
+ return ids
+ }
+
+ return t.encodeChunkInto(c.text, ids)
+}
+
+// Encode tokenizes text to token IDs.
+// Parallel encoding is used only for very large inputs with enough chunks per worker.
+func (t *Tokenizer) Encode(s string, addBOS bool) []int32 {
+ // First: split by special tokens
+ parts := t.splitBySpecialTokens(s)
+
+ // Fast path: encode sequentially without materializing chunk slices.
+ if len(s) < encodeParallelMinInputBytes {
+ var ids []int32
+ for _, part := range parts {
+ t.forEachPartChunk(part, func(c encodeChunk) {
+ ids = t.appendEncodedChunk(ids, c)
+ })
+ }
+
+ if addBOS && t.vocab.BOS >= 0 {
+ ids = append([]int32{t.vocab.BOS}, ids...)
+ }
+ return ids
+ }
+
+ // For large inputs collect chunks to enable parallel processing.
+ var allChunks []encodeChunk
+ for _, part := range parts {
+ t.forEachPartChunk(part, func(c encodeChunk) {
+ allChunks = append(allChunks, c)
+ })
+ }
+
+ // Encode chunks. Use the parallel path only when the chunk count is
+ // large enough to amortize goroutine/synchronization overhead.
+ useParallel := true
+ numWorkers := runtime.GOMAXPROCS(0)
+ if numWorkers > len(allChunks) {
+ numWorkers = len(allChunks)
+ }
+ if numWorkers < 2 || len(allChunks) < numWorkers*encodeParallelMinChunksPerWorker {
+ useParallel = false
+ }
+
+ var ids []int32
+ if !useParallel {
+ for _, c := range allChunks {
+ ids = t.appendEncodedChunk(ids, c)
+ }
+ } else {
+ chunksPer := (len(allChunks) + numWorkers - 1) / numWorkers
+ results := make([][]int32, numWorkers)
+ var wg sync.WaitGroup
+
+ for i := 0; i < numWorkers; i++ {
+ start := i * chunksPer
+ end := start + chunksPer
+ if end > len(allChunks) {
+ end = len(allChunks)
+ }
+ if start >= end {
+ continue
+ }
+
+ wg.Add(1)
+ go func(i int, chunks []encodeChunk) {
+ defer wg.Done()
+ var r []int32
+ for _, c := range chunks {
+ r = t.appendEncodedChunk(r, c)
+ }
+ results[i] = r
+ }(i, allChunks[start:end])
+ }
+ wg.Wait()
+
+ for _, r := range results {
+ ids = append(ids, r...)
+ }
+ }
+
+ if addBOS && t.vocab.BOS >= 0 {
+ ids = append([]int32{t.vocab.BOS}, ids...)
+ }
+ return ids
+}
+
+// encodeChunkInto appends encoded tokens to ids and returns the extended slice.
+// Uses BPE merge algorithm for both BPE and SentencePiece tokenization.
+func (t *Tokenizer) encodeChunkInto(s string, ids []int32) []int32 {
+ if s == "" {
+ return ids
+ }
+
+ // Apply encoding transformation
+ // SentencePiece: replace space with ▁
+ // BPE: convert bytes using precomputed table (GPT-2 byte-level encoding)
+ var encoded string
+ if t.typ == TokenizerSentencePiece {
+ encoded = strings.ReplaceAll(s, " ", "▁")
+ } else {
+ var sb strings.Builder
+ sb.Grow(len(s) * 2)
+ for i := 0; i < len(s); i++ {
+ sb.WriteRune(byteToRune[s[i]])
+ }
+ encoded = sb.String()
+ }
+
+ // Fast path: check if entire chunk is a single token
+ if id, ok := t.vocab.Reverse[encoded]; ok {
+ return append(ids, id)
+ }
+
+ return t.encodeBPEMerge(encoded, ids)
+}
diff --git a/x/tokenizer/tokenizer_ggml_parity_test.go b/x/tokenizer/tokenizer_ggml_parity_test.go
new file mode 100644
index 00000000000..4cef3d3dd40
--- /dev/null
+++ b/x/tokenizer/tokenizer_ggml_parity_test.go
@@ -0,0 +1,207 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "bufio"
+ "encoding/json"
+ "os"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+)
+
+func llama32GGMLFixturePath(tb testing.TB, file string) string {
+ tb.Helper()
+
+ _, filename, _, ok := runtime.Caller(0)
+ if !ok {
+ tb.Fatal("failed to resolve test file path")
+ }
+
+ return filepath.Join(filepath.Dir(filename), "..", "..", "tokenizer", "testdata", "llama3.2", file)
+}
+
+func loadLlama32FromGGMLFixture(tb testing.TB) *Tokenizer {
+ tb.Helper()
+
+ f, err := os.Open(llama32GGMLFixturePath(tb, "encoder.json"))
+ if err != nil {
+ tb.Fatalf("failed to open encoder.json: %v", err)
+ }
+ defer f.Close()
+
+ vocab := make(map[string]int32)
+ if err := json.NewDecoder(f).Decode(&vocab); err != nil {
+ tb.Fatalf("failed to decode encoder.json: %v", err)
+ }
+
+ type addedToken struct {
+ ID int32 `json:"id"`
+ Content string `json:"content"`
+ Special bool `json:"special"`
+ }
+ var addedTokens []addedToken
+ for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
+ if _, ok := vocab[token]; !ok {
+ id := int32(len(vocab))
+ vocab[token] = id
+ addedTokens = append(addedTokens, addedToken{ID: id, Content: token, Special: true})
+ }
+ }
+
+ mf, err := os.Open(llama32GGMLFixturePath(tb, "vocab.bpe"))
+ if err != nil {
+ tb.Fatalf("failed to open vocab.bpe: %v", err)
+ }
+ defer mf.Close()
+
+ var merges []string
+ scanner := bufio.NewScanner(mf)
+ for scanner.Scan() {
+ line := scanner.Text()
+ if strings.HasPrefix(line, "#") {
+ continue
+ }
+ line = strings.TrimSpace(line)
+ if line != "" {
+ merges = append(merges, line)
+ }
+ }
+ if err := scanner.Err(); err != nil {
+ tb.Fatalf("failed to read vocab.bpe: %v", err)
+ }
+
+ payload := struct {
+ Model struct {
+ Type string `json:"type"`
+ Vocab map[string]int32 `json:"vocab"`
+ Merges []string `json:"merges"`
+ } `json:"model"`
+ PreTokenizer struct {
+ Type string `json:"type"`
+ Pretokenizers []struct {
+ Type string `json:"type"`
+ Pattern struct {
+ Regex string `json:"Regex"`
+ } `json:"pattern"`
+ } `json:"pretokenizers"`
+ } `json:"pre_tokenizer"`
+ AddedTokens []addedToken `json:"added_tokens"`
+ }{}
+
+ payload.Model.Type = "BPE"
+ payload.Model.Vocab = vocab
+ payload.Model.Merges = merges
+ payload.PreTokenizer.Type = "Sequence"
+ payload.PreTokenizer.Pretokenizers = []struct {
+ Type string `json:"type"`
+ Pattern struct {
+ Regex string `json:"Regex"`
+ } `json:"pattern"`
+ }{
+ {
+ Type: "Split",
+ Pattern: struct {
+ Regex string `json:"Regex"`
+ }{
+ Regex: `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
+ },
+ },
+ }
+ payload.AddedTokens = addedTokens
+
+ data, err := json.Marshal(payload)
+ if err != nil {
+ tb.Fatalf("failed to marshal synthetic tokenizer.json: %v", err)
+ }
+
+ tok, err := LoadFromBytes(data)
+ if err != nil {
+ tb.Fatalf("failed to load tokenizer from fixture data: %v", err)
+ }
+ return tok
+}
+
+func TestGGMLLlamaKnownEncodings(t *testing.T) {
+ tok := loadLlama32FromGGMLFixture(t)
+
+ cases := map[string][]int32{
+ "hello world": {15339, 1917},
+ "hello <|end_of_text|>": {15339, 220, 128001},
+ "<|begin_of_text|>A B!": {128000, 32, 426, 0},
+ "<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
+ "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
+ "<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
+ }
+
+ for input, want := range cases {
+ got := tok.Encode(input, false)
+ if !equalIDs(got, want) {
+ t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
+ }
+ }
+}
+
+func TestGGMLLlamaRepeatedZeros(t *testing.T) {
+ tok := loadLlama32FromGGMLFixture(t)
+
+ cases := map[int][]int32{
+ 1: {15},
+ 2: {410},
+ 3: {931},
+ 4: {931, 15},
+ 5: {931, 410},
+ 6: {931, 931},
+ 7: {931, 931, 15},
+ 8: {931, 931, 410},
+ 9: {931, 931, 931},
+ 10: {931, 931, 931, 15},
+ 11: {931, 931, 931, 410},
+ 12: {931, 931, 931, 931},
+ 13: {931, 931, 931, 931, 15},
+ 14: {931, 931, 931, 931, 410},
+ 15: {931, 931, 931, 931, 931},
+ 16: {931, 931, 931, 931, 931, 15},
+ 17: {931, 931, 931, 931, 931, 410},
+ }
+
+ for n, want := range cases {
+ input := strings.Repeat("0", n)
+ got := tok.Encode(input, false)
+ if !equalIDs(got, want) {
+ t.Fatalf("encode mismatch for %q:\n got: %v\n want: %v", input, got, want)
+ }
+ }
+}
+
+func TestGGMLLlamaRoundtripAndByteBehavior(t *testing.T) {
+ tok := loadLlama32FromGGMLFixture(t)
+
+ cases := []string{
+ "hello",
+ "hello ",
+ "hello ",
+ " hello",
+ " hello ",
+ " hello ",
+ "hello world",
+ "请考试我的软件!12345",
+ }
+
+ for _, input := range cases {
+ ids := tok.Encode(input, false)
+ got := tok.Decode(ids)
+ if got != input {
+ t.Fatalf("roundtrip mismatch for %q: got %q", input, got)
+ }
+ }
+
+ // Match GGML tokenizer behavior: 0x00 is omitted when decoding.
+ ids := tok.Encode(string(rune(0x00)), false)
+ got := tok.Decode(ids)
+ if got != "" {
+ t.Fatalf("expected empty decode for 0x00, got %q (ids=%v)", got, ids)
+ }
+}
diff --git a/x/tokenizer/tokenizer_load.go b/x/tokenizer/tokenizer_load.go
new file mode 100644
index 00000000000..d2a253e179c
--- /dev/null
+++ b/x/tokenizer/tokenizer_load.go
@@ -0,0 +1,458 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "encoding/json"
+ "fmt"
+ "regexp"
+ "sort"
+ "strings"
+)
+
+// TokenizerConfig holds optional configuration data that can be passed to LoadFromBytesWithConfig.
+type TokenizerConfig struct {
+ TokenizerConfigJSON []byte // tokenizer_config.json content
+ GenerationConfigJSON []byte // generation_config.json content
+ SpecialTokensMapJSON []byte // special_tokens_map.json content
+ ConfigJSON []byte // config.json content
+}
+
+// LoadFromBytes loads a tokenizer from tokenizer.json bytes.
+// This is useful when loading from blob storage where the file content is already in memory.
+// Note: This won't load special token config from companion files. Use LoadFromBytesWithConfig
+// to provide tokenizer_config.json data for proper PAD/EOS token loading.
+func LoadFromBytes(data []byte) (*Tokenizer, error) {
+ return loadFromTokenizerJSON(data)
+}
+
+// LoadFromBytesWithConfig loads a tokenizer from tokenizer.json bytes with additional config files.
+// This is useful when loading from blob storage where companion config files are also blobs.
+func LoadFromBytesWithConfig(data []byte, config *TokenizerConfig) (*Tokenizer, error) {
+ t, err := loadFromTokenizerJSON(data)
+ if err != nil {
+ return nil, err
+ }
+
+ if config == nil {
+ return t, nil
+ }
+
+ // Apply special token configs from provided data
+ loadSpecialTokenConfigFromBytes(t, config)
+
+ return t, nil
+}
+
+// loadFromTokenizerJSON parses tokenizer.json content from bytes.
+func loadFromTokenizerJSON(data []byte) (*Tokenizer, error) {
+
+ var raw struct {
+ Model struct {
+ Type string `json:"type"` // "BPE"
+ Vocab map[string]int32 `json:"vocab"`
+ Merges json.RawMessage `json:"merges"` // Can be []string or [][]string (BPE only)
+ } `json:"model"`
+ PreTokenizer json.RawMessage `json:"pre_tokenizer"`
+ Decoder json.RawMessage `json:"decoder"`
+ AddedTokens []struct {
+ ID int32 `json:"id"`
+ Content string `json:"content"`
+ Special bool `json:"special"`
+ } `json:"added_tokens"`
+ }
+
+ if err := json.Unmarshal(data, &raw); err != nil {
+ return nil, fmt.Errorf("failed to parse tokenizer: %w", err)
+ }
+
+ // Covers SentencePiece and BPE models
+ if raw.Model.Type != "BPE" {
+ return nil, fmt.Errorf("unsupported tokenizer type: %s", raw.Model.Type)
+ }
+
+ // Parse merges - can be []string (Llama) or [][]string (GPT-OSS).
+ var mergesStrings []string
+ if raw.Model.Merges != nil {
+ var mergesArrays [][]string
+ if err := json.Unmarshal(raw.Model.Merges, &mergesStrings); err != nil {
+ // Try array of arrays format
+ if err := json.Unmarshal(raw.Model.Merges, &mergesArrays); err != nil {
+ return nil, fmt.Errorf("failed to parse merges: %w", err)
+ }
+ // Convert [][]string to []string
+ mergesStrings = make([]string, len(mergesArrays))
+ for i, pair := range mergesArrays {
+ if len(pair) != 2 {
+ return nil, fmt.Errorf("failed to parse merges: expected merge pair of length 2, got %d", len(pair))
+ }
+ mergesStrings[i] = pair[0] + " " + pair[1]
+ }
+ }
+ }
+
+ // Build tokenizer
+ t := &Tokenizer{
+ vocab: &Vocabulary{
+ Values: make([]string, len(raw.Model.Vocab)),
+ Reverse: raw.Model.Vocab,
+ Merges: make(map[string]int, len(mergesStrings)),
+ BOS: -1,
+ PAD: -1,
+ },
+ specialTokens: make(map[string]int32),
+ }
+
+ // Build values array
+ for token, id := range raw.Model.Vocab {
+ if int(id) >= len(t.vocab.Values) {
+ newValues := make([]string, id+1)
+ copy(newValues, t.vocab.Values)
+ t.vocab.Values = newValues
+ }
+ t.vocab.Values[id] = token
+ }
+
+ // Build merges map
+ for i, merge := range mergesStrings {
+ t.vocab.Merges[merge] = i
+ }
+
+ // Add all added_tokens to vocabulary and special tokens map.
+ // HuggingFace treats ALL added_tokens as special for tokenization purposes -
+ // they bypass BPE and get their own token ID. The "special" flag just indicates
+ // if it's a "truly special" token like BOS/EOS/PAD, but for tokenization we need
+ // to treat all added_tokens as special to match HuggingFace behavior.
+ for _, tok := range raw.AddedTokens {
+ if int(tok.ID) >= len(t.vocab.Values) {
+ newValues := make([]string, tok.ID+1)
+ copy(newValues, t.vocab.Values)
+ t.vocab.Values = newValues
+ }
+ t.vocab.Values[tok.ID] = tok.Content
+ t.specialTokens[tok.Content] = tok.ID // Add ALL added_tokens to special tokens
+ }
+
+ // Precompute byte token IDs for <0xNN> fallback
+ initByteTokens(t)
+
+ // Determine tokenizer type
+ switch {
+ case detectSentencePiece(raw.Decoder):
+ t.typ = TokenizerSentencePiece
+ default:
+ t.typ = TokenizerBPE
+ }
+
+ // Parse and compile pretokenizer pattern (BPE only - SentencePiece doesn't use pretokenizer)
+ if t.typ == TokenizerBPE {
+ pattern := extractPretokenizer(raw.PreTokenizer)
+ if pattern == "" {
+ pattern = `'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+`
+ }
+ re, err := regexp.Compile(rewritePatternForRE2(pattern))
+ if err != nil {
+ return nil, fmt.Errorf("failed to compile pretokenizer regex %q: %w", pattern, err)
+ }
+ t.pretokenizer = re
+ }
+
+ cacheSortedSpecialTokens(t)
+
+ return t, nil
+}
+
+func cacheSortedSpecialTokens(t *Tokenizer) {
+ if len(t.specialTokens) == 0 {
+ t.sortedSpecialTokens = nil
+ return
+ }
+
+ tokens := make([]string, 0, len(t.specialTokens))
+ for tok := range t.specialTokens {
+ tokens = append(tokens, tok)
+ }
+ sort.Slice(tokens, func(i, j int) bool {
+ return len(tokens[i]) > len(tokens[j])
+ })
+ t.sortedSpecialTokens = tokens
+}
+
+type specialTokenConfigData struct {
+ tokenizerConfigJSON []byte
+ generationConfigJSON []byte
+ specialTokensMapJSON []byte
+ configJSON []byte
+}
+
+func applySpecialTokenConfig(t *Tokenizer, config specialTokenConfigData) {
+ parseTokenIDs := func(v interface{}) []int32 {
+ switch val := v.(type) {
+ case float64:
+ return []int32{int32(val)}
+ case []interface{}:
+ ids := make([]int32, 0, len(val))
+ for _, id := range val {
+ if f, ok := id.(float64); ok {
+ ids = append(ids, int32(f))
+ }
+ }
+ return ids
+ }
+ return nil
+ }
+
+ // Priority 1: generation_config.json
+ if len(config.generationConfigJSON) > 0 {
+ var genConfig struct {
+ EOSTokenID interface{} `json:"eos_token_id"`
+ BOSTokenID interface{} `json:"bos_token_id"`
+ }
+ if err := json.Unmarshal(config.generationConfigJSON, &genConfig); err == nil {
+ if ids := parseTokenIDs(genConfig.EOSTokenID); len(ids) > 0 {
+ t.vocab.EOS = ids
+ }
+ if ids := parseTokenIDs(genConfig.BOSTokenID); len(ids) > 0 {
+ t.vocab.BOS = ids[0]
+ }
+ }
+ }
+
+ // Priority 2: config.json
+ if len(config.configJSON) > 0 && (len(t.vocab.EOS) == 0 || t.vocab.BOS < 0) {
+ var modelConfig struct {
+ EOSTokenID interface{} `json:"eos_token_id"`
+ BOSTokenID interface{} `json:"bos_token_id"`
+ }
+ if err := json.Unmarshal(config.configJSON, &modelConfig); err == nil {
+ if len(t.vocab.EOS) == 0 {
+ if ids := parseTokenIDs(modelConfig.EOSTokenID); len(ids) > 0 {
+ t.vocab.EOS = ids
+ }
+ }
+ if t.vocab.BOS < 0 {
+ if ids := parseTokenIDs(modelConfig.BOSTokenID); len(ids) > 0 {
+ t.vocab.BOS = ids[0]
+ }
+ }
+ }
+ }
+
+ // Priority 3: tokenizer_config.json
+ if len(config.tokenizerConfigJSON) > 0 {
+ var tokConfig struct {
+ BOSToken interface{} `json:"bos_token"`
+ EOSToken interface{} `json:"eos_token"`
+ PADToken interface{} `json:"pad_token"`
+ AddBOSToken *bool `json:"add_bos_token"`
+ AddEOSToken *bool `json:"add_eos_token"`
+ }
+ if err := json.Unmarshal(config.tokenizerConfigJSON, &tokConfig); err == nil {
+ if t.vocab.BOS < 0 {
+ if bosStr := extractTokenString(tokConfig.BOSToken); bosStr != "" {
+ if id, ok := t.specialTokens[bosStr]; ok {
+ t.vocab.BOS = id
+ }
+ }
+ }
+ if len(t.vocab.EOS) == 0 {
+ if eosStr := extractTokenString(tokConfig.EOSToken); eosStr != "" {
+ if id, ok := t.specialTokens[eosStr]; ok {
+ t.vocab.EOS = []int32{id}
+ }
+ }
+ }
+ if t.vocab.PAD < 0 {
+ if padStr := extractTokenString(tokConfig.PADToken); padStr != "" {
+ if id, ok := t.specialTokens[padStr]; ok {
+ t.vocab.PAD = id
+ }
+ }
+ }
+ if tokConfig.AddBOSToken != nil {
+ t.vocab.AddBOS = *tokConfig.AddBOSToken
+ }
+ if tokConfig.AddEOSToken != nil {
+ t.vocab.AddEOS = *tokConfig.AddEOSToken
+ }
+ }
+ }
+
+ // Priority 4: special_tokens_map.json
+ if len(config.specialTokensMapJSON) > 0 {
+ var tokensMap map[string]interface{}
+ if err := json.Unmarshal(config.specialTokensMapJSON, &tokensMap); err == nil {
+ if t.vocab.BOS < 0 {
+ if bosStr := extractTokenString(tokensMap["bos_token"]); bosStr != "" {
+ if id, ok := t.specialTokens[bosStr]; ok {
+ t.vocab.BOS = id
+ }
+ }
+ }
+ if len(t.vocab.EOS) == 0 {
+ if eosStr := extractTokenString(tokensMap["eos_token"]); eosStr != "" {
+ if id, ok := t.specialTokens[eosStr]; ok {
+ t.vocab.EOS = []int32{id}
+ }
+ }
+ }
+ if t.vocab.PAD < 0 {
+ if padStr := extractTokenString(tokensMap["pad_token"]); padStr != "" {
+ if id, ok := t.specialTokens[padStr]; ok {
+ t.vocab.PAD = id
+ }
+ }
+ }
+ }
+ }
+}
+
+// extractTokenString extracts the token string from various formats used in HuggingFace configs.
+// Tokens can be represented as:
+// - string: "token"
+// - object: {"content": "token", ...}
+func extractTokenString(v interface{}) string {
+ if v == nil {
+ return ""
+ }
+ // Direct string
+ if s, ok := v.(string); ok {
+ return s
+ }
+ // Object with content field
+ if m, ok := v.(map[string]interface{}); ok {
+ if content, ok := m["content"].(string); ok {
+ return content
+ }
+ }
+ return ""
+}
+
+// rewritePatternForRE2 rewrites HuggingFace pretokenizer regex patterns to be
+// compatible with Go's regexp package (RE2). HuggingFace patterns use PCRE features:
+// - (?!\S) negative lookahead - RE2 doesn't support this
+// - (?i:...) inline case-insensitive groups - RE2 doesn't support this
+//
+// We replace \s+(?!\S)|\s+ with \s+ and fix whitespace boundaries in encodeWithRegex().
+// The lookahead version splits "a b" into ["a", " ", " b"] (space prepended to word).
+// Simple \s+ would give ["a", " ", "b"]. We post-process to match Python's behavior.
+func rewritePatternForRE2(pattern string) string {
+ // Replace lookahead pattern with simple \s+ - we fix boundaries in encodeWithRegex()
+ pattern = strings.ReplaceAll(pattern, `\s+(?!\S)|\s+`, `\s+`)
+
+ // Handle the pattern when it appears with a ? suffix (optional contractions in GPT-4o style)
+ // IMPORTANT: Must be done before the non-optional version to avoid partial replacement
+ pattern = strings.ReplaceAll(pattern,
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)?`,
+ `(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?`)
+
+ // Expand case-insensitive contraction pattern to explicit alternations
+ // (?i:'s|'t|'re|'ve|'m|'ll|'d) -> '[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD]
+ pattern = strings.ReplaceAll(pattern,
+ `(?i:'s|'t|'re|'ve|'m|'ll|'d)`,
+ `(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])`)
+
+ return pattern
+}
+
+// loadSpecialTokenConfigFromBytes loads special token configuration from byte slices.
+func loadSpecialTokenConfigFromBytes(t *Tokenizer, config *TokenizerConfig) {
+ applySpecialTokenConfig(t, specialTokenConfigData{
+ tokenizerConfigJSON: config.TokenizerConfigJSON,
+ generationConfigJSON: config.GenerationConfigJSON,
+ specialTokensMapJSON: config.SpecialTokensMapJSON,
+ configJSON: config.ConfigJSON,
+ })
+}
+
+// detectSentencePiece checks if the decoder uses SentencePiece-style (▁ for spaces)
+// vs GPT-2 byte-level encoding
+func detectSentencePiece(data json.RawMessage) bool {
+ if data == nil {
+ return false
+ }
+
+ // Check for Sequence decoder with Replace step (SentencePiece style)
+ var seq struct {
+ Type string `json:"type"`
+ Decoders []struct {
+ Type string `json:"type"`
+ Pattern struct {
+ String string `json:"String"`
+ } `json:"pattern"`
+ } `json:"decoders"`
+ }
+ if err := json.Unmarshal(data, &seq); err == nil {
+ if seq.Type == "Sequence" {
+ for _, dec := range seq.Decoders {
+ // Look for Replace decoder that converts ▁ to space
+ if dec.Type == "Replace" && dec.Pattern.String == "▁" {
+ return true
+ }
+ }
+ }
+ }
+
+ // Check for direct ByteLevel decoder (GPT-2 style)
+ var simple struct {
+ Type string `json:"type"`
+ }
+ if err := json.Unmarshal(data, &simple); err == nil {
+ if simple.Type == "ByteLevel" {
+ return false
+ }
+ }
+
+ return false
+}
+
+// initByteTokens precomputes byte token IDs for <0xNN> fallback encoding
+func initByteTokens(t *Tokenizer) {
+ for i := range t.vocab.byteTokens {
+ t.vocab.byteTokens[i] = -1
+ }
+ for b := 0; b < 256; b++ {
+ token := fmt.Sprintf("<0x%02X>", b)
+ if id, ok := t.vocab.Reverse[token]; ok {
+ t.vocab.byteTokens[b] = id
+ }
+ }
+}
+
+// extractPretokenizer extracts the regex pattern from the pre_tokenizer config
+func extractPretokenizer(data json.RawMessage) string {
+ if data == nil {
+ return ""
+ }
+
+ // Try to parse as a single Split pretokenizer
+ var single struct {
+ Type string `json:"type"`
+ Pattern struct {
+ Regex string `json:"Regex"`
+ } `json:"pattern"`
+ }
+ if err := json.Unmarshal(data, &single); err == nil && single.Pattern.Regex != "" {
+ return single.Pattern.Regex
+ }
+
+ // Try to parse as Sequence of pretokenizers - use first Split pattern
+ var seq struct {
+ Type string `json:"type"`
+ Pretokenizers []struct {
+ Type string `json:"type"`
+ Pattern struct {
+ Regex string `json:"Regex"`
+ } `json:"pattern"`
+ } `json:"pretokenizers"`
+ }
+ if err := json.Unmarshal(data, &seq); err == nil && seq.Type == "Sequence" {
+ for _, pt := range seq.Pretokenizers {
+ if pt.Type == "Split" && pt.Pattern.Regex != "" {
+ return pt.Pattern.Regex
+ }
+ }
+ }
+
+ return ""
+}
diff --git a/x/tokenizer/tokenizer_load_test.go b/x/tokenizer/tokenizer_load_test.go
new file mode 100644
index 00000000000..136399c2ea4
--- /dev/null
+++ b/x/tokenizer/tokenizer_load_test.go
@@ -0,0 +1,26 @@
+//go:build mlx
+
+package tokenizer
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestLoadFromBytesRejectsWordPiece(t *testing.T) {
+ data := []byte(`{
+ "model": {
+ "type": "WordPiece",
+ "vocab": {"[UNK]": 0, "hello": 1}
+ },
+ "added_tokens": []
+ }`)
+
+ _, err := LoadFromBytes(data)
+ if err == nil {
+ t.Fatal("expected WordPiece load to fail")
+ }
+ if !strings.Contains(err.Error(), "unsupported tokenizer type: WordPiece") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
diff --git a/x/tools/webfetch.go b/x/tools/webfetch.go
index 82a00623392..793e184c2be 100644
--- a/x/tools/webfetch.go
+++ b/x/tools/webfetch.go
@@ -15,6 +15,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
)
const (
@@ -71,6 +72,10 @@ type webFetchResponse struct {
// Execute fetches content from a web page.
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
func (w *WebFetchTool) Execute(args map[string]any) (string, error) {
+ if internalcloud.Disabled() {
+ return "", errors.New(internalcloud.DisabledError("web fetch is unavailable"))
+ }
+
urlStr, ok := args["url"].(string)
if !ok || urlStr == "" {
return "", fmt.Errorf("url parameter is required")
diff --git a/x/tools/websearch.go b/x/tools/websearch.go
index 16b0dde2c83..1da124af84b 100644
--- a/x/tools/websearch.go
+++ b/x/tools/websearch.go
@@ -15,6 +15,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
+ internalcloud "github.com/ollama/ollama/internal/cloud"
)
const (
@@ -77,6 +78,10 @@ type webSearchResult struct {
// Execute performs the web search.
// Uses Ollama key signing for authentication - this makes requests via ollama.com API.
func (w *WebSearchTool) Execute(args map[string]any) (string, error) {
+ if internalcloud.Disabled() {
+ return "", errors.New(internalcloud.DisabledError("web search is unavailable"))
+ }
+
query, ok := args["query"].(string)
if !ok || query == "" {
return "", fmt.Errorf("query parameter is required")