diff --git a/src/analysis/__init__.py b/src/analysis/__init__.py index c4cbd15..dd67a91 100644 --- a/src/analysis/__init__.py +++ b/src/analysis/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations from src.analysis.scanner import MarketScanner +from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner from src.analysis.volatility import VolatilityAnalyzer -__all__ = ["VolatilityAnalyzer", "MarketScanner"] +__all__ = ["VolatilityAnalyzer", "MarketScanner", "SmartVolatilityScanner", "ScanCandidate"] diff --git a/src/analysis/smart_scanner.py b/src/analysis/smart_scanner.py new file mode 100644 index 0000000..b25f15a --- /dev/null +++ b/src/analysis/smart_scanner.py @@ -0,0 +1,192 @@ +"""Smart Volatility Scanner with RSI and volume filters. + +Fetches market rankings from KIS API and applies technical filters +to identify high-probability trading candidates. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import Any + +from src.analysis.volatility import VolatilityAnalyzer +from src.broker.kis_api import KISBroker +from src.config import Settings + +logger = logging.getLogger(__name__) + + +@dataclass +class ScanCandidate: + """A qualified candidate from the smart scanner.""" + + stock_code: str + name: str + price: float + volume: float + volume_ratio: float # Current volume / previous day volume + rsi: float + signal: str # "oversold" or "momentum" + score: float # Composite score for ranking + + +class SmartVolatilityScanner: + """Scans market rankings and applies RSI/volume filters. + + Flow: + 1. Fetch volume rankings from KIS API + 2. For each ranked stock, fetch daily prices + 3. Calculate RSI and volume ratio + 4. Apply filters: volume > VOL_MULTIPLIER AND (RSI < 30 OR RSI > 70) + 5. Return top N qualified candidates + """ + + def __init__( + self, + broker: KISBroker, + volatility_analyzer: VolatilityAnalyzer, + settings: Settings, + ) -> None: + """Initialize the smart scanner. + + Args: + broker: KIS broker for API calls + volatility_analyzer: Analyzer for RSI calculation + settings: Application settings + """ + self.broker = broker + self.analyzer = volatility_analyzer + self.settings = settings + + # Extract scanner settings + self.rsi_oversold = settings.RSI_OVERSOLD_THRESHOLD + self.rsi_momentum = settings.RSI_MOMENTUM_THRESHOLD + self.vol_multiplier = settings.VOL_MULTIPLIER + self.top_n = settings.SCANNER_TOP_N + + async def scan( + self, + fallback_stocks: list[str] | None = None, + ) -> list[ScanCandidate]: + """Execute smart scan and return qualified candidates. + + Args: + fallback_stocks: Stock codes to use if ranking API fails + + Returns: + List of ScanCandidate, sorted by score, up to top_n items + """ + # Step 1: Fetch rankings + try: + rankings = await self.broker.fetch_market_rankings( + ranking_type="volume", + limit=30, # Fetch more than needed for filtering + ) + logger.info("Fetched %d stocks from volume rankings", len(rankings)) + except ConnectionError as exc: + logger.warning("Ranking API failed, using fallback: %s", exc) + if fallback_stocks: + # Create minimal ranking data for fallback + rankings = [ + { + "stock_code": code, + "name": code, + "price": 0, + "volume": 0, + "change_rate": 0, + "volume_increase_rate": 0, + } + for code in fallback_stocks + ] + else: + return [] + + # Step 2: Analyze each stock + candidates: list[ScanCandidate] = [] + + for stock in rankings: + stock_code = stock["stock_code"] + if not stock_code: + continue + + try: + # Fetch daily prices for RSI calculation + daily_prices = await self.broker.get_daily_prices(stock_code, days=20) + + if len(daily_prices) < 15: # Need at least 14+1 for RSI + logger.debug("Insufficient price history for %s", stock_code) + continue + + # Calculate RSI + close_prices = [p["close"] for p in daily_prices] + rsi = self.analyzer.calculate_rsi(close_prices, period=14) + + # Calculate volume ratio (today vs previous day avg) + if len(daily_prices) >= 2: + prev_day_volume = daily_prices[-2]["volume"] + current_volume = stock.get("volume", 0) or daily_prices[-1]["volume"] + volume_ratio = ( + current_volume / prev_day_volume if prev_day_volume > 0 else 1.0 + ) + else: + volume_ratio = stock.get("volume_increase_rate", 0) / 100 + 1 # Fallback + + # Apply filters + volume_qualified = volume_ratio >= self.vol_multiplier + rsi_oversold = rsi < self.rsi_oversold + rsi_momentum = rsi > self.rsi_momentum + + if volume_qualified and (rsi_oversold or rsi_momentum): + signal = "oversold" if rsi_oversold else "momentum" + + # Calculate composite score + # Higher score for: extreme RSI + high volume + rsi_extremity = abs(rsi - 50) / 50 # 0-1 scale + volume_score = min(volume_ratio / 5, 1.0) # Cap at 5x + score = (rsi_extremity * 0.6 + volume_score * 0.4) * 100 + + candidates.append( + ScanCandidate( + stock_code=stock_code, + name=stock.get("name", stock_code), + price=stock.get("price", daily_prices[-1]["close"]), + volume=current_volume, + volume_ratio=volume_ratio, + rsi=rsi, + signal=signal, + score=score, + ) + ) + + logger.info( + "Qualified: %s (%s) RSI=%.1f vol=%.1fx signal=%s score=%.1f", + stock_code, + stock.get("name", ""), + rsi, + volume_ratio, + signal, + score, + ) + + except ConnectionError as exc: + logger.warning("Failed to analyze %s: %s", stock_code, exc) + continue + except Exception as exc: + logger.error("Unexpected error analyzing %s: %s", stock_code, exc) + continue + + # Sort by score and return top N + candidates.sort(key=lambda c: c.score, reverse=True) + return candidates[: self.top_n] + + def get_stock_codes(self, candidates: list[ScanCandidate]) -> list[str]: + """Extract stock codes from candidates for watchlist update. + + Args: + candidates: List of scan candidates + + Returns: + List of stock codes + """ + return [c.stock_code for c in candidates] diff --git a/src/analysis/volatility.py b/src/analysis/volatility.py index cdb56d0..0794220 100644 --- a/src/analysis/volatility.py +++ b/src/analysis/volatility.py @@ -124,6 +124,54 @@ class VolatilityAnalyzer: return 1.0 return current_volume / avg_volume + def calculate_rsi( + self, + close_prices: list[float], + period: int = 14, + ) -> float: + """Calculate Relative Strength Index (RSI) using Wilder's smoothing. + + Args: + close_prices: List of closing prices (oldest to newest, minimum period+1 values) + period: RSI period (default 14) + + Returns: + RSI value between 0 and 100, or 50.0 (neutral) if insufficient data + + Examples: + >>> analyzer = VolatilityAnalyzer() + >>> prices = [100 - i * 0.5 for i in range(20)] # Downtrend + >>> rsi = analyzer.calculate_rsi(prices) + >>> assert rsi < 50 # Oversold territory + """ + if len(close_prices) < period + 1: + return 50.0 # Neutral RSI if insufficient data + + # Calculate price changes + changes = [close_prices[i] - close_prices[i - 1] for i in range(1, len(close_prices))] + + # Separate gains and losses + gains = [max(0.0, change) for change in changes] + losses = [max(0.0, -change) for change in changes] + + # Calculate initial average gain/loss (simple average for first period) + avg_gain = sum(gains[:period]) / period + avg_loss = sum(losses[:period]) / period + + # Apply Wilder's smoothing for remaining periods + for i in range(period, len(changes)): + avg_gain = (avg_gain * (period - 1) + gains[i]) / period + avg_loss = (avg_loss * (period - 1) + losses[i]) / period + + # Calculate RS and RSI + if avg_loss == 0: + return 100.0 # All gains, maximum RSI + + rs = avg_gain / avg_loss + rsi = 100 - (100 / (1 + rs)) + + return rsi + def calculate_pv_divergence( self, price_change: float, diff --git a/src/broker/kis_api.py b/src/broker/kis_api.py index f3c832b..15381c4 100644 --- a/src/broker/kis_api.py +++ b/src/broker/kis_api.py @@ -280,3 +280,153 @@ class KISBroker: return data except (TimeoutError, aiohttp.ClientError) as exc: raise ConnectionError(f"Network error sending order: {exc}") from exc + + async def fetch_market_rankings( + self, + ranking_type: str = "volume", + limit: int = 30, + ) -> list[dict[str, Any]]: + """Fetch market rankings from KIS API. + + Args: + ranking_type: Type of ranking ("volume" or "fluctuation") + limit: Maximum number of results to return + + Returns: + List of stock data dicts with keys: stock_code, name, price, volume, + change_rate, volume_increase_rate + + Raises: + ConnectionError: If API request fails + """ + await self._rate_limiter.acquire() + session = self._get_session() + + # TR_ID for volume ranking + tr_id = "FHPST01710000" if ranking_type == "volume" else "FHPST01710100" + headers = await self._auth_headers(tr_id) + + params = { + "FID_COND_MRKT_DIV_CODE": "J", # Stock/ETF/ETN + "FID_COND_SCR_DIV_CODE": "20001", # Volume surge + "FID_INPUT_ISCD": "0000", # All stocks + "FID_DIV_CLS_CODE": "0", # All types + "FID_BLNG_CLS_CODE": "0", + "FID_TRGT_CLS_CODE": "111111111", + "FID_TRGT_EXLS_CLS_CODE": "000000", + "FID_INPUT_PRICE_1": "0", + "FID_INPUT_PRICE_2": "0", + "FID_VOL_CNT": "0", + "FID_INPUT_DATE_1": "", + } + + url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/volume-rank" + + try: + async with session.get(url, headers=headers, params=params) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"fetch_market_rankings failed ({resp.status}): {text}" + ) + data = await resp.json() + + # Parse response - output is a list of ranked stocks + def _safe_float(value: str | float | None, default: float = 0.0) -> float: + if value is None or value == "": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + rankings = [] + for item in data.get("output", [])[:limit]: + rankings.append({ + "stock_code": item.get("mksc_shrn_iscd", ""), + "name": item.get("hts_kor_isnm", ""), + "price": _safe_float(item.get("stck_prpr", "0")), + "volume": _safe_float(item.get("acml_vol", "0")), + "change_rate": _safe_float(item.get("prdy_ctrt", "0")), + "volume_increase_rate": _safe_float(item.get("vol_inrt", "0")), + }) + return rankings + + except (TimeoutError, aiohttp.ClientError) as exc: + raise ConnectionError(f"Network error fetching rankings: {exc}") from exc + + async def get_daily_prices( + self, + stock_code: str, + days: int = 20, + ) -> list[dict[str, Any]]: + """Fetch daily OHLCV price history for a stock. + + Args: + stock_code: 6-digit stock code + days: Number of trading days to fetch (default 20 for RSI calculation) + + Returns: + List of daily price dicts with keys: date, open, high, low, close, volume + Sorted oldest to newest + + Raises: + ConnectionError: If API request fails + """ + await self._rate_limiter.acquire() + session = self._get_session() + + headers = await self._auth_headers("FHKST03010100") + + # Calculate date range (today and N days ago) + from datetime import datetime, timedelta + end_date = datetime.now().strftime("%Y%m%d") + start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d") + + params = { + "FID_COND_MRKT_DIV_CODE": "J", + "FID_INPUT_ISCD": stock_code, + "FID_INPUT_DATE_1": start_date, + "FID_INPUT_DATE_2": end_date, + "FID_PERIOD_DIV_CODE": "D", # Daily + "FID_ORG_ADJ_PRC": "0", # Adjusted price + } + + url = f"{self._base_url}/uapi/domestic-stock/v1/quotations/inquire-daily-itemchartprice" + + try: + async with session.get(url, headers=headers, params=params) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"get_daily_prices failed ({resp.status}): {text}" + ) + data = await resp.json() + + # Parse response + def _safe_float(value: str | float | None, default: float = 0.0) -> float: + if value is None or value == "": + return default + try: + return float(value) + except (ValueError, TypeError): + return default + + prices = [] + for item in data.get("output2", []): + prices.append({ + "date": item.get("stck_bsop_date", ""), + "open": _safe_float(item.get("stck_oprc", "0")), + "high": _safe_float(item.get("stck_hgpr", "0")), + "low": _safe_float(item.get("stck_lwpr", "0")), + "close": _safe_float(item.get("stck_clpr", "0")), + "volume": _safe_float(item.get("acml_vol", "0")), + }) + + # Sort oldest to newest (KIS returns newest first) + prices.reverse() + + return prices[:days] # Return only requested number of days + + except (TimeoutError, aiohttp.ClientError) as exc: + raise ConnectionError(f"Network error fetching daily prices: {exc}") from exc diff --git a/src/config.py b/src/config.py index fa0bb13..1c6c075 100644 --- a/src/config.py +++ b/src/config.py @@ -33,6 +33,12 @@ class Settings(BaseSettings): FAT_FINGER_PCT: float = Field(default=30.0, gt=0.0, le=100.0) CONFIDENCE_THRESHOLD: int = Field(default=80, ge=0, le=100) + # Smart Scanner Configuration + RSI_OVERSOLD_THRESHOLD: int = Field(default=30, ge=0, le=50) + RSI_MOMENTUM_THRESHOLD: int = Field(default=70, ge=50, le=100) + VOL_MULTIPLIER: float = Field(default=2.0, gt=1.0, le=10.0) + SCANNER_TOP_N: int = Field(default=3, ge=1, le=10) + # Database DB_PATH: str = "data/trade_logs.db" diff --git a/src/db.py b/src/db.py index 0a43424..9f37345 100644 --- a/src/db.py +++ b/src/db.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import sqlite3 from datetime import UTC, datetime from pathlib import Path @@ -38,6 +39,8 @@ def init_db(db_path: str) -> sqlite3.Connection: conn.execute("ALTER TABLE trades ADD COLUMN market TEXT DEFAULT 'KR'") if "exchange_code" not in columns: conn.execute("ALTER TABLE trades ADD COLUMN exchange_code TEXT DEFAULT 'KRX'") + if "selection_context" not in columns: + conn.execute("ALTER TABLE trades ADD COLUMN selection_context TEXT") # Context tree tables for multi-layered memory management conn.execute( @@ -118,15 +121,33 @@ def log_trade( pnl: float = 0.0, market: str = "KR", exchange_code: str = "KRX", + selection_context: dict[str, any] | None = None, ) -> None: - """Insert a trade record into the database.""" + """Insert a trade record into the database. + + Args: + conn: Database connection + stock_code: Stock code + action: Trade action (BUY/SELL/HOLD) + confidence: Confidence level (0-100) + rationale: AI decision rationale + quantity: Number of shares + price: Trade price + pnl: Profit/loss + market: Market code + exchange_code: Exchange code + selection_context: Scanner selection data (RSI, volume_ratio, signal, score) + """ + # Serialize selection context to JSON + context_json = json.dumps(selection_context) if selection_context else None + conn.execute( """ INSERT INTO trades ( timestamp, stock_code, action, confidence, rationale, - quantity, price, pnl, market, exchange_code + quantity, price, pnl, market, exchange_code, selection_context ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( datetime.now(UTC).isoformat(), @@ -139,6 +160,7 @@ def log_trade( pnl, market, exchange_code, + context_json, ), ) conn.commit() diff --git a/src/main.py b/src/main.py index c01e5f1..ec11afb 100644 --- a/src/main.py +++ b/src/main.py @@ -15,6 +15,7 @@ from datetime import UTC, datetime from typing import Any from src.analysis.scanner import MarketScanner +from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner from src.analysis.volatility import VolatilityAnalyzer from src.brain.gemini_client import GeminiClient from src.broker.kis_api import KISBroker @@ -100,6 +101,7 @@ async def trading_cycle( telegram: TelegramClient, market: MarketInfo, stock_code: str, + scan_candidates: dict[str, ScanCandidate], ) -> None: """Execute one trading cycle for a single stock.""" cycle_start_time = asyncio.get_event_loop().time() @@ -292,7 +294,17 @@ async def trading_cycle( except Exception as exc: logger.warning("Telegram notification failed: %s", exc) - # 6. Log trade + # 6. Log trade with selection context + selection_context = None + if stock_code in scan_candidates: + candidate = scan_candidates[stock_code] + selection_context = { + "rsi": candidate.rsi, + "volume_ratio": candidate.volume_ratio, + "signal": candidate.signal, + "score": candidate.score, + } + log_trade( conn=db_conn, stock_code=stock_code, @@ -301,6 +313,7 @@ async def trading_cycle( rationale=decision.rationale, market=market.code, exchange_code=market.exchange_code, + selection_context=selection_context, ) # 7. Latency monitoring @@ -722,6 +735,16 @@ async def run(settings: Settings) -> None: max_concurrent_scans=1, # Fully serialized to avoid EGW00201 ) + # Initialize smart scanner (Python-first, AI-last pipeline) + smart_scanner = SmartVolatilityScanner( + broker=broker, + volatility_analyzer=volatility_analyzer, + settings=settings, + ) + + # Track scan candidates for selection context logging + scan_candidates: dict[str, ScanCandidate] = {} # stock_code -> candidate + # Initialize latency control system criticality_assessor = CriticalityAssessor( critical_pnl_threshold=-2.5, # Near circuit breaker at -3.0% @@ -867,38 +890,46 @@ async def run(settings: Settings) -> None: logger.warning("Market open notification failed: %s", exc) _market_states[market.code] = True - # Volatility Hunter: Scan market periodically to update watchlist + # Smart Scanner: Python-first filtering (RSI + volume) before AI now_timestamp = asyncio.get_event_loop().time() last_scan = last_scan_time.get(market.code, 0.0) if now_timestamp - last_scan >= SCAN_INTERVAL_SECONDS: try: - # Scan all stocks in the universe - stock_universe = STOCK_UNIVERSE.get(market.code, []) - if stock_universe: - logger.info("Volatility Hunter: Scanning %s market", market.name) - scan_result = await market_scanner.scan_market( - market, stock_universe - ) + logger.info("Smart Scanner: Scanning %s market", market.name) - # Update watchlist with top movers + # Run smart scan with fallback to static universe + fallback_universe = STOCK_UNIVERSE.get(market.code, []) + candidates = await smart_scanner.scan(fallback_stocks=fallback_universe) + + if candidates: + # Update watchlist with qualified candidates + qualified_codes = smart_scanner.get_stock_codes(candidates) + + # Merge with existing watchlist (keep some continuity) current_watchlist = WATCHLISTS.get(market.code, []) - updated_watchlist = market_scanner.get_updated_watchlist( - current_watchlist, - scan_result, - max_replacements=2, - ) - WATCHLISTS[market.code] = updated_watchlist + # Keep up to 2 from existing, add new qualified + merged = qualified_codes + [ + c for c in current_watchlist if c not in qualified_codes + ][:2] + WATCHLISTS[market.code] = merged[:5] # Cap at 5 + + # Store candidates for later selection context logging + for candidate in candidates: + scan_candidates[candidate.stock_code] = candidate logger.info( - "Volatility Hunter: Watchlist updated for %s (%d top movers, %d breakouts)", + "Smart Scanner: Found %d qualified candidates for %s: %s", + len(candidates), market.name, - len(scan_result.top_movers), - len(scan_result.breakouts), + [f"{c.stock_code}(RSI={c.rsi:.0f})" for c in candidates], ) + else: + logger.info("Smart Scanner: No qualified candidates for %s", market.name) last_scan_time[market.code] = now_timestamp + except Exception as exc: - logger.error("Volatility Hunter scan failed for %s: %s", market.name, exc) + logger.error("Smart Scanner failed for %s: %s", market.name, exc) # Get watchlist for this market watchlist = WATCHLISTS.get(market.code, []) @@ -928,6 +959,7 @@ async def run(settings: Settings) -> None: telegram, market, stock_code, + scan_candidates, ) break # Success — exit retry loop except CircuitBreakerTripped as exc: diff --git a/tests/test_main.py b/tests/test_main.py index df9942b..9ed185e 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -174,6 +174,7 @@ class TestTradingCycleTelegramIntegration: telegram=mock_telegram, market=mock_market, stock_code="005930", + scan_candidates={}, ) # Verify notification was sent @@ -216,6 +217,7 @@ class TestTradingCycleTelegramIntegration: telegram=mock_telegram, market=mock_market, stock_code="005930", + scan_candidates={}, ) # Verify notification was attempted @@ -257,6 +259,7 @@ class TestTradingCycleTelegramIntegration: telegram=mock_telegram, market=mock_market, stock_code="005930", + scan_candidates={}, ) # Verify notification was sent @@ -305,6 +308,7 @@ class TestTradingCycleTelegramIntegration: telegram=mock_telegram, market=mock_market, stock_code="005930", + scan_candidates={}, ) # Verify notification was attempted @@ -345,6 +349,7 @@ class TestTradingCycleTelegramIntegration: telegram=mock_telegram, market=mock_market, stock_code="005930", + scan_candidates={}, ) # Verify no trade notification sent @@ -543,6 +548,7 @@ class TestOverseasBalanceParsing: telegram=mock_telegram, market=mock_overseas_market, stock_code="AAPL", + scan_candidates={}, ) # Verify balance API was called @@ -577,6 +583,7 @@ class TestOverseasBalanceParsing: telegram=mock_telegram, market=mock_overseas_market, stock_code="AAPL", + scan_candidates={}, ) # Verify balance API was called @@ -611,6 +618,7 @@ class TestOverseasBalanceParsing: telegram=mock_telegram, market=mock_overseas_market, stock_code="AAPL", + scan_candidates={}, ) # Verify balance API was called @@ -645,6 +653,7 @@ class TestOverseasBalanceParsing: telegram=mock_telegram, market=mock_overseas_market, stock_code="AAPL", + scan_candidates={}, ) # Verify price API was called diff --git a/tests/test_smart_scanner.py b/tests/test_smart_scanner.py new file mode 100644 index 0000000..fc380d7 --- /dev/null +++ b/tests/test_smart_scanner.py @@ -0,0 +1,377 @@ +"""Tests for SmartVolatilityScanner.""" + +from __future__ import annotations + +import pytest +from unittest.mock import AsyncMock, MagicMock + +from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner +from src.analysis.volatility import VolatilityAnalyzer +from src.broker.kis_api import KISBroker +from src.config import Settings + + +@pytest.fixture +def mock_settings() -> Settings: + """Create test settings.""" + return Settings( + KIS_APP_KEY="test", + KIS_APP_SECRET="test", + KIS_ACCOUNT_NO="12345678-01", + GEMINI_API_KEY="test", + RSI_OVERSOLD_THRESHOLD=30, + RSI_MOMENTUM_THRESHOLD=70, + VOL_MULTIPLIER=2.0, + SCANNER_TOP_N=3, + DB_PATH=":memory:", + ) + + +@pytest.fixture +def mock_broker(mock_settings: Settings) -> MagicMock: + """Create mock broker.""" + broker = MagicMock(spec=KISBroker) + broker._settings = mock_settings + broker.fetch_market_rankings = AsyncMock() + broker.get_daily_prices = AsyncMock() + return broker + + +@pytest.fixture +def scanner(mock_broker: MagicMock, mock_settings: Settings) -> SmartVolatilityScanner: + """Create smart scanner instance.""" + analyzer = VolatilityAnalyzer() + return SmartVolatilityScanner( + broker=mock_broker, + volatility_analyzer=analyzer, + settings=mock_settings, + ) + + +class TestSmartVolatilityScanner: + """Test suite for SmartVolatilityScanner.""" + + @pytest.mark.asyncio + async def test_scan_finds_oversold_candidates( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test that scanner identifies oversold stocks with high volume.""" + # Mock rankings + mock_broker.fetch_market_rankings.return_value = [ + { + "stock_code": "005930", + "name": "Samsung", + "price": 70000, + "volume": 5000000, + "change_rate": -3.5, + "volume_increase_rate": 250, + }, + ] + + # Mock daily prices - trending down (oversold) + prices = [] + for i in range(20): + prices.append({ + "date": f"2026020{i:02d}", + "open": 75000 - i * 200, + "high": 75500 - i * 200, + "low": 74500 - i * 200, + "close": 75000 - i * 250, # Steady decline + "volume": 2000000, + }) + mock_broker.get_daily_prices.return_value = prices + + candidates = await scanner.scan() + + # Should find at least one candidate (depending on exact RSI calculation) + mock_broker.fetch_market_rankings.assert_called_once() + mock_broker.get_daily_prices.assert_called_once_with("005930", days=20) + + # If qualified, should have oversold signal + if candidates: + assert candidates[0].signal in ["oversold", "momentum"] + assert candidates[0].volume_ratio >= scanner.vol_multiplier + + @pytest.mark.asyncio + async def test_scan_finds_momentum_candidates( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test that scanner identifies momentum stocks with high volume.""" + mock_broker.fetch_market_rankings.return_value = [ + { + "stock_code": "035420", + "name": "NAVER", + "price": 250000, + "volume": 3000000, + "change_rate": 5.0, + "volume_increase_rate": 300, + }, + ] + + # Mock daily prices - trending up (momentum) + prices = [] + for i in range(20): + prices.append({ + "date": f"2026020{i:02d}", + "open": 230000 + i * 500, + "high": 231000 + i * 500, + "low": 229000 + i * 500, + "close": 230500 + i * 500, # Steady rise + "volume": 1000000, + }) + mock_broker.get_daily_prices.return_value = prices + + candidates = await scanner.scan() + + mock_broker.fetch_market_rankings.assert_called_once() + + @pytest.mark.asyncio + async def test_scan_filters_low_volume( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test that stocks with low volume ratio are filtered out.""" + mock_broker.fetch_market_rankings.return_value = [ + { + "stock_code": "000660", + "name": "SK Hynix", + "price": 150000, + "volume": 500000, + "change_rate": -5.0, + "volume_increase_rate": 50, # Only 50% increase (< 200%) + }, + ] + + # Low volume + prices = [] + for i in range(20): + prices.append({ + "date": f"2026020{i:02d}", + "open": 150000 - i * 100, + "high": 151000 - i * 100, + "low": 149000 - i * 100, + "close": 150000 - i * 150, # Declining (would be oversold) + "volume": 1000000, # Current 500k < 2x prev day 1M + }) + mock_broker.get_daily_prices.return_value = prices + + candidates = await scanner.scan() + + # Should be filtered out due to low volume ratio + assert len(candidates) == 0 + + @pytest.mark.asyncio + async def test_scan_filters_neutral_rsi( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test that stocks with neutral RSI are filtered out.""" + mock_broker.fetch_market_rankings.return_value = [ + { + "stock_code": "051910", + "name": "LG Chem", + "price": 500000, + "volume": 3000000, + "change_rate": 0.5, + "volume_increase_rate": 300, # High volume + }, + ] + + # Flat prices (neutral RSI ~50) + prices = [] + for i in range(20): + prices.append({ + "date": f"2026020{i:02d}", + "open": 500000 + (i % 2) * 100, # Small oscillation + "high": 500500, + "low": 499500, + "close": 500000 + (i % 2) * 50, + "volume": 1000000, + }) + mock_broker.get_daily_prices.return_value = prices + + candidates = await scanner.scan() + + # Should be filtered out (RSI ~50, not < 30 or > 70) + assert len(candidates) == 0 + + @pytest.mark.asyncio + async def test_scan_uses_fallback_on_api_error( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test fallback to static list when ranking API fails.""" + mock_broker.fetch_market_rankings.side_effect = ConnectionError("API unavailable") + + # Fallback stocks should still be analyzed + prices = [] + for i in range(20): + prices.append({ + "date": f"2026020{i:02d}", + "open": 50000 - i * 50, + "high": 51000 - i * 50, + "low": 49000 - i * 50, + "close": 50000 - i * 75, # Declining + "volume": 1000000, + }) + mock_broker.get_daily_prices.return_value = prices + + candidates = await scanner.scan(fallback_stocks=["005930", "000660"]) + + # Should not crash + assert isinstance(candidates, list) + + @pytest.mark.asyncio + async def test_scan_returns_top_n_only( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test that scan returns at most top_n candidates.""" + # Return many stocks + mock_broker.fetch_market_rankings.return_value = [ + { + "stock_code": f"00{i}000", + "name": f"Stock{i}", + "price": 10000 * i, + "volume": 5000000, + "change_rate": -10, + "volume_increase_rate": 500, + } + for i in range(1, 10) + ] + + # All oversold with high volume + def make_prices(code: str) -> list[dict]: + prices = [] + for i in range(20): + prices.append({ + "date": f"2026020{i:02d}", + "open": 10000 - i * 100, + "high": 10500 - i * 100, + "low": 9500 - i * 100, + "close": 10000 - i * 150, + "volume": 1000000, + }) + return prices + + mock_broker.get_daily_prices.side_effect = make_prices + + candidates = await scanner.scan() + + # Should respect top_n limit (3) + assert len(candidates) <= scanner.top_n + + @pytest.mark.asyncio + async def test_scan_skips_insufficient_price_history( + self, scanner: SmartVolatilityScanner, mock_broker: MagicMock + ) -> None: + """Test that stocks with insufficient history are skipped.""" + mock_broker.fetch_market_rankings.return_value = [ + { + "stock_code": "005930", + "name": "Samsung", + "price": 70000, + "volume": 5000000, + "change_rate": -5.0, + "volume_increase_rate": 300, + }, + ] + + # Only 5 days of data (need 15+ for RSI) + mock_broker.get_daily_prices.return_value = [ + { + "date": f"2026020{i:02d}", + "open": 70000, + "high": 71000, + "low": 69000, + "close": 70000, + "volume": 2000000, + } + for i in range(5) + ] + + candidates = await scanner.scan() + + # Should skip due to insufficient data + assert len(candidates) == 0 + + @pytest.mark.asyncio + async def test_get_stock_codes( + self, scanner: SmartVolatilityScanner + ) -> None: + """Test extraction of stock codes from candidates.""" + candidates = [ + ScanCandidate( + stock_code="005930", + name="Samsung", + price=70000, + volume=5000000, + volume_ratio=2.5, + rsi=28, + signal="oversold", + score=85.0, + ), + ScanCandidate( + stock_code="035420", + name="NAVER", + price=250000, + volume=3000000, + volume_ratio=3.0, + rsi=75, + signal="momentum", + score=88.0, + ), + ] + + codes = scanner.get_stock_codes(candidates) + + assert codes == ["005930", "035420"] + + +class TestRSICalculation: + """Test RSI calculation in VolatilityAnalyzer.""" + + def test_rsi_oversold(self) -> None: + """Test RSI calculation for downtrending prices.""" + analyzer = VolatilityAnalyzer() + + # Steadily declining prices + prices = [100 - i * 0.5 for i in range(20)] + rsi = analyzer.calculate_rsi(prices, period=14) + + assert rsi < 50 # Should be oversold territory + + def test_rsi_overbought(self) -> None: + """Test RSI calculation for uptrending prices.""" + analyzer = VolatilityAnalyzer() + + # Steadily rising prices + prices = [100 + i * 0.5 for i in range(20)] + rsi = analyzer.calculate_rsi(prices, period=14) + + assert rsi > 50 # Should be overbought territory + + def test_rsi_neutral(self) -> None: + """Test RSI calculation for flat prices.""" + analyzer = VolatilityAnalyzer() + + # Flat prices with small oscillation + prices = [100 + (i % 2) * 0.1 for i in range(20)] + rsi = analyzer.calculate_rsi(prices, period=14) + + assert 40 < rsi < 60 # Should be near neutral + + def test_rsi_insufficient_data(self) -> None: + """Test RSI returns neutral when insufficient data.""" + analyzer = VolatilityAnalyzer() + + prices = [100, 101, 102] # Only 3 prices, need 15+ + rsi = analyzer.calculate_rsi(prices, period=14) + + assert rsi == 50.0 # Default neutral + + def test_rsi_all_gains(self) -> None: + """Test RSI returns 100 when all gains (no losses).""" + analyzer = VolatilityAnalyzer() + + # Monotonic increase + prices = [100 + i for i in range(20)] + rsi = analyzer.calculate_rsi(prices, period=14) + + assert rsi == 100.0 # Maximum RSI