Bläddra i källkod

Work in progress on PostgreSQL database instance

Penta 5 år sedan
förälder
incheckning
c1882bc7c1
6 ändrade filer med 90 tillägg och 47 borttagningar
  1. 5 12
      myanimebot/anilist.py
  2. 13 13
      myanimebot/commands.py
  3. 35 0
      myanimebot/database.py
  4. 1 1
      myanimebot/discord.py
  5. 27 13
      myanimebot/globals.py
  6. 9 8
      myanimebot/utils.py

+ 5 - 12
myanimebot/anilist.py

@@ -9,6 +9,7 @@ import requests
 
 import myanimebot.globals as globals
 import myanimebot.utils as utils
+import myanimebot.database as database
 from myanimebot.discord import send_embed_wrapper, build_embed
 
 
@@ -239,7 +240,7 @@ def get_users_db():
     ''' Returns the registered users using AniList '''
 
 	# TODO Make generic execute
-    cursor = globals.conn.cursor(buffered=True, dictionary=True)
+    cursor = database.create_cursor()
     cursor.execute("SELECT id, {}, servers FROM t_users WHERE service = %s".format(globals.DB_USER_NAME), [globals.SERVICE_ANILIST])
     users_data = cursor.fetchall()
     cursor.close()
@@ -289,15 +290,7 @@ async def send_embed_to_channels(activity : utils.Feed):
 def insert_feed_db(feed: utils.Feed):
     ''' Insert an AniList feed into database '''
 
-    cursor = globals.conn.cursor(buffered=True)
-    cursor.execute("INSERT INTO t_feeds (published, title, url, user, found, type, service) VALUES (FROM_UNIXTIME(%s), %s, %s, %s, NOW(), %s, %s)",
-                    (feed.date_publication.timestamp(),
-                     feed.media.name,
-                     feed.media.url,
-                     feed.user.name,
-                     feed.get_status_str(),
-                     globals.SERVICE_ANILIST))
-    globals.conn.commit()
+    database.insert_feed_db(feed, globals.SERVICE_ANILIST)
 
 
 async def process_new_activities(last_activity_date, users : List[utils.User]):
@@ -343,7 +336,7 @@ def get_last_activity_date_db() -> float:
     globals.conn.commit()
 
     # Get last activity date
-    cursor = globals.conn.cursor(buffered=True, dictionary=True)
+    cursor = database.create_cursor()
     cursor.execute("SELECT published FROM t_feeds WHERE service=%s ORDER BY published DESC LIMIT 1", [globals.SERVICE_ANILIST])
     data = cursor.fetchone()
 
@@ -382,7 +375,7 @@ async def background_check_feed(asyncioloop):
         try:
             await check_new_activities()
         except Exception as e:
-            globals.logger.error('Error while fetching Anilist feeds : ({})'.format(e))
+            globals.logger.exception('Error while fetching Anilist feeds : ({})'.format(e))
 
         await asyncio.sleep(globals.ANILIST_SECONDS_BETWEEN_FETCHES)
 

+ 13 - 13
myanimebot/commands.py

@@ -7,7 +7,7 @@ from typing import List, Tuple
 import myanimebot.utils as utils
 import myanimebot.globals as globals
 import myanimebot.anilist as anilist
-
+import myanimebot.database as database
 
 def build_info_cmd_message(users, server, channels, role, filters : List[utils.Service]) -> str:
     ''' Build the corresponding message for the info command '''
@@ -150,7 +150,7 @@ async def add_user_cmd(words, message):
                 return await message.channel.send("User **{}** is already registered in our database for this server!".format(user))
             else:
                 new_servers = '{},{}'.format(user_servers, server_id)
-                utils.update_user_servers_db(user, service, new_servers)					
+                utils.update_user_servers_db(user, service, new_servers)
                 return await message.channel.send("**{}** added to the database for the server **{}**.".format(user, str(message.guild)))
     except Exception as e:
         globals.logger.warning("Error while adding user '{}' on server '{}': {}".format(user, message.guild, str(e)))
@@ -275,8 +275,8 @@ async def here_cmd(author, server, channel):
         if (str(channel.id) in channels_id):
             await channel.send("Channel **{}** already in use for this server.".format(channel))
         else:
-            cursor = globals.conn.cursor(buffered=True)
-            cursor.execute("UPDATE t_servers SET channel = {} WHERE server = {}".format(channel.id, server.id))
+            cursor = database.create_cursor()
+            cursor.execute("UPDATE t_servers SET channel = {} WHERE server = '{}'".format(channel.id, server.id))
             globals.conn.commit()
             
             await channel.send("Channel updated to: **{}**.".format(channel))
