diff --git a/fhost.py b/fhost.py index 5578f6e..2e207d4 100755 --- a/fhost.py +++ b/fhost.py @@ -22,6 +22,7 @@ from flask import Flask, abort, make_response, redirect, render_template, \ Request, request, Response, send_from_directory, url_for from flask_sqlalchemy import SQLAlchemy from flask_migrate import Migrate +from werkzeug.datastructures import FileStorage from sqlalchemy import and_, or_ from sqlalchemy.orm import declared_attr import sqlalchemy.types as types @@ -37,9 +38,11 @@ import sys import time import datetime import ipaddress +import io import typing import requests import secrets +import shutil import re from validators import url as url_valid from pathlib import Path @@ -155,6 +158,65 @@ class IPAddress(types.TypeDecorator): return value +class TransferFile(): + def __init__(self, stream, name, content_type): + self.stream = stream + self.name = name + self.sha256 = file_digest(stream, "sha256").hexdigest() + + stream.seek(0, os.SEEK_END) + self.size = stream.tell() + stream.seek(0) + + self.mime, self.mime_detected = self.get_mime(content_type) + self.ext = self.get_ext() + + def get_mime(self, content_type): + try: + guess = mimedetect.from_descriptor(self.stream.fileno()) + except io.UnsupportedOperation: + guess = mimedetect.from_buffer(self.stream.getvalue()) + + app.logger.debug(f"MIME - specified: '{content_type}' - " + f"detected: '{guess}'") + + if (not content_type + or "/" not in content_type + or content_type == "application/octet-stream"): + mime = guess + else: + mime = content_type + + if mime.startswith("text/") and "charset" not in mime: + mime += "; charset=utf-8" + + return mime, guess + + def get_ext(self): + ext = "".join(Path(self.name).suffixes[-2:]) + if len(ext) > app.config["FHOST_MAX_EXT_LENGTH"]: + ext = Path(self.name).suffixes[-1] + gmime = self.mime.split(";")[0] + guess = guess_extension(gmime) + + app.logger.debug(f"extension - specified: '{ext}' - detected: " + f"'{guess}'") + + if not ext: + if gmime in app.config["FHOST_EXT_OVERRIDE"]: + ext = app.config["FHOST_EXT_OVERRIDE"][gmime] + elif guess: + ext = guess + else: + ext = "" + + return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin" + + def save(self, path: os.PathLike): + with open(path, "wb") as of: + shutil.copyfileobj(self.stream, of) + + class File(db.Model): id = db.Column(db.Integer, primary_key=True) sha256 = db.Column(db.String, unique=True) @@ -170,10 +232,10 @@ class File(db.Model): last_vscan = db.Column(db.DateTime) size = db.Column(db.BigInteger) - def __init__(self, sha256, ext, mime, addr, ua, expiration, mgmt_token): - self.sha256 = sha256 - self.ext = ext - self.mime = mime + def __init__(self, file_: TransferFile, addr, ua, expiration, mgmt_token): + self.sha256 = file_.sha256 + self.ext = file_.ext + self.mime = file_.mime self.addr = addr self.ua = ua self.expiration = expiration @@ -246,62 +308,20 @@ class File(db.Model): down to that value. """ @staticmethod - def store(file_, requested_expiration: typing.Optional[int], addr, ua, - secret: bool): - fstream = file_.stream - digest = file_digest(fstream, "sha256").hexdigest() - fstream.seek(0, os.SEEK_END) - flen = fstream.tell() - fstream.seek(0) + def store(file_: TransferFile, requested_expiration: typing.Optional[int], + addr, ua, secret: bool): - def get_mime(): - guess = mimedetect.from_descriptor(fstream.fileno()) - app.logger.debug(f"MIME - specified: '{file_.content_type}' - " - f"detected: '{guess}'") + if len(file_.mime) > 128: + abort(400) - if (not file_.content_type - or "/" not in file_.content_type - or file_.content_type == "application/octet-stream"): - mime = guess - else: - mime = file_.content_type + for flt in MIMEFilter.query.all(): + if flt.check(file_.mime_detected): + abort(403, flt.reason) - if len(mime) > 128: - abort(400) - - for flt in MIMEFilter.query.all(): - if flt.check(guess): - abort(403, flt.reason) - - if mime.startswith("text/") and "charset" not in mime: - mime += "; charset=utf-8" - - return mime - - def get_ext(mime): - ext = "".join(Path(file_.filename).suffixes[-2:]) - if len(ext) > app.config["FHOST_MAX_EXT_LENGTH"]: - ext = Path(file_.filename).suffixes[-1] - gmime = mime.split(";")[0] - guess = guess_extension(gmime) - - app.logger.debug(f"extension - specified: '{ext}' - detected: " - f"'{guess}'") - - if not ext: - if gmime in app.config["FHOST_EXT_OVERRIDE"]: - ext = app.config["FHOST_EXT_OVERRIDE"][gmime] - elif guess: - ext = guess - else: - ext = "" - - return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin" - - expiration = File.get_expiration(requested_expiration, flen) + expiration = File.get_expiration(requested_expiration, file_.size) isnew = True - f = File.query.filter_by(sha256=digest).first() + f = File.query.filter_by(sha256=file_.sha256).first() if f: # If the file already exists if f.removed: @@ -318,10 +338,8 @@ class File(db.Model): f.expiration = max(f.expiration, expiration) isnew = False else: - mime = get_mime() - ext = get_ext(mime) mgmt_token = secrets.token_urlsafe() - f = File(digest, ext, mime, addr, ua, expiration, mgmt_token) + f = File(file_, addr, ua, expiration, mgmt_token) f.addr = addr f.ua = ua @@ -334,12 +352,12 @@ class File(db.Model): storage = Path(app.config["FHOST_STORAGE_PATH"]) storage.mkdir(parents=True, exist_ok=True) - p = storage / digest + p = storage / file_.sha256 if not p.is_file(): file_.save(p) - f.size = flen + f.size = file_.size if not f.nsfw_score and app.config["NSFW_DETECT"]: f.nsfw_score = nsfw.detect(str(p)) @@ -527,8 +545,9 @@ requested_expiration can be: Any value greater that the longest allowed file lifespan will be rounded down to that value. """ -def store_file(f, requested_expiration: typing.Optional[int], addr, ua, - secret: bool): +def store_file(f: TransferFile, requested_expiration: typing.Optional[int], + addr, ua, secret: bool): + sf, isnew = File.store(f, requested_expiration, addr, ua, secret) response = make_response(sf.geturl()) @@ -556,13 +575,10 @@ def store_url(url, addr, ua, secret: bool): length = int(r.headers["content-length"]) if length <= app.config["MAX_CONTENT_LENGTH"]: - def urlfile(**kwargs): - return type('', (), kwargs)() + tf = TransferFile(io.BytesIO(r.raw.read()), + r.headers["content-type"], "") - f = urlfile(read=r.raw.read, - content_type=r.headers["content-type"], filename="") - - return store_file(f, None, addr, ua, secret) + return store_file(tf, None, addr, ua, secret) else: abort(413) else: @@ -661,10 +677,13 @@ def fhost(): addr = addr.ipv4_mapped or addr if "file" in request.files: + f = request.files["file"] + tf = TransferFile(f.stream, f.filename, f.content_type) + try: # Store the file with the requested expiration date return store_file( - request.files["file"], + tf, int(request.form["expires"]), addr, request.user_agent.string, @@ -676,7 +695,7 @@ def fhost(): except KeyError: # No expiration date was requested, store with the max lifespan return store_file( - request.files["file"], + tf, None, addr, request.user_agent.string,