Releases: samhaswon/skin_segmentation
v0.0.2
Changes
-
Roughly doubled dataset size to 1,216 samples (1,134 training)
-
Added 2 new data augments
- Only really worked out on U2Net and BiRefNet, with it needing to be omitted for the smaller models.
-
Added more dataset information
-
Added dlmv model (DeepLabV3 + MobileNetV3 backbone)
- Basically just
deeplabv3_mobilenet_v3_largefromtorchvision
- Basically just
-
Added BiRefNet_lite
- Often just referred to as "BiRefNet" or "birefnet" in this repo
-
Added StraightU2Net
- Referred to as "sunet" in some places
-
Added U2NetP chunk refiner model
-
Added quantitative testing results
-
mIoU, mIoU@0.5, MAE, and HCE.
-
Inference time (CPU) with PyTorch and ONNX Runtime.
-
-
Added various other traditional methods
- All perform worse than most of the AI methods, except Google's which is just bad.
Usage
(batch, channels, height, width)
| Model | Input tensor shape | Output tensor shape |
|---|---|---|
birefnet* |
(1, 3, 1728, 1728) |
List or Tuple of (1, 1, 1728, 1728) |
u2net* |
(1, 3, 1024, 1024) |
List or Tuple (length 7) of (1, 1, 1024, 1024) |
u2netp* |
(1, 3, 512, 512) |
List or Tuple (length 7) of (1, 1, 512, 512) |
u2netp_chunks |
(1, 4, 512, 512) |
List or Tuple (length 7) of (1, 1, 512, 512) |
sunet* |
(1, 3, 320, 320) |
(List if ONNX) (1, 1, 320, 320) |
dlmv* |
(1, 3, 256, 256) |
(List if ONNX) (1, 1, 256, 256) |
For the U2Net models, you want the 0th item of the list or tuple as demonstrated in the included Session class. Inputs should be normed to [0, 1] by dividing by 255, and the outputs should be normed then multiplied by 255 for usage in uint8.
This release contains a fair number of files with a variety of names. Here's the idea behind the naming scheme.
-
Prefix:
-
u2net: The big U2Net model -
u2netp: The smaller, more mobile-friendly U2NetP model -
dlmv: The DeepLabV3 + MobileNetV3 backbone model -
birefnet: BiRefNet_lite- Note:
birefnet_22.onnxuses opset version 22, namely DeformConv, which might not be available in your runtime.
- Note:
-
sunet: StraightU2Net
-
-
_quant: Indicates a quantized model:-
Note: PyTorch exports of quantized models have not been tested by me.
-
_fbgemmuses PyTorch'sfbgemmbackend- More server/desktop CPU friendly
-
_qnnpackuses PyTorch'sqnnpackbackend- More mobile friendly.
-
-
u2netp_chunks: The refiner model- This model takes in the original image, stacked with the output of either: BiRefNet, U2Net, or U2NetP. It does better with bigger models, as it's main purpose is to refine edges.
-
File extension
-
.onnx: ONNX model -
.pth: PyTorch model- Note: these are state_dicts, not self-contained checkpoints, meaning you have to instantiate the model before loading. You can find example code for this below.
-
.pth.tar: Training checkpoint with complete state["state"]["state_dict"]for model weights
-
This repo includes an example session class in u2net/ that should work for both U2Net Models, the DeepLabV3 + MobileNetV3 backbone model, and StraightU2Net.
For torch inference, grab the U2Net/dlmv/StraightU2Net/BiRefNet files from here: https://github.com/samhaswon/rembg-trainer-cuda/tree/segmentation-gradients/model
Then import the model and torch. Then:
net = U2NETP(3, 1) # or U2NET
# Or for dlmv
# net = DeepLabV3MobileNetV3(1)
if torch.cuda.is_available():
net.load_state_dict(
torch.load(model_path, weights_only=False)
)
net.cuda() # or load and call `.to(DEVICE)` later
else:
net.load_state_dict(
torch.load(
model_path,
map_location=torch.device(DEVICE),
weights_only=False
)
)
net.eval()You could probably set weights_only=True and be fine, this is just what I do in my code. It should just be weights in there.
Results
Base models/methods:
Training Set
| Model/Method | mIoU | mIoU@0.5 | MAE | HCE |
|---|---|---|---|---|
| BirefNet | 0.97259184 | 0.98543735 | 0.65279536 | 100.6 |
| U2Net | 0.95717705 | 0.97403675 | 2.21020127 | 99.0 |
| U2NetP | 0.92909448 | 0.94376292 | 3.53340132 | 131.6 |
| StraightU2Net | 0.86673576 | 0.88170478 | 5.83840019 | 177.9 |
| DeepLabV3MobileNetV3 | 0.87064232 | 0.88061131 | 4.80666945 | 119.4 |
| Google (MediaPipe) | 0.61916837 | 0.61970152 | 31.84890669 | 237.1 |
| ICM | 0.63464635 | 0.63695261 | 29.26669075 | 834.1 |
| Diagonal Elliptical YCbCr | 0.62804600 | 0.63014882 | 33.72749032 | 934.4 |
| Elliptical YCbCr | 0.52903351 | 0.53050420 | 39.72920897 | 908.2 |
| YCbCr | 0.54825962 | 0.54968896 | 51.89878261 | 787.7 |
| YCbCr & HSV | 0.54879407 | 0.55008863 | 41.10210647 | 1119.8 |
| HSV | 0.52135957 | 0.52343017 | 46.99411327 | 1176.1 |
| Face | 0.36567424 | 0.36688121 | 67.68717940 | 1367.4 |
Time: 89343.20s (~24.8 hours)
Evaluation Set
| Model/Method | mIoU | mIoU@0.5 | MAE | HCE |
|---|---|---|---|---|
| BiRefNet | 0.95078197 | 0.96431441 | 1.13846406 | 218.1 |
| U2Net | 0.91900272 | 0.93312760 | 1.77361109 | 276.1 |
| U2NetP | 0.81913939 | 0.83019210 | 7.66510450 | 310.3 |
| StraightU2Net | 0.81821826 | 0.83248770 | 6.23616959 | 356.6 |
| DeepLabV3MobileNetV3 | 0.65540375 | 0.66071213 | 14.67025902 | 220.9 |
| Google (MediaPipe) | 0.55579589 | 0.55700667 | 35.15274714 | 482.6 |
| ICM | 0.62055383 | 0.62393307 | 34.90752014 | 1676.0 |
| Diagonal Elliptical YCbCr | 0.60762614 | 0.61074735 | 40.07986746 | 1897.7 |
| Elliptical YCbCr | 0.52158137 | 0.52399477 | 38.98099503 | 1711.4 |
| YCbCr | 0.53553124 | 0.53808545 | 56.83574937 | 1491.5 |
| YCbCr & HSV | 0.56219051 | 0.56487315 | 40.28673917 | 1878.5 |
| HSV | 0.54106430 | 0.54405230 | 44.63998198 | 1942.1 |
| Face | 0.35228740 | 0.35455905 | 70.11496282 | 2602.0 |
Time: 11442.24s
Quantized models:
| Model | Quantization Engine | mIoU | mIoU@0.5 | MAE | HCE |
|---|---|---|---|---|---|
| U2Net | fbgemm (x86) | 0.91131639 | 0.92721647 | 2.04782018 | 274.7 |
| U2Net | qnnpack | 0.91367015 | 0.92994349 | 2.13429254 | 273.0 |
| U2NetP | fbgemm (x86) | 0.84821305 | 0.86138324 | 4.39863700 | 302.3 |
| U2NetP | qnnpack | 0.84268503 | 0.85858355 | 4.68695845 | 302.0 |
| DeepLabV3MobileNetV3 | fbgemm (x86) | 0.64375089 | 0.64992159 | 14.83924018 | 397.1 |
| DeepLabV3MobileNetV3 | qnnpack | 0.64348288 | 0.64969167 | 14.92561650 | 400.2 |
*Note: BiRefNet in FP32 takes ~14GB of memory with PyTorch, but ~40GB with ONNX Runtime at 1728x1728.
Chunked Inference
| Base Model | mIoU | mIoU@0.5 | MAE | HCE |
|---|---|---|---|---|
| BirefNet | 0.9527391221493825 | 0.962896830505795 | 1.10493106 | 207.4 |
| U2Net | 0.9277886885142149 | 0.9371333954537787 | 2.06625062 | 246.5 |
| U2NetP | 0.8624414666957356 | 0.8699439350606465 | 4.35943205 | 240.9 |
Checksums
| File | SHA256 checksum |
|---|---|
u2net.onnx |
ac44f925c222a842d51d60f336e766621fc3593bced20a3624663dc0022c97ed |
u2net.pth |
84ebf1d09899e1b2d4d02532f1d5287027d12cf3b09bbfb133bcfc17c6f8be10 |
u2net_quant_fbgemm.onnx |
ec50b4863b85b320d477402ec1a0a5d785440b6297aa87c6f6c6ecb6f751a555 |
u2net_quant_fbgemm.pth |
a40ed1292fbf9b7b1b6f3aea319cb5fb8aa4563bb4fa90ba616ca6ab8a136e0f |
u2net_quant_qnnpack.onnx |
912cb9700569f9f0cff51f0988a31bb44ae554406d5565d0c1ca892eb7a90018 |
u2net_quant_qnnpack.pth |
65146c2a7a1a175f87653c4ed3648ef1ee84b9d9deaf5898033ee983db6b6e5c |
u2net_state.pth.tar |
119e06f99ceb7cf424caa922aae429db444d435c19985662b1f92e4f0098f03c |
u2netp.onnx |
b6d2e3ecb212d66ce53144de6db9b75ec8b6adfa787bde4143674491bc012f02 |
u2netp.pth |
b52eb0bb45841554a07b88c4d7a3099868b1e6a7f8da9c562ae2e937486c28a1 |
u2netp_state.pth.tar |
a2162d851f6d98d4d8b0469a86e650bf7741d96a01fb1f4cf0066b4bde608d55 |
u2netp_quant_fbgemm.onnx |
c4f7b3fa4fd9d9693e3a666505ca84311a93191eb67cdb138abbdb3aba6a1a4e |
u2netp_quant_fbgemm.pth |
6396354288987c49c4597864843fc7259c2c11ade5a7fc6608b7eec5305fd25d |
u2netp_quant_qnnpack.onnx |
dd5f2e310e847793714947492f98b5f15b7d49319f85df5cb195b540e70379f9 |
u2netp_quant_qnnpack.pth |
248eeecad19d0bd26890540882477f6abaec9f71ce76463c75eb8087dae54cfd |
u2netp_chunks.onnx |
`657bdf94e7f1a66d8f0d36b645023cb06368c48b48b47f7c65... |
v0.0.2-pre
This model expects a (1, 3, 256, 256) input and returns a list containing a (1, 1, 256, 256) output
v0.0.1
- Make session class easier to use
- Create example creation script
- Release non-onnx PyTorch weights for easier custom usage
- Add u2netp and optimized u2netp (
skin_u2netp_o.onnx) models
| Model | Input tensor shape |
|---|---|
skin_u2net* |
(1, 3, 1024, 1024) |
skin_u2netp* |
(1, 3, 512, 512) |
Full Changelog: v0.0.0...v0.0.1
Inital Models
Initial model upload