diff --git a/src/brain/gemini_client.py b/src/brain/gemini_client.py index 0d48d28..8163624 100644 --- a/src/brain/gemini_client.py +++ b/src/brain/gemini_client.py @@ -49,15 +49,40 @@ class GeminiClient: The prompt instructs Gemini to return valid JSON with action, confidence, and rationale fields. """ + market_name = market_data.get("market_name", "Korean stock market") + + # Build market data section dynamically based on available fields + market_info_lines = [ + f"Market: {market_name}", + f"Stock Code: {market_data['stock_code']}", + f"Current Price: {market_data['current_price']}", + ] + + # Add orderbook if available (domestic markets) + if "orderbook" in market_data: + market_info_lines.append( + f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}" + ) + + # Add foreigner net if non-zero + if market_data.get("foreigner_net", 0) != 0: + market_info_lines.append( + f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}" + ) + + market_info = "\n".join(market_info_lines) + + json_format = ( + '{"action": "BUY"|"SELL"|"HOLD", ' + '"confidence": , "rationale": ""}' + ) return ( - "You are a professional Korean stock market trading analyst.\n" - "Analyze the following market data and decide whether to BUY, SELL, or HOLD.\n\n" - f"Stock Code: {market_data['stock_code']}\n" - f"Current Price: {market_data['current_price']}\n" - f"Orderbook: {json.dumps(market_data['orderbook'], ensure_ascii=False)}\n" - f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}\n\n" + f"You are a professional {market_name} trading analyst.\n" + "Analyze the following market data and decide whether to " + "BUY, SELL, or HOLD.\n\n" + f"{market_info}\n\n" "You MUST respond with ONLY valid JSON in the following format:\n" - '{"action": "BUY"|"SELL"|"HOLD", "confidence": , "rationale": ""}\n\n' + f"{json_format}\n\n" "Rules:\n" "- action must be exactly one of: BUY, SELL, HOLD\n" "- confidence must be an integer from 0 to 100\n" diff --git a/src/broker/kis_api.py b/src/broker/kis_api.py index 0983a29..e3d88ba 100644 --- a/src/broker/kis_api.py +++ b/src/broker/kis_api.py @@ -6,11 +6,8 @@ Handles token refresh, rate limiting (leaky bucket), and hash key generation. from __future__ import annotations import asyncio -import hashlib -import json import logging import ssl -import time from typing import Any import aiohttp @@ -168,7 +165,7 @@ class KISBroker: f"get_orderbook failed ({resp.status}): {text}" ) return await resp.json() - except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + except (TimeoutError, aiohttp.ClientError) as exc: raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc async def get_balance(self) -> dict[str, Any]: @@ -200,7 +197,7 @@ class KISBroker: f"get_balance failed ({resp.status}): {text}" ) return await resp.json() - except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + except (TimeoutError, aiohttp.ClientError) as exc: raise ConnectionError(f"Network error fetching balance: {exc}") from exc async def send_order( @@ -253,5 +250,5 @@ class KISBroker: }, ) return data - except (aiohttp.ClientError, asyncio.TimeoutError) as exc: + except (TimeoutError, aiohttp.ClientError) as exc: raise ConnectionError(f"Network error sending order: {exc}") from exc diff --git a/src/broker/overseas.py b/src/broker/overseas.py new file mode 100644 index 0000000..874df83 --- /dev/null +++ b/src/broker/overseas.py @@ -0,0 +1,200 @@ +"""KIS Overseas Stock API client.""" + +from __future__ import annotations + +import logging +from typing import Any + +import aiohttp + +from src.broker.kis_api import KISBroker + +logger = logging.getLogger(__name__) + + +class OverseasBroker: + """KIS Overseas Stock API wrapper that reuses KISBroker infrastructure.""" + + def __init__(self, kis_broker: KISBroker) -> None: + """ + Initialize overseas broker. + + Args: + kis_broker: Domestic KIS broker instance to reuse session/token/rate limiter + """ + self._broker = kis_broker + + async def get_overseas_price( + self, exchange_code: str, stock_code: str + ) -> dict[str, Any]: + """ + Fetch overseas stock price. + + Args: + exchange_code: Exchange code (e.g., "NASD", "NYSE", "TSE") + stock_code: Stock ticker symbol + + Returns: + API response with price data + + Raises: + ConnectionError: On network or API errors + """ + await self._broker._rate_limiter.acquire() + session = self._broker._get_session() + + headers = await self._broker._auth_headers("HHDFS00000300") + params = { + "AUTH": "", + "EXCD": exchange_code, + "SYMB": stock_code, + } + url = f"{self._broker._base_url}/uapi/overseas-price/v1/quotations/price" + + try: + async with session.get(url, headers=headers, params=params) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"get_overseas_price failed ({resp.status}): {text}" + ) + return await resp.json() + except (TimeoutError, aiohttp.ClientError) as exc: + raise ConnectionError( + f"Network error fetching overseas price: {exc}" + ) from exc + + async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]: + """ + Fetch overseas account balance. + + Args: + exchange_code: Exchange code (e.g., "NASD", "NYSE") + + Returns: + API response with balance data + + Raises: + ConnectionError: On network or API errors + """ + await self._broker._rate_limiter.acquire() + session = self._broker._get_session() + + # Virtual trading TR_ID for overseas balance inquiry + headers = await self._broker._auth_headers("VTTS3012R") + params = { + "CANO": self._broker._account_no, + "ACNT_PRDT_CD": self._broker._product_cd, + "OVRS_EXCG_CD": exchange_code, + "TR_CRCY_CD": self._get_currency_code(exchange_code), + "CTX_AREA_FK200": "", + "CTX_AREA_NK200": "", + } + url = ( + f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-balance" + ) + + try: + async with session.get(url, headers=headers, params=params) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"get_overseas_balance failed ({resp.status}): {text}" + ) + return await resp.json() + except (TimeoutError, aiohttp.ClientError) as exc: + raise ConnectionError( + f"Network error fetching overseas balance: {exc}" + ) from exc + + async def send_overseas_order( + self, + exchange_code: str, + stock_code: str, + order_type: str, # "BUY" or "SELL" + quantity: int, + price: float = 0.0, + ) -> dict[str, Any]: + """ + Submit overseas stock order. + + Args: + exchange_code: Exchange code (e.g., "NASD", "NYSE") + stock_code: Stock ticker symbol + order_type: "BUY" or "SELL" + quantity: Number of shares + price: Order price (0 for market order) + + Returns: + API response with order result + + Raises: + ConnectionError: On network or API errors + """ + await self._broker._rate_limiter.acquire() + session = self._broker._get_session() + + # Virtual trading TR_IDs for overseas orders + tr_id = "VTTT1002U" if order_type == "BUY" else "VTTT1006U" + + body = { + "CANO": self._broker._account_no, + "ACNT_PRDT_CD": self._broker._product_cd, + "OVRS_EXCG_CD": exchange_code, + "PDNO": stock_code, + "ORD_DVSN": "00" if price > 0 else "01", # 00=지정가, 01=시장가 + "ORD_QTY": str(quantity), + "OVRS_ORD_UNPR": str(price) if price > 0 else "0", + "ORD_SVR_DVSN_CD": "0", # 0=해외주문 + } + + hash_key = await self._broker._get_hash_key(body) + headers = await self._broker._auth_headers(tr_id) + headers["hashkey"] = hash_key + + url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order" + + try: + async with session.post(url, headers=headers, json=body) as resp: + if resp.status != 200: + text = await resp.text() + raise ConnectionError( + f"send_overseas_order failed ({resp.status}): {text}" + ) + data = await resp.json() + logger.info( + "Overseas order submitted", + extra={ + "exchange": exchange_code, + "stock_code": stock_code, + "action": order_type, + }, + ) + return data + except (TimeoutError, aiohttp.ClientError) as exc: + raise ConnectionError( + f"Network error sending overseas order: {exc}" + ) from exc + + def _get_currency_code(self, exchange_code: str) -> str: + """ + Map exchange code to currency code. + + Args: + exchange_code: Exchange code + + Returns: + Currency code (e.g., "USD", "JPY") + """ + currency_map = { + "NASD": "USD", + "NYSE": "USD", + "AMEX": "USD", + "TSE": "JPY", + "SEHK": "HKD", + "SHAA": "CNY", + "SZAA": "CNY", + "HNX": "VND", + "HSX": "VND", + } + return currency_map.get(exchange_code, "USD") diff --git a/src/config.py b/src/config.py index d9bd569..da94f19 100644 --- a/src/config.py +++ b/src/config.py @@ -33,6 +33,9 @@ class Settings(BaseSettings): # Trading mode MODE: str = Field(default="paper", pattern="^(paper|live)$") + # Market selection (comma-separated market codes) + ENABLED_MARKETS: str = "KR" + model_config = {"env_file": ".env", "env_file_encoding": "utf-8"} @property @@ -42,3 +45,8 @@ class Settings(BaseSettings): @property def account_product_code(self) -> str: return self.KIS_ACCOUNT_NO.split("-")[1] + + @property + def enabled_market_list(self) -> list[str]: + """Parse ENABLED_MARKETS into list of market codes.""" + return [m.strip() for m in self.ENABLED_MARKETS.split(",") if m.strip()] diff --git a/src/core/risk_manager.py b/src/core/risk_manager.py index 22d755b..7fd559b 100644 --- a/src/core/risk_manager.py +++ b/src/core/risk_manager.py @@ -7,7 +7,6 @@ Changes require human approval and two passing test suites. from __future__ import annotations import logging -from dataclasses import dataclass from src.config import Settings diff --git a/src/db.py b/src/db.py index c8d4957..f61d84d 100644 --- a/src/db.py +++ b/src/db.py @@ -3,9 +3,8 @@ from __future__ import annotations import sqlite3 -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path -from typing import Any def init_db(db_path: str) -> sqlite3.Connection: @@ -24,10 +23,22 @@ def init_db(db_path: str) -> sqlite3.Connection: rationale TEXT, quantity INTEGER, price REAL, - pnl REAL DEFAULT 0.0 + pnl REAL DEFAULT 0.0, + market TEXT DEFAULT 'KR', + exchange_code TEXT DEFAULT 'KRX' ) """ ) + + # Migration: Add market and exchange_code columns if they don't exist + cursor = conn.execute("PRAGMA table_info(trades)") + columns = {row[1] for row in cursor.fetchall()} + + if "market" not in columns: + conn.execute("ALTER TABLE trades ADD COLUMN market TEXT DEFAULT 'KR'") + if "exchange_code" not in columns: + conn.execute("ALTER TABLE trades ADD COLUMN exchange_code TEXT DEFAULT 'KRX'") + conn.commit() return conn @@ -41,15 +52,20 @@ def log_trade( quantity: int = 0, price: float = 0.0, pnl: float = 0.0, + market: str = "KR", + exchange_code: str = "KRX", ) -> None: """Insert a trade record into the database.""" conn.execute( """ - INSERT INTO trades (timestamp, stock_code, action, confidence, rationale, quantity, price, pnl) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) + INSERT INTO trades ( + timestamp, stock_code, action, confidence, rationale, + quantity, price, pnl, market, exchange_code + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( - datetime.now(timezone.utc).isoformat(), + datetime.now(UTC).isoformat(), stock_code, action, confidence, @@ -57,6 +73,8 @@ def log_trade( quantity, price, pnl, + market, + exchange_code, ), ) conn.commit() diff --git a/src/evolution/optimizer.py b/src/evolution/optimizer.py index eb6d1dc..98016d5 100644 --- a/src/evolution/optimizer.py +++ b/src/evolution/optimizer.py @@ -14,7 +14,7 @@ import logging import sqlite3 import subprocess import textwrap -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -136,7 +136,7 @@ class EvolutionOptimizer: body = "\n".join(lines[1:-1]) # Create strategy file - timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S") version = f"v{timestamp}" class_name = f"Strategy_{version}" file_name = f"{version}_evolved.py" @@ -149,7 +149,7 @@ class EvolutionOptimizer: content = STRATEGY_TEMPLATE.format( name=version, - timestamp=datetime.now(timezone.utc).isoformat(), + timestamp=datetime.now(UTC).isoformat(), rationale="Auto-evolved from failure analysis", class_name=class_name, body=indented_body.strip(), diff --git a/src/logging_config.py b/src/logging_config.py index 54fd50a..644b99e 100644 --- a/src/logging_config.py +++ b/src/logging_config.py @@ -2,20 +2,19 @@ from __future__ import annotations +import json import logging import sys -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any -import json - class JSONFormatter(logging.Formatter): """Emit log records as single-line JSON objects.""" def format(self, record: logging.LogRecord) -> str: log_entry: dict[str, Any] = { - "timestamp": datetime.now(timezone.utc).isoformat(), + "timestamp": datetime.now(UTC).isoformat(), "level": record.levelname, "logger": record.name, "message": record.getMessage(), diff --git a/src/main.py b/src/main.py index cacb329..c95d86c 100644 --- a/src/main.py +++ b/src/main.py @@ -10,66 +10,93 @@ import argparse import asyncio import logging import signal -import sys +from datetime import UTC, datetime from typing import Any from src.brain.gemini_client import GeminiClient from src.broker.kis_api import KISBroker +from src.broker.overseas import OverseasBroker from src.config import Settings from src.core.risk_manager import CircuitBreakerTripped, RiskManager from src.db import init_db, log_trade from src.logging_config import setup_logging +from src.markets.schedule import MarketInfo, get_next_market_open, get_open_markets logger = logging.getLogger(__name__) -# Target stock codes to monitor -WATCHLIST = ["005930", "000660", "035420"] # Samsung, SK Hynix, NAVER +# Target stock codes to monitor per market +WATCHLISTS = { + "KR": ["005930", "000660", "035420"], # Samsung, SK Hynix, NAVER + "US_NASDAQ": ["AAPL", "MSFT", "GOOGL"], # Example US stocks + "US_NYSE": ["JPM", "BAC"], # Example NYSE stocks + "JP": ["7203", "6758"], # Toyota, Sony +} TRADE_INTERVAL_SECONDS = 60 +MAX_CONNECTION_RETRIES = 3 async def trading_cycle( broker: KISBroker, + overseas_broker: OverseasBroker, brain: GeminiClient, risk: RiskManager, db_conn: Any, + market: MarketInfo, stock_code: str, ) -> None: """Execute one trading cycle for a single stock.""" # 1. Fetch market data - orderbook = await broker.get_orderbook(stock_code) - balance_data = await broker.get_balance() + if market.is_domestic: + orderbook = await broker.get_orderbook(stock_code) + balance_data = await broker.get_balance() - output2 = balance_data.get("output2", [{}]) - total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0 - total_cash = float( - balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") - if output2 - else "0" - ) - purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0 + output2 = balance_data.get("output2", [{}]) + total_eval = float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0 + total_cash = float( + balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") + if output2 + else "0" + ) + purchase_total = float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0 + + current_price = float(orderbook.get("output1", {}).get("stck_prpr", "0")) + foreigner_net = float(orderbook.get("output1", {}).get("frgn_ntby_qty", "0")) + else: + # Overseas market + price_data = await overseas_broker.get_overseas_price( + market.exchange_code, stock_code + ) + balance_data = await overseas_broker.get_overseas_balance(market.exchange_code) + + output2 = balance_data.get("output2", [{}]) + total_eval = float(output2[0].get("frcr_evlu_tota", "0")) if output2 else 0 + total_cash = float(output2[0].get("frcr_dncl_amt_2", "0")) if output2 else 0 + purchase_total = float(output2[0].get("frcr_buy_amt_smtl", "0")) if output2 else 0 + + current_price = float(price_data.get("output", {}).get("last", "0")) + foreigner_net = 0.0 # Not available for overseas # Calculate daily P&L % - pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0 - - current_price = float( - orderbook.get("output1", {}).get("stck_prpr", "0") + pnl_pct = ( + ((total_eval - purchase_total) / purchase_total * 100) + if purchase_total > 0 + else 0.0 ) market_data = { "stock_code": stock_code, + "market_name": market.name, "current_price": current_price, - "orderbook": orderbook.get("output1", {}), - "foreigner_net": float( - orderbook.get("output1", {}).get("frgn_ntby_qty", "0") - ), + "foreigner_net": foreigner_net, } # 2. Ask the brain for a decision decision = await brain.decide(market_data) logger.info( - "Decision for %s: %s (confidence=%d)", + "Decision for %s (%s): %s (confidence=%d)", stock_code, + market.name, decision.action, decision.confidence, ) @@ -88,12 +115,21 @@ async def trading_cycle( ) # 5. Send order - result = await broker.send_order( - stock_code=stock_code, - order_type=decision.action, - quantity=quantity, - price=0, # market order - ) + if market.is_domestic: + result = await broker.send_order( + stock_code=stock_code, + order_type=decision.action, + quantity=quantity, + price=0, # market order + ) + else: + result = await overseas_broker.send_overseas_order( + exchange_code=market.exchange_code, + stock_code=stock_code, + order_type=decision.action, + quantity=quantity, + price=0.0, # market order + ) logger.info("Order result: %s", result.get("msg1", "OK")) # 6. Log trade @@ -103,12 +139,15 @@ async def trading_cycle( action=decision.action, confidence=decision.confidence, rationale=decision.rationale, + market=market.code, + exchange_code=market.exchange_code, ) async def run(settings: Settings) -> None: - """Main async loop — iterate over watchlist on a timer.""" + """Main async loop — iterate over open markets on a timer.""" broker = KISBroker(settings) + overseas_broker = OverseasBroker(broker) brain = GeminiClient(settings) risk = RiskManager(settings) db_conn = init_db(settings.DB_PATH) @@ -124,27 +163,93 @@ async def run(settings: Settings) -> None: loop.add_signal_handler(sig, _signal_handler) logger.info("The Ouroboros is alive. Mode: %s", settings.MODE) - logger.info("Watchlist: %s", WATCHLIST) + logger.info("Enabled markets: %s", settings.enabled_market_list) try: while not shutdown.is_set(): - for code in WATCHLIST: + # Get currently open markets + open_markets = get_open_markets(settings.enabled_market_list) + + if not open_markets: + # No markets open — wait until next market opens + try: + next_market, next_open_time = get_next_market_open( + settings.enabled_market_list + ) + now = datetime.now(UTC) + wait_seconds = (next_open_time - now).total_seconds() + logger.info( + "No markets open. Next market: %s, opens in %.1f hours", + next_market.name, + wait_seconds / 3600, + ) + await asyncio.wait_for(shutdown.wait(), timeout=wait_seconds) + except TimeoutError: + continue # Market should be open now + except ValueError as exc: + logger.error("Failed to find next market open: %s", exc) + await asyncio.sleep(TRADE_INTERVAL_SECONDS) + continue + + # Process each open market + for market in open_markets: if shutdown.is_set(): break - try: - await trading_cycle(broker, brain, risk, db_conn, code) - except CircuitBreakerTripped: - logger.critical("Circuit breaker tripped — shutting down") - raise - except ConnectionError as exc: - logger.error("Connection error for %s: %s", code, exc) - except Exception as exc: - logger.exception("Unexpected error for %s: %s", code, exc) + + # Get watchlist for this market + watchlist = WATCHLISTS.get(market.code, []) + if not watchlist: + logger.debug("No watchlist for market %s", market.code) + continue + + logger.info("Processing market: %s (%d stocks)", market.name, len(watchlist)) + + # Process each stock in the watchlist + for stock_code in watchlist: + if shutdown.is_set(): + break + + # Retry logic for connection errors + for attempt in range(1, MAX_CONNECTION_RETRIES + 1): + try: + await trading_cycle( + broker, + overseas_broker, + brain, + risk, + db_conn, + market, + stock_code, + ) + break # Success — exit retry loop + except CircuitBreakerTripped: + logger.critical("Circuit breaker tripped — shutting down") + raise + except ConnectionError as exc: + if attempt < MAX_CONNECTION_RETRIES: + logger.warning( + "Connection error for %s (attempt %d/%d): %s", + stock_code, + attempt, + MAX_CONNECTION_RETRIES, + exc, + ) + await asyncio.sleep(2**attempt) # Exponential backoff + else: + logger.error( + "Connection error for %s (all retries exhausted): %s", + stock_code, + exc, + ) + break # Give up on this stock + except Exception as exc: + logger.exception("Unexpected error for %s: %s", stock_code, exc) + break # Don't retry on unexpected errors # Wait for next cycle or shutdown try: await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS) - except asyncio.TimeoutError: + except TimeoutError: pass # Normal — timeout means it's time for next cycle finally: await broker.close() diff --git a/src/markets/__init__.py b/src/markets/__init__.py new file mode 100644 index 0000000..b49e4df --- /dev/null +++ b/src/markets/__init__.py @@ -0,0 +1 @@ +"""Global market scheduling and timezone management.""" diff --git a/src/markets/schedule.py b/src/markets/schedule.py new file mode 100644 index 0000000..0adfe56 --- /dev/null +++ b/src/markets/schedule.py @@ -0,0 +1,252 @@ +"""Market schedule management with timezone support.""" + +from dataclasses import dataclass +from datetime import datetime, time, timedelta +from zoneinfo import ZoneInfo + + +@dataclass(frozen=True) +class MarketInfo: + """Information about a trading market.""" + + code: str # Market code for internal use (e.g., "KR", "US_NASDAQ") + exchange_code: str # KIS API exchange code (e.g., "NASD", "NYSE") + name: str # Human-readable name + timezone: ZoneInfo # Market timezone + open_time: time # Market open time in local timezone + close_time: time # Market close time in local timezone + is_domestic: bool # True for Korean market, False for overseas + lunch_break: tuple[time, time] | None = None # (start, end) or None + + +# 10 global markets with their schedules +MARKETS: dict[str, MarketInfo] = { + "KR": MarketInfo( + code="KR", + exchange_code="KRX", + name="Korea Exchange", + timezone=ZoneInfo("Asia/Seoul"), + open_time=time(9, 0), + close_time=time(15, 30), + is_domestic=True, + lunch_break=None, # KRX removed lunch break + ), + "US_NASDAQ": MarketInfo( + code="US_NASDAQ", + exchange_code="NASD", + name="NASDAQ", + timezone=ZoneInfo("America/New_York"), + open_time=time(9, 30), + close_time=time(16, 0), + is_domestic=False, + lunch_break=None, + ), + "US_NYSE": MarketInfo( + code="US_NYSE", + exchange_code="NYSE", + name="New York Stock Exchange", + timezone=ZoneInfo("America/New_York"), + open_time=time(9, 30), + close_time=time(16, 0), + is_domestic=False, + lunch_break=None, + ), + "US_AMEX": MarketInfo( + code="US_AMEX", + exchange_code="AMEX", + name="NYSE American", + timezone=ZoneInfo("America/New_York"), + open_time=time(9, 30), + close_time=time(16, 0), + is_domestic=False, + lunch_break=None, + ), + "JP": MarketInfo( + code="JP", + exchange_code="TSE", + name="Tokyo Stock Exchange", + timezone=ZoneInfo("Asia/Tokyo"), + open_time=time(9, 0), + close_time=time(15, 0), + is_domestic=False, + lunch_break=(time(11, 30), time(12, 30)), + ), + "HK": MarketInfo( + code="HK", + exchange_code="SEHK", + name="Hong Kong Stock Exchange", + timezone=ZoneInfo("Asia/Hong_Kong"), + open_time=time(9, 30), + close_time=time(16, 0), + is_domestic=False, + lunch_break=(time(12, 0), time(13, 0)), + ), + "CN_SHA": MarketInfo( + code="CN_SHA", + exchange_code="SHAA", + name="Shanghai Stock Exchange", + timezone=ZoneInfo("Asia/Shanghai"), + open_time=time(9, 30), + close_time=time(15, 0), + is_domestic=False, + lunch_break=(time(11, 30), time(13, 0)), + ), + "CN_SZA": MarketInfo( + code="CN_SZA", + exchange_code="SZAA", + name="Shenzhen Stock Exchange", + timezone=ZoneInfo("Asia/Shanghai"), + open_time=time(9, 30), + close_time=time(15, 0), + is_domestic=False, + lunch_break=(time(11, 30), time(13, 0)), + ), + "VN_HAN": MarketInfo( + code="VN_HAN", + exchange_code="HNX", + name="Hanoi Stock Exchange", + timezone=ZoneInfo("Asia/Ho_Chi_Minh"), + open_time=time(9, 0), + close_time=time(15, 0), + is_domestic=False, + lunch_break=(time(11, 30), time(13, 0)), + ), + "VN_HCM": MarketInfo( + code="VN_HCM", + exchange_code="HSX", + name="Ho Chi Minh Stock Exchange", + timezone=ZoneInfo("Asia/Ho_Chi_Minh"), + open_time=time(9, 0), + close_time=time(15, 0), + is_domestic=False, + lunch_break=(time(11, 30), time(13, 0)), + ), +} + + +def is_market_open(market: MarketInfo, now: datetime | None = None) -> bool: + """ + Check if a market is currently open for trading. + + Args: + market: Market information + now: Current time (defaults to datetime.now(UTC) for testing) + + Returns: + True if market is open, False otherwise + + Note: + Does not account for holidays (KIS API will reject orders on holidays) + """ + if now is None: + now = datetime.now(ZoneInfo("UTC")) + + # Convert to market's local timezone + local_now = now.astimezone(market.timezone) + + # Check if it's a weekend + if local_now.weekday() >= 5: # Saturday=5, Sunday=6 + return False + + current_time = local_now.time() + + # Check if within trading hours + if current_time < market.open_time or current_time >= market.close_time: + return False + + # Check lunch break + if market.lunch_break: + lunch_start, lunch_end = market.lunch_break + if lunch_start <= current_time < lunch_end: + return False + + return True + + +def get_open_markets( + enabled_markets: list[str] | None = None, now: datetime | None = None +) -> list[MarketInfo]: + """ + Get list of currently open markets. + + Args: + enabled_markets: List of market codes to check (defaults to all markets) + now: Current time (defaults to datetime.now(UTC) for testing) + + Returns: + List of open markets, sorted by market code + """ + if enabled_markets is None: + enabled_markets = list(MARKETS.keys()) + + open_markets = [ + MARKETS[code] + for code in enabled_markets + if code in MARKETS and is_market_open(MARKETS[code], now) + ] + + return sorted(open_markets, key=lambda m: m.code) + + +def get_next_market_open( + enabled_markets: list[str] | None = None, now: datetime | None = None +) -> tuple[MarketInfo, datetime]: + """ + Find the next market that will open and when. + + Args: + enabled_markets: List of market codes to check (defaults to all markets) + now: Current time (defaults to datetime.now(UTC) for testing) + + Returns: + Tuple of (market, open_datetime) for the next market to open + + Raises: + ValueError: If no enabled markets are configured + """ + if now is None: + now = datetime.now(ZoneInfo("UTC")) + + if enabled_markets is None: + enabled_markets = list(MARKETS.keys()) + + if not enabled_markets: + raise ValueError("No enabled markets configured") + + next_open_time: datetime | None = None + next_market: MarketInfo | None = None + + for code in enabled_markets: + if code not in MARKETS: + continue + + market = MARKETS[code] + market_now = now.astimezone(market.timezone) + + # Calculate next open time for this market + for days_ahead in range(7): # Check next 7 days + check_date = market_now.date() + timedelta(days=days_ahead) + check_datetime = datetime.combine( + check_date, market.open_time, tzinfo=market.timezone + ) + + # Skip weekends + if check_datetime.weekday() >= 5: + continue + + # Skip if this open time already passed today + if check_datetime <= market_now: + continue + + # Convert to UTC for comparison + check_datetime_utc = check_datetime.astimezone(ZoneInfo("UTC")) + + if next_open_time is None or check_datetime_utc < next_open_time: + next_open_time = check_datetime_utc + next_market = market + break + + if next_market is None or next_open_time is None: + raise ValueError("Could not find next market open time") + + return next_market, next_open_time diff --git a/tests/conftest.py b/tests/conftest.py index fdb0b08..30fae40 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -20,4 +20,5 @@ def settings() -> Settings: FAT_FINGER_PCT=30.0, CONFIDENCE_THRESHOLD=80, DB_PATH=":memory:", + ENABLED_MARKETS="KR", ) diff --git a/tests/test_brain.py b/tests/test_brain.py index 204fcd1..d464f12 100644 --- a/tests/test_brain.py +++ b/tests/test_brain.py @@ -2,12 +2,7 @@ from __future__ import annotations -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from src.brain.gemini_client import GeminiClient, TradeDecision - +from src.brain.gemini_client import GeminiClient # --------------------------------------------------------------------------- # Response Parsing diff --git a/tests/test_broker.py b/tests/test_broker.py index 5d8d2ac..fc88996 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -3,14 +3,12 @@ from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch -import aiohttp import pytest from src.broker.kis_api import KISBroker - # --------------------------------------------------------------------------- # Token Management # --------------------------------------------------------------------------- @@ -68,7 +66,7 @@ class TestNetworkErrorHandling: with patch( "aiohttp.ClientSession.get", - side_effect=asyncio.TimeoutError(), + side_effect=TimeoutError(), ): with pytest.raises(ConnectionError): await broker.get_orderbook("005930") diff --git a/tests/test_market_schedule.py b/tests/test_market_schedule.py new file mode 100644 index 0000000..ea33ab4 --- /dev/null +++ b/tests/test_market_schedule.py @@ -0,0 +1,201 @@ +"""Tests for market schedule management.""" + +from datetime import datetime +from zoneinfo import ZoneInfo + +import pytest + +from src.markets.schedule import ( + MARKETS, + get_next_market_open, + get_open_markets, + is_market_open, +) + + +class TestMarketInfo: + """Test MarketInfo dataclass.""" + + def test_market_info_immutable(self) -> None: + """MarketInfo should be frozen.""" + market = MARKETS["KR"] + with pytest.raises(AttributeError): + market.code = "US" # type: ignore[misc] + + def test_all_markets_defined(self) -> None: + """All 10 markets should be defined.""" + expected_markets = { + "KR", + "US_NASDAQ", + "US_NYSE", + "US_AMEX", + "JP", + "HK", + "CN_SHA", + "CN_SZA", + "VN_HAN", + "VN_HCM", + } + assert set(MARKETS.keys()) == expected_markets + + +class TestIsMarketOpen: + """Test is_market_open function.""" + + def test_kr_market_open_weekday(self) -> None: + """KR market should be open during trading hours on weekday.""" + # Monday 2026-02-02 10:00 KST + test_time = datetime(2026, 2, 2, 10, 0, tzinfo=ZoneInfo("Asia/Seoul")) + assert is_market_open(MARKETS["KR"], test_time) + + def test_kr_market_closed_before_open(self) -> None: + """KR market should be closed before 9:00.""" + # Monday 2026-02-02 08:30 KST + test_time = datetime(2026, 2, 2, 8, 30, tzinfo=ZoneInfo("Asia/Seoul")) + assert not is_market_open(MARKETS["KR"], test_time) + + def test_kr_market_closed_after_close(self) -> None: + """KR market should be closed after 15:30.""" + # Monday 2026-02-02 15:30 KST (exact close time) + test_time = datetime(2026, 2, 2, 15, 30, tzinfo=ZoneInfo("Asia/Seoul")) + assert not is_market_open(MARKETS["KR"], test_time) + + def test_kr_market_closed_weekend(self) -> None: + """KR market should be closed on weekends.""" + # Saturday 2026-02-07 10:00 KST + test_time = datetime(2026, 2, 7, 10, 0, tzinfo=ZoneInfo("Asia/Seoul")) + assert not is_market_open(MARKETS["KR"], test_time) + + # Sunday 2026-02-08 10:00 KST + test_time = datetime(2026, 2, 8, 10, 0, tzinfo=ZoneInfo("Asia/Seoul")) + assert not is_market_open(MARKETS["KR"], test_time) + + def test_us_nasdaq_open_with_dst(self) -> None: + """US markets should respect DST.""" + # Monday 2026-06-01 10:00 EDT (DST in effect) + test_time = datetime(2026, 6, 1, 10, 0, tzinfo=ZoneInfo("America/New_York")) + assert is_market_open(MARKETS["US_NASDAQ"], test_time) + + # Monday 2026-12-07 10:00 EST (no DST) + test_time = datetime(2026, 12, 7, 10, 0, tzinfo=ZoneInfo("America/New_York")) + assert is_market_open(MARKETS["US_NASDAQ"], test_time) + + def test_jp_market_lunch_break(self) -> None: + """JP market should be closed during lunch break.""" + # Monday 2026-02-02 12:00 JST (lunch break) + test_time = datetime(2026, 2, 2, 12, 0, tzinfo=ZoneInfo("Asia/Tokyo")) + assert not is_market_open(MARKETS["JP"], test_time) + + # Before lunch + test_time = datetime(2026, 2, 2, 11, 0, tzinfo=ZoneInfo("Asia/Tokyo")) + assert is_market_open(MARKETS["JP"], test_time) + + # After lunch + test_time = datetime(2026, 2, 2, 13, 0, tzinfo=ZoneInfo("Asia/Tokyo")) + assert is_market_open(MARKETS["JP"], test_time) + + def test_hk_market_lunch_break(self) -> None: + """HK market should be closed during lunch break.""" + # Monday 2026-02-02 12:30 HKT (lunch break) + test_time = datetime(2026, 2, 2, 12, 30, tzinfo=ZoneInfo("Asia/Hong_Kong")) + assert not is_market_open(MARKETS["HK"], test_time) + + def test_timezone_conversion(self) -> None: + """Should correctly convert timezones.""" + # 2026-02-02 10:00 KST = 2026-02-02 01:00 UTC + test_time = datetime(2026, 2, 2, 1, 0, tzinfo=ZoneInfo("UTC")) + assert is_market_open(MARKETS["KR"], test_time) + + +class TestGetOpenMarkets: + """Test get_open_markets function.""" + + def test_get_open_markets_all_closed(self) -> None: + """Should return empty list when all markets closed.""" + # Sunday 2026-02-08 12:00 UTC (all markets closed) + test_time = datetime(2026, 2, 8, 12, 0, tzinfo=ZoneInfo("UTC")) + assert get_open_markets(now=test_time) == [] + + def test_get_open_markets_kr_only(self) -> None: + """Should return only KR when filtering enabled markets.""" + # Monday 2026-02-02 10:00 KST = 01:00 UTC + test_time = datetime(2026, 2, 2, 1, 0, tzinfo=ZoneInfo("UTC")) + open_markets = get_open_markets(enabled_markets=["KR"], now=test_time) + assert len(open_markets) == 1 + assert open_markets[0].code == "KR" + + def test_get_open_markets_multiple(self) -> None: + """Should return multiple markets when open.""" + # Monday 2026-02-02 14:30 EST = 19:30 UTC + # US markets: 9:30-16:00 EST → 14:30-21:00 UTC (open) + test_time = datetime(2026, 2, 2, 19, 30, tzinfo=ZoneInfo("UTC")) + open_markets = get_open_markets( + enabled_markets=["US_NASDAQ", "US_NYSE", "US_AMEX"], now=test_time + ) + assert len(open_markets) == 3 + codes = {m.code for m in open_markets} + assert codes == {"US_NASDAQ", "US_NYSE", "US_AMEX"} + + def test_get_open_markets_sorted(self) -> None: + """Should return markets sorted by code.""" + # Monday 2026-02-02 14:30 EST + test_time = datetime(2026, 2, 2, 19, 30, tzinfo=ZoneInfo("UTC")) + open_markets = get_open_markets( + enabled_markets=["US_NYSE", "US_AMEX", "US_NASDAQ"], now=test_time + ) + codes = [m.code for m in open_markets] + assert codes == sorted(codes) + + +class TestGetNextMarketOpen: + """Test get_next_market_open function.""" + + def test_get_next_market_open_weekend(self) -> None: + """Should find next Monday opening when called on weekend.""" + # Saturday 2026-02-07 12:00 UTC + test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) + market, open_time = get_next_market_open( + enabled_markets=["KR"], now=test_time + ) + assert market.code == "KR" + # Monday 2026-02-09 09:00 KST + expected = datetime(2026, 2, 9, 9, 0, tzinfo=ZoneInfo("Asia/Seoul")) + assert open_time == expected.astimezone(ZoneInfo("UTC")) + + def test_get_next_market_open_after_close(self) -> None: + """Should find next day opening when called after market close.""" + # Monday 2026-02-02 16:00 KST (after close) + test_time = datetime(2026, 2, 2, 16, 0, tzinfo=ZoneInfo("Asia/Seoul")) + market, open_time = get_next_market_open( + enabled_markets=["KR"], now=test_time + ) + assert market.code == "KR" + # Tuesday 2026-02-03 09:00 KST + expected = datetime(2026, 2, 3, 9, 0, tzinfo=ZoneInfo("Asia/Seoul")) + assert open_time == expected.astimezone(ZoneInfo("UTC")) + + def test_get_next_market_open_multiple_markets(self) -> None: + """Should find earliest opening market among multiple.""" + # Saturday 2026-02-07 12:00 UTC + test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) + market, open_time = get_next_market_open( + enabled_markets=["KR", "US_NASDAQ"], now=test_time + ) + # Monday 2026-02-09: KR opens at 09:00 KST = 00:00 UTC + # Monday 2026-02-09: US opens at 09:30 EST = 14:30 UTC + # KR opens first + assert market.code == "KR" + + def test_get_next_market_open_no_markets(self) -> None: + """Should raise ValueError when no markets enabled.""" + test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) + with pytest.raises(ValueError, match="No enabled markets"): + get_next_market_open(enabled_markets=[], now=test_time) + + def test_get_next_market_open_invalid_market(self) -> None: + """Should skip invalid market codes.""" + test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) + market, _ = get_next_market_open( + enabled_markets=["INVALID", "KR"], now=test_time + ) + assert market.code == "KR" diff --git a/tests/test_risk.py b/tests/test_risk.py index bebdb5a..9a3e74b 100644 --- a/tests/test_risk.py +++ b/tests/test_risk.py @@ -10,7 +10,6 @@ from src.core.risk_manager import ( RiskManager, ) - # --------------------------------------------------------------------------- # Circuit Breaker Tests # ---------------------------------------------------------------------------