Browse Source

[3.0.0] Update of this fucking SInk

Penta 3 tuần trước cách đây
mục cha
commit
702eb5a6b2
1 tập tin đã thay đổi với 48 bổ sung35 xóa
  1. 48 35
      chatbot.py

+ 48 - 35
chatbot.py

@@ -94,27 +94,9 @@ class STTSink(Sink):
     def __init__(self):
         super().__init__()
         self.user_ws = {}
-        self.buffers = {}          # audio accumulé par user
-        self.last_send = {}        # timestamp dernier envoi
+        self.buffers = {}
         self.last_voice = {}
 
-        asyncio.run_coroutine_threadsafe(self._send_silence_loop(), MAIN_LOOP)
-
-    async def _send_silence_loop(self):
-        while True:
-            await asyncio.sleep(0.2)
-            now = time.time()
-
-            for user_id, last in list(self.last_voice.items()):
-                if now - last > 0.7:
-                    # 200 ms de silence @ 16kHz int16
-                    silence = np.zeros(3200, dtype=np.int16).tobytes()
-
-                    await self._send_audio(user_id, silence)
-
-                    # reset pour éviter le spam
-                    self.last_voice[user_id] = now
-
     async def _get_ws(self, user_id):
         if user_id not in self.user_ws:
             ws = await websockets.connect(WHISPER_WS_URL)
@@ -127,12 +109,16 @@ class STTSink(Sink):
             async for msg in ws:
                 data = json.loads(msg)
                 if data.get("type") == "final":
-                    text = data["text"]
+                    text = data["text"].strip()
+
+                    if self._ignore_text(text):
+                        return
 
                     if reply_mode == ReplyMode.TEXT and reply_text_channel:
                         await reply_text_channel.send(f"🗣️ {text}")
                     else:
                         logger.info(f"[STT][{user_id}] {text}")
+
         except Exception as e:
             logger.warning(f"[STT][{user_id}] WS fermé : {e}")
 
@@ -141,32 +127,59 @@ class STTSink(Sink):
             return
 
         audio = discord_pcm_to_whisper_int16(pcm_bytes)
-
         if not audio:
             return
 
-        self.last_voice[user_id] = time.time()
+        now = time.time()
+        self.last_voice[user_id] = now
 
         if user_id not in self.buffers:
-            self.buffers[user_id] = []
-            self.last_send[user_id] = time.time()
+            self.buffers[user_id] = bytearray()
 
-        self.buffers[user_id].append(audio)
-        now = time.time()
+        self.buffers[user_id].extend(audio)
+
+        # Durée buffer actuelle
+        buffer_sec = len(self.buffers[user_id]) / (16000 * 2)
+
+        # Silence détecté → on flush
+        if buffer_sec >= 2.0:
+            asyncio.run_coroutine_threadsafe(
+                self._flush_if_silence(user_id),
+                MAIN_LOOP
+            )
 
-        # Envoi toutes les ~600 ms
-        if now - self.last_send[user_id] >= 0.6:
-            chunk = b"".join(self.buffers[user_id])
-            self.buffers[user_id].clear()
-            self.last_send[user_id] = now
+    async def _flush_if_silence(self, user_id):
+        await asyncio.sleep(1.2)
 
-            asyncio.run_coroutine_threadsafe(self._send_audio(user_id, chunk), MAIN_LOOP)
+        last = self.last_voice.get(user_id, 0)
+        if time.time() - last < 1.2:
+            return  # toujours en train de parler
 
-            logger.debug(f"[STT] audio envoyé user={user_id} bytes={len(chunk)}")
+        chunk = bytes(self.buffers[user_id])
+        self.buffers[user_id].clear()
+
+        if len(chunk) < 16000 * 2 * 2:
+            return  # trop court
 
-    async def _send_audio(self, user_id, pcm_bytes):
         ws = await self._get_ws(user_id)
-        await ws.send(pcm_bytes)
+        await ws.send(chunk)
+
+        logger.debug(f"[STT] chunk envoyé user={user_id} bytes={len(chunk)}")
+
+    def _ignore_text(self, text: str) -> bool:
+        BAD = [
+            "amara",
+            "sous-titres",
+            "merci",
+            "musique",
+            "applaudissements"
+        ]
+
+        t = text.lower()
+        return (
+            len(t) < 3
+            or any(b in t for b in BAD)
+        )
 
 # Liste pour stocker l'historique des conversations
 conversation_history = []