Merge pull request 'workflow: session handover gate 실행환경 모드 분리 (#353)' (#354) from feature/issue-353-ci-handover-mode-v2 into feature/v3-session-policy-stream
All checks were successful
Gitea CI / test (push) Successful in 37s

Reviewed-on: #354
This commit was merged in pull request #354.
This commit is contained in:
2026-03-01 21:00:51 +09:00
74 changed files with 1173 additions and 1393 deletions

View File

@@ -25,7 +25,7 @@ jobs:
run: pip install ".[dev]" run: pip install ".[dev]"
- name: Session handover gate - name: Session handover gate
run: python3 scripts/session_handover_check.py --strict run: python3 scripts/session_handover_check.py --strict --ci
- name: Validate governance assets - name: Validate governance assets
env: env:

View File

@@ -22,7 +22,7 @@ jobs:
run: pip install ".[dev]" run: pip install ".[dev]"
- name: Session handover gate - name: Session handover gate
run: python3 scripts/session_handover_check.py --strict run: python3 scripts/session_handover_check.py --strict --ci
- name: Validate governance assets - name: Validate governance assets
env: env:

View File

@@ -1,9 +1,9 @@
<!-- <!--
Doc-ID: DOC-REQ-001 Doc-ID: DOC-REQ-001
Version: 1.0.0 Version: 1.0.1
Status: active Status: active
Owner: strategy Owner: strategy
Updated: 2026-02-26 Updated: 2026-03-01
--> -->
# 요구사항 원장 (Single Source of Truth) # 요구사항 원장 (Single Source of Truth)
@@ -37,3 +37,4 @@ Updated: 2026-02-26
- `REQ-OPS-001`: 타임존은 모든 시간 필드에 명시(KST/UTC)되어야 한다. - `REQ-OPS-001`: 타임존은 모든 시간 필드에 명시(KST/UTC)되어야 한다.
- `REQ-OPS-002`: 문서의 수치 정책은 원장에서만 변경한다. - `REQ-OPS-002`: 문서의 수치 정책은 원장에서만 변경한다.
- `REQ-OPS-003`: 구현 태스크는 반드시 테스트 태스크를 동반한다. - `REQ-OPS-003`: 구현 태스크는 반드시 테스트 태스크를 동반한다.
- `REQ-OPS-004`: 원본 계획 문서(`v2`, `v3`)는 `docs/ouroboros/source/` 경로를 단일 기준으로 사용한다.

View File

@@ -51,6 +51,7 @@ Updated: 2026-02-26
- `TASK-OPS-001` (`REQ-OPS-001`): 시간 필드/로그 스키마의 타임존 표기 강제 규칙 구현 - `TASK-OPS-001` (`REQ-OPS-001`): 시간 필드/로그 스키마의 타임존 표기 강제 규칙 구현
- `TASK-OPS-002` (`REQ-OPS-002`): 정책 수치 변경 시 `01_requirements_registry.md` 선수정 CI 체크 추가 - `TASK-OPS-002` (`REQ-OPS-002`): 정책 수치 변경 시 `01_requirements_registry.md` 선수정 CI 체크 추가
- `TASK-OPS-003` (`REQ-OPS-003`): `TASK-*` 없는 `REQ-*` 또는 `TEST-*` 없는 `REQ-*`를 차단하는 문서 검증 게이트 유지 - `TASK-OPS-003` (`REQ-OPS-003`): `TASK-*` 없는 `REQ-*` 또는 `TEST-*` 없는 `REQ-*`를 차단하는 문서 검증 게이트 유지
- `TASK-OPS-004` (`REQ-OPS-004`): v2/v3 원본 계획 문서 위치를 `docs/ouroboros/source/`로 표준화하고 링크 일관성 검증
## 커밋 규칙 ## 커밋 규칙

View File

@@ -29,6 +29,7 @@ Updated: 2026-02-26
- `TEST-ACC-007` (`REQ-OPS-001`): 시간 관련 필드는 타임존(KST/UTC)이 누락되면 검증 실패한다. - `TEST-ACC-007` (`REQ-OPS-001`): 시간 관련 필드는 타임존(KST/UTC)이 누락되면 검증 실패한다.
- `TEST-ACC-008` (`REQ-OPS-002`): 정책 수치 변경이 원장 미반영이면 검증 실패한다. - `TEST-ACC-008` (`REQ-OPS-002`): 정책 수치 변경이 원장 미반영이면 검증 실패한다.
- `TEST-ACC-009` (`REQ-OPS-003`): `REQ-*``TASK-*`/`TEST-*` 매핑 없이 존재하면 검증 실패한다. - `TEST-ACC-009` (`REQ-OPS-003`): `REQ-*``TASK-*`/`TEST-*` 매핑 없이 존재하면 검증 실패한다.
- `TEST-ACC-019` (`REQ-OPS-004`): v2/v3 원본 계획 문서 링크는 `docs/ouroboros/source/` 경로 기준으로만 통과한다.
## 테스트 계층 ## 테스트 계층

View File

@@ -1,14 +1,14 @@
<!-- <!--
Doc-ID: DOC-ROOT-001 Doc-ID: DOC-ROOT-001
Version: 1.0.0 Version: 1.0.1
Status: active Status: active
Owner: strategy Owner: strategy
Updated: 2026-02-26 Updated: 2026-03-01
--> -->
# The Ouroboros 실행 문서 허브 # The Ouroboros 실행 문서 허브
이 폴더는 `ouroboros_plan_v2.txt`, `ouroboros_plan_v3.txt`를 구현 가능한 작업 지시서 수준으로 분해한 문서 허브다. 이 폴더는 `source/ouroboros_plan_v2.txt`, `source/ouroboros_plan_v3.txt`를 구현 가능한 작업 지시서 수준으로 분해한 문서 허브다.
## 읽기 순서 (Routing) ## 읽기 순서 (Routing)
@@ -40,5 +40,5 @@ python3 scripts/validate_ouroboros_docs.py
## 원본 계획 문서 ## 원본 계획 문서
- [v2](/home/agentson/repos/The-Ouroboros/ouroboros_plan_v2.txt) - [v2](./source/ouroboros_plan_v2.txt)
- [v3](/home/agentson/repos/The-Ouroboros/ouroboros_plan_v3.txt) - [v3](./source/ouroboros_plan_v3.txt)

View File

@@ -66,6 +66,7 @@ def _check_handover_entry(
*, *,
branch: str, branch: str,
strict: bool, strict: bool,
ci_mode: bool,
errors: list[str], errors: list[str],
) -> None: ) -> None:
if not HANDOVER_LOG.exists(): if not HANDOVER_LOG.exists():
@@ -88,6 +89,10 @@ def _check_handover_entry(
errors.append(f"latest handover entry missing token: {token}") errors.append(f"latest handover entry missing token: {token}")
if strict: if strict:
if "- next_ticket: #TBD" in latest:
errors.append("latest handover entry must not use placeholder next_ticket (#TBD)")
if strict and not ci_mode:
today_utc = datetime.now(UTC).date().isoformat() today_utc = datetime.now(UTC).date().isoformat()
if today_utc not in latest: if today_utc not in latest:
errors.append( errors.append(
@@ -99,8 +104,6 @@ def _check_handover_entry(
"latest handover entry must target current branch " "latest handover entry must target current branch "
f"({branch_token})" f"({branch_token})"
) )
if "- next_ticket: #TBD" in latest:
errors.append("latest handover entry must not use placeholder next_ticket (#TBD)")
if "merged_to_feature_branch=no" in latest: if "merged_to_feature_branch=no" in latest:
errors.append( errors.append(
"process gate indicates not merged; implementation must stay blocked " "process gate indicates not merged; implementation must stay blocked "
@@ -117,6 +120,14 @@ def main() -> int:
action="store_true", action="store_true",
help="Enforce today-date and current-branch match on latest handover entry.", help="Enforce today-date and current-branch match on latest handover entry.",
) )
parser.add_argument(
"--ci",
action="store_true",
help=(
"CI mode: keep structural/token checks and placeholder guard, "
"but skip strict today-date/current-branch/merge-gate checks."
),
)
args = parser.parse_args() args = parser.parse_args()
errors: list[str] = [] errors: list[str] = []
@@ -125,10 +136,15 @@ def main() -> int:
branch = _current_branch() branch = _current_branch()
if not branch: if not branch:
errors.append("cannot resolve current git branch") errors.append("cannot resolve current git branch")
elif branch in {"main", "master"}: elif not args.ci and branch in {"main", "master"}:
errors.append(f"working branch must not be {branch}") errors.append(f"working branch must not be {branch}")
_check_handover_entry(branch=branch, strict=args.strict, errors=errors) _check_handover_entry(
branch=branch,
strict=args.strict,
ci_mode=args.ci,
errors=errors,
)
if errors: if errors:
print("[FAIL] session handover check failed") print("[FAIL] session handover check failed")

View File

@@ -2,8 +2,8 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import math import math
from dataclasses import dataclass
@dataclass(frozen=True) @dataclass(frozen=True)

View File

@@ -2,12 +2,11 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import math import math
from dataclasses import dataclass
from random import Random from random import Random
from typing import Literal from typing import Literal
OrderSide = Literal["BUY", "SELL"] OrderSide = Literal["BUY", "SELL"]
@@ -77,7 +76,9 @@ class BacktestExecutionModel:
reason="execution_failure", reason="execution_failure",
) )
slip_mult = 1.0 + (slippage_bps / 10000.0 if request.side == "BUY" else -slippage_bps / 10000.0) slip_mult = 1.0 + (
slippage_bps / 10000.0 if request.side == "BUY" else -slippage_bps / 10000.0
)
exec_price = request.reference_price * slip_mult exec_price = request.reference_price * slip_mult
if self._rng.random() < partial_rate: if self._rng.random() < partial_rate:

View File

@@ -10,8 +10,7 @@ from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from statistics import mean from statistics import mean
from typing import Literal from typing import Literal, cast
from typing import cast
from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model from src.analysis.backtest_cost_guard import BacktestCostModel, validate_backtest_cost_model
from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier from src.analysis.triple_barrier import TripleBarrierSpec, label_with_triple_barrier

View File

@@ -104,6 +104,7 @@ class MarketScanner:
# Store in L7 real-time layer # Store in L7 real-time layer
from datetime import UTC, datetime from datetime import UTC, datetime
timeframe = datetime.now(UTC).isoformat() timeframe = datetime.now(UTC).isoformat()
self.context_store.set_context( self.context_store.set_context(
ContextLayer.L7_REALTIME, ContextLayer.L7_REALTIME,
@@ -158,12 +159,8 @@ class MarketScanner:
top_movers = valid_metrics[: self.top_n] top_movers = valid_metrics[: self.top_n]
# Detect breakouts and breakdowns # Detect breakouts and breakdowns
breakouts = [ breakouts = [m.stock_code for m in valid_metrics if self.analyzer.is_breakout(m)]
m.stock_code for m in valid_metrics if self.analyzer.is_breakout(m) breakdowns = [m.stock_code for m in valid_metrics if self.analyzer.is_breakdown(m)]
]
breakdowns = [
m.stock_code for m in valid_metrics if self.analyzer.is_breakdown(m)
]
logger.info( logger.info(
"%s scan complete: %d scanned, top momentum=%.1f, %d breakouts, %d breakdowns", "%s scan complete: %d scanned, top momentum=%.1f, %d breakouts, %d breakdowns",
@@ -228,10 +225,9 @@ class MarketScanner:
# If we removed too many, backfill from current watchlist # If we removed too many, backfill from current watchlist
if len(updated) < len(current_watchlist): if len(updated) < len(current_watchlist):
backfill = [ backfill = [code for code in current_watchlist if code not in updated][
code for code in current_watchlist : len(current_watchlist) - len(updated)
if code not in updated ]
][: len(current_watchlist) - len(updated)]
updated.extend(backfill) updated.extend(backfill)
logger.info( logger.info(

View File

@@ -158,7 +158,12 @@ class SmartVolatilityScanner:
price = latest_close price = latest_close
latest_high = _safe_float(latest.get("high")) latest_high = _safe_float(latest.get("high"))
latest_low = _safe_float(latest.get("low")) latest_low = _safe_float(latest.get("low"))
if latest_close > 0 and latest_high > 0 and latest_low > 0 and latest_high >= latest_low: if (
latest_close > 0
and latest_high > 0
and latest_low > 0
and latest_high >= latest_low
):
intraday_range_pct = (latest_high - latest_low) / latest_close * 100.0 intraday_range_pct = (latest_high - latest_low) / latest_close * 100.0
if volume <= 0: if volume <= 0:
volume = _safe_float(latest.get("volume")) volume = _safe_float(latest.get("volume"))
@@ -234,9 +239,7 @@ class SmartVolatilityScanner:
limit=50, limit=50,
) )
except Exception as exc: except Exception as exc:
logger.warning( logger.warning("Overseas fluctuation ranking failed for %s: %s", market.code, exc)
"Overseas fluctuation ranking failed for %s: %s", market.code, exc
)
fluct_rows = [] fluct_rows = []
if not fluct_rows: if not fluct_rows:
@@ -250,9 +253,7 @@ class SmartVolatilityScanner:
limit=50, limit=50,
) )
except Exception as exc: except Exception as exc:
logger.warning( logger.warning("Overseas volume ranking failed for %s: %s", market.code, exc)
"Overseas volume ranking failed for %s: %s", market.code, exc
)
volume_rows = [] volume_rows = []
for idx, row in enumerate(volume_rows): for idx, row in enumerate(volume_rows):
@@ -433,16 +434,10 @@ def _extract_intraday_range_pct(row: dict[str, Any], price: float) -> float:
if price <= 0: if price <= 0:
return 0.0 return 0.0
high = _safe_float( high = _safe_float(
row.get("high") row.get("high") or row.get("ovrs_hgpr") or row.get("stck_hgpr") or row.get("day_hgpr")
or row.get("ovrs_hgpr")
or row.get("stck_hgpr")
or row.get("day_hgpr")
) )
low = _safe_float( low = _safe_float(
row.get("low") row.get("low") or row.get("ovrs_lwpr") or row.get("stck_lwpr") or row.get("day_lwpr")
or row.get("ovrs_lwpr")
or row.get("stck_lwpr")
or row.get("day_lwpr")
) )
if high <= 0 or low <= 0 or high < low: if high <= 0 or low <= 0 or high < low:
return 0.0 return 0.0

View File

@@ -6,10 +6,10 @@ Implements first-touch labeling with upper/lower/time barriers.
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Literal, Sequence from typing import Literal
TieBreakMode = Literal["stop_first", "take_first"] TieBreakMode = Literal["stop_first", "take_first"]
@@ -92,7 +92,10 @@ def label_with_triple_barrier(
else: else:
assert spec.max_holding_bars is not None assert spec.max_holding_bars is not None
warnings.warn( warnings.warn(
"TripleBarrierSpec.max_holding_bars is deprecated; use max_holding_minutes with timestamps instead.", (
"TripleBarrierSpec.max_holding_bars is deprecated; "
"use max_holding_minutes with timestamps instead."
),
DeprecationWarning, DeprecationWarning,
stacklevel=2, stacklevel=2,
) )

View File

@@ -92,9 +92,7 @@ class VolatilityAnalyzer:
recent_tr = true_ranges[-period:] recent_tr = true_ranges[-period:]
return sum(recent_tr) / len(recent_tr) return sum(recent_tr) / len(recent_tr)
def calculate_price_change( def calculate_price_change(self, current_price: float, past_price: float) -> float:
self, current_price: float, past_price: float
) -> float:
"""Calculate price change percentage. """Calculate price change percentage.
Args: Args:
@@ -108,9 +106,7 @@ class VolatilityAnalyzer:
return 0.0 return 0.0
return ((current_price - past_price) / past_price) * 100 return ((current_price - past_price) / past_price) * 100
def calculate_volume_surge( def calculate_volume_surge(self, current_volume: float, avg_volume: float) -> float:
self, current_volume: float, avg_volume: float
) -> float:
"""Calculate volume surge ratio. """Calculate volume surge ratio.
Args: Args:
@@ -240,11 +236,7 @@ class VolatilityAnalyzer:
Momentum score (0-100) Momentum score (0-100)
""" """
# Weight recent changes more heavily # Weight recent changes more heavily
weighted_change = ( weighted_change = price_change_1m * 0.4 + price_change_5m * 0.3 + price_change_15m * 0.2
price_change_1m * 0.4 +
price_change_5m * 0.3 +
price_change_15m * 0.2
)
# Volume contribution (normalized to 0-10 scale) # Volume contribution (normalized to 0-10 scale)
volume_contribution = min(10.0, (volume_surge - 1.0) * 5.0) volume_contribution = min(10.0, (volume_surge - 1.0) * 5.0)
@@ -301,17 +293,11 @@ class VolatilityAnalyzer:
if len(close_prices) > 0: if len(close_prices) > 0:
if len(close_prices) >= 1: if len(close_prices) >= 1:
price_change_1m = self.calculate_price_change( price_change_1m = self.calculate_price_change(current_price, close_prices[-1])
current_price, close_prices[-1]
)
if len(close_prices) >= 5: if len(close_prices) >= 5:
price_change_5m = self.calculate_price_change( price_change_5m = self.calculate_price_change(current_price, close_prices[-5])
current_price, close_prices[-5]
)
if len(close_prices) >= 15: if len(close_prices) >= 15:
price_change_15m = self.calculate_price_change( price_change_15m = self.calculate_price_change(current_price, close_prices[-15])
current_price, close_prices[-15]
)
# Calculate volume surge # Calculate volume surge
avg_volume = sum(volumes) / len(volumes) if volumes else current_volume avg_volume = sum(volumes) / len(volumes) if volumes else current_volume

View File

@@ -7,9 +7,9 @@ This module provides:
- Health monitoring and alerts - Health monitoring and alerts
""" """
from src.backup.exporter import BackupExporter, ExportFormat
from src.backup.scheduler import BackupScheduler, BackupPolicy
from src.backup.cloud_storage import CloudStorage, S3Config from src.backup.cloud_storage import CloudStorage, S3Config
from src.backup.exporter import BackupExporter, ExportFormat
from src.backup.scheduler import BackupPolicy, BackupScheduler
__all__ = [ __all__ = [
"BackupExporter", "BackupExporter",

View File

@@ -94,7 +94,9 @@ class CloudStorage:
if metadata: if metadata:
extra_args["Metadata"] = metadata extra_args["Metadata"] = metadata
logger.info("Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key) logger.info(
"Uploading %s to s3://%s/%s", file_path.name, self.config.bucket_name, object_key
)
try: try:
self.client.upload_file( self.client.upload_file(

View File

@@ -14,14 +14,14 @@ import json
import logging import logging
import sqlite3 import sqlite3
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ExportFormat(str, Enum): class ExportFormat(StrEnum):
"""Supported export formats.""" """Supported export formats."""
JSON = "json" JSON = "json"
@@ -103,15 +103,11 @@ class BackupExporter:
elif fmt == ExportFormat.CSV: elif fmt == ExportFormat.CSV:
return self._export_csv(output_dir, timestamp, compress, incremental_since) return self._export_csv(output_dir, timestamp, compress, incremental_since)
elif fmt == ExportFormat.PARQUET: elif fmt == ExportFormat.PARQUET:
return self._export_parquet( return self._export_parquet(output_dir, timestamp, compress, incremental_since)
output_dir, timestamp, compress, incremental_since
)
else: else:
raise ValueError(f"Unsupported format: {fmt}") raise ValueError(f"Unsupported format: {fmt}")
def _get_trades( def _get_trades(self, incremental_since: datetime | None = None) -> list[dict[str, Any]]:
self, incremental_since: datetime | None = None
) -> list[dict[str, Any]]:
"""Fetch trades from database. """Fetch trades from database.
Args: Args:
@@ -164,9 +160,7 @@ class BackupExporter:
data = { data = {
"export_timestamp": datetime.now(UTC).isoformat(), "export_timestamp": datetime.now(UTC).isoformat(),
"incremental_since": ( "incremental_since": (incremental_since.isoformat() if incremental_since else None),
incremental_since.isoformat() if incremental_since else None
),
"record_count": len(trades), "record_count": len(trades),
"trades": trades, "trades": trades,
} }
@@ -284,8 +278,7 @@ class BackupExporter:
import pyarrow.parquet as pq import pyarrow.parquet as pq
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"pyarrow is required for Parquet export. " "pyarrow is required for Parquet export. Install with: pip install pyarrow"
"Install with: pip install pyarrow"
) )
# Convert to pyarrow table # Convert to pyarrow table

View File

@@ -14,14 +14,14 @@ import shutil
import sqlite3 import sqlite3
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HealthStatus(str, Enum): class HealthStatus(StrEnum):
"""Health check status.""" """Health check status."""
HEALTHY = "healthy" HEALTHY = "healthy"
@@ -137,9 +137,13 @@ class HealthMonitor:
used_percent = (stat.used / stat.total) * 100 used_percent = (stat.used / stat.total) * 100
if stat.free < self.min_disk_space_bytes: if stat.free < self.min_disk_space_bytes:
min_disk_gb = self.min_disk_space_bytes / 1024 / 1024 / 1024
return HealthCheckResult( return HealthCheckResult(
status=HealthStatus.UNHEALTHY, status=HealthStatus.UNHEALTHY,
message=f"Low disk space: {free_gb:.2f} GB free (minimum: {self.min_disk_space_bytes / 1024 / 1024 / 1024:.2f} GB)", message=(
f"Low disk space: {free_gb:.2f} GB free "
f"(minimum: {min_disk_gb:.2f} GB)"
),
details={ details={
"free_gb": free_gb, "free_gb": free_gb,
"total_gb": total_gb, "total_gb": total_gb,

View File

@@ -12,14 +12,14 @@ import logging
import shutil import shutil
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
from enum import Enum from enum import StrEnum
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BackupPolicy(str, Enum): class BackupPolicy(StrEnum):
"""Backup retention policies.""" """Backup retention policies."""
DAILY = "daily" DAILY = "daily"
@@ -69,9 +69,7 @@ class BackupScheduler:
for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]: for d in [self.daily_dir, self.weekly_dir, self.monthly_dir]:
d.mkdir(parents=True, exist_ok=True) d.mkdir(parents=True, exist_ok=True)
def create_backup( def create_backup(self, policy: BackupPolicy, verify: bool = True) -> BackupMetadata:
self, policy: BackupPolicy, verify: bool = True
) -> BackupMetadata:
"""Create a database backup. """Create a database backup.
Args: Args:
@@ -229,9 +227,7 @@ class BackupScheduler:
return removed return removed
def list_backups( def list_backups(self, policy: BackupPolicy | None = None) -> list[BackupMetadata]:
self, policy: BackupPolicy | None = None
) -> list[BackupMetadata]:
"""List available backups. """List available backups.
Args: Args:

View File

@@ -13,8 +13,8 @@ import hashlib
import json import json
import logging import logging
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass
from typing import Any, TYPE_CHECKING from typing import TYPE_CHECKING, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from src.brain.gemini_client import TradeDecision from src.brain.gemini_client import TradeDecision
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
class CacheEntry: class CacheEntry:
"""Cached decision with metadata.""" """Cached decision with metadata."""
decision: "TradeDecision" decision: TradeDecision
cached_at: float # Unix timestamp cached_at: float # Unix timestamp
hit_count: int = 0 hit_count: int = 0
market_data_hash: str = "" market_data_hash: str = ""
@@ -239,9 +239,7 @@ class DecisionCache:
""" """
current_time = time.time() current_time = time.time()
expired_keys = [ expired_keys = [
k k for k, v in self._cache.items() if current_time - v.cached_at > self.ttl_seconds
for k, v in self._cache.items()
if current_time - v.cached_at > self.ttl_seconds
] ]
count = len(expired_keys) count = len(expired_keys)

View File

@@ -11,14 +11,14 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import UTC, datetime
from enum import Enum from enum import StrEnum
from typing import Any from typing import Any
from src.context.layer import ContextLayer from src.context.layer import ContextLayer
from src.context.store import ContextStore from src.context.store import ContextStore
class DecisionType(str, Enum): class DecisionType(StrEnum):
"""Type of trading decision being made.""" """Type of trading decision being made."""
NORMAL = "normal" # Regular trade decision NORMAL = "normal" # Regular trade decision
@@ -183,9 +183,7 @@ class ContextSelector:
ContextLayer.L1_LEGACY, ContextLayer.L1_LEGACY,
] ]
scores = { scores = {layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers}
layer: self.score_layer_relevance(layer, decision_type) for layer in all_layers
}
# Filter by minimum score # Filter by minimum score
selected_layers = [layer for layer, score in scores.items() if score >= min_score] selected_layers = [layer for layer, score in scores.items() if score >= min_score]

View File

@@ -25,12 +25,12 @@ from typing import Any
from google import genai from google import genai
from src.config import Settings
from src.data.news_api import NewsAPI, NewsSentiment
from src.data.economic_calendar import EconomicCalendar
from src.data.market_data import MarketData
from src.brain.cache import DecisionCache from src.brain.cache import DecisionCache
from src.brain.prompt_optimizer import PromptOptimizer from src.brain.prompt_optimizer import PromptOptimizer
from src.config import Settings
from src.data.economic_calendar import EconomicCalendar
from src.data.market_data import MarketData
from src.data.news_api import NewsAPI, NewsSentiment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -159,16 +159,12 @@ class GeminiClient:
return "" return ""
# Check for upcoming high-impact events # Check for upcoming high-impact events
upcoming = self._economic_calendar.get_upcoming_events( upcoming = self._economic_calendar.get_upcoming_events(days_ahead=7, min_impact="HIGH")
days_ahead=7, min_impact="HIGH"
)
if upcoming.high_impact_count == 0: if upcoming.high_impact_count == 0:
return "" return ""
lines = [ lines = [f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"]
f"Upcoming High-Impact Events: {upcoming.high_impact_count} in next 7 days"
]
if upcoming.next_major_event is not None: if upcoming.next_major_event is not None:
event = upcoming.next_major_event event = upcoming.next_major_event
@@ -180,9 +176,7 @@ class GeminiClient:
# Check for earnings # Check for earnings
earnings_date = self._economic_calendar.get_earnings_date(stock_code) earnings_date = self._economic_calendar.get_earnings_date(stock_code)
if earnings_date is not None: if earnings_date is not None:
lines.append( lines.append(f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}")
f" Earnings: {stock_code} on {earnings_date.strftime('%Y-%m-%d')}"
)
return "\n".join(lines) return "\n".join(lines)
@@ -235,9 +229,7 @@ class GeminiClient:
# Add foreigner net if non-zero # Add foreigner net if non-zero
if market_data.get("foreigner_net", 0) != 0: if market_data.get("foreigner_net", 0) != 0:
market_info_lines.append( market_info_lines.append(f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}")
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
)
market_info = "\n".join(market_info_lines) market_info = "\n".join(market_info_lines)
@@ -249,8 +241,7 @@ class GeminiClient:
market_info += f"\n\n{external_context}" market_info += f"\n\n{external_context}"
json_format = ( json_format = (
'{"action": "BUY"|"SELL"|"HOLD", ' '{"action": "BUY"|"SELL"|"HOLD", "confidence": <int 0-100>, "rationale": "<string>"}'
'"confidence": <int 0-100>, "rationale": "<string>"}'
) )
return ( return (
f"You are a professional {market_name} trading analyst.\n" f"You are a professional {market_name} trading analyst.\n"
@@ -289,15 +280,12 @@ class GeminiClient:
# Add foreigner net if non-zero # Add foreigner net if non-zero
if market_data.get("foreigner_net", 0) != 0: if market_data.get("foreigner_net", 0) != 0:
market_info_lines.append( market_info_lines.append(f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}")
f"Foreigner Net Buy/Sell: {market_data['foreigner_net']}"
)
market_info = "\n".join(market_info_lines) market_info = "\n".join(market_info_lines)
json_format = ( json_format = (
'{"action": "BUY"|"SELL"|"HOLD", ' '{"action": "BUY"|"SELL"|"HOLD", "confidence": <int 0-100>, "rationale": "<string>"}'
'"confidence": <int 0-100>, "rationale": "<string>"}'
) )
return ( return (
f"You are a professional {market_name} trading analyst.\n" f"You are a professional {market_name} trading analyst.\n"
@@ -339,25 +327,19 @@ class GeminiClient:
data = json.loads(cleaned) data = json.loads(cleaned)
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning("Malformed JSON from Gemini — defaulting to HOLD") logger.warning("Malformed JSON from Gemini — defaulting to HOLD")
return TradeDecision( return TradeDecision(action="HOLD", confidence=0, rationale="Malformed JSON response")
action="HOLD", confidence=0, rationale="Malformed JSON response"
)
# Validate required fields # Validate required fields
if not all(k in data for k in ("action", "confidence", "rationale")): if not all(k in data for k in ("action", "confidence", "rationale")):
logger.warning("Missing fields in Gemini response — defaulting to HOLD") logger.warning("Missing fields in Gemini response — defaulting to HOLD")
# Preserve raw text in rationale so prompt_override callers (e.g. pre_market_planner) # Preserve raw text in rationale so prompt_override callers (e.g. pre_market_planner)
# can extract their own JSON format from decision.rationale (#245) # can extract their own JSON format from decision.rationale (#245)
return TradeDecision( return TradeDecision(action="HOLD", confidence=0, rationale=raw)
action="HOLD", confidence=0, rationale=raw
)
action = str(data["action"]).upper() action = str(data["action"]).upper()
if action not in VALID_ACTIONS: if action not in VALID_ACTIONS:
logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action) logger.warning("Invalid action '%s' from Gemini — defaulting to HOLD", action)
return TradeDecision( return TradeDecision(action="HOLD", confidence=0, rationale=f"Invalid action: {action}")
action="HOLD", confidence=0, rationale=f"Invalid action: {action}"
)
confidence = int(data["confidence"]) confidence = int(data["confidence"])
rationale = str(data["rationale"]) rationale = str(data["rationale"])
@@ -445,9 +427,7 @@ class GeminiClient:
# not a parsed TradeDecision. Skip parse_response to avoid spurious # not a parsed TradeDecision. Skip parse_response to avoid spurious
# "Missing fields" warnings and return the raw response directly. (#247) # "Missing fields" warnings and return the raw response directly. (#247)
if "prompt_override" in market_data: if "prompt_override" in market_data:
logger.info( logger.info("Gemini raw response received (prompt_override, tokens=%d)", token_count)
"Gemini raw response received (prompt_override, tokens=%d)", token_count
)
# Not a trade decision — don't inflate _total_decisions metrics # Not a trade decision — don't inflate _total_decisions metrics
return TradeDecision( return TradeDecision(
action="HOLD", confidence=0, rationale=raw, token_count=token_count action="HOLD", confidence=0, rationale=raw, token_count=token_count
@@ -546,9 +526,7 @@ class GeminiClient:
# Batch Decision Making (for daily trading mode) # Batch Decision Making (for daily trading mode)
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def decide_batch( async def decide_batch(self, stocks_data: list[dict[str, Any]]) -> dict[str, TradeDecision]:
self, stocks_data: list[dict[str, Any]]
) -> dict[str, TradeDecision]:
"""Make decisions for multiple stocks in a single API call. """Make decisions for multiple stocks in a single API call.
This is designed for daily trading mode to minimize API usage This is designed for daily trading mode to minimize API usage

View File

@@ -179,7 +179,8 @@ class PromptOptimizer:
# Minimal instructions # Minimal instructions
prompt = ( prompt = (
f"{market_name} trader. Analyze:\n{data_str}\n\n" f"{market_name} trader. Analyze:\n{data_str}\n\n"
'Return JSON: {"action":"BUY"|"SELL"|"HOLD","confidence":<0-100>,"rationale":"<text>"}\n' "Return JSON: "
'{"action":"BUY"|"SELL"|"HOLD","confidence":<0-100>,"rationale":"<text>"}\n'
"Rules: action=BUY/SELL/HOLD, confidence=0-100, rationale=concise. No markdown." "Rules: action=BUY/SELL/HOLD, confidence=0-100, rationale=concise. No markdown."
) )
else: else:

View File

@@ -58,7 +58,7 @@ class LeakyBucket:
def __init__(self, rate: float) -> None: def __init__(self, rate: float) -> None:
"""Args: """Args:
rate: Maximum requests per second. rate: Maximum requests per second.
""" """
self._rate = rate self._rate = rate
self._interval = 1.0 / rate self._interval = 1.0 / rate
@@ -103,7 +103,8 @@ class KISBroker:
ssl_ctx.verify_mode = ssl.CERT_NONE ssl_ctx.verify_mode = ssl.CERT_NONE
connector = aiohttp.TCPConnector(ssl=ssl_ctx) connector = aiohttp.TCPConnector(ssl=ssl_ctx)
self._session = aiohttp.ClientSession( self._session = aiohttp.ClientSession(
timeout=timeout, connector=connector, timeout=timeout,
connector=connector,
) )
return self._session return self._session
@@ -224,16 +225,12 @@ class KISBroker:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"get_orderbook failed ({resp.status}): {text}")
f"get_orderbook failed ({resp.status}): {text}"
)
return await resp.json() return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc raise ConnectionError(f"Network error fetching orderbook: {exc}") from exc
async def get_current_price( async def get_current_price(self, stock_code: str) -> tuple[float, float, float]:
self, stock_code: str
) -> tuple[float, float, float]:
"""Fetch current price data for a domestic stock. """Fetch current price data for a domestic stock.
Uses the ``inquire-price`` API (FHKST01010100), which works in both Uses the ``inquire-price`` API (FHKST01010100), which works in both
@@ -265,9 +262,7 @@ class KISBroker:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"get_current_price failed ({resp.status}): {text}")
f"get_current_price failed ({resp.status}): {text}"
)
data = await resp.json() data = await resp.json()
out = data.get("output", {}) out = data.get("output", {})
return ( return (
@@ -276,9 +271,7 @@ class KISBroker:
_f(out.get("frgn_ntby_qty")), _f(out.get("frgn_ntby_qty")),
) )
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching current price: {exc}") from exc
f"Network error fetching current price: {exc}"
) from exc
async def get_balance(self) -> dict[str, Any]: async def get_balance(self) -> dict[str, Any]:
"""Fetch current account balance and holdings.""" """Fetch current account balance and holdings."""
@@ -308,9 +301,7 @@ class KISBroker:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"get_balance failed ({resp.status}): {text}")
f"get_balance failed ({resp.status}): {text}"
)
return await resp.json() return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError(f"Network error fetching balance: {exc}") from exc raise ConnectionError(f"Network error fetching balance: {exc}") from exc
@@ -369,9 +360,7 @@ class KISBroker:
async with session.post(url, headers=headers, json=body) as resp: async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"send_order failed ({resp.status}): {text}")
f"send_order failed ({resp.status}): {text}"
)
data = await resp.json() data = await resp.json()
logger.info( logger.info(
"Order submitted", "Order submitted",
@@ -449,9 +438,7 @@ class KISBroker:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"fetch_market_rankings failed ({resp.status}): {text}")
f"fetch_market_rankings failed ({resp.status}): {text}"
)
data = await resp.json() data = await resp.json()
# Parse response - output is a list of ranked stocks # Parse response - output is a list of ranked stocks
@@ -465,14 +452,16 @@ class KISBroker:
rankings = [] rankings = []
for item in data.get("output", [])[:limit]: for item in data.get("output", [])[:limit]:
rankings.append({ rankings.append(
"stock_code": item.get("stck_shrn_iscd") or item.get("mksc_shrn_iscd", ""), {
"name": item.get("hts_kor_isnm", ""), "stock_code": item.get("stck_shrn_iscd") or item.get("mksc_shrn_iscd", ""),
"price": _safe_float(item.get("stck_prpr", "0")), "name": item.get("hts_kor_isnm", ""),
"volume": _safe_float(item.get("acml_vol", "0")), "price": _safe_float(item.get("stck_prpr", "0")),
"change_rate": _safe_float(item.get("prdy_ctrt", "0")), "volume": _safe_float(item.get("acml_vol", "0")),
"volume_increase_rate": _safe_float(item.get("vol_inrt", "0")), "change_rate": _safe_float(item.get("prdy_ctrt", "0")),
}) "volume_increase_rate": _safe_float(item.get("vol_inrt", "0")),
}
)
return rankings return rankings
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
@@ -522,9 +511,7 @@ class KISBroker:
data = await resp.json() data = await resp.json()
return data.get("output", []) or [] return data.get("output", []) or []
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching domestic pending orders: {exc}") from exc
f"Network error fetching domestic pending orders: {exc}"
) from exc
async def cancel_domestic_order( async def cancel_domestic_order(
self, self,
@@ -575,14 +562,10 @@ class KISBroker:
async with session.post(url, headers=headers, json=body) as resp: async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"cancel_domestic_order failed ({resp.status}): {text}")
f"cancel_domestic_order failed ({resp.status}): {text}"
)
return cast(dict[str, Any], await resp.json()) return cast(dict[str, Any], await resp.json())
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error cancelling domestic order: {exc}") from exc
f"Network error cancelling domestic order: {exc}"
) from exc
async def get_daily_prices( async def get_daily_prices(
self, self,
@@ -609,6 +592,7 @@ class KISBroker:
# Calculate date range (today and N days ago) # Calculate date range (today and N days ago)
from datetime import datetime, timedelta from datetime import datetime, timedelta
end_date = datetime.now().strftime("%Y%m%d") end_date = datetime.now().strftime("%Y%m%d")
start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d") start_date = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d")
@@ -627,9 +611,7 @@ class KISBroker:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"get_daily_prices failed ({resp.status}): {text}")
f"get_daily_prices failed ({resp.status}): {text}"
)
data = await resp.json() data = await resp.json()
# Parse response # Parse response
@@ -643,14 +625,16 @@ class KISBroker:
prices = [] prices = []
for item in data.get("output2", []): for item in data.get("output2", []):
prices.append({ prices.append(
"date": item.get("stck_bsop_date", ""), {
"open": _safe_float(item.get("stck_oprc", "0")), "date": item.get("stck_bsop_date", ""),
"high": _safe_float(item.get("stck_hgpr", "0")), "open": _safe_float(item.get("stck_oprc", "0")),
"low": _safe_float(item.get("stck_lwpr", "0")), "high": _safe_float(item.get("stck_hgpr", "0")),
"close": _safe_float(item.get("stck_clpr", "0")), "low": _safe_float(item.get("stck_lwpr", "0")),
"volume": _safe_float(item.get("acml_vol", "0")), "close": _safe_float(item.get("stck_clpr", "0")),
}) "volume": _safe_float(item.get("acml_vol", "0")),
}
)
# Sort oldest to newest (KIS returns newest first) # Sort oldest to newest (KIS returns newest first)
prices.reverse() prices.reverse()

