From a2b322f8687f72462fa9ac892ef8459fb374b074 Mon Sep 17 00:00:00 2001 From: Mia Herkt Date: Fri, 27 Sep 2024 20:45:42 +0200 Subject: [PATCH] Avoid holding in-memory copies of file content MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Werkzeug uses tempfile.SpooledTemporaryFile, so we can make use of file-like object properties. This may result in more disk writes, but that’s probably better than eating up RAM. I hope this fixes #84. --- fhost.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/fhost.py b/fhost.py index e363f29..5578f6e 100755 --- a/fhost.py +++ b/fhost.py @@ -27,7 +27,7 @@ from sqlalchemy.orm import declared_attr import sqlalchemy.types as types from jinja2.exceptions import * from jinja2 import ChoiceLoader, FileSystemLoader -from hashlib import sha256 +from hashlib import file_digest from magic import Magic from mimetypes import guess_extension import click @@ -248,11 +248,14 @@ class File(db.Model): @staticmethod def store(file_, requested_expiration: typing.Optional[int], addr, ua, secret: bool): - data = file_.read() - digest = sha256(data).hexdigest() + fstream = file_.stream + digest = file_digest(fstream, "sha256").hexdigest() + fstream.seek(0, os.SEEK_END) + flen = fstream.tell() + fstream.seek(0) def get_mime(): - guess = mimedetect.from_buffer(data) + guess = mimedetect.from_descriptor(fstream.fileno()) app.logger.debug(f"MIME - specified: '{file_.content_type}' - " f"detected: '{guess}'") @@ -295,7 +298,7 @@ class File(db.Model): return ext[:app.config["FHOST_MAX_EXT_LENGTH"]] or ".bin" - expiration = File.get_expiration(requested_expiration, len(data)) + expiration = File.get_expiration(requested_expiration, flen) isnew = True f = File.query.filter_by(sha256=digest).first() @@ -334,10 +337,9 @@ class File(db.Model): p = storage / digest if not p.is_file(): - with open(p, "wb") as of: - of.write(data) + file_.save(p) - f.size = len(data) + f.size = flen if not f.nsfw_score and app.config["NSFW_DETECT"]: f.nsfw_score = nsfw.detect(str(p))