Skip to content

Releases: samhaswon/skin_segmentation

v0.0.2

30 Jan 21:00
91022cf

Choose a tag to compare

Changes

  • Roughly doubled dataset size to 1,216 samples (1,134 training)

  • Added 2 new data augments

    • Only really worked out on U2Net and BiRefNet, with it needing to be omitted for the smaller models.
  • Added more dataset information

  • Added dlmv model (DeepLabV3 + MobileNetV3 backbone)

    • Basically just deeplabv3_mobilenet_v3_large from torchvision
  • Added BiRefNet_lite

    • Often just referred to as "BiRefNet" or "birefnet" in this repo
  • Added StraightU2Net

    • Referred to as "sunet" in some places
  • Added U2NetP chunk refiner model

  • Added quantitative testing results

    • mIoU, mIoU@0.5, MAE, and HCE.

    • Inference time (CPU) with PyTorch and ONNX Runtime.

  • Added various other traditional methods

    • All perform worse than most of the AI methods, except Google's which is just bad.

Usage

(batch, channels, height, width)

Model Input tensor shape Output tensor shape
birefnet* (1, 3, 1728, 1728) List or Tuple of (1, 1, 1728, 1728)
u2net* (1, 3, 1024, 1024) List or Tuple (length 7) of (1, 1, 1024, 1024)
u2netp* (1, 3, 512, 512) List or Tuple (length 7) of (1, 1, 512, 512)
u2netp_chunks (1, 4, 512, 512) List or Tuple (length 7) of (1, 1, 512, 512)
sunet* (1, 3, 320, 320) (List if ONNX) (1, 1, 320, 320)
dlmv* (1, 3, 256, 256) (List if ONNX) (1, 1, 256, 256)

For the U2Net models, you want the 0th item of the list or tuple as demonstrated in the included Session class. Inputs should be normed to [0, 1] by dividing by 255, and the outputs should be normed then multiplied by 255 for usage in uint8.

This release contains a fair number of files with a variety of names. Here's the idea behind the naming scheme.

  • Prefix:

    • u2net: The big U2Net model

    • u2netp: The smaller, more mobile-friendly U2NetP model

    • dlmv: The DeepLabV3 + MobileNetV3 backbone model

    • birefnet: BiRefNet_lite

      • Note: birefnet_22.onnx uses opset version 22, namely DeformConv, which might not be available in your runtime.
    • sunet: StraightU2Net

  • _quant: Indicates a quantized model:

    • Note: PyTorch exports of quantized models have not been tested by me.

    • _fbgemm uses PyTorch's fbgemm backend

      • More server/desktop CPU friendly
    • _qnnpack uses PyTorch's qnnpack backend

      • More mobile friendly.
  • u2netp_chunks: The refiner model

    • This model takes in the original image, stacked with the output of either: BiRefNet, U2Net, or U2NetP. It does better with bigger models, as it's main purpose is to refine edges.
  • File extension

    • .onnx: ONNX model

    • .pth: PyTorch model

      • Note: these are state_dicts, not self-contained checkpoints, meaning you have to instantiate the model before loading. You can find example code for this below.
    • .pth.tar: Training checkpoint with complete state

      • ["state"]["state_dict"] for model weights

This repo includes an example session class in u2net/ that should work for both U2Net Models, the DeepLabV3 + MobileNetV3 backbone model, and StraightU2Net.

For torch inference, grab the U2Net/dlmv/StraightU2Net/BiRefNet files from here: https://github.com/samhaswon/rembg-trainer-cuda/tree/segmentation-gradients/model

Then import the model and torch. Then:

net = U2NETP(3, 1)  # or U2NET
# Or for dlmv
# net = DeepLabV3MobileNetV3(1)
if torch.cuda.is_available():
    net.load_state_dict(
        torch.load(model_path, weights_only=False)
    )
    net.cuda()  # or load and call `.to(DEVICE)` later
else:
    net.load_state_dict(
        torch.load(
            model_path,
            map_location=torch.device(DEVICE),
            weights_only=False
        )
    )
net.eval()

You could probably set weights_only=True and be fine, this is just what I do in my code. It should just be weights in there.

Results

Base models/methods:

Training Set

Model/Method mIoU mIoU@0.5 MAE HCE
BirefNet 0.97259184 0.98543735 0.65279536 100.6
U2Net 0.95717705 0.97403675 2.21020127 99.0
U2NetP 0.92909448 0.94376292 3.53340132 131.6
StraightU2Net 0.86673576 0.88170478 5.83840019 177.9
DeepLabV3MobileNetV3 0.87064232 0.88061131 4.80666945 119.4
Google (MediaPipe) 0.61916837 0.61970152 31.84890669 237.1
ICM 0.63464635 0.63695261 29.26669075 834.1
Diagonal Elliptical YCbCr 0.62804600 0.63014882 33.72749032 934.4
Elliptical YCbCr 0.52903351 0.53050420 39.72920897 908.2
YCbCr 0.54825962 0.54968896 51.89878261 787.7
YCbCr & HSV 0.54879407 0.55008863 41.10210647 1119.8
HSV 0.52135957 0.52343017 46.99411327 1176.1
Face 0.36567424 0.36688121 67.68717940 1367.4

