PEP8 compliance

This commit is contained in:
Mia Herkt 2024-09-27 17:39:18 +02:00
parent a2147cc964
commit de19212a71
No known key found for this signature in database
18 changed files with 376 additions and 299 deletions

View file

@ -5,4 +5,4 @@ print("Instead, please run")
print("") print("")
print(" $ FLASK_APP=fhost flask prune") print(" $ FLASK_APP=fhost flask prune")
print("") print("")
exit(1); exit(1)

178
fhost.py
View file

@ -1,8 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*-
""" """
Copyright © 2020 Mia Herkt Copyright © 2024 Mia Herkt
Licensed under the EUPL, Version 1.2 or - as soon as approved Licensed under the EUPL, Version 1.2 or - as soon as approved
by the European Commission - subsequent versions of the EUPL by the European Commission - subsequent versions of the EUPL
(the "License"); (the "License");
@ -19,7 +18,8 @@
and limitations under the License. and limitations under the License.
""" """
from flask import Flask, abort, make_response, redirect, request, send_from_directory, url_for, Response, render_template, Request from flask import Flask, abort, make_response, redirect, render_template, \
Request, request, Response, send_from_directory, url_for
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate from flask_migrate import Migrate
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
@ -47,7 +47,9 @@ from pathlib import Path
app = Flask(__name__, instance_relative_config=True) app = Flask(__name__, instance_relative_config=True)
app.config.update( app.config.update(
SQLALCHEMY_TRACK_MODIFICATIONS=False, SQLALCHEMY_TRACK_MODIFICATIONS=False,
PREFERRED_URL_SCHEME = "https", # nginx users: make sure to have 'uwsgi_param UWSGI_SCHEME $scheme;' in your config PREFERRED_URL_SCHEME="https", # nginx users: make sure to have
# 'uwsgi_param UWSGI_SCHEME $scheme;' in
# your config
MAX_CONTENT_LENGTH=256 * 1024 * 1024, MAX_CONTENT_LENGTH=256 * 1024 * 1024,
MAX_URL_LENGTH=4096, MAX_URL_LENGTH=4096,
USE_X_SENDFILE=False, USE_X_SENDFILE=False,
@ -77,7 +79,8 @@ app.config.update(
"PUA.Win.Packer.XmMusicFile", "PUA.Win.Packer.XmMusicFile",
], ],
VSCAN_INTERVAL=datetime.timedelta(days=7), VSCAN_INTERVAL=datetime.timedelta(days=7),
URL_ALPHABET = "DEQhd2uFteibPwq0SWBInTpA_jcZL5GKz3YCR14Ulk87Jors9vNHgfaOmMXy6Vx-", URL_ALPHABET="DEQhd2uFteibPwq0SWBInTpA_jcZL5GKz3YCR14Ulk87Jors9vNHgfaOmMX"
"y6Vx-",
) )
app.config.from_pyfile("config.py") app.config.from_pyfile("config.py")
@ -95,7 +98,7 @@ if app.config["NSFW_DETECT"]:
try: try:
mimedetect = Magic(mime=True, mime_encoding=False) mimedetect = Magic(mime=True, mime_encoding=False)
except: except TypeError:
print("""Error: You have installed the wrong version of the 'magic' module. print("""Error: You have installed the wrong version of the 'magic' module.
Please install python-magic.""") Please install python-magic.""")
sys.exit(1) sys.exit(1)
@ -103,6 +106,7 @@ Please install python-magic.""")
db = SQLAlchemy(app) db = SQLAlchemy(app)
migrate = Migrate(app, db) migrate = Migrate(app, db)
class URL(db.Model): class URL(db.Model):
__tablename__ = "URL" __tablename__ = "URL"
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
@ -117,6 +121,7 @@ class URL(db.Model):
def geturl(self): def geturl(self):
return url_for("get", path=self.getname(), _external=True) + "\n" return url_for("get", path=self.getname(), _external=True) + "\n"
@staticmethod
def get(url): def get(url):
u = URL.query.filter_by(url=url).first() u = URL.query.filter_by(url=url).first()
@ -127,6 +132,7 @@ class URL(db.Model):
return u return u
class IPAddress(types.TypeDecorator): class IPAddress(types.TypeDecorator):
impl = types.LargeBinary impl = types.LargeBinary
cache_ok = True cache_ok = True
@ -175,18 +181,19 @@ class File(db.Model):
@property @property
def is_nsfw(self) -> bool: def is_nsfw(self) -> bool:
return self.nsfw_score and self.nsfw_score > app.config["NSFW_THRESHOLD"] if self.nsfw_score:
return self.nsfw_score > app.config["NSFW_THRESHOLD"]
return False
def getname(self): def getname(self):
return u"{0}{1}".format(su.enbase(self.id), self.ext) return u"{0}{1}".format(su.enbase(self.id), self.ext)
def geturl(self): def geturl(self):
n = self.getname() n = self.getname()
a = "nsfw" if self.is_nsfw else None
if self.is_nsfw: return url_for("get", path=n, secret=self.secret,
return url_for("get", path=n, secret=self.secret, _external=True, _anchor="nsfw") + "\n" _external=True, _anchor=a) + "\n"
else:
return url_for("get", path=n, secret=self.secret, _external=True) + "\n"
def getpath(self) -> Path: def getpath(self) -> Path:
return Path(app.config["FHOST_STORAGE_PATH"]) / self.sha256 return Path(app.config["FHOST_STORAGE_PATH"]) / self.sha256
@ -197,33 +204,37 @@ class File(db.Model):
self.removed = permanent self.removed = permanent
self.getpath().unlink(missing_ok=True) self.getpath().unlink(missing_ok=True)
# Returns the epoch millisecond that a file should expire """
# Returns the epoch millisecond that a file should expire
# Uses the expiration time provided by the user (requested_expiration)
# upper-bounded by an algorithm that computes the size based on the size of the Uses the expiration time provided by the user (requested_expiration)
# file. upper-bounded by an algorithm that computes the size based on the size of
# the file.
# That is, all files are assigned a computed expiration, which can voluntarily
# shortened by the user either by providing a timestamp in epoch millis or a That is, all files are assigned a computed expiration, which can be
# duration in hours. voluntarily shortened by the user either by providing a timestamp in
milliseconds since epoch or a duration in hours.
"""
@staticmethod
def get_expiration(requested_expiration, size) -> int: def get_expiration(requested_expiration, size) -> int:
current_epoch_millis = time.time() * 1000; current_epoch_millis = time.time() * 1000
# Maximum lifetime of the file in milliseconds # Maximum lifetime of the file in milliseconds
this_files_max_lifespan = get_max_lifespan(size); max_lifespan = get_max_lifespan(size)
# The latest allowed expiration date for this file, in epoch millis # The latest allowed expiration date for this file, in epoch millis
this_files_max_expiration = this_files_max_lifespan + 1000 * time.time(); max_expiration = max_lifespan + 1000 * time.time()
if requested_expiration is None: if requested_expiration is None:
return this_files_max_expiration return max_expiration
elif requested_expiration < 1650460320000: elif requested_expiration < 1650460320000:
# Treat the requested expiration time as a duration in hours # Treat the requested expiration time as a duration in hours
requested_expiration_ms = requested_expiration * 60 * 60 * 1000 requested_expiration_ms = requested_expiration * 60 * 60 * 1000
return min(this_files_max_expiration, current_epoch_millis + requested_expiration_ms) return min(max_expiration,
current_epoch_millis + requested_expiration_ms)
else: else:
# Treat the requested expiration time as a timestamp in epoch millis # Treat expiration time as a timestamp in epoch millis
return min(this_files_max_expiration, requested_expiration) return min(max_expiration, requested_expiration)
""" """
requested_expiration can be: requested_expiration can be:
@ -231,18 +242,23 @@ class File(db.Model):
- a duration (in hours) that the file should live for - a duration (in hours) that the file should live for
- a timestamp in epoch millis that the file should expire at - a timestamp in epoch millis that the file should expire at
Any value greater that the longest allowed file lifespan will be rounded down to that Any value greater that the longest allowed file lifespan will be rounded
value. down to that value.
""" """
def store(file_, requested_expiration: typing.Optional[int], addr, ua, secret: bool): @staticmethod
def store(file_, requested_expiration: typing.Optional[int], addr, ua,
secret: bool):
data = file_.read() data = file_.read()
digest = sha256(data).hexdigest() digest = sha256(data).hexdigest()
def get_mime(): def get_mime():
guess = mimedetect.from_buffer(data) guess = mimedetect.from_buffer(data)
app.logger.debug(f"MIME - specified: '{file_.content_type}' - detected: '{guess}'") app.logger.debug(f"MIME - specified: '{file_.content_type}' - "
f"detected: '{guess}'")
if not file_.content_type or not "/" in file_.content_type or file_.content_type == "application/octet-stream": if (not file_.content_type
or "/" not in file_.content_type
or file_.content_type == "application/octet-stream"):
mime = guess mime = guess
else: else:
mime = file_.content_type mime = file_.content_type
@ -254,7 +270,7 @@ class File(db.Model):
if flt.check(guess): if flt.check(guess):
abort(403, flt.reason) abort(403, flt.reason)
if mime.startswith("text/") and not "charset" in mime: if mime.startswith("text/") and "charset" not in mime:
mime += "; charset=utf-8" mime += "; charset=utf-8"
return mime return mime
@ -266,7 +282,8 @@ class File(db.Model):
gmime = mime.split(";")[0] gmime = mime.split(";")[0]
guess = guess_extension(gmime) guess = guess_extension(gmime)
app.logger.debug(f"extension - specified: '{ext}' - detected: '{guess}'") app.logger.debug(f"extension - specified: '{ext}' - detected: "
f"'{guess}'")
if not ext: if not ext:
if gmime in app.config["FHOST_EXT_OVERRIDE"]: if gmime in app.config["FHOST_EXT_OVERRIDE"]:
@ -309,7 +326,8 @@ class File(db.Model):
if isnew: if isnew:
f.secret = None f.secret = None
if secret: if secret:
f.secret = secrets.token_urlsafe(app.config["FHOST_SECRET_BYTES"]) f.secret = \
secrets.token_urlsafe(app.config["FHOST_SECRET_BYTES"])
storage = Path(app.config["FHOST_STORAGE_PATH"]) storage = Path(app.config["FHOST_STORAGE_PATH"])
storage.mkdir(parents=True, exist_ok=True) storage.mkdir(parents=True, exist_ok=True)
@ -471,17 +489,21 @@ class UrlEncoder(object):
result += self.alphabet.index(c) * (n ** i) result += self.alphabet.index(c) * (n ** i)
return result return result
su = UrlEncoder(alphabet=app.config["URL_ALPHABET"], min_length=1) su = UrlEncoder(alphabet=app.config["URL_ALPHABET"], min_length=1)
def fhost_url(scheme=None): def fhost_url(scheme=None):
if not scheme: if not scheme:
return url_for(".fhost", _external=True).rstrip("/") return url_for(".fhost", _external=True).rstrip("/")
else: else:
return url_for(".fhost", _external=True, _scheme=scheme).rstrip("/") return url_for(".fhost", _external=True, _scheme=scheme).rstrip("/")
def is_fhost_url(url): def is_fhost_url(url):
return url.startswith(fhost_url()) or url.startswith(fhost_url("https")) return url.startswith(fhost_url()) or url.startswith(fhost_url("https"))
def shorten(url): def shorten(url):
if len(url) > app.config["MAX_URL_LENGTH"]: if len(url) > app.config["MAX_URL_LENGTH"]:
abort(414) abort(414)
@ -493,16 +515,18 @@ def shorten(url):
return u.geturl() return u.geturl()
""" """
requested_expiration can be: requested_expiration can be:
- None, to use the longest allowed file lifespan - None, to use the longest allowed file lifespan
- a duration (in hours) that the file should live for - a duration (in hours) that the file should live for
- a timestamp in epoch millis that the file should expire at - a timestamp in epoch millis that the file should expire at
Any value greater that the longest allowed file lifespan will be rounded down to that Any value greater that the longest allowed file lifespan will be rounded down
value. to that value.
""" """
def store_file(f, requested_expiration: typing.Optional[int], addr, ua, secret: bool): def store_file(f, requested_expiration: typing.Optional[int], addr, ua,
secret: bool):
sf, isnew = File.store(f, requested_expiration, addr, ua, secret) sf, isnew = File.store(f, requested_expiration, addr, ua, secret)
response = make_response(sf.geturl()) response = make_response(sf.geturl())
@ -513,6 +537,7 @@ def store_file(f, requested_expiration: typing.Optional[int], addr, ua, secret:
return response return response
def store_url(url, addr, ua, secret: bool): def store_url(url, addr, ua, secret: bool):
if is_fhost_url(url): if is_fhost_url(url):
abort(400) abort(400)
@ -526,13 +551,14 @@ def store_url(url, addr, ua, secret: bool):
return str(e) + "\n" return str(e) + "\n"
if "content-length" in r.headers: if "content-length" in r.headers:
l = int(r.headers["content-length"]) length = int(r.headers["content-length"])
if l <= app.config["MAX_CONTENT_LENGTH"]: if length <= app.config["MAX_CONTENT_LENGTH"]:
def urlfile(**kwargs): def urlfile(**kwargs):
return type('', (), kwargs)() return type('', (), kwargs)()
f = urlfile(read=r.raw.read, content_type=r.headers["content-type"], filename="") f = urlfile(read=r.raw.read,
content_type=r.headers["content-type"], filename="")
return store_file(f, None, addr, ua, secret) return store_file(f, None, addr, ua, secret)
else: else:
@ -540,10 +566,9 @@ def store_url(url, addr, ua, secret: bool):
else: else:
abort(411) abort(411)
def manage_file(f): def manage_file(f):
try: if request.form["token"] != f.mgmt_token:
assert(request.form["token"] == f.mgmt_token)
except:
abort(401) abort(401)
if "delete" in request.form: if "delete" in request.form:
@ -562,6 +587,7 @@ def manage_file(f):
abort(400) abort(400)
@app.route("/<path:path>", methods=["GET", "POST"]) @app.route("/<path:path>", methods=["GET", "POST"])
@app.route("/s/<secret>/<path:path>", methods=["GET", "POST"]) @app.route("/s/<secret>/<path:path>", methods=["GET", "POST"])
def get(path, secret=None): def get(path, secret=None):
@ -598,7 +624,9 @@ def get(path, secret=None):
response.headers["Content-Length"] = f.size response.headers["Content-Length"] = f.size
response.headers["X-Accel-Redirect"] = "/" + str(fpath) response.headers["X-Accel-Redirect"] = "/" + str(fpath)
else: else:
response = send_from_directory(app.config["FHOST_STORAGE_PATH"], f.sha256, mimetype = f.mime) response = send_from_directory(
app.config["FHOST_STORAGE_PATH"], f.sha256,
mimetype=f.mime)
response.headers["X-Expires"] = f.expiration response.headers["X-Expires"] = f.expiration
return response return response
@ -616,6 +644,7 @@ def get(path, secret=None):
abort(404) abort(404)
@app.route("/", methods=["GET", "POST"]) @app.route("/", methods=["GET", "POST"])
def fhost(): def fhost():
if request.method == "POST": if request.method == "POST":
@ -665,12 +694,14 @@ def fhost():
else: else:
return render_template("index.html") return render_template("index.html")
@app.route("/robots.txt") @app.route("/robots.txt")
def robots(): def robots():
return """User-agent: * return """User-agent: *
Disallow: / Disallow: /
""" """
@app.errorhandler(400) @app.errorhandler(400)
@app.errorhandler(401) @app.errorhandler(401)
@app.errorhandler(403) @app.errorhandler(403)
@ -682,20 +713,23 @@ Disallow: /
@app.errorhandler(451) @app.errorhandler(451)
def ehandler(e): def ehandler(e):
try: try:
return render_template(f"{e.code}.html", id=id, request=request, description=e.description), e.code return render_template(f"{e.code}.html", id=id, request=request,
description=e.description), e.code
except TemplateNotFound: except TemplateNotFound:
return "Segmentation fault\n", e.code return "Segmentation fault\n", e.code
@app.cli.command("prune") @app.cli.command("prune")
def prune(): def prune():
""" """
Clean up expired files Clean up expired files
Deletes any files from the filesystem which have hit their expiration time. This Deletes any files from the filesystem which have hit their expiration time.
doesn't remove them from the database, only from the filesystem. It's recommended This doesn't remove them from the database, only from the filesystem.
that server owners run this command regularly, or set it up on a timer. It is recommended that server owners run this command regularly, or set it
up on a timer.
""" """
current_time = time.time() * 1000; current_time = time.time() * 1000
# The path to where uploaded files are stored # The path to where uploaded files are stored
storage = Path(app.config["FHOST_STORAGE_PATH"]) storage = Path(app.config["FHOST_STORAGE_PATH"])
@ -709,7 +743,7 @@ def prune():
) )
) )
files_removed = 0; files_removed = 0
# For every expired file... # For every expired file...
for file in expired_files: for file in expired_files:
@ -722,31 +756,33 @@ def prune():
# Remove it from the file system # Remove it from the file system
try: try:
os.remove(file_path) os.remove(file_path)
files_removed += 1; files_removed += 1
except FileNotFoundError: except FileNotFoundError:
pass # If the file was already gone, we're good pass # If the file was already gone, we're good
except OSError as e: except OSError as e:
print(e) print(e)
print( print(
"\n------------------------------------" "\n------------------------------------"
"Encountered an error while trying to remove file {file_path}. Double" "Encountered an error while trying to remove file {file_path}."
"check to make sure the server is configured correctly, permissions are" "Make sure the server is configured correctly, permissions "
"okay, and everything is ship shape, then try again.") "are okay, and everything is ship shape, then try again.")
return; return
# Finally, mark that the file was removed # Finally, mark that the file was removed
file.expiration = None; file.expiration = None
db.session.commit() db.session.commit()
print(f"\nDone! {files_removed} file(s) removed") print(f"\nDone! {files_removed} file(s) removed")
""" For a file of a given size, determine the largest allowed lifespan of that file
Based on the current app's configuration: Specifically, the MAX_CONTENT_LENGTH, as well """
as FHOST_{MIN,MAX}_EXPIRATION. For a file of a given size, determine the largest allowed lifespan of that file
This lifespan may be shortened by a user's request, but no files should be allowed to Based on the current app's configuration:
expire at a point after this number. Specifically, the MAX_CONTENT_LENGTH, as well as FHOST_{MIN,MAX}_EXPIRATION.
This lifespan may be shortened by a user's request, but no files should be
allowed to expire at a point after this number.
Value returned is a duration in milliseconds. Value returned is a duration in milliseconds.
""" """
@ -756,11 +792,13 @@ def get_max_lifespan(filesize: int) -> int:
max_size = app.config.get("MAX_CONTENT_LENGTH", 256 * 1024 * 1024) max_size = app.config.get("MAX_CONTENT_LENGTH", 256 * 1024 * 1024)
return min_exp + int((-max_exp + min_exp) * (filesize / max_size - 1) ** 3) return min_exp + int((-max_exp + min_exp) * (filesize / max_size - 1) ** 3)
def do_vscan(f): def do_vscan(f):
if f["path"].is_file(): if f["path"].is_file():
with open(f["path"], "rb") as scanf: with open(f["path"], "rb") as scanf:
try: try:
f["result"] = list(app.config["VSCAN_SOCKET"].instream(scanf).values())[0] res = list(app.config["VSCAN_SOCKET"].instream(scanf).values())
f["result"] = res[0]
except: except:
f["result"] = ("SCAN FAILED", None) f["result"] = ("SCAN FAILED", None)
else: else:
@ -768,11 +806,12 @@ def do_vscan(f):
return f return f
@app.cli.command("vscan") @app.cli.command("vscan")
def vscan(): def vscan():
if not app.config["VSCAN_SOCKET"]: if not app.config["VSCAN_SOCKET"]:
print("""Error: Virus scanning enabled but no connection method specified. print("Error: Virus scanning enabled but no connection method "
Please set VSCAN_SOCKET.""") "specified.\nPlease set VSCAN_SOCKET.")
sys.exit(1) sys.exit(1)
qp = Path(app.config["VSCAN_QUARANTINE_PATH"]) qp = Path(app.config["VSCAN_QUARANTINE_PATH"])
@ -786,9 +825,11 @@ Please set VSCAN_SOCKET.""")
File.last_vscan == None), File.last_vscan == None),
File.removed == False) File.removed == False)
else: else:
res = File.query.filter(File.last_vscan == None, File.removed == False) res = File.query.filter(File.last_vscan == None,
File.removed == False)
work = [{"path" : f.getpath(), "name" : f.getname(), "id" : f.id} for f in res] work = [{"path": f.getpath(), "name": f.getname(), "id": f.id}
for f in res]
results = [] results = []
for i, r in enumerate(p.imap_unordered(do_vscan, work)): for i, r in enumerate(p.imap_unordered(do_vscan, work)):
@ -803,7 +844,8 @@ Please set VSCAN_SOCKET.""")
results.append({ results.append({
"id": r["id"], "id": r["id"],
"last_vscan" : None if r["result"][0] == "SCAN FAILED" else datetime.datetime.now(), "last_vscan": None if r["result"][0] == "SCAN FAILED"
else datetime.datetime.now(),
"removed": found}) "removed": found})
db.session.bulk_update_mappings(File, results) db.session.bulk_update_mappings(File, results)

