1
0
Fork 0
forked from mia/0x0

Replace NSFW detector implementation

This commit is contained in:
Mia Herkt 2024-09-25 18:12:39 +02:00
parent 3330a85c2c
commit 6393538333
No known key found for this signature in database
8 changed files with 21 additions and 3566 deletions

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python3
"""
Copyright © 2020 Mia Herkt
Copyright © 2024 Mia Herkt
Licensed under the EUPL, Version 1.2 or - as soon as approved
by the European Commission - subsequent versions of the EUPL
(the "License");
@ -18,57 +18,16 @@
and limitations under the License.
"""
import numpy as np
import os
import sys
from io import BytesIO
from pathlib import Path
os.environ["GLOG_minloglevel"] = "2" # seriously :|
import caffe
import av
av.logging.set_level(av.logging.PANIC)
from transformers import pipeline
class NSFWDetector:
def __init__(self):
npath = Path(__file__).parent / "nsfw_model"
self.nsfw_net = caffe.Net(
str(npath / "deploy.prototxt"),
caffe.TEST,
weights = str(npath / "resnet_50_1by2_nsfw.caffemodel")
)
self.caffe_transformer = caffe.io.Transformer({
'data': self.nsfw_net.blobs['data'].data.shape
})
# move image channels to outermost
self.caffe_transformer.set_transpose('data', (2, 0, 1))
# subtract the dataset-mean value in each channel
self.caffe_transformer.set_mean('data', np.array([104, 117, 123]))
# rescale from [0, 1] to [0, 255]
self.caffe_transformer.set_raw_scale('data', 255)
# swap channels from RGB to BGR
self.caffe_transformer.set_channel_swap('data', (2, 1, 0))
def _compute(self, img):
image = caffe.io.load_image(img)
H, W, _ = image.shape
_, _, h, w = self.nsfw_net.blobs["data"].data.shape
h_off = int(max((H - h) / 2, 0))
w_off = int(max((W - w) / 2, 0))
crop = image[h_off:h_off + h, w_off:w_off + w, :]
transformed_image = self.caffe_transformer.preprocess('data', crop)
transformed_image.shape = (1,) + transformed_image.shape
input_name = self.nsfw_net.inputs[0]
output_layers = ["prob"]
all_outputs = self.nsfw_net.forward_all(
blobs=output_layers, **{input_name: transformed_image})
outputs = all_outputs[output_layers[0]][0].astype(float)
return outputs
self.classifier = pipeline("image-classification", model="giacomoarienti/nsfw-classifier")
def detect(self, fpath):
try:
@ -77,23 +36,13 @@ class NSFWDetector:
except: container.seek(0)
frame = next(container.decode(video=0))
img = frame.to_image()
res = self.classifier(img)
if frame.width >= frame.height:
w = 256
h = int(frame.height * (256 / frame.width))
else:
w = int(frame.width * (256 / frame.height))
h = 256
frame = frame.reformat(width=w, height=h, format="rgb24")
img = BytesIO()
frame.to_image().save(img, format="ppm")
scores = self._compute(img)
except:
return -1.0
return scores[1]
return max([x["score"] for x in res if x["label"] not in ["neutral", "drawings"]])
except: pass
return -1.0
if __name__ == "__main__":
n = NSFWDetector()