Johnounet 1 an în urmă
părinte
comite
88740579d4
1 a modificat fișierele cu 42 adăugiri și 17 ștergeri
  1. 42 17
      chatbot.py

+ 42 - 17
chatbot.py

@@ -79,21 +79,46 @@ openai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
 # Charger l'encodeur pour le modèle GPT-4o
 encoding = tiktoken.get_encoding("o200k_base")
 
-def resize_image(image_bytes, mode='high'):
-    with Image.open(BytesIO(image_bytes)) as img:
-        if mode == 'high':
-            # Redimensionner pour le mode haute fidélité
-            img.thumbnail((2000, 2000))
-            if min(img.size) < 768:
-                scale = 768 / min(img.size)
-                new_size = tuple(int(x * scale) for x in img.size)
-                img = img.resize(new_size, Image.Resampling.LANCZOS)
-        elif mode == 'low':
-            # Redimensionner pour le mode basse fidélité
-            img = img.resize((512, 512))
-        buffer = BytesIO()
-        img.save(buffer, format=img.format)
-        return buffer.getvalue()
+def resize_image(image_bytes, mode='high', attachment_filename=None):
+    try:
+        with Image.open(BytesIO(image_bytes)) as img:
+            original_format = img.format  # Store the original format
+
+            if mode == 'high':
+                # Redimensionner pour le mode haute fidélité
+                img.thumbnail((2000, 2000))
+                if min(img.size) < 768:
+                    scale = 768 / min(img.size)
+                    new_size = tuple(int(x * scale) for x in img.size)
+                    img = img.resize(new_size, Image.Resampling.LANCZOS)
+            elif mode == 'low':
+                # Redimensionner pour le mode basse fidélité
+                img = img.resize((512, 512))
+
+            buffer = BytesIO()
+
+            img_format = img.format
+            if not img_format:
+                if attachment_filename:
+                    _, ext = os.path.splitext(attachment_filename)
+                    ext = ext.lower()
+                    format_mapping = {
+                        '.jpg': 'JPEG',
+                        '.jpeg': 'JPEG',
+                        '.png': 'PNG',
+                        '.gif': 'GIF',
+                        '.bmp': 'BMP',
+                        '.tiff': 'TIFF'
+                    }
+                    img_format = format_mapping.get(ext, 'PNG')
+                else:
+                    img_format = 'PNG'
+
+            img.save(buffer, format=img_format)
+            return buffer.getvalue()
+    except Exception as e:
+        logger.error(f"Error resizing image: {e}")
+        raise
 
 def contains_ascii_art(text):
     """
@@ -228,7 +253,7 @@ async def read_text_file(attachment):
 
 async def encode_image_from_attachment(attachment, mode='high'):
     image_data = await attachment.read()
-    resized_image = resize_image(image_data, mode=mode)
+    resized_image = resize_image(image_data, mode=mode, attachment_filename=attachment.filename)
     return base64.b64encode(resized_image).decode('utf-8')
 
 async def summarize_text(text, max_tokens=50):
@@ -359,7 +384,7 @@ async def on_message(message):
                 file_content = await read_text_file(attachment)
                 break
             # Vérifier si c'est une image
-            elif attachment.content_type.startswith('image'):
+            elif attachment.content_type in ['image/jpeg', 'image/png', 'image/gif', 'image/bmp', 'image/tiff']:
                 image_data = await encode_image_from_attachment(attachment, mode='high')
                 break