Move uploaded file handling to class

So uploading files via remote URLs was completely broken and
apparently nobody noticed. This commit fixes that, too.
Wouldn’t it be nice if there were a test suite!
This commit is contained in:
Mia Herkt 2025-03-01 10:08:27 +01:00
parent a2b322f868
commit 0cd289d981
No known key found for this signature in database

153
fhost.py
View file

@ -22,6 +22,7 @@ from flask import Flask, abort, make_response, redirect, render_template, \
Request, request, Response, send_from_directory, url_for
from flask_sqlalchemy import SQLAlchemy
from flask_migrate import Migrate
from werkzeug.datastructures import FileStorage
from sqlalchemy import and_, or_
from sqlalchemy.orm import declared_attr
import sqlalchemy.types as types
@ -37,9 +38,11 @@ import sys
import time
import datetime
import ipaddress
import io
import typing
import requests
import secrets
import shutil
import re
from validators import url as url_valid
from pathlib import Path
@ -155,6 +158,65 @@ class IPAddress(types.TypeDecorator):
return value
class TransferFile():
def __init__(self, stream, name, content_type):
self.stream = stream
self.name = name
self.sha256 = file_digest(stream, "sha256").hexdigest()
stream.seek(0, os.SEEK_END)
self.size = stream.tell()
stream.seek(0)
self.mime, self.mime_detected = self.get_mime(content_type)
self.ext = self.get_ext()
def get_mime(self, content_type):
try:
guess = mimedetect.from_descriptor(self.stream.fileno())
except io.UnsupportedOperation:
guess = mimedetect.from_buffer(self.stream.getvalue())
app.logger.debug(f"MIME - specified: '{content_type}' - "
f"detected: '{guess}'")
if (not content_type
or "/" not in content_type
or content_type == "application/octet-stream"):
mime = guess
else:
mime = content_type
if mime.startswith("text/") and "charset" not in mime:
mime += "; charset=utf-8"
return mime, guess
def get_ext(self):
ext = "".join(Path(self.name).suffixes[-2:])
if len(ext) > app.config["FHOST_MAX_EXT_LENGTH"]:
ext = Path(self.name).suffixes[-1]
gmime = self.mime.split(";")[0]
guess = guess_extension(gmime)
app.logger.debug(f"extension - specified: '{ext}' - detected: "
f"'{guess}'")
if not ext:
if gmime in app.config["FHOST_EXT_OVERRIDE"]:
ext = app.config["FHOST_EXT_OVERRIDE"][gmime]
elif guess:
ext = guess
else:
ext = ""
return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin"
def save(self, path: os.PathLike):
with open(path, "wb") as of:
shutil.copyfileobj(self.stream, of)
class File(db.Model):
id = db.Column(db.Integer, primary_key=True)
sha256 = db.Column(db.String, unique=True)
@ -170,10 +232,10 @@ class File(db.Model):
last_vscan = db.Column(db.DateTime)
size = db.Column(db.BigInteger)
def __init__(self, sha256, ext, mime, addr, ua, expiration, mgmt_token):
self.sha256 = sha256
self.ext = ext
self.mime = mime
def __init__(self, file_: TransferFile, addr, ua, expiration, mgmt_token):
self.sha256 = file_.sha256
self.ext = file_.ext
self.mime = file_.mime
self.addr = addr
self.ua = ua
self.expiration = expiration
@ -246,62 +308,20 @@ class File(db.Model):
down to that value.
"""
@staticmethod
def store(file_, requested_expiration: typing.Optional[int], addr, ua,
secret: bool):
fstream = file_.stream
digest = file_digest(fstream, "sha256").hexdigest()
fstream.seek(0, os.SEEK_END)
flen = fstream.tell()
fstream.seek(0)
def store(file_: TransferFile, requested_expiration: typing.Optional[int],
addr, ua, secret: bool):
def get_mime():
guess = mimedetect.from_descriptor(fstream.fileno())
app.logger.debug(f"MIME - specified: '{file_.content_type}' - "
f"detected: '{guess}'")
if (not file_.content_type
or "/" not in file_.content_type
or file_.content_type == "application/octet-stream"):
mime = guess
else:
mime = file_.content_type
if len(mime) > 128:
if len(file_.mime) > 128:
abort(400)
for flt in MIMEFilter.query.all():
if flt.check(guess):
if flt.check(file_.mime_detected):
abort(403, flt.reason)
if mime.startswith("text/") and "charset" not in mime:
mime += "; charset=utf-8"
return mime
def get_ext(mime):
ext = "".join(Path(file_.filename).suffixes[-2:])
if len(ext) > app.config["FHOST_MAX_EXT_LENGTH"]:
ext = Path(file_.filename).suffixes[-1]
gmime = mime.split(";")[0]
guess = guess_extension(gmime)
app.logger.debug(f"extension - specified: '{ext}' - detected: "
f"'{guess}'")
if not ext:
if gmime in app.config["FHOST_EXT_OVERRIDE"]:
ext = app.config["FHOST_EXT_OVERRIDE"][gmime]
elif guess:
ext = guess
else:
ext = ""
return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin"
expiration = File.get_expiration(requested_expiration, flen)
expiration = File.get_expiration(requested_expiration, file_.size)
isnew = True
f = File.query.filter_by(sha256=digest).first()
f = File.query.filter_by(sha256=file_.sha256).first()
if f:
# If the file already exists
if f.removed:
@ -318,10 +338,8 @@ class File(db.Model):
f.expiration = max(f.expiration, expiration)
isnew = False
else:
mime = get_mime()
ext = get_ext(mime)
mgmt_token = secrets.token_urlsafe()
f = File(digest, ext, mime, addr, ua, expiration, mgmt_token)
f = File(file_, addr, ua, expiration, mgmt_token)
f.addr = addr
f.ua = ua
@ -334,12 +352,12 @@ class File(db.Model):
storage = Path(app.config["FHOST_STORAGE_PATH"])
storage.mkdir(parents=True, exist_ok=True)
p = storage / digest
p = storage / file_.sha256
if not p.is_file():
file_.save(p)
f.size = flen
f.size = file_.size
if not f.nsfw_score and app.config["NSFW_DETECT"]:
f.nsfw_score = nsfw.detect(str(p))
@ -527,8 +545,9 @@ requested_expiration can be:
Any value greater that the longest allowed file lifespan will be rounded down
to that value.
"""
def store_file(f, requested_expiration: typing.Optional[int], addr, ua,
secret: bool):
def store_file(f: TransferFile, requested_expiration: typing.Optional[int],
addr, ua, secret: bool):
sf, isnew = File.store(f, requested_expiration, addr, ua, secret)
response = make_response(sf.geturl())
@ -556,13 +575,10 @@ def store_url(url, addr, ua, secret: bool):
length = int(r.headers["content-length"])
if length <= app.config["MAX_CONTENT_LENGTH"]:
def urlfile(**kwargs):
return type('', (), kwargs)()
tf = TransferFile(io.BytesIO(r.raw.read()),
r.headers["content-type"], "")
f = urlfile(read=r.raw.read,
content_type=r.headers["content-type"], filename="")
return store_file(f, None, addr, ua, secret)
return store_file(tf, None, addr, ua, secret)
else:
abort(413)
else:
@ -661,10 +677,13 @@ def fhost():
addr = addr.ipv4_mapped or addr
if "file" in request.files:
f = request.files["file"]
tf = TransferFile(f.stream, f.filename, f.content_type)
try:
# Store the file with the requested expiration date
return store_file(
request.files["file"],
tf,
int(request.form["expires"]),
addr,
request.user_agent.string,
@ -676,7 +695,7 @@ def fhost():
except KeyError:
# No expiration date was requested, store with the max lifespan
return store_file(
request.files["file"],
tf,
None,
addr,
request.user_agent.string,