import os import logging from typing import List, Optional, Dict, Tuple, Any, Set import psycopg2 from psycopg2.extras import execute_batch from datetime import datetime, timedelta from dataclasses import dataclass from dotenv import load_dotenv @dataclass class DBConfig: host: str port: str database: str user: str password: str class DatabaseManager: def __init__(self, config: Optional[DBConfig] = None): if config is None: load_dotenv() config = DBConfig( host=os.getenv('DB_HOST', 'localhost'), port=os.getenv('DB_PORT', '5432'), database=os.getenv('DB_NAME', 'postgres'), user=os.getenv('DB_USER', ''), password=os.getenv('DB_PASSWORD', '') ) self.config = config self._setup_logging() self.ensure_schema() def _setup_logging(self): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) def get_connection(self): return psycopg2.connect( host=self.config.host, port=self.config.port, database=self.config.database, user=self.config.user, password=self.config.password ) def ensure_schema(self): """Ensure all required tables and columns exist in the database.""" conn = self.get_connection() cursor = conn.cursor() try: # Check if stocknews table exists cursor.execute(""" SELECT EXISTS ( SELECT FROM information_schema.tables WHERE table_name = 'stocknews' ); """) if not cursor.fetchone()[0]: logging.info("Creating stocknews table") cursor.execute(""" CREATE TABLE stocknews ( id SERIAL PRIMARY KEY, ticker_id INTEGER NOT NULL, news_type VARCHAR(50) NOT NULL, title TEXT NOT NULL, content TEXT, url TEXT, published_at TIMESTAMP WITH TIME ZONE NOT NULL, sentiment_score FLOAT, source VARCHAR(100), created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP, CONSTRAINT fk_ticker FOREIGN KEY(ticker_id) REFERENCES tickers(id) ON DELETE CASCADE ); CREATE INDEX idx_stocknews_ticker_date ON stocknews(ticker_id, published_at); CREATE INDEX idx_stocknews_title ON stocknews(title text_pattern_ops); """) conn.commit() logging.info("Successfully created stocknews table") except Exception as e: conn.rollback() logging.error(f"Error ensuring schema: {e}") raise finally: cursor.close() conn.close() def _parse_date(self, date_str: str) -> datetime: date_formats = [ '%Y-%m-%dT%H:%M:%S.%fZ', '%Y-%m-%dT%H:%M:%SZ', '%Y-%m-%d %H:%M:%S', '%Y-%m-%d' ] for date_format in date_formats: try: return datetime.strptime(date_str, date_format) except ValueError: continue raise ValueError(f"Unable to parse date: {date_str}") def _convert_sentiment_to_score(self, sentiment: Any) -> float: if isinstance(sentiment, (int, float)): return max(-1.0, min(1.0, float(sentiment))) if isinstance(sentiment, str): sentiment_map = { 'positive': 1.0, 'negative': -1.0, 'neutral': 0.0, } return sentiment_map.get(sentiment.lower(), 0.0) return 0.0 def get_latest_news_date(self, ticker_id: int) -> datetime: conn = self.get_connection() cursor = conn.cursor() try: cursor.execute(""" SELECT MAX(published_at) FROM stocknews WHERE ticker_id = %s """, (ticker_id,)) latest_date = cursor.fetchone()[0] return latest_date if latest_date else datetime.now() - timedelta(days=30) finally: cursor.close() conn.close() def get_existing_titles(self, ticker_id: int, from_date: datetime) -> Set[str]: conn = self.get_connection() cursor = conn.cursor() try: cursor.execute(""" SELECT title FROM stocknews WHERE ticker_id = %s AND published_at >= %s """, (ticker_id, from_date)) return {row[0] for row in cursor.fetchall()} finally: cursor.close() conn.close() def save_news_batch(self, news_items: List[Tuple[int, str, Dict]]) -> int: if not news_items: return 0 conn = self.get_connection() cursor = conn.cursor() items_saved = 0 try: ticker_ids = set() for ticker_id, news_type, news_item in news_items: ticker_ids.add(ticker_id) date_str = news_item.get('publishedDate', news_item.get('date')) if not date_str: continue try: published_at = self._parse_date(date_str) title = news_item.get('title', '').strip() if not title: continue # Check for existing news with normalized comparison cursor.execute(""" SELECT COUNT(*) FROM stocknews WHERE ticker_id = %s AND LOWER(TRIM(title)) = LOWER(%s) AND date_trunc('day', published_at) = date_trunc('day', %s::timestamp) """, (ticker_id, title, published_at)) if cursor.fetchone()[0] == 0: sentiment_score = self._convert_sentiment_to_score(news_item.get('sentiment')) cursor.execute(""" INSERT INTO stocknews (ticker_id, news_type, title, content, url, published_at, sentiment_score, source, created_at) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( ticker_id, news_type, title, news_item.get('text', news_item.get('content')), news_item.get('url'), published_at, sentiment_score, news_item.get('source', 'FMP'), datetime.now() )) items_saved += 1 except ValueError as e: logging.error(f"Error processing news item: {e}") continue if ticker_ids: cursor.execute(""" UPDATE tickers SET last_checked_at = NOW() WHERE id = ANY(%s) """, (list(ticker_ids),)) conn.commit() logging.info(f"Successfully saved {items_saved} news items") return items_saved except Exception as e: conn.rollback() logging.error(f"Error in batch save: {e}") raise finally: cursor.close() conn.close() def get_tickers_for_update(self, update_interval_minutes: int = 15) -> List[Tuple[int, str]]: conn = self.get_connection() cursor = conn.cursor() try: cursor.execute(""" SELECT t.id, t.yf_ticker FROM tickers t WHERE t.yf_ticker IS NOT NULL AND ( t.last_checked_at IS NULL OR t.last_checked_at < NOW() - INTERVAL '%s minutes' ) ORDER BY t.last_checked_at NULLS FIRST """, (update_interval_minutes,)) return cursor.fetchall() finally: cursor.close() conn.close()