Time: 89343.20s (~24.8 hours)

Evaluation Set

Model/Method mIoU mIoU@0.5 MAE HCE
BiRefNet 0.95078197 0.96431441 1.13846406 218.1
U2Net 0.91900272 0.93312760 1.77361109 276.1
U2NetP 0.81913939 0.83019210 7.66510450 310.3
StraightU2Net 0.81821826 0.83248770 6.23616959 356.6
DeepLabV3MobileNetV3 0.65540375 0.66071213 14.67025902 220.9
Google (MediaPipe) 0.55579589 0.55700667 35.15274714 482.6
ICM 0.62055383 0.62393307 34.90752014 1676.0
Diagonal Elliptical YCbCr 0.60762614 0.61074735 40.07986746 1897.7
Elliptical YCbCr 0.52158137 0.52399477 38.98099503 1711.4
YCbCr 0.53553124 0.53808545 56.83574937 1491.5
YCbCr & HSV 0.56219051 0.56487315 40.28673917 1878.5
HSV 0.54106430 0.54405230 44.63998198 1942.1
Face 0.35228740 0.35455905 70.11496282 2602.0

Time: 11442.24s

Quantized models:

Model Quantization Engine mIoU mIoU@0.5 MAE HCE
U2Net fbgemm (x86) 0.91131639 0.92721647 2.04782018 274.7
U2Net qnnpack 0.91367015 0.92994349 2.13429254 273.0
U2NetP fbgemm (x86) 0.84821305 0.86138324 4.39863700 302.3
U2NetP qnnpack 0.84268503 0.85858355 4.68695845 302.0
DeepLabV3MobileNetV3 fbgemm (x86) 0.64375089 0.64992159 14.83924018 397.1
DeepLabV3MobileNetV3 qnnpack 0.64348288 0.64969167 14.92561650 400.2

*Note: BiRefNet in FP32 takes ~14GB of memory with PyTorch, but ~40GB with ONNX Runtime at 1728x1728.

Chunked Inference

Base Model mIoU mIoU@0.5 MAE HCE
BirefNet 0.9527391221493825 0.962896830505795 1.10493106 207.4
U2Net 0.9277886885142149 0.9371333954537787 2.06625062 246.5
U2NetP 0.8624414666957356 0.8699439350606465 4.35943205 240.9

Checksums

File SHA256 checksum
u2net.onnx ac44f925c222a842d51d60f336e766621fc3593bced20a3624663dc0022c97ed
u2net.pth 84ebf1d09899e1b2d4d02532f1d5287027d12cf3b09bbfb133bcfc17c6f8be10
u2net_quant_fbgemm.onnx ec50b4863b85b320d477402ec1a0a5d785440b6297aa87c6f6c6ecb6f751a555
u2net_quant_fbgemm.pth a40ed1292fbf9b7b1b6f3aea319cb5fb8aa4563bb4fa90ba616ca6ab8a136e0f
u2net_quant_qnnpack.onnx 912cb9700569f9f0cff51f0988a31bb44ae554406d5565d0c1ca892eb7a90018
u2net_quant_qnnpack.pth 65146c2a7a1a175f87653c4ed3648ef1ee84b9d9deaf5898033ee983db6b6e5c
u2net_state.pth.tar 119e06f99ceb7cf424caa922aae429db444d435c19985662b1f92e4f0098f03c
u2netp.onnx b6d2e3ecb212d66ce53144de6db9b75ec8b6adfa787bde4143674491bc012f02
u2netp.pth b52eb0bb45841554a07b88c4d7a3099868b1e6a7f8da9c562ae2e937486c28a1
u2netp_state.pth.tar a2162d851f6d98d4d8b0469a86e650bf7741d96a01fb1f4cf0066b4bde608d55
u2netp_quant_fbgemm.onnx c4f7b3fa4fd9d9693e3a666505ca84311a93191eb67cdb138abbdb3aba6a1a4e
u2netp_quant_fbgemm.pth 6396354288987c49c4597864843fc7259c2c11ade5a7fc6608b7eec5305fd25d
u2netp_quant_qnnpack.onnx dd5f2e310e847793714947492f98b5f15b7d49319f85df5cb195b540e70379f9
u2netp_quant_qnnpack.pth 248eeecad19d0bd26890540882477f6abaec9f71ce76463c75eb8087dae54cfd
u2netp_chunks.onnx `657bdf94e7f1a66d8f0d36b645023cb06368c48b48b47f7c65...
Read more

v0.0.2-pre

03 Oct 14:16

Choose a tag to compare

v0.0.2-pre Pre-release
Pre-release

This model expects a (1, 3, 256, 256) input and returns a list containing a (1, 1, 256, 256) output

v0.0.1

05 Sep 17:29

Choose a tag to compare

  • Make session class easier to use
  • Create example creation script
  • Release non-onnx PyTorch weights for easier custom usage
  • Add u2netp and optimized u2netp (skin_u2netp_o.onnx) models
Model Input tensor shape
skin_u2net* (1, 3, 1024, 1024)
skin_u2netp* (1, 3, 512, 512)

Full Changelog: v0.0.0...v0.0.1

Inital Models

28 Apr 03:28

Choose a tag to compare

Initial model upload