View File

@@ -36,11 +36,11 @@ _CANCEL_TR_ID_MAP: dict[str, tuple[str, str]] = {
"NYSE": ("TTTT1004U", "VTTT1004U"), "NYSE": ("TTTT1004U", "VTTT1004U"),
"AMEX": ("TTTT1004U", "VTTT1004U"), "AMEX": ("TTTT1004U", "VTTT1004U"),
"SEHK": ("TTTS1003U", "VTTS1003U"), "SEHK": ("TTTS1003U", "VTTS1003U"),
"TSE": ("TTTS0309U", "VTTS0309U"), "TSE": ("TTTS0309U", "VTTS0309U"),
"SHAA": ("TTTS0302U", "VTTS0302U"), "SHAA": ("TTTS0302U", "VTTS0302U"),
"SZAA": ("TTTS0306U", "VTTS0306U"), "SZAA": ("TTTS0306U", "VTTS0306U"),
"HNX": ("TTTS0312U", "VTTS0312U"), "HNX": ("TTTS0312U", "VTTS0312U"),
"HSX": ("TTTS0312U", "VTTS0312U"), "HSX": ("TTTS0312U", "VTTS0312U"),
} }
@@ -56,9 +56,7 @@ class OverseasBroker:
""" """
self._broker = kis_broker self._broker = kis_broker
async def get_overseas_price( async def get_overseas_price(self, exchange_code: str, stock_code: str) -> dict[str, Any]:
self, exchange_code: str, stock_code: str
) -> dict[str, Any]:
""" """
Fetch overseas stock price. Fetch overseas stock price.
@@ -89,14 +87,10 @@ class OverseasBroker:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"get_overseas_price failed ({resp.status}): {text}")
f"get_overseas_price failed ({resp.status}): {text}"
)
return await resp.json() return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching overseas price: {exc}") from exc
f"Network error fetching overseas price: {exc}"
) from exc
async def fetch_overseas_rankings( async def fetch_overseas_rankings(
self, self,
@@ -154,9 +148,7 @@ class OverseasBroker:
ranking_type, ranking_type,
) )
return [] return []
raise ConnectionError( raise ConnectionError(f"fetch_overseas_rankings failed ({resp.status}): {text}")
f"fetch_overseas_rankings failed ({resp.status}): {text}"
)
data = await resp.json() data = await resp.json()
rows = self._extract_ranking_rows(data) rows = self._extract_ranking_rows(data)
@@ -171,9 +163,7 @@ class OverseasBroker:
) )
return [] return []
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching overseas rankings: {exc}") from exc
f"Network error fetching overseas rankings: {exc}"
) from exc
async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]: async def get_overseas_balance(self, exchange_code: str) -> dict[str, Any]:
""" """
@@ -193,9 +183,7 @@ class OverseasBroker:
# TR_ID: 실전 TTTS3012R, 모의 VTTS3012R # TR_ID: 실전 TTTS3012R, 모의 VTTS3012R
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 잔고조회' 시트 # Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 잔고조회' 시트
balance_tr_id = ( balance_tr_id = "TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R"
"TTTS3012R" if self._broker._settings.MODE == "live" else "VTTS3012R"
)
headers = await self._broker._auth_headers(balance_tr_id) headers = await self._broker._auth_headers(balance_tr_id)
params = { params = {
"CANO": self._broker._account_no, "CANO": self._broker._account_no,
@@ -205,22 +193,16 @@ class OverseasBroker:
"CTX_AREA_FK200": "", "CTX_AREA_FK200": "",
"CTX_AREA_NK200": "", "CTX_AREA_NK200": "",
} }
url = ( url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-balance"
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-balance"
)
try: try:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"get_overseas_balance failed ({resp.status}): {text}")
f"get_overseas_balance failed ({resp.status}): {text}"
)
return await resp.json() return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching overseas balance: {exc}") from exc
f"Network error fetching overseas balance: {exc}"
) from exc
async def get_overseas_buying_power( async def get_overseas_buying_power(
self, self,
@@ -247,9 +229,7 @@ class OverseasBroker:
# TR_ID: 실전 TTTS3007R, 모의 VTTS3007R # TR_ID: 실전 TTTS3007R, 모의 VTTS3007R
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트 # Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트
ps_tr_id = ( ps_tr_id = "TTTS3007R" if self._broker._settings.MODE == "live" else "VTTS3007R"
"TTTS3007R" if self._broker._settings.MODE == "live" else "VTTS3007R"
)
headers = await self._broker._auth_headers(ps_tr_id) headers = await self._broker._auth_headers(ps_tr_id)
params = { params = {
"CANO": self._broker._account_no, "CANO": self._broker._account_no,
@@ -258,9 +238,7 @@ class OverseasBroker:
"OVRS_ORD_UNPR": f"{price:.2f}", "OVRS_ORD_UNPR": f"{price:.2f}",
"ITEM_CD": stock_code, "ITEM_CD": stock_code,
} }
url = ( url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-psamount"
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-psamount"
)
try: try:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
@@ -271,9 +249,7 @@ class OverseasBroker:
) )
return await resp.json() return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching overseas buying power: {exc}") from exc
f"Network error fetching overseas buying power: {exc}"
) from exc
async def send_overseas_order( async def send_overseas_order(
self, self,
@@ -330,9 +306,7 @@ class OverseasBroker:
async with session.post(url, headers=headers, json=body) as resp: async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"send_overseas_order failed ({resp.status}): {text}")
f"send_overseas_order failed ({resp.status}): {text}"
)
data = await resp.json() data = await resp.json()
rt_cd = data.get("rt_cd", "") rt_cd = data.get("rt_cd", "")
msg1 = data.get("msg1", "") msg1 = data.get("msg1", "")
@@ -357,13 +331,9 @@ class OverseasBroker:
) )
return data return data
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error sending overseas order: {exc}") from exc
f"Network error sending overseas order: {exc}"
) from exc
async def get_overseas_pending_orders( async def get_overseas_pending_orders(self, exchange_code: str) -> list[dict[str, Any]]:
self, exchange_code: str
) -> list[dict[str, Any]]:
"""Fetch unfilled (pending) overseas orders for a given exchange. """Fetch unfilled (pending) overseas orders for a given exchange.
Args: Args:
@@ -379,9 +349,7 @@ class OverseasBroker:
ConnectionError: On network or API errors (live mode only). ConnectionError: On network or API errors (live mode only).
""" """
if self._broker._settings.MODE != "live": if self._broker._settings.MODE != "live":
logger.debug( logger.debug("Pending orders API (TTTS3018R) not supported in paper mode; returning []")
"Pending orders API (TTTS3018R) not supported in paper mode; returning []"
)
return [] return []
await self._broker._rate_limiter.acquire() await self._broker._rate_limiter.acquire()
@@ -398,9 +366,7 @@ class OverseasBroker:
"CTX_AREA_FK200": "", "CTX_AREA_FK200": "",
"CTX_AREA_NK200": "", "CTX_AREA_NK200": "",
} }
url = ( url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs"
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/inquire-nccs"
)
try: try:
async with session.get(url, headers=headers, params=params) as resp: async with session.get(url, headers=headers, params=params) as resp:
@@ -415,9 +381,7 @@ class OverseasBroker:
return output return output
return [] return []
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error fetching pending orders: {exc}") from exc
f"Network error fetching pending orders: {exc}"
) from exc
async def cancel_overseas_order( async def cancel_overseas_order(
self, self,
@@ -469,22 +433,16 @@ class OverseasBroker:
headers = await self._broker._auth_headers(tr_id) headers = await self._broker._auth_headers(tr_id)
headers["hashkey"] = hash_key headers["hashkey"] = hash_key
url = ( url = f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl"
f"{self._broker._base_url}/uapi/overseas-stock/v1/trading/order-rvsecncl"
)
try: try:
async with session.post(url, headers=headers, json=body) as resp: async with session.post(url, headers=headers, json=body) as resp:
if resp.status != 200: if resp.status != 200:
text = await resp.text() text = await resp.text()
raise ConnectionError( raise ConnectionError(f"cancel_overseas_order failed ({resp.status}): {text}")
f"cancel_overseas_order failed ({resp.status}): {text}"
)
return await resp.json() return await resp.json()
except (TimeoutError, aiohttp.ClientError) as exc: except (TimeoutError, aiohttp.ClientError) as exc:
raise ConnectionError( raise ConnectionError(f"Network error cancelling overseas order: {exc}") from exc
f"Network error cancelling overseas order: {exc}"
) from exc
def _get_currency_code(self, exchange_code: str) -> str: def _get_currency_code(self, exchange_code: str) -> str:
""" """

View File

@@ -111,25 +111,21 @@ class Settings(BaseSettings):
# Telegram notification type filters (granular control) # Telegram notification type filters (granular control)
# circuit_breaker is always sent regardless — safety-critical # circuit_breaker is always sent regardless — safety-critical
TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts TELEGRAM_NOTIFY_TRADES: bool = True # BUY/SELL execution alerts
TELEGRAM_NOTIFY_MARKET_OPEN_CLOSE: bool = True # Market open/close alerts TELEGRAM_NOTIFY_MARKET_OPEN_CLOSE: bool = True # Market open/close alerts
TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts TELEGRAM_NOTIFY_FAT_FINGER: bool = True # Fat-finger rejection alerts
TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts TELEGRAM_NOTIFY_SYSTEM_EVENTS: bool = True # System start/shutdown alerts
TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts TELEGRAM_NOTIFY_PLAYBOOK: bool = True # Playbook generated/failed alerts
TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent) TELEGRAM_NOTIFY_SCENARIO_MATCH: bool = True # Scenario matched alerts (most frequent)
TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts TELEGRAM_NOTIFY_ERRORS: bool = True # Error alerts
# Overseas ranking API (KIS endpoint/TR_ID may vary by account/product) # Overseas ranking API (KIS endpoint/TR_ID may vary by account/product)
# Override these from .env if your account uses different specs. # Override these from .env if your account uses different specs.
OVERSEAS_RANKING_ENABLED: bool = True OVERSEAS_RANKING_ENABLED: bool = True
OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76290000" OVERSEAS_RANKING_FLUCT_TR_ID: str = "HHDFS76290000"
OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76270000" OVERSEAS_RANKING_VOLUME_TR_ID: str = "HHDFS76270000"
OVERSEAS_RANKING_FLUCT_PATH: str = ( OVERSEAS_RANKING_FLUCT_PATH: str = "/uapi/overseas-stock/v1/ranking/updown-rate"
"/uapi/overseas-stock/v1/ranking/updown-rate" OVERSEAS_RANKING_VOLUME_PATH: str = "/uapi/overseas-stock/v1/ranking/volume-surge"
)
OVERSEAS_RANKING_VOLUME_PATH: str = (
"/uapi/overseas-stock/v1/ranking/volume-surge"
)
# Dashboard (optional) # Dashboard (optional)
DASHBOARD_ENABLED: bool = False DASHBOARD_ENABLED: bool = False

View File

@@ -222,9 +222,7 @@ class ContextAggregator:
total_pnl = 0.0 total_pnl = 0.0
for month in months: for month in months:
monthly_pnl = self.store.get_context( monthly_pnl = self.store.get_context(ContextLayer.L4_MONTHLY, month, "monthly_pnl")
ContextLayer.L4_MONTHLY, month, "monthly_pnl"
)
if monthly_pnl is not None: if monthly_pnl is not None:
total_pnl += monthly_pnl total_pnl += monthly_pnl
@@ -251,9 +249,7 @@ class ContextAggregator:
if quarterly_pnl is not None: if quarterly_pnl is not None:
total_pnl += quarterly_pnl total_pnl += quarterly_pnl
self.store.set_context( self.store.set_context(ContextLayer.L2_ANNUAL, year, "annual_pnl", round(total_pnl, 2))
ContextLayer.L2_ANNUAL, year, "annual_pnl", round(total_pnl, 2)
)
def aggregate_legacy_from_annual(self) -> None: def aggregate_legacy_from_annual(self) -> None:
"""Aggregate L1 (legacy) context from all L2 (annual) data.""" """Aggregate L1 (legacy) context from all L2 (annual) data."""
@@ -280,9 +276,7 @@ class ContextAggregator:
self.store.set_context( self.store.set_context(
ContextLayer.L1_LEGACY, "LEGACY", "total_pnl", round(total_pnl, 2) ContextLayer.L1_LEGACY, "LEGACY", "total_pnl", round(total_pnl, 2)
) )
self.store.set_context( self.store.set_context(ContextLayer.L1_LEGACY, "LEGACY", "years_traded", years_traded)
ContextLayer.L1_LEGACY, "LEGACY", "years_traded", years_traded
)
self.store.set_context( self.store.set_context(
ContextLayer.L1_LEGACY, ContextLayer.L1_LEGACY,
"LEGACY", "LEGACY",

View File

@@ -3,10 +3,10 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import StrEnum
class ContextLayer(str, Enum): class ContextLayer(StrEnum):
"""7-tier context hierarchy from real-time to generational.""" """7-tier context hierarchy from real-time to generational."""
L1_LEGACY = "L1_LEGACY" # Cumulative/generational wisdom L1_LEGACY = "L1_LEGACY" # Cumulative/generational wisdom

View File

@@ -9,7 +9,7 @@ This module summarizes old context data instead of including raw details:
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime
from typing import Any from typing import Any
from src.context.layer import ContextLayer from src.context.layer import ContextLayer

View File

@@ -11,8 +11,9 @@ Order is fixed:
from __future__ import annotations from __future__ import annotations
import inspect import inspect
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable from typing import Any
StepCallable = Callable[[], Any | Awaitable[Any]] StepCallable = Callable[[], Any | Awaitable[Any]]

View File

@@ -15,7 +15,7 @@ from src.markets.schedule import MarketInfo
_LOW_LIQUIDITY_SESSIONS = {"NXT_AFTER", "US_PRE", "US_DAY", "US_AFTER"} _LOW_LIQUIDITY_SESSIONS = {"NXT_AFTER", "US_PRE", "US_DAY", "US_AFTER"}
class OrderPolicyRejected(Exception): class OrderPolicyRejectedError(Exception):
"""Raised when an order violates session policy.""" """Raised when an order violates session policy."""
def __init__(self, message: str, *, session_id: str, market_code: str) -> None: def __init__(self, message: str, *, session_id: str, market_code: str) -> None:
@@ -61,7 +61,9 @@ def classify_session_id(market: MarketInfo, now: datetime | None = None) -> str:
def get_session_info(market: MarketInfo, now: datetime | None = None) -> SessionInfo: def get_session_info(market: MarketInfo, now: datetime | None = None) -> SessionInfo:
session_id = classify_session_id(market, now) session_id = classify_session_id(market, now)
return SessionInfo(session_id=session_id, is_low_liquidity=session_id in _LOW_LIQUIDITY_SESSIONS) return SessionInfo(
session_id=session_id, is_low_liquidity=session_id in _LOW_LIQUIDITY_SESSIONS
)
def validate_order_policy( def validate_order_policy(
@@ -76,7 +78,7 @@ def validate_order_policy(
is_market_order = price <= 0 is_market_order = price <= 0
if info.is_low_liquidity and is_market_order: if info.is_low_liquidity and is_market_order:
raise OrderPolicyRejected( raise OrderPolicyRejectedError(
f"Market order is forbidden in low-liquidity session ({info.session_id})", f"Market order is forbidden in low-liquidity session ({info.session_id})",
session_id=info.session_id, session_id=info.session_id,
market_code=market.code, market_code=market.code,
@@ -84,10 +86,14 @@ def validate_order_policy(
# Guard against accidental unsupported actions. # Guard against accidental unsupported actions.
if order_type not in {"BUY", "SELL"}: if order_type not in {"BUY", "SELL"}:
raise OrderPolicyRejected( raise OrderPolicyRejectedError(
f"Unsupported order_type={order_type}", f"Unsupported order_type={order_type}",
session_id=info.session_id, session_id=info.session_id,
market_code=market.code, market_code=market.code,
) )
return info return info
# Backward compatibility alias
OrderPolicyRejected = OrderPolicyRejectedError

View File

@@ -28,9 +28,7 @@ class PriorityTask:
# Task data not used in comparison # Task data not used in comparison
task_id: str = field(compare=False) task_id: str = field(compare=False)
task_data: dict[str, Any] = field(compare=False, default_factory=dict) task_data: dict[str, Any] = field(compare=False, default_factory=dict)
callback: Callable[[], Coroutine[Any, Any, Any]] | None = field( callback: Callable[[], Coroutine[Any, Any, Any]] | None = field(compare=False, default=None)
compare=False, default=None
)
@dataclass @dataclass

View File

@@ -25,7 +25,7 @@ class CircuitBreakerTripped(SystemExit):
) )
class FatFingerRejected(Exception): class FatFingerRejectedError(Exception):
"""Raised when an order exceeds the maximum allowed proportion of cash.""" """Raised when an order exceeds the maximum allowed proportion of cash."""
def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None: def __init__(self, order_amount: float, total_cash: float, max_pct: float) -> None:
@@ -61,7 +61,7 @@ class RiskManager:
def check_fat_finger(self, order_amount: float, total_cash: float) -> None: def check_fat_finger(self, order_amount: float, total_cash: float) -> None:
"""Reject orders that exceed the maximum proportion of available cash.""" """Reject orders that exceed the maximum proportion of available cash."""
if total_cash <= 0: if total_cash <= 0:
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct) raise FatFingerRejectedError(order_amount, total_cash, self._ff_max_pct)
ratio_pct = (order_amount / total_cash) * 100 ratio_pct = (order_amount / total_cash) * 100
if ratio_pct > self._ff_max_pct: if ratio_pct > self._ff_max_pct:
@@ -69,7 +69,7 @@ class RiskManager:
"Fat finger check failed", "Fat finger check failed",
extra={"order_amount": order_amount}, extra={"order_amount": order_amount},
) )
raise FatFingerRejected(order_amount, total_cash, self._ff_max_pct) raise FatFingerRejectedError(order_amount, total_cash, self._ff_max_pct)
def validate_order( def validate_order(
self, self,
@@ -81,3 +81,7 @@ class RiskManager:
self.check_circuit_breaker(current_pnl_pct) self.check_circuit_breaker(current_pnl_pct)
self.check_fat_finger(order_amount, total_cash) self.check_fat_finger(order_amount, total_cash)
logger.info("Order passed risk validation") logger.info("Order passed risk validation")
# Backward compatibility alias
FatFingerRejected = FatFingerRejectedError

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import json import json
import os import os
import sqlite3 import sqlite3
from datetime import UTC, datetime, timezone from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -188,10 +188,7 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI:
return { return {
"market": "all", "market": "all",
"combined": combined, "combined": combined,
"by_market": [ "by_market": [_row_to_performance(row) for row in by_market_rows],
_row_to_performance(row)
for row in by_market_rows
],
} }
row = conn.execute( row = conn.execute(
@@ -401,7 +398,7 @@ def create_dashboard_app(db_path: str, mode: str = "paper") -> FastAPI:
""" """
).fetchall() ).fetchall()
now = datetime.now(timezone.utc) now = datetime.now(UTC)
positions = [] positions = []
for row in rows: for row in rows:
entry_time_str = row["entry_time"] entry_time_str = row["entry_time"]

View File

@@ -9,7 +9,6 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -123,8 +123,7 @@ def init_db(db_path: str) -> sqlite3.Connection:
""" """
) )
decision_columns = { decision_columns = {
row[1] row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()
for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()
} }
if "session_id" not in decision_columns: if "session_id" not in decision_columns:
conn.execute("ALTER TABLE decision_logs ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'") conn.execute("ALTER TABLE decision_logs ADD COLUMN session_id TEXT DEFAULT 'UNKNOWN'")
@@ -185,9 +184,7 @@ def init_db(db_path: str) -> sqlite3.Connection:
conn.execute( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_decision_logs_timestamp ON decision_logs(timestamp)" "CREATE INDEX IF NOT EXISTS idx_decision_logs_timestamp ON decision_logs(timestamp)"
) )
conn.execute( conn.execute("CREATE INDEX IF NOT EXISTS idx_decision_logs_reviewed ON decision_logs(reviewed)")
"CREATE INDEX IF NOT EXISTS idx_decision_logs_reviewed ON decision_logs(reviewed)"
)
conn.execute( conn.execute(
"CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)" "CREATE INDEX IF NOT EXISTS idx_decision_logs_confidence ON decision_logs(confidence)"
) )
@@ -381,9 +378,7 @@ def get_open_position(
return {"decision_id": row[1], "price": row[2], "quantity": row[3], "timestamp": row[4]} return {"decision_id": row[1], "price": row[2], "quantity": row[3], "timestamp": row[4]}
def get_recent_symbols( def get_recent_symbols(conn: sqlite3.Connection, market: str, limit: int = 30) -> list[str]:
conn: sqlite3.Connection, market: str, limit: int = 30
) -> list[str]:
"""Return recent unique symbols for a market, newest first.""" """Return recent unique symbols for a market, newest first."""
cursor = conn.execute( cursor = conn.execute(
""" """

View File

@@ -90,9 +90,7 @@ class ABTester:
sharpe_ratio = None sharpe_ratio = None
if len(pnls) > 1: if len(pnls) > 1:
mean_return = avg_pnl mean_return = avg_pnl
std_return = ( std_return = (sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1)) ** 0.5
sum((p - mean_return) ** 2 for p in pnls) / (len(pnls) - 1)
) ** 0.5
if std_return > 0: if std_return > 0:
sharpe_ratio = mean_return / std_return sharpe_ratio = mean_return / std_return
@@ -198,8 +196,7 @@ class ABTester:
if meets_criteria: if meets_criteria:
logger.info( logger.info(
"Strategy '%s' meets deployment criteria: " "Strategy '%s' meets deployment criteria: win_rate=%.2f%%, trades=%d, avg_pnl=%.2f",
"win_rate=%.2f%%, trades=%d, avg_pnl=%.2f",
result.winner, result.winner,
winning_perf.win_rate, winning_perf.win_rate,
winning_perf.total_trades, winning_perf.total_trades,

View File

@@ -60,9 +60,7 @@ class DailyReviewer:
if isinstance(scenario_match, dict) and scenario_match: if isinstance(scenario_match, dict) and scenario_match:
matched += 1 matched += 1
scenario_match_rate = ( scenario_match_rate = (
round((matched / total_decisions) * 100, 2) round((matched / total_decisions) * 100, 2) if total_decisions else 0.0
if total_decisions
else 0.0
) )
trade_stats = self._conn.execute( trade_stats = self._conn.execute(

View File

@@ -80,26 +80,26 @@ class EvolutionOptimizer:
# Convert to dict format for analysis # Convert to dict format for analysis
failures = [] failures = []
for decision in losing_decisions: for decision in losing_decisions:
failures.append({ failures.append(
"decision_id": decision.decision_id, {
"timestamp": decision.timestamp, "decision_id": decision.decision_id,
"stock_code": decision.stock_code, "timestamp": decision.timestamp,
"market": decision.market, "stock_code": decision.stock_code,
"exchange_code": decision.exchange_code, "market": decision.market,
"action": decision.action, "exchange_code": decision.exchange_code,
"confidence": decision.confidence, "action": decision.action,
"rationale": decision.rationale, "confidence": decision.confidence,
"outcome_pnl": decision.outcome_pnl, "rationale": decision.rationale,
"outcome_accuracy": decision.outcome_accuracy, "outcome_pnl": decision.outcome_pnl,
"context_snapshot": decision.context_snapshot, "outcome_accuracy": decision.outcome_accuracy,
"input_data": decision.input_data, "context_snapshot": decision.context_snapshot,
}) "input_data": decision.input_data,
}
)
return failures return failures
def identify_failure_patterns( def identify_failure_patterns(self, failures: list[dict[str, Any]]) -> dict[str, Any]:
self, failures: list[dict[str, Any]]
) -> dict[str, Any]:
"""Identify patterns in losing decisions. """Identify patterns in losing decisions.
Analyzes: Analyzes:
@@ -143,12 +143,8 @@ class EvolutionOptimizer:
total_confidence += failure.get("confidence", 0) total_confidence += failure.get("confidence", 0)
total_loss += failure.get("outcome_pnl", 0.0) total_loss += failure.get("outcome_pnl", 0.0)
patterns["avg_confidence"] = ( patterns["avg_confidence"] = round(total_confidence / len(failures), 2) if failures else 0.0
round(total_confidence / len(failures), 2) if failures else 0.0 patterns["avg_loss"] = round(total_loss / len(failures), 2) if failures else 0.0
)
patterns["avg_loss"] = (
round(total_loss / len(failures), 2) if failures else 0.0
)
# Convert Counters to regular dicts for JSON serialization # Convert Counters to regular dicts for JSON serialization
patterns["markets"] = dict(patterns["markets"]) patterns["markets"] = dict(patterns["markets"])
@@ -197,7 +193,8 @@ class EvolutionOptimizer:
prompt = ( prompt = (
"You are a quantitative trading strategy developer.\n" "You are a quantitative trading strategy developer.\n"
"Analyze these failed trades and their patterns, then generate an improved strategy.\n\n" "Analyze these failed trades and their patterns, "
"then generate an improved strategy.\n\n"
f"Failure Patterns:\n{json.dumps(patterns, indent=2)}\n\n" f"Failure Patterns:\n{json.dumps(patterns, indent=2)}\n\n"
f"Sample Failed Trades (first 5):\n" f"Sample Failed Trades (first 5):\n"
f"{json.dumps(failures[:5], indent=2, default=str)}\n\n" f"{json.dumps(failures[:5], indent=2, default=str)}\n\n"
@@ -214,7 +211,8 @@ class EvolutionOptimizer:
try: try:
response = await self._client.aio.models.generate_content( response = await self._client.aio.models.generate_content(
model=self._model_name, contents=prompt, model=self._model_name,
contents=prompt,
) )
body = response.text.strip() body = response.text.strip()
except Exception as exc: except Exception as exc:
@@ -280,9 +278,7 @@ class EvolutionOptimizer:
logger.info("Strategy validation PASSED") logger.info("Strategy validation PASSED")
return True return True
else: else:
logger.warning( logger.warning("Strategy validation FAILED:\n%s", result.stdout + result.stderr)
"Strategy validation FAILED:\n%s", result.stdout + result.stderr
)
# Clean up failing strategy # Clean up failing strategy
strategy_path.unlink(missing_ok=True) strategy_path.unlink(missing_ok=True)
return False return False

View File

@@ -187,9 +187,7 @@ class PerformanceTracker:
return metrics return metrics
def calculate_improvement_trend( def calculate_improvement_trend(self, metrics_history: list[StrategyMetrics]) -> dict[str, Any]:
self, metrics_history: list[StrategyMetrics]
) -> dict[str, Any]:
"""Calculate improvement trend from historical metrics. """Calculate improvement trend from historical metrics.
Args: Args:
@@ -229,9 +227,7 @@ class PerformanceTracker:
"period_count": len(metrics_history), "period_count": len(metrics_history),
} }
def generate_dashboard( def generate_dashboard(self, strategy_name: str | None = None) -> PerformanceDashboard:
self, strategy_name: str | None = None
) -> PerformanceDashboard:
"""Generate a comprehensive performance dashboard. """Generate a comprehensive performance dashboard.
Args: Args:
@@ -260,9 +256,7 @@ class PerformanceTracker:
improvement_trend=improvement_trend, improvement_trend=improvement_trend,
) )
def export_dashboard_json( def export_dashboard_json(self, dashboard: PerformanceDashboard) -> str:
self, dashboard: PerformanceDashboard
) -> str:
"""Export dashboard as JSON string. """Export dashboard as JSON string.
Args: Args:

View File

@@ -140,9 +140,7 @@ class DecisionLogger:
) )
self.conn.commit() self.conn.commit()
def update_outcome( def update_outcome(self, decision_id: str, pnl: float, accuracy: int) -> None:
self, decision_id: str, pnl: float, accuracy: int
) -> None:
"""Update the outcome of a decision after trade execution. """Update the outcome of a decision after trade execution.
Args: Args:

View File

@@ -26,12 +26,12 @@ from src.context.aggregator import ContextAggregator
from src.context.layer import ContextLayer from src.context.layer import ContextLayer
from src.context.scheduler import ContextScheduler from src.context.scheduler import ContextScheduler
from src.context.store import ContextStore from src.context.store import ContextStore
from src.core.criticality import CriticalityAssessor
from src.core.blackout_manager import ( from src.core.blackout_manager import (
BlackoutOrderManager, BlackoutOrderManager,
QueuedOrderIntent, QueuedOrderIntent,
parse_blackout_windows_kst, parse_blackout_windows_kst,
) )
from src.core.criticality import CriticalityAssessor
from src.core.kill_switch import KillSwitchOrchestrator from src.core.kill_switch import KillSwitchOrchestrator
from src.core.order_policy import ( from src.core.order_policy import (
OrderPolicyRejected, OrderPolicyRejected,
@@ -52,12 +52,16 @@ from src.evolution.optimizer import EvolutionOptimizer
from src.logging.decision_logger import DecisionLogger from src.logging.decision_logger import DecisionLogger
from src.logging_config import setup_logging from src.logging_config import setup_logging
from src.markets.schedule import MARKETS, MarketInfo, get_next_market_open, get_open_markets from src.markets.schedule import MARKETS, MarketInfo, get_next_market_open, get_open_markets
from src.notifications.telegram_client import NotificationFilter, TelegramClient, TelegramCommandHandler from src.notifications.telegram_client import (
from src.strategy.models import DayPlaybook, MarketOutlook NotificationFilter,
TelegramClient,
TelegramCommandHandler,
)
from src.strategy.exit_rules import ExitRuleConfig, ExitRuleInput, evaluate_exit from src.strategy.exit_rules import ExitRuleConfig, ExitRuleInput, evaluate_exit
from src.strategy.models import DayPlaybook, MarketOutlook
from src.strategy.playbook_store import PlaybookStore from src.strategy.playbook_store import PlaybookStore
from src.strategy.pre_market_planner import PreMarketPlanner
from src.strategy.position_state_machine import PositionState from src.strategy.position_state_machine import PositionState
from src.strategy.pre_market_planner import PreMarketPlanner
from src.strategy.scenario_engine import ScenarioEngine from src.strategy.scenario_engine import ScenarioEngine
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -350,9 +354,7 @@ async def _inject_staged_exit_features(
return return
if "pred_down_prob" not in market_data: if "pred_down_prob" not in market_data:
market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi( market_data["pred_down_prob"] = _estimate_pred_down_prob_from_rsi(market_data.get("rsi"))
market_data.get("rsi")
)
existing_atr = safe_float(market_data.get("atr_value"), 0.0) existing_atr = safe_float(market_data.get("atr_value"), 0.0)
if existing_atr > 0: if existing_atr > 0:
@@ -389,7 +391,7 @@ async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kw
return await coro_factory(*args, **kwargs) return await coro_factory(*args, **kwargs)
except ConnectionError as exc: except ConnectionError as exc:
if attempt < MAX_CONNECTION_RETRIES: if attempt < MAX_CONNECTION_RETRIES:
wait_secs = 2 ** attempt wait_secs = 2**attempt
logger.warning( logger.warning(
"Connection error %s (attempt %d/%d), retrying in %ds: %s", "Connection error %s (attempt %d/%d), retrying in %ds: %s",
label, label,
@@ -413,7 +415,7 @@ async def sync_positions_from_broker(
broker: Any, broker: Any,
overseas_broker: Any, overseas_broker: Any,
db_conn: Any, db_conn: Any,
settings: "Settings", settings: Settings,
) -> int: ) -> int:
"""Sync open positions from the live broker into the local DB at startup. """Sync open positions from the live broker into the local DB at startup.
@@ -441,9 +443,7 @@ async def sync_positions_from_broker(
if market.exchange_code in seen_exchange_codes: if market.exchange_code in seen_exchange_codes:
continue continue
seen_exchange_codes.add(market.exchange_code) seen_exchange_codes.add(market.exchange_code)
balance_data = await overseas_broker.get_overseas_balance( balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
market.exchange_code
)
log_market = market_code # e.g. "US_NASDAQ" log_market = market_code # e.g. "US_NASDAQ"
except ConnectionError as exc: except ConnectionError as exc:
logger.warning( logger.warning(
@@ -453,9 +453,7 @@ async def sync_positions_from_broker(
) )
continue continue
held_codes = _extract_held_codes_from_balance( held_codes = _extract_held_codes_from_balance(balance_data, is_domestic=market.is_domestic)
balance_data, is_domestic=market.is_domestic
)
for stock_code in held_codes: for stock_code in held_codes:
if get_open_position(db_conn, stock_code, log_market): if get_open_position(db_conn, stock_code, log_market):
continue # already tracked continue # already tracked
@@ -487,9 +485,7 @@ async def sync_positions_from_broker(
synced += 1 synced += 1
if synced: if synced:
logger.info( logger.info("Startup sync complete: %d position(s) synced from broker", synced)
"Startup sync complete: %d position(s) synced from broker", synced
)
else: else:
logger.info("Startup sync: no new positions to sync from broker") logger.info("Startup sync: no new positions to sync from broker")
return synced return synced
@@ -859,15 +855,9 @@ def _apply_staged_exit_override_for_hold(
pnl_pct = (current_price - entry_price) / entry_price * 100.0 pnl_pct = (current_price - entry_price) / entry_price * 100.0
if exit_eval.reason == "hard_stop": if exit_eval.reason == "hard_stop":
rationale = ( rationale = f"Stop-loss triggered ({pnl_pct:.2f}% <= {stop_loss_threshold:.2f}%)"
f"Stop-loss triggered ({pnl_pct:.2f}% <= "
f"{stop_loss_threshold:.2f}%)"
)
elif exit_eval.reason == "arm_take_profit": elif exit_eval.reason == "arm_take_profit":
rationale = ( rationale = f"Take-profit triggered ({pnl_pct:.2f}% >= {arm_pct:.2f}%)"
f"Take-profit triggered ({pnl_pct:.2f}% >= "
f"{arm_pct:.2f}%)"
)
elif exit_eval.reason == "atr_trailing_stop": elif exit_eval.reason == "atr_trailing_stop":
rationale = "ATR trailing-stop triggered" rationale = "ATR trailing-stop triggered"
elif exit_eval.reason == "be_lock_threat": elif exit_eval.reason == "be_lock_threat":
@@ -978,7 +968,10 @@ def _maybe_queue_order_intent(
) )
if queued: if queued:
logger.warning( logger.warning(
"Blackout active: queued order intent %s %s (%s) qty=%d price=%.4f source=%s pending=%d", (
"Blackout active: queued order intent %s %s (%s) "
"qty=%d price=%.4f source=%s pending=%d"
),
order_type, order_type,
stock_code, stock_code,
market.code, market.code,
@@ -1071,7 +1064,10 @@ async def process_blackout_recovery_orders(
) )
if queued_price <= 0 or current_price <= 0: if queued_price <= 0 or current_price <= 0:
logger.info( logger.info(
"Drop queued intent by price revalidation (invalid price): %s %s (%s) queued=%.4f current=%.4f", (
"Drop queued intent by price revalidation (invalid price): "
"%s %s (%s) queued=%.4f current=%.4f"
),
intent.order_type, intent.order_type,
intent.stock_code, intent.stock_code,
market.code, market.code,
@@ -1082,7 +1078,10 @@ async def process_blackout_recovery_orders(
drift_pct = abs(current_price - queued_price) / queued_price * 100.0 drift_pct = abs(current_price - queued_price) / queued_price * 100.0
if drift_pct > max_drift_pct: if drift_pct > max_drift_pct:
logger.info( logger.info(
"Drop queued intent by price revalidation: %s %s (%s) queued=%.4f current=%.4f drift=%.2f%% max=%.2f%%", (
"Drop queued intent by price revalidation: %s %s (%s) "
"queued=%.4f current=%.4f drift=%.2f%% max=%.2f%%"
),
intent.order_type, intent.order_type,
intent.stock_code, intent.stock_code,
market.code, market.code,
@@ -1375,24 +1374,18 @@ async def trading_cycle(
# 1. Fetch market data # 1. Fetch market data
price_output: dict[str, Any] = {} # Populated for overseas markets; used for fallback metrics price_output: dict[str, Any] = {} # Populated for overseas markets; used for fallback metrics
if market.is_domestic: if market.is_domestic:
current_price, price_change_pct, foreigner_net = await broker.get_current_price( current_price, price_change_pct, foreigner_net = await broker.get_current_price(stock_code)
stock_code
)
balance_data = await broker.get_balance() balance_data = await broker.get_balance()
output2 = balance_data.get("output2", [{}]) output2 = balance_data.get("output2", [{}])
total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0 total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
total_cash = safe_float( total_cash = safe_float(
balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") balance_data.get("output2", [{}])[0].get("dnca_tot_amt", "0") if output2 else "0"
if output2
else "0"
) )
purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0 purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
else: else:
# Overseas market # Overseas market
price_data = await overseas_broker.get_overseas_price( price_data = await overseas_broker.get_overseas_price(market.exchange_code, stock_code)
market.exchange_code, stock_code
)
balance_data = await overseas_broker.get_overseas_balance(market.exchange_code) balance_data = await overseas_broker.get_overseas_balance(market.exchange_code)
output2 = balance_data.get("output2", [{}]) output2 = balance_data.get("output2", [{}])
@@ -1459,11 +1452,7 @@ async def trading_cycle(
total_cash = settings.PAPER_OVERSEAS_CASH total_cash = settings.PAPER_OVERSEAS_CASH
# Calculate daily P&L % # Calculate daily P&L %
pnl_pct = ( pnl_pct = ((total_eval - purchase_total) / purchase_total * 100) if purchase_total > 0 else 0.0
((total_eval - purchase_total) / purchase_total * 100)
if purchase_total > 0
else 0.0
)
market_data: dict[str, Any] = { market_data: dict[str, Any] = {
"stock_code": stock_code, "stock_code": stock_code,
@@ -1491,11 +1480,13 @@ async def trading_cycle(
market_data["rsi"] = max(0.0, min(100.0, 50.0 + price_change_pct * 2.0)) market_data["rsi"] = max(0.0, min(100.0, 50.0 + price_change_pct * 2.0))
if price_output and current_price > 0: if price_output and current_price > 0:
pr_high = safe_float( pr_high = safe_float(
price_output.get("high") or price_output.get("ovrs_hgpr") price_output.get("high")
or price_output.get("ovrs_hgpr")
or price_output.get("stck_hgpr") or price_output.get("stck_hgpr")
) )
pr_low = safe_float( pr_low = safe_float(
price_output.get("low") or price_output.get("ovrs_lwpr") price_output.get("low")
or price_output.get("ovrs_lwpr")
or price_output.get("stck_lwpr") or price_output.get("stck_lwpr")
) )
if pr_high > 0 and pr_low > 0 and pr_high >= pr_low: if pr_high > 0 and pr_low > 0 and pr_high >= pr_low:
@@ -1512,9 +1503,7 @@ async def trading_cycle(
if open_pos and current_price > 0: if open_pos and current_price > 0:
entry_price = safe_float(open_pos.get("price"), 0.0) entry_price = safe_float(open_pos.get("price"), 0.0)
if entry_price > 0: if entry_price > 0:
market_data["unrealized_pnl_pct"] = ( market_data["unrealized_pnl_pct"] = (current_price - entry_price) / entry_price * 100
(current_price - entry_price) / entry_price * 100
)
entry_ts = open_pos.get("timestamp") entry_ts = open_pos.get("timestamp")
if entry_ts: if entry_ts:
try: try:
@@ -1745,16 +1734,19 @@ async def trading_cycle(
stock_playbook=stock_playbook, stock_playbook=stock_playbook,
settings=settings, settings=settings,
) )
if open_position and decision.action == "HOLD" and _should_force_exit_for_overnight( if (
open_position
and decision.action == "HOLD"
and _should_force_exit_for_overnight(
market=market, market=market,
settings=settings, settings=settings,
)
): ):
decision = TradeDecision( decision = TradeDecision(
action="SELL", action="SELL",
confidence=max(decision.confidence, 85), confidence=max(decision.confidence, 85),
rationale=( rationale=(
"Forced exit by overnight policy" "Forced exit by overnight policy (session close window / kill switch priority)"
" (session close window / kill switch priority)"
), ),
) )
logger.info( logger.info(
@@ -1834,9 +1826,7 @@ async def trading_cycle(
return return
broker_held_qty = ( broker_held_qty = (
_extract_held_qty_from_balance( _extract_held_qty_from_balance(balance_data, stock_code, is_domestic=market.is_domestic)
balance_data, stock_code, is_domestic=market.is_domestic
)
if decision.action == "SELL" if decision.action == "SELL"
else 0 else 0
) )
@@ -1871,7 +1861,10 @@ async def trading_cycle(
) )
if fx_blocked: if fx_blocked:
logger.warning( logger.warning(
"Skip BUY %s (%s): FX buffer guard (remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)", (
"Skip BUY %s (%s): FX buffer guard "
"(remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)"
),
stock_code, stock_code,
market.name, market.name,
remaining_cash, remaining_cash,
@@ -2068,8 +2061,7 @@ async def trading_cycle(
action="SELL", action="SELL",
confidence=0, confidence=0,
rationale=( rationale=(
"[ghost-close] Broker reported no balance;" "[ghost-close] Broker reported no balance; position closed without fill"
" position closed without fill"
), ),
quantity=0, quantity=0,
price=0.0, price=0.0,
@@ -2275,17 +2267,13 @@ async def handle_domestic_pending_orders(
outcome="cancelled", outcome="cancelled",
) )
except Exception as notify_exc: except Exception as notify_exc:
logger.warning( logger.warning("notify_unfilled_order failed: %s", notify_exc)
"notify_unfilled_order failed: %s", notify_exc
)
else: else:
# First unfilled SELL → resubmit at last * 0.996 (-0.4%). # First unfilled SELL → resubmit at last * 0.996 (-0.4%).
try: try:
last_price, _, _ = await broker.get_current_price(stock_code) last_price, _, _ = await broker.get_current_price(stock_code)
if last_price <= 0: if last_price <= 0:
raise ValueError( raise ValueError(f"Invalid price ({last_price}) for {stock_code}")
f"Invalid price ({last_price}) for {stock_code}"
)
new_price = kr_round_down(last_price * 0.996) new_price = kr_round_down(last_price * 0.996)
validate_order_policy( validate_order_policy(
market=MARKETS["KR"], market=MARKETS["KR"],
@@ -2298,9 +2286,7 @@ async def handle_domestic_pending_orders(
quantity=psbl_qty, quantity=psbl_qty,
price=new_price, price=new_price,
) )
sell_resubmit_counts[key] = ( sell_resubmit_counts[key] = sell_resubmit_counts.get(key, 0) + 1
sell_resubmit_counts.get(key, 0) + 1
)
try: try:
await telegram.notify_unfilled_order( await telegram.notify_unfilled_order(
stock_code=stock_code, stock_code=stock_code,
@@ -2311,9 +2297,7 @@ async def handle_domestic_pending_orders(
new_price=float(new_price), new_price=float(new_price),
) )
except Exception as notify_exc: except Exception as notify_exc:
logger.warning( logger.warning("notify_unfilled_order failed: %s", notify_exc)
"notify_unfilled_order failed: %s", notify_exc
)
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"SELL resubmit failed for KR %s: %s", "SELL resubmit failed for KR %s: %s",
@@ -2381,9 +2365,7 @@ async def handle_overseas_pending_orders(
try: try:
orders = await overseas_broker.get_overseas_pending_orders(exchange_code) orders = await overseas_broker.get_overseas_pending_orders(exchange_code)
except Exception as exc: except Exception as exc:
logger.warning( logger.warning("Failed to fetch pending orders for %s: %s", exchange_code, exc)
"Failed to fetch pending orders for %s: %s", exchange_code, exc
)
continue continue
for order in orders: for order in orders:
@@ -2448,26 +2430,21 @@ async def handle_overseas_pending_orders(
outcome="cancelled", outcome="cancelled",
) )
except Exception as notify_exc: except Exception as notify_exc:
logger.warning( logger.warning("notify_unfilled_order failed: %s", notify_exc)
"notify_unfilled_order failed: %s", notify_exc
)
else: else:
# First unfilled SELL → resubmit at last * 0.996 (-0.4%). # First unfilled SELL → resubmit at last * 0.996 (-0.4%).
try: try:
price_data = await overseas_broker.get_overseas_price( price_data = await overseas_broker.get_overseas_price(
order_exchange, stock_code order_exchange, stock_code
) )
last_price = float( last_price = float(price_data.get("output", {}).get("last", "0") or "0")
price_data.get("output", {}).get("last", "0") or "0"
)
if last_price <= 0: if last_price <= 0:
raise ValueError( raise ValueError(f"Invalid price ({last_price}) for {stock_code}")
f"Invalid price ({last_price}) for {stock_code}"
)
new_price = round(last_price * 0.996, 4) new_price = round(last_price * 0.996, 4)
market_info = next( market_info = next(
( (
m for m in MARKETS.values() m
for m in MARKETS.values()
if m.exchange_code == order_exchange and not m.is_domestic if m.exchange_code == order_exchange and not m.is_domestic
), ),
None, None,
@@ -2485,9 +2462,7 @@ async def handle_overseas_pending_orders(
quantity=nccs_qty, quantity=nccs_qty,
price=new_price, price=new_price,
) )
sell_resubmit_counts[key] = ( sell_resubmit_counts[key] = sell_resubmit_counts.get(key, 0) + 1
sell_resubmit_counts.get(key, 0) + 1
)
try: try:
await telegram.notify_unfilled_order( await telegram.notify_unfilled_order(
stock_code=stock_code, stock_code=stock_code,
@@ -2498,9 +2473,7 @@ async def handle_overseas_pending_orders(
new_price=new_price, new_price=new_price,
) )
except Exception as notify_exc: except Exception as notify_exc:
logger.warning( logger.warning("notify_unfilled_order failed: %s", notify_exc)
"notify_unfilled_order failed: %s", notify_exc
)
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"SELL resubmit failed for %s %s: %s", "SELL resubmit failed for %s %s: %s",
@@ -2659,13 +2632,16 @@ async def run_daily_session(
logger.warning("Playbook notification failed: %s", exc) logger.warning("Playbook notification failed: %s", exc)
logger.info( logger.info(
"Generated playbook for %s: %d stocks, %d scenarios", "Generated playbook for %s: %d stocks, %d scenarios",
market.code, playbook.stock_count, playbook.scenario_count, market.code,
playbook.stock_count,
playbook.scenario_count,
) )
except Exception as exc: except Exception as exc:
logger.error("Playbook generation failed for %s: %s", market.code, exc) logger.error("Playbook generation failed for %s: %s", market.code, exc)
try: try:
await telegram.notify_playbook_failed( await telegram.notify_playbook_failed(
market=market.code, reason=str(exc)[:200], market=market.code,
reason=str(exc)[:200],
) )
except Exception as notify_exc: except Exception as notify_exc:
logger.warning("Playbook failed notification error: %s", notify_exc) logger.warning("Playbook failed notification error: %s", notify_exc)
@@ -2676,12 +2652,10 @@ async def run_daily_session(
for stock_code in watchlist: for stock_code in watchlist:
try: try:
if market.is_domestic: if market.is_domestic:
current_price, price_change_pct, foreigner_net = ( current_price, price_change_pct, foreigner_net = await _retry_connection(
await _retry_connection( broker.get_current_price,
broker.get_current_price, stock_code,
stock_code, label=stock_code,
label=stock_code,
)
) )
else: else:
price_data = await _retry_connection( price_data = await _retry_connection(
@@ -2690,9 +2664,7 @@ async def run_daily_session(
stock_code, stock_code,
label=f"{stock_code}@{market.exchange_code}", label=f"{stock_code}@{market.exchange_code}",
) )
current_price = safe_float( current_price = safe_float(price_data.get("output", {}).get("last", "0"))
price_data.get("output", {}).get("last", "0")
)
# Fallback: if price API returns 0, use scanner candidate price # Fallback: if price API returns 0, use scanner candidate price
if current_price <= 0: if current_price <= 0:
cand_lookup = candidate_map.get(stock_code) cand_lookup = candidate_map.get(stock_code)
@@ -2704,9 +2676,7 @@ async def run_daily_session(
) )
current_price = cand_lookup.price current_price = cand_lookup.price
foreigner_net = 0.0 foreigner_net = 0.0
price_change_pct = safe_float( price_change_pct = safe_float(price_data.get("output", {}).get("rate", "0"))
price_data.get("output", {}).get("rate", "0")
)
# Fall back to scanner candidate price if API returns 0. # Fall back to scanner candidate price if API returns 0.
if current_price <= 0: if current_price <= 0:
cand_lookup = candidate_map.get(stock_code) cand_lookup = candidate_map.get(stock_code)
@@ -2769,15 +2739,9 @@ async def run_daily_session(
if market.is_domestic: if market.is_domestic:
output2 = balance_data.get("output2", [{}]) output2 = balance_data.get("output2", [{}])
total_eval = safe_float( total_eval = safe_float(output2[0].get("tot_evlu_amt", "0")) if output2 else 0
output2[0].get("tot_evlu_amt", "0") total_cash = safe_float(output2[0].get("dnca_tot_amt", "0")) if output2 else 0
) if output2 else 0 purchase_total = safe_float(output2[0].get("pchs_amt_smtl_amt", "0")) if output2 else 0
total_cash = safe_float(
output2[0].get("dnca_tot_amt", "0")
) if output2 else 0
purchase_total = safe_float(
output2[0].get("pchs_amt_smtl_amt", "0")
) if output2 else 0
else: else:
output2 = balance_data.get("output2", [{}]) output2 = balance_data.get("output2", [{}])
if isinstance(output2, list) and output2: if isinstance(output2, list) and output2:
@@ -2788,18 +2752,15 @@ async def run_daily_session(
balance_info = {} balance_info = {}
total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0") total_eval = safe_float(balance_info.get("frcr_evlu_tota", "0") or "0")
purchase_total = safe_float( purchase_total = safe_float(balance_info.get("frcr_buy_amt_smtl", "0") or "0")
balance_info.get("frcr_buy_amt_smtl", "0") or "0"
)
# Fetch available foreign currency cash via inquire-psamount (TTTS3007R/VTTS3007R). # Fetch available foreign currency cash via inquire-psamount (TTTS3007R/VTTS3007R).
# TTTS3012R output2 does not include a cash/deposit field — frcr_dncl_amt_2 does not exist. # TTTS3012R output2 does not include a cash/deposit field.
# frcr_dncl_amt_2 does not exist.
# Use the first stock with a valid price as the reference for the buying power query. # Use the first stock with a valid price as the reference for the buying power query.
# Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트 # Source: 한국투자증권 오픈API 전체문서 (20260221) — '해외주식 매수가능금액조회' 시트
total_cash = 0.0 total_cash = 0.0
ref_stock = next( ref_stock = next((s for s in stocks_data if s.get("current_price", 0) > 0), None)
(s for s in stocks_data if s.get("current_price", 0) > 0), None
)
if ref_stock: if ref_stock:
try: try:
ps_data = await overseas_broker.get_overseas_buying_power( ps_data = await overseas_broker.get_overseas_buying_power(
@@ -2819,11 +2780,7 @@ async def run_daily_session(
# Paper mode fallback: VTS overseas balance API often fails for many accounts. # Paper mode fallback: VTS overseas balance API often fails for many accounts.
# Only activate in paper mode — live mode must use real balance from KIS. # Only activate in paper mode — live mode must use real balance from KIS.
if ( if total_cash <= 0 and settings.MODE == "paper" and settings.PAPER_OVERSEAS_CASH > 0:
total_cash <= 0
and settings.MODE == "paper"
and settings.PAPER_OVERSEAS_CASH > 0
):
total_cash = settings.PAPER_OVERSEAS_CASH total_cash = settings.PAPER_OVERSEAS_CASH
# Capture the day's opening portfolio value on the first market processed # Capture the day's opening portfolio value on the first market processed
@@ -2856,13 +2813,17 @@ async def run_daily_session(
# Evaluate scenarios for each stock (local, no API calls) # Evaluate scenarios for each stock (local, no API calls)
logger.info( logger.info(
"Evaluating %d stocks against playbook for %s", "Evaluating %d stocks against playbook for %s",
len(stocks_data), market.name, len(stocks_data),
market.name,
) )
for stock_data in stocks_data: for stock_data in stocks_data:
stock_code = stock_data["stock_code"] stock_code = stock_data["stock_code"]
stock_playbook = playbook.get_stock_playbook(stock_code) stock_playbook = playbook.get_stock_playbook(stock_code)
match = scenario_engine.evaluate( match = scenario_engine.evaluate(
playbook, stock_code, stock_data, portfolio_data, playbook,
stock_code,
stock_data,
portfolio_data,
) )
decision = TradeDecision( decision = TradeDecision(
action=match.action.value, action=match.action.value,
@@ -2969,9 +2930,13 @@ async def run_daily_session(
stock_playbook=stock_playbook, stock_playbook=stock_playbook,
settings=settings, settings=settings,
) )
if daily_open and decision.action == "HOLD" and _should_force_exit_for_overnight( if (
market=market, daily_open
settings=settings, and decision.action == "HOLD"
and _should_force_exit_for_overnight(
market=market,
settings=settings,
)
): ):
decision = TradeDecision( decision = TradeDecision(
action="SELL", action="SELL",
@@ -3063,16 +3028,21 @@ async def run_daily_session(
) )
continue continue
order_amount = stock_data["current_price"] * quantity order_amount = stock_data["current_price"] * quantity
fx_blocked, remaining_cash, required_buffer = _should_block_overseas_buy_for_fx_buffer( fx_blocked, remaining_cash, required_buffer = (
market=market, _should_block_overseas_buy_for_fx_buffer(
action=decision.action, market=market,
total_cash=total_cash, action=decision.action,
order_amount=order_amount, total_cash=total_cash,
settings=settings, order_amount=order_amount,
settings=settings,
)
) )
if fx_blocked: if fx_blocked:
logger.warning( logger.warning(
"Skip BUY %s (%s): FX buffer guard (remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)", (
"Skip BUY %s (%s): FX buffer guard "
"(remaining=%.2f, required=%.2f, cash=%.2f, order=%.2f)"
),
stock_code, stock_code,
market.name, market.name,
remaining_cash, remaining_cash,
@@ -3090,7 +3060,10 @@ async def run_daily_session(
if now < daily_cooldown_until: if now < daily_cooldown_until:
remaining = int(daily_cooldown_until - now) remaining = int(daily_cooldown_until - now)
logger.info( logger.info(
"Skip BUY %s (%s): insufficient-balance cooldown active (%ds remaining)", (
"Skip BUY %s (%s): insufficient-balance cooldown active "
"(%ds remaining)"
),
stock_code, stock_code,
market.name, market.name,
remaining, remaining,
@@ -3149,13 +3122,9 @@ async def run_daily_session(
# Use limit orders (지정가) for domestic stocks. # Use limit orders (지정가) for domestic stocks.
# KRX tick rounding applied via kr_round_down. # KRX tick rounding applied via kr_round_down.
if decision.action == "BUY": if decision.action == "BUY":
order_price = kr_round_down( order_price = kr_round_down(stock_data["current_price"] * 1.002)
stock_data["current_price"] * 1.002
)
else: else:
order_price = kr_round_down( order_price = kr_round_down(stock_data["current_price"] * 0.998)
stock_data["current_price"] * 0.998
)
try: try:
validate_order_policy( validate_order_policy(
market=market, market=market,
@@ -3260,9 +3229,7 @@ async def run_daily_session(
except Exception as exc: except Exception as exc:
logger.warning("Telegram notification failed: %s", exc) logger.warning("Telegram notification failed: %s", exc)
except Exception as exc: except Exception as exc:
logger.error( logger.error("Order execution failed for %s: %s", stock_code, exc)
"Order execution failed for %s: %s", stock_code, exc
)
continue continue
if decision.action == "SELL" and order_succeeded: if decision.action == "SELL" and order_succeeded:
@@ -3286,7 +3253,9 @@ async def run_daily_session(
accuracy=1 if trade_pnl > 0 else 0, accuracy=1 if trade_pnl > 0 else 0,
) )
if trade_pnl < 0: if trade_pnl < 0:
cooldown_key = _stoploss_cooldown_key(market=market, stock_code=stock_code) cooldown_key = _stoploss_cooldown_key(
market=market, stock_code=stock_code
)
cooldown_minutes = _stoploss_cooldown_minutes( cooldown_minutes = _stoploss_cooldown_minutes(
settings, settings,
market=market, market=market,
@@ -3369,7 +3338,8 @@ async def _handle_market_close(
def _run_context_scheduler( def _run_context_scheduler(
scheduler: ContextScheduler, now: datetime | None = None, scheduler: ContextScheduler,
now: datetime | None = None,
) -> None: ) -> None:
"""Run periodic context scheduler tasks and log when anything executes.""" """Run periodic context scheduler tasks and log when anything executes."""
result = scheduler.run_if_due(now=now) result = scheduler.run_if_due(now=now)
@@ -3438,6 +3408,7 @@ def _start_dashboard_server(settings: Settings) -> threading.Thread | None:
# reported synchronously (avoids the misleading "started" → "failed" log pair). # reported synchronously (avoids the misleading "started" → "failed" log pair).
try: try:
import uvicorn # noqa: F401 import uvicorn # noqa: F401
from src.dashboard import create_dashboard_app # noqa: F401 from src.dashboard import create_dashboard_app # noqa: F401
except ImportError as exc: except ImportError as exc:
logger.warning("Dashboard server unavailable (missing dependency): %s", exc) logger.warning("Dashboard server unavailable (missing dependency): %s", exc)
@@ -3446,6 +3417,7 @@ def _start_dashboard_server(settings: Settings) -> threading.Thread | None:
def _serve() -> None: def _serve() -> None:
try: try:
import uvicorn import uvicorn
from src.dashboard import create_dashboard_app from src.dashboard import create_dashboard_app
app = create_dashboard_app(settings.DB_PATH, mode=settings.MODE) app = create_dashboard_app(settings.DB_PATH, mode=settings.MODE)
@@ -3586,8 +3558,7 @@ async def run(settings: Settings) -> None:
pause_trading.set() pause_trading.set()
logger.info("Trading resumed via Telegram command") logger.info("Trading resumed via Telegram command")
await telegram.send_message( await telegram.send_message(
"<b>▶️ Trading Resumed</b>\n\n" "<b>▶️ Trading Resumed</b>\n\nTrading operations have been restarted."
"Trading operations have been restarted."
) )
async def handle_status() -> None: async def handle_status() -> None:
@@ -3630,9 +3601,7 @@ async def run(settings: Settings) -> None:
except Exception as exc: except Exception as exc:
logger.error("Error in /status handler: %s", exc) logger.error("Error in /status handler: %s", exc)
await telegram.send_message( await telegram.send_message("<b>⚠️ Error</b>\n\nFailed to retrieve trading status.")
"<b>⚠️ Error</b>\n\nFailed to retrieve trading status."
)
async def handle_positions() -> None: async def handle_positions() -> None:
"""Handle /positions command - show account summary.""" """Handle /positions command - show account summary."""
@@ -3643,8 +3612,7 @@ async def run(settings: Settings) -> None:
if not output2: if not output2:
await telegram.send_message( await telegram.send_message(
"<b>💼 Account Summary</b>\n\n" "<b>💼 Account Summary</b>\n\nNo balance information available."
"No balance information available."
) )
return return
@@ -3673,9 +3641,7 @@ async def run(settings: Settings) -> None:
except Exception as exc: except Exception as exc:
logger.error("Error in /positions handler: %s", exc) logger.error("Error in /positions handler: %s", exc)
await telegram.send_message( await telegram.send_message("<b>⚠️ Error</b>\n\nFailed to retrieve positions.")
"<b>⚠️ Error</b>\n\nFailed to retrieve positions."
)
async def handle_report() -> None: async def handle_report() -> None:
"""Handle /report command - show daily summary metrics.""" """Handle /report command - show daily summary metrics."""
@@ -3719,9 +3685,7 @@ async def run(settings: Settings) -> None:
) )
except Exception as exc: except Exception as exc:
logger.error("Error in /report handler: %s", exc) logger.error("Error in /report handler: %s", exc)
await telegram.send_message( await telegram.send_message("<b>⚠️ Error</b>\n\nFailed to generate daily report.")
"<b>⚠️ Error</b>\n\nFailed to generate daily report."
)
async def handle_scenarios() -> None: async def handle_scenarios() -> None:
"""Handle /scenarios command - show today's playbook scenarios.""" """Handle /scenarios command - show today's playbook scenarios."""
@@ -3770,9 +3734,7 @@ async def run(settings: Settings) -> None:
await telegram.send_message("\n".join(lines).strip()) await telegram.send_message("\n".join(lines).strip())
except Exception as exc: except Exception as exc:
logger.error("Error in /scenarios handler: %s", exc) logger.error("Error in /scenarios handler: %s", exc)
await telegram.send_message( await telegram.send_message("<b>⚠️ Error</b>\n\nFailed to retrieve scenarios.")
"<b>⚠️ Error</b>\n\nFailed to retrieve scenarios."
)
async def handle_review() -> None: async def handle_review() -> None:
"""Handle /review command - show recent scorecards.""" """Handle /review command - show recent scorecards."""
@@ -3788,9 +3750,7 @@ async def run(settings: Settings) -> None:
).fetchall() ).fetchall()
if not rows: if not rows:
await telegram.send_message( await telegram.send_message("<b>📝 Recent Reviews</b>\n\nNo scorecards available.")
"<b>📝 Recent Reviews</b>\n\nNo scorecards available."
)
return return
lines = ["<b>📝 Recent Reviews</b>", ""] lines = ["<b>📝 Recent Reviews</b>", ""]
@@ -3808,9 +3768,7 @@ async def run(settings: Settings) -> None:
await telegram.send_message("\n".join(lines)) await telegram.send_message("\n".join(lines))
except Exception as exc: except Exception as exc:
logger.error("Error in /review handler: %s", exc) logger.error("Error in /review handler: %s", exc)
await telegram.send_message( await telegram.send_message("<b>⚠️ Error</b>\n\nFailed to retrieve reviews.")
"<b>⚠️ Error</b>\n\nFailed to retrieve reviews."
)
async def handle_notify(args: list[str]) -> None: async def handle_notify(args: list[str]) -> None:
"""Handle /notify [key] [on|off] — query or change notification filters.""" """Handle /notify [key] [on|off] — query or change notification filters."""
@@ -3845,8 +3803,7 @@ async def run(settings: Settings) -> None:
else: else:
valid = ", ".join(list(status.keys()) + ["all"]) valid = ", ".join(list(status.keys()) + ["all"])
await telegram.send_message( await telegram.send_message(
f"❌ 알 수 없는 키: <code>{key}</code>\n" f"❌ 알 수 없는 키: <code>{key}</code>\n유효한 키: {valid}"
f"유효한 키: {valid}"
) )
return return
@@ -3858,30 +3815,22 @@ async def run(settings: Settings) -> None:
value = toggle == "on" value = toggle == "on"
if telegram.set_notification(key, value): if telegram.set_notification(key, value):
icon = "" if value else "" icon = "" if value else ""
label = f"전체 알림" if key == "all" else f"<code>{key}</code> 알림" label = "전체 알림" if key == "all" else f"<code>{key}</code> 알림"
state = "켜짐" if value else "꺼짐" state = "켜짐" if value else "꺼짐"
await telegram.send_message(f"{icon} {label}{state}") await telegram.send_message(f"{icon} {label}{state}")
logger.info("Notification filter changed via Telegram: %s=%s", key, value) logger.info("Notification filter changed via Telegram: %s=%s", key, value)
else: else:
valid = ", ".join(list(telegram.filter_status().keys()) + ["all"]) valid = ", ".join(list(telegram.filter_status().keys()) + ["all"])
await telegram.send_message( await telegram.send_message(f"❌ 알 수 없는 키: <code>{key}</code>\n유효한 키: {valid}")
f"❌ 알 수 없는 키: <code>{key}</code>\n"
f"유효한 키: {valid}"
)
async def handle_dashboard() -> None: async def handle_dashboard() -> None:
"""Handle /dashboard command - show dashboard URL if enabled.""" """Handle /dashboard command - show dashboard URL if enabled."""
if not settings.DASHBOARD_ENABLED: if not settings.DASHBOARD_ENABLED:
await telegram.send_message( await telegram.send_message("<b>🖥️ Dashboard</b>\n\nDashboard is not enabled.")
"<b>🖥️ Dashboard</b>\n\nDashboard is not enabled."
)
return return
url = f"http://{settings.DASHBOARD_HOST}:{settings.DASHBOARD_PORT}" url = f"http://{settings.DASHBOARD_HOST}:{settings.DASHBOARD_PORT}"
await telegram.send_message( await telegram.send_message(f"<b>🖥️ Dashboard</b>\n\n<b>URL:</b> {url}")
"<b>🖥️ Dashboard</b>\n\n"
f"<b>URL:</b> {url}"
)
command_handler.register_command("help", handle_help) command_handler.register_command("help", handle_help)
command_handler.register_command("stop", handle_stop) command_handler.register_command("stop", handle_stop)
@@ -4182,9 +4131,7 @@ async def run(settings: Settings) -> None:
) )
# Store candidates per market for selection context logging # Store candidates per market for selection context logging
scan_candidates[market.code] = { scan_candidates[market.code] = {c.stock_code: c for c in candidates}
c.stock_code: c for c in candidates
}
logger.info( logger.info(
"Smart Scanner: Found %d candidates for %s: %s", "Smart Scanner: Found %d candidates for %s: %s",
@@ -4194,9 +4141,7 @@ async def run(settings: Settings) -> None:
) )
# Get market-local date for playbook keying # Get market-local date for playbook keying
market_today = datetime.now( market_today = datetime.now(market.timezone).date()
market.timezone
).date()
# Load or generate playbook (1 Gemini call per market per day) # Load or generate playbook (1 Gemini call per market per day)
if market.code not in playbooks: if market.code not in playbooks:
@@ -4234,7 +4179,8 @@ async def run(settings: Settings) -> None:
except Exception as exc: except Exception as exc:
logger.error( logger.error(
"Playbook generation failed for %s: %s", "Playbook generation failed for %s: %s",
market.code, exc, market.code,
exc,
) )
try: try:
await telegram.notify_playbook_failed( await telegram.notify_playbook_failed(
@@ -4279,7 +4225,8 @@ async def run(settings: Settings) -> None:
except Exception as exc: except Exception as exc:
logger.warning( logger.warning(
"Failed to fetch holdings for %s: %s — skipping holdings merge", "Failed to fetch holdings for %s: %s — skipping holdings merge",
market.name, exc, market.name,
exc,
) )
held_codes = [] held_codes = []
@@ -4288,7 +4235,8 @@ async def run(settings: Settings) -> None:
if extra_held: if extra_held:
logger.info( logger.info(
"Holdings added to loop for %s (not in scanner): %s", "Holdings added to loop for %s (not in scanner): %s",
market.name, extra_held, market.name,
extra_held,
) )
if not stock_codes: if not stock_codes:

View File

@@ -211,9 +211,7 @@ def get_open_markets(
return is_market_open(market, now) return is_market_open(market, now)
open_markets = [ open_markets = [
MARKETS[code] MARKETS[code] for code in enabled_markets if code in MARKETS and is_available(MARKETS[code])
for code in enabled_markets
if code in MARKETS and is_available(MARKETS[code])
] ]
return sorted(open_markets, key=lambda m: m.code) return sorted(open_markets, key=lambda m: m.code)
@@ -282,9 +280,7 @@ def get_next_market_open(
# Calculate next open time for this market # Calculate next open time for this market
for days_ahead in range(7): # Check next 7 days for days_ahead in range(7): # Check next 7 days
check_date = market_now.date() + timedelta(days=days_ahead) check_date = market_now.date() + timedelta(days=days_ahead)
check_datetime = datetime.combine( check_datetime = datetime.combine(check_date, market.open_time, tzinfo=market.timezone)
check_date, market.open_time, tzinfo=market.timezone
)
# Skip weekends # Skip weekends
if check_datetime.weekday() >= 5: if check_datetime.weekday() >= 5:

View File

@@ -4,7 +4,7 @@ import asyncio
import logging import logging
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import dataclass, fields from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import ClassVar from typing import ClassVar
@@ -136,14 +136,14 @@ class TelegramClient:
self._enabled = enabled self._enabled = enabled
self._rate_limiter = LeakyBucket(rate=rate_limit) self._rate_limiter = LeakyBucket(rate=rate_limit)
self._session: aiohttp.ClientSession | None = None self._session: aiohttp.ClientSession | None = None
self._filter = notification_filter if notification_filter is not None else NotificationFilter() self._filter = (
notification_filter if notification_filter is not None else NotificationFilter()
)
if not enabled: if not enabled:
logger.info("Telegram notifications disabled via configuration") logger.info("Telegram notifications disabled via configuration")
elif bot_token is None or chat_id is None: elif bot_token is None or chat_id is None:
logger.warning( logger.warning("Telegram notifications disabled (missing bot_token or chat_id)")
"Telegram notifications disabled (missing bot_token or chat_id)"
)
self._enabled = False self._enabled = False
else: else:
logger.info("Telegram notifications enabled for chat_id=%s", chat_id) logger.info("Telegram notifications enabled for chat_id=%s", chat_id)
@@ -209,14 +209,12 @@ class TelegramClient:
async with session.post(url, json=payload) as resp: async with session.post(url, json=payload) as resp:
if resp.status != 200: if resp.status != 200:
error_text = await resp.text() error_text = await resp.text()
logger.error( logger.error("Telegram API error (status=%d): %s", resp.status, error_text)
"Telegram API error (status=%d): %s", resp.status, error_text
)
return False return False
logger.debug("Telegram message sent: %s", text[:50]) logger.debug("Telegram message sent: %s", text[:50])
return True return True
except asyncio.TimeoutError: except TimeoutError:
logger.error("Telegram message timeout") logger.error("Telegram message timeout")
return False return False
except aiohttp.ClientError as exc: except aiohttp.ClientError as exc:
@@ -305,9 +303,7 @@ class TelegramClient:
NotificationMessage(priority=NotificationPriority.LOW, message=message) NotificationMessage(priority=NotificationPriority.LOW, message=message)
) )
async def notify_circuit_breaker( async def notify_circuit_breaker(self, pnl_pct: float, threshold: float) -> None:
self, pnl_pct: float, threshold: float
) -> None:
""" """
Notify circuit breaker activation. Notify circuit breaker activation.
@@ -354,9 +350,7 @@ class TelegramClient:
NotificationMessage(priority=NotificationPriority.HIGH, message=message) NotificationMessage(priority=NotificationPriority.HIGH, message=message)
) )
async def notify_system_start( async def notify_system_start(self, mode: str, enabled_markets: list[str]) -> None:
self, mode: str, enabled_markets: list[str]
) -> None:
""" """
Notify system startup. Notify system startup.
@@ -369,9 +363,7 @@ class TelegramClient:
mode_emoji = "📝" if mode == "paper" else "💰" mode_emoji = "📝" if mode == "paper" else "💰"
markets_str = ", ".join(enabled_markets) markets_str = ", ".join(enabled_markets)
message = ( message = (
f"<b>{mode_emoji} System Started</b>\n" f"<b>{mode_emoji} System Started</b>\nMode: {mode.upper()}\nMarkets: {markets_str}"
f"Mode: {mode.upper()}\n"
f"Markets: {markets_str}"
) )
await self._send_notification( await self._send_notification(
NotificationMessage(priority=NotificationPriority.MEDIUM, message=message) NotificationMessage(priority=NotificationPriority.MEDIUM, message=message)
@@ -445,11 +437,7 @@ class TelegramClient:
""" """
if not self._filter.playbook: if not self._filter.playbook:
return return
message = ( message = f"<b>Playbook Failed</b>\nMarket: {market}\nReason: {reason[:200]}"
f"<b>Playbook Failed</b>\n"
f"Market: {market}\n"
f"Reason: {reason[:200]}"
)
await self._send_notification( await self._send_notification(
NotificationMessage(priority=NotificationPriority.HIGH, message=message) NotificationMessage(priority=NotificationPriority.HIGH, message=message)
) )
@@ -469,9 +457,7 @@ class TelegramClient:
if "circuit breaker" in reason.lower() if "circuit breaker" in reason.lower()
else NotificationPriority.MEDIUM else NotificationPriority.MEDIUM
) )
await self._send_notification( await self._send_notification(NotificationMessage(priority=priority, message=message))
NotificationMessage(priority=priority, message=message)
)
async def notify_unfilled_order( async def notify_unfilled_order(
self, self,
@@ -496,11 +482,7 @@ class TelegramClient:
return return
# SELL resubmit is high priority — position liquidation at risk. # SELL resubmit is high priority — position liquidation at risk.
# BUY cancel is medium priority — only cash is freed. # BUY cancel is medium priority — only cash is freed.
priority = ( priority = NotificationPriority.HIGH if action == "SELL" else NotificationPriority.MEDIUM
NotificationPriority.HIGH
if action == "SELL"
else NotificationPriority.MEDIUM
)
outcome_emoji = "🔄" if outcome == "resubmitted" else "" outcome_emoji = "🔄" if outcome == "resubmitted" else ""
outcome_label = "재주문" if outcome == "resubmitted" else "취소됨" outcome_label = "재주문" if outcome == "resubmitted" else "취소됨"
action_emoji = "🔴" if action == "SELL" else "🟢" action_emoji = "🔴" if action == "SELL" else "🟢"
@@ -515,9 +497,7 @@ class TelegramClient:
message = "\n".join(lines) message = "\n".join(lines)
await self._send_notification(NotificationMessage(priority=priority, message=message)) await self._send_notification(NotificationMessage(priority=priority, message=message))
async def notify_error( async def notify_error(self, error_type: str, error_msg: str, context: str) -> None:
self, error_type: str, error_msg: str, context: str
) -> None:
""" """
Notify system error. Notify system error.
@@ -541,9 +521,7 @@ class TelegramClient:
class TelegramCommandHandler: class TelegramCommandHandler:
"""Handles incoming Telegram commands via long polling.""" """Handles incoming Telegram commands via long polling."""
def __init__( def __init__(self, client: TelegramClient, polling_interval: float = 1.0) -> None:
self, client: TelegramClient, polling_interval: float = 1.0
) -> None:
""" """
Initialize command handler. Initialize command handler.
@@ -559,9 +537,7 @@ class TelegramCommandHandler:
self._polling_task: asyncio.Task[None] | None = None self._polling_task: asyncio.Task[None] | None = None
self._running = False self._running = False
def register_command( def register_command(self, command: str, handler: Callable[[], Awaitable[None]]) -> None:
self, command: str, handler: Callable[[], Awaitable[None]]
) -> None:
""" """
Register a command handler (no arguments). Register a command handler (no arguments).
@@ -672,7 +648,7 @@ class TelegramCommandHandler:
return updates return updates
except asyncio.TimeoutError: except TimeoutError:
logger.debug("getUpdates timeout (normal)") logger.debug("getUpdates timeout (normal)")
return [] return []
except aiohttp.ClientError as exc: except aiohttp.ClientError as exc:
@@ -697,9 +673,7 @@ class TelegramCommandHandler:
# Verify chat_id matches configured chat # Verify chat_id matches configured chat
chat_id = str(message.get("chat", {}).get("id", "")) chat_id = str(message.get("chat", {}).get("id", ""))
if chat_id != self._client._chat_id: if chat_id != self._client._chat_id:
logger.warning( logger.warning("Ignoring command from unauthorized chat_id: %s", chat_id)
"Ignoring command from unauthorized chat_id: %s", chat_id
)
return return
# Extract command text # Extract command text

View File

@@ -8,12 +8,12 @@ Defines the data contracts for the proactive strategy system:
from __future__ import annotations from __future__ import annotations
from datetime import UTC, date, datetime from datetime import UTC, date, datetime
from enum import Enum from enum import StrEnum
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
class ScenarioAction(str, Enum): class ScenarioAction(StrEnum):
"""Actions that can be taken by scenarios.""" """Actions that can be taken by scenarios."""
BUY = "BUY" BUY = "BUY"
@@ -22,7 +22,7 @@ class ScenarioAction(str, Enum):
REDUCE_ALL = "REDUCE_ALL" REDUCE_ALL = "REDUCE_ALL"
class MarketOutlook(str, Enum): class MarketOutlook(StrEnum):
"""AI's assessment of market direction.""" """AI's assessment of market direction."""
BULLISH = "bullish" BULLISH = "bullish"
@@ -32,7 +32,7 @@ class MarketOutlook(str, Enum):
BEARISH = "bearish" BEARISH = "bearish"
class PlaybookStatus(str, Enum): class PlaybookStatus(StrEnum):
"""Lifecycle status of a playbook.""" """Lifecycle status of a playbook."""
PENDING = "pending" PENDING = "pending"

View File

@@ -6,7 +6,6 @@ Designed for the pre-market strategy system (one playbook per market per day).
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import sqlite3 import sqlite3
from datetime import date from datetime import date
@@ -53,8 +52,10 @@ class PlaybookStore:
row_id = cursor.lastrowid or 0 row_id = cursor.lastrowid or 0
logger.info( logger.info(
"Saved playbook for %s/%s (%d stocks, %d scenarios)", "Saved playbook for %s/%s (%d stocks, %d scenarios)",
playbook.date, playbook.market, playbook.date,
playbook.stock_count, playbook.scenario_count, playbook.market,
playbook.stock_count,
playbook.scenario_count,
) )
return row_id return row_id

View File

@@ -6,10 +6,10 @@ State progression is monotonic (promotion-only) except terminal EXITED.
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import StrEnum
class PositionState(str, Enum): class PositionState(StrEnum):
HOLDING = "HOLDING" HOLDING = "HOLDING"
BE_LOCK = "BE_LOCK" BE_LOCK = "BE_LOCK"
ARMED = "ARMED" ARMED = "ARMED"
@@ -40,12 +40,7 @@ def evaluate_exit_first(inp: StateTransitionInput) -> bool:
EXITED must be evaluated before any promotion. EXITED must be evaluated before any promotion.
""" """
return ( return inp.hard_stop_hit or inp.trailing_stop_hit or inp.model_exit_signal or inp.be_lock_threat
inp.hard_stop_hit
or inp.trailing_stop_hit
or inp.model_exit_signal
or inp.be_lock_threat
)
def promote_state(current: PositionState, inp: StateTransitionInput) -> PositionState: def promote_state(current: PositionState, inp: StateTransitionInput) -> PositionState:

View File

@@ -124,12 +124,14 @@ class PreMarketPlanner:
# 4. Parse response # 4. Parse response
playbook = self._parse_response( playbook = self._parse_response(
decision.rationale, today, market, candidates, cross_market, decision.rationale,
today,
market,
candidates,
cross_market,
current_holdings=current_holdings, current_holdings=current_holdings,
) )
playbook_with_tokens = playbook.model_copy( playbook_with_tokens = playbook.model_copy(update={"token_count": decision.token_count})
update={"token_count": decision.token_count}
)
logger.info( logger.info(
"Generated playbook for %s: %d stocks, %d scenarios, %d tokens", "Generated playbook for %s: %d stocks, %d scenarios, %d tokens",
market, market,
@@ -146,7 +148,9 @@ class PreMarketPlanner:
return self._empty_playbook(today, market) return self._empty_playbook(today, market)
def build_cross_market_context( def build_cross_market_context(
self, target_market: str, today: date | None = None, self,
target_market: str,
today: date | None = None,
) -> CrossMarketContext | None: ) -> CrossMarketContext | None:
"""Build cross-market context from the other market's L6 data. """Build cross-market context from the other market's L6 data.
@@ -192,7 +196,9 @@ class PreMarketPlanner:
) )
def build_self_market_scorecard( def build_self_market_scorecard(
self, market: str, today: date | None = None, self,
market: str,
today: date | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Build previous-day scorecard for the same market.""" """Build previous-day scorecard for the same market."""
if today is None: if today is None:
@@ -320,18 +326,18 @@ class PreMarketPlanner:
f"{context_text}\n" f"{context_text}\n"
f"## Instructions\n" f"## Instructions\n"
f"Return a JSON object with this exact structure:\n" f"Return a JSON object with this exact structure:\n"
f'{{\n' f"{{\n"
f' "market_outlook": "bullish|neutral_to_bullish|neutral' f' "market_outlook": "bullish|neutral_to_bullish|neutral'
f'|neutral_to_bearish|bearish",\n' f'|neutral_to_bearish|bearish",\n'
f' "global_rules": [\n' f' "global_rules": [\n'
f' {{"condition": "portfolio_pnl_pct < -2.0",' f' {{"condition": "portfolio_pnl_pct < -2.0",'
f' "action": "REDUCE_ALL", "rationale": "..."}}\n' f' "action": "REDUCE_ALL", "rationale": "..."}}\n'
f' ],\n' f" ],\n"
f' "stocks": [\n' f' "stocks": [\n'
f' {{\n' f" {{\n"
f' "stock_code": "...",\n' f' "stock_code": "...",\n'
f' "scenarios": [\n' f' "scenarios": [\n'
f' {{\n' f" {{\n"
f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0,' f' "condition": {{"rsi_below": 30, "volume_ratio_above": 2.0,'
f' "unrealized_pnl_pct_above": 3.0, "holding_days_above": 5}},\n' f' "unrealized_pnl_pct_above": 3.0, "holding_days_above": 5}},\n'
f' "action": "BUY|SELL|HOLD",\n' f' "action": "BUY|SELL|HOLD",\n'
@@ -340,11 +346,11 @@ class PreMarketPlanner:
f' "stop_loss_pct": -2.0,\n' f' "stop_loss_pct": -2.0,\n'
f' "take_profit_pct": 3.0,\n' f' "take_profit_pct": 3.0,\n'
f' "rationale": "..."\n' f' "rationale": "..."\n'
f' }}\n' f" }}\n"
f' ]\n' f" ]\n"
f' }}\n' f" }}\n"
f' ]\n' f" ]\n"
f'}}\n\n' f"}}\n\n"
f"Rules:\n" f"Rules:\n"
f"- Max {max_scenarios} scenarios per stock\n" f"- Max {max_scenarios} scenarios per stock\n"
f"- Candidates list is the primary source for BUY candidates\n" f"- Candidates list is the primary source for BUY candidates\n"
@@ -575,8 +581,7 @@ class PreMarketPlanner:
stop_loss_pct=-3.0, stop_loss_pct=-3.0,
take_profit_pct=5.0, take_profit_pct=5.0,
rationale=( rationale=(
f"Rule-based BUY: oversold signal, " f"Rule-based BUY: oversold signal, RSI={c.rsi:.0f} (fallback planner)"
f"RSI={c.rsi:.0f} (fallback planner)"
), ),
) )
) )

View File

@@ -107,7 +107,9 @@ class ScenarioEngine:
# 2. Find stock playbook # 2. Find stock playbook
stock_pb = playbook.get_stock_playbook(stock_code) stock_pb = playbook.get_stock_playbook(stock_code)
if stock_pb is None: if stock_pb is None:
logger.debug("No playbook for %s — defaulting to %s", stock_code, playbook.default_action) logger.debug(
"No playbook for %s — defaulting to %s", stock_code, playbook.default_action
)
return ScenarioMatch( return ScenarioMatch(
stock_code=stock_code, stock_code=stock_code,
matched_scenario=None, matched_scenario=None,
@@ -135,7 +137,9 @@ class ScenarioEngine:
) )
# 4. No match — default action # 4. No match — default action
logger.debug("No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action) logger.debug(
"No scenario matched for %s — defaulting to %s", stock_code, playbook.default_action
)
return ScenarioMatch( return ScenarioMatch(
stock_code=stock_code, stock_code=stock_code,
matched_scenario=None, matched_scenario=None,
@@ -198,17 +202,27 @@ class ScenarioEngine:
checks.append(price is not None and price < condition.price_below) checks.append(price is not None and price < condition.price_below)
price_change_pct = self._safe_float(market_data.get("price_change_pct")) price_change_pct = self._safe_float(market_data.get("price_change_pct"))
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: if (
condition.price_change_pct_above is not None
or condition.price_change_pct_below is not None
):
if "price_change_pct" not in market_data: if "price_change_pct" not in market_data:
self._warn_missing_key("price_change_pct") self._warn_missing_key("price_change_pct")
if condition.price_change_pct_above is not None: if condition.price_change_pct_above is not None:
checks.append(price_change_pct is not None and price_change_pct > condition.price_change_pct_above) checks.append(
price_change_pct is not None and price_change_pct > condition.price_change_pct_above
)
if condition.price_change_pct_below is not None: if condition.price_change_pct_below is not None:
checks.append(price_change_pct is not None and price_change_pct < condition.price_change_pct_below) checks.append(
price_change_pct is not None and price_change_pct < condition.price_change_pct_below
)
# Position-aware conditions # Position-aware conditions
unrealized_pnl_pct = self._safe_float(market_data.get("unrealized_pnl_pct")) unrealized_pnl_pct = self._safe_float(market_data.get("unrealized_pnl_pct"))
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None: if (
condition.unrealized_pnl_pct_above is not None
or condition.unrealized_pnl_pct_below is not None
):
if "unrealized_pnl_pct" not in market_data: if "unrealized_pnl_pct" not in market_data:
self._warn_missing_key("unrealized_pnl_pct") self._warn_missing_key("unrealized_pnl_pct")
if condition.unrealized_pnl_pct_above is not None: if condition.unrealized_pnl_pct_above is not None:
@@ -227,15 +241,9 @@ class ScenarioEngine:
if "holding_days" not in market_data: if "holding_days" not in market_data:
self._warn_missing_key("holding_days") self._warn_missing_key("holding_days")
if condition.holding_days_above is not None: if condition.holding_days_above is not None:
checks.append( checks.append(holding_days is not None and holding_days > condition.holding_days_above)
holding_days is not None
and holding_days > condition.holding_days_above
)
if condition.holding_days_below is not None: if condition.holding_days_below is not None:
checks.append( checks.append(holding_days is not None and holding_days < condition.holding_days_below)
holding_days is not None
and holding_days < condition.holding_days_below
)
return len(checks) > 0 and all(checks) return len(checks) > 0 and all(checks)
@@ -295,9 +303,15 @@ class ScenarioEngine:
details["volume_ratio"] = self._safe_float(market_data.get("volume_ratio")) details["volume_ratio"] = self._safe_float(market_data.get("volume_ratio"))
if condition.price_above is not None or condition.price_below is not None: if condition.price_above is not None or condition.price_below is not None:
details["current_price"] = self._safe_float(market_data.get("current_price")) details["current_price"] = self._safe_float(market_data.get("current_price"))
if condition.price_change_pct_above is not None or condition.price_change_pct_below is not None: if (
condition.price_change_pct_above is not None
or condition.price_change_pct_below is not None
):
details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct")) details["price_change_pct"] = self._safe_float(market_data.get("price_change_pct"))
if condition.unrealized_pnl_pct_above is not None or condition.unrealized_pnl_pct_below is not None: if (
condition.unrealized_pnl_pct_above is not None
or condition.unrealized_pnl_pct_below is not None
):
details["unrealized_pnl_pct"] = self._safe_float(market_data.get("unrealized_pnl_pct")) details["unrealized_pnl_pct"] = self._safe_float(market_data.get("unrealized_pnl_pct"))
if condition.holding_days_above is not None or condition.holding_days_below is not None: if condition.holding_days_above is not None or condition.holding_days_below is not None:
details["holding_days"] = self._safe_float(market_data.get("holding_days")) details["holding_days"] = self._safe_float(market_data.get("holding_days"))

View File

@@ -4,8 +4,7 @@ from __future__ import annotations
import sqlite3 import sqlite3
import sys import sys
import tempfile from datetime import UTC, datetime
from datetime import UTC, datetime, timedelta
from pathlib import Path from pathlib import Path
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
@@ -48,7 +47,9 @@ def temp_db(tmp_path: Path) -> Path:
cursor.executemany( cursor.executemany(
""" """
INSERT INTO trades (timestamp, stock_code, action, quantity, price, confidence, rationale, pnl) INSERT INTO trades (
timestamp, stock_code, action, quantity, price, confidence, rationale, pnl
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""", """,
test_trades, test_trades,
@@ -73,9 +74,7 @@ class TestBackupExporter:
exporter = BackupExporter(str(temp_db)) exporter = BackupExporter(str(temp_db))
output_dir = tmp_path / "exports" output_dir = tmp_path / "exports"
results = exporter.export_all( results = exporter.export_all(output_dir, formats=[ExportFormat.JSON], compress=False)
output_dir, formats=[ExportFormat.JSON], compress=False
)
assert ExportFormat.JSON in results assert ExportFormat.JSON in results
assert results[ExportFormat.JSON].exists() assert results[ExportFormat.JSON].exists()
@@ -86,9 +85,7 @@ class TestBackupExporter:
exporter = BackupExporter(str(temp_db)) exporter = BackupExporter(str(temp_db))
output_dir = tmp_path / "exports" output_dir = tmp_path / "exports"
results = exporter.export_all( results = exporter.export_all(output_dir, formats=[ExportFormat.JSON], compress=True)
output_dir, formats=[ExportFormat.JSON], compress=True
)
assert ExportFormat.JSON in results assert ExportFormat.JSON in results
assert results[ExportFormat.JSON].suffix == ".gz" assert results[ExportFormat.JSON].suffix == ".gz"
@@ -98,15 +95,13 @@ class TestBackupExporter:
exporter = BackupExporter(str(temp_db)) exporter = BackupExporter(str(temp_db))
output_dir = tmp_path / "exports" output_dir = tmp_path / "exports"
results = exporter.export_all( results = exporter.export_all(output_dir, formats=[ExportFormat.CSV], compress=False)
output_dir, formats=[ExportFormat.CSV], compress=False
)
assert ExportFormat.CSV in results assert ExportFormat.CSV in results
assert results[ExportFormat.CSV].exists() assert results[ExportFormat.CSV].exists()
# Verify CSV content # Verify CSV content
with open(results[ExportFormat.CSV], "r") as f: with open(results[ExportFormat.CSV]) as f:
lines = f.readlines() lines = f.readlines()
assert len(lines) == 4 # Header + 3 rows assert len(lines) == 4 # Header + 3 rows
@@ -146,7 +141,7 @@ class TestBackupExporter:
# Should only have 1 trade (AAPL on Jan 2) # Should only have 1 trade (AAPL on Jan 2)
import json import json
with open(results[ExportFormat.JSON], "r") as f: with open(results[ExportFormat.JSON]) as f:
data = json.load(f) data = json.load(f)
assert data["record_count"] == 1 assert data["record_count"] == 1
assert data["trades"][0]["stock_code"] == "AAPL" assert data["trades"][0]["stock_code"] == "AAPL"
@@ -407,9 +402,7 @@ class TestBackupExporterAdditional:
assert ExportFormat.JSON in results assert ExportFormat.JSON in results
assert ExportFormat.CSV in results assert ExportFormat.CSV in results
def test_export_all_logs_error_on_failure( def test_export_all_logs_error_on_failure(self, temp_db: Path, tmp_path: Path) -> None:
self, temp_db: Path, tmp_path: Path
) -> None:
"""export_all must log an error and continue when one format fails.""" """export_all must log an error and continue when one format fails."""
exporter = BackupExporter(str(temp_db)) exporter = BackupExporter(str(temp_db))
# Patch _export_format to raise on JSON, succeed on CSV # Patch _export_format to raise on JSON, succeed on CSV
@@ -430,9 +423,7 @@ class TestBackupExporterAdditional:
assert ExportFormat.JSON not in results assert ExportFormat.JSON not in results
assert ExportFormat.CSV in results assert ExportFormat.CSV in results
def test_export_csv_empty_trades_no_compress( def test_export_csv_empty_trades_no_compress(self, empty_db: Path, tmp_path: Path) -> None:
self, empty_db: Path, tmp_path: Path
) -> None:
"""CSV export with no trades and compress=False must write header row only.""" """CSV export with no trades and compress=False must write header row only."""
exporter = BackupExporter(str(empty_db)) exporter = BackupExporter(str(empty_db))
results = exporter.export_all( results = exporter.export_all(
@@ -446,9 +437,7 @@ class TestBackupExporterAdditional:
content = out.read_text() content = out.read_text()
assert "timestamp" in content assert "timestamp" in content
def test_export_csv_empty_trades_compressed( def test_export_csv_empty_trades_compressed(self, empty_db: Path, tmp_path: Path) -> None:
self, empty_db: Path, tmp_path: Path
) -> None:
"""CSV export with no trades and compress=True must write gzipped header.""" """CSV export with no trades and compress=True must write gzipped header."""
import gzip import gzip
@@ -465,9 +454,7 @@ class TestBackupExporterAdditional:
content = f.read() content = f.read()
assert "timestamp" in content assert "timestamp" in content
def test_export_csv_with_data_compressed( def test_export_csv_with_data_compressed(self, temp_db: Path, tmp_path: Path) -> None:
self, temp_db: Path, tmp_path: Path
) -> None:
"""CSV export with data and compress=True must write gzipped rows.""" """CSV export with data and compress=True must write gzipped rows."""
import gzip import gzip
@@ -492,6 +479,7 @@ class TestBackupExporterAdditional:
with patch.dict(sys.modules, {"pyarrow": None, "pyarrow.parquet": None}): with patch.dict(sys.modules, {"pyarrow": None, "pyarrow.parquet": None}):
try: try:
import pyarrow # noqa: F401 import pyarrow # noqa: F401
pytest.skip("pyarrow is installed; cannot test ImportError path") pytest.skip("pyarrow is installed; cannot test ImportError path")
except ImportError: except ImportError:
pass pass
@@ -557,9 +545,7 @@ class TestCloudStorage:
importlib.reload(m) importlib.reload(m)
m.CloudStorage(s3_config) m.CloudStorage(s3_config)
def test_upload_file_success( def test_upload_file_success(self, mock_boto3_module, s3_config, tmp_path: Path) -> None:
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file must call client.upload_file and return the object key.""" """upload_file must call client.upload_file and return the object key."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -572,9 +558,7 @@ class TestCloudStorage:
assert key == "backups/backup.json.gz" assert key == "backups/backup.json.gz"
storage.client.upload_file.assert_called_once() storage.client.upload_file.assert_called_once()
def test_upload_file_default_key( def test_upload_file_default_key(self, mock_boto3_module, s3_config, tmp_path: Path) -> None:
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file without object_key must use the filename as key.""" """upload_file without object_key must use the filename as key."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -586,9 +570,7 @@ class TestCloudStorage:
assert key == "myfile.gz" assert key == "myfile.gz"
def test_upload_file_not_found( def test_upload_file_not_found(self, mock_boto3_module, s3_config, tmp_path: Path) -> None:
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""upload_file must raise FileNotFoundError for missing files.""" """upload_file must raise FileNotFoundError for missing files."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -611,9 +593,7 @@ class TestCloudStorage:
with pytest.raises(RuntimeError, match="network error"): with pytest.raises(RuntimeError, match="network error"):
storage.upload_file(test_file) storage.upload_file(test_file)
def test_download_file_success( def test_download_file_success(self, mock_boto3_module, s3_config, tmp_path: Path) -> None:
self, mock_boto3_module, s3_config, tmp_path: Path
) -> None:
"""download_file must call client.download_file and return local path.""" """download_file must call client.download_file and return local path."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -637,11 +617,8 @@ class TestCloudStorage:
with pytest.raises(RuntimeError, match="timeout"): with pytest.raises(RuntimeError, match="timeout"):
storage.download_file("key", tmp_path / "dest.gz") storage.download_file("key", tmp_path / "dest.gz")
def test_list_files_returns_objects( def test_list_files_returns_objects(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""list_files must return parsed file metadata from S3 response.""" """list_files must return parsed file metadata from S3 response."""
from datetime import timezone
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -651,7 +628,7 @@ class TestCloudStorage:
{ {
"Key": "backups/a.gz", "Key": "backups/a.gz",
"Size": 1024, "Size": 1024,
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc), "LastModified": datetime(2026, 1, 1, tzinfo=UTC),
"ETag": '"abc123"', "ETag": '"abc123"',
} }
] ]
@@ -662,9 +639,7 @@ class TestCloudStorage:
assert files[0]["key"] == "backups/a.gz" assert files[0]["key"] == "backups/a.gz"
assert files[0]["size_bytes"] == 1024 assert files[0]["size_bytes"] == 1024
def test_list_files_empty_bucket( def test_list_files_empty_bucket(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""list_files must return empty list when bucket has no objects.""" """list_files must return empty list when bucket has no objects."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -674,9 +649,7 @@ class TestCloudStorage:
files = storage.list_files() files = storage.list_files()
assert files == [] assert files == []
def test_list_files_propagates_error( def test_list_files_propagates_error(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""list_files must re-raise exceptions from the boto3 client.""" """list_files must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -686,9 +659,7 @@ class TestCloudStorage:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
storage.list_files() storage.list_files()
def test_delete_file_success( def test_delete_file_success(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""delete_file must call client.delete_object with the correct key.""" """delete_file must call client.delete_object with the correct key."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -698,9 +669,7 @@ class TestCloudStorage:
Bucket="test-bucket", Key="backups/old.gz" Bucket="test-bucket", Key="backups/old.gz"
) )
def test_delete_file_propagates_error( def test_delete_file_propagates_error(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""delete_file must re-raise exceptions from the boto3 client.""" """delete_file must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -710,11 +679,8 @@ class TestCloudStorage:
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
storage.delete_file("backups/old.gz") storage.delete_file("backups/old.gz")
def test_get_storage_stats_success( def test_get_storage_stats_success(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""get_storage_stats must aggregate file sizes correctly.""" """get_storage_stats must aggregate file sizes correctly."""
from datetime import timezone
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -724,13 +690,13 @@ class TestCloudStorage:
{ {
"Key": "a.gz", "Key": "a.gz",
"Size": 1024 * 1024, "Size": 1024 * 1024,
"LastModified": datetime(2026, 1, 1, tzinfo=timezone.utc), "LastModified": datetime(2026, 1, 1, tzinfo=UTC),
"ETag": '"x"', "ETag": '"x"',
}, },
{ {
"Key": "b.gz", "Key": "b.gz",
"Size": 1024 * 1024, "Size": 1024 * 1024,
"LastModified": datetime(2026, 1, 2, tzinfo=timezone.utc), "LastModified": datetime(2026, 1, 2, tzinfo=UTC),
"ETag": '"y"', "ETag": '"y"',
}, },
] ]
@@ -741,9 +707,7 @@ class TestCloudStorage:
assert stats["total_size_bytes"] == 2 * 1024 * 1024 assert stats["total_size_bytes"] == 2 * 1024 * 1024
assert stats["total_size_mb"] == pytest.approx(2.0) assert stats["total_size_mb"] == pytest.approx(2.0)
def test_get_storage_stats_on_error( def test_get_storage_stats_on_error(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""get_storage_stats must return error dict without raising on failure.""" """get_storage_stats must return error dict without raising on failure."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -754,9 +718,7 @@ class TestCloudStorage:
assert "error" in stats assert "error" in stats
assert stats["total_files"] == 0 assert stats["total_files"] == 0
def test_verify_connection_success( def test_verify_connection_success(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""verify_connection must return True when head_bucket succeeds.""" """verify_connection must return True when head_bucket succeeds."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -764,9 +726,7 @@ class TestCloudStorage:
result = storage.verify_connection() result = storage.verify_connection()
assert result is True assert result is True
def test_verify_connection_failure( def test_verify_connection_failure(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""verify_connection must return False when head_bucket raises.""" """verify_connection must return False when head_bucket raises."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -776,9 +736,7 @@ class TestCloudStorage:
result = storage.verify_connection() result = storage.verify_connection()
assert result is False assert result is False
def test_enable_versioning( def test_enable_versioning(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""enable_versioning must call put_bucket_versioning.""" """enable_versioning must call put_bucket_versioning."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage
@@ -786,9 +744,7 @@ class TestCloudStorage:
storage.enable_versioning() storage.enable_versioning()
storage.client.put_bucket_versioning.assert_called_once() storage.client.put_bucket_versioning.assert_called_once()
def test_enable_versioning_propagates_error( def test_enable_versioning_propagates_error(self, mock_boto3_module, s3_config) -> None:
self, mock_boto3_module, s3_config
) -> None:
"""enable_versioning must re-raise exceptions from the boto3 client.""" """enable_versioning must re-raise exceptions from the boto3 client."""
from src.backup.cloud_storage import CloudStorage from src.backup.cloud_storage import CloudStorage

View File

@@ -323,7 +323,8 @@ class TestPromptOverride:
# Verify the custom prompt was sent, not a built prompt # Verify the custom prompt was sent, not a built prompt
mock_generate.assert_called_once() mock_generate.assert_called_once()
actual_prompt = mock_generate.call_args[1].get( actual_prompt = mock_generate.call_args[1].get(
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None "contents",
mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None,
) )
assert actual_prompt == custom_prompt assert actual_prompt == custom_prompt
# Raw response preserved in rationale without parse_response (#247) # Raw response preserved in rationale without parse_response (#247)
@@ -385,7 +386,8 @@ class TestPromptOverride:
await client.decide(market_data) await client.decide(market_data)
actual_prompt = mock_generate.call_args[1].get( actual_prompt = mock_generate.call_args[1].get(
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None "contents",
mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None,
) )
# The custom prompt must be used, not the compressed prompt # The custom prompt must be used, not the compressed prompt
assert actual_prompt == custom_prompt assert actual_prompt == custom_prompt
@@ -411,7 +413,8 @@ class TestPromptOverride:
await client.decide(market_data) await client.decide(market_data)
actual_prompt = mock_generate.call_args[1].get( actual_prompt = mock_generate.call_args[1].get(
"contents", mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None "contents",
mock_generate.call_args[0][1] if len(mock_generate.call_args[0]) > 1 else None,
) )
# Should contain stock code from build_prompt, not be a custom override # Should contain stock code from build_prompt, not be a custom override
assert "005930" in actual_prompt assert "005930" in actual_prompt

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@@ -99,7 +99,10 @@ class TestTokenManagement:
mock_resp_403 = AsyncMock() mock_resp_403 = AsyncMock()
mock_resp_403.status = 403 mock_resp_403.status = 403
mock_resp_403.text = AsyncMock( mock_resp_403.text = AsyncMock(
return_value='{"error_code":"EGW00133","error_description":"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}' return_value=(
'{"error_code":"EGW00133","error_description":'
'"접근토큰 발급 잠시 후 다시 시도하세요(1분당 1회)"}'
)
) )
mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403) mock_resp_403.__aenter__ = AsyncMock(return_value=mock_resp_403)
mock_resp_403.__aexit__ = AsyncMock(return_value=False) mock_resp_403.__aexit__ = AsyncMock(return_value=False)
@@ -232,9 +235,7 @@ class TestRateLimiter:
mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp) mock_order_resp.__aenter__ = AsyncMock(return_value=mock_order_resp)
mock_order_resp.__aexit__ = AsyncMock(return_value=False) mock_order_resp.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]):
"aiohttp.ClientSession.post", side_effect=[mock_hash_resp, mock_order_resp]
):
with patch.object( with patch.object(
broker._rate_limiter, "acquire", new_callable=AsyncMock broker._rate_limiter, "acquire", new_callable=AsyncMock
) as mock_acquire: ) as mock_acquire:
@@ -405,7 +406,7 @@ class TestFetchMarketRankings:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from src.broker.kis_api import kr_tick_unit, kr_round_down # noqa: E402 from src.broker.kis_api import kr_round_down, kr_tick_unit # noqa: E402
class TestKrTickUnit: class TestKrTickUnit:
@@ -435,13 +436,13 @@ class TestKrTickUnit:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"price, expected_rounded", "price, expected_rounded",
[ [
(188150, 188100), # 100원 단위, 50원 잔여 → 내림 (188150, 188100), # 100원 단위, 50원 잔여 → 내림
(188100, 188100), # 이미 정렬됨 (188100, 188100), # 이미 정렬됨
(75050, 75000), # 100원 단위, 50원 잔여 → 내림 (75050, 75000), # 100원 단위, 50원 잔여 → 내림
(49950, 49950), # 50원 단위 정렬됨 (49950, 49950), # 50원 단위 정렬됨
(49960, 49950), # 50원 단위, 10원 잔여 → 내림 (49960, 49950), # 50원 단위, 10원 잔여 → 내림
(1999, 1999), # 1원 단위 → 그대로 (1999, 1999), # 1원 단위 → 그대로
(5003, 5000), # 10원 단위, 3원 잔여 → 내림 (5003, 5000), # 10원 단위, 3원 잔여 → 내림
], ],
) )
def test_round_down_to_tick(self, price: int, expected_rounded: int) -> None: def test_round_down_to_tick(self, price: int, expected_rounded: int) -> None:
@@ -538,15 +539,13 @@ class TestSendOrderTickRounding:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1, price=188150) await broker.send_order("005930", "BUY", 1, price=188150)
order_call = mock_post.call_args_list[1] order_call = mock_post.call_args_list[1]
body = order_call[1].get("json", {}) body = order_call[1].get("json", {})
assert body["ORD_UNPR"] == "188100" # rounded down assert body["ORD_UNPR"] == "188100" # rounded down
assert body["ORD_DVSN"] == "00" # 지정가 assert body["ORD_DVSN"] == "00" # 지정가
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_limit_order_ord_dvsn_is_00(self, broker: KISBroker) -> None: async def test_limit_order_ord_dvsn_is_00(self, broker: KISBroker) -> None:
@@ -563,9 +562,7 @@ class TestSendOrderTickRounding:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1, price=50000) await broker.send_order("005930", "BUY", 1, price=50000)
order_call = mock_post.call_args_list[1] order_call = mock_post.call_args_list[1]
@@ -587,9 +584,7 @@ class TestSendOrderTickRounding:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "SELL", 1, price=0) await broker.send_order("005930", "SELL", 1, price=0)
order_call = mock_post.call_args_list[1] order_call = mock_post.call_args_list[1]
@@ -628,9 +623,7 @@ class TestTRIDBranchingDomestic:
broker = self._make_broker(settings, "paper") broker = self._make_broker(settings, "paper")
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
mock_resp.json = AsyncMock( mock_resp.json = AsyncMock(return_value={"output1": [], "output2": {}})
return_value={"output1": [], "output2": {}}
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False) mock_resp.__aexit__ = AsyncMock(return_value=False)
@@ -645,9 +638,7 @@ class TestTRIDBranchingDomestic:
broker = self._make_broker(settings, "live") broker = self._make_broker(settings, "live")
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
mock_resp.json = AsyncMock( mock_resp.json = AsyncMock(return_value={"output1": [], "output2": {}})
return_value={"output1": [], "output2": {}}
)
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
mock_resp.__aexit__ = AsyncMock(return_value=False) mock_resp.__aexit__ = AsyncMock(return_value=False)
@@ -672,9 +663,7 @@ class TestTRIDBranchingDomestic:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1) await broker.send_order("005930", "BUY", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})
@@ -695,9 +684,7 @@ class TestTRIDBranchingDomestic:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "BUY", 1) await broker.send_order("005930", "BUY", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})
@@ -718,9 +705,7 @@ class TestTRIDBranchingDomestic:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "SELL", 1) await broker.send_order("005930", "SELL", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})
@@ -741,9 +726,7 @@ class TestTRIDBranchingDomestic:
mock_order.__aenter__ = AsyncMock(return_value=mock_order) mock_order.__aenter__ = AsyncMock(return_value=mock_order)
mock_order.__aexit__ = AsyncMock(return_value=False) mock_order.__aexit__ = AsyncMock(return_value=False)
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.send_order("005930", "SELL", 1) await broker.send_order("005930", "SELL", 1)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})
@@ -788,9 +771,7 @@ class TestGetDomesticPendingOrders:
mock_get.assert_not_called() mock_get.assert_not_called()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_live_mode_calls_tttc0084r_with_correct_params( async def test_live_mode_calls_tttc0084r_with_correct_params(self, settings) -> None:
self, settings
) -> None:
"""Live mode must call TTTC0084R with INQR_DVSN_1/2 and paging params.""" """Live mode must call TTTC0084R with INQR_DVSN_1/2 and paging params."""
broker = self._make_broker(settings, "live") broker = self._make_broker(settings, "live")
pending = [{"odno": "001", "pdno": "005930", "psbl_qty": "10"}] pending = [{"odno": "001", "pdno": "005930", "psbl_qty": "10"}]
@@ -872,9 +853,7 @@ class TestCancelDomesticOrder:
broker = self._make_broker(settings, "live") broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5) await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})
@@ -886,9 +865,7 @@ class TestCancelDomesticOrder:
broker = self._make_broker(settings, "paper") broker = self._make_broker(settings, "paper")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5) await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})
@@ -900,9 +877,7 @@ class TestCancelDomesticOrder:
broker = self._make_broker(settings, "live") broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5) await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 5)
body = mock_post.call_args_list[1][1].get("json", {}) body = mock_post.call_args_list[1][1].get("json", {})
@@ -916,9 +891,7 @@ class TestCancelDomesticOrder:
broker = self._make_broker(settings, "live") broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD123", "BRN456", 3) await broker.cancel_domestic_order("005930", "ORD123", "BRN456", 3)
body = mock_post.call_args_list[1][1].get("json", {}) body = mock_post.call_args_list[1][1].get("json", {})
@@ -932,9 +905,7 @@ class TestCancelDomesticOrder:
broker = self._make_broker(settings, "live") broker = self._make_broker(settings, "live")
mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"}) mock_hash, mock_order = self._make_post_mocks({"rt_cd": "0"})
with patch( with patch("aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]) as mock_post:
"aiohttp.ClientSession.post", side_effect=[mock_hash, mock_order]
) as mock_post:
await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 2) await broker.cancel_domestic_order("005930", "ORD001", "BRNO01", 2)
order_headers = mock_post.call_args_list[1][1].get("headers", {}) order_headers = mock_post.call_args_list[1][1].get("headers", {})

View File

@@ -77,9 +77,7 @@ class TestContextStore:
# Latest by updated_at, which should be the last one set # Latest by updated_at, which should be the last one set
assert latest == "2026-02-02" assert latest == "2026-02-02"
def test_delete_old_contexts( def test_delete_old_contexts(self, store: ContextStore, db_conn: sqlite3.Connection) -> None:
self, store: ContextStore, db_conn: sqlite3.Connection
) -> None:
"""Test deleting contexts older than a cutoff date.""" """Test deleting contexts older than a cutoff date."""
# Insert contexts with specific old timestamps # Insert contexts with specific old timestamps
# (bypassing set_context which uses current time) # (bypassing set_context which uses current time)
@@ -170,9 +168,7 @@ class TestContextAggregator:
log_trade(db_conn, "035720", "HOLD", 75, "Wait", quantity=0, price=0, pnl=0) log_trade(db_conn, "035720", "HOLD", 75, "Wait", quantity=0, price=0, pnl=0)
# Manually set timestamps to the target date # Manually set timestamps to the target date
db_conn.execute( db_conn.execute(f"UPDATE trades SET timestamp = '{date}T10:00:00+00:00'")
f"UPDATE trades SET timestamp = '{date}T10:00:00+00:00'"
)
db_conn.commit() db_conn.commit()
# Aggregate # Aggregate
@@ -194,18 +190,10 @@ class TestContextAggregator:
week = "2026-W06" week = "2026-W06"
# Set daily contexts # Set daily contexts
aggregator.store.set_context( aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "total_pnl_KR", 100.0)
ContextLayer.L6_DAILY, "2026-02-02", "total_pnl_KR", 100.0 aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "total_pnl_KR", 200.0)
) aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence_KR", 80.0)
aggregator.store.set_context( aggregator.store.set_context(ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence_KR", 85.0)
ContextLayer.L6_DAILY, "2026-02-03", "total_pnl_KR", 200.0
)
aggregator.store.set_context(
ContextLayer.L6_DAILY, "2026-02-02", "avg_confidence_KR", 80.0
)
aggregator.store.set_context(
ContextLayer.L6_DAILY, "2026-02-03", "avg_confidence_KR", 85.0
)
# Aggregate # Aggregate
aggregator.aggregate_weekly_from_daily(week) aggregator.aggregate_weekly_from_daily(week)
@@ -223,15 +211,9 @@ class TestContextAggregator:
month = "2026-02" month = "2026-02"
# Set weekly contexts # Set weekly contexts
aggregator.store.set_context( aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl_KR", 100.0)
ContextLayer.L5_WEEKLY, "2026-W05", "weekly_pnl_KR", 100.0 aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl_KR", 200.0)
) aggregator.store.set_context(ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl_KR", 150.0)
aggregator.store.set_context(
ContextLayer.L5_WEEKLY, "2026-W06", "weekly_pnl_KR", 200.0
)
aggregator.store.set_context(
ContextLayer.L5_WEEKLY, "2026-W07", "weekly_pnl_KR", 150.0
)
# Aggregate # Aggregate
aggregator.aggregate_monthly_from_weekly(month) aggregator.aggregate_monthly_from_weekly(month)
@@ -316,6 +298,7 @@ class TestContextAggregator:
store = aggregator.store store = aggregator.store
assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 1000.0 assert store.get_context(ContextLayer.L6_DAILY, date, "total_pnl_KR") == 1000.0
from datetime import date as date_cls from datetime import date as date_cls
trade_date = date_cls.fromisoformat(date) trade_date = date_cls.fromisoformat(date)
iso_year, iso_week, _ = trade_date.isocalendar() iso_year, iso_week, _ = trade_date.isocalendar()
trade_week = f"{iso_year}-W{iso_week:02d}" trade_week = f"{iso_year}-W{iso_week:02d}"
@@ -324,7 +307,9 @@ class TestContextAggregator:
trade_quarter = f"{trade_date.year}-Q{(trade_date.month - 1) // 3 + 1}" trade_quarter = f"{trade_date.year}-Q{(trade_date.month - 1) // 3 + 1}"
trade_year = str(trade_date.year) trade_year = str(trade_date.year)
assert store.get_context(ContextLayer.L4_MONTHLY, trade_month, "monthly_pnl") == 1000.0 assert store.get_context(ContextLayer.L4_MONTHLY, trade_month, "monthly_pnl") == 1000.0
assert store.get_context(ContextLayer.L3_QUARTERLY, trade_quarter, "quarterly_pnl") == 1000.0 assert (
store.get_context(ContextLayer.L3_QUARTERLY, trade_quarter, "quarterly_pnl") == 1000.0
)
assert store.get_context(ContextLayer.L2_ANNUAL, trade_year, "annual_pnl") == 1000.0 assert store.get_context(ContextLayer.L2_ANNUAL, trade_year, "annual_pnl") == 1000.0
@@ -429,9 +414,7 @@ class TestContextSummarizer:
# summarize_layer # summarize_layer
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def test_summarize_layer_no_data( def test_summarize_layer_no_data(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""summarize_layer with no data must return the 'No data' sentinel.""" """summarize_layer with no data must return the 'No data' sentinel."""
result = summarizer.summarize_layer(ContextLayer.L6_DAILY) result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
assert result["count"] == 0 assert result["count"] == 0
@@ -448,15 +431,12 @@ class TestContextSummarizer:
result = summarizer.summarize_layer(ContextLayer.L6_DAILY) result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
assert "total_entries" in result assert "total_entries" in result
def test_summarize_layer_with_dict_values( def test_summarize_layer_with_dict_values(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""summarize_layer must handle dict values by extracting numeric subkeys.""" """summarize_layer must handle dict values by extracting numeric subkeys."""
store = summarizer.store store = summarizer.store
# set_context serialises the value as JSON, so passing a dict works # set_context serialises the value as JSON, so passing a dict works
store.set_context( store.set_context(
ContextLayer.L6_DAILY, "2026-02-01", "metrics", ContextLayer.L6_DAILY, "2026-02-01", "metrics", {"win_rate": 65.0, "label": "good"}
{"win_rate": 65.0, "label": "good"}
) )
result = summarizer.summarize_layer(ContextLayer.L6_DAILY) result = summarizer.summarize_layer(ContextLayer.L6_DAILY)
@@ -464,9 +444,7 @@ class TestContextSummarizer:
# numeric subkey "win_rate" should appear as "metrics.win_rate" # numeric subkey "win_rate" should appear as "metrics.win_rate"
assert "metrics.win_rate" in result assert "metrics.win_rate" in result
def test_summarize_layer_with_string_values( def test_summarize_layer_with_string_values(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""summarize_layer must count string values separately.""" """summarize_layer must count string values separately."""
store = summarizer.store store = summarizer.store
# set_context stores string values as JSON-encoded strings # set_context stores string values as JSON-encoded strings
@@ -480,9 +458,7 @@ class TestContextSummarizer:
# rolling_window_summary # rolling_window_summary
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def test_rolling_window_summary_basic( def test_rolling_window_summary_basic(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""rolling_window_summary must return the expected structure.""" """rolling_window_summary must return the expected structure."""
store = summarizer.store store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 500.0) store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 500.0)
@@ -492,22 +468,16 @@ class TestContextSummarizer:
assert "recent_data" in result assert "recent_data" in result
assert "historical_summary" in result assert "historical_summary" in result
def test_rolling_window_summary_no_older_data( def test_rolling_window_summary_no_older_data(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""rolling_window_summary with summarize_older=False skips history.""" """rolling_window_summary with summarize_older=False skips history."""
result = summarizer.rolling_window_summary( result = summarizer.rolling_window_summary(ContextLayer.L6_DAILY, summarize_older=False)
ContextLayer.L6_DAILY, summarize_older=False
)
assert result["historical_summary"] == {} assert result["historical_summary"] == {}
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# aggregate_to_higher_layer # aggregate_to_higher_layer
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def test_aggregate_to_higher_layer_mean( def test_aggregate_to_higher_layer_mean(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'mean' via dict subkeys returns average.""" """aggregate_to_higher_layer with 'mean' via dict subkeys returns average."""
store = summarizer.store store = summarizer.store
# Use different outer keys but same inner metric key so get_all_contexts # Use different outer keys but same inner metric key so get_all_contexts
@@ -520,9 +490,7 @@ class TestContextSummarizer:
) )
assert result == pytest.approx(150.0) assert result == pytest.approx(150.0)
def test_aggregate_to_higher_layer_sum( def test_aggregate_to_higher_layer_sum(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'sum' must return the total.""" """aggregate_to_higher_layer with 'sum' must return the total."""
store = summarizer.store store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0}) store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
@@ -533,9 +501,7 @@ class TestContextSummarizer:
) )
assert result == pytest.approx(300.0) assert result == pytest.approx(300.0)
def test_aggregate_to_higher_layer_max( def test_aggregate_to_higher_layer_max(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'max' must return the maximum.""" """aggregate_to_higher_layer with 'max' must return the maximum."""
store = summarizer.store store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0}) store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
@@ -546,9 +512,7 @@ class TestContextSummarizer:
) )
assert result == pytest.approx(200.0) assert result == pytest.approx(200.0)
def test_aggregate_to_higher_layer_min( def test_aggregate_to_higher_layer_min(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with 'min' must return the minimum.""" """aggregate_to_higher_layer with 'min' must return the minimum."""
store = summarizer.store store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0}) store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "day1", {"pnl": 100.0})
@@ -559,9 +523,7 @@ class TestContextSummarizer:
) )
assert result == pytest.approx(100.0) assert result == pytest.approx(100.0)
def test_aggregate_to_higher_layer_no_data( def test_aggregate_to_higher_layer_no_data(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""aggregate_to_higher_layer with no matching key must return None.""" """aggregate_to_higher_layer with no matching key must return None."""
result = summarizer.aggregate_to_higher_layer( result = summarizer.aggregate_to_higher_layer(
ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "nonexistent", "mean" ContextLayer.L6_DAILY, ContextLayer.L5_WEEKLY, "nonexistent", "mean"
@@ -585,9 +547,7 @@ class TestContextSummarizer:
# create_compact_summary + format_summary_for_prompt # create_compact_summary + format_summary_for_prompt
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def test_create_compact_summary( def test_create_compact_summary(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""create_compact_summary must produce a dict keyed by layer value.""" """create_compact_summary must produce a dict keyed by layer value."""
store = summarizer.store store = summarizer.store
store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0) store.set_context(ContextLayer.L6_DAILY, "2026-02-01", "pnl", 100.0)
@@ -615,9 +575,7 @@ class TestContextSummarizer:
text = summarizer.format_summary_for_prompt(summary) text = summarizer.format_summary_for_prompt(summary)
assert text == "" assert text == ""
def test_format_summary_non_dict_value( def test_format_summary_non_dict_value(self, summarizer: ContextSummarizer) -> None:
self, summarizer: ContextSummarizer
) -> None:
"""format_summary_for_prompt must render non-dict values as plain text.""" """format_summary_for_prompt must render non-dict values as plain text."""
summary = { summary = {
"daily": { "daily": {

View File

@@ -4,6 +4,7 @@ from __future__ import annotations
import json import json
import sqlite3 import sqlite3
from datetime import UTC, datetime
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@@ -16,8 +17,6 @@ from src.evolution.daily_review import DailyReviewer
from src.evolution.scorecard import DailyScorecard from src.evolution.scorecard import DailyScorecard
from src.logging.decision_logger import DecisionLogger from src.logging.decision_logger import DecisionLogger
from datetime import UTC, datetime
TODAY = datetime.now(UTC).strftime("%Y-%m-%d") TODAY = datetime.now(UTC).strftime("%Y-%m-%d")
@@ -53,7 +52,8 @@ def _log_decision(
def test_generate_scorecard_market_scoped( def test_generate_scorecard_market_scoped(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store) reviewer = DailyReviewer(db_conn, context_store)
logger = DecisionLogger(db_conn) logger = DecisionLogger(db_conn)
@@ -134,7 +134,8 @@ def test_generate_scorecard_market_scoped(
def test_generate_scorecard_top_winners_and_losers( def test_generate_scorecard_top_winners_and_losers(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store) reviewer = DailyReviewer(db_conn, context_store)
logger = DecisionLogger(db_conn) logger = DecisionLogger(db_conn)
@@ -168,7 +169,8 @@ def test_generate_scorecard_top_winners_and_losers(
def test_generate_scorecard_empty_day( def test_generate_scorecard_empty_day(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store) reviewer = DailyReviewer(db_conn, context_store)
scorecard = reviewer.generate_scorecard(TODAY, "KR") scorecard = reviewer.generate_scorecard(TODAY, "KR")
@@ -184,7 +186,8 @@ def test_generate_scorecard_empty_day(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_lessons_without_gemini_returns_empty( async def test_generate_lessons_without_gemini_returns_empty(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store, gemini_client=None) reviewer = DailyReviewer(db_conn, context_store, gemini_client=None)
lessons = await reviewer.generate_lessons( lessons = await reviewer.generate_lessons(
@@ -206,7 +209,8 @@ async def test_generate_lessons_without_gemini_returns_empty(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_lessons_parses_json_array( async def test_generate_lessons_parses_json_array(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
mock_gemini = MagicMock() mock_gemini = MagicMock()
mock_gemini.decide = AsyncMock( mock_gemini.decide = AsyncMock(
@@ -233,7 +237,8 @@ async def test_generate_lessons_parses_json_array(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_lessons_fallback_to_lines( async def test_generate_lessons_fallback_to_lines(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
mock_gemini = MagicMock() mock_gemini = MagicMock()
mock_gemini.decide = AsyncMock( mock_gemini.decide = AsyncMock(
@@ -260,7 +265,8 @@ async def test_generate_lessons_fallback_to_lines(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_lessons_handles_gemini_error( async def test_generate_lessons_handles_gemini_error(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
mock_gemini = MagicMock() mock_gemini = MagicMock()
mock_gemini.decide = AsyncMock(side_effect=RuntimeError("boom")) mock_gemini.decide = AsyncMock(side_effect=RuntimeError("boom"))
@@ -284,7 +290,8 @@ async def test_generate_lessons_handles_gemini_error(
def test_store_scorecard_in_context( def test_store_scorecard_in_context(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store) reviewer = DailyReviewer(db_conn, context_store)
scorecard = DailyScorecard( scorecard = DailyScorecard(
@@ -316,7 +323,8 @@ def test_store_scorecard_in_context(
def test_store_scorecard_key_is_market_scoped( def test_store_scorecard_key_is_market_scoped(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store) reviewer = DailyReviewer(db_conn, context_store)
kr = DailyScorecard( kr = DailyScorecard(
@@ -357,7 +365,8 @@ def test_store_scorecard_key_is_market_scoped(
def test_generate_scorecard_handles_invalid_context_snapshot( def test_generate_scorecard_handles_invalid_context_snapshot(
db_conn: sqlite3.Connection, context_store: ContextStore, db_conn: sqlite3.Connection,
context_store: ContextStore,
) -> None: ) -> None:
reviewer = DailyReviewer(db_conn, context_store) reviewer = DailyReviewer(db_conn, context_store)
db_conn.execute( db_conn.execute(

View File

@@ -355,6 +355,7 @@ def test_positions_empty_when_no_trades(tmp_path: Path) -> None:
def _seed_cb_context(conn: sqlite3.Connection, pnl_pct: float, market: str = "KR") -> None: def _seed_cb_context(conn: sqlite3.Connection, pnl_pct: float, market: str = "KR") -> None:
import json as _json import json as _json
conn.execute( conn.execute(
"INSERT OR REPLACE INTO system_metrics (key, value, updated_at) VALUES (?, ?, ?)", "INSERT OR REPLACE INTO system_metrics (key, value, updated_at) VALUES (?, ?, ?)",
( (

View File

@@ -79,7 +79,7 @@ class TestNewsAPI:
# Mock the fetch to avoid real API call # Mock the fetch to avoid real API call
with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch: with patch.object(api, "_fetch_news", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = None mock_fetch.return_value = None
result = await api.get_news_sentiment("AAPL") await api.get_news_sentiment("AAPL")
# Should have attempted refetch since cache expired # Should have attempted refetch since cache expired
mock_fetch.assert_called_once_with("AAPL") mock_fetch.assert_called_once_with("AAPL")
@@ -111,9 +111,7 @@ class TestNewsAPI:
"source": "Reuters", "source": "Reuters",
"time_published": "2026-02-04T10:00:00", "time_published": "2026-02-04T10:00:00",
"url": "https://example.com/1", "url": "https://example.com/1",
"ticker_sentiment": [ "ticker_sentiment": [{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}],
{"ticker": "AAPL", "ticker_sentiment_score": "0.85"}
],
"overall_sentiment_score": "0.75", "overall_sentiment_score": "0.75",
}, },
{ {
@@ -122,9 +120,7 @@ class TestNewsAPI:
"source": "Bloomberg", "source": "Bloomberg",
"time_published": "2026-02-04T09:00:00", "time_published": "2026-02-04T09:00:00",
"url": "https://example.com/2", "url": "https://example.com/2",
"ticker_sentiment": [ "ticker_sentiment": [{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}],
{"ticker": "AAPL", "ticker_sentiment_score": "-0.3"}
],
"overall_sentiment_score": "-0.2", "overall_sentiment_score": "-0.2",
}, },
] ]
@@ -661,7 +657,9 @@ class TestGeminiClientWithExternalData:
) )
# Mock the Gemini API call # Mock the Gemini API call
with patch.object(client._client.aio.models, "generate_content", new_callable=AsyncMock) as mock_gen: with patch.object(
client._client.aio.models, "generate_content", new_callable=AsyncMock
) as mock_gen:
mock_response = MagicMock() mock_response = MagicMock()
mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}' mock_response.text = '{"action": "BUY", "confidence": 85, "rationale": "Good news"}'
mock_gen.return_value = mock_response mock_gen.return_value = mock_response

View File

@@ -1,7 +1,7 @@
"""Tests for database helper functions.""" """Tests for database helper functions."""
import tempfile
import os import os
import tempfile
from src.db import get_latest_buy_trade, get_open_position, init_db, log_trade from src.db import get_latest_buy_trade, get_open_position, init_db, log_trade
@@ -204,7 +204,8 @@ def test_mode_migration_adds_column_to_existing_db() -> None:
assert "strategy_pnl" in columns assert "strategy_pnl" in columns
assert "fx_pnl" in columns assert "fx_pnl" in columns
migrated = conn.execute( migrated = conn.execute(
"SELECT pnl, strategy_pnl, fx_pnl, session_id FROM trades WHERE stock_code='AAPL' LIMIT 1" "SELECT pnl, strategy_pnl, fx_pnl, session_id "
"FROM trades WHERE stock_code='AAPL' LIMIT 1"
).fetchone() ).fetchone()
assert migrated is not None assert migrated is not None
assert migrated[0] == 123.45 assert migrated[0] == 123.45
@@ -407,9 +408,7 @@ def test_decision_logs_session_id_migration_backfills_unknown() -> None:
conn = init_db(db_path) conn = init_db(db_path)
columns = {row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()} columns = {row[1] for row in conn.execute("PRAGMA table_info(decision_logs)").fetchall()}
assert "session_id" in columns assert "session_id" in columns
row = conn.execute( row = conn.execute("SELECT session_id FROM decision_logs WHERE decision_id='d1'").fetchone()
"SELECT session_id FROM decision_logs WHERE decision_id='d1'"
).fetchone()
assert row is not None assert row is not None
assert row[0] == "UNKNOWN" assert row[0] == "UNKNOWN"
conn.close() conn.close()

View File

@@ -49,7 +49,10 @@ def test_log_decision_creates_record(logger: DecisionLogger, db_conn: sqlite3.Co
# Verify record exists in database # Verify record exists in database
cursor = db_conn.execute( cursor = db_conn.execute(
"SELECT decision_id, action, confidence, session_id FROM decision_logs WHERE decision_id = ?", (
"SELECT decision_id, action, confidence, session_id "
"FROM decision_logs WHERE decision_id = ?"
),
(decision_id,), (decision_id,),
) )
row = cursor.fetchone() row = cursor.fetchone()

View File

@@ -208,7 +208,9 @@ def test_identify_failure_patterns_empty(optimizer: EvolutionOptimizer) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp_path: Path) -> None: async def test_generate_strategy_creates_file(
optimizer: EvolutionOptimizer, tmp_path: Path
) -> None:
"""Test that generate_strategy creates a strategy file.""" """Test that generate_strategy creates a strategy file."""
failures = [ failures = [
{ {
@@ -234,7 +236,9 @@ async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp
return {"action": "HOLD", "confidence": 50, "rationale": "Waiting"} return {"action": "HOLD", "confidence": 50, "rationale": "Waiting"}
""" """
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): with patch.object(
optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)
):
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
strategy_path = await optimizer.generate_strategy(failures) strategy_path = await optimizer.generate_strategy(failures)
@@ -247,7 +251,8 @@ async def test_generate_strategy_creates_file(optimizer: EvolutionOptimizer, tmp
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_strategy_saves_valid_python_code( async def test_generate_strategy_saves_valid_python_code(
optimizer: EvolutionOptimizer, tmp_path: Path, optimizer: EvolutionOptimizer,
tmp_path: Path,
) -> None: ) -> None:
"""Test that syntactically valid generated code is saved.""" """Test that syntactically valid generated code is saved."""
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}] failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
@@ -255,12 +260,14 @@ async def test_generate_strategy_saves_valid_python_code(
mock_response = Mock() mock_response = Mock()
mock_response.text = ( mock_response.text = (
'price = market_data.get("current_price", 0)\n' 'price = market_data.get("current_price", 0)\n'
'if price > 0:\n' "if price > 0:\n"
' return {"action": "BUY", "confidence": 80, "rationale": "Positive price"}\n' ' return {"action": "BUY", "confidence": 80, "rationale": "Positive price"}\n'
'return {"action": "HOLD", "confidence": 50, "rationale": "No signal"}\n' 'return {"action": "HOLD", "confidence": 50, "rationale": "No signal"}\n'
) )
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): with patch.object(
optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)
):
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
strategy_path = await optimizer.generate_strategy(failures) strategy_path = await optimizer.generate_strategy(failures)
@@ -270,7 +277,9 @@ async def test_generate_strategy_saves_valid_python_code(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_strategy_blocks_invalid_python_code( async def test_generate_strategy_blocks_invalid_python_code(
optimizer: EvolutionOptimizer, tmp_path: Path, caplog: pytest.LogCaptureFixture, optimizer: EvolutionOptimizer,
tmp_path: Path,
caplog: pytest.LogCaptureFixture,
) -> None: ) -> None:
"""Test that syntactically invalid generated code is not saved.""" """Test that syntactically invalid generated code is not saved."""
failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}] failures = [{"decision_id": "1", "timestamp": "2024-01-15T09:30:00+00:00"}]
@@ -281,7 +290,9 @@ async def test_generate_strategy_blocks_invalid_python_code(
' return {"action": "BUY", "confidence": 80, "rationale": "broken"}\n' ' return {"action": "BUY", "confidence": 80, "rationale": "broken"}\n'
) )
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): with patch.object(
optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)
):
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
with caplog.at_level("WARNING"): with caplog.at_level("WARNING"):
strategy_path = await optimizer.generate_strategy(failures) strategy_path = await optimizer.generate_strategy(failures)
@@ -310,6 +321,7 @@ def test_get_performance_summary() -> None:
"""Test getting performance summary from trades table.""" """Test getting performance summary from trades table."""
# Create a temporary database with trades # Create a temporary database with trades
import tempfile import tempfile
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp:
tmp_path = tmp.name tmp_path = tmp.name
@@ -604,7 +616,9 @@ def test_calculate_improvement_trend_declining(performance_tracker: PerformanceT
assert trend["pnl_change"] == -250.0 assert trend["pnl_change"] == -250.0
def test_calculate_improvement_trend_insufficient_data(performance_tracker: PerformanceTracker) -> None: def test_calculate_improvement_trend_insufficient_data(
performance_tracker: PerformanceTracker,
) -> None:
"""Test improvement trend with insufficient data.""" """Test improvement trend with insufficient data."""
metrics = [ metrics = [
StrategyMetrics( StrategyMetrics(
@@ -718,7 +732,9 @@ async def test_full_evolution_pipeline(optimizer: EvolutionOptimizer, tmp_path:
mock_response = Mock() mock_response = Mock()
mock_response.text = 'return {"action": "HOLD", "confidence": 50, "rationale": "Test"}' mock_response.text = 'return {"action": "HOLD", "confidence": 50, "rationale": "Test"}'
with patch.object(optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)): with patch.object(
optimizer._client.aio.models, "generate_content", new=AsyncMock(return_value=mock_response)
):
with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path): with patch("src.evolution.optimizer.STRATEGIES_DIR", tmp_path):
with patch("subprocess.run") as mock_run: with patch("subprocess.run") as mock_run:
mock_run.return_value = Mock(returncode=0, stdout="", stderr="") mock_run.return_value = Mock(returncode=0, stdout="", stderr="")

View File

@@ -103,9 +103,7 @@ class TestSetupLogging:
"""setup_logging must attach a JSON handler to the root logger.""" """setup_logging must attach a JSON handler to the root logger."""
setup_logging(level=logging.DEBUG) setup_logging(level=logging.DEBUG)
root = logging.getLogger() root = logging.getLogger()
json_handlers = [ json_handlers = [h for h in root.handlers if isinstance(h.formatter, JSONFormatter)]
h for h in root.handlers if isinstance(h.formatter, JSONFormatter)
]
assert len(json_handlers) == 1 assert len(json_handlers) == 1
assert root.level == logging.DEBUG assert root.level == logging.DEBUG

File diff suppressed because it is too large Load Diff

View File

@@ -173,9 +173,7 @@ class TestGetNextMarketOpen:
"""Should find next Monday opening when called on weekend.""" """Should find next Monday opening when called on weekend."""
# Saturday 2026-02-07 12:00 UTC # Saturday 2026-02-07 12:00 UTC
test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC"))
market, open_time = get_next_market_open( market, open_time = get_next_market_open(enabled_markets=["KR"], now=test_time)
enabled_markets=["KR"], now=test_time
)
assert market.code == "KR" assert market.code == "KR"
# Monday 2026-02-09 09:00 KST # Monday 2026-02-09 09:00 KST
expected = datetime(2026, 2, 9, 9, 0, tzinfo=ZoneInfo("Asia/Seoul")) expected = datetime(2026, 2, 9, 9, 0, tzinfo=ZoneInfo("Asia/Seoul"))
@@ -185,9 +183,7 @@ class TestGetNextMarketOpen:
"""Should find next day opening when called after market close.""" """Should find next day opening when called after market close."""
# Monday 2026-02-02 16:00 KST (after close) # Monday 2026-02-02 16:00 KST (after close)
test_time = datetime(2026, 2, 2, 16, 0, tzinfo=ZoneInfo("Asia/Seoul")) test_time = datetime(2026, 2, 2, 16, 0, tzinfo=ZoneInfo("Asia/Seoul"))
market, open_time = get_next_market_open( market, open_time = get_next_market_open(enabled_markets=["KR"], now=test_time)
enabled_markets=["KR"], now=test_time
)
assert market.code == "KR" assert market.code == "KR"
# Tuesday 2026-02-03 09:00 KST # Tuesday 2026-02-03 09:00 KST
expected = datetime(2026, 2, 3, 9, 0, tzinfo=ZoneInfo("Asia/Seoul")) expected = datetime(2026, 2, 3, 9, 0, tzinfo=ZoneInfo("Asia/Seoul"))
@@ -197,9 +193,7 @@ class TestGetNextMarketOpen:
"""Should find earliest opening market among multiple.""" """Should find earliest opening market among multiple."""
# Saturday 2026-02-07 12:00 UTC # Saturday 2026-02-07 12:00 UTC
test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC"))
market, open_time = get_next_market_open( market, open_time = get_next_market_open(enabled_markets=["KR", "US_NASDAQ"], now=test_time)
enabled_markets=["KR", "US_NASDAQ"], now=test_time
)
# Monday 2026-02-09: KR opens at 09:00 KST = 00:00 UTC # Monday 2026-02-09: KR opens at 09:00 KST = 00:00 UTC
# Monday 2026-02-09: US opens at 09:30 EST = 14:30 UTC # Monday 2026-02-09: US opens at 09:30 EST = 14:30 UTC
# KR opens first # KR opens first
@@ -214,9 +208,7 @@ class TestGetNextMarketOpen:
def test_get_next_market_open_invalid_market(self) -> None: def test_get_next_market_open_invalid_market(self) -> None:
"""Should skip invalid market codes.""" """Should skip invalid market codes."""
test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC")) test_time = datetime(2026, 2, 7, 12, 0, tzinfo=ZoneInfo("UTC"))
market, _ = get_next_market_open( market, _ = get_next_market_open(enabled_markets=["INVALID", "KR"], now=test_time)
enabled_markets=["INVALID", "KR"], now=test_time
)
assert market.code == "KR" assert market.code == "KR"
def test_get_next_market_open_prefers_extended_session(self) -> None: def test_get_next_market_open_prefers_extended_session(self) -> None:

View File

@@ -8,7 +8,7 @@ import aiohttp
import pytest import pytest
from src.broker.kis_api import KISBroker from src.broker.kis_api import KISBroker
from src.broker.overseas import OverseasBroker, _PRICE_EXCHANGE_MAP, _RANKING_EXCHANGE_MAP from src.broker.overseas import _PRICE_EXCHANGE_MAP, _RANKING_EXCHANGE_MAP, OverseasBroker
from src.config import Settings from src.config import Settings
@@ -85,25 +85,27 @@ class TestConfigDefaults:
assert mock_settings.OVERSEAS_RANKING_VOLUME_TR_ID == "HHDFS76270000" assert mock_settings.OVERSEAS_RANKING_VOLUME_TR_ID == "HHDFS76270000"
def test_fluct_path(self, mock_settings: Settings) -> None: def test_fluct_path(self, mock_settings: Settings) -> None:
assert mock_settings.OVERSEAS_RANKING_FLUCT_PATH == "/uapi/overseas-stock/v1/ranking/updown-rate" assert (
mock_settings.OVERSEAS_RANKING_FLUCT_PATH
== "/uapi/overseas-stock/v1/ranking/updown-rate"
)
def test_volume_path(self, mock_settings: Settings) -> None: def test_volume_path(self, mock_settings: Settings) -> None:
assert mock_settings.OVERSEAS_RANKING_VOLUME_PATH == "/uapi/overseas-stock/v1/ranking/volume-surge" assert (
mock_settings.OVERSEAS_RANKING_VOLUME_PATH
== "/uapi/overseas-stock/v1/ranking/volume-surge"
)
class TestFetchOverseasRankings: class TestFetchOverseasRankings:
"""Test fetch_overseas_rankings method.""" """Test fetch_overseas_rankings method."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fluctuation_uses_correct_params( async def test_fluctuation_uses_correct_params(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Fluctuation ranking should use HHDFS76290000, updown-rate path, and correct params.""" """Fluctuation ranking should use HHDFS76290000, updown-rate path, and correct params."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
mock_resp.json = AsyncMock( mock_resp.json = AsyncMock(return_value={"output": [{"symb": "AAPL", "name": "Apple"}]})
return_value={"output": [{"symb": "AAPL", "name": "Apple"}]}
)
mock_session = MagicMock() mock_session = MagicMock()
mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp)) mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp))
@@ -132,15 +134,11 @@ class TestFetchOverseasRankings:
overseas_broker._broker._auth_headers.assert_called_with("HHDFS76290000") overseas_broker._broker._auth_headers.assert_called_with("HHDFS76290000")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_volume_uses_correct_params( async def test_volume_uses_correct_params(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Volume ranking should use HHDFS76270000, volume-surge path, and correct params.""" """Volume ranking should use HHDFS76270000, volume-surge path, and correct params."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
mock_resp.json = AsyncMock( mock_resp.json = AsyncMock(return_value={"output": [{"symb": "TSLA", "name": "Tesla"}]})
return_value={"output": [{"symb": "TSLA", "name": "Tesla"}]}
)
mock_session = MagicMock() mock_session = MagicMock()
mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp)) mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp))
@@ -169,9 +167,7 @@ class TestFetchOverseasRankings:
overseas_broker._broker._auth_headers.assert_called_with("HHDFS76270000") overseas_broker._broker._auth_headers.assert_called_with("HHDFS76270000")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_404_returns_empty_list( async def test_404_returns_empty_list(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""HTTP 404 should return empty list (fallback) instead of raising.""" """HTTP 404 should return empty list (fallback) instead of raising."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 404 mock_resp.status = 404
@@ -186,9 +182,7 @@ class TestFetchOverseasRankings:
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_non_404_error_raises( async def test_non_404_error_raises(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Non-404 HTTP errors should raise ConnectionError.""" """Non-404 HTTP errors should raise ConnectionError."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 500 mock_resp.status = 500
@@ -203,9 +197,7 @@ class TestFetchOverseasRankings:
await overseas_broker.fetch_overseas_rankings("NASD") await overseas_broker.fetch_overseas_rankings("NASD")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_response_returns_empty( async def test_empty_response_returns_empty(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Empty output in response should return empty list.""" """Empty output in response should return empty list."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -220,18 +212,14 @@ class TestFetchOverseasRankings:
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ranking_disabled_returns_empty( async def test_ranking_disabled_returns_empty(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""When OVERSEAS_RANKING_ENABLED=False, should return empty immediately.""" """When OVERSEAS_RANKING_ENABLED=False, should return empty immediately."""
overseas_broker._broker._settings.OVERSEAS_RANKING_ENABLED = False overseas_broker._broker._settings.OVERSEAS_RANKING_ENABLED = False
result = await overseas_broker.fetch_overseas_rankings("NASD") result = await overseas_broker.fetch_overseas_rankings("NASD")
assert result == [] assert result == []
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_limit_truncates_results( async def test_limit_truncates_results(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Results should be truncated to the specified limit.""" """Results should be truncated to the specified limit."""
rows = [{"symb": f"SYM{i}"} for i in range(20)] rows = [{"symb": f"SYM{i}"} for i in range(20)]
mock_resp = AsyncMock() mock_resp = AsyncMock()
@@ -247,9 +235,7 @@ class TestFetchOverseasRankings:
assert len(result) == 5 assert len(result) == 5
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_network_error_raises( async def test_network_error_raises(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Network errors should raise ConnectionError.""" """Network errors should raise ConnectionError."""
cm = MagicMock() cm = MagicMock()
cm.__aenter__ = AsyncMock(side_effect=aiohttp.ClientError("timeout")) cm.__aenter__ = AsyncMock(side_effect=aiohttp.ClientError("timeout"))
@@ -264,9 +250,7 @@ class TestFetchOverseasRankings:
await overseas_broker.fetch_overseas_rankings("NASD") await overseas_broker.fetch_overseas_rankings("NASD")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_exchange_code_mapping_applied( async def test_exchange_code_mapping_applied(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""All major exchanges should use mapped codes in API params.""" """All major exchanges should use mapped codes in API params."""
for original, mapped in [("NASD", "NAS"), ("NYSE", "NYS"), ("AMEX", "AMS")]: for original, mapped in [("NASD", "NAS"), ("NYSE", "NYS"), ("AMEX", "AMS")]:
mock_resp = AsyncMock() mock_resp = AsyncMock()
@@ -298,7 +282,9 @@ class TestGetOverseasPrice:
mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp)) mock_session.get = MagicMock(return_value=_make_async_cm(mock_resp))
_setup_broker_mocks(overseas_broker, mock_session) _setup_broker_mocks(overseas_broker, mock_session)
overseas_broker._broker._auth_headers = AsyncMock(return_value={"authorization": "Bearer t"}) overseas_broker._broker._auth_headers = AsyncMock(
return_value={"authorization": "Bearer t"}
)
result = await overseas_broker.get_overseas_price("NASD", "AAPL") result = await overseas_broker.get_overseas_price("NASD", "AAPL")
assert result["output"]["last"] == "150.00" assert result["output"]["last"] == "150.00"
@@ -530,11 +516,14 @@ class TestPriceExchangeMap:
def test_price_map_equals_ranking_map(self) -> None: def test_price_map_equals_ranking_map(self) -> None:
assert _PRICE_EXCHANGE_MAP is _RANKING_EXCHANGE_MAP assert _PRICE_EXCHANGE_MAP is _RANKING_EXCHANGE_MAP
@pytest.mark.parametrize("original,expected", [ @pytest.mark.parametrize(
("NASD", "NAS"), "original,expected",
("NYSE", "NYS"), [
("AMEX", "AMS"), ("NASD", "NAS"),
]) ("NYSE", "NYS"),
("AMEX", "AMS"),
],
)
def test_us_exchange_code_mapping(self, original: str, expected: str) -> None: def test_us_exchange_code_mapping(self, original: str, expected: str) -> None:
assert _PRICE_EXCHANGE_MAP[original] == expected assert _PRICE_EXCHANGE_MAP[original] == expected
@@ -574,9 +563,7 @@ class TestOrderRtCdCheck:
return OverseasBroker(broker) return OverseasBroker(broker)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_success_rt_cd_returns_data( async def test_success_rt_cd_returns_data(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""rt_cd='0' → order accepted, data returned.""" """rt_cd='0' → order accepted, data returned."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -590,9 +577,7 @@ class TestOrderRtCdCheck:
assert result["rt_cd"] == "0" assert result["rt_cd"] == "0"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_error_rt_cd_returns_data_with_msg( async def test_error_rt_cd_returns_data_with_msg(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""rt_cd != '0' → order rejected, data still returned (caller checks rt_cd).""" """rt_cd != '0' → order rejected, data still returned (caller checks rt_cd)."""
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -623,6 +608,7 @@ class TestPaperOverseasCash:
def test_env_override(self) -> None: def test_env_override(self) -> None:
import os import os
os.environ["PAPER_OVERSEAS_CASH"] = "25000" os.environ["PAPER_OVERSEAS_CASH"] = "25000"
settings = Settings( settings = Settings(
KIS_APP_KEY="k", KIS_APP_KEY="k",
@@ -635,6 +621,7 @@ class TestPaperOverseasCash:
def test_zero_disables_fallback(self) -> None: def test_zero_disables_fallback(self) -> None:
import os import os
os.environ["PAPER_OVERSEAS_CASH"] = "0" os.environ["PAPER_OVERSEAS_CASH"] = "0"
settings = Settings( settings = Settings(
KIS_APP_KEY="k", KIS_APP_KEY="k",
@@ -822,9 +809,7 @@ class TestGetOverseasPendingOrders:
"""Tests for get_overseas_pending_orders method.""" """Tests for get_overseas_pending_orders method."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_paper_mode_returns_empty( async def test_paper_mode_returns_empty(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Paper mode should immediately return [] without any API call.""" """Paper mode should immediately return [] without any API call."""
# Default mock_settings has MODE="paper" # Default mock_settings has MODE="paper"
overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy(
@@ -855,9 +840,7 @@ class TestGetOverseasPendingOrders:
overseas_broker._broker._auth_headers = mock_auth_headers # type: ignore[method-assign] overseas_broker._broker._auth_headers = mock_auth_headers # type: ignore[method-assign]
pending_orders = [ pending_orders = [{"odno": "001", "pdno": "AAPL", "sll_buy_dvsn_cd": "02", "nccs_qty": "5"}]
{"odno": "001", "pdno": "AAPL", "sll_buy_dvsn_cd": "02", "nccs_qty": "5"}
]
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
mock_resp.json = AsyncMock(return_value={"output": pending_orders}) mock_resp.json = AsyncMock(return_value={"output": pending_orders})
@@ -879,9 +862,7 @@ class TestGetOverseasPendingOrders:
assert captured_params[0]["OVRS_EXCG_CD"] == "NASD" assert captured_params[0]["OVRS_EXCG_CD"] == "NASD"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_live_mode_connection_error( async def test_live_mode_connection_error(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Network error in live mode should raise ConnectionError.""" """Network error in live mode should raise ConnectionError."""
overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy(
update={"MODE": "live"} update={"MODE": "live"}
@@ -926,55 +907,41 @@ class TestCancelOverseasOrder:
return captured_tr_ids, mock_session return captured_tr_ids, mock_session
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_us_live_uses_tttt1004u( async def test_us_live_uses_tttt1004u(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""US exchange in live mode should use TTTT1004U.""" """US exchange in live mode should use TTTT1004U."""
overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy(
update={"MODE": "live"} update={"MODE": "live"}
) )
captured, _ = self._setup_cancel_mocks( captured, _ = self._setup_cancel_mocks(overseas_broker, {"rt_cd": "0", "msg1": "OK"})
overseas_broker, {"rt_cd": "0", "msg1": "OK"}
)
await overseas_broker.cancel_overseas_order("NASD", "AAPL", "ORD001", 5) await overseas_broker.cancel_overseas_order("NASD", "AAPL", "ORD001", 5)
assert "TTTT1004U" in captured assert "TTTT1004U" in captured
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_us_paper_uses_vttt1004u( async def test_us_paper_uses_vttt1004u(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""US exchange in paper mode should use VTTT1004U.""" """US exchange in paper mode should use VTTT1004U."""
# Default mock_settings has MODE="paper" # Default mock_settings has MODE="paper"
captured, _ = self._setup_cancel_mocks( captured, _ = self._setup_cancel_mocks(overseas_broker, {"rt_cd": "0", "msg1": "OK"})
overseas_broker, {"rt_cd": "0", "msg1": "OK"}
)
await overseas_broker.cancel_overseas_order("NASD", "AAPL", "ORD001", 5) await overseas_broker.cancel_overseas_order("NASD", "AAPL", "ORD001", 5)
assert "VTTT1004U" in captured assert "VTTT1004U" in captured
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_hk_live_uses_ttts1003u( async def test_hk_live_uses_ttts1003u(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""SEHK exchange in live mode should use TTTS1003U.""" """SEHK exchange in live mode should use TTTS1003U."""
overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy( overseas_broker._broker._settings = overseas_broker._broker._settings.model_copy(
update={"MODE": "live"} update={"MODE": "live"}
) )
captured, _ = self._setup_cancel_mocks( captured, _ = self._setup_cancel_mocks(overseas_broker, {"rt_cd": "0", "msg1": "OK"})
overseas_broker, {"rt_cd": "0", "msg1": "OK"}
)
await overseas_broker.cancel_overseas_order("SEHK", "0700", "ORD002", 10) await overseas_broker.cancel_overseas_order("SEHK", "0700", "ORD002", 10)
assert "TTTS1003U" in captured assert "TTTS1003U" in captured
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancel_sets_rvse_cncl_dvsn_cd_02( async def test_cancel_sets_rvse_cncl_dvsn_cd_02(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""Cancel body must include RVSE_CNCL_DVSN_CD='02' and OVRS_ORD_UNPR='0'.""" """Cancel body must include RVSE_CNCL_DVSN_CD='02' and OVRS_ORD_UNPR='0'."""
captured_body: list[dict] = [] captured_body: list[dict] = []
@@ -1005,9 +972,7 @@ class TestCancelOverseasOrder:
assert captured_body[0]["ORGN_ODNO"] == "ORD003" assert captured_body[0]["ORGN_ODNO"] == "ORD003"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cancel_sets_hashkey_header( async def test_cancel_sets_hashkey_header(self, overseas_broker: OverseasBroker) -> None:
self, overseas_broker: OverseasBroker
) -> None:
"""hashkey must be set in the request headers.""" """hashkey must be set in the request headers."""
captured_headers: list[dict] = [] captured_headers: list[dict] = []
overseas_broker._broker._get_hash_key = AsyncMock(return_value="test_hash") # type: ignore[method-assign] overseas_broker._broker._get_hash_key = AsyncMock(return_value="test_hash") # type: ignore[method-assign]

View File

@@ -78,9 +78,7 @@ def _gemini_response_json(
"rationale": "Near circuit breaker", "rationale": "Near circuit breaker",
} }
] ]
return json.dumps( return json.dumps({"market_outlook": outlook, "global_rules": global_rules, "stocks": stocks})
{"market_outlook": outlook, "global_rules": global_rules, "stocks": stocks}
)
def _make_planner( def _make_planner(
@@ -564,8 +562,12 @@ class TestBuildPrompt:
def test_prompt_contains_cross_market(self) -> None: def test_prompt_contains_cross_market(self) -> None:
planner = _make_planner() planner = _make_planner()
cross = CrossMarketContext( cross = CrossMarketContext(
market="US", date="2026-02-07", total_pnl=1.5, market="US",
win_rate=60, index_change_pct=0.8, lessons=["Cut losses early"], date="2026-02-07",
total_pnl=1.5,
win_rate=60,
index_change_pct=0.8,
lessons=["Cut losses early"],
) )
prompt = planner._build_prompt("KR", [_candidate()], {}, None, cross) prompt = planner._build_prompt("KR", [_candidate()], {}, None, cross)
@@ -683,9 +685,7 @@ class TestSmartFallbackPlaybook:
) )
def test_momentum_candidate_gets_buy_on_volume(self) -> None: def test_momentum_candidate_gets_buy_on_volume(self) -> None:
candidates = [ candidates = [_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)]
_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)
]
settings = self._make_settings() settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook( pb = PreMarketPlanner._smart_fallback_playbook(
@@ -707,9 +707,7 @@ class TestSmartFallbackPlaybook:
assert sell_sc.condition.price_change_pct_below == -3.0 assert sell_sc.condition.price_change_pct_below == -3.0
def test_oversold_candidate_gets_buy_on_rsi(self) -> None: def test_oversold_candidate_gets_buy_on_rsi(self) -> None:
candidates = [ candidates = [_candidate(code="005930", signal="oversold", rsi=22.0, volume_ratio=3.5)]
_candidate(code="005930", signal="oversold", rsi=22.0, volume_ratio=3.5)
]
settings = self._make_settings() settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook( pb = PreMarketPlanner._smart_fallback_playbook(
@@ -776,9 +774,7 @@ class TestSmartFallbackPlaybook:
def test_empty_candidates_returns_empty_playbook(self) -> None: def test_empty_candidates_returns_empty_playbook(self) -> None:
settings = self._make_settings() settings = self._make_settings()
pb = PreMarketPlanner._smart_fallback_playbook( pb = PreMarketPlanner._smart_fallback_playbook(date(2026, 2, 17), "US_AMEX", [], settings)
date(2026, 2, 17), "US_AMEX", [], settings
)
assert pb.stock_count == 0 assert pb.stock_count == 0
@@ -814,19 +810,14 @@ class TestSmartFallbackPlaybook:
planner = _make_planner() planner = _make_planner()
planner._gemini.decide = AsyncMock(side_effect=ConnectionError("429 quota exceeded")) planner._gemini.decide = AsyncMock(side_effect=ConnectionError("429 quota exceeded"))
# momentum candidate # momentum candidate
candidates = [ candidates = [_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)]
_candidate(code="CHOW", signal="momentum", volume_ratio=13.64, rsi=100.0)
]
pb = await planner.generate_playbook( pb = await planner.generate_playbook("US_AMEX", candidates, today=date(2026, 2, 18))
"US_AMEX", candidates, today=date(2026, 2, 18)
)
# Should NOT be all-SELL defensive; should have BUY for momentum # Should NOT be all-SELL defensive; should have BUY for momentum
assert pb.stock_count == 1 assert pb.stock_count == 1
buy_scenarios = [ buy_scenarios = [
s for s in pb.stock_playbooks[0].scenarios s for s in pb.stock_playbooks[0].scenarios if s.action == ScenarioAction.BUY
if s.action == ScenarioAction.BUY
] ]
assert len(buy_scenarios) == 1 assert len(buy_scenarios) == 1
assert buy_scenarios[0].condition.volume_ratio_above == 2.0 # VOL_MULTIPLIER default assert buy_scenarios[0].condition.volume_ratio_above == 2.0 # VOL_MULTIPLIER default

View File

@@ -14,7 +14,7 @@ from src.strategy.models import (
StockPlaybook, StockPlaybook,
StockScenario, StockScenario,
) )
from src.strategy.scenario_engine import ScenarioEngine, ScenarioMatch from src.strategy.scenario_engine import ScenarioEngine
@pytest.fixture @pytest.fixture
@@ -162,13 +162,15 @@ class TestEvaluateCondition:
def test_mixed_invalid_types_no_exception(self, engine: ScenarioEngine) -> None: def test_mixed_invalid_types_no_exception(self, engine: ScenarioEngine) -> None:
"""Various invalid types should not raise exceptions.""" """Various invalid types should not raise exceptions."""
cond = StockCondition( cond = StockCondition(
rsi_below=30.0, volume_ratio_above=2.0, rsi_below=30.0,
price_above=100, price_change_pct_below=-1.0, volume_ratio_above=2.0,
price_above=100,
price_change_pct_below=-1.0,
) )
data = { data = {
"rsi": [25], # list "rsi": [25], # list
"volume_ratio": "bad", # non-numeric string "volume_ratio": "bad", # non-numeric string
"current_price": {}, # dict "current_price": {}, # dict
"price_change_pct": object(), # arbitrary object "price_change_pct": object(), # arbitrary object
} }
# Should return False (invalid types → None → False), never raise # Should return False (invalid types → None → False), never raise
@@ -356,9 +358,7 @@ class TestEvaluate:
def test_match_details_populated(self, engine: ScenarioEngine) -> None: def test_match_details_populated(self, engine: ScenarioEngine) -> None:
pb = _playbook(scenarios=[_scenario(rsi_below=30.0, volume_ratio_above=2.0)]) pb = _playbook(scenarios=[_scenario(rsi_below=30.0, volume_ratio_above=2.0)])
result = engine.evaluate( result = engine.evaluate(pb, "005930", {"rsi": 25.0, "volume_ratio": 3.0}, {})
pb, "005930", {"rsi": 25.0, "volume_ratio": 3.0}, {}
)
assert result.match_details.get("rsi") == 25.0 assert result.match_details.get("rsi") == 25.0
assert result.match_details.get("volume_ratio") == 3.0 assert result.match_details.get("volume_ratio") == 3.0
@@ -381,7 +381,9 @@ class TestEvaluate:
), ),
StockPlaybook( StockPlaybook(
stock_code="MSFT", stock_code="MSFT",
scenarios=[_scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80)], scenarios=[
_scenario(rsi_above=75.0, action=ScenarioAction.SELL, confidence=80)
],
), ),
], ],
) )
@@ -450,58 +452,42 @@ class TestEvaluate:
class TestPositionAwareConditions: class TestPositionAwareConditions:
"""Tests for unrealized_pnl_pct and holding_days condition fields.""" """Tests for unrealized_pnl_pct and holding_days condition fields."""
def test_evaluate_condition_unrealized_pnl_above_matches( def test_evaluate_condition_unrealized_pnl_above_matches(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_above should match when P&L exceeds threshold.""" """unrealized_pnl_pct_above should match when P&L exceeds threshold."""
condition = StockCondition(unrealized_pnl_pct_above=3.0) condition = StockCondition(unrealized_pnl_pct_above=3.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 5.0}) is True assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 5.0}) is True
def test_evaluate_condition_unrealized_pnl_above_no_match( def test_evaluate_condition_unrealized_pnl_above_no_match(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_above should NOT match when P&L is below threshold.""" """unrealized_pnl_pct_above should NOT match when P&L is below threshold."""
condition = StockCondition(unrealized_pnl_pct_above=3.0) condition = StockCondition(unrealized_pnl_pct_above=3.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 2.0}) is False assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": 2.0}) is False
def test_evaluate_condition_unrealized_pnl_below_matches( def test_evaluate_condition_unrealized_pnl_below_matches(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_below should match when P&L is under threshold.""" """unrealized_pnl_pct_below should match when P&L is under threshold."""
condition = StockCondition(unrealized_pnl_pct_below=-2.0) condition = StockCondition(unrealized_pnl_pct_below=-2.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -3.5}) is True assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -3.5}) is True
def test_evaluate_condition_unrealized_pnl_below_no_match( def test_evaluate_condition_unrealized_pnl_below_no_match(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""unrealized_pnl_pct_below should NOT match when P&L is above threshold.""" """unrealized_pnl_pct_below should NOT match when P&L is above threshold."""
condition = StockCondition(unrealized_pnl_pct_below=-2.0) condition = StockCondition(unrealized_pnl_pct_below=-2.0)
assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -1.0}) is False assert engine.evaluate_condition(condition, {"unrealized_pnl_pct": -1.0}) is False
def test_evaluate_condition_holding_days_above_matches( def test_evaluate_condition_holding_days_above_matches(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""holding_days_above should match when position held longer than threshold.""" """holding_days_above should match when position held longer than threshold."""
condition = StockCondition(holding_days_above=5) condition = StockCondition(holding_days_above=5)
assert engine.evaluate_condition(condition, {"holding_days": 7}) is True assert engine.evaluate_condition(condition, {"holding_days": 7}) is True
def test_evaluate_condition_holding_days_above_no_match( def test_evaluate_condition_holding_days_above_no_match(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""holding_days_above should NOT match when position held shorter.""" """holding_days_above should NOT match when position held shorter."""
condition = StockCondition(holding_days_above=5) condition = StockCondition(holding_days_above=5)
assert engine.evaluate_condition(condition, {"holding_days": 3}) is False assert engine.evaluate_condition(condition, {"holding_days": 3}) is False
def test_evaluate_condition_holding_days_below_matches( def test_evaluate_condition_holding_days_below_matches(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""holding_days_below should match when position held fewer days.""" """holding_days_below should match when position held fewer days."""
condition = StockCondition(holding_days_below=3) condition = StockCondition(holding_days_below=3)
assert engine.evaluate_condition(condition, {"holding_days": 1}) is True assert engine.evaluate_condition(condition, {"holding_days": 1}) is True
def test_evaluate_condition_holding_days_below_no_match( def test_evaluate_condition_holding_days_below_no_match(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""holding_days_below should NOT match when held more days.""" """holding_days_below should NOT match when held more days."""
condition = StockCondition(holding_days_below=3) condition = StockCondition(holding_days_below=3)
assert engine.evaluate_condition(condition, {"holding_days": 5}) is False assert engine.evaluate_condition(condition, {"holding_days": 5}) is False
@@ -513,33 +499,33 @@ class TestPositionAwareConditions:
holding_days_above=5, holding_days_above=5,
) )
# Both met → match # Both met → match
assert engine.evaluate_condition( assert (
condition, engine.evaluate_condition(
{"unrealized_pnl_pct": 4.5, "holding_days": 7}, condition,
) is True {"unrealized_pnl_pct": 4.5, "holding_days": 7},
)
is True
)
# Only pnl met → no match # Only pnl met → no match
assert engine.evaluate_condition( assert (
condition, engine.evaluate_condition(
{"unrealized_pnl_pct": 4.5, "holding_days": 3}, condition,
) is False {"unrealized_pnl_pct": 4.5, "holding_days": 3},
)
is False
)
def test_missing_unrealized_pnl_does_not_match( def test_missing_unrealized_pnl_does_not_match(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""Missing unrealized_pnl_pct key should not match the condition.""" """Missing unrealized_pnl_pct key should not match the condition."""
condition = StockCondition(unrealized_pnl_pct_above=3.0) condition = StockCondition(unrealized_pnl_pct_above=3.0)
assert engine.evaluate_condition(condition, {}) is False assert engine.evaluate_condition(condition, {}) is False
def test_missing_holding_days_does_not_match( def test_missing_holding_days_does_not_match(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""Missing holding_days key should not match the condition.""" """Missing holding_days key should not match the condition."""
condition = StockCondition(holding_days_above=5) condition = StockCondition(holding_days_above=5)
assert engine.evaluate_condition(condition, {}) is False assert engine.evaluate_condition(condition, {}) is False
def test_match_details_includes_position_fields( def test_match_details_includes_position_fields(self, engine: ScenarioEngine) -> None:
self, engine: ScenarioEngine
) -> None:
"""match_details should include position fields when condition specifies them.""" """match_details should include position fields when condition specifies them."""
pb = _playbook( pb = _playbook(
scenarios=[ scenarios=[

View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import importlib.util
from pathlib import Path
def _load_module():
script_path = Path(__file__).resolve().parents[1] / "scripts" / "session_handover_check.py"
spec = importlib.util.spec_from_file_location("session_handover_check", script_path)
assert spec is not None
assert spec.loader is not None
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def test_ci_mode_skips_date_branch_and_merge_gate(monkeypatch, tmp_path) -> None:
module = _load_module()
handover = tmp_path / "session-handover.md"
handover.write_text(
"\n".join(
[
"### 2000-01-01 | session=test",
"- branch: feature/other-branch",
"- docs_checked: docs/workflow.md, docs/commands.md, docs/agent-constraints.md",
"- open_issues_reviewed: #1",
"- next_ticket: #123",
"- process_gate_checked: process_ticket=#1 merged_to_feature_branch=no",
]
),
encoding="utf-8",
)
monkeypatch.setattr(module, "HANDOVER_LOG", handover)
errors: list[str] = []
module._check_handover_entry(
branch="feature/current-branch",
strict=True,
ci_mode=True,
errors=errors,
)
assert errors == []
def test_ci_mode_still_blocks_tbd_next_ticket(monkeypatch, tmp_path) -> None:
module = _load_module()
handover = tmp_path / "session-handover.md"
handover.write_text(
"\n".join(
[
"### 2000-01-01 | session=test",
"- branch: feature/other-branch",
"- docs_checked: docs/workflow.md, docs/commands.md, docs/agent-constraints.md",
"- open_issues_reviewed: #1",
"- next_ticket: #TBD",
"- process_gate_checked: process_ticket=#1 merged_to_feature_branch=no",
]
),
encoding="utf-8",
)
monkeypatch.setattr(module, "HANDOVER_LOG", handover)
errors: list[str] = []
module._check_handover_entry(
branch="feature/current-branch",
strict=True,
ci_mode=True,
errors=errors,
)
assert "latest handover entry must not use placeholder next_ticket (#TBD)" in errors
def test_non_ci_strict_enforces_date_branch_and_merge_gate(monkeypatch, tmp_path) -> None:
module = _load_module()
handover = tmp_path / "session-handover.md"
handover.write_text(
"\n".join(
[
"### 2000-01-01 | session=test",
"- branch: feature/other-branch",
"- docs_checked: docs/workflow.md, docs/commands.md, docs/agent-constraints.md",
"- open_issues_reviewed: #1",
"- next_ticket: #123",
"- process_gate_checked: process_ticket=#1 merged_to_feature_branch=no",
]
),
encoding="utf-8",
)
monkeypatch.setattr(module, "HANDOVER_LOG", handover)
errors: list[str] = []
module._check_handover_entry(
branch="feature/current-branch",
strict=True,
ci_mode=False,
errors=errors,
)
assert any("must contain today's UTC date" in e for e in errors)
assert any("must target current branch" in e for e in errors)
assert any("merged_to_feature_branch=no" in e for e in errors)

View File

@@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
import pytest
from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner from src.analysis.smart_scanner import ScanCandidate, SmartVolatilityScanner
from src.analysis.volatility import VolatilityAnalyzer from src.analysis.volatility import VolatilityAnalyzer
from src.broker.kis_api import KISBroker from src.broker.kis_api import KISBroker
@@ -200,9 +201,7 @@ class TestSmartVolatilityScanner:
assert len(candidates) <= scanner.top_n assert len(candidates) <= scanner.top_n
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_stock_codes( async def test_get_stock_codes(self, scanner: SmartVolatilityScanner) -> None:
self, scanner: SmartVolatilityScanner
) -> None:
"""Test extraction of stock codes from candidates.""" """Test extraction of stock codes from candidates."""
candidates = [ candidates = [
ScanCandidate( ScanCandidate(

View File

@@ -19,7 +19,6 @@ from src.strategy.models import (
StockScenario, StockScenario,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# StockCondition # StockCondition
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -5,7 +5,11 @@ from unittest.mock import AsyncMock, patch
import aiohttp import aiohttp
import pytest import pytest
from src.notifications.telegram_client import NotificationFilter, NotificationPriority, TelegramClient from src.notifications.telegram_client import (
NotificationFilter,
NotificationPriority,
TelegramClient,
)
class TestTelegramClientInit: class TestTelegramClientInit:
@@ -13,9 +17,7 @@ class TestTelegramClientInit:
def test_disabled_via_flag(self) -> None: def test_disabled_via_flag(self) -> None:
"""Client disabled via enabled=False flag.""" """Client disabled via enabled=False flag."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=False)
bot_token="123:abc", chat_id="456", enabled=False
)
assert client._enabled is False assert client._enabled is False
def test_disabled_missing_token(self) -> None: def test_disabled_missing_token(self) -> None:
@@ -30,9 +32,7 @@ class TestTelegramClientInit:
def test_enabled_with_credentials(self) -> None: def test_enabled_with_credentials(self) -> None:
"""Client enabled when credentials provided.""" """Client enabled when credentials provided."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
assert client._enabled is True assert client._enabled is True
@@ -42,9 +42,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_success(self) -> None: async def test_send_message_success(self) -> None:
"""send_message returns True on successful send.""" """send_message returns True on successful send."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -76,9 +74,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_api_error(self) -> None: async def test_send_message_api_error(self) -> None:
"""send_message returns False on API error.""" """send_message returns False on API error."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 400 mock_resp.status = 400
@@ -93,9 +89,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_message_with_markdown(self) -> None: async def test_send_message_with_markdown(self) -> None:
"""send_message supports different parse modes.""" """send_message supports different parse modes."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -128,9 +122,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trade_execution_format(self) -> None: async def test_trade_execution_format(self) -> None:
"""Trade notification has correct format.""" """Trade notification has correct format."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -163,9 +155,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_playbook_generated_format(self) -> None: async def test_playbook_generated_format(self) -> None:
"""Playbook generated notification has expected fields.""" """Playbook generated notification has expected fields."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -190,9 +180,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scenario_matched_format(self) -> None: async def test_scenario_matched_format(self) -> None:
"""Scenario matched notification has expected fields.""" """Scenario matched notification has expected fields."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -217,9 +205,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_playbook_failed_format(self) -> None: async def test_playbook_failed_format(self) -> None:
"""Playbook failed notification has expected fields.""" """Playbook failed notification has expected fields."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -240,9 +226,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_circuit_breaker_priority(self) -> None: async def test_circuit_breaker_priority(self) -> None:
"""Circuit breaker uses CRITICAL priority.""" """Circuit breaker uses CRITICAL priority."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -260,9 +244,7 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_error_handling(self) -> None: async def test_api_error_handling(self) -> None:
"""API errors logged but don't crash.""" """API errors logged but don't crash."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 400 mock_resp.status = 400
@@ -277,25 +259,19 @@ class TestNotificationSending:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_timeout_handling(self) -> None: async def test_timeout_handling(self) -> None:
"""Timeouts logged but don't crash.""" """Timeouts logged but don't crash."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
with patch( with patch(
"aiohttp.ClientSession.post", "aiohttp.ClientSession.post",
side_effect=aiohttp.ClientError("Connection timeout"), side_effect=aiohttp.ClientError("Connection timeout"),
): ):
# Should not raise exception # Should not raise exception
await client.notify_error( await client.notify_error(error_type="Test Error", error_msg="Test", context="test")
error_type="Test Error", error_msg="Test", context="test"
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_session_management(self) -> None: async def test_session_management(self) -> None:
"""Session created and reused correctly.""" """Session created and reused correctly."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
# Session should be None initially # Session should be None initially
assert client._session is None assert client._session is None
@@ -324,9 +300,7 @@ class TestRateLimiting:
"""Rate limiter delays rapid requests.""" """Rate limiter delays rapid requests."""
import time import time
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0)
bot_token="123:abc", chat_id="456", enabled=True, rate_limit=2.0
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -353,9 +327,7 @@ class TestMessagePriorities:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_low_priority_uses_info_emoji(self) -> None: async def test_low_priority_uses_info_emoji(self) -> None:
"""LOW priority uses emoji.""" """LOW priority uses emoji."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -371,9 +343,7 @@ class TestMessagePriorities:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_critical_priority_uses_alarm_emoji(self) -> None: async def test_critical_priority_uses_alarm_emoji(self) -> None:
"""CRITICAL priority uses 🚨 emoji.""" """CRITICAL priority uses 🚨 emoji."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -389,9 +359,7 @@ class TestMessagePriorities:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_playbook_generated_priority(self) -> None: async def test_playbook_generated_priority(self) -> None:
"""Playbook generated uses MEDIUM priority emoji.""" """Playbook generated uses MEDIUM priority emoji."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -412,9 +380,7 @@ class TestMessagePriorities:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_playbook_failed_priority(self) -> None: async def test_playbook_failed_priority(self) -> None:
"""Playbook failed uses HIGH priority emoji.""" """Playbook failed uses HIGH priority emoji."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -433,9 +399,7 @@ class TestMessagePriorities:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_scenario_matched_priority(self) -> None: async def test_scenario_matched_priority(self) -> None:
"""Scenario matched uses HIGH priority emoji.""" """Scenario matched uses HIGH priority emoji."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_resp = AsyncMock() mock_resp = AsyncMock()
mock_resp.status = 200 mock_resp.status = 200
@@ -460,9 +424,7 @@ class TestClientCleanup:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_close_closes_session(self) -> None: async def test_close_closes_session(self) -> None:
"""close() closes the HTTP session.""" """close() closes the HTTP session."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
mock_session = AsyncMock() mock_session = AsyncMock()
mock_session.closed = False mock_session.closed = False
@@ -475,9 +437,7 @@ class TestClientCleanup:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_close_handles_no_session(self) -> None: async def test_close_handles_no_session(self) -> None:
"""close() handles None session gracefully.""" """close() handles None session gracefully."""
client = TelegramClient( client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
bot_token="123:abc", chat_id="456", enabled=True
)
# Should not raise exception # Should not raise exception
await client.close() await client.close()
@@ -535,8 +495,12 @@ class TestNotificationFilter:
) )
with patch("aiohttp.ClientSession.post") as mock_post: with patch("aiohttp.ClientSession.post") as mock_post:
await client.notify_trade_execution( await client.notify_trade_execution(
stock_code="005930", market="KR", action="BUY", stock_code="005930",
quantity=10, price=70000.0, confidence=85.0 market="KR",
action="BUY",
quantity=10,
price=70000.0,
confidence=85.0,
) )
mock_post.assert_not_called() mock_post.assert_not_called()
@@ -556,8 +520,13 @@ class TestNotificationFilter:
async def test_circuit_breaker_always_sends_regardless_of_filter(self) -> None: async def test_circuit_breaker_always_sends_regardless_of_filter(self) -> None:
"""notify_circuit_breaker always sends (no filter flag).""" """notify_circuit_breaker always sends (no filter flag)."""
nf = NotificationFilter( nf = NotificationFilter(
trades=False, market_open_close=False, fat_finger=False, trades=False,
system_events=False, playbook=False, scenario_match=False, errors=False, market_open_close=False,
fat_finger=False,
system_events=False,
playbook=False,
scenario_match=False,
errors=False,
) )
client = TelegramClient( client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf bot_token="123:abc", chat_id="456", enabled=True, notification_filter=nf
@@ -617,7 +586,7 @@ class TestNotificationFilter:
nf = NotificationFilter() nf = NotificationFilter()
assert nf.set_flag("unknown_key", False) is False assert nf.set_flag("unknown_key", False) is False
def test_as_dict_keys_match_KEYS(self) -> None: def test_as_dict_keys_match_keys(self) -> None:
"""as_dict() returns every key defined in KEYS.""" """as_dict() returns every key defined in KEYS."""
nf = NotificationFilter() nf = NotificationFilter()
d = nf.as_dict() d = nf.as_dict()
@@ -640,10 +609,17 @@ class TestNotificationFilter:
def test_set_notification_all_on(self) -> None: def test_set_notification_all_on(self) -> None:
"""set_notification('all', True) enables every filter flag.""" """set_notification('all', True) enables every filter flag."""
client = TelegramClient( client = TelegramClient(
bot_token="123:abc", chat_id="456", enabled=True, bot_token="123:abc",
chat_id="456",
enabled=True,
notification_filter=NotificationFilter( notification_filter=NotificationFilter(
trades=False, market_open_close=False, scenario_match=False, trades=False,
fat_finger=False, system_events=False, playbook=False, errors=False, market_open_close=False,
scenario_match=False,
fat_finger=False,
system_events=False,
playbook=False,
errors=False,
), ),
) )
assert client.set_notification("all", True) is True assert client.set_notification("all", True) is True

View File

@@ -357,8 +357,7 @@ class TestTradingControlCommands:
pause_event.set() pause_event.set()
await client.send_message( await client.send_message(
"<b>▶️ Trading Resumed</b>\n\n" "<b>▶️ Trading Resumed</b>\n\nTrading operations have been restarted."
"Trading operations have been restarted."
) )
handler.register_command("resume", mock_resume) handler.register_command("resume", mock_resume)
@@ -526,9 +525,7 @@ class TestStatusCommands:
async def mock_status_error() -> None: async def mock_status_error() -> None:
"""Mock /status handler with error.""" """Mock /status handler with error."""
await client.send_message( await client.send_message("<b>⚠️ Error</b>\n\nFailed to retrieve trading status.")
"<b>⚠️ Error</b>\n\nFailed to retrieve trading status."
)
handler.register_command("status", mock_status_error) handler.register_command("status", mock_status_error)
@@ -603,10 +600,7 @@ class TestStatusCommands:
async def mock_positions_empty() -> None: async def mock_positions_empty() -> None:
"""Mock /positions handler with no positions.""" """Mock /positions handler with no positions."""
message = ( message = "<b>💼 Account Summary</b>\n\nNo balance information available."
"<b>💼 Account Summary</b>\n\n"
"No balance information available."
)
await client.send_message(message) await client.send_message(message)
handler.register_command("positions", mock_positions_empty) handler.register_command("positions", mock_positions_empty)
@@ -639,9 +633,7 @@ class TestStatusCommands:
async def mock_positions_error() -> None: async def mock_positions_error() -> None:
"""Mock /positions handler with error.""" """Mock /positions handler with error."""
await client.send_message( await client.send_message("<b>⚠️ Error</b>\n\nFailed to retrieve positions.")
"<b>⚠️ Error</b>\n\nFailed to retrieve positions."
)
handler.register_command("positions", mock_positions_error) handler.register_command("positions", mock_positions_error)

View File

@@ -70,7 +70,9 @@ def test_load_changed_files_with_range_uses_git_diff(monkeypatch) -> None:
assert check is True assert check is True
assert capture_output is True assert capture_output is True
assert text is True assert text is True
return SimpleNamespace(stdout="docs/ouroboros/85_loss_recovery_action_plan.md\nsrc/main.py\n") return SimpleNamespace(
stdout="docs/ouroboros/85_loss_recovery_action_plan.md\nsrc/main.py\n"
)
monkeypatch.setattr(module.subprocess, "run", fake_run) monkeypatch.setattr(module.subprocess, "run", fake_run)
changed = module.load_changed_files(["abc...def"], errors) changed = module.load_changed_files(["abc...def"], errors)

View File

@@ -80,9 +80,7 @@ class TestVolatilityAnalyzer:
# ATR should be roughly the average true range # ATR should be roughly the average true range
assert 3.0 <= atr <= 6.0 assert 3.0 <= atr <= 6.0
def test_calculate_atr_insufficient_data( def test_calculate_atr_insufficient_data(self, volatility_analyzer: VolatilityAnalyzer) -> None:
self, volatility_analyzer: VolatilityAnalyzer
) -> None:
"""Test ATR with insufficient data returns 0.""" """Test ATR with insufficient data returns 0."""
high_prices = [110.0, 112.0] high_prices = [110.0, 112.0]
low_prices = [105.0, 107.0] low_prices = [105.0, 107.0]
@@ -120,17 +118,13 @@ class TestVolatilityAnalyzer:
surge = volatility_analyzer.calculate_volume_surge(1000.0, 0.0) surge = volatility_analyzer.calculate_volume_surge(1000.0, 0.0)
assert surge == 1.0 assert surge == 1.0
def test_calculate_pv_divergence_bullish( def test_calculate_pv_divergence_bullish(self, volatility_analyzer: VolatilityAnalyzer) -> None:
self, volatility_analyzer: VolatilityAnalyzer
) -> None:
"""Test bullish price-volume divergence.""" """Test bullish price-volume divergence."""
# Price up + Volume up = bullish # Price up + Volume up = bullish
divergence = volatility_analyzer.calculate_pv_divergence(5.0, 2.0) divergence = volatility_analyzer.calculate_pv_divergence(5.0, 2.0)
assert divergence > 0.0 assert divergence > 0.0
def test_calculate_pv_divergence_bearish( def test_calculate_pv_divergence_bearish(self, volatility_analyzer: VolatilityAnalyzer) -> None:
self, volatility_analyzer: VolatilityAnalyzer
) -> None:
"""Test bearish price-volume divergence.""" """Test bearish price-volume divergence."""
# Price up + Volume down = bearish divergence # Price up + Volume down = bearish divergence
divergence = volatility_analyzer.calculate_pv_divergence(5.0, 0.5) divergence = volatility_analyzer.calculate_pv_divergence(5.0, 0.5)
@@ -144,9 +138,7 @@ class TestVolatilityAnalyzer:
divergence = volatility_analyzer.calculate_pv_divergence(-5.0, 2.0) divergence = volatility_analyzer.calculate_pv_divergence(-5.0, 2.0)
assert divergence < 0.0 assert divergence < 0.0
def test_calculate_momentum_score( def test_calculate_momentum_score(self, volatility_analyzer: VolatilityAnalyzer) -> None:
self, volatility_analyzer: VolatilityAnalyzer
) -> None:
"""Test momentum score calculation.""" """Test momentum score calculation."""
score = volatility_analyzer.calculate_momentum_score( score = volatility_analyzer.calculate_momentum_score(
price_change_1m=5.0, price_change_1m=5.0,
@@ -500,9 +492,7 @@ class TestMarketScanner:
# Should keep all current stocks since they're all in top movers # Should keep all current stocks since they're all in top movers
assert set(updated) == set(current_watchlist) assert set(updated) == set(current_watchlist)
def test_get_updated_watchlist_max_replacements( def test_get_updated_watchlist_max_replacements(self, scanner: MarketScanner) -> None:
self, scanner: MarketScanner
) -> None:
"""Test that max_replacements limit is respected.""" """Test that max_replacements limit is respected."""
current_watchlist = ["000660", "035420", "005490"] current_watchlist = ["000660", "035420", "005490"]
@@ -556,8 +546,6 @@ class TestMarketScanner:
active_count = 0 active_count = 0
peak_count = 0 peak_count = 0
original_scan = scanner.scan_stock
async def tracking_scan(code: str, market: Any) -> VolatilityMetrics: async def tracking_scan(code: str, market: Any) -> VolatilityMetrics:
nonlocal active_count, peak_count nonlocal active_count, peak_count
active_count += 1 active_count += 1