View file

@ -81,6 +81,7 @@ def run_migrations_online():
finally: finally:
connection.close() connection.close()
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:

View file

@ -15,12 +15,8 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('file', sa.Column('mgmt_token', sa.String(), nullable=True)) op.add_column('file', sa.Column('mgmt_token', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('file', 'mgmt_token') op.drop_column('file', 'mgmt_token')
# ### end Alembic commands ###

View file

@ -15,13 +15,11 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
### commands auto generated by Alembic - please adjust! ###
op.create_table('URL', op.create_table('URL',
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('url', sa.UnicodeText(), nullable=True), sa.Column('url', sa.UnicodeText(), nullable=True),
sa.PrimaryKeyConstraint('id'), sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('url') sa.UniqueConstraint('url'))
)
op.create_table('file', op.create_table('file',
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('sha256', sa.String(), nullable=True), sa.Column('sha256', sa.String(), nullable=True),
@ -30,13 +28,9 @@ def upgrade():
sa.Column('addr', sa.UnicodeText(), nullable=True), sa.Column('addr', sa.UnicodeText(), nullable=True),
sa.Column('removed', sa.Boolean(), nullable=True), sa.Column('removed', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'), sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('sha256') sa.UniqueConstraint('sha256'))
)
### end Alembic commands ###
def downgrade(): def downgrade():
### commands auto generated by Alembic - please adjust! ###
op.drop_table('file') op.drop_table('file')
op.drop_table('URL') op.drop_table('URL')
### end Alembic commands ###

View file

@ -19,6 +19,7 @@ from pathlib import Path
Base = automap_base() Base = automap_base()
def upgrade(): def upgrade():
op.add_column('file', sa.Column('size', sa.BigInteger(), nullable=True)) op.add_column('file', sa.Column('size', sa.BigInteger(), nullable=True))
bind = op.get_bind() bind = op.get_bind()

View file

@ -19,45 +19,46 @@ import ipaddress
Base = automap_base() Base = automap_base()
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('request_filter', op.create_table('request_filter',
sa.Column('id', sa.Integer(), nullable=False), sa.Column('id', sa.Integer(), nullable=False),
sa.Column('type', sa.String(length=20), nullable=False), sa.Column('type', sa.String(length=20), nullable=False),
sa.Column('comment', sa.UnicodeText(), nullable=True), sa.Column('comment', sa.UnicodeText(), nullable=True),
sa.Column('addr', sa.LargeBinary(length=16), nullable=True), sa.Column('addr', sa.LargeBinary(length=16),
nullable=True),
sa.Column('net', sa.Text(), nullable=True), sa.Column('net', sa.Text(), nullable=True),
sa.Column('regex', sa.UnicodeText(), nullable=True), sa.Column('regex', sa.UnicodeText(), nullable=True),
sa.PrimaryKeyConstraint('id'), sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('addr') sa.UniqueConstraint('addr'))
)
with op.batch_alter_table('request_filter', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_request_filter_type'), ['type'], unique=False)
# ### end Alembic commands ### with op.batch_alter_table('request_filter', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_request_filter_type'), ['type'],
unique=False)
bind = op.get_bind() bind = op.get_bind()
Base.prepare(autoload_with=bind) Base.prepare(autoload_with=bind)
RequestFilter = Base.classes.request_filter RequestFilter = Base.classes.request_filter
session = Session(bind=bind) session = Session(bind=bind)
if "FHOST_UPLOAD_BLACKLIST" in current_app.config: blp = current_app.config.get("FHOST_UPLOAD_BLACKLIST")
if current_app.config["FHOST_UPLOAD_BLACKLIST"]: if blp:
with current_app.open_instance_resource(current_app.config["FHOST_UPLOAD_BLACKLIST"], "r") as bl: with current_app.open_instance_resource(blp, "r") as bl:
for l in bl.readlines(): for line in bl.readlines():
if not l.startswith("#"): if not line.startswith("#"):
l = l.strip() line = line.strip()
if l.endswith(":"): if line.endswith(":"):
# old implementation uses str.startswith, # old implementation uses str.startswith,
# which does not translate to networks # which does not translate to networks
current_app.logger.warning(f"Ignored address: {l}") current_app.logger.warning(
f"Ignored address: {line}")
continue continue
flt = RequestFilter(type="addr", addr=ipaddress.ip_address(l).packed) addr = ipaddress.ip_address(line).packed
flt = RequestFilter(type="addr", addr=addr)
session.add(flt) session.add(flt)
if "FHOST_MIME_BLACKLIST" in current_app.config: for mime in current_app.config.get("FHOST_MIME_BLACKLIST", []):
for mime in current_app.config["FHOST_MIME_BLACKLIST"]:
flt = RequestFilter(type="mime", regex=mime) flt = RequestFilter(type="mime", regex=mime)
session.add(flt) session.add(flt)
@ -72,9 +73,7 @@ def upgrade():
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('request_filter', schema=None) as batch_op: with op.batch_alter_table('request_filter', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_request_filter_type')) batch_op.drop_index(batch_op.f('ix_request_filter_type'))
op.drop_table('request_filter') op.drop_table('request_filter')
# ### end Alembic commands ###

View file

@ -15,12 +15,9 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### op.add_column('file', sa.Column('last_vscan', sa.DateTime(),
op.add_column('file', sa.Column('last_vscan', sa.DateTime(), nullable=True)) nullable=True))
# ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('file', 'last_vscan') op.drop_column('file', 'last_vscan')
# ### end Alembic commands ###

View file

@ -15,12 +15,8 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('file', sa.Column('nsfw_score', sa.Float(), nullable=True)) op.add_column('file', sa.Column('nsfw_score', sa.Float(), nullable=True))
# ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('file', 'nsfw_score') op.drop_column('file', 'nsfw_score')
# ### end Alembic commands ###

View file

@ -21,24 +21,29 @@ from sqlalchemy.orm import Session
import os import os
import time import time
""" For a file of a given size, determine the largest allowed lifespan of that file
Based on the current app's configuration: Specifically, the MAX_CONTENT_LENGTH, as well """
as FHOST_{MIN,MAX}_EXPIRATION. For a file of a given size, determine the largest allowed lifespan of that file
This lifespan may be shortened by a user's request, but no files should be allowed to Based on the current app's configuration:
expire at a point after this number. Specifically, the MAX_CONTENT_LENGTH, as well as FHOST_{MIN,MAX}_EXPIRATION.
This lifespan may be shortened by a user's request, but no files should be
allowed to expire at a point after this number.
Value returned is a duration in milliseconds. Value returned is a duration in milliseconds.
""" """
def get_max_lifespan(filesize: int) -> int: def get_max_lifespan(filesize: int) -> int:
min_exp = current_app.config.get("FHOST_MIN_EXPIRATION", 30 * 24 * 60 * 60 * 1000) cfg = current_app.config
max_exp = current_app.config.get("FHOST_MAX_EXPIRATION", 365 * 24 * 60 * 60 * 1000) min_exp = cfg.get("FHOST_MIN_EXPIRATION", 30 * 24 * 60 * 60 * 1000)
max_size = current_app.config.get("MAX_CONTENT_LENGTH", 256 * 1024 * 1024) max_exp = cfg.get("FHOST_MAX_EXPIRATION", 365 * 24 * 60 * 60 * 1000)
max_size = cfg.get("MAX_CONTENT_LENGTH", 256 * 1024 * 1024)
return min_exp + int((-max_exp + min_exp) * (filesize / max_size - 1) ** 3) return min_exp + int((-max_exp + min_exp) * (filesize / max_size - 1) ** 3)
Base = automap_base() Base = automap_base()
def upgrade(): def upgrade():
op.add_column('file', sa.Column('expiration', sa.BigInteger())) op.add_column('file', sa.Column('expiration', sa.BigInteger()))
@ -48,7 +53,7 @@ def upgrade():
session = Session(bind=bind) session = Session(bind=bind)
storage = Path(current_app.config["FHOST_STORAGE_PATH"]) storage = Path(current_app.config["FHOST_STORAGE_PATH"])
current_time = time.time() * 1000; current_time = time.time() * 1000
# List of file hashes which have not expired yet # List of file hashes which have not expired yet
# This could get really big for some servers # This could get really big for some servers
@ -74,13 +79,18 @@ def upgrade():
for file in files: for file in files:
file_path = storage / file.sha256 file_path = storage / file.sha256
stat = os.stat(file_path) stat = os.stat(file_path)
max_age = get_max_lifespan(stat.st_size) # How long the file is allowed to live, in ms # How long the file is allowed to live, in ms
file_birth = stat.st_mtime * 1000 # When the file was created, in ms max_age = get_max_lifespan(stat.st_size)
updates.append({'id': file.id, 'expiration': int(file_birth + max_age)}) # When the file was created, in ms
file_birth = stat.st_mtime * 1000
updates.append({
'id': file.id,
'expiration': int(file_birth + max_age)})
# Apply coalesced updates # Apply coalesced updates
session.bulk_update_mappings(File, updates) session.bulk_update_mappings(File, updates)
session.commit() session.commit()
def downgrade(): def downgrade():
op.drop_column('file', 'expiration') op.drop_column('file', 'expiration')

View file

@ -15,16 +15,10 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('file', schema=None) as batch_op: with op.batch_alter_table('file', schema=None) as batch_op:
batch_op.add_column(sa.Column('ua', sa.UnicodeText(), nullable=True)) batch_op.add_column(sa.Column('ua', sa.UnicodeText(), nullable=True))
# ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('file', schema=None) as batch_op: with op.batch_alter_table('file', schema=None) as batch_op:
batch_op.drop_column('ua') batch_op.drop_column('ua')
# ### end Alembic commands ###

View file

@ -15,12 +15,8 @@ import sqlalchemy as sa
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('file', sa.Column('secret', sa.String(), nullable=True)) op.add_column('file', sa.Column('secret', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade(): def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('file', 'secret') op.drop_column('file', 'secret')
# ### end Alembic commands ###

95
mod.py
View file

@ -14,10 +14,11 @@ from jinja2.filters import do_filesizeformat
import ipaddress import ipaddress
from fhost import db, File, AddrFilter, su, app as fhost_app from fhost import db, File, AddrFilter, su, app as fhost_app
from modui import * from modui import FileTable, mime, MpvWidget, Notification
fhost_app.app_context().push() fhost_app.app_context().push()
class NullptrMod(Screen): class NullptrMod(Screen):
BINDINGS = [ BINDINGS = [
("q", "quit_app", "Quit"), ("q", "quit_app", "Quit"),
@ -67,53 +68,58 @@ class NullptrMod(Screen):
self.finput.display = False self.finput.display = False
ftable = self.query_one("#ftable") ftable = self.query_one("#ftable")
ftable.focus() ftable.focus()
q = ftable.base_query
if len(message.value): if len(message.value):
match self.filter_col: match self.filter_col:
case 1: case 1:
try: ftable.query = ftable.base_query.filter(File.id == su.debase(message.value)) try:
except ValueError: pass q = q.filter(File.id == su.debase(message.value))
except ValueError:
return
case 2: case 2:
try: try:
addr = ipaddress.ip_address(message.value) addr = ipaddress.ip_address(message.value)
if type(addr) is ipaddress.IPv6Address: if type(addr) is ipaddress.IPv6Address:
addr = addr.ipv4_mapped or addr addr = addr.ipv4_mapped or addr
q = ftable.base_query.filter(File.addr == addr) q = q.filter(File.addr == addr)
except ValueError:
return
case 3: q = q.filter(File.mime.like(message.value))
case 4: q = q.filter(File.ext.like(message.value))
case 5: q = q.filter(File.ua.like(message.value))
ftable.query = q ftable.query = q
except ValueError: pass
case 3: ftable.query = ftable.base_query.filter(File.mime.like(message.value))
case 4: ftable.query = ftable.base_query.filter(File.ext.like(message.value))
case 5: ftable.query = ftable.base_query.filter(File.ua.like(message.value))
else:
ftable.query = ftable.base_query
def action_remove_file(self, permanent: bool) -> None: def action_remove_file(self, permanent: bool) -> None:
if self.current_file: if self.current_file:
self.current_file.delete(permanent) self.current_file.delete(permanent)
db.session.commit() db.session.commit()
self.mount(Notification(f"{'Banned' if permanent else 'Removed'} file {self.current_file.getname()}")) self.mount(Notification(f"{'Banned' if permanent else 'Removed'}"
f"file {self.current_file.getname()}"))
self.action_refresh() self.action_refresh()
def action_ban_ip(self, nuke: bool) -> None: def action_ban_ip(self, nuke: bool) -> None:
if self.current_file: if self.current_file:
if AddrFilter.query.filter(AddrFilter.addr == addr = self.current_file.addr
self.current_file.addr).scalar(): if AddrFilter.query.filter(AddrFilter.addr == addr).scalar():
txt = f"{self.current_file.addr.compressed} is already banned" txt = f"{addr.compressed} is already banned"
else: else:
db.session.add(AddrFilter(self.current_file.addr)) db.session.add(AddrFilter(addr))
db.session.commit() db.session.commit()
txt = f"Banned {self.current_file.addr.compressed}" txt = f"Banned {addr.compressed}"
if nuke: if nuke:
tsize = 0 tsize = 0
trm = 0 trm = 0
for f in File.query.filter(File.addr == self.current_file.addr): for f in File.query.filter(File.addr == addr):
if f.getpath().is_file(): if f.getpath().is_file():
tsize += f.size or f.getpath().stat().st_size tsize += f.size or f.getpath().stat().st_size
trm += 1 trm += 1
f.delete(True) f.delete(True)
db.session.commit() db.session.commit()
txt += f", removed {trm} {'files' if trm != 1 else 'file'} totaling {do_filesizeformat(tsize, True)}" txt += f", removed {trm} {'files' if trm != 1 else 'file'} " \
f"totaling {do_filesizeformat(tsize, True)}"
self.mount(Notification(txt)) self.mount(Notification(txt))
self._refresh_layout() self._refresh_layout()
ftable = self.query_one("#ftable") ftable = self.query_one("#ftable")
@ -150,11 +156,14 @@ class NullptrMod(Screen):
self.finput = self.query_one("#filter_input") self.finput = self.query_one("#filter_input")
self.mimehandler = mime.MIMEHandler() self.mimehandler = mime.MIMEHandler()
self.mimehandler.register(mime.MIMECategory.Archive, self.handle_libarchive) self.mimehandler.register(mime.MIMECategory.Archive,
self.handle_libarchive)
self.mimehandler.register(mime.MIMECategory.Text, self.handle_text) self.mimehandler.register(mime.MIMECategory.Text, self.handle_text)
self.mimehandler.register(mime.MIMECategory.AV, self.handle_mpv) self.mimehandler.register(mime.MIMECategory.AV, self.handle_mpv)
self.mimehandler.register(mime.MIMECategory.Document, self.handle_mupdf) self.mimehandler.register(mime.MIMECategory.Document,
self.mimehandler.register(mime.MIMECategory.Fallback, self.handle_libarchive) self.handle_mupdf)
self.mimehandler.register(mime.MIMECategory.Fallback,
self.handle_libarchive)
self.mimehandler.register(mime.MIMECategory.Fallback, self.handle_mpv) self.mimehandler.register(mime.MIMECategory.Fallback, self.handle_mpv)
self.mimehandler.register(mime.MIMECategory.Fallback, self.handle_raw) self.mimehandler.register(mime.MIMECategory.Fallback, self.handle_raw)
@ -181,7 +190,8 @@ class NullptrMod(Screen):
self.mpvw.styles.height = "40%" self.mpvw.styles.height = "40%"
self.mpvw.start_mpv("hex://" + imgdata, 0) self.mpvw.start_mpv("hex://" + imgdata, 0)
self.ftlog.write(Text.from_markup(f"[bold]Pages:[/bold] {doc.page_count}")) self.ftlog.write(
Text.from_markup(f"[bold]Pages:[/bold] {doc.page_count}"))
self.ftlog.write(Text.from_markup("[bold]Metadata:[/bold]")) self.ftlog.write(Text.from_markup("[bold]Metadata:[/bold]"))
for k, v in doc.metadata.items(): for k, v in doc.metadata.items():
self.ftlog.write(Text.from_markup(f" [bold]{k}:[/bold] {v}")) self.ftlog.write(Text.from_markup(f" [bold]{k}:[/bold] {v}"))
@ -206,7 +216,8 @@ class NullptrMod(Screen):
for k, v in c.metadata.items(): for k, v in c.metadata.items():
self.ftlog.write(f" {k}: {v}") self.ftlog.write(f" {k}: {v}")
for s in c.streams: for s in c.streams:
self.ftlog.write(Text(f"Stream {s.index}:", style="bold")) self.ftlog.write(
Text(f"Stream {s.index}:", style="bold"))
self.ftlog.write(f" Type: {s.type}") self.ftlog.write(f" Type: {s.type}")
if s.base_rate: if s.base_rate:
self.ftlog.write(f" Frame rate: {s.base_rate}") self.ftlog.write(f" Frame rate: {s.base_rate}")
@ -225,24 +236,31 @@ class NullptrMod(Screen):
else: else:
c = chr(s) c = chr(s)
s = c s = c
if c.isalpha(): return f"\0[chartreuse1]{s}\0[/chartreuse1]" if c.isalpha():
if c.isdigit(): return f"\0[gold1]{s}\0[/gold1]" return f"\0[chartreuse1]{s}\0[/chartreuse1]"
if c.isdigit():
return f"\0[gold1]{s}\0[/gold1]"
if not c.isprintable(): if not c.isprintable():
g = "grey50" if c == "\0" else "cadet_blue" g = "grey50" if c == "\0" else "cadet_blue"
return f"\0[{g}]{s if len(s) == 2 else '.'}\0[/{g}]" return f"\0[{g}]{s if len(s) == 2 else '.'}\0[/{g}]"
return s return s
return Text.from_markup("\n".join(f"{' '.join(map(fmt, map(''.join, zip(*[iter(c.hex())] * 2))))}"
f"{' ' * (16 - len(c))}" return Text.from_markup(
f" {''.join(map(fmt, c))}" "\n".join(' '.join(
for c in map(lambda x: bytes([n for n in x if n != None]), map(fmt, map(''.join, zip(*[iter(c.hex())] * 2)))) +
zip_longest(*[iter(binf.read(min(length, 16 * 10)))] * 16)))) f"{' ' * (16 - len(c))} {''.join(map(fmt, c))}"
for c in
map(lambda x: bytes([n for n in x if n is not None]),
zip_longest(
*[iter(binf.read(min(length, 16 * 10)))] * 16))))
with open(self.current_file.getpath(), "rb") as binf: with open(self.current_file.getpath(), "rb") as binf:
self.ftlog.write(hexdump(binf, self.current_file.size)) self.ftlog.write(hexdump(binf, self.current_file.size))
if self.current_file.size > 16*10*2: if self.current_file.size > 16*10*2:
binf.seek(self.current_file.size-16*10) binf.seek(self.current_file.size-16*10)
self.ftlog.write(" [...] ".center(64, '')) self.ftlog.write(" [...] ".center(64, ''))
self.ftlog.write(hexdump(binf, self.current_file.size - binf.tell())) self.ftlog.write(hexdump(binf,
self.current_file.size - binf.tell()))
return True return True
@ -253,7 +271,9 @@ class NullptrMod(Screen):
self.finfo.add_rows([ self.finfo.add_rows([
("ID:", str(f.id)), ("ID:", str(f.id)),
("File name:", f.getname()), ("File name:", f.getname()),
("URL:", f.geturl() if fhost_app.config["SERVER_NAME"] else "⚠ Set SERVER_NAME in config.py to display"), ("URL:", f.geturl()
if fhost_app.config["SERVER_NAME"]
else "⚠ Set SERVER_NAME in config.py to display"),
("File size:", do_filesizeformat(f.size, True)), ("File size:", do_filesizeformat(f.size, True)),
("MIME type:", f.mime), ("MIME type:", f.mime),
("SHA256 checksum:", f.sha256), ("SHA256 checksum:", f.sha256),
@ -261,9 +281,14 @@ class NullptrMod(Screen):
("User agent:", Text(f.ua or "")), ("User agent:", Text(f.ua or "")),
("Management token:", f.mgmt_token), ("Management token:", f.mgmt_token),
("Secret:", f.secret), ("Secret:", f.secret),
("Is NSFW:", ("Yes" if f.is_nsfw else "No") + (f" (Score: {f.nsfw_score:0.4f})" if f.nsfw_score else " (Not scanned)")), ("Is NSFW:", ("Yes" if f.is_nsfw else "No") +
(f" (Score: {f.nsfw_score:0.4f})"
if f.nsfw_score else " (Not scanned)")),
("Is banned:", "Yes" if f.removed else "No"), ("Is banned:", "Yes" if f.removed else "No"),
("Expires:", time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime(File.get_expiration(f.expiration, f.size)/1000))) ("Expires:",
time.strftime("%Y-%m-%d %H:%M:%S",
time.gmtime(File.get_expiration(f.expiration,
f.size)/1000)))
]) ])
self.mpvw.stop_mpv(True) self.mpvw.stop_mpv(True)
@ -273,6 +298,7 @@ class NullptrMod(Screen):
self.mimehandler.handle(f.mime, f.ext) self.mimehandler.handle(f.mime, f.ext)
self.ftlog.scroll_to(x=0, y=0, animate=False) self.ftlog.scroll_to(x=0, y=0, animate=False)
class NullptrModApp(App): class NullptrModApp(App):
CSS_PATH = "mod.css" CSS_PATH = "mod.css"
@ -282,6 +308,7 @@ class NullptrModApp(App):
self.install_screen(self.main_screen, name="main") self.install_screen(self.main_screen, name="main")
self.push_screen("main") self.push_screen("main")
if __name__ == "__main__": if __name__ == "__main__":
app = NullptrModApp() app = NullptrModApp()
app.run() app.run()

View file

@ -7,12 +7,14 @@ from jinja2.filters import do_filesizeformat
from fhost import File from fhost import File
from modui import mime from modui import mime
class FileTable(DataTable): class FileTable(DataTable):
query = Reactive(None) query = Reactive(None)
order_col = Reactive(0) order_col = Reactive(0)
order_desc = Reactive(True) order_desc = Reactive(True)
limit = 10000 limit = 10000
colmap = [File.id, File.removed, File.nsfw_score, None, File.ext, File.size, File.mime] colmap = [File.id, File.removed, File.nsfw_score, None, File.ext,
File.size, File.mime]
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -33,6 +35,8 @@ class FileTable(DataTable):
def watch_query(self, old, value) -> None: def watch_query(self, old, value) -> None:
def fmt_file(f: File) -> tuple: def fmt_file(f: File) -> tuple:
mimemoji = mime.mimemoji.get(f.mime.split('/')[0],
mime.mimemoji.get(f.mime)) or ' '
return ( return (
str(f.id), str(f.id),
"🔴" if f.removed else " ", "🔴" if f.removed else " ",
@ -40,14 +44,15 @@ class FileTable(DataTable):
"👻" if not f.getpath().is_file() else " ", "👻" if not f.getpath().is_file() else " ",
f.getname(), f.getname(),
do_filesizeformat(f.size, True), do_filesizeformat(f.size, True),
f"{mime.mimemoji.get(f.mime.split('/')[0], mime.mimemoji.get(f.mime)) or ' '} " + f.mime, f"{mimemoji} {f.mime}",
) )
if (self.query): if (self.query):
order = FileTable.colmap[self.order_col] order = FileTable.colmap[self.order_col]
q = self.query q = self.query
if order: q = q.order_by(order.desc() if self.order_desc else order, File.id) if order:
q = q.order_by(order.desc() if self.order_desc
else order, File.id)
qres = list(map(fmt_file, q.limit(self.limit))) qres = list(map(fmt_file, q.limit(self.limit)))
ri = 0 ri = 0

View file

@ -34,9 +34,9 @@ mimemoji = {
"application/pgp-encrypted": "🔏", "application/pgp-encrypted": "🔏",
} }
MIMECategory = Enum("MIMECategory", MIMECategory = Enum("MIMECategory", ["Archive", "Text", "AV", "Document",
["Archive", "Text", "AV", "Document", "Fallback"] "Fallback"])
)
class MIMEHandler: class MIMEHandler:
def __init__(self): def __init__(self):
@ -115,12 +115,14 @@ class MIMEHandler:
cat = getcat(mime) cat = getcat(mime)
for handler in self.handlers[cat][1]: for handler in self.handlers[cat][1]:
try: try:
if handler(cat): return if handler(cat):
return
except: pass except: pass
for handler in self.handlers[MIMECategory.Fallback][1]: for handler in self.handlers[MIMECategory.Fallback][1]:
try: try:
if handler(None): return if handler(None):
return
except: pass except: pass
raise RuntimeError(f"Unhandled MIME type category: {cat}") raise RuntimeError(f"Unhandled MIME type category: {cat}")

View file

@ -1,5 +1,9 @@
import time import time
import fcntl, struct, termios
import fcntl
import struct
import termios
from sys import stdout from sys import stdout
from textual import events, log from textual import events, log
@ -7,6 +11,7 @@ from textual.widgets import Static
from fhost import app as fhost_app from fhost import app as fhost_app
class MpvWidget(Static): class MpvWidget(Static):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@ -14,8 +19,10 @@ class MpvWidget(Static):
self.mpv = None self.mpv = None
self.vo = fhost_app.config.get("MOD_PREVIEW_PROTO") self.vo = fhost_app.config.get("MOD_PREVIEW_PROTO")
if not self.vo in ["sixel", "kitty"]: if self.vo not in ["sixel", "kitty"]:
self.update("⚠ Previews not enabled. \n\nSet MOD_PREVIEW_PROTO to 'sixel' or 'kitty' in config.py,\nwhichever is supported by your terminal.") self.update("⚠ Previews not enabled. \n\nSet MOD_PREVIEW_PROTO "
"to 'sixel' or 'kitty' in config.py,\nwhichever is "
"supported by your terminal.")
else: else:
try: try:
import mpv import mpv
@ -27,28 +34,35 @@ class MpvWidget(Static):
self.mpv[f"vo-sixel-buffered"] = True self.mpv[f"vo-sixel-buffered"] = True
self.mpv["audio"] = False self.mpv["audio"] = False
self.mpv["loop-file"] = "inf" self.mpv["loop-file"] = "inf"
self.mpv["image-display-duration"] = 0.5 if self.vo == "sixel" else "inf" self.mpv["image-display-duration"] = 0.5 \
if self.vo == "sixel" else "inf"
except Exception as e: except Exception as e:
self.mpv = None self.mpv = None
self.update(f"⚠ Previews require python-mpv with libmpv 0.36.0 or later \n\nError was:\n{type(e).__name__}: {e}") self.update("⚠ Previews require python-mpv with libmpv "
"0.36.0 or later \n\nError was:\n"
f"{type(e).__name__}: {e}")
def start_mpv(self, f: str|None = None, pos: float|str|None = None) -> None: def start_mpv(self, f: str | None = None,
pos: float | str | None = None) -> None:
self.display = True self.display = True
self.screen._refresh_layout() self.screen._refresh_layout()
if self.mpv: if self.mpv:
if self.content_region.x: if self.content_region.x:
r, c, w, h = struct.unpack('hhhh', fcntl.ioctl(0, termios.TIOCGWINSZ, '12345678')) winsz = fcntl.ioctl(0, termios.TIOCGWINSZ, '12345678')
r, c, w, h = struct.unpack('hhhh', winsz)
width = int((w / c) * self.content_region.width) width = int((w / c) * self.content_region.width)
height = int((h / r) * (self.content_region.height + (1 if self.vo == "sixel" else 0))) height = int((h / r) * (self.content_region.height +
(1 if self.vo == "sixel" else 0)))
self.mpv[f"vo-{self.vo}-left"] = self.content_region.x + 1 self.mpv[f"vo-{self.vo}-left"] = self.content_region.x + 1
self.mpv[f"vo-{self.vo}-top"] = self.content_region.y + 1 self.mpv[f"vo-{self.vo}-top"] = self.content_region.y + 1
self.mpv[f"vo-{self.vo}-rows"] = self.content_region.height + (1 if self.vo == "sixel" else 0) self.mpv[f"vo-{self.vo}-rows"] = self.content_region.height + \
(1 if self.vo == "sixel" else 0)
self.mpv[f"vo-{self.vo}-cols"] = self.content_region.width self.mpv[f"vo-{self.vo}-cols"] = self.content_region.width
self.mpv[f"vo-{self.vo}-width"] = width self.mpv[f"vo-{self.vo}-width"] = width
self.mpv[f"vo-{self.vo}-height"] = height self.mpv[f"vo-{self.vo}-height"] = height
if pos != None: if pos is not None:
self.mpv["start"] = pos self.mpv["start"] = pos
if f: if f:

View file

@ -1,5 +1,6 @@
from textual.widgets import Static from textual.widgets import Static
class Notification(Static): class Notification(Static):
def on_mount(self) -> None: def on_mount(self) -> None:
self.set_timer(3, self.remove) self.set_timer(3, self.remove)

View file

@ -18,32 +18,34 @@
and limitations under the License. and limitations under the License.
""" """
import os
import sys import sys
from pathlib import Path
import av import av
from transformers import pipeline from transformers import pipeline
class NSFWDetector: class NSFWDetector:
def __init__(self): def __init__(self):
self.classifier = pipeline("image-classification", model="giacomoarienti/nsfw-classifier") self.classifier = pipeline("image-classification",
model="giacomoarienti/nsfw-classifier")
def detect(self, fpath): def detect(self, fpath):
try: try:
with av.open(fpath) as container: with av.open(fpath) as container:
try: container.seek(int(container.duration / 2)) try:
container.seek(int(container.duration / 2))
except: container.seek(0) except: container.seek(0)
frame = next(container.decode(video=0)) frame = next(container.decode(video=0))
img = frame.to_image() img = frame.to_image()
res = self.classifier(img) res = self.classifier(img)
return max([x["score"] for x in res if x["label"] not in ["neutral", "drawings"]]) return max([x["score"] for x in res
if x["label"] not in ["neutral", "drawings"]])
except: pass except: pass
return -1.0 return -1.0
if __name__ == "__main__": if __name__ == "__main__":
n = NSFWDetector() n = NSFWDetector()