@@ -284,8 +284,8 @@ async def here_cmd(author, server, channel):
         cursor.close()
     else:
         # No server found in DB, so register it
-        cursor = globals.conn.cursor(buffered=True)
-        cursor.execute("INSERT INTO t_servers (server, channel) VALUES ({},{})".format(server.id, channel.id))
+        cursor = database.create_cursor()
+        cursor.execute("INSERT INTO t_servers (server, channel) VALUES ('{}',{})".format(server.id, channel.id))
         globals.conn.commit() # TODO Move to corresponding file
         
         await channel.send("Channel **{}** configured for **{}**.".format(channel, server))
@@ -300,7 +300,7 @@ async def stop_cmd(author, server, channel):
 
     if utils.is_server_in_db(server.id):
         # Remove server from DB
-        cursor = globals.conn.cursor(buffered=True)
+        cursor = database.create_cursor()
         cursor.execute("DELETE FROM t_servers WHERE server = {}".format(server.id))
         globals.conn.commit()
 
@@ -321,7 +321,7 @@ async def role_cmd(words, message, author, server, channel):
 
     role_str = words[2]
     if (role_str == "everyone") or (role_str == "@everyone"):
-        cursor = globals.conn.cursor(buffered=True)
+        cursor = database.create_cursor()
         cursor.execute("UPDATE t_servers SET admin_group = NULL WHERE server = %s", [str(server.id)])
         globals.conn.commit()
         cursor.close()
@@ -337,7 +337,7 @@ async def role_cmd(words, message, author, server, channel):
         else:
             roleFound = rolesFound[0]
             # Update db with newly added role
-            cursor = globals.conn.cursor(buffered=True)
+            cursor = database.create_cursor()
             cursor.execute("UPDATE t_servers SET admin_group = %s WHERE server = %s", [str(roleFound.id), str(server.id)])
             globals.conn.commit()
             cursor.close()
@@ -352,7 +352,7 @@ async def top_cmd(words, channel):
 
     if len(words) == 2:
         try:
-            cursor = globals.conn.cursor(buffered=True)
+            cursor = database.create_cursor()
             cursor.execute("SELECT * FROM v_Top")
             data = cursor.fetchone()
             
@@ -365,13 +365,13 @@ async def top_cmd(words, channel):
                         
                     data = cursor.fetchone()
                     
-                cursor = globals.conn.cursor(buffered=True)
+                cursor = database.create_cursor()
                 cursor.execute("SELECT * FROM v_TotalFeeds")
                 data = cursor.fetchone()
                 
                 topText += "\n***Total user entry***: " + str(data[0])
                 
-                cursor = globals.conn.cursor(buffered=True)
+                cursor = database.create_cursor()
                 cursor.execute("SELECT * FROM v_TotalAnimes")
                 data = cursor.fetchone()
                 
@@ -388,7 +388,7 @@ async def top_cmd(words, channel):
         globals.logger.info("Displaying the global top for the keyword: " + keyword)
         
         try:
-            cursor = globals.conn.cursor(buffered=True)
+            cursor = database.create_cursor()
             cursor.callproc('sp_UsersPerKeyword', [str(keyword), '20'])
             for result in cursor.stored_results():
                 data = result.fetchone()

+ 35 - 0
myanimebot/database.py

@@ -0,0 +1,35 @@
+import myanimebot.globals as globals
+import myanimebot.utils as utils
+
+import psycopg2.extras
+
+def create_cursor():
+	if (globals.dbType.lower() == "mariadb") or (globals.dbType.lower() == "mysql") :
+		cursor = globals.conn.cursor(buffered=True, dictionary=True)
+
+	elif (globals.dbType.lower() == "postgresql") or (globals.dbType.lower() == "pgsql") or (globals.dbType.lower() == "posgres") :
+		cursor = globals.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
+
+	return cursor
+	
+def insert_feed_db(feed, service : str):
+	cursor = create_cursor()
+	
+	if (globals.dbType.lower() == "mariadb") or (globals.dbType.lower() == "mysql") :
+		cursor.execute("INSERT INTO t_feeds (published, title, url, user, found, type, service) VALUES (FROM_UNIXTIME(%s), %s, %s, %s, NOW(), %s, %s)",
+						(feed.date_publication.timestamp(),
+						 feed.media.name,
+						 feed.media.url,
+						 feed.user.name,
+						 feed.get_status_str(),
+						 service))
+	elif (globals.dbType.lower() == "postgresql") or (globals.dbType.lower() == "pgsql") or (globals.dbType.lower() == "posgres") :
+		cursor.execute("INSERT INTO t_feeds (published, title, url, \"user\", found, type, service) VALUES (TO_TIMESTAMP(%s), %s, %s, %s, NOW(), %s, %s)",
+						(feed.date_publication.timestamp(),
+						 feed.media.name,
+						 feed.media.url,
+						 feed.user.name,
+						 feed.get_status_str(),
+						 service))
+	
+	globals.conn.commit()

