256 lines
8.8 KiB
Python
256 lines
8.8 KiB
Python
import os
|
|
import logging
|
|
from typing import List, Optional, Dict, Tuple, Any, Set
|
|
import psycopg2
|
|
from psycopg2.extras import execute_batch
|
|
from datetime import datetime, timedelta
|
|
from dataclasses import dataclass
|
|
from dotenv import load_dotenv
|
|
|
|
@dataclass
|
|
class DBConfig:
|
|
host: str
|
|
port: str
|
|
database: str
|
|
user: str
|
|
password: str
|
|
|
|
class DatabaseManager:
|
|
def __init__(self, config: Optional[DBConfig] = None):
|
|
if config is None:
|
|
load_dotenv()
|
|
config = DBConfig(
|
|
host=os.getenv('DB_HOST', 'localhost'),
|
|
port=os.getenv('DB_PORT', '5432'),
|
|
database=os.getenv('DB_NAME', 'postgres'),
|
|
user=os.getenv('DB_USER', ''),
|
|
password=os.getenv('DB_PASSWORD', '')
|
|
)
|
|
self.config = config
|
|
self._setup_logging()
|
|
self.ensure_schema()
|
|
|
|
def _setup_logging(self):
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
|
)
|
|
|
|
def get_connection(self):
|
|
return psycopg2.connect(
|
|
host=self.config.host,
|
|
port=self.config.port,
|
|
database=self.config.database,
|
|
user=self.config.user,
|
|
password=self.config.password
|
|
)
|
|
|
|
def ensure_schema(self):
|
|
"""Ensure all required tables and columns exist in the database."""
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
# Check if stocknews table exists
|
|
cursor.execute("""
|
|
SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_name = 'stocknews'
|
|
);
|
|
""")
|
|
|
|
if not cursor.fetchone()[0]:
|
|
logging.info("Creating stocknews table")
|
|
cursor.execute("""
|
|
CREATE TABLE stocknews (
|
|
id SERIAL PRIMARY KEY,
|
|
ticker_id INTEGER NOT NULL,
|
|
news_type VARCHAR(50) NOT NULL,
|
|
title TEXT NOT NULL,
|
|
content TEXT,
|
|
url TEXT,
|
|
published_at TIMESTAMP WITH TIME ZONE NOT NULL,
|
|
sentiment_score FLOAT,
|
|
source VARCHAR(100),
|
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP,
|
|
CONSTRAINT fk_ticker
|
|
FOREIGN KEY(ticker_id)
|
|
REFERENCES tickers(id)
|
|
ON DELETE CASCADE
|
|
);
|
|
CREATE INDEX idx_stocknews_ticker_date ON stocknews(ticker_id, published_at);
|
|
CREATE INDEX idx_stocknews_title ON stocknews(title text_pattern_ops);
|
|
""")
|
|
conn.commit()
|
|
logging.info("Successfully created stocknews table")
|
|
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logging.error(f"Error ensuring schema: {e}")
|
|
raise
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def _parse_date(self, date_str: str) -> datetime:
|
|
date_formats = [
|
|
'%Y-%m-%dT%H:%M:%S.%fZ',
|
|
'%Y-%m-%dT%H:%M:%SZ',
|
|
'%Y-%m-%d %H:%M:%S',
|
|
'%Y-%m-%d'
|
|
]
|
|
|
|
for date_format in date_formats:
|
|
try:
|
|
return datetime.strptime(date_str, date_format)
|
|
except ValueError:
|
|
continue
|
|
|
|
raise ValueError(f"Unable to parse date: {date_str}")
|
|
|
|
def _convert_sentiment_to_score(self, sentiment: Any) -> float:
|
|
if isinstance(sentiment, (int, float)):
|
|
return max(-1.0, min(1.0, float(sentiment)))
|
|
|
|
if isinstance(sentiment, str):
|
|
sentiment_map = {
|
|
'positive': 1.0,
|
|
'negative': -1.0,
|
|
'neutral': 0.0,
|
|
}
|
|
return sentiment_map.get(sentiment.lower(), 0.0)
|
|
|
|
return 0.0
|
|
|
|
def get_latest_news_date(self, ticker_id: int) -> datetime:
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("""
|
|
SELECT MAX(published_at)
|
|
FROM stocknews
|
|
WHERE ticker_id = %s
|
|
""", (ticker_id,))
|
|
|
|
latest_date = cursor.fetchone()[0]
|
|
return latest_date if latest_date else datetime.now() - timedelta(days=30)
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def get_existing_titles(self, ticker_id: int, from_date: datetime) -> Set[str]:
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("""
|
|
SELECT title
|
|
FROM stocknews
|
|
WHERE ticker_id = %s AND published_at >= %s
|
|
""", (ticker_id, from_date))
|
|
|
|
return {row[0] for row in cursor.fetchall()}
|
|
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def save_news_batch(self, news_items: List[Tuple[int, str, Dict]]) -> int:
|
|
if not news_items:
|
|
return 0
|
|
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
items_saved = 0
|
|
|
|
try:
|
|
ticker_ids = set()
|
|
|
|
for ticker_id, news_type, news_item in news_items:
|
|
ticker_ids.add(ticker_id)
|
|
date_str = news_item.get('publishedDate', news_item.get('date'))
|
|
if not date_str:
|
|
continue
|
|
|
|
try:
|
|
published_at = self._parse_date(date_str)
|
|
title = news_item.get('title', '').strip()
|
|
|
|
if not title:
|
|
continue
|
|
|
|
# Check for existing news with normalized comparison
|
|
cursor.execute("""
|
|
SELECT COUNT(*) FROM stocknews
|
|
WHERE ticker_id = %s
|
|
AND LOWER(TRIM(title)) = LOWER(%s)
|
|
AND date_trunc('day', published_at) = date_trunc('day', %s::timestamp)
|
|
""", (ticker_id, title, published_at))
|
|
|
|
if cursor.fetchone()[0] == 0:
|
|
sentiment_score = self._convert_sentiment_to_score(news_item.get('sentiment'))
|
|
|
|
cursor.execute("""
|
|
INSERT INTO stocknews
|
|
(ticker_id, news_type, title, content, url, published_at,
|
|
sentiment_score, source, created_at)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
""", (
|
|
ticker_id,
|
|
news_type,
|
|
title,
|
|
news_item.get('text', news_item.get('content')),
|
|
news_item.get('url'),
|
|
published_at,
|
|
sentiment_score,
|
|
news_item.get('source', 'FMP'),
|
|
datetime.now()
|
|
))
|
|
items_saved += 1
|
|
|
|
except ValueError as e:
|
|
logging.error(f"Error processing news item: {e}")
|
|
continue
|
|
|
|
if ticker_ids:
|
|
cursor.execute("""
|
|
UPDATE tickers
|
|
SET last_checked_at = NOW()
|
|
WHERE id = ANY(%s)
|
|
""", (list(ticker_ids),))
|
|
|
|
conn.commit()
|
|
logging.info(f"Successfully saved {items_saved} news items")
|
|
return items_saved
|
|
|
|
except Exception as e:
|
|
conn.rollback()
|
|
logging.error(f"Error in batch save: {e}")
|
|
raise
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|
|
|
|
def get_tickers_for_update(self, update_interval_minutes: int = 15) -> List[Tuple[int, str]]:
|
|
conn = self.get_connection()
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
cursor.execute("""
|
|
SELECT t.id, t.yf_ticker
|
|
FROM tickers t
|
|
WHERE t.yf_ticker IS NOT NULL
|
|
AND (
|
|
t.last_checked_at IS NULL
|
|
OR t.last_checked_at < NOW() - INTERVAL '%s minutes'
|
|
)
|
|
ORDER BY t.last_checked_at NULLS FIRST
|
|
""", (update_interval_minutes,))
|
|
|
|
return cursor.fetchall()
|
|
finally:
|
|
cursor.close()
|
|
conn.close()
|