Skip to content

model precision #69

@Fans0014

Description

@Fans0014

I'm trying to use the model from this link(https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr1024/vis_model.pth) with precision torch.float16. However, I encountered a nan result. By debugging the model inference process, I found that some activation values in the model exceeded the maximum value of torch.float16

With using torch.float32, I inserted a print function
print(">>>Debug unet-620: ", x.mean(), temb.mean())
into the https://github.com/apple/ml-mdm/blob/main/ml-mdm-matryoshka/ml_mdm/models/unet.py#L543

And I got the below logs

Debug unet-620: tensor(-0.0274, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(1145381., device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(142.4138, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(-22.8761, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-3.5595, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-0.2938, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-0.3821, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.0356, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.1432, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.1935, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.7042, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.0997, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-0.4697, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-1.0218, device='cuda:0') tensor(-0.2857, device='cuda:0')
Debug unet-620: tensor(-6.7193, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-9.7318, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-3.2291, device='cuda:0') tensor(-2.7265, device='cuda:0')
Debug unet-620: tensor(-137.9206, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(-129216.2422, device='cuda:0') tensor(1961.6932, device='cuda:0')
Debug unet-620: tensor(-672989.1250, device='cuda:0') tensor(1961.6932, device='cuda:0')
====================================

It's clear that some values exceed 65504(ma value for torch.float16). Is there any way I can finetune this model to reduce the intermediate activation values so that it can run with torch.float16? Alternatively, could you please provide a new model with appropriate activation values that support torch.float16?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions