From 6f047a6daf79355b12b3247080a9c0e549505c74 Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:02:48 +0900 Subject: [PATCH 1/7] ci: add --ci mode for session handover gate in workflows (#353) --- .gitea/workflows/ci.yml | 2 +- .github/workflows/ci.yml | 2 +- scripts/session_handover_check.py | 18 ++++++++++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 9fa9522..9ee06db 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -25,7 +25,7 @@ jobs: run: pip install ".[dev]" - name: Session handover gate - run: python3 scripts/session_handover_check.py --strict + run: python3 scripts/session_handover_check.py --strict --ci - name: Validate governance assets env: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index da84fc7..40f340d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,7 +22,7 @@ jobs: run: pip install ".[dev]" - name: Session handover gate - run: python3 scripts/session_handover_check.py --strict + run: python3 scripts/session_handover_check.py --strict --ci - name: Validate governance assets env: diff --git a/scripts/session_handover_check.py b/scripts/session_handover_check.py index b2ded16..7b354be 100755 --- a/scripts/session_handover_check.py +++ b/scripts/session_handover_check.py @@ -66,6 +66,7 @@ def _check_handover_entry( *, branch: str, strict: bool, + ci_mode: bool, errors: list[str], ) -> None: if not HANDOVER_LOG.exists(): @@ -87,7 +88,7 @@ def _check_handover_entry( if token not in latest: errors.append(f"latest handover entry missing token: {token}") - if strict: + if strict and not ci_mode: today_utc = datetime.now(UTC).date().isoformat() if today_utc not in latest: errors.append( @@ -117,6 +118,14 @@ def main() -> int: action="store_true", help="Enforce today-date and current-branch match on latest handover entry.", ) + parser.add_argument( + "--ci", + action="store_true", + help=( + "CI mode: keep structural/token checks but skip strict " + "today-date/current-branch matching." + ), + ) args = parser.parse_args() errors: list[str] = [] @@ -128,7 +137,12 @@ def main() -> int: elif branch in {"main", "master"}: errors.append(f"working branch must not be {branch}") - _check_handover_entry(branch=branch, strict=args.strict, errors=errors) + _check_handover_entry( + branch=branch, + strict=args.strict, + ci_mode=args.ci, + errors=errors, + ) if errors: print("[FAIL] session handover check failed") -- 2.49.1 From 5730f0db2acc37cffbe08ef7fd85a40230e97476 Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:17:13 +0900 Subject: [PATCH 2/7] ci: fix lint baseline and stabilize failing main tests --- src/analysis/backtest_cost_guard.py | 2 +- src/analysis/backtest_execution_model.py | 7 +- src/analysis/backtest_pipeline.py | 3 +- src/analysis/scanner.py | 16 +- src/analysis/smart_scanner.py | 25 +- src/analysis/triple_barrier.py | 9 +- src/analysis/volatility.py | 26 +- src/backup/__init__.py | 4 +- src/backup/cloud_storage.py | 4 +- src/backup/exporter.py | 19 +- src/backup/health_monitor.py | 10 +- src/backup/scheduler.py | 12 +- src/brain/cache.py | 10 +- src/brain/context_selector.py | 8 +- src/brain/gemini_client.py | 54 +-- src/brain/prompt_optimizer.py | 3 +- src/broker/kis_api.py | 86 ++-- src/broker/overseas.py | 90 ++-- src/config.py | 20 +- src/context/aggregator.py | 12 +- src/context/layer.py | 4 +- src/context/summarizer.py | 2 +- src/core/kill_switch.py | 3 +- src/core/order_policy.py | 14 +- src/core/priority_queue.py | 4 +- src/core/risk_manager.py | 10 +- src/dashboard/app.py | 9 +- src/data/economic_calendar.py | 1 - src/db.py | 11 +- src/evolution/ab_test.py | 7 +- src/evolution/daily_review.py | 4 +- src/evolution/optimizer.py | 52 ++- src/evolution/performance_tracker.py | 12 +- src/logging/decision_logger.py | 4 +- src/main.py | 332 +++++++-------- src/markets/schedule.py | 8 +- src/notifications/telegram_client.py | 62 +-- src/strategy/models.py | 8 +- src/strategy/playbook_store.py | 7 +- src/strategy/position_state_machine.py | 11 +- src/strategy/pre_market_planner.py | 39 +- src/strategy/scenario_engine.py | 46 +- tests/test_backup.py | 108 ++--- tests/test_brain.py | 9 +- tests/test_broker.py | 89 ++-- tests/test_context.py | 96 ++--- tests/test_daily_review.py | 33 +- tests/test_dashboard.py | 1 + tests/test_data_integration.py | 14 +- tests/test_db.py | 9 +- tests/test_decision_logger.py | 5 +- tests/test_evolution.py | 34 +- tests/test_logging_config.py | 4 +- tests/test_main.py | 509 +++++++++++++---------- tests/test_market_schedule.py | 16 +- tests/test_overseas_broker.py | 127 ++---- tests/test_pre_market_planner.py | 35 +- tests/test_scenario_engine.py | 86 ++-- tests/test_smart_scanner.py | 7 +- tests/test_strategy_models.py | 1 - tests/test_telegram.py | 126 +++--- tests/test_telegram_commands.py | 16 +- tests/test_validate_governance_assets.py | 4 +- tests/test_volatility.py | 22 +- 64 files changed, 1041 insertions(+), 1380 deletions(-) diff --git a/src/analysis/backtest_cost_guard.py b/src/analysis/backtest_cost_guard.py index 8f2cf98..97e1cd3 100644 --- a/src/analysis/backtest_cost_guard.py +++ b/src/analysis/backtest_cost_guard.py @@ -2,8 +2,8 @@ from __future__ import annotations -from dataclasses import dataclass import math +from dataclasses import dataclass @dataclass(frozen=True) diff --git a/src/analysis/backtest_execution_model.py b/src/analysis/backtest_execution_model.py index 24798dc..704b804 100644 --- a/src/analysis/backtest_execution_model.py +++ b/src/analysis/backtest_execution_model.py @@ -2,12 +2,11 @@ from __future__ import annotations -from dataclasses import dataclass import math +from dataclasses import dataclass from random import Random from typing import Literal - OrderSide = Literal["BUY", "SELL"] @@ -77,7 +76,9 @@ class BacktestExecutionModel: reason="execution_failure", ) - slip_mult = 1.0 + (slippage_bps / 10000.0 if request.side == "BUY" else -slippage_bps / 10000.0) + slip_mult = 1.0 + ( + slippage_bps / 10000.0 if request.side == "BUY" else -slippage_bps / 10000.0 + ) exec_price = request.reference_price * slip_mult if self._rng.random() < partial_rate: diff --git a/src/analysis/backtest_pipeline.py b/src/analysis/backtest_pipeline.py index ba49289..985e0e0 100644 --- a/src/analysis/backtest_pipeline.py +++ b/src/analysis/backtest_pipeline.py @@ -10,8 +10,7 @@ from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime from statistics import mean -from typing import Literal -from typing import cast +from typing import Literal, cast from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier diff --git a/src/analysis/scanner.py b/src/analysis/scanner.py index 50d34ba..8b9d379 100644 --- a/src/analysis/scanner.py +++ b/src/analysis/scanner.py @@ -104,6 +104,7 @@ class MarketScanner: # Store in L7 real-time layer from datetime import UTC, datetime + timeframe = datetime.now(UTC).isoformat() self.context_store.set_context( ContextLayer.L7_REALTIME, @@ -158,12 +159,8 @@ class MarketScanner: top_movers = valid_metrics[: self.top_n] # Detect breakouts and breakdowns - breakouts = [ - m.stock_code for m in valid_metrics if self.analyzer.is_breakout(m) - ] - breakdowns = [ - m.stock_code for m in valid_metrics if self.analyzer.is_breakdown(m) - ] + breakouts = [m.stock_code for m in valid_metrics if self.analyzer.is_breakout(m)] + breakdowns = [m.stock_code for m in valid_metrics if self.analyzer.is_breakdown(m)] logger.info( "%s scan complete: %d scanned, top momentum=%.1f, %d breakouts, %d breakdowns", @@ -228,10 +225,9 @@ class MarketScanner: # If we removed too many, backfill from current watchlist if len(updated) < len(current_watchlist): - backfill = [ - code for code in current_watchlist - if code not in updated - ][: len(current_watchlist) - len(updated)] + backfill = [code for code in current_watchlist if code not in updated][ + : len(current_watchlist) - len(updated) + ] updated.extend(backfill) logger.info( diff --git a/src/analysis/smart_scanner.py b/src/analysis/smart_scanner.py index 7717166..63d3fe1 100644 --- a/src/analysis/smart_scanner.py +++ b/src/analysis/smart_scanner.py @@ -158,7 +158,12 @@ class SmartVolatilityScanner: price = latest_close latest_high = _safe_float(latest.get("high")) latest_low = _safe_float(latest.get("low")) - if latest_close > 0 and latest_high > 0 and latest_low > 0 and latest_high >= latest_low: + if ( + latest_close > 0 + and latest_high > 0 + and latest_low > 0 + and latest_high >= latest_low + ): intraday_range_pct = (latest_high - latest_low) / latest_close * 100.0 if volume <= 0: volume = _safe_float(latest.get("volume")) @@ -234,9 +239,7 @@ class SmartVolatilityScanner: limit=50, ) except Exception as exc: - logger.warning( - "Overseas fluctuation ranking failed for %s: %s", market.code, exc - ) + logger.warning("Overseas fluctuation ranking failed for %s: %s", market.code, exc) fluct_rows = [] if not fluct_rows: @@ -250,9 +253,7 @@ class SmartVolatilityScanner: limit=50, ) except Exception as exc: - logger.warning( - "Overseas volume ranking failed for %s: %s", market.code, exc - ) + logger.warning("Overseas volume ranking failed for %s: %s", market.code, exc) volume_rows = [] for idx, row in enumerate(volume_rows): @@ -433,16 +434,10 @@ def _extract_intraday_range_pct(row: dict[str, Any], price: float) -> float: if price <= 0: return 0.0 high = _safe_float( - row.get("high") - or row.get("ovrs_hgpr") - or row.get("stck_hgpr") - or row.get("day_hgpr") + row.get("high") or row.get("ovrs_hgpr") or row.get("stck_hgpr") or row.get("day_hgpr") ) low = _safe_float( - row.get("low") - or row.get("ovrs_lwpr") - or row.get("stck_lwpr") - or row.get("day_lwpr") + row.get("low") or row.get("ovrs_lwpr") or row.get("stck_lwpr") or row.get("day_lwpr") ) if high <= 0 or low <= 0 or high < low: return 0.0 diff --git a/src/analysis/triple_barrier.py b/src/analysis/triple_barrier.py index 793250d..11c7018 100644 --- a/src/analysis/triple_barrier.py +++ b/src/analysis/triple_barrier.py @@ -6,10 +6,10 @@ Implements first-touch labeling with upper/lower/time barriers. from __future__ import annotations import warnings +from collections.abc import Sequence from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Literal, Sequence - +from typing import Literal TieBreakMode = Literal["stop_first", "take_first"] @@ -92,7 +92,10 @@ def label_with_triple_barrier( else: assert spec.max_holding_bars is not None warnings.warn( - "TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.", + ( + "TripleBarrierSpec.max_holding_bars is deprecated; " + "use max_holding_minutes with timestamps instead." + ), DeprecationWarning, stacklevel=2, ) diff --git a/src/analysis/volatility.py b/src/analysis/volatility.py index 0794220..a974e0d 100644 --- a/src/analysis/volatility.py +++ b/src/analysis/volatility.py @@ -92,9 +92,7 @@ class VolatilityAnalyzer: recent_tr = true_ranges[-period:] return sum(recent_tr) / len(recent_tr) - def calculate_price_change( - self, current_price: float, past_price: float - ) -> float: + def calculate_price_change(self, current_price: float, past_price: float) -> float: """Calculate price change percentage. Args: @@ -108,9 +106,7 @@ class VolatilityAnalyzer: return 0.0 return ((current_price - past_price) / past_price) * 100 - def calculate_volume_surge( - self, current_volume: float, avg_volume: float - ) -> float: + def calculate_volume_surge(self, current_volume: float, avg_volume: float) -> float: """Calculate volume surge ratio. Args: @@ -240,11 +236,7 @@ class VolatilityAnalyzer: Momentum score (0-100) """ # Weight recent changes more heavily - weighted_change = ( - price_change_1m * 0.4 + - price_change_5m * 0.3 + - price_change_15m * 0.2 - ) + weighted_change = price_change_1m * 0.4 + price_change_5m * 0.3 + price_change_15m * 0.2 # Volume contribution (normalized to 0-10 scale) volume_contribution = min(10.0, (volume_surge - 1.0) * 5.0) @@ -301,17 +293,11 @@ class VolatilityAnalyzer: if len(close_prices) > 0: if len(close_prices) >= 1: - price_change_1m = self.calculate_price_change( - current_price, close_prices[-1] - ) + price_change_1m = self.calculate_price_change(current_price, close_prices[-1]) if len(close_prices) >= 5: - price_change_5m = self.calculate_price_change( - current_price, close_prices[-5] - ) + price_change_5m = self.calculate_price_change(current_price, close_prices[-5]) if len(close_prices) >= 15: - price_change_15m = self.calculate_price_change( - current_price, close_prices[-15] - ) + price_change_15m = self.calculate_price_change(current_price, close_prices[-15]) # Calculate volume surge avg_volume = sum(volumes) / len(volumes) if volumes else current_volume diff --git a/src/backup/__init__.py b/src/backup/__init__.py index a58e700..069fdd6 100644 --- a/src/backup/__init__.py +++ b/src/backup/__init__.py @@ -7,9 +7,9 @@ This module provides: - Health monitoring and alerts """ -from src.backup.exporter import BackupExporter, ExportFormat -from src.backup.scheduler import BackupScheduler, BackupPolicy from src.backup.cloud_storage import CloudStorage, S3Config +from src.backup.exporter import BackupExporter, ExportFormat +from src.backup.scheduler import BackupPolicy, BackupScheduler __all__ = [ "BackupExporter", diff --git a/src/backup/cloud_storage.py b/src/backup/cloud_storage.py index 4850e8d..ba62f4c 100644 --- a/src/backup/cloud_storage.py +++ b/src/backup/cloud_storage.py @@ -94,7 +94,9 @@ class CloudStorage: if metadata: extra_args["Metadata"] = metadata - logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key) + logger.info( + "Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key + ) try: self.client.upload_file( diff --git a/src/backup/exporter.py b/src/backup/exporter.py index f5b3cd6..979982d 100644 --- a/src/backup/exporter.py +++ b/src/backup/exporter.py @@ -14,14 +14,14 @@ import json import logging import sqlite3 from datetime import UTC, datetime -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import Any logger = logging.getLogger(__name__) -class ExportFormat(str, Enum): +class ExportFormat(StrEnum): """Supported export formats.""" JSON = "json" @@ -103,15 +103,11 @@ class BackupExporter: elif fmt == ExportFormat.CSV: return self._export_csv(output_dir, timestamp, compress, incremental_since) elif fmt == ExportFormat.PARQUET: - return self._export_parquet( - output_dir, timestamp, compress, incremental_since - ) + return self._export_parquet(output_dir, timestamp, compress, incremental_since) else: raise ValueError(f"Unsupported format: {fmt}") - def _get_trades( - self, incremental_since: datetime | None = None - ) -> list[dict[str, Any]]: + def _get_trades(self, incremental_since: datetime | None = None) -> list[dict[str, Any]]: """Fetch trades from database. Args: @@ -164,9 +160,7 @@ class BackupExporter: data = { "export_timestamp": datetime.now(UTC).isoformat(), - "incremental_since": ( - incremental_since.isoformat() if incremental_since else None - ), + "incremental_since": (incremental_since.isoformat() if incremental_since else None), "record_count": len(trades), "trades": trades, } @@ -284,8 +278,7 @@ class BackupExporter: import pyarrow.parquet as pq except ImportError: raise ImportError( - "pyarrow is required for Parquet export. " - "Install with: pip install pyarrow" + "pyarrow is required for Parquet export. Install with: pip install pyarrow" ) # Convert to pyarrow table diff --git a/src/backup/health_monitor.py b/src/backup/health_monitor.py index 4ec8406..a2c6fc9 100644 --- a/src/backup/health_monitor.py +++ b/src/backup/health_monitor.py @@ -14,14 +14,14 @@ import shutil import sqlite3 from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import Any logger = logging.getLogger(__name__) -class HealthStatus(str, Enum): +class HealthStatus(StrEnum): """Health check status.""" HEALTHY = "healthy" @@ -137,9 +137,13 @@ class HealthMonitor: used_percent = (stat.used / stat.total) * 100 if stat.free < self.min_disk_space_bytes: + min_disk_gb = self.min_disk_space_bytes / 1024 / 1024 / 1024 return HealthCheckResult( status=HealthStatus.UNHEALTHY, - message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)", + message=( + f"Low disk space: {free_gb:.2f} GB free " + f"(minimum: {min_disk_gb:.2f} GB)" + ), details={ "free_gb": free_gb, "total_gb": total_gb, diff --git a/src/backup/scheduler.py b/src/backup/scheduler.py index c9f16d6..3b9f633 100644 --- a/src/backup/scheduler.py +++ b/src/backup/scheduler.py @@ -12,14 +12,14 @@ import logging import shutil from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from enum import Enum +from enum import StrEnum from pathlib import Path from typing import Any logger = logging.getLogger(__name__) -class BackupPolicy(str, Enum): +class BackupPolicy(StrEnum): """Backup retention policies.""" DAILY = "daily" @@ -69,9 +69,7 @@ class BackupScheduler: for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]: d.mkdir(parents=True, exist_ok=True) - def create_backup( - self, policy: BackupPolicy, verify: bool = True - ) -> BackupMetadata: + def create_backup(self, policy: BackupPolicy, verify: bool = True) -> BackupMetadata: """Create a database backup. Args: @@ -229,9 +227,7 @@ class BackupScheduler: return removed - def list_backups( - self, policy: BackupPolicy | None = None - ) -> list[BackupMetadata]: + def list_backups(self, policy: BackupPolicy | None = None) -> list[BackupMetadata]: """List available backups. Args: diff --git a/src/brain/cache.py b/src/brain/cache.py index cf9190b..cf5f540 100644 --- a/src/brain/cache.py +++ b/src/brain/cache.py @@ -13,8 +13,8 @@ import hashlib import json import logging import time -from dataclasses import dataclass, field -from typing import Any, TYPE_CHECKING +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from src.brain.gemini_client import TradeDecision @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) class CacheEntry: """Cached decision with metadata.""" - decision: "TradeDecision" + decision: TradeDecision cached_at: float # Unix timestamp hit_count: int = 0 market_data_hash: str = "" @@ -239,9 +239,7 @@ class DecisionCache: """ current_time = time.time() expired_keys = [ - k - for k, v in self._cache.items() - if current_time - v.cached_at > self.ttl_seconds + k for k, v in self._cache.items() if current_time - v.cached_at > self.ttl_seconds ] count = len(expired_keys) diff --git a/src/brain/context_selector.py b/src/brain/context_selector.py index 47620e4..119eb78 100644 --- a/src/brain/context_selector.py +++ b/src/brain/context_selector.py @@ -11,14 +11,14 @@ from __future__ import annotations from dataclasses import dataclass from datetime import UTC, datetime -from enum import Enum +from enum import StrEnum from typing import Any from src.context.layer import ContextLayer from src.context.store import ContextStore -class DecisionType(str, Enum): +class DecisionType(StrEnum): """Type of trading decision being made.""" NORMAL = "normal" # Regular trade decision @@ -183,9 +183,7 @@ class ContextSelector: ContextLayer.L1_LEGACY, ] - scores = { - layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers - } + scores = {layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers} # Filter by minimum score selected_layers = [layer for layer, score in scores.items() if score >= min_score] diff --git a/src/brain/gemini_client.py b/src/brain/gemini_client.py index c664eb2..6e61c40 100644 --- a/src/brain/gemini_client.py +++ b/src/brain/gemini_client.py @@ -25,12 +25,12 @@ from typing import Any from google import genai -from src.config import Settings -from src.data.news_api import NewsAPI, NewsSentiment -from src.data.economic_calendar import EconomicCalendar -from src.data.market_data import MarketData from src.brain.cache import DecisionCache from src.brain.prompt_optimizer import PromptOptimizer +from src.config import Settings +from src.data.economic_calendar import EconomicCalendar +from src.data.market_data import MarketData +from src.data.news_api import NewsAPI, NewsSentiment logger = logging.getLogger(__name__) @@ -159,16 +159,12 @@ class GeminiClient: return "" # Check for upcoming high-impact events - upcoming = self._economic_calendar.get_upcoming_events( - days_ahead=7, min_impact="HIGH" - ) + upcoming = self._economic_calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH") if upcoming.high_impact_count == 0: return "" - lines = [ - f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days" - ] + lines = [f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"] if upcoming.next_major_event is not None: event = upcoming.next_major_event @@ -180,9 +176,7 @@ class GeminiClient: # Check for earnings earnings_date = self._economic_calendar.get_earnings_date(stock_code) if earnings_date is not None: - lines.append( - f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}" - ) + lines.append(f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}") return "\n".join(lines) @@ -235,9 +229,7 @@ class GeminiClient: # Add foreigner net if non-zero if market_data.get("foreigner_net", 0) != 0: - market_info_lines.append( - f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}" - ) + market_info_lines.append(f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}") market_info = "\n".join(market_info_lines) @@ -249,8 +241,7 @@ class GeminiClient: market_info += f"\n\n{external_context}" json_format = ( - '{"action": "BUY"|"SELL"|"HOLD", ' - '"confidence": , "rationale": ""}' + '{"action": "BUY"|"SELL"|"HOLD", "confidence": , "rationale": ""}' ) return ( f"You are a professional {market_name} trading analyst.\n" @@ -289,15 +280,12 @@ class GeminiClient: # Add foreigner net if non-zero if market_data.get("foreigner_net", 0) != 0: - market_info_lines.append( - f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}" - ) + market_info_lines.append(f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}") market_info = "\n".join(market_info_lines) json_format = ( - '{"action": "BUY"|"SELL"|"HOLD", ' - '"confidence": , "rationale": ""}' + '{"action": "BUY"|"SELL"|"HOLD", "confidence": , "rationale": ""}' ) return ( f"You are a professional {market_name} trading analyst.\n" @@ -339,25 +327,19 @@ class GeminiClient: data = json.loads(cleaned) except json.JSONDecodeError: logger.warning("Malformed JSON from Gemini — defaulting to HOLD") - return TradeDecision( - action="HOLD", confidence=0, rationale="Malformed JSON response" - ) + return TradeDecision(action="HOLD", confidence=0, rationale="Malformed JSON response") # Validate required fields if not all(k in data for k in ("action", "confidence", "rationale")): logger.warning("Missing fields in Gemini response — defaulting to HOLD") # Preserve raw text in rationale so prompt_override callers (e.g. pre_market_planner) # can extract their own JSON format from decision.rationale (#245) - return TradeDecision( - action="HOLD", confidence=0, rationale=raw - ) + return TradeDecision(action="HOLD", confidence=0, rationale=raw) action = str(data["action"]).upper() if action not in VALID_ACTIONS: logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action) - return TradeDecision( - action="HOLD", confidence=0, rationale=f"Invalid action: {action}" - ) + return TradeDecision(action="HOLD", confidence=0, rationale=f"Invalid action: {action}") confidence = int(data["confidence"]) rationale = str(data["rationale"]) @@ -445,9 +427,7 @@ class GeminiClient: # not a parsed TradeDecision. Skip parse_response to avoid spurious # "Missing fields" warnings and return the raw response directly. (#247) if "prompt_override" in market_data: - logger.info( - "Gemini raw response received (prompt_override, tokens=%d)", token_count - ) + logger.info("Gemini raw response received (prompt_override, tokens=%d)", token_count) # Not a trade decision — don't inflate _total_decisions metrics return TradeDecision( action="HOLD", confidence=0, rationale=raw, token_count=token_count @@ -546,9 +526,7 @@ class GeminiClient: # Batch Decision Making (for daily trading mode) # ------------------------------------------------------------------ - async def decide_batch( - self, stocks_data: list[dict[str, Any]] - ) -> dict[str, TradeDecision]: + async def decide_batch(self, stocks_data: list[dict[str, Any]]) -> dict[str, TradeDecision]: """Make decisions for multiple stocks in a single API call. This is designed for daily trading mode to minimize API usage diff --git a/src/brain/prompt_optimizer.py b/src/brain/prompt_optimizer.py index fdc0d99..c85edc8 100644 --- a/src/brain/prompt_optimizer.py +++ b/src/brain/prompt_optimizer.py @@ -179,7 +179,8 @@ class PromptOptimizer: # Minimal instructions prompt = ( f"{market_name} trader. Analyze:\n{data_str}\n\n" - 'Return JSON: {"action":"BUY"|"SELL"|"HOLD","confidence":<0-100>,"rationale":""}\n' + "Return JSON: " + '{"action":"BUY"|"SELL"|"HOLD","confidence":<0-100>,"rationale":""}\n' "Rules: action=BUY/SELL/HOLD, confidence=0-100, rationale=concise. No markdown." ) else: diff --git a/src/broker/kis_api.py b/src/broker/kis_api.py index 953a604..269463b 100644 --- a/src/broker/kis_api.py +++ b/src/broker/kis_api.py @@ -58,7 +58,7 @@ class LeakyBucket: def __init__(self, rate: float) -> None: """Args: - rate: Maximum requests per second. + rate: Maximum requests per second. """ self._rate = rate self._interval = 1.0 / rate @@ -103,7 +103,8 @@ class KISBroker: ssl_ctx.verify_mode = ssl.CERT_NONE connector = aiohttp.TCPConnector(ssl=ssl_ctx) self._session = aiohttp.ClientSession( - timeout=timeout, connector=connector, + timeout=timeout, + connector=connector, ) return self._session @@ -224,16 +225,12 @@ class KISBroker: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"get_orderbook failed ({resp.status}): {text}" - ) + raise ConnectionError(f"get_orderbook failed ({resp.status}): {text}") return await resp.json() except (TimeoutError, aiohttp.ClientError) as exc: raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc - async def get_current_price( - self, stock_code: str - ) -> tuple[float, float, float]: + async def get_current_price(self, stock_code: str) -> tuple[float, float, float]: """Fetch current price data for a domestic stock. Uses the ``inquire-price`` API (FHKST01010100), which works in both @@ -265,9 +262,7 @@ class KISBroker: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"get_current_price failed ({resp.status}): {text}" - ) + raise ConnectionError(f"get_current_price failed ({resp.status}): {text}") data = await resp.json() out = data.get("output", {}) return ( @@ -276,9 +271,7 @@ class KISBroker: _f(out.get("frgn_ntby_qty")), ) except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching current price: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching current price: {exc}") from exc async def get_balance(self) -> dict[str, Any]: """Fetch current account balance and holdings.""" @@ -308,9 +301,7 @@ class KISBroker: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"get_balance failed ({resp.status}): {text}" - ) + raise ConnectionError(f"get_balance failed ({resp.status}): {text}") return await resp.json() except (TimeoutError, aiohttp.ClientError) as exc: raise ConnectionError(f"Network error fetching balance: {exc}") from exc @@ -369,9 +360,7 @@ class KISBroker: async with session.post(url, headers=headers, json=body) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"send_order failed ({resp.status}): {text}" - ) + raise ConnectionError(f"send_order failed ({resp.status}): {text}") data = await resp.json() logger.info( "Order submitted", @@ -449,9 +438,7 @@ class KISBroker: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"fetch_market_rankings failed ({resp.status}): {text}" - ) + raise ConnectionError(f"fetch_market_rankings failed ({resp.status}): {text}") data = await resp.json() # Parse response - output is a list of ranked stocks @@ -465,14 +452,16 @@ class KISBroker: rankings = [] for item in data.get("output", [])[:limit]: - rankings.append({ - "stock_code": item.get("stck_shrn_iscd") or item.get("mksc_shrn_iscd", ""), - "name": item.get("hts_kor_isnm", ""), - "price": _safe_float(item.get("stck_prpr", "0")), - "volume": _safe_float(item.get("acml_vol", "0")), - "change_rate": _safe_float(item.get("prdy_ctrt", "0")), - "volume_increase_rate": _safe_float(item.get("vol_inrt", "0")), - }) + rankings.append( + { + "stock_code": item.get("stck_shrn_iscd") or item.get("mksc_shrn_iscd", ""), + "name": item.get("hts_kor_isnm", ""), + "price": _safe_float(item.get("stck_prpr", "0")), + "volume": _safe_float(item.get("acml_vol", "0")), + "change_rate": _safe_float(item.get("prdy_ctrt", "0")), + "volume_increase_rate": _safe_float(item.get("vol_inrt", "0")), + } + ) return rankings except (TimeoutError, aiohttp.ClientError) as exc: @@ -522,9 +511,7 @@ class KISBroker: data = await resp.json() return data.get("output", []) or [] except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching domestic pending orders: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching domestic pending orders: {exc}") from exc async def cancel_domestic_order( self, @@ -575,14 +562,10 @@ class KISBroker: async with session.post(url, headers=headers, json=body) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"cancel_domestic_order failed ({resp.status}): {text}" - ) + raise ConnectionError(f"cancel_domestic_order failed ({resp.status}): {text}") return cast(dict[str, Any], await resp.json()) except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error cancelling domestic order: {exc}" - ) from exc + raise ConnectionError(f"Network error cancelling domestic order: {exc}") from exc async def get_daily_prices( self, @@ -609,6 +592,7 @@ class KISBroker: # Calculate date range (today and N days ago) from datetime import datetime, timedelta + end_date = datetime.now().strftime("%Y%m%d") start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d") @@ -627,9 +611,7 @@ class KISBroker: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"get_daily_prices failed ({resp.status}): {text}" - ) + raise ConnectionError(f"get_daily_prices failed ({resp.status}): {text}") data = await resp.json() # Parse response @@ -643,14 +625,16 @@ class KISBroker: prices = [] for item in data.get("output2", []): - prices.append({ - "date": item.get("stck_bsop_date", ""), - "open": _safe_float(item.get("stck_oprc", "0")), - "high": _safe_float(item.get("stck_hgpr", "0")), - "low": _safe_float(item.get("stck_lwpr", "0")), - "close": _safe_float(item.get("stck_clpr", "0")), - "volume": _safe_float(item.get("acml_vol", "0")), - }) + prices.append( + { + "date": item.get("stck_bsop_date", ""), + "open": _safe_float(item.get("stck_oprc", "0")), + "high": _safe_float(item.get("stck_hgpr", "0")), + "low": _safe_float(item.get("stck_lwpr", "0")), + "close": _safe_float(item.get("stck_clpr", "0")), + "volume": _safe_float(item.get("acml_vol", "0")), + } + ) # Sort oldest to newest (KIS returns newest first) prices.reverse() diff --git a/src/broker/overseas.py b/src/broker/overseas.py index d98ea67..5120ed6 100644 --- a/src/broker/overseas.py +++ b/src/broker/overseas.py @@ -36,11 +36,11 @@ _CANCEL_TR_ID_MAP: dict[str, tuple[str, str]] = { "NYSE": ("TTTT1004U", "VTTT1004U"), "AMEX": ("TTTT1004U", "VTTT1004U"), "SEHK": ("TTTS1003U", "VTTS1003U"), - "TSE": ("TTTS0309U", "VTTS0309U"), + "TSE": ("TTTS0309U", "VTTS0309U"), "SHAA": ("TTTS0302U", "VTTS0302U"), "SZAA": ("TTTS0306U", "VTTS0306U"), - "HNX": ("TTTS0312U", "VTTS0312U"), - "HSX": ("TTTS0312U", "VTTS0312U"), + "HNX": ("TTTS0312U", "VTTS0312U"), + "HSX": ("TTTS0312U", "VTTS0312U"), } @@ -56,9 +56,7 @@ class OverseasBroker: """ self._broker = kis_broker - async def get_overseas_price( - self, exchange_code: str, stock_code: str - ) -> dict[str, Any]: + async def get_overseas_price(self, exchange_code: str, stock_code: str) -> dict[str, Any]: """ Fetch overseas stock price. @@ -89,14 +87,10 @@ class OverseasBroker: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"get_overseas_price failed ({resp.status}): {text}" - ) + raise ConnectionError(f"get_overseas_price failed ({resp.status}): {text}") return await resp.json() except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching overseas price: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching overseas price: {exc}") from exc async def fetch_overseas_rankings( self, @@ -154,9 +148,7 @@ class OverseasBroker: ranking_type, ) return [] - raise ConnectionError( - f"fetch_overseas_rankings failed ({resp.status}): {text}" - ) + raise ConnectionError(f"fetch_overseas_rankings failed ({resp.status}): {text}") data = await resp.json() rows = self._extract_ranking_rows(data) @@ -171,9 +163,7 @@ class OverseasBroker: ) return [] except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching overseas rankings: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching overseas rankings: {exc}") from exc async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]: """ @@ -193,9 +183,7 @@ class OverseasBroker: # TR_ID: 실전 TTTS3012R, 모의 VTTS3012R # Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 잔고조회' 시트 - balance_tr_id = ( - "TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R" - ) + balance_tr_id = "TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R" headers = await self._broker._auth_headers(balance_tr_id) params = { "CANO": self._broker._account_no, @@ -205,22 +193,16 @@ class OverseasBroker: "CTX_AREA_FK200": "", "CTX_AREA_NK200": "", } - url = ( - f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-balance" - ) + url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-balance" try: async with session.get(url, headers=headers, params=params) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"get_overseas_balance failed ({resp.status}): {text}" - ) + raise ConnectionError(f"get_overseas_balance failed ({resp.status}): {text}") return await resp.json() except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching overseas balance: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching overseas balance: {exc}") from exc async def get_overseas_buying_power( self, @@ -247,9 +229,7 @@ class OverseasBroker: # TR_ID: 실전 TTTS3007R, 모의 VTTS3007R # Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트 - ps_tr_id = ( - "TTTS3007R" if self._broker._settings.MODE == "live" else "VTTS3007R" - ) + ps_tr_id = "TTTS3007R" if self._broker._settings.MODE == "live" else "VTTS3007R" headers = await self._broker._auth_headers(ps_tr_id) params = { "CANO": self._broker._account_no, @@ -258,9 +238,7 @@ class OverseasBroker: "OVRS_ORD_UNPR": f"{price:.2f}", "ITEM_CD": stock_code, } - url = ( - f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-psamount" - ) + url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-psamount" try: async with session.get(url, headers=headers, params=params) as resp: @@ -271,9 +249,7 @@ class OverseasBroker: ) return await resp.json() except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching overseas buying power: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching overseas buying power: {exc}") from exc async def send_overseas_order( self, @@ -330,9 +306,7 @@ class OverseasBroker: async with session.post(url, headers=headers, json=body) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"send_overseas_order failed ({resp.status}): {text}" - ) + raise ConnectionError(f"send_overseas_order failed ({resp.status}): {text}") data = await resp.json() rt_cd = data.get("rt_cd", "") msg1 = data.get("msg1", "") @@ -357,13 +331,9 @@ class OverseasBroker: ) return data except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error sending overseas order: {exc}" - ) from exc + raise ConnectionError(f"Network error sending overseas order: {exc}") from exc - async def get_overseas_pending_orders( - self, exchange_code: str - ) -> list[dict[str, Any]]: + async def get_overseas_pending_orders(self, exchange_code: str) -> list[dict[str, Any]]: """Fetch unfilled (pending) overseas orders for a given exchange. Args: @@ -379,9 +349,7 @@ class OverseasBroker: ConnectionError: On network or API errors (live mode only). """ if self._broker._settings.MODE != "live": - logger.debug( - "Pending orders API (TTTS3018R) not supported in paper mode; returning []" - ) + logger.debug("Pending orders API (TTTS3018R) not supported in paper mode; returning []") return [] await self._broker._rate_limiter.acquire() @@ -398,9 +366,7 @@ class OverseasBroker: "CTX_AREA_FK200": "", "CTX_AREA_NK200": "", } - url = ( - f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs" - ) + url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs" try: async with session.get(url, headers=headers, params=params) as resp: @@ -415,9 +381,7 @@ class OverseasBroker: return output return [] except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error fetching pending orders: {exc}" - ) from exc + raise ConnectionError(f"Network error fetching pending orders: {exc}") from exc async def cancel_overseas_order( self, @@ -469,22 +433,16 @@ class OverseasBroker: headers = await self._broker._auth_headers(tr_id) headers["hashkey"] = hash_key - url = ( - f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl" - ) + url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl" try: async with session.post(url, headers=headers, json=body) as resp: if resp.status != 200: text = await resp.text() - raise ConnectionError( - f"cancel_overseas_order failed ({resp.status}): {text}" - ) + raise ConnectionError(f"cancel_overseas_order failed ({resp.status}): {text}") return await resp.json() except (TimeoutError, aiohttp.ClientError) as exc: - raise ConnectionError( - f"Network error cancelling overseas order: {exc}" - ) from exc + raise ConnectionError(f"Network error cancelling overseas order: {exc}") from exc def _get_currency_code(self, exchange_code: str) -> str: """ diff --git a/src/config.py b/src/config.py index 671b95b..81290e3 100644 --- a/src/config.py +++ b/src/config.py @@ -111,25 +111,21 @@ class Settings(BaseSettings): # Telegram notification type filters (granular control) # circuit_breaker is always sent regardless — safety-critical - TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts + TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts TELEGRAM_NOTIFY_MARKET_OPEN_CLOSE: bool = True # Market open/close alerts - TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts - TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts - TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts - TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent) - TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts + TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts + TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts + TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts + TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent) + TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts # Overseas ranking API (KIS endpoint/TR_ID may vary by account/product) # Override these from .env if your account uses different specs. OVERSEAS_RANKING_ENABLED: bool = True OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76290000" OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76270000" - OVERSEAS_RANKING_FLUCT_PATH: str = ( - "/uapi/overseas-stock/v1/ranking/updown-rate" - ) - OVERSEAS_RANKING_VOLUME_PATH: str = ( - "/uapi/overseas-stock/v1/ranking/volume-surge" - ) + OVERSEAS_RANKING_FLUCT_PATH: str = "/uapi/overseas-stock/v1/ranking/updown-rate" + OVERSEAS_RANKING_VOLUME_PATH: str = "/uapi/overseas-stock/v1/ranking/volume-surge" # Dashboard (optional) DASHBOARD_ENABLED: bool = False diff --git a/src/context/aggregator.py b/src/context/aggregator.py index 8eaecab..36e3982 100644 --- a/src/context/aggregator.py +++ b/src/context/aggregator.py @@ -222,9 +222,7 @@ class ContextAggregator: total_pnl = 0.0 for month in months: - monthly_pnl = self.store.get_context( - ContextLayer.L4_MONTHLY, month, "monthly_pnl" - ) + monthly_pnl = self.store.get_context(ContextLayer.L4_MONTHLY, month, "monthly_pnl") if monthly_pnl is not None: total_pnl += monthly_pnl @@ -251,9 +249,7 @@ class ContextAggregator: if quarterly_pnl is not None: total_pnl += quarterly_pnl - self.store.set_context( - ContextLayer.L2_ANNUAL, year, "annual_pnl", round(total_pnl, 2) - ) + self.store.set_context(ContextLayer.L2_ANNUAL, year, "annual_pnl", round(total_pnl, 2)) def aggregate_legacy_from_annual(self) -> None: """Aggregate L1 (legacy) context from all L2 (annual) data.""" @@ -280,9 +276,7 @@ class ContextAggregator: self.store.set_context( ContextLayer.L1_LEGACY, "LEGACY", "total_pnl", round(total_pnl, 2) ) - self.store.set_context( - ContextLayer.L1_LEGACY, "LEGACY", "years_traded", years_traded - ) + self.store.set_context(ContextLayer.L1_LEGACY, "LEGACY", "years_traded", years_traded) self.store.set_context( ContextLayer.L1_LEGACY, "LEGACY", diff --git a/src/context/layer.py b/src/context/layer.py index fdad474..7c40d34 100644 --- a/src/context/layer.py +++ b/src/context/layer.py @@ -3,10 +3,10 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Enum +from enum import StrEnum -class ContextLayer(str, Enum): +class ContextLayer(StrEnum): """7-tier context hierarchy from real-time to generational.""" L1_LEGACY = "L1_LEGACY" # Cumulative/generational wisdom diff --git a/src/context/summarizer.py b/src/context/summarizer.py index c154ff7..8bc024d 100644 --- a/src/context/summarizer.py +++ b/src/context/summarizer.py @@ -9,7 +9,7 @@ This module summarizes old context data instead of including raw details: from __future__ import annotations from dataclasses import dataclass -from datetime import UTC, datetime, timedelta +from datetime import UTC, datetime from typing import Any from src.context.layer import ContextLayer diff --git a/src/core/kill_switch.py b/src/core/kill_switch.py index 9f2231b..71a3cdf 100644 --- a/src/core/kill_switch.py +++ b/src/core/kill_switch.py @@ -11,8 +11,9 @@ Order is fixed: from __future__ import annotations import inspect +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Awaitable, Callable +from typing import Any StepCallable = Callable[[], Any | Awaitable[Any]] diff --git a/src/core/order_policy.py b/src/core/order_policy.py index 5fbb43a..a347996 100644 --- a/src/core/order_policy.py +++ b/src/core/order_policy.py @@ -15,7 +15,7 @@ from src.markets.schedule import MarketInfo _LOW_LIQUIDITY_SESSIONS = {"NXT_AFTER", "US_PRE", "US_DAY", "US_AFTER"} -class OrderPolicyRejected(Exception): +class OrderPolicyRejectedError(Exception): """Raised when an order violates session policy.""" def __init__(self, message: str, *, session_id: str, market_code: str) -> None: @@ -61,7 +61,9 @@ def classify_session_id(market: MarketInfo, now: datetime | None = None) -> str: def get_session_info(market: MarketInfo, now: datetime | None = None) -> SessionInfo: session_id = classify_session_id(market, now) - return SessionInfo(session_id=session_id, is_low_liquidity=session_id in _LOW_LIQUIDITY_SESSIONS) + return SessionInfo( + session_id=session_id, is_low_liquidity=session_id in _LOW_LIQUIDITY_SESSIONS + ) def validate_order_policy( @@ -76,7 +78,7 @@ def validate_order_policy( is_market_order = price <= 0 if info.is_low_liquidity and is_market_order: - raise OrderPolicyRejected( + raise OrderPolicyRejectedError( f"Market order is forbidden in low-liquidity session ({info.session_id})", session_id=info.session_id, market_code=market.code, @@ -84,10 +86,14 @@ def validate_order_policy( # Guard against accidental unsupported actions. if order_type not in {"BUY", "SELL"}: - raise OrderPolicyRejected( + raise OrderPolicyRejectedError( f"Unsupported order_type={order_type}", session_id=info.session_id, market_code=market.code, ) return info + + +# Backward compatibility alias +OrderPolicyRejected = OrderPolicyRejectedError diff --git a/src/core/priority_queue.py b/src/core/priority_queue.py index 92f9ace..1010491 100644 --- a/src/core/priority_queue.py +++ b/src/core/priority_queue.py @@ -28,9 +28,7 @@ class PriorityTask: # Task data not used in comparison task_id: str = field(compare=False) task_data: dict[str, Any] = field(compare=False, default_factory=dict) - callback: Callable[[], Coroutine[Any, Any, Any]] | None = field( - compare=False, default=None - ) + callback: Callable[[], Coroutine[Any, Any, Any]] | None = field(compare=False, default=None) @dataclass diff --git a/src/core/risk_manager.py b/src/core/risk_manager.py index 7fd559b..8ce405b 100644 --- a/src/core/risk_manager.py +++ b/src/core/risk_manager.py @@ -25,7 +25,7 @@ class CircuitBreakerTripped(SystemExit): ) -class FatFingerRejected(Exception): +class FatFingerRejectedError(Exception): """Raised when an order exceeds the maximum allowed proportion of cash.""" def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None: @@ -61,7 +61,7 @@ class RiskManager: def check_fat_finger(self, order_amount: float, total_cash: float) -> None: """Reject orders that exceed the maximum proportion of available cash.""" if total_cash <= 0: - raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct) + raise FatFingerRejectedError(order_amount, total_cash, self._ff_max_pct) ratio_pct = (order_amount / total_cash) * 100 if ratio_pct > self._ff_max_pct: @@ -69,7 +69,7 @@ class RiskManager: "Fat finger check failed", extra={"order_amount": order_amount}, ) - raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct) + raise FatFingerRejectedError(order_amount, total_cash, self._ff_max_pct) def validate_order( self, @@ -81,3 +81,7 @@ class RiskManager: self.check_circuit_breaker(current_pnl_pct) self.check_fat_finger(order_amount, total_cash) logger.info("Order passed risk validation") + + +# Backward compatibility alias +FatFingerRejected = FatFingerRejectedError diff --git a/src/dashboard/app.py b/src/dashboard/app.py index e9d8e26..2c42676 100644 --- a/src/dashboard/app.py +++ b/src/dashboard/app.py @@ -5,7 +5,7 @@ from __future__ import annotations import json import os import sqlite3 -from datetime import UTC, datetime, timezone +from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -188,10 +188,7 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI: return { "market": "all", "combined": combined, - "by_market": [ - _row_to_performance(row) - for row in by_market_rows - ], + "by_market": [_row_to_performance(row) for row in by_market_rows], } row = conn.execute( @@ -401,7 +398,7 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI: """ ).fetchall() - now = datetime.now(timezone.utc) + now = datetime.now(UTC) positions = [] for row in rows: entry_time_str = row["entry_time"] diff --git a/src/data/economic_calendar.py b/src/data/economic_calendar.py index 9f662b6..3057ebe 100644 --- a/src/data/economic_calendar.py +++ b/src/data/economic_calendar.py @@ -9,7 +9,6 @@ from __future__ import annotations import logging from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any logger = logging.getLogger(__name__) diff --git a/src/db.py b/src/db.py index e161de3..d7cbeb5 100644 --- a/src/db.py +++ b/src/db.py @@ -123,8 +123,7 @@ def init_db(db_path: str) -> sqlite3.Connection: """ ) decision_columns = { - row[1] - for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall() + row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall() } if "session_id" not in decision_columns: conn.execute("ALTER TABLE decision_logs ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'") @@ -185,9 +184,7 @@ def init_db(db_path: str) -> sqlite3.Connection: conn.execute( "CREATE INDEX IF NOT EXISTS idx_decision_logs_timestamp ON decision_logs(timestamp)" ) - conn.execute( - "CREATE INDEX IF NOT EXISTS idx_decision_logs_reviewed ON decision_logs(reviewed)" - ) + conn.execute("CREATE INDEX IF NOT EXISTS idx_decision_logs_reviewed ON decision_logs(reviewed)") conn.execute( "CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)" ) @@ -381,9 +378,7 @@ def get_open_position( return {"decision_id": row[1], "price": row[2], "quantity": row[3], "timestamp": row[4]} -def get_recent_symbols( - conn: sqlite3.Connection, market: str, limit: int = 30 -) -> list[str]: +def get_recent_symbols(conn: sqlite3.Connection, market: str, limit: int = 30) -> list[str]: """Return recent unique symbols for a market, newest first.""" cursor = conn.execute( """ diff --git a/src/evolution/ab_test.py b/src/evolution/ab_test.py index e9ed3df..daf8854 100644 --- a/src/evolution/ab_test.py +++ b/src/evolution/ab_test.py @@ -90,9 +90,7 @@ class ABTester: sharpe_ratio = None if len(pnls) > 1: mean_return = avg_pnl - std_return = ( - sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1) - ) ** 0.5 + std_return = (sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1)) ** 0.5 if std_return > 0: sharpe_ratio = mean_return / std_return @@ -198,8 +196,7 @@ class ABTester: if meets_criteria: logger.info( - "Strategy '%s' meets deployment criteria: " - "win_rate=%.2f%%, trades=%d, avg_pnl=%.2f", + "Strategy '%s' meets deployment criteria: win_rate=%.2f%%, trades=%d, avg_pnl=%.2f", result.winner, winning_perf.win_rate, winning_perf.total_trades, diff --git a/src/evolution/daily_review.py b/src/evolution/daily_review.py index fd4eb0c..eb37100 100644 --- a/src/evolution/daily_review.py +++ b/src/evolution/daily_review.py @@ -60,9 +60,7 @@ class DailyReviewer: if isinstance(scenario_match, dict) and scenario_match: matched += 1 scenario_match_rate = ( - round((matched / total_decisions) * 100, 2) - if total_decisions - else 0.0 + round((matched / total_decisions) * 100, 2) if total_decisions else 0.0 ) trade_stats = self._conn.execute( diff --git a/src/evolution/optimizer.py b/src/evolution/optimizer.py index c9ef719..4369c54 100644 --- a/src/evolution/optimizer.py +++ b/src/evolution/optimizer.py @@ -80,26 +80,26 @@ class EvolutionOptimizer: # Convert to dict format for analysis failures = [] for decision in losing_decisions: - failures.append({ - "decision_id": decision.decision_id, - "timestamp": decision.timestamp, - "stock_code": decision.stock_code, - "market": decision.market, - "exchange_code": decision.exchange_code, - "action": decision.action, - "confidence": decision.confidence, - "rationale": decision.rationale, - "outcome_pnl": decision.outcome_pnl, - "outcome_accuracy": decision.outcome_accuracy, - "context_snapshot": decision.context_snapshot, - "input_data": decision.input_data, - }) + failures.append( + { + "decision_id": decision.decision_id, + "timestamp": decision.timestamp, + "stock_code": decision.stock_code, + "market": decision.market, + "exchange_code": decision.exchange_code, + "action": decision.action, + "confidence": decision.confidence, + "rationale": decision.rationale, + "outcome_pnl": decision.outcome_pnl, + "outcome_accuracy": decision.outcome_accuracy, + "context_snapshot": decision.context_snapshot, + "input_data": decision.input_data, + } + ) return failures - def identify_failure_patterns( - self, failures: list[dict[str, Any]] - ) -> dict[str, Any]: + def identify_failure_patterns(self, failures: list[dict[str, Any]]) -> dict[str, Any]: """Identify patterns in losing decisions. Analyzes: @@ -143,12 +143,8 @@ class EvolutionOptimizer: total_confidence += failure.get("confidence", 0) total_loss += failure.get("outcome_pnl", 0.0) - patterns["avg_confidence"] = ( - round(total_confidence / len(failures), 2) if failures else 0.0 - ) - patterns["avg_loss"] = ( - round(total_loss / len(failures), 2) if failures else 0.0 - ) + patterns["avg_confidence"] = round(total_confidence / len(failures), 2) if failures else 0.0 + patterns["avg_loss"] = round(total_loss / len(failures), 2) if failures else 0.0 # Convert Counters to regular dicts for JSON serialization patterns["markets"] = dict(patterns["markets"]) @@ -197,7 +193,8 @@ class EvolutionOptimizer: prompt = ( "You are a quantitative trading strategy developer.\n" - "Analyze these failed trades and their patterns, then generate an improved strategy.\n\n" + "Analyze these failed trades and their patterns, " + "then generate an improved strategy.\n\n" f"Failure Patterns:\n{json.dumps(patterns, indent=2)}\n\n" f"Sample Failed Trades (first 5):\n" f"{json.dumps(failures[:5], indent=2, default=str)}\n\n" @@ -214,7 +211,8 @@ class EvolutionOptimizer: try: response = await self._client.aio.models.generate_content( - model=self._model_name, contents=prompt, + model=self._model_name, + contents=prompt, ) body = response.text.strip() except Exception as exc: @@ -280,9 +278,7 @@ class EvolutionOptimizer: logger.info("Strategy validation PASSED") return True else: - logger.warning( - "Strategy validation FAILED:\n%s", result.stdout + result.stderr - ) + logger.warning("Strategy validation FAILED:\n%s", result.stdout + result.stderr) # Clean up failing strategy strategy_path.unlink(missing_ok=True) return False diff --git a/src/evolution/performance_tracker.py b/src/evolution/performance_tracker.py index fd3476c..c7bc7e1 100644 --- a/src/evolution/performance_tracker.py +++ b/src/evolution/performance_tracker.py @@ -187,9 +187,7 @@ class PerformanceTracker: return metrics - def calculate_improvement_trend( - self, metrics_history: list[StrategyMetrics] - ) -> dict[str, Any]: + def calculate_improvement_trend(self, metrics_history: list[StrategyMetrics]) -> dict[str, Any]: """Calculate improvement trend from historical metrics. Args: @@ -229,9 +227,7 @@ class PerformanceTracker: "period_count": len(metrics_history), } - def generate_dashboard( - self, strategy_name: str | None = None - ) -> PerformanceDashboard: + def generate_dashboard(self, strategy_name: str | None = None) -> PerformanceDashboard: """Generate a comprehensive performance dashboard. Args: @@ -260,9 +256,7 @@ class PerformanceTracker: improvement_trend=improvement_trend, ) - def export_dashboard_json( - self, dashboard: PerformanceDashboard - ) -> str: + def export_dashboard_json(self, dashboard: PerformanceDashboard) -> str: """Export dashboard as JSON string. Args: diff --git a/src/logging/decision_logger.py b/src/logging/decision_logger.py index cd19b28..5a05d84 100644 --- a/src/logging/decision_logger.py +++ b/src/logging/decision_logger.py @@ -140,9 +140,7 @@ class DecisionLogger: ) self.conn.commit() - def update_outcome( - self, decision_id: str, pnl: float, accuracy: int - ) -> None: + def update_outcome(self, decision_id: str, pnl: float, accuracy: int) -> None: """Update the outcome of a decision after trade execution. Args: diff --git a/src/main.py b/src/main.py index da6f3e9..512f4f2 100644 --- a/src/main.py +++ b/src/main.py @@ -26,12 +26,12 @@ from src.context.aggregator import ContextAggregator from src.context.layer import ContextLayer from src.context.scheduler import ContextScheduler from src.context.store import ContextStore -from src.core.criticality import CriticalityAssessor from src.core.blackout_manager import ( BlackoutOrderManager, QueuedOrderIntent, parse_blackout_windows_kst, ) +from src.core.criticality import CriticalityAssessor from src.core.kill_switch import KillSwitchOrchestrator from src.core.order_policy import ( OrderPolicyRejected, @@ -52,12 +52,16 @@ from src.evolution.optimizer import EvolutionOptimizer from src.logging.decision_logger import DecisionLogger from src.logging_config import setup_logging from src.markets.schedule import MARKETS, MarketInfo, get_next_market_open, get_open_markets -from src.notifications.telegram_client import NotificationFilter, TelegramClient, TelegramCommandHandler -from src.strategy.models import DayPlaybook, MarketOutlook +from src.notifications.telegram_client import ( + NotificationFilter, + TelegramClient, + TelegramCommandHandler, +) from src.strategy.exit_rules import ExitRuleConfig, ExitRuleInput, evaluate_exit +from src.strategy.models import DayPlaybook, MarketOutlook from src.strategy.playbook_store import PlaybookStore -from src.strategy.pre_market_planner import PreMarketPlanner from src.strategy.position_state_machine import PositionState +from src.strategy.pre_market_planner import PreMarketPlanner from src.strategy.scenario_engine import ScenarioEngine logger = logging.getLogger(__name__) @@ -350,9 +354,7 @@ async def _inject_staged_exit_features( return if "pred_down_prob" not in market_data: - market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi( - market_data.get("rsi") - ) + market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi(market_data.get("rsi")) existing_atr = safe_float(market_data.get("atr_value"), 0.0) if existing_atr > 0: @@ -389,7 +391,7 @@ async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kw return await coro_factory(*args, **kwargs) except ConnectionError as exc: if attempt < MAX_CONNECTION_RETRIES: - wait_secs = 2 ** attempt + wait_secs = 2**attempt logger.warning( "Connection error %s (attempt %d/%d), retrying in %ds: %s", label, @@ -413,7 +415,7 @@ async def sync_positions_from_broker( broker: Any, overseas_broker: Any, db_conn: Any, - settings: "Settings", + settings: Settings, ) -> int: """Sync open positions from the live broker into the local DB at startup. @@ -441,9 +443,7 @@ async def sync_positions_from_broker( if market.exchange_code in seen_exchange_codes: continue seen_exchange_codes.add(market.exchange_code) - balance_data = await overseas_broker.get_overseas_balance( - market.exchange_code - ) + balance_data = await overseas_broker.get_overseas_balance(market.exchange_code) log_market = market_code # e.g. "US_NASDAQ" except ConnectionError as exc: logger.warning( @@ -453,9 +453,7 @@ async def sync_positions_from_broker( ) continue - held_codes = _extract_held_codes_from_balance( - balance_data, is_domestic=market.is_domestic - ) + held_codes = _extract_held_codes_from_balance(balance_data, is_domestic=market.is_domestic) for stock_code in held_codes: if get_open_position(db_conn, stock_code, log_market): continue # already tracked @@ -487,9 +485,7 @@ async def sync_positions_from_broker( synced += 1 if synced: - logger.info( - "Startup sync complete: %d position(s) synced from broker", synced - ) + logger.info("Startup sync complete: %d position(s) synced from broker", synced) else: logger.info("Startup sync: no new positions to sync from broker") return synced @@ -859,15 +855,9 @@ def _apply_staged_exit_override_for_hold( pnl_pct = (current_price - entry_price) / entry_price * 100.0 if exit_eval.reason == "hard_stop": - rationale = ( - f"Stop-loss triggered ({pnl_pct:.2f}% <= " - f"{stop_loss_threshold:.2f}%)" - ) + rationale = f"Stop-loss triggered ({pnl_pct:.2f}% <= {stop_loss_threshold:.2f}%)" elif exit_eval.reason == "arm_take_profit": - rationale = ( - f"Take-profit triggered ({pnl_pct:.2f}% >= " - f"{arm_pct:.2f}%)" - ) + rationale = f"Take-profit triggered ({pnl_pct:.2f}% >= {arm_pct:.2f}%)" elif exit_eval.reason == "atr_trailing_stop": rationale = "ATR trailing-stop triggered" elif exit_eval.reason == "be_lock_threat": @@ -978,7 +968,10 @@ def _maybe_queue_order_intent( ) if queued: logger.warning( - "Blackout active: queued order intent %s %s (%s) qty=%d price=%.4f source=%s pending=%d", + ( + "Blackout active: queued order intent %s %s (%s) " + "qty=%d price=%.4f source=%s pending=%d" + ), order_type, stock_code, market.code, @@ -1071,7 +1064,10 @@ async def process_blackout_recovery_orders( ) if queued_price <= 0 or current_price <= 0: logger.info( - "Drop queued intent by price revalidation (invalid price): %s %s (%s) queued=%.4f current=%.4f", + ( + "Drop queued intent by price revalidation (invalid price): " + "%s %s (%s) queued=%.4f current=%.4f" + ), intent.order_type, intent.stock_code, market.code, @@ -1082,7 +1078,10 @@ async def process_blackout_recovery_orders( drift_pct = abs(current_price - queued_price) / queued_price * 100.0 if drift_pct > max_drift_pct: logger.info( - "Drop queued intent by price revalidation: %s %s (%s) queued=%.4f current=%.4f drift=%.2f%% max=%.2f%%", + ( + "Drop queued intent by price revalidation: %s %s (%s) " + "queued=%.4f current=%.4f drift=%.2f%% max=%.2f%%" + ), intent.order_type, intent.stock_code, market.code, @@ -1375,24 +1374,18 @@ async def trading_cycle( # 1. Fetch market data price_output: dict[str, Any] = {} # Populated for overseas markets; used for fallback metrics if market.is_domestic: - current_price, price_change_pct, foreigner_net = await broker.get_current_price( - stock_code - ) + current_price, price_change_pct, foreigner_net = await broker.get_current_price(stock_code) balance_data = await broker.get_balance() output2 = balance_data.get("output2", [{}]) total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0 total_cash = safe_float( - balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") - if output2 - else "0" + balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") if output2 else "0" ) purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0 else: # Overseas market - price_data = await overseas_broker.get_overseas_price( - market.exchange_code, stock_code - ) + price_data = await overseas_broker.get_overseas_price(market.exchange_code, stock_code) balance_data = await overseas_broker.get_overseas_balance(market.exchange_code) output2 = balance_data.get("output2", [{}]) @@ -1459,11 +1452,7 @@ async def trading_cycle( total_cash = settings.PAPER_OVERSEAS_CASH # Calculate daily P&L % - pnl_pct = ( - ((total_eval - purchase_total) / purchase_total * 100) - if purchase_total > 0 - else 0.0 - ) + pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0 market_data: dict[str, Any] = { "stock_code": stock_code, @@ -1491,11 +1480,13 @@ async def trading_cycle( market_data["rsi"] = max(0.0, min(100.0, 50.0 + price_change_pct * 2.0)) if price_output and current_price > 0: pr_high = safe_float( - price_output.get("high") or price_output.get("ovrs_hgpr") + price_output.get("high") + or price_output.get("ovrs_hgpr") or price_output.get("stck_hgpr") ) pr_low = safe_float( - price_output.get("low") or price_output.get("ovrs_lwpr") + price_output.get("low") + or price_output.get("ovrs_lwpr") or price_output.get("stck_lwpr") ) if pr_high > 0 and pr_low > 0 and pr_high >= pr_low: @@ -1512,9 +1503,7 @@ async def trading_cycle( if open_pos and current_price > 0: entry_price = safe_float(open_pos.get("price"), 0.0) if entry_price > 0: - market_data["unrealized_pnl_pct"] = ( - (current_price - entry_price) / entry_price * 100 - ) + market_data["unrealized_pnl_pct"] = (current_price - entry_price) / entry_price * 100 entry_ts = open_pos.get("timestamp") if entry_ts: try: @@ -1745,16 +1734,19 @@ async def trading_cycle( stock_playbook=stock_playbook, settings=settings, ) - if open_position and decision.action == "HOLD" and _should_force_exit_for_overnight( + if ( + open_position + and decision.action == "HOLD" + and _should_force_exit_for_overnight( market=market, settings=settings, + ) ): decision = TradeDecision( action="SELL", confidence=max(decision.confidence, 85), rationale=( - "Forced exit by overnight policy" - " (session close window / kill switch priority)" + "Forced exit by overnight policy (session close window / kill switch priority)" ), ) logger.info( @@ -1834,9 +1826,7 @@ async def trading_cycle( return broker_held_qty = ( - _extract_held_qty_from_balance( - balance_data, stock_code, is_domestic=market.is_domestic - ) + _extract_held_qty_from_balance(balance_data, stock_code, is_domestic=market.is_domestic) if decision.action == "SELL" else 0 ) @@ -1871,7 +1861,10 @@ async def trading_cycle( ) if fx_blocked: logger.warning( - "Skip BUY %s (%s): FX buffer guard (remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)", + ( + "Skip BUY %s (%s): FX buffer guard " + "(remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)" + ), stock_code, market.name, remaining_cash, @@ -2068,8 +2061,7 @@ async def trading_cycle( action="SELL", confidence=0, rationale=( - "[ghost-close] Broker reported no balance;" - " position closed without fill" + "[ghost-close] Broker reported no balance; position closed without fill" ), quantity=0, price=0.0, @@ -2275,17 +2267,13 @@ async def handle_domestic_pending_orders( outcome="cancelled", ) except Exception as notify_exc: - logger.warning( - "notify_unfilled_order failed: %s", notify_exc - ) + logger.warning("notify_unfilled_order failed: %s", notify_exc) else: # First unfilled SELL → resubmit at last * 0.996 (-0.4%). try: last_price, _, _ = await broker.get_current_price(stock_code) if last_price <= 0: - raise ValueError( - f"Invalid price ({last_price}) for {stock_code}" - ) + raise ValueError(f"Invalid price ({last_price}) for {stock_code}") new_price = kr_round_down(last_price * 0.996) validate_order_policy( market=MARKETS["KR"], @@ -2298,9 +2286,7 @@ async def handle_domestic_pending_orders( quantity=psbl_qty, price=new_price, ) - sell_resubmit_counts[key] = ( - sell_resubmit_counts.get(key, 0) + 1 - ) + sell_resubmit_counts[key] = sell_resubmit_counts.get(key, 0) + 1 try: await telegram.notify_unfilled_order( stock_code=stock_code, @@ -2311,9 +2297,7 @@ async def handle_domestic_pending_orders( new_price=float(new_price), ) except Exception as notify_exc: - logger.warning( - "notify_unfilled_order failed: %s", notify_exc - ) + logger.warning("notify_unfilled_order failed: %s", notify_exc) except Exception as exc: logger.error( "SELL resubmit failed for KR %s: %s", @@ -2381,9 +2365,7 @@ async def handle_overseas_pending_orders( try: orders = await overseas_broker.get_overseas_pending_orders(exchange_code) except Exception as exc: - logger.warning( - "Failed to fetch pending orders for %s: %s", exchange_code, exc - ) + logger.warning("Failed to fetch pending orders for %s: %s", exchange_code, exc) continue for order in orders: @@ -2448,26 +2430,21 @@ async def handle_overseas_pending_orders( outcome="cancelled", ) except Exception as notify_exc: - logger.warning( - "notify_unfilled_order failed: %s", notify_exc - ) + logger.warning("notify_unfilled_order failed: %s", notify_exc) else: # First unfilled SELL → resubmit at last * 0.996 (-0.4%). try: price_data = await overseas_broker.get_overseas_price( order_exchange, stock_code ) - last_price = float( - price_data.get("output", {}).get("last", "0") or "0" - ) + last_price = float(price_data.get("output", {}).get("last", "0") or "0") if last_price <= 0: - raise ValueError( - f"Invalid price ({last_price}) for {stock_code}" - ) + raise ValueError(f"Invalid price ({last_price}) for {stock_code}") new_price = round(last_price * 0.996, 4) market_info = next( ( - m for m in MARKETS.values() + m + for m in MARKETS.values() if m.exchange_code == order_exchange and not m.is_domestic ), None, @@ -2485,9 +2462,7 @@ async def handle_overseas_pending_orders( quantity=nccs_qty, price=new_price, ) - sell_resubmit_counts[key] = ( - sell_resubmit_counts.get(key, 0) + 1 - ) + sell_resubmit_counts[key] = sell_resubmit_counts.get(key, 0) + 1 try: await telegram.notify_unfilled_order( stock_code=stock_code, @@ -2498,9 +2473,7 @@ async def handle_overseas_pending_orders( new_price=new_price, ) except Exception as notify_exc: - logger.warning( - "notify_unfilled_order failed: %s", notify_exc - ) + logger.warning("notify_unfilled_order failed: %s", notify_exc) except Exception as exc: logger.error( "SELL resubmit failed for %s %s: %s", @@ -2659,13 +2632,16 @@ async def run_daily_session( logger.warning("Playbook notification failed: %s", exc) logger.info( "Generated playbook for %s: %d stocks, %d scenarios", - market.code, playbook.stock_count, playbook.scenario_count, + market.code, + playbook.stock_count, + playbook.scenario_count, ) except Exception as exc: logger.error("Playbook generation failed for %s: %s", market.code, exc) try: await telegram.notify_playbook_failed( - market=market.code, reason=str(exc)[:200], + market=market.code, + reason=str(exc)[:200], ) except Exception as notify_exc: logger.warning("Playbook failed notification error: %s", notify_exc) @@ -2676,12 +2652,10 @@ async def run_daily_session( for stock_code in watchlist: try: if market.is_domestic: - current_price, price_change_pct, foreigner_net = ( - await _retry_connection( - broker.get_current_price, - stock_code, - label=stock_code, - ) + current_price, price_change_pct, foreigner_net = await _retry_connection( + broker.get_current_price, + stock_code, + label=stock_code, ) else: price_data = await _retry_connection( @@ -2690,9 +2664,7 @@ async def run_daily_session( stock_code, label=f"{stock_code}@{market.exchange_code}", ) - current_price = safe_float( - price_data.get("output", {}).get("last", "0") - ) + current_price = safe_float(price_data.get("output", {}).get("last", "0")) # Fallback: if price API returns 0, use scanner candidate price if current_price <= 0: cand_lookup = candidate_map.get(stock_code) @@ -2704,9 +2676,7 @@ async def run_daily_session( ) current_price = cand_lookup.price foreigner_net = 0.0 - price_change_pct = safe_float( - price_data.get("output", {}).get("rate", "0") - ) + price_change_pct = safe_float(price_data.get("output", {}).get("rate", "0")) # Fall back to scanner candidate price if API returns 0. if current_price <= 0: cand_lookup = candidate_map.get(stock_code) @@ -2769,15 +2739,9 @@ async def run_daily_session( if market.is_domestic: output2 = balance_data.get("output2", [{}]) - total_eval = safe_float( - output2[0].get("tot_evlu_amt", "0") - ) if output2 else 0 - total_cash = safe_float( - output2[0].get("dnca_tot_amt", "0") - ) if output2 else 0 - purchase_total = safe_float( - output2[0].get("pchs_amt_smtl_amt", "0") - ) if output2 else 0 + total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0 + total_cash = safe_float(output2[0].get("dnca_tot_amt", "0")) if output2 else 0 + purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0 else: output2 = balance_data.get("output2", [{}]) if isinstance(output2, list) and output2: @@ -2788,18 +2752,15 @@ async def run_daily_session( balance_info = {} total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0") - purchase_total = safe_float( - balance_info.get("frcr_buy_amt_smtl", "0") or "0" - ) + purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0") # Fetch available foreign currency cash via inquire-psamount (TTTS3007R/VTTS3007R). - # TTTS3012R output2 does not include a cash/deposit field — frcr_dncl_amt_2 does not exist. + # TTTS3012R output2 does not include a cash/deposit field. + # frcr_dncl_amt_2 does not exist. # Use the first stock with a valid price as the reference for the buying power query. # Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트 total_cash = 0.0 - ref_stock = next( - (s for s in stocks_data if s.get("current_price", 0) > 0), None - ) + ref_stock = next((s for s in stocks_data if s.get("current_price", 0) > 0), None) if ref_stock: try: ps_data = await overseas_broker.get_overseas_buying_power( @@ -2819,11 +2780,7 @@ async def run_daily_session( # Paper mode fallback: VTS overseas balance API often fails for many accounts. # Only activate in paper mode — live mode must use real balance from KIS. - if ( - total_cash <= 0 - and settings.MODE == "paper" - and settings.PAPER_OVERSEAS_CASH > 0 - ): + if total_cash <= 0 and settings.MODE == "paper" and settings.PAPER_OVERSEAS_CASH > 0: total_cash = settings.PAPER_OVERSEAS_CASH # Capture the day's opening portfolio value on the first market processed @@ -2856,13 +2813,17 @@ async def run_daily_session( # Evaluate scenarios for each stock (local, no API calls) logger.info( "Evaluating %d stocks against playbook for %s", - len(stocks_data), market.name, + len(stocks_data), + market.name, ) for stock_data in stocks_data: stock_code = stock_data["stock_code"] stock_playbook = playbook.get_stock_playbook(stock_code) match = scenario_engine.evaluate( - playbook, stock_code, stock_data, portfolio_data, + playbook, + stock_code, + stock_data, + portfolio_data, ) decision = TradeDecision( action=match.action.value, @@ -2969,9 +2930,13 @@ async def run_daily_session( stock_playbook=stock_playbook, settings=settings, ) - if daily_open and decision.action == "HOLD" and _should_force_exit_for_overnight( - market=market, - settings=settings, + if ( + daily_open + and decision.action == "HOLD" + and _should_force_exit_for_overnight( + market=market, + settings=settings, + ) ): decision = TradeDecision( action="SELL", @@ -3063,16 +3028,21 @@ async def run_daily_session( ) continue order_amount = stock_data["current_price"] * quantity - fx_blocked, remaining_cash, required_buffer = _should_block_overseas_buy_for_fx_buffer( - market=market, - action=decision.action, - total_cash=total_cash, - order_amount=order_amount, - settings=settings, + fx_blocked, remaining_cash, required_buffer = ( + _should_block_overseas_buy_for_fx_buffer( + market=market, + action=decision.action, + total_cash=total_cash, + order_amount=order_amount, + settings=settings, + ) ) if fx_blocked: logger.warning( - "Skip BUY %s (%s): FX buffer guard (remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)", + ( + "Skip BUY %s (%s): FX buffer guard " + "(remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)" + ), stock_code, market.name, remaining_cash, @@ -3090,7 +3060,10 @@ async def run_daily_session( if now < daily_cooldown_until: remaining = int(daily_cooldown_until - now) logger.info( - "Skip BUY %s (%s): insufficient-balance cooldown active (%ds remaining)", + ( + "Skip BUY %s (%s): insufficient-balance cooldown active " + "(%ds remaining)" + ), stock_code, market.name, remaining, @@ -3149,13 +3122,9 @@ async def run_daily_session( # Use limit orders (지정가) for domestic stocks. # KRX tick rounding applied via kr_round_down. if decision.action == "BUY": - order_price = kr_round_down( - stock_data["current_price"] * 1.002 - ) + order_price = kr_round_down(stock_data["current_price"] * 1.002) else: - order_price = kr_round_down( - stock_data["current_price"] * 0.998 - ) + order_price = kr_round_down(stock_data["current_price"] * 0.998) try: validate_order_policy( market=market, @@ -3260,9 +3229,7 @@ async def run_daily_session( except Exception as exc: logger.warning("Telegram notification failed: %s", exc) except Exception as exc: - logger.error( - "Order execution failed for %s: %s", stock_code, exc - ) + logger.error("Order execution failed for %s: %s", stock_code, exc) continue if decision.action == "SELL" and order_succeeded: @@ -3286,7 +3253,9 @@ async def run_daily_session( accuracy=1 if trade_pnl > 0 else 0, ) if trade_pnl < 0: - cooldown_key = _stoploss_cooldown_key(market=market, stock_code=stock_code) + cooldown_key = _stoploss_cooldown_key( + market=market, stock_code=stock_code + ) cooldown_minutes = _stoploss_cooldown_minutes( settings, market=market, @@ -3369,7 +3338,8 @@ async def _handle_market_close( def _run_context_scheduler( - scheduler: ContextScheduler, now: datetime | None = None, + scheduler: ContextScheduler, + now: datetime | None = None, ) -> None: """Run periodic context scheduler tasks and log when anything executes.""" result = scheduler.run_if_due(now=now) @@ -3438,6 +3408,7 @@ def _start_dashboard_server(settings: Settings) -> threading.Thread | None: # reported synchronously (avoids the misleading "started" → "failed" log pair). try: import uvicorn # noqa: F401 + from src.dashboard import create_dashboard_app # noqa: F401 except ImportError as exc: logger.warning("Dashboard server unavailable (missing dependency): %s", exc) @@ -3446,6 +3417,7 @@ def _start_dashboard_server(settings: Settings) -> threading.Thread | None: def _serve() -> None: try: import uvicorn + from src.dashboard import create_dashboard_app app = create_dashboard_app(settings.DB_PATH, mode=settings.MODE) @@ -3586,8 +3558,7 @@ async def run(settings: Settings) -> None: pause_trading.set() logger.info("Trading resumed via Telegram command") await telegram.send_message( - "▶️ Trading Resumed\n\n" - "Trading operations have been restarted." + "▶️ Trading Resumed\n\nTrading operations have been restarted." ) async def handle_status() -> None: @@ -3630,9 +3601,7 @@ async def run(settings: Settings) -> None: except Exception as exc: logger.error("Error in /status handler: %s", exc) - await telegram.send_message( - "⚠️ Error\n\nFailed to retrieve trading status." - ) + await telegram.send_message("⚠️ Error\n\nFailed to retrieve trading status.") async def handle_positions() -> None: """Handle /positions command - show account summary.""" @@ -3643,8 +3612,7 @@ async def run(settings: Settings) -> None: if not output2: await telegram.send_message( - "💼 Account Summary\n\n" - "No balance information available." + "💼 Account Summary\n\nNo balance information available." ) return @@ -3673,9 +3641,7 @@ async def run(settings: Settings) -> None: except Exception as exc: logger.error("Error in /positions handler: %s", exc) - await telegram.send_message( - "⚠️ Error\n\nFailed to retrieve positions." - ) + await telegram.send_message("⚠️ Error\n\nFailed to retrieve positions.") async def handle_report() -> None: """Handle /report command - show daily summary metrics.""" @@ -3719,9 +3685,7 @@ async def run(settings: Settings) -> None: ) except Exception as exc: logger.error("Error in /report handler: %s", exc) - await telegram.send_message( - "⚠️ Error\n\nFailed to generate daily report." - ) + await telegram.send_message("⚠️ Error\n\nFailed to generate daily report.") async def handle_scenarios() -> None: """Handle /scenarios command - show today's playbook scenarios.""" @@ -3770,9 +3734,7 @@ async def run(settings: Settings) -> None: await telegram.send_message("\n".join(lines).strip()) except Exception as exc: logger.error("Error in /scenarios handler: %s", exc) - await telegram.send_message( - "⚠️ Error\n\nFailed to retrieve scenarios." - ) + await telegram.send_message("⚠️ Error\n\nFailed to retrieve scenarios.") async def handle_review() -> None: """Handle /review command - show recent scorecards.""" @@ -3788,9 +3750,7 @@ async def run(settings: Settings) -> None: ).fetchall() if not rows: - await telegram.send_message( - "📝 Recent Reviews\n\nNo scorecards available." - ) + await telegram.send_message("📝 Recent Reviews\n\nNo scorecards available.") return lines = ["📝 Recent Reviews", ""] @@ -3808,9 +3768,7 @@ async def run(settings: Settings) -> None: await telegram.send_message("\n".join(lines)) except Exception as exc: logger.error("Error in /review handler: %s", exc) - await telegram.send_message( - "⚠️ Error\n\nFailed to retrieve reviews." - ) + await telegram.send_message("⚠️ Error\n\nFailed to retrieve reviews.") async def handle_notify(args: list[str]) -> None: """Handle /notify [key] [on|off] — query or change notification filters.""" @@ -3845,8 +3803,7 @@ async def run(settings: Settings) -> None: else: valid = ", ".join(list(status.keys()) + ["all"]) await telegram.send_message( - f"❌ 알 수 없는 키: {key}\n" - f"유효한 키: {valid}" + f"❌ 알 수 없는 키: {key}\n유효한 키: {valid}" ) return @@ -3858,30 +3815,22 @@ async def run(settings: Settings) -> None: value = toggle == "on" if telegram.set_notification(key, value): icon = "✅" if value else "❌" - label = f"전체 알림" if key == "all" else f"{key} 알림" + label = "전체 알림" if key == "all" else f"{key} 알림" state = "켜짐" if value else "꺼짐" await telegram.send_message(f"{icon} {label} → {state}") logger.info("Notification filter changed via Telegram: %s=%s", key, value) else: valid = ", ".join(list(telegram.filter_status().keys()) + ["all"]) - await telegram.send_message( - f"❌ 알 수 없는 키: {key}\n" - f"유효한 키: {valid}" - ) + await telegram.send_message(f"❌ 알 수 없는 키: {key}\n유효한 키: {valid}") async def handle_dashboard() -> None: """Handle /dashboard command - show dashboard URL if enabled.""" if not settings.DASHBOARD_ENABLED: - await telegram.send_message( - "🖥️ Dashboard\n\nDashboard is not enabled." - ) + await telegram.send_message("🖥️ Dashboard\n\nDashboard is not enabled.") return url = f"http://{settings.DASHBOARD_HOST}:{settings.DASHBOARD_PORT}" - await telegram.send_message( - "🖥️ Dashboard\n\n" - f"URL: {url}" - ) + await telegram.send_message(f"🖥️ Dashboard\n\nURL: {url}") command_handler.register_command("help", handle_help) command_handler.register_command("stop", handle_stop) @@ -4182,9 +4131,7 @@ async def run(settings: Settings) -> None: ) # Store candidates per market for selection context logging - scan_candidates[market.code] = { - c.stock_code: c for c in candidates - } + scan_candidates[market.code] = {c.stock_code: c for c in candidates} logger.info( "Smart Scanner: Found %d candidates for %s: %s", @@ -4194,9 +4141,7 @@ async def run(settings: Settings) -> None: ) # Get market-local date for playbook keying - market_today = datetime.now( - market.timezone - ).date() + market_today = datetime.now(market.timezone).date() # Load or generate playbook (1 Gemini call per market per day) if market.code not in playbooks: @@ -4234,7 +4179,8 @@ async def run(settings: Settings) -> None: except Exception as exc: logger.error( "Playbook generation failed for %s: %s", - market.code, exc, + market.code, + exc, ) try: await telegram.notify_playbook_failed( @@ -4279,7 +4225,8 @@ async def run(settings: Settings) -> None: except Exception as exc: logger.warning( "Failed to fetch holdings for %s: %s — skipping holdings merge", - market.name, exc, + market.name, + exc, ) held_codes = [] @@ -4288,7 +4235,8 @@ async def run(settings: Settings) -> None: if extra_held: logger.info( "Holdings added to loop for %s (not in scanner): %s", - market.name, extra_held, + market.name, + extra_held, ) if not stock_codes: diff --git a/src/markets/schedule.py b/src/markets/schedule.py index 9d142d9..a87408e 100644 --- a/src/markets/schedule.py +++ b/src/markets/schedule.py @@ -211,9 +211,7 @@ def get_open_markets( return is_market_open(market, now) open_markets = [ - MARKETS[code] - for code in enabled_markets - if code in MARKETS and is_available(MARKETS[code]) + MARKETS[code] for code in enabled_markets if code in MARKETS and is_available(MARKETS[code]) ] return sorted(open_markets, key=lambda m: m.code) @@ -282,9 +280,7 @@ def get_next_market_open( # Calculate next open time for this market for days_ahead in range(7): # Check next 7 days check_date = market_now.date() + timedelta(days=days_ahead) - check_datetime = datetime.combine( - check_date, market.open_time, tzinfo=market.timezone - ) + check_datetime = datetime.combine(check_date, market.open_time, tzinfo=market.timezone) # Skip weekends if check_datetime.weekday() >= 5: diff --git a/src/notifications/telegram_client.py b/src/notifications/telegram_client.py index 0030645..381c5dd 100644 --- a/src/notifications/telegram_client.py +++ b/src/notifications/telegram_client.py @@ -4,7 +4,7 @@ import asyncio import logging import time from collections.abc import Awaitable, Callable -from dataclasses import dataclass, fields +from dataclasses import dataclass from enum import Enum from typing import ClassVar @@ -136,14 +136,14 @@ class TelegramClient: self._enabled = enabled self._rate_limiter = LeakyBucket(rate=rate_limit) self._session: aiohttp.ClientSession | None = None - self._filter = notification_filter if notification_filter is not None else NotificationFilter() + self._filter = ( + notification_filter if notification_filter is not None else NotificationFilter() + ) if not enabled: logger.info("Telegram notifications disabled via configuration") elif bot_token is None or chat_id is None: - logger.warning( - "Telegram notifications disabled (missing bot_token or chat_id)" - ) + logger.warning("Telegram notifications disabled (missing bot_token or chat_id)") self._enabled = False else: logger.info("Telegram notifications enabled for chat_id=%s", chat_id) @@ -209,14 +209,12 @@ class TelegramClient: async with session.post(url, json=payload) as resp: if resp.status != 200: error_text = await resp.text() - logger.error( - "Telegram API error (status=%d): %s", resp.status, error_text - ) + logger.error("Telegram API error (status=%d): %s", resp.status, error_text) return False logger.debug("Telegram message sent: %s", text[:50]) return True - except asyncio.TimeoutError: + except TimeoutError: logger.error("Telegram message timeout") return False except aiohttp.ClientError as exc: @@ -305,9 +303,7 @@ class TelegramClient: NotificationMessage(priority=NotificationPriority.LOW, message=message) ) - async def notify_circuit_breaker( - self, pnl_pct: float, threshold: float - ) -> None: + async def notify_circuit_breaker(self, pnl_pct: float, threshold: float) -> None: """ Notify circuit breaker activation. @@ -354,9 +350,7 @@ class TelegramClient: NotificationMessage(priority=NotificationPriority.HIGH, message=message) ) - async def notify_system_start( - self, mode: str, enabled_markets: list[str] - ) -> None: + async def notify_system_start(self, mode: str, enabled_markets: list[str]) -> None: """ Notify system startup. @@ -369,9 +363,7 @@ class TelegramClient: mode_emoji = "📝" if mode == "paper" else "💰" markets_str = ", ".join(enabled_markets) message = ( - f"{mode_emoji} System Started\n" - f"Mode: {mode.upper()}\n" - f"Markets: {markets_str}" + f"{mode_emoji} System Started\nMode: {mode.upper()}\nMarkets: {markets_str}" ) await self._send_notification( NotificationMessage(priority=NotificationPriority.MEDIUM, message=message) @@ -445,11 +437,7 @@ class TelegramClient: """ if not self._filter.playbook: return - message = ( - f"Playbook Failed\n" - f"Market: {market}\n" - f"Reason: {reason[:200]}" - ) + message = f"Playbook Failed\nMarket: {market}\nReason: {reason[:200]}" await self._send_notification( NotificationMessage(priority=NotificationPriority.HIGH, message=message) ) @@ -469,9 +457,7 @@ class TelegramClient: if "circuit breaker" in reason.lower() else NotificationPriority.MEDIUM ) - await self._send_notification( - NotificationMessage(priority=priority, message=message) - ) + await self._send_notification(NotificationMessage(priority=priority, message=message)) async def notify_unfilled_order( self, @@ -496,11 +482,7 @@ class TelegramClient: return # SELL resubmit is high priority — position liquidation at risk. # BUY cancel is medium priority — only cash is freed. - priority = ( - NotificationPriority.HIGH - if action == "SELL" - else NotificationPriority.MEDIUM - ) + priority = NotificationPriority.HIGH if action == "SELL" else NotificationPriority.MEDIUM outcome_emoji = "🔄" if outcome == "resubmitted" else "❌" outcome_label = "재주문" if outcome == "resubmitted" else "취소됨" action_emoji = "🔴" if action == "SELL" else "🟢" @@ -515,9 +497,7 @@ class TelegramClient: message = "\n".join(lines) await self._send_notification(NotificationMessage(priority=priority, message=message)) - async def notify_error( - self, error_type: str, error_msg: str, context: str - ) -> None: + async def notify_error(self, error_type: str, error_msg: str, context: str) -> None: """ Notify system error. @@ -541,9 +521,7 @@ class TelegramClient: class TelegramCommandHandler: """Handles incoming Telegram commands via long polling.""" - def __init__( - self, client: TelegramClient, polling_interval: float = 1.0 - ) -> None: + def __init__(self, client: TelegramClient, polling_interval: float = 1.0) -> None: """ Initialize command handler. @@ -559,9 +537,7 @@ class TelegramCommandHandler: self._polling_task: asyncio.Task[None] | None = None self._running = False - def register_command( - self, command: str, handler: Callable[[], Awaitable[None]] - ) -> None: + def register_command(self, command: str, handler: Callable[[], Awaitable[None]]) -> None: """ Register a command handler (no arguments). @@ -672,7 +648,7 @@ class TelegramCommandHandler: return updates - except asyncio.TimeoutError: + except TimeoutError: logger.debug("getUpdates timeout (normal)") return [] except aiohttp.ClientError as exc: @@ -697,9 +673,7 @@ class TelegramCommandHandler: # Verify chat_id matches configured chat chat_id = str(message.get("chat", {}).get("id", "")) if chat_id != self._client._chat_id: - logger.warning( - "Ignoring command from unauthorized chat_id: %s", chat_id - ) + logger.warning("Ignoring command from unauthorized chat_id: %s", chat_id) return # Extract command text diff --git a/src/strategy/models.py b/src/strategy/models.py index f7090f7..68375da 100644 --- a/src/strategy/models.py +++ b/src/strategy/models.py @@ -8,12 +8,12 @@ Defines the data contracts for the proactive strategy system: from __future__ import annotations from datetime import UTC, date, datetime -from enum import Enum +from enum import StrEnum from pydantic import BaseModel, Field, field_validator -class ScenarioAction(str, Enum): +class ScenarioAction(StrEnum): """Actions that can be taken by scenarios.""" BUY = "BUY" @@ -22,7 +22,7 @@ class ScenarioAction(str, Enum): REDUCE_ALL = "REDUCE_ALL" -class MarketOutlook(str, Enum): +class MarketOutlook(StrEnum): """AI's assessment of market direction.""" BULLISH = "bullish" @@ -32,7 +32,7 @@ class MarketOutlook(str, Enum): BEARISH = "bearish" -class PlaybookStatus(str, Enum): +class PlaybookStatus(StrEnum): """Lifecycle status of a playbook.""" PENDING = "pending" diff --git a/src/strategy/playbook_store.py b/src/strategy/playbook_store.py index 4b47356..95f2a2f 100644 --- a/src/strategy/playbook_store.py +++ b/src/strategy/playbook_store.py @@ -6,7 +6,6 @@ Designed for the pre-market strategy system (one playbook per market per day). from __future__ import annotations -import json import logging import sqlite3 from datetime import date @@ -53,8 +52,10 @@ class PlaybookStore: row_id = cursor.lastrowid or 0 logger.info( "Saved playbook for %s/%s (%d stocks, %d scenarios)", - playbook.date, playbook.market, - playbook.stock_count, playbook.scenario_count, + playbook.date, + playbook.market, + playbook.stock_count, + playbook.scenario_count, ) return row_id diff --git a/src/strategy/position_state_machine.py b/src/strategy/position_state_machine.py index 6a9e3a6..79f993f 100644 --- a/src/strategy/position_state_machine.py +++ b/src/strategy/position_state_machine.py @@ -6,10 +6,10 @@ State progression is monotonic (promotion-only) except terminal EXITED. from __future__ import annotations from dataclasses import dataclass -from enum import Enum +from enum import StrEnum -class PositionState(str, Enum): +class PositionState(StrEnum): HOLDING = "HOLDING" BE_LOCK = "BE_LOCK" ARMED = "ARMED" @@ -40,12 +40,7 @@ def evaluate_exit_first(inp: StateTransitionInput) -> bool: EXITED must be evaluated before any promotion. """ - return ( - inp.hard_stop_hit - or inp.trailing_stop_hit - or inp.model_exit_signal - or inp.be_lock_threat - ) + return inp.hard_stop_hit or inp.trailing_stop_hit or inp.model_exit_signal or inp.be_lock_threat def promote_state(current: PositionState, inp: StateTransitionInput) -> PositionState: diff --git a/src/strategy/pre_market_planner.py b/src/strategy/pre_market_planner.py index 1f30b11..7370a16 100644 --- a/src/strategy/pre_market_planner.py +++ b/src/strategy/pre_market_planner.py @@ -124,12 +124,14 @@ class PreMarketPlanner: # 4. Parse response playbook = self._parse_response( - decision.rationale, today, market, candidates, cross_market, + decision.rationale, + today, + market, + candidates, + cross_market, current_holdings=current_holdings, ) - playbook_with_tokens = playbook.model_copy( - update={"token_count": decision.token_count} - ) + playbook_with_tokens = playbook.model_copy(update={"token_count": decision.token_count}) logger.info( "Generated playbook for %s: %d stocks, %d scenarios, %d tokens", market, @@ -146,7 +148,9 @@ class PreMarketPlanner: return self._empty_playbook(today, market) def build_cross_market_context( - self, target_market: str, today: date | None = None, + self, + target_market: str, + today: date | None = None, ) -> CrossMarketContext | None: """Build cross-market context from the other market's L6 data. @@ -192,7 +196,9 @@ class PreMarketPlanner: ) def build_self_market_scorecard( - self, market: str, today: date | None = None, + self, + market: str, + today: date | None = None, ) -> dict[str, Any] | None: """Build previous-day scorecard for the same market.""" if today is None: @@ -320,18 +326,18 @@ class PreMarketPlanner: f"{context_text}\n" f"## Instructions\n" f"Return a JSON object with this exact structure:\n" - f'{{\n' + f"{{\n" f' "market_outlook": "bullish|neutral_to_bullish|neutral' f'|neutral_to_bearish|bearish",\n' f' "global_rules": [\n' f' {{"condition": "portfolio_pnl_pct < -2.0",' f' "action": "REDUCE_ALL", "rationale": "..."}}\n' - f' ],\n' + f" ],\n" f' "stocks": [\n' - f' {{\n' + f" {{\n" f' "stock_code": "...",\n' f' "scenarios": [\n' - f' {{\n' + f" {{\n" f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0,' f' "unrealized_pnl_pct_above": 3.0, "holding_days_above": 5}},\n' f' "action": "BUY|SELL|HOLD",\n' @@ -340,11 +346,11 @@ class PreMarketPlanner: f' "stop_loss_pct": -2.0,\n' f' "take_profit_pct": 3.0,\n' f' "rationale": "..."\n' - f' }}\n' - f' ]\n' - f' }}\n' - f' ]\n' - f'}}\n\n' + f" }}\n" + f" ]\n" + f" }}\n" + f" ]\n" + f"}}\n\n" f"Rules:\n" f"- Max {max_scenarios} scenarios per stock\n" f"- Candidates list is the primary source for BUY candidates\n" @@ -575,8 +581,7 @@ class PreMarketPlanner: stop_loss_pct=-3.0, take_profit_pct=5.0, rationale=( - f"Rule-based BUY: oversold signal, " - f"RSI={c.rsi:.0f} (fallback planner)" + f"Rule-based BUY: oversold signal, RSI={c.rsi:.0f} (fallback planner)" ), ) ) diff --git a/src/strategy/scenario_engine.py b/src/strategy/scenario_engine.py index f1cd530..bf8f217 100644 --- a/src/strategy/scenario_engine.py +++ b/src/strategy/scenario_engine.py @@ -107,7 +107,9 @@ class ScenarioEngine: # 2. Find stock playbook stock_pb = playbook.get_stock_playbook(stock_code) if stock_pb is None: - logger.debug("No playbook for %s — defaulting to %s", stock_code, playbook.default_action) + logger.debug( + "No playbook for %s — defaulting to %s", stock_code, playbook.default_action + ) return ScenarioMatch( stock_code=stock_code, matched_scenario=None, @@ -135,7 +137,9 @@ class ScenarioEngine: ) # 4. No match — default action - logger.debug("No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action) + logger.debug( + "No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action + ) return ScenarioMatch( stock_code=stock_code, matched_scenario=None, @@ -198,17 +202,27 @@ class ScenarioEngine: checks.append(price is not None and price < condition.price_below) price_change_pct = self._safe_float(market_data.get("price_change_pct")) - if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: + if ( + condition.price_change_pct_above is not None + or condition.price_change_pct_below is not None + ): if "price_change_pct" not in market_data: self._warn_missing_key("price_change_pct") if condition.price_change_pct_above is not None: - checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above) + checks.append( + price_change_pct is not None and price_change_pct > condition.price_change_pct_above + ) if condition.price_change_pct_below is not None: - checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below) + checks.append( + price_change_pct is not None and price_change_pct < condition.price_change_pct_below + ) # Position-aware conditions unrealized_pnl_pct = self._safe_float(market_data.get("unrealized_pnl_pct")) - if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None: + if ( + condition.unrealized_pnl_pct_above is not None + or condition.unrealized_pnl_pct_below is not None + ): if "unrealized_pnl_pct" not in market_data: self._warn_missing_key("unrealized_pnl_pct") if condition.unrealized_pnl_pct_above is not None: @@ -227,15 +241,9 @@ class ScenarioEngine: if "holding_days" not in market_data: self._warn_missing_key("holding_days") if condition.holding_days_above is not None: - checks.append( - holding_days is not None - and holding_days > condition.holding_days_above - ) + checks.append(holding_days is not None and holding_days > condition.holding_days_above) if condition.holding_days_below is not None: - checks.append( - holding_days is not None - and holding_days < condition.holding_days_below - ) + checks.append(holding_days is not None and holding_days < condition.holding_days_below) return len(checks) > 0 and all(checks) @@ -295,9 +303,15 @@ class ScenarioEngine: details["volume_ratio"] = self._safe_float(market_data.get("volume_ratio")) if condition.price_above is not None or condition.price_below is not None: details["current_price"] = self._safe_float(market_data.get("current_price")) - if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: + if ( + condition.price_change_pct_above is not None + or condition.price_change_pct_below is not None + ): details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct")) - if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None: + if ( + condition.unrealized_pnl_pct_above is not None + or condition.unrealized_pnl_pct_below is not None + ): details["unrealized_pnl_pct"] = self._safe_float(market_data.get("unrealized_pnl_pct")) if condition.holding_days_above is not None or condition.holding_days_below is not None: details["holding_days"] = self._safe_float(market_data.get("holding_days")) diff --git a/tests/test_backup.py b/tests/test_backup.py index 0ecfa3e..3e82e39 100644 --- a/tests/test_backup.py +++ b/tests/test_backup.py @@ -4,8 +4,7 @@ from __future__ import annotations import sqlite3 import sys -import tempfile -from datetime import UTC, datetime, timedelta +from datetime import UTC, datetime from pathlib import Path from unittest.mock import MagicMock, patch @@ -48,7 +47,9 @@ def temp_db(tmp_path: Path) -> Path: cursor.executemany( """ - INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl) + INSERT INTO trades ( + timestamp, stock_code, action, quantity, price, confidence, rationale, pnl + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, test_trades, @@ -73,9 +74,7 @@ class TestBackupExporter: exporter = BackupExporter(str(temp_db)) output_dir = tmp_path / "exports" - results = exporter.export_all( - output_dir, formats=[ExportFormat.JSON], compress=False - ) + results = exporter.export_all(output_dir, formats=[ExportFormat.JSON], compress=False) assert ExportFormat.JSON in results assert results[ExportFormat.JSON].exists() @@ -86,9 +85,7 @@ class TestBackupExporter: exporter = BackupExporter(str(temp_db)) output_dir = tmp_path / "exports" - results = exporter.export_all( - output_dir, formats=[ExportFormat.JSON], compress=True - ) + results = exporter.export_all(output_dir, formats=[ExportFormat.JSON], compress=True) assert ExportFormat.JSON in results assert results[ExportFormat.JSON].suffix == ".gz" @@ -98,15 +95,13 @@ class TestBackupExporter: exporter = BackupExporter(str(temp_db)) output_dir = tmp_path / "exports" - results = exporter.export_all( - output_dir, formats=[ExportFormat.CSV], compress=False - ) + results = exporter.export_all(output_dir, formats=[ExportFormat.CSV], compress=False) assert ExportFormat.CSV in results assert results[ExportFormat.CSV].exists() # Verify CSV content - with open(results[ExportFormat.CSV], "r") as f: + with open(results[ExportFormat.CSV]) as f: lines = f.readlines() assert len(lines) == 4 # Header + 3 rows @@ -146,7 +141,7 @@ class TestBackupExporter: # Should only have 1 trade (AAPL on Jan 2) import json - with open(results[ExportFormat.JSON], "r") as f: + with open(results[ExportFormat.JSON]) as f: data = json.load(f) assert data["record_count"] == 1 assert data["trades"][0]["stock_code"] == "AAPL" @@ -407,9 +402,7 @@ class TestBackupExporterAdditional: assert ExportFormat.JSON in results assert ExportFormat.CSV in results - def test_export_all_logs_error_on_failure( - self, temp_db: Path, tmp_path: Path - ) -> None: + def test_export_all_logs_error_on_failure(self, temp_db: Path, tmp_path: Path) -> None: """export_all must log an error and continue when one format fails.""" exporter = BackupExporter(str(temp_db)) # Patch _export_format to raise on JSON, succeed on CSV @@ -430,9 +423,7 @@ class TestBackupExporterAdditional: assert ExportFormat.JSON not in results assert ExportFormat.CSV in results - def test_export_csv_empty_trades_no_compress( - self, empty_db: Path, tmp_path: Path - ) -> None: + def test_export_csv_empty_trades_no_compress(self, empty_db: Path, tmp_path: Path) -> None: """CSV export with no trades and compress=False must write header row only.""" exporter = BackupExporter(str(empty_db)) results = exporter.export_all( @@ -446,9 +437,7 @@ class TestBackupExporterAdditional: content = out.read_text() assert "timestamp" in content - def test_export_csv_empty_trades_compressed( - self, empty_db: Path, tmp_path: Path - ) -> None: + def test_export_csv_empty_trades_compressed(self, empty_db: Path, tmp_path: Path) -> None: """CSV export with no trades and compress=True must write gzipped header.""" import gzip @@ -465,9 +454,7 @@ class TestBackupExporterAdditional: content = f.read() assert "timestamp" in content - def test_export_csv_with_data_compressed( - self, temp_db: Path, tmp_path: Path - ) -> None: + def test_export_csv_with_data_compressed(self, temp_db: Path, tmp_path: Path) -> None: """CSV export with data and compress=True must write gzipped rows.""" import gzip @@ -492,6 +479,7 @@ class TestBackupExporterAdditional: with patch.dict(sys.modules, {"pyarrow": None, "pyarrow.parquet": None}): try: import pyarrow # noqa: F401 + pytest.skip("pyarrow is installed; cannot test ImportError path") except ImportError: pass @@ -557,9 +545,7 @@ class TestCloudStorage: importlib.reload(m) m.CloudStorage(s3_config) - def test_upload_file_success( - self, mock_boto3_module, s3_config, tmp_path: Path - ) -> None: + def test_upload_file_success(self, mock_boto3_module, s3_config, tmp_path: Path) -> None: """upload_file must call client.upload_file and return the object key.""" from src.backup.cloud_storage import CloudStorage @@ -572,9 +558,7 @@ class TestCloudStorage: assert key == "backups/backup.json.gz" storage.client.upload_file.assert_called_once() - def test_upload_file_default_key( - self, mock_boto3_module, s3_config, tmp_path: Path - ) -> None: + def test_upload_file_default_key(self, mock_boto3_module, s3_config, tmp_path: Path) -> None: """upload_file without object_key must use the filename as key.""" from src.backup.cloud_storage import CloudStorage @@ -586,9 +570,7 @@ class TestCloudStorage: assert key == "myfile.gz" - def test_upload_file_not_found( - self, mock_boto3_module, s3_config, tmp_path: Path - ) -> None: + def test_upload_file_not_found(self, mock_boto3_module, s3_config, tmp_path: Path) -> None: """upload_file must raise FileNotFoundError for missing files.""" from src.backup.cloud_storage import CloudStorage @@ -611,9 +593,7 @@ class TestCloudStorage: with pytest.raises(RuntimeError, match="network error"): storage.upload_file(test_file) - def test_download_file_success( - self, mock_boto3_module, s3_config, tmp_path: Path - ) -> None: + def test_download_file_success(self, mock_boto3_module, s3_config, tmp_path: Path) -> None: """download_file must call client.download_file and return local path.""" from src.backup.cloud_storage import CloudStorage @@ -637,11 +617,8 @@ class TestCloudStorage: with pytest.raises(RuntimeError, match="timeout"): storage.download_file("key", tmp_path / "dest.gz") - def test_list_files_returns_objects( - self, mock_boto3_module, s3_config - ) -> None: + def test_list_files_returns_objects(self, mock_boto3_module, s3_config) -> None: """list_files must return parsed file metadata from S3 response.""" - from datetime import timezone from src.backup.cloud_storage import CloudStorage @@ -651,7 +628,7 @@ class TestCloudStorage: { "Key": "backups/a.gz", "Size": 1024, - "LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc), + "LastModified": datetime(2026, 1, 1, tzinfo=UTC), "ETag": '"abc123"', } ] @@ -662,9 +639,7 @@ class TestCloudStorage: assert files[0]["key"] == "backups/a.gz" assert files[0]["size_bytes"] == 1024 - def test_list_files_empty_bucket( - self, mock_boto3_module, s3_config - ) -> None: + def test_list_files_empty_bucket(self, mock_boto3_module, s3_config) -> None: """list_files must return empty list when bucket has no objects.""" from src.backup.cloud_storage import CloudStorage @@ -674,9 +649,7 @@ class TestCloudStorage: files = storage.list_files() assert files == [] - def test_list_files_propagates_error( - self, mock_boto3_module, s3_config - ) -> None: + def test_list_files_propagates_error(self, mock_boto3_module, s3_config) -> None: """list_files must re-raise exceptions from the boto3 client.""" from src.backup.cloud_storage import CloudStorage @@ -686,9 +659,7 @@ class TestCloudStorage: with pytest.raises(RuntimeError): storage.list_files() - def test_delete_file_success( - self, mock_boto3_module, s3_config - ) -> None: + def test_delete_file_success(self, mock_boto3_module, s3_config) -> None: """delete_file must call client.delete_object with the correct key.""" from src.backup.cloud_storage import CloudStorage @@ -698,9 +669,7 @@ class TestCloudStorage: Bucket="test-bucket", Key="backups/old.gz" ) - def test_delete_file_propagates_error( - self, mock_boto3_module, s3_config - ) -> None: + def test_delete_file_propagates_error(self, mock_boto3_module, s3_config) -> None: """delete_file must re-raise exceptions from the boto3 client.""" from src.backup.cloud_storage import CloudStorage @@ -710,11 +679,8 @@ class TestCloudStorage: with pytest.raises(RuntimeError): storage.delete_file("backups/old.gz") - def test_get_storage_stats_success( - self, mock_boto3_module, s3_config - ) -> None: + def test_get_storage_stats_success(self, mock_boto3_module, s3_config) -> None: """get_storage_stats must aggregate file sizes correctly.""" - from datetime import timezone from src.backup.cloud_storage import CloudStorage @@ -724,13 +690,13 @@ class TestCloudStorage: { "Key": "a.gz", "Size": 1024 * 1024, - "LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc), + "LastModified": datetime(2026, 1, 1, tzinfo=UTC), "ETag": '"x"', }, { "Key": "b.gz", "Size": 1024 * 1024, - "LastModified": datetime(2026, 1, 2, tzinfo=timezone.utc), + "LastModified": datetime(2026, 1, 2, tzinfo=UTC), "ETag": '"y"', }, ] @@ -741,9 +707,7 @@ class TestCloudStorage: assert stats["total_size_bytes"] == 2 * 1024 * 1024 assert stats["total_size_mb"] == pytest.approx(2.0) - def test_get_storage_stats_on_error( - self, mock_boto3_module, s3_config - ) -> None: + def test_get_storage_stats_on_error(self, mock_boto3_module, s3_config) -> None: """get_storage_stats must return error dict without raising on failure.""" from src.backup.cloud_storage import CloudStorage @@ -754,9 +718,7 @@ class TestCloudStorage: assert "error" in stats assert stats["total_files"] == 0 - def test_verify_connection_success( - self, mock_boto3_module, s3_config - ) -> None: + def test_verify_connection_success(self, mock_boto3_module, s3_config) -> None: """verify_connection must return True when head_bucket succeeds.""" from src.backup.cloud_storage import CloudStorage @@ -764,9 +726,7 @@ class TestCloudStorage: result = storage.verify_connection() assert result is True - def test_verify_connection_failure( - self, mock_boto3_module, s3_config - ) -> None: + def test_verify_connection_failure(self, mock_boto3_module, s3_config) -> None: """verify_connection must return False when head_bucket raises.""" from src.backup.cloud_storage import CloudStorage @@ -776,9 +736,7 @@ class TestCloudStorage: result = storage.verify_connection() assert result is False - def test_enable_versioning( - self, mock_boto3_module, s3_config - ) -> None: + def test_enable_versioning(self, mock_boto3_module, s3_config) -> None: """enable_versioning must call put_bucket_versioning.""" from src.backup.cloud_storage import CloudStorage @@ -786,9 +744,7 @@ class TestCloudStorage: storage.enable_versioning() storage.client.put_bucket_versioning.assert_called_once() - def test_enable_versioning_propagates_error( - self, mock_boto3_module, s3_config - ) -> None: + def test_enable_versioning_propagates_error(self, mock_boto3_module, s3_config) -> None: """enable_versioning must re-raise exceptions from the boto3 client.""" from src.backup.cloud_storage import CloudStorage diff --git a/tests/test_brain.py b/tests/test_brain.py index c857720..9bf99d8 100644 --- a/tests/test_brain.py +++ b/tests/test_brain.py @@ -323,7 +323,8 @@ class TestPromptOverride: # Verify the custom prompt was sent, not a built prompt mock_generate.assert_called_once() actual_prompt = mock_generate.call_args[1].get( - "contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None + "contents", + mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None, ) assert actual_prompt == custom_prompt # Raw response preserved in rationale without parse_response (#247) @@ -385,7 +386,8 @@ class TestPromptOverride: await client.decide(market_data) actual_prompt = mock_generate.call_args[1].get( - "contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None + "contents", + mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None, ) # The custom prompt must be used, not the compressed prompt assert actual_prompt == custom_prompt @@ -411,7 +413,8 @@ class TestPromptOverride: await client.decide(market_data) actual_prompt = mock_generate.call_args[1].get( - "contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None + "contents", + mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None, ) # Should contain stock code from build_prompt, not be a custom override assert "005930" in actual_prompt diff --git a/tests/test_broker.py b/tests/test_broker.py index 5213013..16ad45f 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, patch import pytest @@ -99,7 +99,10 @@ class TestTokenManagement: mock_resp_403 = AsyncMock() mock_resp_403.status = 403 mock_resp_403.text = AsyncMock( - return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}' + return_value=( + '{"error_code":"EGW00133","error_description":' + '"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}' + ) ) mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403) mock_resp_403.__aexit__ = AsyncMock(return_value=False) @@ -232,9 +235,7 @@ class TestRateLimiter: mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp) mock_order_resp.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp] - ): + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]): with patch.object( broker._rate_limiter, "acquire", new_callable=AsyncMock ) as mock_acquire: @@ -405,7 +406,7 @@ class TestFetchMarketRankings: # --------------------------------------------------------------------------- -from src.broker.kis_api import kr_tick_unit, kr_round_down # noqa: E402 +from src.broker.kis_api import kr_round_down, kr_tick_unit # noqa: E402 class TestKrTickUnit: @@ -435,13 +436,13 @@ class TestKrTickUnit: @pytest.mark.parametrize( "price, expected_rounded", [ - (188150, 188100), # 100원 단위, 50원 잔여 → 내림 - (188100, 188100), # 이미 정렬됨 - (75050, 75000), # 100원 단위, 50원 잔여 → 내림 - (49950, 49950), # 50원 단위 정렬됨 - (49960, 49950), # 50원 단위, 10원 잔여 → 내림 - (1999, 1999), # 1원 단위 → 그대로 - (5003, 5000), # 10원 단위, 3원 잔여 → 내림 + (188150, 188100), # 100원 단위, 50원 잔여 → 내림 + (188100, 188100), # 이미 정렬됨 + (75050, 75000), # 100원 단위, 50원 잔여 → 내림 + (49950, 49950), # 50원 단위 정렬됨 + (49960, 49950), # 50원 단위, 10원 잔여 → 내림 + (1999, 1999), # 1원 단위 → 그대로 + (5003, 5000), # 10원 단위, 3원 잔여 → 내림 ], ) def test_round_down_to_tick(self, price: int, expected_rounded: int) -> None: @@ -538,15 +539,13 @@ class TestSendOrderTickRounding: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "BUY", 1, price=188150) order_call = mock_post.call_args_list[1] body = order_call[1].get("json", {}) assert body["ORD_UNPR"] == "188100" # rounded down - assert body["ORD_DVSN"] == "00" # 지정가 + assert body["ORD_DVSN"] == "00" # 지정가 @pytest.mark.asyncio async def test_limit_order_ord_dvsn_is_00(self, broker: KISBroker) -> None: @@ -563,9 +562,7 @@ class TestSendOrderTickRounding: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "BUY", 1, price=50000) order_call = mock_post.call_args_list[1] @@ -587,9 +584,7 @@ class TestSendOrderTickRounding: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "SELL", 1, price=0) order_call = mock_post.call_args_list[1] @@ -628,9 +623,7 @@ class TestTRIDBranchingDomestic: broker = self._make_broker(settings, "paper") mock_resp = AsyncMock() mock_resp.status = 200 - mock_resp.json = AsyncMock( - return_value={"output1": [], "output2": {}} - ) + mock_resp.json = AsyncMock(return_value={"output1": [], "output2": {}}) mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) mock_resp.__aexit__ = AsyncMock(return_value=False) @@ -645,9 +638,7 @@ class TestTRIDBranchingDomestic: broker = self._make_broker(settings, "live") mock_resp = AsyncMock() mock_resp.status = 200 - mock_resp.json = AsyncMock( - return_value={"output1": [], "output2": {}} - ) + mock_resp.json = AsyncMock(return_value={"output1": [], "output2": {}}) mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) mock_resp.__aexit__ = AsyncMock(return_value=False) @@ -672,9 +663,7 @@ class TestTRIDBranchingDomestic: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "BUY", 1) order_headers = mock_post.call_args_list[1][1].get("headers", {}) @@ -695,9 +684,7 @@ class TestTRIDBranchingDomestic: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "BUY", 1) order_headers = mock_post.call_args_list[1][1].get("headers", {}) @@ -718,9 +705,7 @@ class TestTRIDBranchingDomestic: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "SELL", 1) order_headers = mock_post.call_args_list[1][1].get("headers", {}) @@ -741,9 +726,7 @@ class TestTRIDBranchingDomestic: mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aexit__ = AsyncMock(return_value=False) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.send_order("005930", "SELL", 1) order_headers = mock_post.call_args_list[1][1].get("headers", {}) @@ -788,9 +771,7 @@ class TestGetDomesticPendingOrders: mock_get.assert_not_called() @pytest.mark.asyncio - async def test_live_mode_calls_tttc0084r_with_correct_params( - self, settings - ) -> None: + async def test_live_mode_calls_tttc0084r_with_correct_params(self, settings) -> None: """Live mode must call TTTC0084R with INQR_DVSN_1/2 and paging params.""" broker = self._make_broker(settings, "live") pending = [{"odno": "001", "pdno": "005930", "psbl_qty": "10"}] @@ -872,9 +853,7 @@ class TestCancelDomesticOrder: broker = self._make_broker(settings, "live") mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5) order_headers = mock_post.call_args_list[1][1].get("headers", {}) @@ -886,9 +865,7 @@ class TestCancelDomesticOrder: broker = self._make_broker(settings, "paper") mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5) order_headers = mock_post.call_args_list[1][1].get("headers", {}) @@ -900,9 +877,7 @@ class TestCancelDomesticOrder: broker = self._make_broker(settings, "live") mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5) body = mock_post.call_args_list[1][1].get("json", {}) @@ -916,9 +891,7 @@ class TestCancelDomesticOrder: broker = self._make_broker(settings, "live") mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.cancel_domestic_order("005930", "ORD123", "BRN456", 3) body = mock_post.call_args_list[1][1].get("json", {}) @@ -932,9 +905,7 @@ class TestCancelDomesticOrder: broker = self._make_broker(settings, "live") mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) - with patch( - "aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order] - ) as mock_post: + with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post: await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 2) order_headers = mock_post.call_args_list[1][1].get("headers", {}) diff --git a/tests/test_context.py b/tests/test_context.py index a1d1f29..3abc58d 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -77,9 +77,7 @@ class TestContextStore: # Latest by updated_at, which should be the last one set assert latest == "2026-02-02" - def test_delete_old_contexts( - self, store: ContextStore, db_conn: sqlite3.Connection - ) -> None: + def test_delete_old_contexts(self, store: ContextStore, db_conn: sqlite3.Connection) -> None: """Test deleting contexts older than a cutoff date.""" # Insert contexts with specific old timestamps # (bypassing set_context which uses current time) @@ -170,9 +168,7 @@ class TestContextAggregator: log_trade(db_conn, "035720", "HOLD", 75, "Wait", quantity=0, price=0, pnl=0) # Manually set timestamps to the target date - db_conn.execute( - f"UPDATE trades SET timestamp = '{date}T10:00:00+00:00'" - ) + db_conn.execute(f"UPDATE trades SET timestamp = '{date}T10:00:00+00:00'") db_conn.commit() # Aggregate @@ -194,18 +190,10 @@ class TestContextAggregator: week = "2026-W06" # Set daily contexts - aggregator.store.set_context( - ContextLayer.L6_DAILY, "2026-02-02", "total_pnl_KR", 100.0 - ) - aggregator.store.set_context( - ContextLayer.L6_DAILY, "2026-02-03", "total_pnl_KR", 200.0 - ) - aggregator.store.set_context( - ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence_KR", 80.0 - ) - aggregator.store.set_context( - ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence_KR", 85.0 - ) + aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl_KR", 100.0) + aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "total_pnl_KR", 200.0) + aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence_KR", 80.0) + aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence_KR", 85.0) # Aggregate aggregator.aggregate_weekly_from_daily(week) @@ -223,15 +211,9 @@ class TestContextAggregator: month = "2026-02" # Set weekly contexts - aggregator.store.set_context( - ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl_KR", 100.0 - ) - aggregator.store.set_context( - ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl_KR", 200.0 - ) - aggregator.store.set_context( - ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl_KR", 150.0 - ) + aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl_KR", 100.0) + aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl_KR", 200.0) + aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl_KR", 150.0) # Aggregate aggregator.aggregate_monthly_from_weekly(month) @@ -316,6 +298,7 @@ class TestContextAggregator: store = aggregator.store assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 1000.0 from datetime import date as date_cls + trade_date = date_cls.fromisoformat(date) iso_year, iso_week, _ = trade_date.isocalendar() trade_week = f"{iso_year}-W{iso_week:02d}" @@ -324,7 +307,9 @@ class TestContextAggregator: trade_quarter = f"{trade_date.year}-Q{(trade_date.month - 1) // 3 + 1}" trade_year = str(trade_date.year) assert store.get_context(ContextLayer.L4_MONTHLY, trade_month, "monthly_pnl") == 1000.0 - assert store.get_context(ContextLayer.L3_QUARTERLY, trade_quarter, "quarterly_pnl") == 1000.0 + assert ( + store.get_context(ContextLayer.L3_QUARTERLY, trade_quarter, "quarterly_pnl") == 1000.0 + ) assert store.get_context(ContextLayer.L2_ANNUAL, trade_year, "annual_pnl") == 1000.0 @@ -429,9 +414,7 @@ class TestContextSummarizer: # summarize_layer # ------------------------------------------------------------------ - def test_summarize_layer_no_data( - self, summarizer: ContextSummarizer - ) -> None: + def test_summarize_layer_no_data(self, summarizer: ContextSummarizer) -> None: """summarize_layer with no data must return the 'No data' sentinel.""" result = summarizer.summarize_layer(ContextLayer.L6_DAILY) assert result["count"] == 0 @@ -448,15 +431,12 @@ class TestContextSummarizer: result = summarizer.summarize_layer(ContextLayer.L6_DAILY) assert "total_entries" in result - def test_summarize_layer_with_dict_values( - self, summarizer: ContextSummarizer - ) -> None: + def test_summarize_layer_with_dict_values(self, summarizer: ContextSummarizer) -> None: """summarize_layer must handle dict values by extracting numeric subkeys.""" store = summarizer.store # set_context serialises the value as JSON, so passing a dict works store.set_context( - ContextLayer.L6_DAILY, "2026-02-01", "metrics", - {"win_rate": 65.0, "label": "good"} + ContextLayer.L6_DAILY, "2026-02-01", "metrics", {"win_rate": 65.0, "label": "good"} ) result = summarizer.summarize_layer(ContextLayer.L6_DAILY) @@ -464,9 +444,7 @@ class TestContextSummarizer: # numeric subkey "win_rate" should appear as "metrics.win_rate" assert "metrics.win_rate" in result - def test_summarize_layer_with_string_values( - self, summarizer: ContextSummarizer - ) -> None: + def test_summarize_layer_with_string_values(self, summarizer: ContextSummarizer) -> None: """summarize_layer must count string values separately.""" store = summarizer.store # set_context stores string values as JSON-encoded strings @@ -480,9 +458,7 @@ class TestContextSummarizer: # rolling_window_summary # ------------------------------------------------------------------ - def test_rolling_window_summary_basic( - self, summarizer: ContextSummarizer - ) -> None: + def test_rolling_window_summary_basic(self, summarizer: ContextSummarizer) -> None: """rolling_window_summary must return the expected structure.""" store = summarizer.store store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 500.0) @@ -492,22 +468,16 @@ class TestContextSummarizer: assert "recent_data" in result assert "historical_summary" in result - def test_rolling_window_summary_no_older_data( - self, summarizer: ContextSummarizer - ) -> None: + def test_rolling_window_summary_no_older_data(self, summarizer: ContextSummarizer) -> None: """rolling_window_summary with summarize_older=False skips history.""" - result = summarizer.rolling_window_summary( - ContextLayer.L6_DAILY, summarize_older=False - ) + result = summarizer.rolling_window_summary(ContextLayer.L6_DAILY, summarize_older=False) assert result["historical_summary"] == {} # ------------------------------------------------------------------ # aggregate_to_higher_layer # ------------------------------------------------------------------ - def test_aggregate_to_higher_layer_mean( - self, summarizer: ContextSummarizer - ) -> None: + def test_aggregate_to_higher_layer_mean(self, summarizer: ContextSummarizer) -> None: """aggregate_to_higher_layer with 'mean' via dict subkeys returns average.""" store = summarizer.store # Use different outer keys but same inner metric key so get_all_contexts @@ -520,9 +490,7 @@ class TestContextSummarizer: ) assert result == pytest.approx(150.0) - def test_aggregate_to_higher_layer_sum( - self, summarizer: ContextSummarizer - ) -> None: + def test_aggregate_to_higher_layer_sum(self, summarizer: ContextSummarizer) -> None: """aggregate_to_higher_layer with 'sum' must return the total.""" store = summarizer.store store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0}) @@ -533,9 +501,7 @@ class TestContextSummarizer: ) assert result == pytest.approx(300.0) - def test_aggregate_to_higher_layer_max( - self, summarizer: ContextSummarizer - ) -> None: + def test_aggregate_to_higher_layer_max(self, summarizer: ContextSummarizer) -> None: """aggregate_to_higher_layer with 'max' must return the maximum.""" store = summarizer.store store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0}) @@ -546,9 +512,7 @@ class TestContextSummarizer: ) assert result == pytest.approx(200.0) - def test_aggregate_to_higher_layer_min( - self, summarizer: ContextSummarizer - ) -> None: + def test_aggregate_to_higher_layer_min(self, summarizer: ContextSummarizer) -> None: """aggregate_to_higher_layer with 'min' must return the minimum.""" store = summarizer.store store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0}) @@ -559,9 +523,7 @@ class TestContextSummarizer: ) assert result == pytest.approx(100.0) - def test_aggregate_to_higher_layer_no_data( - self, summarizer: ContextSummarizer - ) -> None: + def test_aggregate_to_higher_layer_no_data(self, summarizer: ContextSummarizer) -> None: """aggregate_to_higher_layer with no matching key must return None.""" result = summarizer.aggregate_to_higher_layer( ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "nonexistent", "mean" @@ -585,9 +547,7 @@ class TestContextSummarizer: # create_compact_summary + format_summary_for_prompt # ------------------------------------------------------------------ - def test_create_compact_summary( - self, summarizer: ContextSummarizer - ) -> None: + def test_create_compact_summary(self, summarizer: ContextSummarizer) -> None: """create_compact_summary must produce a dict keyed by layer value.""" store = summarizer.store store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0) @@ -615,9 +575,7 @@ class TestContextSummarizer: text = summarizer.format_summary_for_prompt(summary) assert text == "" - def test_format_summary_non_dict_value( - self, summarizer: ContextSummarizer - ) -> None: + def test_format_summary_non_dict_value(self, summarizer: ContextSummarizer) -> None: """format_summary_for_prompt must render non-dict values as plain text.""" summary = { "daily": { diff --git a/tests/test_daily_review.py b/tests/test_daily_review.py index 38765e6..e127b84 100644 --- a/tests/test_daily_review.py +++ b/tests/test_daily_review.py @@ -4,6 +4,7 @@ from __future__ import annotations import json import sqlite3 +from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock @@ -16,8 +17,6 @@ from src.evolution.daily_review import DailyReviewer from src.evolution.scorecard import DailyScorecard from src.logging.decision_logger import DecisionLogger -from datetime import UTC, datetime - TODAY = datetime.now(UTC).strftime("%Y-%m-%d") @@ -53,7 +52,8 @@ def _log_decision( def test_generate_scorecard_market_scoped( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store) logger = DecisionLogger(db_conn) @@ -134,7 +134,8 @@ def test_generate_scorecard_market_scoped( def test_generate_scorecard_top_winners_and_losers( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store) logger = DecisionLogger(db_conn) @@ -168,7 +169,8 @@ def test_generate_scorecard_top_winners_and_losers( def test_generate_scorecard_empty_day( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store) scorecard = reviewer.generate_scorecard(TODAY, "KR") @@ -184,7 +186,8 @@ def test_generate_scorecard_empty_day( @pytest.mark.asyncio async def test_generate_lessons_without_gemini_returns_empty( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store, gemini_client=None) lessons = await reviewer.generate_lessons( @@ -206,7 +209,8 @@ async def test_generate_lessons_without_gemini_returns_empty( @pytest.mark.asyncio async def test_generate_lessons_parses_json_array( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: mock_gemini = MagicMock() mock_gemini.decide = AsyncMock( @@ -233,7 +237,8 @@ async def test_generate_lessons_parses_json_array( @pytest.mark.asyncio async def test_generate_lessons_fallback_to_lines( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: mock_gemini = MagicMock() mock_gemini.decide = AsyncMock( @@ -260,7 +265,8 @@ async def test_generate_lessons_fallback_to_lines( @pytest.mark.asyncio async def test_generate_lessons_handles_gemini_error( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: mock_gemini = MagicMock() mock_gemini.decide = AsyncMock(side_effect=RuntimeError("boom")) @@ -284,7 +290,8 @@ async def test_generate_lessons_handles_gemini_error( def test_store_scorecard_in_context( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store) scorecard = DailyScorecard( @@ -316,7 +323,8 @@ def test_store_scorecard_in_context( def test_store_scorecard_key_is_market_scoped( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store) kr = DailyScorecard( @@ -357,7 +365,8 @@ def test_store_scorecard_key_is_market_scoped( def test_generate_scorecard_handles_invalid_context_snapshot( - db_conn: sqlite3.Connection, context_store: ContextStore, + db_conn: sqlite3.Connection, + context_store: ContextStore, ) -> None: reviewer = DailyReviewer(db_conn, context_store) db_conn.execute( diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 8620c44..106ff54 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -355,6 +355,7 @@ def test_positions_empty_when_no_trades(tmp_path: Path) -> None: def _seed_cb_context(conn: sqlite3.Connection, pnl_pct: float, market: str = "KR") -> None: import json as _json + conn.execute( "INSERT OR REPLACE INTO system_metrics (key, value, updated_at) VALUES (?, ?, ?)", ( diff --git a/tests/test_data_integration.py b/tests/test_data_integration.py index 45b1e2a..ea41199 100644 --- a/tests/test_data_integration.py +++ b/tests/test_data_integration.py @@ -79,7 +79,7 @@ class TestNewsAPI: # Mock the fetch to avoid real API call with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch: mock_fetch.return_value = None - result = await api.get_news_sentiment("AAPL") + await api.get_news_sentiment("AAPL") # Should have attempted refetch since cache expired mock_fetch.assert_called_once_with("AAPL") @@ -111,9 +111,7 @@ class TestNewsAPI: "source": "Reuters", "time_published": "2026-02-04T10:00:00", "url": "https://example.com/1", - "ticker_sentiment": [ - {"ticker": "AAPL", "ticker_sentiment_score": "0.85"} - ], + "ticker_sentiment": [{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}], "overall_sentiment_score": "0.75", }, { @@ -122,9 +120,7 @@ class TestNewsAPI: "source": "Bloomberg", "time_published": "2026-02-04T09:00:00", "url": "https://example.com/2", - "ticker_sentiment": [ - {"ticker": "AAPL", "ticker_sentiment_score": "-0.3"} - ], + "ticker_sentiment": [{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}], "overall_sentiment_score": "-0.2", }, ] @@ -661,7 +657,9 @@ class TestGeminiClientWithExternalData: ) # Mock the Gemini API call - with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen: + with patch.object( + client._client.aio.models, "generate_content", new_callable=AsyncMock + ) as mock_gen: mock_response = MagicMock() mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}' mock_gen.return_value = mock_response diff --git a/tests/test_db.py b/tests/test_db.py index fb2feb9..4f4d7a2 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -1,7 +1,7 @@ """Tests for database helper functions.""" -import tempfile import os +import tempfile from src.db import get_latest_buy_trade, get_open_position, init_db, log_trade @@ -204,7 +204,8 @@ def test_mode_migration_adds_column_to_existing_db() -> None: assert "strategy_pnl" in columns assert "fx_pnl" in columns migrated = conn.execute( - "SELECT pnl, strategy_pnl, fx_pnl, session_id FROM trades WHERE stock_code='AAPL' LIMIT 1" + "SELECT pnl, strategy_pnl, fx_pnl, session_id " + "FROM trades WHERE stock_code='AAPL' LIMIT 1" ).fetchone() assert migrated is not None assert migrated[0] == 123.45 @@ -407,9 +408,7 @@ def test_decision_logs_session_id_migration_backfills_unknown() -> None: conn = init_db(db_path) columns = {row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()} assert "session_id" in columns - row = conn.execute( - "SELECT session_id FROM decision_logs WHERE decision_id='d1'" - ).fetchone() + row = conn.execute("SELECT session_id FROM decision_logs WHERE decision_id='d1'").fetchone() assert row is not None assert row[0] == "UNKNOWN" conn.close() diff --git a/tests/test_decision_logger.py b/tests/test_decision_logger.py index dec3a64..ebb1572 100644 --- a/tests/test_decision_logger.py +++ b/tests/test_decision_logger.py @@ -49,7 +49,10 @@ def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Co # Verify record exists in database cursor = db_conn.execute( - "SELECT decision_id, action, confidence, session_id FROM decision_logs WHERE decision_id = ?", + ( + "SELECT decision_id, action, confidence, session_id " + "FROM decision_logs WHERE decision_id = ?" + ), (decision_id,), ) row = cursor.fetchone() diff --git a/tests/test_evolution.py b/tests/test_evolution.py index d5ad349..cdcd38c 100644 --- a/tests/test_evolution.py +++ b/tests/test_evolution.py @@ -208,7 +208,9 @@ def test_identify_failure_patterns_empty(optimizer: EvolutionOptimizer) -> None: @pytest.mark.asyncio -async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp_path: Path) -> None: +async def test_generate_strategy_creates_file( + optimizer: EvolutionOptimizer, tmp_path: Path +) -> None: """Test that generate_strategy creates a strategy file.""" failures = [ { @@ -234,7 +236,9 @@ async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp return {"action": "HOLD", "confidence": 50, "rationale": "Waiting"} """ - with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): + with patch.object( + optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response) + ): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): strategy_path = await optimizer.generate_strategy(failures) @@ -247,7 +251,8 @@ async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp @pytest.mark.asyncio async def test_generate_strategy_saves_valid_python_code( - optimizer: EvolutionOptimizer, tmp_path: Path, + optimizer: EvolutionOptimizer, + tmp_path: Path, ) -> None: """Test that syntactically valid generated code is saved.""" failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}] @@ -255,12 +260,14 @@ async def test_generate_strategy_saves_valid_python_code( mock_response = Mock() mock_response.text = ( 'price = market_data.get("current_price", 0)\n' - 'if price > 0:\n' + "if price > 0:\n" ' return {"action": "BUY", "confidence": 80, "rationale": "Positive price"}\n' 'return {"action": "HOLD", "confidence": 50, "rationale": "No signal"}\n' ) - with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): + with patch.object( + optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response) + ): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): strategy_path = await optimizer.generate_strategy(failures) @@ -270,7 +277,9 @@ async def test_generate_strategy_saves_valid_python_code( @pytest.mark.asyncio async def test_generate_strategy_blocks_invalid_python_code( - optimizer: EvolutionOptimizer, tmp_path: Path, caplog: pytest.LogCaptureFixture, + optimizer: EvolutionOptimizer, + tmp_path: Path, + caplog: pytest.LogCaptureFixture, ) -> None: """Test that syntactically invalid generated code is not saved.""" failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}] @@ -281,7 +290,9 @@ async def test_generate_strategy_blocks_invalid_python_code( ' return {"action": "BUY", "confidence": 80, "rationale": "broken"}\n' ) - with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): + with patch.object( + optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response) + ): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): with caplog.at_level("WARNING"): strategy_path = await optimizer.generate_strategy(failures) @@ -310,6 +321,7 @@ def test_get_performance_summary() -> None: """Test getting performance summary from trades table.""" # Create a temporary database with trades import tempfile + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: tmp_path = tmp.name @@ -604,7 +616,9 @@ def test_calculate_improvement_trend_declining(performance_tracker: PerformanceT assert trend["pnl_change"] == -250.0 -def test_calculate_improvement_trend_insufficient_data(performance_tracker: PerformanceTracker) -> None: +def test_calculate_improvement_trend_insufficient_data( + performance_tracker: PerformanceTracker, +) -> None: """Test improvement trend with insufficient data.""" metrics = [ StrategyMetrics( @@ -718,7 +732,9 @@ async def test_full_evolution_pipeline(optimizer: EvolutionOptimizer, tmp_path: mock_response = Mock() mock_response.text = 'return {"action": "HOLD", "confidence": 50, "rationale": "Test"}' - with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): + with patch.object( + optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response) + ): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): with patch("subprocess.run") as mock_run: mock_run.return_value = Mock(returncode=0, stdout="", stderr="") diff --git a/tests/test_logging_config.py b/tests/test_logging_config.py index 526f692..387623e 100644 --- a/tests/test_logging_config.py +++ b/tests/test_logging_config.py @@ -103,9 +103,7 @@ class TestSetupLogging: """setup_logging must attach a JSON handler to the root logger.""" setup_logging(level=logging.DEBUG) root = logging.getLogger() - json_handlers = [ - h for h in root.handlers if isinstance(h.formatter, JSONFormatter) - ] + json_handlers = [h for h in root.handlers if isinstance(h.formatter, JSONFormatter)] assert len(json_handlers) == 1 assert root.level == logging.DEBUG diff --git a/tests/test_main.py b/tests/test_main.py index bacedc1..95b2c40 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -4,45 +4,45 @@ from datetime import UTC, date, datetime from unittest.mock import ANY, AsyncMock, MagicMock, patch import pytest -import src.main as main_module +import src.main as main_module from src.config import Settings from src.context.layer import ContextLayer from src.context.scheduler import ScheduleResult -from src.core.order_policy import OrderPolicyRejected +from src.core.order_policy import OrderPolicyRejected, get_session_info from src.core.risk_manager import CircuitBreakerTripped, FatFingerRejected from src.db import init_db, log_trade from src.evolution.scorecard import DailyScorecard from src.logging.decision_logger import DecisionLogger from src.main import ( - KILL_SWITCH, + _RUNTIME_EXIT_PEAKS, + _RUNTIME_EXIT_STATES, _SESSION_RISK_LAST_BY_MARKET, _SESSION_RISK_OVERRIDES_BY_MARKET, _SESSION_RISK_PROFILES_MAP, _STOPLOSS_REENTRY_COOLDOWN_UNTIL, + KILL_SWITCH, + _apply_dashboard_flag, _apply_staged_exit_override_for_hold, _compute_kr_atr_value, - _estimate_pred_down_prob_from_rsi, - _inject_staged_exit_features, - _RUNTIME_EXIT_PEAKS, - _RUNTIME_EXIT_STATES, - _should_force_exit_for_overnight, - _should_block_overseas_buy_for_fx_buffer, - _trigger_emergency_kill_switch, - _apply_dashboard_flag, + _compute_kr_dynamic_stop_loss_pct, _determine_order_quantity, + _estimate_pred_down_prob_from_rsi, _extract_avg_price_from_balance, _extract_held_codes_from_balance, _extract_held_qty_from_balance, _handle_market_close, - _retry_connection, + _inject_staged_exit_features, _resolve_market_setting, _resolve_sell_qty_for_pnl, + _retry_connection, _run_context_scheduler, _run_evolution_loop, + _should_block_overseas_buy_for_fx_buffer, + _should_force_exit_for_overnight, _start_dashboard_server, _stoploss_cooldown_minutes, - _compute_kr_dynamic_stop_loss_pct, + _trigger_emergency_kill_switch, handle_domestic_pending_orders, handle_overseas_pending_orders, process_blackout_recovery_orders, @@ -336,10 +336,7 @@ async def test_inject_staged_exit_features_sets_pred_down_prob_and_atr_for_kr() broker = MagicMock() broker.get_daily_prices = AsyncMock( - return_value=[ - {"high": 102.0 + i, "low": 98.0 + i, "close": 100.0 + i} - for i in range(40) - ] + return_value=[{"high": 102.0 + i, "low": 98.0 + i, "close": 100.0 + i} for i in range(40)] ) await _inject_staged_exit_features( @@ -483,9 +480,7 @@ class TestExtractHeldQtyFromBalance: def test_overseas_returns_ord_psbl_qty_first(self) -> None: """ord_psbl_qty (주문가능수량) takes priority over ovrs_cblc_qty.""" - balance = { - "output1": [{"ovrs_pdno": "AAPL", "ord_psbl_qty": "8", "ovrs_cblc_qty": "10"}] - } + balance = {"output1": [{"ovrs_pdno": "AAPL", "ord_psbl_qty": "8", "ovrs_cblc_qty": "10"}]} assert _extract_held_qty_from_balance(balance, "AAPL", is_domestic=False) == 8 def test_overseas_fallback_to_ovrs_cblc_qty_when_ord_psbl_qty_absent(self) -> None: @@ -809,9 +804,7 @@ class TestTradingCycleTelegramIntegration: def mock_criticality_assessor(self) -> MagicMock: """Create mock criticality assessor.""" assessor = MagicMock() - assessor.assess_market_conditions = MagicMock( - return_value=MagicMock(value="NORMAL") - ) + assessor.assess_market_conditions = MagicMock(return_value=MagicMock(value="NORMAL")) assessor.get_timeout = MagicMock(return_value=5.0) return assessor @@ -1199,9 +1192,7 @@ class TestOverseasBalanceParsing: def mock_overseas_broker_with_list(self) -> MagicMock: """Create mock overseas broker returning list format.""" broker = MagicMock() - broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "150.50"}} - ) + broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "150.50"}}) broker.get_overseas_balance = AsyncMock( return_value={ "output2": [ @@ -1221,9 +1212,7 @@ class TestOverseasBalanceParsing: def mock_overseas_broker_with_dict(self) -> MagicMock: """Create mock overseas broker returning dict format.""" broker = MagicMock() - broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "150.50"}} - ) + broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "150.50"}}) broker.get_overseas_balance = AsyncMock( return_value={ "output2": { @@ -1241,9 +1230,7 @@ class TestOverseasBalanceParsing: def mock_overseas_broker_with_empty(self) -> MagicMock: """Create mock overseas broker returning empty output2.""" broker = MagicMock() - broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "150.50"}} - ) + broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "150.50"}}) broker.get_overseas_balance = AsyncMock(return_value={"output2": []}) broker.get_overseas_buying_power = AsyncMock( return_value={"output": {"ovrs_ord_psbl_amt": "0.00"}} @@ -1327,9 +1314,7 @@ class TestOverseasBalanceParsing: def mock_criticality_assessor(self) -> MagicMock: """Create mock criticality assessor.""" assessor = MagicMock() - assessor.assess_market_conditions = MagicMock( - return_value=MagicMock(value="NORMAL") - ) + assessor.assess_market_conditions = MagicMock(return_value=MagicMock(value="NORMAL")) assessor.get_timeout = MagicMock(return_value=5.0) return assessor @@ -1492,9 +1477,7 @@ class TestOverseasBalanceParsing: def mock_overseas_broker_with_buy_scenario(self) -> MagicMock: """Create mock overseas broker that returns a valid price for BUY orders.""" broker = MagicMock() - broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "182.50"}} - ) + broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "182.50"}}) broker.get_overseas_balance = AsyncMock( return_value={ "output2": [ @@ -1615,9 +1598,7 @@ class TestOverseasBalanceParsing: overseas_broker.get_overseas_buying_power = AsyncMock( return_value={"output": {"ovrs_ord_psbl_amt": "50000.00"}} ) - overseas_broker.send_overseas_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + overseas_broker.send_overseas_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) sell_engine = MagicMock(spec=ScenarioEngine) sell_engine.evaluate = MagicMock(return_value=_make_sell_match("AAPL")) @@ -1709,8 +1690,10 @@ class TestOverseasBalanceParsing: ) overseas_broker.send_overseas_order.assert_called_once() - sent_price = overseas_broker.send_overseas_order.call_args[1].get("price") or \ - overseas_broker.send_overseas_order.call_args[0][4] + sent_price = ( + overseas_broker.send_overseas_order.call_args[1].get("price") + or overseas_broker.send_overseas_order.call_args[0][4] + ) # 50.1234 * 1.002 = 50.2235... rounded to 2 decimals = 50.22 assert sent_price == round(50.1234 * 1.002, 2), ( f"Expected 2-decimal price {round(50.1234 * 1.002, 2)} but got {sent_price} (#252)" @@ -1753,25 +1736,33 @@ class TestOverseasBalanceParsing: engine = MagicMock(spec=ScenarioEngine) engine.evaluate = MagicMock(return_value=_make_buy_match()) - await trading_cycle( - broker=mock_domestic_broker, - overseas_broker=overseas_broker, - scenario_engine=engine, - playbook=mock_playbook, - risk=mock_risk, - db_conn=db_conn, - decision_logger=decision_logger, - context_store=mock_context_store, - criticality_assessor=mock_criticality_assessor, - telegram=mock_telegram, - market=mock_overseas_market, - stock_code="PENNYX", - scan_candidates={}, - ) + with patch( + "src.main._resolve_market_setting", + side_effect=lambda **kwargs: ( + 0.1 if kwargs.get("key") == "US_MIN_PRICE" else kwargs.get("default") + ), + ): + await trading_cycle( + broker=mock_domestic_broker, + overseas_broker=overseas_broker, + scenario_engine=engine, + playbook=mock_playbook, + risk=mock_risk, + db_conn=db_conn, + decision_logger=decision_logger, + context_store=mock_context_store, + criticality_assessor=mock_criticality_assessor, + telegram=mock_telegram, + market=mock_overseas_market, + stock_code="PENNYX", + scan_candidates={}, + ) overseas_broker.send_overseas_order.assert_called_once() - sent_price = overseas_broker.send_overseas_order.call_args[1].get("price") or \ - overseas_broker.send_overseas_order.call_args[0][4] + sent_price = ( + overseas_broker.send_overseas_order.call_args[1].get("price") + or overseas_broker.send_overseas_order.call_args[0][4] + ) # 0.5678 * 1.002 = 0.56893... rounded to 4 decimals = 0.5689 assert sent_price == round(0.5678 * 1.002, 4), ( f"Expected 4-decimal price {round(0.5678 * 1.002, 4)} but got {sent_price} (#252)" @@ -1821,7 +1812,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_scenario_engine_called_with_enriched_market_data( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test scenario engine receives market_data enriched with scanner metrics.""" from src.analysis.smart_scanner import ScanCandidate @@ -1831,9 +1825,14 @@ class TestScenarioEngineIntegration: playbook = _make_playbook() candidate = ScanCandidate( - stock_code="005930", name="Samsung", price=50000, - volume=1000000, volume_ratio=3.5, rsi=25.0, - signal="oversold", score=85.0, + stock_code="005930", + name="Samsung", + price=50000, + volume=1000000, + volume_ratio=3.5, + rsi=25.0, + signal="oversold", + score=85.0, ) with ( @@ -1877,7 +1876,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_trading_cycle_sets_l7_context_keys( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test L7 context is written with market-scoped keys.""" from src.analysis.smart_scanner import ScanCandidate @@ -1888,9 +1890,14 @@ class TestScenarioEngineIntegration: context_store = MagicMock(get_latest_timeframe=MagicMock(return_value=None)) candidate = ScanCandidate( - stock_code="005930", name="Samsung", price=50000, - volume=1000000, volume_ratio=3.5, rsi=25.0, - signal="oversold", score=85.0, + stock_code="005930", + name="Samsung", + price=50000, + volume=1000000, + volume_ratio=3.5, + rsi=25.0, + signal="oversold", + score=85.0, ) with patch("src.main.log_trade"): @@ -1940,7 +1947,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_scan_candidates_market_scoped( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test scan_candidates uses market-scoped lookup, ignoring other markets.""" from src.analysis.smart_scanner import ScanCandidate @@ -1950,9 +1960,14 @@ class TestScenarioEngineIntegration: # Candidate stored under US market — should NOT be found for KR market us_candidate = ScanCandidate( - stock_code="005930", name="Overlap", price=100, - volume=500000, volume_ratio=5.0, rsi=15.0, - signal="oversold", score=90.0, + stock_code="005930", + name="Overlap", + price=100, + volume=500000, + volume_ratio=5.0, + rsi=15.0, + signal="oversold", + score=90.0, ) with patch("src.main.log_trade"): @@ -1982,7 +1997,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_scenario_engine_called_without_scanner_data( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test scenario engine works when stock has no scan candidate.""" engine = MagicMock(spec=ScenarioEngine) @@ -2020,7 +2038,9 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_holding_overseas_stock_derives_volume_ratio_from_price_api( - self, mock_broker: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test overseas holding stocks derive volume_ratio from get_overseas_price high/low.""" engine = MagicMock(spec=ScenarioEngine) @@ -2035,15 +2055,17 @@ class TestScenarioEngineIntegration: os_broker = MagicMock() # price_change_pct=5.0, high=106, low=94 → intraday_range=12% → volume_ratio=max(1,6)=6 - os_broker.get_overseas_price = AsyncMock(return_value={ - "output": {"last": "100.0", "rate": "5.0", "high": "106.0", "low": "94.0"} - }) - os_broker.get_overseas_balance = AsyncMock(return_value={ - "output2": [{"frcr_evlu_tota": "10000", "frcr_buy_amt_smtl": "9000"}] - }) - os_broker.get_overseas_buying_power = AsyncMock(return_value={ - "output": {"ovrs_ord_psbl_amt": "500"} - }) + os_broker.get_overseas_price = AsyncMock( + return_value={ + "output": {"last": "100.0", "rate": "5.0", "high": "106.0", "low": "94.0"} + } + ) + os_broker.get_overseas_balance = AsyncMock( + return_value={"output2": [{"frcr_evlu_tota": "10000", "frcr_buy_amt_smtl": "9000"}]} + ) + os_broker.get_overseas_buying_power = AsyncMock( + return_value={"output": {"ovrs_ord_psbl_amt": "500"}} + ) with patch("src.main.log_trade"): await trading_cycle( @@ -2075,7 +2097,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_scenario_matched_notification_sent( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test telegram notification sent when a scenario matches.""" # Create a match with matched_scenario (not None) @@ -2125,7 +2150,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_no_scenario_matched_notification_on_default_hold( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test no scenario notification when default HOLD is returned.""" engine = MagicMock(spec=ScenarioEngine) @@ -2156,7 +2184,10 @@ class TestScenarioEngineIntegration: @pytest.mark.asyncio async def test_decision_logger_receives_scenario_match_details( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test decision logger context includes scenario match details.""" match = ScenarioMatch( @@ -2193,13 +2224,16 @@ class TestScenarioEngineIntegration: decision_logger.log_decision.assert_called_once() call_kwargs = decision_logger.log_decision.call_args.kwargs - assert call_kwargs["session_id"] == "KRX_REG" + assert call_kwargs["session_id"] == get_session_info(mock_market).session_id assert "scenario_match" in call_kwargs["context_snapshot"] assert call_kwargs["context_snapshot"]["scenario_match"]["rsi"] == 45.0 @pytest.mark.asyncio async def test_reduce_all_does_not_execute_order( - self, mock_broker: MagicMock, mock_market: MagicMock, mock_telegram: MagicMock, + self, + mock_broker: MagicMock, + mock_market: MagicMock, + mock_telegram: MagicMock, ) -> None: """Test REDUCE_ALL action does not trigger order execution.""" match = ScenarioMatch( @@ -2340,7 +2374,9 @@ async def test_stoploss_reentry_cooldown_blocks_buy_when_active() -> None: broker.get_balance = AsyncMock( return_value={ "output1": [], - "output2": [{"tot_evlu_amt": "100000", "dnca_tot_amt": "50000", "pchs_amt_smtl_amt": "50000"}], + "output2": [ + {"tot_evlu_amt": "100000", "dnca_tot_amt": "50000", "pchs_amt_smtl_amt": "50000"} + ], } ) broker.send_order = AsyncMock(return_value={"msg1": "OK"}) @@ -2359,7 +2395,9 @@ async def test_stoploss_reentry_cooldown_blocks_buy_when_active() -> None: risk=MagicMock(validate_order=MagicMock(), check_circuit_breaker=MagicMock()), db_conn=db_conn, decision_logger=DecisionLogger(db_conn), - context_store=MagicMock(get_latest_timeframe=MagicMock(return_value=None), set_context=MagicMock()), + context_store=MagicMock( + get_latest_timeframe=MagicMock(return_value=None), set_context=MagicMock() + ), criticality_assessor=MagicMock( assess_market_conditions=MagicMock(return_value=MagicMock(value="NORMAL")), get_timeout=MagicMock(return_value=5.0), @@ -2389,7 +2427,9 @@ async def test_stoploss_reentry_cooldown_allows_buy_after_expiry() -> None: broker.get_balance = AsyncMock( return_value={ "output1": [], - "output2": [{"tot_evlu_amt": "100000", "dnca_tot_amt": "50000", "pchs_amt_smtl_amt": "50000"}], + "output2": [ + {"tot_evlu_amt": "100000", "dnca_tot_amt": "50000", "pchs_amt_smtl_amt": "50000"} + ], } ) broker.send_order = AsyncMock(return_value={"msg1": "OK"}) @@ -2408,7 +2448,9 @@ async def test_stoploss_reentry_cooldown_allows_buy_after_expiry() -> None: risk=MagicMock(validate_order=MagicMock(), check_circuit_breaker=MagicMock()), db_conn=db_conn, decision_logger=DecisionLogger(db_conn), - context_store=MagicMock(get_latest_timeframe=MagicMock(return_value=None), set_context=MagicMock()), + context_store=MagicMock( + get_latest_timeframe=MagicMock(return_value=None), set_context=MagicMock() + ), criticality_assessor=MagicMock( assess_market_conditions=MagicMock(return_value=MagicMock(value="NORMAL")), get_timeout=MagicMock(return_value=5.0), @@ -3419,6 +3461,7 @@ def test_start_dashboard_server_returns_none_when_uvicorn_missing() -> None: DASHBOARD_ENABLED=True, ) import builtins + real_import = builtins.__import__ def mock_import(name: str, *args: object, **kwargs: object) -> object: @@ -3446,8 +3489,13 @@ class TestBuyCooldown: broker.get_current_price = AsyncMock(return_value=(100.0, 1.0, 0.0)) broker.get_balance = AsyncMock( return_value={ - "output2": [{"tot_evlu_amt": "1000000", "dnca_tot_amt": "500000", - "pchs_amt_smtl_amt": "500000"}] + "output2": [ + { + "tot_evlu_amt": "1000000", + "dnca_tot_amt": "500000", + "pchs_amt_smtl_amt": "500000", + } + ] } ) broker.send_order = AsyncMock(return_value={"msg1": "OK"}) @@ -3475,13 +3523,22 @@ class TestBuyCooldown: def mock_overseas_broker(self) -> MagicMock: broker = MagicMock() broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "1.0", "rate": "0.0", - "high": "1.05", "low": "0.95", "tvol": "1000000"}} + return_value={ + "output": { + "last": "1.0", + "rate": "0.0", + "high": "1.05", + "low": "0.95", + "tvol": "1000000", + } + } + ) + broker.get_overseas_balance = AsyncMock( + return_value={ + "output1": [], + "output2": [{"frcr_evlu_tota": "50000", "frcr_buy_amt_smtl": "0"}], + } ) - broker.get_overseas_balance = AsyncMock(return_value={ - "output1": [], - "output2": [{"frcr_evlu_tota": "50000", "frcr_buy_amt_smtl": "0"}], - }) broker.get_overseas_buying_power = AsyncMock( return_value={"output": {"ovrs_ord_psbl_amt": "50000"}} ) @@ -3501,7 +3558,9 @@ class TestBuyCooldown: @pytest.mark.asyncio async def test_cooldown_set_on_insufficient_balance( - self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, + self, + mock_broker: MagicMock, + mock_overseas_broker: MagicMock, mock_overseas_market: MagicMock, ) -> None: """BUY cooldown entry is created after 주문가능금액 rejection.""" @@ -3509,7 +3568,12 @@ class TestBuyCooldown: engine.evaluate = MagicMock(return_value=self._make_buy_match_overseas("MLECW")) buy_cooldown: dict[str, float] = {} - with patch("src.main.log_trade"): + with patch("src.main.log_trade"), patch( + "src.main._resolve_market_setting", + side_effect=lambda **kwargs: ( + 0.1 if kwargs.get("key") == "US_MIN_PRICE" else kwargs.get("default") + ), + ): await trading_cycle( broker=mock_broker, overseas_broker=mock_overseas_broker, @@ -3540,7 +3604,9 @@ class TestBuyCooldown: @pytest.mark.asyncio async def test_cooldown_skips_buy( - self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, + self, + mock_broker: MagicMock, + mock_overseas_broker: MagicMock, mock_overseas_market: MagicMock, ) -> None: """BUY is skipped when cooldown is active for the stock.""" @@ -3548,10 +3614,9 @@ class TestBuyCooldown: engine.evaluate = MagicMock(return_value=self._make_buy_match_overseas("MLECW")) import asyncio + # Set an active cooldown (expires far in the future) - buy_cooldown: dict[str, float] = { - "US_NASDAQ:MLECW": asyncio.get_event_loop().time() + 600 - } + buy_cooldown: dict[str, float] = {"US_NASDAQ:MLECW": asyncio.get_event_loop().time() + 600} with patch("src.main.log_trade"): await trading_cycle( @@ -3584,7 +3649,9 @@ class TestBuyCooldown: @pytest.mark.asyncio async def test_cooldown_not_set_on_other_errors( - self, mock_broker: MagicMock, mock_overseas_market: MagicMock, + self, + mock_broker: MagicMock, + mock_overseas_market: MagicMock, ) -> None: """Cooldown is NOT set for non-balance-related rejections.""" engine = MagicMock(spec=ScenarioEngine) @@ -3592,13 +3659,22 @@ class TestBuyCooldown: # Different rejection reason overseas_broker = MagicMock() overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "1.0", "rate": "0.0", - "high": "1.05", "low": "0.95", "tvol": "1000000"}} + return_value={ + "output": { + "last": "1.0", + "rate": "0.0", + "high": "1.05", + "low": "0.95", + "tvol": "1000000", + } + } + ) + overseas_broker.get_overseas_balance = AsyncMock( + return_value={ + "output1": [], + "output2": [{"frcr_evlu_tota": "50000", "frcr_buy_amt_smtl": "0"}], + } ) - overseas_broker.get_overseas_balance = AsyncMock(return_value={ - "output1": [], - "output2": [{"frcr_evlu_tota": "50000", "frcr_buy_amt_smtl": "0"}], - }) overseas_broker.get_overseas_buying_power = AsyncMock( return_value={"output": {"ovrs_ord_psbl_amt": "50000"}} ) @@ -3638,14 +3714,21 @@ class TestBuyCooldown: @pytest.mark.asyncio async def test_no_cooldown_param_still_works( - self, mock_broker: MagicMock, mock_overseas_broker: MagicMock, + self, + mock_broker: MagicMock, + mock_overseas_broker: MagicMock, mock_overseas_market: MagicMock, ) -> None: """trading_cycle works normally when buy_cooldown is None (default).""" engine = MagicMock(spec=ScenarioEngine) engine.evaluate = MagicMock(return_value=self._make_buy_match_overseas("MLECW")) - with patch("src.main.log_trade"): + with patch("src.main.log_trade"), patch( + "src.main._resolve_market_setting", + side_effect=lambda **kwargs: ( + 0.1 if kwargs.get("key") == "US_MIN_PRICE" else kwargs.get("default") + ), + ): await trading_cycle( broker=mock_broker, overseas_broker=mock_overseas_broker, @@ -3722,6 +3805,7 @@ class TestMarketOutlookConfidenceThreshold: self, confidence: int, stock_code: str = "005930" ) -> ScenarioMatch: from src.strategy.models import StockScenario + scenario = StockScenario( condition=StockCondition(rsi_below=30), action=ScenarioAction.BUY, @@ -3736,10 +3820,9 @@ class TestMarketOutlookConfidenceThreshold: rationale="Test buy", ) - def _make_playbook_with_outlook( - self, outlook_str: str, market: str = "KR" - ) -> DayPlaybook: + def _make_playbook_with_outlook(self, outlook_str: str, market: str = "KR") -> DayPlaybook: from src.strategy.models import MarketOutlook + outlook_map = { "bearish": MarketOutlook.BEARISH, "bullish": MarketOutlook.BULLISH, @@ -3991,7 +4074,15 @@ async def test_buy_suppressed_when_open_position_exists() -> None: overseas_broker = MagicMock() overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "51.0", "rate": "2.0", "high": "52.0", "low": "50.0", "tvol": "1000000"}} + return_value={ + "output": { + "last": "51.0", + "rate": "2.0", + "high": "52.0", + "low": "50.0", + "tvol": "1000000", + } + } ) overseas_broker.get_overseas_balance = AsyncMock( return_value={ @@ -4058,7 +4149,15 @@ async def test_buy_proceeds_when_no_open_position() -> None: overseas_broker = MagicMock() overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "100.0", "rate": "1.0", "high": "101.0", "low": "99.0", "tvol": "500000"}} + return_value={ + "output": { + "last": "100.0", + "rate": "1.0", + "high": "101.0", + "low": "99.0", + "tvol": "500000", + } + } ) overseas_broker.get_overseas_balance = AsyncMock( return_value={ @@ -4160,9 +4259,7 @@ class TestOverseasBrokerIntegration: ) overseas_broker = MagicMock() - overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "182.50"}} - ) + overseas_broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "182.50"}}) # 브로커: 여전히 AAPL 10주 보유 중 (SELL 미체결) overseas_broker.get_overseas_balance = AsyncMock( return_value={ @@ -4236,9 +4333,7 @@ class TestOverseasBrokerIntegration: # DB: 레코드 없음 (신규 포지션) overseas_broker = MagicMock() - overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "182.50"}} - ) + overseas_broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "182.50"}}) # 브로커: AAPL 미보유 overseas_broker.get_overseas_balance = AsyncMock( return_value={ @@ -4306,9 +4401,7 @@ class TestOverseasBrokerIntegration: db_conn = init_db(":memory:") overseas_broker = MagicMock() - overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "182.50"}} - ) + overseas_broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "182.50"}}) overseas_broker.get_overseas_balance = AsyncMock( return_value={ "output1": [], @@ -4387,6 +4480,7 @@ class TestRetryConnection: @pytest.mark.asyncio async def test_success_on_first_attempt(self) -> None: """Returns the result immediately when the first call succeeds.""" + async def ok() -> str: return "data" @@ -4596,9 +4690,7 @@ class TestDailyCBBaseline: return_value=self._make_domestic_balance(tot_evlu_amt=55000.0) ) # Price data for the stock - broker.get_current_price = AsyncMock( - return_value=(100.0, 1.5, 100.0) - ) + broker.get_current_price = AsyncMock(return_value=(100.0, 1.5, 100.0)) market = MagicMock() market.name = "KR" @@ -4643,8 +4735,10 @@ class TestDailyCBBaseline: async def _passthrough(fn, *a, label: str = "", **kw): # type: ignore[override] return await fn(*a, **kw) - with patch("src.main.get_open_markets", return_value=[market]), \ - patch("src.main._retry_connection", new=_passthrough): + with ( + patch("src.main.get_open_markets", return_value=[market]), + patch("src.main._retry_connection", new=_passthrough), + ): result = await run_daily_session( broker=broker, overseas_broker=MagicMock(), @@ -4720,8 +4814,10 @@ class TestDailyCBBaseline: async def _passthrough(fn, *a, label: str = "", **kw): # type: ignore[override] return await fn(*a, **kw) - with patch("src.main.get_open_markets", return_value=[market]), \ - patch("src.main._retry_connection", new=_passthrough): + with ( + patch("src.main.get_open_markets", return_value=[market]), + patch("src.main._retry_connection", new=_passthrough), + ): result = await run_daily_session( broker=broker, overseas_broker=MagicMock(), @@ -4844,8 +4940,10 @@ async def test_run_daily_session_applies_staged_exit_override_on_hold() -> None: async def _passthrough(fn, *a, label: str = "", **kw): # type: ignore[override] return await fn(*a, **kw) - with patch("src.main.get_open_markets", return_value=[market]), \ - patch("src.main._retry_connection", new=_passthrough): + with ( + patch("src.main.get_open_markets", return_value=[market]), + patch("src.main._retry_connection", new=_passthrough), + ): await run_daily_session( broker=broker, overseas_broker=MagicMock(), @@ -5032,17 +5130,14 @@ class TestSyncPositionsFromBroker: db_conn = init_db(":memory:") broker = MagicMock() - broker.get_balance = AsyncMock( - return_value=self._domestic_balance("005930", qty=7) - ) + broker.get_balance = AsyncMock(return_value=self._domestic_balance("005930", qty=7)) overseas_broker = MagicMock() - synced = await sync_positions_from_broker( - broker, overseas_broker, db_conn, settings - ) + synced = await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) assert synced == 1 from src.db import get_open_position + pos = get_open_position(db_conn, "005930", "KR") assert pos is not None assert pos["quantity"] == 7 @@ -5066,14 +5161,10 @@ class TestSyncPositionsFromBroker: ) broker = MagicMock() - broker.get_balance = AsyncMock( - return_value=self._domestic_balance("005930", qty=5) - ) + broker.get_balance = AsyncMock(return_value=self._domestic_balance("005930", qty=5)) overseas_broker = MagicMock() - synced = await sync_positions_from_broker( - broker, overseas_broker, db_conn, settings - ) + synced = await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) assert synced == 0 @@ -5089,12 +5180,11 @@ class TestSyncPositionsFromBroker: return_value=self._overseas_balance("AAPL", qty=10) ) - synced = await sync_positions_from_broker( - broker, overseas_broker, db_conn, settings - ) + synced = await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) assert synced == 1 from src.db import get_open_position + pos = get_open_position(db_conn, "AAPL", "US_NASDAQ") assert pos is not None assert pos["quantity"] == 10 @@ -5106,14 +5196,10 @@ class TestSyncPositionsFromBroker: db_conn = init_db(":memory:") broker = MagicMock() - broker.get_balance = AsyncMock( - return_value={"output1": [], "output2": [{}]} - ) + broker.get_balance = AsyncMock(return_value={"output1": [], "output2": [{}]}) overseas_broker = MagicMock() - synced = await sync_positions_from_broker( - broker, overseas_broker, db_conn, settings - ) + synced = await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) assert synced == 0 @@ -5124,14 +5210,10 @@ class TestSyncPositionsFromBroker: db_conn = init_db(":memory:") broker = MagicMock() - broker.get_balance = AsyncMock( - side_effect=ConnectionError("KIS unreachable") - ) + broker.get_balance = AsyncMock(side_effect=ConnectionError("KIS unreachable")) overseas_broker = MagicMock() - synced = await sync_positions_from_broker( - broker, overseas_broker, db_conn, settings - ) + synced = await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) assert synced == 0 # Failure treated as no-op @@ -5151,9 +5233,7 @@ class TestSyncPositionsFromBroker: return_value={"output1": [], "output2": [{}]} ) - await sync_positions_from_broker( - broker, overseas_broker, db_conn, settings - ) + await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) # Two distinct exchange codes (NASD, NYSE) → 2 calls assert overseas_broker.get_overseas_balance.call_count == 2 @@ -5166,7 +5246,9 @@ class TestSyncPositionsFromBroker: balance = { "output1": [{"pdno": "005930", "ord_psbl_qty": "5", "pchs_avg_pric": "68000.0"}], - "output2": [{"tot_evlu_amt": "1000000", "dnca_tot_amt": "500000", "pchs_amt_smtl_amt": "500000"}], + "output2": [ + {"tot_evlu_amt": "1000000", "dnca_tot_amt": "500000", "pchs_amt_smtl_amt": "500000"} + ], } broker = MagicMock() broker.get_balance = AsyncMock(return_value=balance) @@ -5175,6 +5257,7 @@ class TestSyncPositionsFromBroker: await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) from src.db import get_open_position + pos = get_open_position(db_conn, "005930", "KR") assert pos is not None assert pos["price"] == 68000.0 @@ -5196,6 +5279,7 @@ class TestSyncPositionsFromBroker: await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) from src.db import get_open_position + pos = get_open_position(db_conn, "AAPL", "US_NASDAQ") assert pos is not None assert pos["price"] == 170.0 @@ -5209,7 +5293,9 @@ class TestSyncPositionsFromBroker: # No pchs_avg_pric in output1 balance = { "output1": [{"pdno": "005930", "ord_psbl_qty": "5"}], - "output2": [{"tot_evlu_amt": "1000000", "dnca_tot_amt": "500000", "pchs_amt_smtl_amt": "500000"}], + "output2": [ + {"tot_evlu_amt": "1000000", "dnca_tot_amt": "500000", "pchs_amt_smtl_amt": "500000"} + ], } broker = MagicMock() broker.get_balance = AsyncMock(return_value=balance) @@ -5218,6 +5304,7 @@ class TestSyncPositionsFromBroker: await sync_positions_from_broker(broker, overseas_broker, db_conn, settings) from src.db import get_open_position + pos = get_open_position(db_conn, "005930", "KR") assert pos is not None assert pos["price"] == 0.0 @@ -5345,12 +5432,8 @@ class TestHandleOverseasPendingOrders: "ovrs_excg_cd": "NASD", } overseas_broker = MagicMock() - overseas_broker.get_overseas_pending_orders = AsyncMock( - return_value=[pending_order] - ) - overseas_broker.cancel_overseas_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + overseas_broker.get_overseas_pending_orders = AsyncMock(return_value=[pending_order]) + overseas_broker.cancel_overseas_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) sell_resubmit_counts: dict[str, int] = {} buy_cooldown: dict[str, float] = {} @@ -5385,18 +5468,10 @@ class TestHandleOverseasPendingOrders: "ovrs_excg_cd": "NASD", } overseas_broker = MagicMock() - overseas_broker.get_overseas_pending_orders = AsyncMock( - return_value=[pending_order] - ) - overseas_broker.cancel_overseas_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) - overseas_broker.get_overseas_price = AsyncMock( - return_value={"output": {"last": "200.0"}} - ) - overseas_broker.send_overseas_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + overseas_broker.get_overseas_pending_orders = AsyncMock(return_value=[pending_order]) + overseas_broker.cancel_overseas_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) + overseas_broker.get_overseas_price = AsyncMock(return_value={"output": {"last": "200.0"}}) + overseas_broker.send_overseas_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) sell_resubmit_counts: dict[str, int] = {} @@ -5427,9 +5502,7 @@ class TestHandleOverseasPendingOrders: "ovrs_excg_cd": "NASD", } overseas_broker = MagicMock() - overseas_broker.get_overseas_pending_orders = AsyncMock( - return_value=[pending_order] - ) + overseas_broker.get_overseas_pending_orders = AsyncMock(return_value=[pending_order]) overseas_broker.cancel_overseas_order = AsyncMock( return_value={"rt_cd": "1", "msg1": "Error"} # failure ) @@ -5458,12 +5531,8 @@ class TestHandleOverseasPendingOrders: "ovrs_excg_cd": "NASD", } overseas_broker = MagicMock() - overseas_broker.get_overseas_pending_orders = AsyncMock( - return_value=[pending_order] - ) - overseas_broker.cancel_overseas_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + overseas_broker.get_overseas_pending_orders = AsyncMock(return_value=[pending_order]) + overseas_broker.cancel_overseas_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) overseas_broker.send_overseas_order = AsyncMock() # Already resubmitted once @@ -5536,9 +5605,7 @@ class TestHandleDomesticPendingOrders: } broker = MagicMock() broker.get_domestic_pending_orders = AsyncMock(return_value=[pending_order]) - broker.cancel_domestic_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + broker.cancel_domestic_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) sell_resubmit_counts: dict[str, int] = {} buy_cooldown: dict[str, float] = {} @@ -5577,17 +5644,13 @@ class TestHandleDomesticPendingOrders: } broker = MagicMock() broker.get_domestic_pending_orders = AsyncMock(return_value=[pending_order]) - broker.cancel_domestic_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + broker.cancel_domestic_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) broker.get_current_price = AsyncMock(return_value=(50000.0, 0.0, 0.0)) broker.send_order = AsyncMock(return_value={"rt_cd": "0"}) sell_resubmit_counts: dict[str, int] = {} - await handle_domestic_pending_orders( - broker, telegram, settings, sell_resubmit_counts - ) + await handle_domestic_pending_orders(broker, telegram, settings, sell_resubmit_counts) broker.cancel_domestic_order.assert_called_once() broker.send_order.assert_called_once() @@ -5621,9 +5684,7 @@ class TestHandleDomesticPendingOrders: sell_resubmit_counts: dict[str, int] = {} - await handle_domestic_pending_orders( - broker, telegram, settings, sell_resubmit_counts - ) + await handle_domestic_pending_orders(broker, telegram, settings, sell_resubmit_counts) broker.send_order.assert_not_called() telegram.notify_unfilled_order.assert_not_called() @@ -5643,17 +5704,13 @@ class TestHandleDomesticPendingOrders: } broker = MagicMock() broker.get_domestic_pending_orders = AsyncMock(return_value=[pending_order]) - broker.cancel_domestic_order = AsyncMock( - return_value={"rt_cd": "0", "msg1": "OK"} - ) + broker.cancel_domestic_order = AsyncMock(return_value={"rt_cd": "0", "msg1": "OK"}) broker.send_order = AsyncMock() # Already resubmitted once sell_resubmit_counts: dict[str, int] = {"KR:005930": 1} - await handle_domestic_pending_orders( - broker, telegram, settings, sell_resubmit_counts - ) + await handle_domestic_pending_orders(broker, telegram, settings, sell_resubmit_counts) broker.cancel_domestic_order.assert_called_once() broker.send_order.assert_not_called() @@ -5867,9 +5924,7 @@ class TestOverseasGhostPositionClose: current_price = 1.5 # ord_psbl_qty=5 means the code passes the qty check and a SELL is sent balance_data = { - "output1": [ - {"ovrs_pdno": stock_code, "ord_psbl_qty": "5", "ovrs_cblc_qty": "5"} - ], + "output1": [{"ovrs_pdno": stock_code, "ord_psbl_qty": "5", "ovrs_cblc_qty": "5"}], "output2": [{"tot_evlu_amt": "10000"}], } sell_result = {"rt_cd": "1", "msg1": "모의투자 잔고내역이 없습니다"} @@ -5905,9 +5960,11 @@ class TestOverseasGhostPositionClose: settings.POSITION_SIZING_ENABLED = False settings.PAPER_OVERSEAS_CASH = 0 - with patch("src.main.log_trade") as mock_log_trade, patch( - "src.main.get_open_position", return_value=None - ), patch("src.main.get_latest_buy_trade", return_value=None): + with ( + patch("src.main.log_trade") as mock_log_trade, + patch("src.main.get_open_position", return_value=None), + patch("src.main.get_latest_buy_trade", return_value=None), + ): await trading_cycle( broker=domestic_broker, overseas_broker=overseas_broker, @@ -5976,8 +6033,9 @@ class TestOverseasGhostPositionClose: db_conn = MagicMock() - with patch("src.main.log_trade") as mock_log_trade, patch( - "src.main.get_open_position", return_value=None + with ( + patch("src.main.log_trade") as mock_log_trade, + patch("src.main.get_open_position", return_value=None), ): await trading_cycle( broker=domestic_broker, @@ -6168,7 +6226,10 @@ async def test_us_min_price_filter_boundary(price: float, should_block: bool) -> return_value={"output": {"last": str(price), "rate": "0.0"}} ) overseas_broker.get_overseas_balance = AsyncMock( - return_value={"output1": [], "output2": [{"frcr_evlu_tota": "10000", "frcr_buy_amt_smtl": "0"}]} + return_value={ + "output1": [], + "output2": [{"frcr_evlu_tota": "10000", "frcr_buy_amt_smtl": "0"}], + } ) overseas_broker.get_overseas_buying_power = AsyncMock( return_value={"output": {"ovrs_ord_psbl_amt": "10000"}} diff --git a/tests/test_market_schedule.py b/tests/test_market_schedule.py index 49110bc..8723c2f 100644 --- a/tests/test_market_schedule.py +++ b/tests/test_market_schedule.py @@ -173,9 +173,7 @@ class TestGetNextMarketOpen: """Should find next Monday opening when called on weekend.""" # Saturday 2026-02-07 12:00 UTC test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) - market, open_time = get_next_market_open( - enabled_markets=["KR"], now=test_time - ) + market, open_time = get_next_market_open(enabled_markets=["KR"], now=test_time) assert market.code == "KR" # Monday 2026-02-09 09:00 KST expected = datetime(2026, 2, 9, 9, 0, tzinfo=ZoneInfo("Asia/Seoul")) @@ -185,9 +183,7 @@ class TestGetNextMarketOpen: """Should find next day opening when called after market close.""" # Monday 2026-02-02 16:00 KST (after close) test_time = datetime(2026, 2, 2, 16, 0, tzinfo=ZoneInfo("Asia/Seoul")) - market, open_time = get_next_market_open( - enabled_markets=["KR"], now=test_time - ) + market, open_time = get_next_market_open(enabled_markets=["KR"], now=test_time) assert market.code == "KR" # Tuesday 2026-02-03 09:00 KST expected = datetime(2026, 2, 3, 9, 0, tzinfo=ZoneInfo("Asia/Seoul")) @@ -197,9 +193,7 @@ class TestGetNextMarketOpen: """Should find earliest opening market among multiple.""" # Saturday 2026-02-07 12:00 UTC test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) - market, open_time = get_next_market_open( - enabled_markets=["KR", "US_NASDAQ"], now=test_time - ) + market, open_time = get_next_market_open(enabled_markets=["KR", "US_NASDAQ"], now=test_time) # Monday 2026-02-09: KR opens at 09:00 KST = 00:00 UTC # Monday 2026-02-09: US opens at 09:30 EST = 14:30 UTC # KR opens first @@ -214,9 +208,7 @@ class TestGetNextMarketOpen: def test_get_next_market_open_invalid_market(self) -> None: """Should skip invalid market codes.""" test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) - market, _ = get_next_market_open( - enabled_markets=["INVALID", "KR"], now=test_time - ) + market, _ = get_next_market_open(enabled_markets=["INVALID", "KR"], now=test_time) assert market.code == "KR" def test_get_next_market_open_prefers_extended_session(self) -> None: diff --git a/tests/test_overseas_broker.py b/tests/test_overseas_broker.py index bd74cd9..6ac6f9b 100644 --- a/tests/test_overseas_broker.py +++ b/tests/test_overseas_broker.py @@ -8,7 +8,7 @@ import aiohttp import pytest from src.broker.kis_api import KISBroker -from src.broker.overseas import OverseasBroker, _PRICE_EXCHANGE_MAP, _RANKING_EXCHANGE_MAP +from src.broker.overseas import _PRICE_EXCHANGE_MAP, _RANKING_EXCHANGE_MAP, OverseasBroker from src.config import Settings @@ -85,25 +85,27 @@ class TestConfigDefaults: assert mock_settings.OVERSEAS_RANKING_VOLUME_TR_ID == "HHDFS76270000" def test_fluct_path(self, mock_settings: Settings) -> None: - assert mock_settings.OVERSEAS_RANKING_FLUCT_PATH == "/uapi/overseas-stock/v1/ranking/updown-rate" + assert ( + mock_settings.OVERSEAS_RANKING_FLUCT_PATH + == "/uapi/overseas-stock/v1/ranking/updown-rate" + ) def test_volume_path(self, mock_settings: Settings) -> None: - assert mock_settings.OVERSEAS_RANKING_VOLUME_PATH == "/uapi/overseas-stock/v1/ranking/volume-surge" + assert ( + mock_settings.OVERSEAS_RANKING_VOLUME_PATH + == "/uapi/overseas-stock/v1/ranking/volume-surge" + ) class TestFetchOverseasRankings: """Test fetch_overseas_rankings method.""" @pytest.mark.asyncio - async def test_fluctuation_uses_correct_params( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_fluctuation_uses_correct_params(self, overseas_broker: OverseasBroker) -> None: """Fluctuation ranking should use HHDFS76290000, updown-rate path, and correct params.""" mock_resp = AsyncMock() mock_resp.status = 200 - mock_resp.json = AsyncMock( - return_value={"output": [{"symb": "AAPL", "name": "Apple"}]} - ) + mock_resp.json = AsyncMock(return_value={"output": [{"symb": "AAPL", "name": "Apple"}]}) mock_session = MagicMock() mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp)) @@ -132,15 +134,11 @@ class TestFetchOverseasRankings: overseas_broker._broker._auth_headers.assert_called_with("HHDFS76290000") @pytest.mark.asyncio - async def test_volume_uses_correct_params( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_volume_uses_correct_params(self, overseas_broker: OverseasBroker) -> None: """Volume ranking should use HHDFS76270000, volume-surge path, and correct params.""" mock_resp = AsyncMock() mock_resp.status = 200 - mock_resp.json = AsyncMock( - return_value={"output": [{"symb": "TSLA", "name": "Tesla"}]} - ) + mock_resp.json = AsyncMock(return_value={"output": [{"symb": "TSLA", "name": "Tesla"}]}) mock_session = MagicMock() mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp)) @@ -169,9 +167,7 @@ class TestFetchOverseasRankings: overseas_broker._broker._auth_headers.assert_called_with("HHDFS76270000") @pytest.mark.asyncio - async def test_404_returns_empty_list( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_404_returns_empty_list(self, overseas_broker: OverseasBroker) -> None: """HTTP 404 should return empty list (fallback) instead of raising.""" mock_resp = AsyncMock() mock_resp.status = 404 @@ -186,9 +182,7 @@ class TestFetchOverseasRankings: assert result == [] @pytest.mark.asyncio - async def test_non_404_error_raises( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_non_404_error_raises(self, overseas_broker: OverseasBroker) -> None: """Non-404 HTTP errors should raise ConnectionError.""" mock_resp = AsyncMock() mock_resp.status = 500 @@ -203,9 +197,7 @@ class TestFetchOverseasRankings: await overseas_broker.fetch_overseas_rankings("NASD") @pytest.mark.asyncio - async def test_empty_response_returns_empty( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_empty_response_returns_empty(self, overseas_broker: OverseasBroker) -> None: """Empty output in response should return empty list.""" mock_resp = AsyncMock() mock_resp.status = 200 @@ -220,18 +212,14 @@ class TestFetchOverseasRankings: assert result == [] @pytest.mark.asyncio - async def test_ranking_disabled_returns_empty( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_ranking_disabled_returns_empty(self, overseas_broker: OverseasBroker) -> None: """When OVERSEAS_RANKING_ENABLED=False, should return empty immediately.""" overseas_broker._broker._settings.OVERSEAS_RANKING_ENABLED = False result = await overseas_broker.fetch_overseas_rankings("NASD") assert result == [] @pytest.mark.asyncio - async def test_limit_truncates_results( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_limit_truncates_results(self, overseas_broker: OverseasBroker) -> None: """Results should be truncated to the specified limit.""" rows = [{"symb": f"SYM{i}"} for i in range(20)] mock_resp = AsyncMock() @@ -247,9 +235,7 @@ class TestFetchOverseasRankings: assert len(result) == 5 @pytest.mark.asyncio - async def test_network_error_raises( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_network_error_raises(self, overseas_broker: OverseasBroker) -> None: """Network errors should raise ConnectionError.""" cm = MagicMock() cm.__aenter__ = AsyncMock(side_effect=aiohttp.ClientError("timeout")) @@ -264,9 +250,7 @@ class TestFetchOverseasRankings: await overseas_broker.fetch_overseas_rankings("NASD") @pytest.mark.asyncio - async def test_exchange_code_mapping_applied( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_exchange_code_mapping_applied(self, overseas_broker: OverseasBroker) -> None: """All major exchanges should use mapped codes in API params.""" for original, mapped in [("NASD", "NAS"), ("NYSE", "NYS"), ("AMEX", "AMS")]: mock_resp = AsyncMock() @@ -298,7 +282,9 @@ class TestGetOverseasPrice: mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp)) _setup_broker_mocks(overseas_broker, mock_session) - overseas_broker._broker._auth_headers = AsyncMock(return_value={"authorization": "Bearer t"}) + overseas_broker._broker._auth_headers = AsyncMock( + return_value={"authorization": "Bearer t"} + ) result = await overseas_broker.get_overseas_price("NASD", "AAPL") assert result["output"]["last"] == "150.00" @@ -530,11 +516,14 @@ class TestPriceExchangeMap: def test_price_map_equals_ranking_map(self) -> None: assert _PRICE_EXCHANGE_MAP is _RANKING_EXCHANGE_MAP - @pytest.mark.parametrize("original,expected", [ - ("NASD", "NAS"), - ("NYSE", "NYS"), - ("AMEX", "AMS"), - ]) + @pytest.mark.parametrize( + "original,expected", + [ + ("NASD", "NAS"), + ("NYSE", "NYS"), + ("AMEX", "AMS"), + ], + ) def test_us_exchange_code_mapping(self, original: str, expected: str) -> None: assert _PRICE_EXCHANGE_MAP[original] == expected @@ -574,9 +563,7 @@ class TestOrderRtCdCheck: return OverseasBroker(broker) @pytest.mark.asyncio - async def test_success_rt_cd_returns_data( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_success_rt_cd_returns_data(self, overseas_broker: OverseasBroker) -> None: """rt_cd='0' → order accepted, data returned.""" mock_resp = AsyncMock() mock_resp.status = 200 @@ -590,9 +577,7 @@ class TestOrderRtCdCheck: assert result["rt_cd"] == "0" @pytest.mark.asyncio - async def test_error_rt_cd_returns_data_with_msg( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_error_rt_cd_returns_data_with_msg(self, overseas_broker: OverseasBroker) -> None: """rt_cd != '0' → order rejected, data still returned (caller checks rt_cd).""" mock_resp = AsyncMock() mock_resp.status = 200 @@ -623,6 +608,7 @@ class TestPaperOverseasCash: def test_env_override(self) -> None: import os + os.environ["PAPER_OVERSEAS_CASH"] = "25000" settings = Settings( KIS_APP_KEY="k", @@ -635,6 +621,7 @@ class TestPaperOverseasCash: def test_zero_disables_fallback(self) -> None: import os + os.environ["PAPER_OVERSEAS_CASH"] = "0" settings = Settings( KIS_APP_KEY="k", @@ -822,9 +809,7 @@ class TestGetOverseasPendingOrders: """Tests for get_overseas_pending_orders method.""" @pytest.mark.asyncio - async def test_paper_mode_returns_empty( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_paper_mode_returns_empty(self, overseas_broker: OverseasBroker) -> None: """Paper mode should immediately return [] without any API call.""" # Default mock_settings has MODE="paper" overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( @@ -855,9 +840,7 @@ class TestGetOverseasPendingOrders: overseas_broker._broker._auth_headers = mock_auth_headers # type: ignore[method-assign] - pending_orders = [ - {"odno": "001", "pdno": "AAPL", "sll_buy_dvsn_cd": "02", "nccs_qty": "5"} - ] + pending_orders = [{"odno": "001", "pdno": "AAPL", "sll_buy_dvsn_cd": "02", "nccs_qty": "5"}] mock_resp = AsyncMock() mock_resp.status = 200 mock_resp.json = AsyncMock(return_value={"output": pending_orders}) @@ -879,9 +862,7 @@ class TestGetOverseasPendingOrders: assert captured_params[0]["OVRS_EXCG_CD"] == "NASD" @pytest.mark.asyncio - async def test_live_mode_connection_error( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_live_mode_connection_error(self, overseas_broker: OverseasBroker) -> None: """Network error in live mode should raise ConnectionError.""" overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( update={"MODE": "live"} @@ -926,55 +907,41 @@ class TestCancelOverseasOrder: return captured_tr_ids, mock_session @pytest.mark.asyncio - async def test_us_live_uses_tttt1004u( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_us_live_uses_tttt1004u(self, overseas_broker: OverseasBroker) -> None: """US exchange in live mode should use TTTT1004U.""" overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( update={"MODE": "live"} ) - captured, _ = self._setup_cancel_mocks( - overseas_broker, {"rt_cd": "0", "msg1": "OK"} - ) + captured, _ = self._setup_cancel_mocks(overseas_broker, {"rt_cd": "0", "msg1": "OK"}) await overseas_broker.cancel_overseas_order("NASD", "AAPL", "ORD001", 5) assert "TTTT1004U" in captured @pytest.mark.asyncio - async def test_us_paper_uses_vttt1004u( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_us_paper_uses_vttt1004u(self, overseas_broker: OverseasBroker) -> None: """US exchange in paper mode should use VTTT1004U.""" # Default mock_settings has MODE="paper" - captured, _ = self._setup_cancel_mocks( - overseas_broker, {"rt_cd": "0", "msg1": "OK"} - ) + captured, _ = self._setup_cancel_mocks(overseas_broker, {"rt_cd": "0", "msg1": "OK"}) await overseas_broker.cancel_overseas_order("NASD", "AAPL", "ORD001", 5) assert "VTTT1004U" in captured @pytest.mark.asyncio - async def test_hk_live_uses_ttts1003u( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_hk_live_uses_ttts1003u(self, overseas_broker: OverseasBroker) -> None: """SEHK exchange in live mode should use TTTS1003U.""" overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( update={"MODE": "live"} ) - captured, _ = self._setup_cancel_mocks( - overseas_broker, {"rt_cd": "0", "msg1": "OK"} - ) + captured, _ = self._setup_cancel_mocks(overseas_broker, {"rt_cd": "0", "msg1": "OK"}) await overseas_broker.cancel_overseas_order("SEHK", "0700", "ORD002", 10) assert "TTTS1003U" in captured @pytest.mark.asyncio - async def test_cancel_sets_rvse_cncl_dvsn_cd_02( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_cancel_sets_rvse_cncl_dvsn_cd_02(self, overseas_broker: OverseasBroker) -> None: """Cancel body must include RVSE_CNCL_DVSN_CD='02' and OVRS_ORD_UNPR='0'.""" captured_body: list[dict] = [] @@ -1005,9 +972,7 @@ class TestCancelOverseasOrder: assert captured_body[0]["ORGN_ODNO"] == "ORD003" @pytest.mark.asyncio - async def test_cancel_sets_hashkey_header( - self, overseas_broker: OverseasBroker - ) -> None: + async def test_cancel_sets_hashkey_header(self, overseas_broker: OverseasBroker) -> None: """hashkey must be set in the request headers.""" captured_headers: list[dict] = [] overseas_broker._broker._get_hash_key = AsyncMock(return_value="test_hash") # type: ignore[method-assign] diff --git a/tests/test_pre_market_planner.py b/tests/test_pre_market_planner.py index 50e2a3b..b35161c 100644 --- a/tests/test_pre_market_planner.py +++ b/tests/test_pre_market_planner.py @@ -78,9 +78,7 @@ def _gemini_response_json( "rationale": "Near circuit breaker", } ] - return json.dumps( - {"market_outlook": outlook, "global_rules": global_rules, "stocks": stocks} - ) + return json.dumps({"market_outlook": outlook, "global_rules": global_rules, "stocks": stocks}) def _make_planner( @@ -564,8 +562,12 @@ class TestBuildPrompt: def test_prompt_contains_cross_market(self) -> None: planner = _make_planner() cross = CrossMarketContext( - market="US", date="2026-02-07", total_pnl=1.5, - win_rate=60, index_change_pct=0.8, lessons=["Cut losses early"], + market="US", + date="2026-02-07", + total_pnl=1.5, + win_rate=60, + index_change_pct=0.8, + lessons=["Cut losses early"], ) prompt = planner._build_prompt("KR", [_candidate()], {}, None, cross) @@ -683,9 +685,7 @@ class TestSmartFallbackPlaybook: ) def test_momentum_candidate_gets_buy_on_volume(self) -> None: - candidates = [ - _candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0) - ] + candidates = [_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)] settings = self._make_settings() pb = PreMarketPlanner._smart_fallback_playbook( @@ -707,9 +707,7 @@ class TestSmartFallbackPlaybook: assert sell_sc.condition.price_change_pct_below == -3.0 def test_oversold_candidate_gets_buy_on_rsi(self) -> None: - candidates = [ - _candidate(code="005930", signal="oversold", rsi=22.0, volume_ratio=3.5) - ] + candidates = [_candidate(code="005930", signal="oversold", rsi=22.0, volume_ratio=3.5)] settings = self._make_settings() pb = PreMarketPlanner._smart_fallback_playbook( @@ -776,9 +774,7 @@ class TestSmartFallbackPlaybook: def test_empty_candidates_returns_empty_playbook(self) -> None: settings = self._make_settings() - pb = PreMarketPlanner._smart_fallback_playbook( - date(2026, 2, 17), "US_AMEX", [], settings - ) + pb = PreMarketPlanner._smart_fallback_playbook(date(2026, 2, 17), "US_AMEX", [], settings) assert pb.stock_count == 0 @@ -814,19 +810,14 @@ class TestSmartFallbackPlaybook: planner = _make_planner() planner._gemini.decide = AsyncMock(side_effect=ConnectionError("429 quota exceeded")) # momentum candidate - candidates = [ - _candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0) - ] + candidates = [_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)] - pb = await planner.generate_playbook( - "US_AMEX", candidates, today=date(2026, 2, 18) - ) + pb = await planner.generate_playbook("US_AMEX", candidates, today=date(2026, 2, 18)) # Should NOT be all-SELL defensive; should have BUY for momentum assert pb.stock_count == 1 buy_scenarios = [ - s for s in pb.stock_playbooks[0].scenarios - if s.action == ScenarioAction.BUY + s for s in pb.stock_playbooks[0].scenarios if s.action == ScenarioAction.BUY ] assert len(buy_scenarios) == 1 assert buy_scenarios[0].condition.volume_ratio_above == 2.0 # VOL_MULTIPLIER default diff --git a/tests/test_scenario_engine.py b/tests/test_scenario_engine.py index 4fcea51..4b6bbd5 100644 --- a/tests/test_scenario_engine.py +++ b/tests/test_scenario_engine.py @@ -14,7 +14,7 @@ from src.strategy.models import ( StockPlaybook, StockScenario, ) -from src.strategy.scenario_engine import ScenarioEngine, ScenarioMatch +from src.strategy.scenario_engine import ScenarioEngine @pytest.fixture @@ -162,13 +162,15 @@ class TestEvaluateCondition: def test_mixed_invalid_types_no_exception(self, engine: ScenarioEngine) -> None: """Various invalid types should not raise exceptions.""" cond = StockCondition( - rsi_below=30.0, volume_ratio_above=2.0, - price_above=100, price_change_pct_below=-1.0, + rsi_below=30.0, + volume_ratio_above=2.0, + price_above=100, + price_change_pct_below=-1.0, ) data = { - "rsi": [25], # list + "rsi": [25], # list "volume_ratio": "bad", # non-numeric string - "current_price": {}, # dict + "current_price": {}, # dict "price_change_pct": object(), # arbitrary object } # Should return False (invalid types → None → False), never raise @@ -356,9 +358,7 @@ class TestEvaluate: def test_match_details_populated(self, engine: ScenarioEngine) -> None: pb = _playbook(scenarios=[_scenario(rsi_below=30.0, volume_ratio_above=2.0)]) - result = engine.evaluate( - pb, "005930", {"rsi": 25.0, "volume_ratio": 3.0}, {} - ) + result = engine.evaluate(pb, "005930", {"rsi": 25.0, "volume_ratio": 3.0}, {}) assert result.match_details.get("rsi") == 25.0 assert result.match_details.get("volume_ratio") == 3.0 @@ -381,7 +381,9 @@ class TestEvaluate: ), StockPlaybook( stock_code="MSFT", - scenarios=[_scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80)], + scenarios=[ + _scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80) + ], ), ], ) @@ -450,58 +452,42 @@ class TestEvaluate: class TestPositionAwareConditions: """Tests for unrealized_pnl_pct and holding_days condition fields.""" - def test_evaluate_condition_unrealized_pnl_above_matches( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_unrealized_pnl_above_matches(self, engine: ScenarioEngine) -> None: """unrealized_pnl_pct_above should match when P&L exceeds threshold.""" condition = StockCondition(unrealized_pnl_pct_above=3.0) assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 5.0}) is True - def test_evaluate_condition_unrealized_pnl_above_no_match( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_unrealized_pnl_above_no_match(self, engine: ScenarioEngine) -> None: """unrealized_pnl_pct_above should NOT match when P&L is below threshold.""" condition = StockCondition(unrealized_pnl_pct_above=3.0) assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 2.0}) is False - def test_evaluate_condition_unrealized_pnl_below_matches( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_unrealized_pnl_below_matches(self, engine: ScenarioEngine) -> None: """unrealized_pnl_pct_below should match when P&L is under threshold.""" condition = StockCondition(unrealized_pnl_pct_below=-2.0) assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -3.5}) is True - def test_evaluate_condition_unrealized_pnl_below_no_match( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_unrealized_pnl_below_no_match(self, engine: ScenarioEngine) -> None: """unrealized_pnl_pct_below should NOT match when P&L is above threshold.""" condition = StockCondition(unrealized_pnl_pct_below=-2.0) assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -1.0}) is False - def test_evaluate_condition_holding_days_above_matches( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_holding_days_above_matches(self, engine: ScenarioEngine) -> None: """holding_days_above should match when position held longer than threshold.""" condition = StockCondition(holding_days_above=5) assert engine.evaluate_condition(condition, {"holding_days": 7}) is True - def test_evaluate_condition_holding_days_above_no_match( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_holding_days_above_no_match(self, engine: ScenarioEngine) -> None: """holding_days_above should NOT match when position held shorter.""" condition = StockCondition(holding_days_above=5) assert engine.evaluate_condition(condition, {"holding_days": 3}) is False - def test_evaluate_condition_holding_days_below_matches( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_holding_days_below_matches(self, engine: ScenarioEngine) -> None: """holding_days_below should match when position held fewer days.""" condition = StockCondition(holding_days_below=3) assert engine.evaluate_condition(condition, {"holding_days": 1}) is True - def test_evaluate_condition_holding_days_below_no_match( - self, engine: ScenarioEngine - ) -> None: + def test_evaluate_condition_holding_days_below_no_match(self, engine: ScenarioEngine) -> None: """holding_days_below should NOT match when held more days.""" condition = StockCondition(holding_days_below=3) assert engine.evaluate_condition(condition, {"holding_days": 5}) is False @@ -513,33 +499,33 @@ class TestPositionAwareConditions: holding_days_above=5, ) # Both met → match - assert engine.evaluate_condition( - condition, - {"unrealized_pnl_pct": 4.5, "holding_days": 7}, - ) is True + assert ( + engine.evaluate_condition( + condition, + {"unrealized_pnl_pct": 4.5, "holding_days": 7}, + ) + is True + ) # Only pnl met → no match - assert engine.evaluate_condition( - condition, - {"unrealized_pnl_pct": 4.5, "holding_days": 3}, - ) is False + assert ( + engine.evaluate_condition( + condition, + {"unrealized_pnl_pct": 4.5, "holding_days": 3}, + ) + is False + ) - def test_missing_unrealized_pnl_does_not_match( - self, engine: ScenarioEngine - ) -> None: + def test_missing_unrealized_pnl_does_not_match(self, engine: ScenarioEngine) -> None: """Missing unrealized_pnl_pct key should not match the condition.""" condition = StockCondition(unrealized_pnl_pct_above=3.0) assert engine.evaluate_condition(condition, {}) is False - def test_missing_holding_days_does_not_match( - self, engine: ScenarioEngine - ) -> None: + def test_missing_holding_days_does_not_match(self, engine: ScenarioEngine) -> None: """Missing holding_days key should not match the condition.""" condition = StockCondition(holding_days_above=5) assert engine.evaluate_condition(condition, {}) is False - def test_match_details_includes_position_fields( - self, engine: ScenarioEngine - ) -> None: + def test_match_details_includes_position_fields(self, engine: ScenarioEngine) -> None: """match_details should include position fields when condition specifies them.""" pb = _playbook( scenarios=[ diff --git a/tests/test_smart_scanner.py b/tests/test_smart_scanner.py index bb8200f..5fa1c07 100644 --- a/tests/test_smart_scanner.py +++ b/tests/test_smart_scanner.py @@ -2,9 +2,10 @@ from __future__ import annotations -import pytest from unittest.mock import AsyncMock, MagicMock +import pytest + from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner from src.analysis.volatility import VolatilityAnalyzer from src.broker.kis_api import KISBroker @@ -200,9 +201,7 @@ class TestSmartVolatilityScanner: assert len(candidates) <= scanner.top_n @pytest.mark.asyncio - async def test_get_stock_codes( - self, scanner: SmartVolatilityScanner - ) -> None: + async def test_get_stock_codes(self, scanner: SmartVolatilityScanner) -> None: """Test extraction of stock codes from candidates.""" candidates = [ ScanCandidate( diff --git a/tests/test_strategy_models.py b/tests/test_strategy_models.py index 9ea40e0..7cee5eb 100644 --- a/tests/test_strategy_models.py +++ b/tests/test_strategy_models.py @@ -19,7 +19,6 @@ from src.strategy.models import ( StockScenario, ) - # --------------------------------------------------------------------------- # StockCondition # --------------------------------------------------------------------------- diff --git a/tests/test_telegram.py b/tests/test_telegram.py index 606b4e7..6af177c 100644 --- a/tests/test_telegram.py +++ b/tests/test_telegram.py @@ -5,7 +5,11 @@ from unittest.mock import AsyncMock, patch import aiohttp import pytest -from src.notifications.telegram_client import NotificationFilter, NotificationPriority, TelegramClient +from src.notifications.telegram_client import ( + NotificationFilter, + NotificationPriority, + TelegramClient, +) class TestTelegramClientInit: @@ -13,9 +17,7 @@ class TestTelegramClientInit: def test_disabled_via_flag(self) -> None: """Client disabled via enabled=False flag.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=False - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=False) assert client._enabled is False def test_disabled_missing_token(self) -> None: @@ -30,9 +32,7 @@ class TestTelegramClientInit: def test_enabled_with_credentials(self) -> None: """Client enabled when credentials provided.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) assert client._enabled is True @@ -42,9 +42,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_send_message_success(self) -> None: """send_message returns True on successful send.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -76,9 +74,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_send_message_api_error(self) -> None: """send_message returns False on API error.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 400 @@ -93,9 +89,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_send_message_with_markdown(self) -> None: """send_message supports different parse modes.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -128,9 +122,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_trade_execution_format(self) -> None: """Trade notification has correct format.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -163,9 +155,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_playbook_generated_format(self) -> None: """Playbook generated notification has expected fields.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -190,9 +180,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_scenario_matched_format(self) -> None: """Scenario matched notification has expected fields.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -217,9 +205,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_playbook_failed_format(self) -> None: """Playbook failed notification has expected fields.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -240,9 +226,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_circuit_breaker_priority(self) -> None: """Circuit breaker uses CRITICAL priority.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -260,9 +244,7 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_api_error_handling(self) -> None: """API errors logged but don't crash.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 400 @@ -277,25 +259,19 @@ class TestNotificationSending: @pytest.mark.asyncio async def test_timeout_handling(self) -> None: """Timeouts logged but don't crash.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) with patch( "aiohttp.ClientSession.post", side_effect=aiohttp.ClientError("Connection timeout"), ): # Should not raise exception - await client.notify_error( - error_type="Test Error", error_msg="Test", context="test" - ) + await client.notify_error(error_type="Test Error", error_msg="Test", context="test") @pytest.mark.asyncio async def test_session_management(self) -> None: """Session created and reused correctly.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) # Session should be None initially assert client._session is None @@ -324,9 +300,7 @@ class TestRateLimiting: """Rate limiter delays rapid requests.""" import time - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0 - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0) mock_resp = AsyncMock() mock_resp.status = 200 @@ -353,9 +327,7 @@ class TestMessagePriorities: @pytest.mark.asyncio async def test_low_priority_uses_info_emoji(self) -> None: """LOW priority uses ℹ️ emoji.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -371,9 +343,7 @@ class TestMessagePriorities: @pytest.mark.asyncio async def test_critical_priority_uses_alarm_emoji(self) -> None: """CRITICAL priority uses 🚨 emoji.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -389,9 +359,7 @@ class TestMessagePriorities: @pytest.mark.asyncio async def test_playbook_generated_priority(self) -> None: """Playbook generated uses MEDIUM priority emoji.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -412,9 +380,7 @@ class TestMessagePriorities: @pytest.mark.asyncio async def test_playbook_failed_priority(self) -> None: """Playbook failed uses HIGH priority emoji.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -433,9 +399,7 @@ class TestMessagePriorities: @pytest.mark.asyncio async def test_scenario_matched_priority(self) -> None: """Scenario matched uses HIGH priority emoji.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_resp = AsyncMock() mock_resp.status = 200 @@ -460,9 +424,7 @@ class TestClientCleanup: @pytest.mark.asyncio async def test_close_closes_session(self) -> None: """close() closes the HTTP session.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) mock_session = AsyncMock() mock_session.closed = False @@ -475,9 +437,7 @@ class TestClientCleanup: @pytest.mark.asyncio async def test_close_handles_no_session(self) -> None: """close() handles None session gracefully.""" - client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True - ) + client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True) # Should not raise exception await client.close() @@ -535,8 +495,12 @@ class TestNotificationFilter: ) with patch("aiohttp.ClientSession.post") as mock_post: await client.notify_trade_execution( - stock_code="005930", market="KR", action="BUY", - quantity=10, price=70000.0, confidence=85.0 + stock_code="005930", + market="KR", + action="BUY", + quantity=10, + price=70000.0, + confidence=85.0, ) mock_post.assert_not_called() @@ -556,8 +520,13 @@ class TestNotificationFilter: async def test_circuit_breaker_always_sends_regardless_of_filter(self) -> None: """notify_circuit_breaker always sends (no filter flag).""" nf = NotificationFilter( - trades=False, market_open_close=False, fat_finger=False, - system_events=False, playbook=False, scenario_match=False, errors=False, + trades=False, + market_open_close=False, + fat_finger=False, + system_events=False, + playbook=False, + scenario_match=False, + errors=False, ) client = TelegramClient( bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf @@ -617,7 +586,7 @@ class TestNotificationFilter: nf = NotificationFilter() assert nf.set_flag("unknown_key", False) is False - def test_as_dict_keys_match_KEYS(self) -> None: + def test_as_dict_keys_match_keys(self) -> None: """as_dict() returns every key defined in KEYS.""" nf = NotificationFilter() d = nf.as_dict() @@ -640,10 +609,17 @@ class TestNotificationFilter: def test_set_notification_all_on(self) -> None: """set_notification('all', True) enables every filter flag.""" client = TelegramClient( - bot_token="123:abc", chat_id="456", enabled=True, + bot_token="123:abc", + chat_id="456", + enabled=True, notification_filter=NotificationFilter( - trades=False, market_open_close=False, scenario_match=False, - fat_finger=False, system_events=False, playbook=False, errors=False, + trades=False, + market_open_close=False, + scenario_match=False, + fat_finger=False, + system_events=False, + playbook=False, + errors=False, ), ) assert client.set_notification("all", True) is True diff --git a/tests/test_telegram_commands.py b/tests/test_telegram_commands.py index a184549..9615022 100644 --- a/tests/test_telegram_commands.py +++ b/tests/test_telegram_commands.py @@ -357,8 +357,7 @@ class TestTradingControlCommands: pause_event.set() await client.send_message( - "▶️ Trading Resumed\n\n" - "Trading operations have been restarted." + "▶️ Trading Resumed\n\nTrading operations have been restarted." ) handler.register_command("resume", mock_resume) @@ -526,9 +525,7 @@ class TestStatusCommands: async def mock_status_error() -> None: """Mock /status handler with error.""" - await client.send_message( - "⚠️ Error\n\nFailed to retrieve trading status." - ) + await client.send_message("⚠️ Error\n\nFailed to retrieve trading status.") handler.register_command("status", mock_status_error) @@ -603,10 +600,7 @@ class TestStatusCommands: async def mock_positions_empty() -> None: """Mock /positions handler with no positions.""" - message = ( - "💼 Account Summary\n\n" - "No balance information available." - ) + message = "💼 Account Summary\n\nNo balance information available." await client.send_message(message) handler.register_command("positions", mock_positions_empty) @@ -639,9 +633,7 @@ class TestStatusCommands: async def mock_positions_error() -> None: """Mock /positions handler with error.""" - await client.send_message( - "⚠️ Error\n\nFailed to retrieve positions." - ) + await client.send_message("⚠️ Error\n\nFailed to retrieve positions.") handler.register_command("positions", mock_positions_error) diff --git a/tests/test_validate_governance_assets.py b/tests/test_validate_governance_assets.py index 3a0bc0b..719d801 100644 --- a/tests/test_validate_governance_assets.py +++ b/tests/test_validate_governance_assets.py @@ -70,7 +70,9 @@ def test_load_changed_files_with_range_uses_git_diff(monkeypatch) -> None: assert check is True assert capture_output is True assert text is True - return SimpleNamespace(stdout="docs/ouroboros/85_loss_recovery_action_plan.md\nsrc/main.py\n") + return SimpleNamespace( + stdout="docs/ouroboros/85_loss_recovery_action_plan.md\nsrc/main.py\n" + ) monkeypatch.setattr(module.subprocess, "run", fake_run) changed = module.load_changed_files(["abc...def"], errors) diff --git a/tests/test_volatility.py b/tests/test_volatility.py index 02f0234..25b08b1 100644 --- a/tests/test_volatility.py +++ b/tests/test_volatility.py @@ -80,9 +80,7 @@ class TestVolatilityAnalyzer: # ATR should be roughly the average true range assert 3.0 <= atr <= 6.0 - def test_calculate_atr_insufficient_data( - self, volatility_analyzer: VolatilityAnalyzer - ) -> None: + def test_calculate_atr_insufficient_data(self, volatility_analyzer: VolatilityAnalyzer) -> None: """Test ATR with insufficient data returns 0.""" high_prices = [110.0, 112.0] low_prices = [105.0, 107.0] @@ -120,17 +118,13 @@ class TestVolatilityAnalyzer: surge = volatility_analyzer.calculate_volume_surge(1000.0, 0.0) assert surge == 1.0 - def test_calculate_pv_divergence_bullish( - self, volatility_analyzer: VolatilityAnalyzer - ) -> None: + def test_calculate_pv_divergence_bullish(self, volatility_analyzer: VolatilityAnalyzer) -> None: """Test bullish price-volume divergence.""" # Price up + Volume up = bullish divergence = volatility_analyzer.calculate_pv_divergence(5.0, 2.0) assert divergence > 0.0 - def test_calculate_pv_divergence_bearish( - self, volatility_analyzer: VolatilityAnalyzer - ) -> None: + def test_calculate_pv_divergence_bearish(self, volatility_analyzer: VolatilityAnalyzer) -> None: """Test bearish price-volume divergence.""" # Price up + Volume down = bearish divergence divergence = volatility_analyzer.calculate_pv_divergence(5.0, 0.5) @@ -144,9 +138,7 @@ class TestVolatilityAnalyzer: divergence = volatility_analyzer.calculate_pv_divergence(-5.0, 2.0) assert divergence < 0.0 - def test_calculate_momentum_score( - self, volatility_analyzer: VolatilityAnalyzer - ) -> None: + def test_calculate_momentum_score(self, volatility_analyzer: VolatilityAnalyzer) -> None: """Test momentum score calculation.""" score = volatility_analyzer.calculate_momentum_score( price_change_1m=5.0, @@ -500,9 +492,7 @@ class TestMarketScanner: # Should keep all current stocks since they're all in top movers assert set(updated) == set(current_watchlist) - def test_get_updated_watchlist_max_replacements( - self, scanner: MarketScanner - ) -> None: + def test_get_updated_watchlist_max_replacements(self, scanner: MarketScanner) -> None: """Test that max_replacements limit is respected.""" current_watchlist = ["000660", "035420", "005490"] @@ -556,8 +546,6 @@ class TestMarketScanner: active_count = 0 peak_count = 0 - original_scan = scanner.scan_stock - async def tracking_scan(code: str, market: Any) -> VolatilityMetrics: nonlocal active_count, peak_count active_count += 1 -- 2.49.1 From 4c0b55d67c9d80572704cd954f2b25022d60df2d Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:22:13 +0900 Subject: [PATCH 3/7] docs: replace absolute plan links with repo-relative paths --- docs/ouroboros/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ouroboros/README.md b/docs/ouroboros/README.md index 6e53e6c..e64062d 100644 --- a/docs/ouroboros/README.md +++ b/docs/ouroboros/README.md @@ -38,5 +38,5 @@ python3 scripts/validate_ouroboros_docs.py ## 원본 계획 문서 -- [v2](/home/agentson/repos/The-Ouroboros/ouroboros_plan_v2.txt) -- [v3](/home/agentson/repos/The-Ouroboros/ouroboros_plan_v3.txt) +- [v2](../../ouroboros_plan_v2.txt) +- [v3](../../ouroboros_plan_v3.txt) -- 2.49.1 From 2c6e9802be81725c2e752f47d1d307154e23155b Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:23:34 +0900 Subject: [PATCH 4/7] docs: sync requirements registry metadata for policy doc changes --- docs/ouroboros/01_requirements_registry.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ouroboros/01_requirements_registry.md b/docs/ouroboros/01_requirements_registry.md index d01269e..23f8868 100644 --- a/docs/ouroboros/01_requirements_registry.md +++ b/docs/ouroboros/01_requirements_registry.md @@ -3,7 +3,7 @@ Doc-ID: DOC-REQ-001 Version: 1.0.0 Status: active Owner: strategy -Updated: 2026-02-26 +Updated: 2026-03-01 --> # 요구사항 원장 (Single Source of Truth) -- 2.49.1 From 05be1120858d192d7b3baf11f938618c72ff55de Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:25:39 +0900 Subject: [PATCH 5/7] docs: move v2/v3 source plans under docs/ouroboros/source --- docs/ouroboros/01_requirements_registry.md | 3 ++- docs/ouroboros/30_code_level_work_orders.md | 1 + docs/ouroboros/40_acceptance_and_test_plan.md | 1 + docs/ouroboros/README.md | 10 +++++----- .../ouroboros/source/ouroboros_plan_v2.txt | 0 .../ouroboros/source/ouroboros_plan_v3.txt | 0 6 files changed, 9 insertions(+), 6 deletions(-) rename ouroboros_plan_v2.txt => docs/ouroboros/source/ouroboros_plan_v2.txt (100%) rename ouroboros_plan_v3.txt => docs/ouroboros/source/ouroboros_plan_v3.txt (100%) diff --git a/docs/ouroboros/01_requirements_registry.md b/docs/ouroboros/01_requirements_registry.md index 23f8868..7248955 100644 --- a/docs/ouroboros/01_requirements_registry.md +++ b/docs/ouroboros/01_requirements_registry.md @@ -1,6 +1,6 @@ # The Ouroboros 실행 문서 허브 -이 폴더는 `ouroboros_plan_v2.txt`, `ouroboros_plan_v3.txt`를 구현 가능한 작업 지시서 수준으로 분해한 문서 허브다. +이 폴더는 `source/ouroboros_plan_v2.txt`, `source/ouroboros_plan_v3.txt`를 구현 가능한 작업 지시서 수준으로 분해한 문서 허브다. ## 읽기 순서 (Routing) @@ -38,5 +38,5 @@ python3 scripts/validate_ouroboros_docs.py ## 원본 계획 문서 -- [v2](../../ouroboros_plan_v2.txt) -- [v3](../../ouroboros_plan_v3.txt) +- [v2](./source/ouroboros_plan_v2.txt) +- [v3](./source/ouroboros_plan_v3.txt) diff --git a/ouroboros_plan_v2.txt b/docs/ouroboros/source/ouroboros_plan_v2.txt similarity index 100% rename from ouroboros_plan_v2.txt rename to docs/ouroboros/source/ouroboros_plan_v2.txt diff --git a/ouroboros_plan_v3.txt b/docs/ouroboros/source/ouroboros_plan_v3.txt similarity index 100% rename from ouroboros_plan_v3.txt rename to docs/ouroboros/source/ouroboros_plan_v3.txt -- 2.49.1 From 940a7e094bed5d787190e34a90068ba3018dde5b Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:35:22 +0900 Subject: [PATCH 6/7] workflow: skip main/master branch guard in --ci mode --- scripts/session_handover_check.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/session_handover_check.py b/scripts/session_handover_check.py index 7b354be..dfe200b 100755 --- a/scripts/session_handover_check.py +++ b/scripts/session_handover_check.py @@ -134,7 +134,7 @@ def main() -> int: branch = _current_branch() if not branch: errors.append("cannot resolve current git branch") - elif branch in {"main", "master"}: + elif not args.ci and branch in {"main", "master"}: errors.append(f"working branch must not be {branch}") _check_handover_entry( -- 2.49.1 From 8f2c08e2b7e5f89aa528f54cfca955a953281e8c Mon Sep 17 00:00:00 2001 From: agentson Date: Sun, 1 Mar 2026 20:43:06 +0900 Subject: [PATCH 7/7] test: add ci-mode coverage for session handover gate --- scripts/session_handover_check.py | 10 +-- tests/test_session_handover_check.py | 100 +++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 tests/test_session_handover_check.py diff --git a/scripts/session_handover_check.py b/scripts/session_handover_check.py index dfe200b..68ae256 100755 --- a/scripts/session_handover_check.py +++ b/scripts/session_handover_check.py @@ -88,6 +88,10 @@ def _check_handover_entry( if token not in latest: errors.append(f"latest handover entry missing token: {token}") + if strict: + if "- next_ticket: #TBD" in latest: + errors.append("latest handover entry must not use placeholder next_ticket (#TBD)") + if strict and not ci_mode: today_utc = datetime.now(UTC).date().isoformat() if today_utc not in latest: @@ -100,8 +104,6 @@ def _check_handover_entry( "latest handover entry must target current branch " f"({branch_token})" ) - if "- next_ticket: #TBD" in latest: - errors.append("latest handover entry must not use placeholder next_ticket (#TBD)") if "merged_to_feature_branch=no" in latest: errors.append( "process gate indicates not merged; implementation must stay blocked " @@ -122,8 +124,8 @@ def main() -> int: "--ci", action="store_true", help=( - "CI mode: keep structural/token checks but skip strict " - "today-date/current-branch matching." + "CI mode: keep structural/token checks and placeholder guard, " + "but skip strict today-date/current-branch/merge-gate checks." ), ) args = parser.parse_args() diff --git a/tests/test_session_handover_check.py b/tests/test_session_handover_check.py new file mode 100644 index 0000000..8c4aedb --- /dev/null +++ b/tests/test_session_handover_check.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import importlib.util +from pathlib import Path + + +def _load_module(): + script_path = Path(__file__).resolve().parents[1] / "scripts" / "session_handover_check.py" + spec = importlib.util.spec_from_file_location("session_handover_check", script_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_ci_mode_skips_date_branch_and_merge_gate(monkeypatch, tmp_path) -> None: + module = _load_module() + handover = tmp_path / "session-handover.md" + handover.write_text( + "\n".join( + [ + "### 2000-01-01 | session=test", + "- branch: feature/other-branch", + "- docs_checked: docs/workflow.md, docs/commands.md, docs/agent-constraints.md", + "- open_issues_reviewed: #1", + "- next_ticket: #123", + "- process_gate_checked: process_ticket=#1 merged_to_feature_branch=no", + ] + ), + encoding="utf-8", + ) + monkeypatch.setattr(module, "HANDOVER_LOG", handover) + + errors: list[str] = [] + module._check_handover_entry( + branch="feature/current-branch", + strict=True, + ci_mode=True, + errors=errors, + ) + assert errors == [] + + +def test_ci_mode_still_blocks_tbd_next_ticket(monkeypatch, tmp_path) -> None: + module = _load_module() + handover = tmp_path / "session-handover.md" + handover.write_text( + "\n".join( + [ + "### 2000-01-01 | session=test", + "- branch: feature/other-branch", + "- docs_checked: docs/workflow.md, docs/commands.md, docs/agent-constraints.md", + "- open_issues_reviewed: #1", + "- next_ticket: #TBD", + "- process_gate_checked: process_ticket=#1 merged_to_feature_branch=no", + ] + ), + encoding="utf-8", + ) + monkeypatch.setattr(module, "HANDOVER_LOG", handover) + + errors: list[str] = [] + module._check_handover_entry( + branch="feature/current-branch", + strict=True, + ci_mode=True, + errors=errors, + ) + assert "latest handover entry must not use placeholder next_ticket (#TBD)" in errors + + +def test_non_ci_strict_enforces_date_branch_and_merge_gate(monkeypatch, tmp_path) -> None: + module = _load_module() + handover = tmp_path / "session-handover.md" + handover.write_text( + "\n".join( + [ + "### 2000-01-01 | session=test", + "- branch: feature/other-branch", + "- docs_checked: docs/workflow.md, docs/commands.md, docs/agent-constraints.md", + "- open_issues_reviewed: #1", + "- next_ticket: #123", + "- process_gate_checked: process_ticket=#1 merged_to_feature_branch=no", + ] + ), + encoding="utf-8", + ) + monkeypatch.setattr(module, "HANDOVER_LOG", handover) + + errors: list[str] = [] + module._check_handover_entry( + branch="feature/current-branch", + strict=True, + ci_mode=False, + errors=errors, + ) + assert any("must contain today's UTC date" in e for e in errors) + assert any("must target current branch" in e for e in errors) + assert any("merged_to_feature_branch=no" in e for e in errors) -- 2.49.1