ed67fe
From 7fe3dff241c11206616bf6229be898854ce0d066 Mon Sep 17 00:00:00 2001
ed67fe
From: Lumir Balhar <lbalhar@redhat.com>
ed67fe
Date: Mon, 14 Jun 2021 11:33:36 +0200
ed67fe
Subject: [PATCH] CVE-2021-28675
ed67fe
ed67fe
---
ed67fe
 src/PIL/ImageFile.py      | 12 ++++++++++--
ed67fe
 src/PIL/PsdImagePlugin.py | 33 +++++++++++++++++++++++----------
ed67fe
 2 files changed, 33 insertions(+), 12 deletions(-)
ed67fe
ed67fe
diff --git a/src/PIL/ImageFile.py b/src/PIL/ImageFile.py
ed67fe
index 1a3c4aa..2cef9ee 100644
ed67fe
--- a/src/PIL/ImageFile.py
ed67fe
+++ b/src/PIL/ImageFile.py
ed67fe
@@ -522,12 +522,18 @@ def _safe_read(fp, size):
ed67fe
 
ed67fe
     :param fp: File handle.  Must implement a read method.
ed67fe
     :param size: Number of bytes to read.
ed67fe
-    :returns: A string containing up to size bytes of data.
ed67fe
+    :returns: A string containing size bytes of data.
ed67fe
+
ed67fe
+    Raises an OSError if the file is truncated and the read can not be completed
ed67fe
+
ed67fe
     """
ed67fe
     if size <= 0:
ed67fe
         return b""
ed67fe
     if size <= SAFEBLOCK:
ed67fe
-        return fp.read(size)
ed67fe
+        data = fp.read(size)
ed67fe
+        if len(data) < size:
ed67fe
+            raise OSError("Truncated File Read")
ed67fe
+        return data
ed67fe
     data = []
ed67fe
     while size > 0:
ed67fe
         block = fp.read(min(size, SAFEBLOCK))
ed67fe
@@ -535,6 +541,8 @@ def _safe_read(fp, size):
ed67fe
             break
ed67fe
         data.append(block)
ed67fe
         size -= len(block)
ed67fe
+    if sum(len(d) for d in data) < size:
ed67fe
+        raise OSError("Truncated File Read")
ed67fe
     return b"".join(data)
ed67fe
 
ed67fe
 
ed67fe
diff --git a/src/PIL/PsdImagePlugin.py b/src/PIL/PsdImagePlugin.py
ed67fe
index fe2a2ff..add9996 100644
ed67fe
--- a/src/PIL/PsdImagePlugin.py
ed67fe
+++ b/src/PIL/PsdImagePlugin.py
ed67fe
@@ -18,6 +18,8 @@
ed67fe
 
ed67fe
 __version__ = "0.4"
ed67fe
 
ed67fe
+import io
ed67fe
+
ed67fe
 from . import Image, ImageFile, ImagePalette
ed67fe
 from ._binary import i8, i16be as i16, i32be as i32
ed67fe
 
ed67fe
@@ -114,7 +116,8 @@ class PsdImageFile(ImageFile.ImageFile):
ed67fe
             end = self.fp.tell() + size
ed67fe
             size = i32(read(4))
ed67fe
             if size:
ed67fe
-                self.layers = _layerinfo(self.fp)
ed67fe
+                _layer_data = io.BytesIO(ImageFile._safe_read(self.fp, size))
ed67fe
+                self.layers = _layerinfo(_layer_data, size)
ed67fe
             self.fp.seek(end)
ed67fe
 
ed67fe
         #
ed67fe
@@ -164,11 +167,20 @@ class PsdImageFile(ImageFile.ImageFile):
ed67fe
             Image.Image.load(self)
ed67fe
 
ed67fe
 
ed67fe
-def _layerinfo(file):
ed67fe
+def _layerinfo(fp, ct_bytes):
ed67fe
     # read layerinfo block
ed67fe
     layers = []
ed67fe
-    read = file.read
ed67fe
-    for i in range(abs(i16(read(2)))):
ed67fe
+
ed67fe
+    def read(size):
ed67fe
+        return ImageFile._safe_read(fp, size)
ed67fe
+
ed67fe
+    ct = i16(read(2))
ed67fe
+
ed67fe
+    # sanity check
ed67fe
+    if ct_bytes < (abs(ct) * 20):
ed67fe
+        raise SyntaxError("Layer block too short for number of layers requested")
ed67fe
+
ed67fe
+    for i in range(abs(ct)):
ed67fe
 
ed67fe
         # bounding box
ed67fe
         y0 = i32(read(4))
ed67fe
@@ -179,7 +191,8 @@ def _layerinfo(file):
ed67fe
         # image info
ed67fe
         info = []
ed67fe
         mode = []
ed67fe
-        types = list(range(i16(read(2))))
ed67fe
+        ct_types = i16(read(2))
ed67fe
+        types = list(range(ct_types))
ed67fe
         if len(types) > 4:
ed67fe
             continue
ed67fe
 
ed67fe
@@ -212,7 +225,7 @@ def _layerinfo(file):
ed67fe
         size = i32(read(4))  # length of the extra data field
ed67fe
         combined = 0
ed67fe
         if size:
ed67fe
-            data_end = file.tell() + size
ed67fe
+            data_end = fp.tell() + size
ed67fe
 
ed67fe
             length = i32(read(4))
ed67fe
             if length:
ed67fe
@@ -220,12 +233,12 @@ def _layerinfo(file):
ed67fe
                 mask_x = i32(read(4))
ed67fe
                 mask_h = i32(read(4)) - mask_y
ed67fe
                 mask_w = i32(read(4)) - mask_x
ed67fe
-                file.seek(length - 16, 1)
ed67fe
+                fp.seek(length - 16, 1)
ed67fe
             combined += length + 4
ed67fe
 
ed67fe
             length = i32(read(4))
ed67fe
             if length:
ed67fe
-                file.seek(length, 1)
ed67fe
+                fp.seek(length, 1)
ed67fe
             combined += length + 4
ed67fe
 
ed67fe
             length = i8(read(1))
ed67fe
@@ -235,7 +248,7 @@ def _layerinfo(file):
ed67fe
                 name = read(length).decode('latin-1', 'replace')
ed67fe
             combined += length + 1
ed67fe
 
ed67fe
-            file.seek(data_end)
ed67fe
+            fp.seek(data_end)
ed67fe
         layers.append((name, mode, (x0, y0, x1, y1)))
ed67fe
 
ed67fe
     # get tiles
ed67fe
@@ -243,7 +256,7 @@ def _layerinfo(file):
ed67fe
     for name, mode, bbox in layers:
ed67fe
         tile = []
ed67fe
         for m in mode:
ed67fe
-            t = _maketile(file, m, bbox, 1)
ed67fe
+            t = _maketile(fp, m, bbox, 1)
ed67fe
             if t:
ed67fe
                 tile.extend(t)
ed67fe
         layers[i] = name, mode, bbox, tile
ed67fe
-- 
ed67fe
2.31.1
ed67fe