nsfw_detect: Use PyAV instead of ffmpegthumbnailer

This commit is contained in:
Mia Herkt 2022-11-29 21:46:33 +01:00
parent 14cfe3da58
commit eb0b1d2f69
No known key found for this signature in database
GPG key ID: 72E154B8622EC191
2 changed files with 20 additions and 9 deletions

View file

@ -56,7 +56,7 @@ neural network model. This works for images and video files and requires
the following: the following:
* Caffe Python module (built for Python 3) * Caffe Python module (built for Python 3)
* ``ffmpegthumbnailer`` executable in ``$PATH`` * `PyAV <https://github.com/PyAV-Org/PyAV>`_
Network Security Considerations Network Security Considerations

View file

@ -22,11 +22,12 @@ import numpy as np
import os import os
import sys import sys
from io import BytesIO from io import BytesIO
from subprocess import run, PIPE, DEVNULL
from pathlib import Path from pathlib import Path
os.environ["GLOG_minloglevel"] = "2" # seriously :| os.environ["GLOG_minloglevel"] = "2" # seriously :|
import caffe import caffe
import av
av.logging.set_level(av.logging.PANIC)
class NSFWDetector: class NSFWDetector:
def __init__(self): def __init__(self):
@ -49,7 +50,7 @@ class NSFWDetector:
self.caffe_transformer.set_channel_swap('data', (2, 1, 0)) self.caffe_transformer.set_channel_swap('data', (2, 1, 0))
def _compute(self, img): def _compute(self, img):
image = caffe.io.load_image(BytesIO(img)) image = caffe.io.load_image(img)
H, W, _ = image.shape H, W, _ = image.shape
_, _, h, w = self.nsfw_net.blobs["data"].data.shape _, _, h, w = self.nsfw_net.blobs["data"].data.shape
@ -71,13 +72,23 @@ class NSFWDetector:
def detect(self, fpath): def detect(self, fpath):
try: try:
ff = run([ with av.open(fpath) as container:
"ffmpegthumbnailer", "-m", "-o-", "-s256", "-t50%", "-a", try: container.seek(int(container.duration / 2))
"-cpng", "-i", fpath except: container.seek(0)
], stdout=PIPE, stderr=DEVNULL, check=True)
image_data = ff.stdout
scores = self._compute(image_data) frame = next(container.decode(video=0))
if frame.width >= frame.height:
w = 256
h = int(frame.height * (256 / frame.width))
else:
w = int(frame.width * (256 / frame.height))
h = 256
frame = frame.reformat(width=w, height=h, format="rgb24")
img = BytesIO()
frame.to_image().save(img, format="ppm")
scores = self._compute(img)
except: except:
return -1.0 return -1.0