Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ml-mdm-matryoshka/tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from ml_mdm.language_models.tokenizer import Tokenizer # Tokenizer class from tokenizer.py

def test_tokenizer_bert():
"""Ensure BERT tokenizer loads from the test vocab file without errors."""
f = Path(__file__).parent.parent/"data/bert.vocab" # To solve from relative to absolute import
assert Tokenizer(f, mode="bert")

def test_tokenizer_t5():
"""Ensure T5 tokenizer loads from the test vocab file without errors."""
f = Path(__file__).parent.parent/"data/t5.vocab"
assert Tokenizer(f, mode="tf")

def test_tokenizer():
"""Ensure default tokenizer loads from the Imagenet vocab file."""
f = Path(__file__).parent.parent/"data/imagenet.vocab"
assert Tokenizer(f)

Expand Down
3 changes: 2 additions & 1 deletion ml-mdm-matryoshka/tests/test_unet_mlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def test_pytorch_mlp():


def test_self_attention_1d():
"""Check parity between PyTorch and MLX SelfAttention1D outputs and shapes."""
# Define parameters
channels = 8
num_heads = 2
Expand Down Expand Up @@ -156,4 +157,4 @@ def test_pytorch_mlx_temporal_attention_block():
atol=1e-1, # Significantly increased tolerance
), "Outputs of PyTorch and MLX TemporalAttentionBlock should match"

print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")
print("Test passed for both PyTorch and MLX TemporalAttentionBlock!")