+ 1 - 1
myanimebot/discord.py

@@ -285,7 +285,7 @@ async def update_thumbnail_catalog(asyncioloop):
         globals.logger.info("Automatic check of the thumbnail database on going...")
         reload = 0
         
-        cursor = globals.conn.cursor(buffered=True)
+        cursor = database.create_cursor()
         cursor.execute("SELECT guid, title, thumbnail FROM t_animes")
         data = cursor.fetchone()
 

+ 27 - 13
myanimebot/globals.py

@@ -7,6 +7,7 @@ import discord
 import pytz
 import feedparser
 import mariadb
+import psycopg2
 import pytz
 
 
@@ -45,14 +46,21 @@ except Exception as e:
 
 CONFIG=config["MYANIMEBOT"]
 logLevel=CONFIG.get("logLevel", "INFO")
-dbHost=CONFIG.get("mariadb.host", "127.0.0.1")
-dbUser=CONFIG.get("mariadb.user", "myanimebot")
-dbPassword=CONFIG.get("mariadb.password")
-dbName=CONFIG.get("mariadb.name", "myanimebot")
-dbSSLenabled=CONFIG.getboolean("mariadb.ssl", False)
-dbSSLca=CONFIG.get("mariadb.ssl.ca")
-dbSSLcert=CONFIG.get("mariadb.ssl.cert")
-dbSSLkey=CONFIG.get("mariadb.ssl.key")
+dbType=CONFIG.get("database.type", "mariadb")
+dbMariaHost=CONFIG.get("mariadb.host", "127.0.0.1")
+dbMariaPort=CONFIG.get("mariadb.port", "3306")
+dbMariaUser=CONFIG.get("mariadb.user", "myanimebot")
+dbMariaPassword=CONFIG.get("mariadb.password")
+dbMariaName=CONFIG.get("mariadb.name", "myanimebot")
+dbMariaSSLenabled=CONFIG.getboolean("mariadb.ssl", False)
+dbMariaSSLca=CONFIG.get("mariadb.ssl.ca")
+dbMariaSSLcert=CONFIG.get("mariadb.ssl.cert")
+dbMariaSSLkey=CONFIG.get("mariadb.ssl.key")
+dbPgHost=CONFIG.get("postgresql.host", "127.0.0.1")
+dbPgPort=CONFIG.get("postgresql.port", "5432")
+dbPgUser=CONFIG.get("postgresql.user", "myanimebot")
+dbPgPassword=CONFIG.get("postgresql.password")
+dbPgName=CONFIG.get("postgresql.name", "myanimebot")
 logPath=CONFIG.get("logPath", "myanimebot.log")
 timezone=pytz.timezone(CONFIG.get("timezone", "utc"))
 secondMax=CONFIG.getint("secondMax", 7200)
@@ -99,12 +107,18 @@ logger.debug("DEBUG log: OK")
 # Initialization of the database
 try:
 	# Main database connection
-	if (dbSSLenabled) :
-		conn = mariadb.connect(host=dbHost, user=dbUser, password=dbPassword, database=dbName, ssl_ca=dbSSLca, ssl_cert=dbSSLcert, ssl_key=dbSSLkey)
-	else :
-		conn = mariadb.connect(host=dbHost, user=dbUser, password=dbPassword, database=dbName)
+	if (dbType.lower() == "mariadb") or (dbType.lower() == "mysql") :
+		if (dbSSLenabled) :
+			conn = mariadb.connect(host=dbMariaHost, user=dbMariaUser, password=dbMariaPassword, database=dbMariaName, port=dbMariaPort, ssl_ca=dbMariaSSLca, ssl_cert=dbMariaSSLcert, ssl_key=dbMariaSSLkey)
+		else :
+			conn = mariadb.connect(host=dbMariaHost, user=dbMariaUser, password=dbMariaPassword, database=dbMariaName)
+	elif (dbType.lower() == "postgresql") or (dbType.lower() == "pgsql") or (dbType.lower() == "posgres") :
+		conn = psycopg2.connect(host=dbPgHost, user=dbPgUser, password=dbPgPassword, database=dbPgName, port=dbPgPort)
+	else:
+		logger.critical("'{}' is not a supported database type!".format(dbType))
+		quit()
 except Exception as e:
