Move uploaded file handling to class

So uploading files via remote URLs was completely broken and
apparently nobody noticed. This commit fixes that, too.
Wouldn’t it be nice if there were a test suite!
This commit is contained in:
Mia Herkt 2025-03-01 10:08:27 +01:00
parent a2b322f868
commit 0cd289d981
No known key found for this signature in database

159
fhost.py
View file

@ -22,6 +22,7 @@ from flask import Flask, abort, make_response, redirect, render_template, \
Request, request, Response, send_from_directory, url_for Request, request, Response, send_from_directory, url_for
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate from flask_migrate import Migrate
from werkzeug.datastructures import FileStorage
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from sqlalchemy.orm import declared_attr from sqlalchemy.orm import declared_attr
import sqlalchemy.types as types import sqlalchemy.types as types
@ -37,9 +38,11 @@ import sys
import time import time
import datetime import datetime
import ipaddress import ipaddress
import io
import typing import typing
import requests import requests
import secrets import secrets
import shutil
import re import re
from validators import url as url_valid from validators import url as url_valid
from pathlib import Path from pathlib import Path
@ -155,6 +158,65 @@ class IPAddress(types.TypeDecorator):
return value return value
class TransferFile():
def __init__(self, stream, name, content_type):
self.stream = stream
self.name = name
self.sha256 = file_digest(stream, "sha256").hexdigest()
stream.seek(0, os.SEEK_END)
self.size = stream.tell()
stream.seek(0)
self.mime, self.mime_detected = self.get_mime(content_type)
self.ext = self.get_ext()
def get_mime(self, content_type):
try:
guess = mimedetect.from_descriptor(self.stream.fileno())
except io.UnsupportedOperation:
guess = mimedetect.from_buffer(self.stream.getvalue())
app.logger.debug(f"MIME - specified: '{content_type}' - "
f"detected: '{guess}'")
if (not content_type
or "/" not in content_type
or content_type == "application/octet-stream"):
mime = guess
else:
mime = content_type
if mime.startswith("text/") and "charset" not in mime:
mime += "; charset=utf-8"
return mime, guess
def get_ext(self):
ext = "".join(Path(self.name).suffixes[-2:])
if len(ext) > app.config["FHOST_MAX_EXT_LENGTH"]:
ext = Path(self.name).suffixes[-1]
gmime = self.mime.split(";")[0]
guess = guess_extension(gmime)
app.logger.debug(f"extension - specified: '{ext}' - detected: "
f"'{guess}'")
if not ext:
if gmime in app.config["FHOST_EXT_OVERRIDE"]:
ext = app.config["FHOST_EXT_OVERRIDE"][gmime]
elif guess:
ext = guess
else:
ext = ""
return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin"
def save(self, path: os.PathLike):
with open(path, "wb") as of:
shutil.copyfileobj(self.stream, of)
class File(db.Model): class File(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
sha256 = db.Column(db.String, unique=True) sha256 = db.Column(db.String, unique=True)
@ -170,10 +232,10 @@ class File(db.Model):
last_vscan = db.Column(db.DateTime) last_vscan = db.Column(db.DateTime)
size = db.Column(db.BigInteger) size = db.Column(db.BigInteger)
def __init__(self, sha256, ext, mime, addr, ua, expiration, mgmt_token): def __init__(self, file_: TransferFile, addr, ua, expiration, mgmt_token):
self.sha256 = sha256 self.sha256 = file_.sha256
self.ext = ext self.ext = file_.ext
self.mime = mime self.mime = file_.mime
self.addr = addr self.addr = addr
self.ua = ua self.ua = ua
self.expiration = expiration self.expiration = expiration
@ -246,62 +308,20 @@ class File(db.Model):
down to that value. down to that value.
""" """
@staticmethod @staticmethod
def store(file_, requested_expiration: typing.Optional[int], addr, ua, def store(file_: TransferFile, requested_expiration: typing.Optional[int],
secret: bool): addr, ua, secret: bool):
fstream = file_.stream
digest = file_digest(fstream, "sha256").hexdigest()
fstream.seek(0, os.SEEK_END)
flen = fstream.tell()
fstream.seek(0)
def get_mime(): if len(file_.mime) > 128:
guess = mimedetect.from_descriptor(fstream.fileno()) abort(400)
app.logger.debug(f"MIME - specified: '{file_.content_type}' - "
f"detected: '{guess}'")
if (not file_.content_type for flt in MIMEFilter.query.all():
or "/" not in file_.content_type if flt.check(file_.mime_detected):
or file_.content_type == "application/octet-stream"): abort(403, flt.reason)
mime = guess
else:
mime = file_.content_type
if len(mime) > 128: expiration = File.get_expiration(requested_expiration, file_.size)
abort(400)
for flt in MIMEFilter.query.all():
if flt.check(guess):
abort(403, flt.reason)
if mime.startswith("text/") and "charset" not in mime:
mime += "; charset=utf-8"
return mime
def get_ext(mime):
ext = "".join(Path(file_.filename).suffixes[-2:])
if len(ext) > app.config["FHOST_MAX_EXT_LENGTH"]:
ext = Path(file_.filename).suffixes[-1]
gmime = mime.split(";")[0]
guess = guess_extension(gmime)
app.logger.debug(f"extension - specified: '{ext}' - detected: "
f"'{guess}'")
if not ext:
if gmime in app.config["FHOST_EXT_OVERRIDE"]:
ext = app.config["FHOST_EXT_OVERRIDE"][gmime]
elif guess:
ext = guess
else:
ext = ""
return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin"
expiration = File.get_expiration(requested_expiration, flen)
isnew = True isnew = True
f = File.query.filter_by(sha256=digest).first() f = File.query.filter_by(sha256=file_.sha256).first()
if f: if f:
# If the file already exists # If the file already exists
if f.removed: if f.removed:
@ -318,10 +338,8 @@ class File(db.Model):
f.expiration = max(f.expiration, expiration) f.expiration = max(f.expiration, expiration)
isnew = False isnew = False
else: else:
mime = get_mime()
ext = get_ext(mime)
mgmt_token = secrets.token_urlsafe() mgmt_token = secrets.token_urlsafe()
f = File(digest, ext, mime, addr, ua, expiration, mgmt_token) f = File(file_, addr, ua, expiration, mgmt_token)
f.addr = addr f.addr = addr
f.ua = ua f.ua = ua
@ -334,12 +352,12 @@ class File(db.Model):
storage = Path(app.config["FHOST_STORAGE_PATH"]) storage = Path(app.config["FHOST_STORAGE_PATH"])
storage.mkdir(parents=True, exist_ok=True) storage.mkdir(parents=True, exist_ok=True)
p = storage / digest p = storage / file_.sha256
if not p.is_file(): if not p.is_file():
file_.save(p) file_.save(p)
f.size = flen f.size = file_.size
if not f.nsfw_score and app.config["NSFW_DETECT"]: if not f.nsfw_score and app.config["NSFW_DETECT"]:
f.nsfw_score = nsfw.detect(str(p)) f.nsfw_score = nsfw.detect(str(p))
@ -527,8 +545,9 @@ requested_expiration can be:
Any value greater that the longest allowed file lifespan will be rounded down Any value greater that the longest allowed file lifespan will be rounded down
to that value. to that value.
""" """
def store_file(f, requested_expiration: typing.Optional[int], addr, ua, def store_file(f: TransferFile, requested_expiration: typing.Optional[int],
secret: bool): addr, ua, secret: bool):
sf, isnew = File.store(f, requested_expiration, addr, ua, secret) sf, isnew = File.store(f, requested_expiration, addr, ua, secret)
response = make_response(sf.geturl()) response = make_response(sf.geturl())
@ -556,13 +575,10 @@ def store_url(url, addr, ua, secret: bool):
length = int(r.headers["content-length"]) length = int(r.headers["content-length"])
if length <= app.config["MAX_CONTENT_LENGTH"]: if length <= app.config["MAX_CONTENT_LENGTH"]:
def urlfile(**kwargs): tf = TransferFile(io.BytesIO(r.raw.read()),
return type('', (), kwargs)() r.headers["content-type"], "")
f = urlfile(read=r.raw.read, return store_file(tf, None, addr, ua, secret)
content_type=r.headers["content-type"], filename="")
return store_file(f, None, addr, ua, secret)
else: else:
abort(413) abort(413)
else: else:
@ -661,10 +677,13 @@ def fhost():
addr = addr.ipv4_mapped or addr addr = addr.ipv4_mapped or addr
if "file" in request.files: if "file" in request.files:
f = request.files["file"]
tf = TransferFile(f.stream, f.filename, f.content_type)
try: try:
# Store the file with the requested expiration date # Store the file with the requested expiration date
return store_file( return store_file(
request.files["file"], tf,
int(request.form["expires"]), int(request.form["expires"]),
addr, addr,
request.user_agent.string, request.user_agent.string,
@ -676,7 +695,7 @@ def fhost():
except KeyError: except KeyError:
# No expiration date was requested, store with the max lifespan # No expiration date was requested, store with the max lifespan
return store_file( return store_file(
request.files["file"], tf,
None, None,
addr, addr,
request.user_agent.string, request.user_agent.string,