Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions ml_mdm/models/unet_mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.

import einops

import mlx.core as mx
import mlx.nn as nn

from ml_mdm.models.unet import ResNetConfig


def zero_module_mlx(module):
"""
Zero out the parameters of an MLX module and return it.
"""
# Create a new parameter dictionary with all parameters replaced by zeros
zeroed_params = {
name: mx.zeros(param.shape, dtype=param.dtype)
for name, param in module.parameters().items()
}
# Update the module's parameters with the zeroed parameters
module.update(zeroed_params)
return module


class MLP_MLX(nn.Module): # mlx based nn.Module
def __init__(self, channels, multiplier=4):
super().__init__()
### use mlx layers
self.main = nn.Sequential(
nn.LayerNorm(channels),
nn.Linear(channels, multiplier * channels),
nn.GELU(),
zero_module_mlx(nn.Linear(multiplier * channels, channels)),
)

def forward(self, x):
return x + self.main(x)


class ResNet_MLX(nn.Module):
def __init__(self, time_emb_channels, config: ResNetConfig):
# TODO(ndjaitly): What about scales of weights.
super(ResNet_MLX, self).__init__()
self.config = config
self.num_groups = config.num_groups_norm
self.num_channels = config.num_channels
self.norm1 = nn.GroupNorm(
config.num_groups_norm, config.num_channels, pytorch_compatible=True
)
self.conv1 = nn.Conv2d(
config.num_channels,
config.output_channels,
kernel_size=3,
padding=1,
bias=True,
)
self.time_layer = nn.Linear(time_emb_channels, config.output_channels * 2)
self.norm2 = nn.GroupNorm(
config.num_groups_norm, config.output_channels, pytorch_compatible=True
)
self.dropout = nn.Dropout(config.dropout)
self.conv2 = zero_module_mlx(
nn.Conv2d(
config.output_channels,
config.output_channels,
kernel_size=3,
padding=1,
bias=True,
)
)
if self.config.output_channels != self.config.num_channels:
self.conv3 = nn.Conv2d(
config.num_channels, config.output_channels, kernel_size=1, bias=True
)

def forward(self, x, temb):
print("Shape before norm:", x.shape)
# Try explicitly permuting/reshaping?
h = self.norm1(x)
print("Shape after norm:", h.shape)
h = nn.silu(h)

h = self.conv1(h)
ta, tb = (
self.time_layer(nn.silu(temb)).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1)
)
if h.size(0) > ta.size(0): # HACK. repeat to match the shape.
N = h.size(0) // ta.size(0)
ta = einops.repeat(ta, "b c h w -> (b n) c h w", n=N)
tb = einops.repeat(tb, "b c h w -> (b n) c h w", n=N)
h = nn.silu(self.norm2(h) * (1 + ta) + tb)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try printing the shapes or returning early (in both the pytorch and the mlx version) to see if the tensor shapes are the same up to this point

h = self.dropout(h)
h = self.conv2(h)
if self.config.output_channels != self.config.num_channels:
x = self.conv3(x)
return h + x

def __call__(self, x, temb):
return self.forward(x, temb)
107 changes: 107 additions & 0 deletions tests/test_mlx_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.

import mlx.core as mx
import numpy as np
import torch

from ml_mdm.models.unet import MLP, ResNet, ResNetConfig
from ml_mdm.models.unet_mlx import MLP_MLX, ResNet_MLX


def test_pytorch_mlp():
"""
Simple test for our MLP implementations
"""
# Define parameters
channels = 8 # Number of channels
multiplier = 4 # Multiplier for hidden dimensions

# Create a model instance
pytorch_mlp = MLP(channels=channels, multiplier=multiplier)
mlx_mlp = MLP_MLX(channels=channels, multiplier=multiplier)

## Start by testing pytorch version

# Set model to evaluation mode
pytorch_mlp.eval()

# Create a dummy pytorch input tensor (batch size = 2, channels = 8)
input_tensor = torch.randn(2, channels)

# Pass the input through the model
output = pytorch_mlp(input_tensor)

# Assertions to validate the output shape and properties
assert output.shape == input_tensor.shape, "Output shape mismatch"
assert torch.allclose(
output, input_tensor, atol=1e-5
), "Output should be close to input as the final layer is zero-initialized"

## now test mlx version

# Convert the same input to MLX tensor
mlx_tensor = mx.array(input_tensor.numpy())

mlx_mlp.eval()

mlx_output = mlx_mlp.forward(mlx_tensor)

assert isinstance(mlx_output, mx.array)
assert mlx_output.shape == input_tensor.shape, "MLX MLP: Output shape mismatch"

# Validate numerical equivalence using numpy
assert np.allclose(
output.detach().numpy(), np.array(mlx_output), atol=1e-5
), "Outputs of PyTorch MLP and MLX MLP should match"

print("Test passed for both PyTorch and MLX MLP!")


def test_pytorch_ResNet():
"""
Simple test for our ResNet implementations
"""
# Define parameters
batch_size = 2
time_emb_channels = 32
height = 16
width = 16

# Create config
config = ResNetConfig(
num_channels=64,
output_channels=128,
num_groups_norm=32,
dropout=0.0, # Set to 0 for deterministic comparison
use_attention_ffn=False,
)

# Create model instances
pytorch_resnet = ResNet(time_emb_channels=time_emb_channels, config=config)
mlx_resnet = ResNet_MLX(time_emb_channels=time_emb_channels, config=config)

# Set both models to evaluation mode
pytorch_resnet.eval()
mlx_resnet.eval()

# Create a dummy pytorch input tensor (batch size = 2, channels = 64, height, width = 16)
x_torch = torch.randn(batch_size, config.num_channels, height, width)
temb_torch = torch.randn(batch_size, time_emb_channels)

# pass the input thorugh the model
output_torch, activations_torch = pytorch_resnet(x_torch, temb_torch)

# Convert inputs to MLX tensors
x_mlx = mx.array(x_torch.numpy())
temb_mlx = mx.array(temb_torch.numpy())

# Get MLX output
output_mlx, activations_mlx = mlx_resnet(x_mlx, temb_mlx)

# Verify outputs match
assert np.allclose(
output_torch.detach().numpy(), np.array(output_mlx), atol=1e-5
), "PyTorch and MLX ResNet outputs should match"

print("Test passed for ResNet implementations!")