-	logger.critical("Can't connect to the database: " + str(e))
+	logger.critical("Can't connect to the database: {}".format(e))
 	quit()
 
 

+ 9 - 8
myanimebot/utils.py

@@ -3,6 +3,7 @@ from enum import Enum
 from typing import List
 
 import myanimebot.globals as globals
+import myanimebot.database as database
 
 
 # TODO Redo all of the desc/status system
@@ -246,7 +247,7 @@ def get_channels(server_id: int) -> dict:
     if server_id is None: return None
 
     # TODO Make generic execute
-    cursor = globals.conn.cursor(buffered=True, dictionary=True)
+    cursor = database.create_cursor()
     cursor.execute("SELECT channel FROM t_servers WHERE server = %s", [server_id])
     channels = cursor.fetchall()
     cursor.close()
@@ -259,7 +260,7 @@ def is_server_in_db(server_id : str) -> bool:
     if server_id is None:
         return False
 
-    cursor = globals.conn.cursor(buffered=True)
+    cursor = database.create_cursor()
     cursor.execute("SELECT server FROM t_servers WHERE server=%s", [server_id])
     data = cursor.fetchone()
     cursor.close()
@@ -269,7 +270,7 @@ def is_server_in_db(server_id : str) -> bool:
 def get_users() -> List[dict]:
     '''Returns all registered users'''
 
-    cursor = globals.conn.cursor(buffered=True, dictionary=True)
+    cursor = database.create_cursor()
     cursor.execute('SELECT {}, service, servers FROM t_users'.format(globals.DB_USER_NAME))
     users = cursor.fetchall()
     cursor.close()
@@ -281,7 +282,7 @@ def get_user_servers(user_name : str, service : Service) -> str:
     if user_name is None or service is None:
         return
 
-    cursor = globals.conn.cursor(buffered=True, dictionary=True)
+    cursor = database.create_cursor()
     cursor.execute("SELECT servers FROM t_users WHERE LOWER({})=%s AND service=%s".format(globals.DB_USER_NAME),
                      [user_name.lower(), service.value])
     user_servers = cursor.fetchone()
@@ -314,7 +315,7 @@ def delete_user_from_db(user_name : str, service : Service) -> bool:
         globals.logger.warning("Error while trying to delete user '{}' with service '{}'".format(user_name, service))
         return False
 
-    cursor = globals.conn.cursor(buffered=True)
+    cursor = database.create_cursor()
     cursor.execute("DELETE FROM t_users WHERE LOWER({}) = %s AND service=%s".format(globals.DB_USER_NAME),
                          [user_name.lower(), service.value])
     globals.conn.commit()
@@ -327,7 +328,7 @@ def update_user_servers_db(user_name : str, service : Service, servers : str) ->
         globals.logger.warning("Error while trying to update user's servers. User '{}' with service '{}' and servers '{}'".format(user_name, service, servers))
         return False
 
-    cursor = globals.conn.cursor(buffered=True)
+    cursor = database.create_cursor()
     cursor.execute("UPDATE t_users SET servers = %s WHERE LOWER({}) = %s AND service=%s".format(globals.DB_USER_NAME),
                           [servers, user_name.lower(), service.value])
     globals.conn.commit()
@@ -342,7 +343,7 @@ def insert_user_into_db(user_name : str, service : Service, servers : str) -> bo
         globals.logger.warning("Error while trying to add user '{}' with service '{}' and servers '{}'".format(user_name, service, servers))
         return False
 
-    cursor = globals.conn.cursor(buffered=True)
+    cursor = database.create_cursor()
     cursor.execute("INSERT INTO t_users ({}, service, servers) VALUES (%s, %s, %s)".format(globals.DB_USER_NAME),
                         [user_name, service.value, servers])
     globals.conn.commit()
@@ -352,7 +353,7 @@ def insert_user_into_db(user_name : str, service : Service, servers : str) -> bo
 def get_allowed_role(server : int) -> int:
     '''Return the allowed role for a given server'''
 
-    cursor = globals.conn.cursor(buffered=True, dictionary=True)
+    cursor = database.create_cursor()
     cursor.execute("SELECT admin_group FROM t_servers WHERE server=%s LIMIT 1", [str(server)])
     allowedRole = cursor.fetchone()
     cursor.close()