From fd0246769a26af30380e79e80794800bda237c0c Mon Sep 17 00:00:00 2001 From: agentson Date: Sat, 28 Feb 2026 17:13:56 +0900 Subject: [PATCH] test: add sell qty fallback guard and quantity-basis coverage (#322) --- src/main.py | 17 +++++++++++++++-- tests/test_main.py | 13 +++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/main.py b/src/main.py index 6f8272c..17fcbbd 100644 --- a/src/main.py +++ b/src/main.py @@ -110,6 +110,14 @@ DAILY_TRADE_SESSIONS = 4 # Number of trading sessions per day TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions +def _resolve_sell_qty_for_pnl(*, sell_qty: int | None, buy_qty: int | None) -> int: + """Choose quantity basis for SELL outcome PnL with safe fallback.""" + resolved_sell = int(sell_qty or 0) + if resolved_sell > 0: + return resolved_sell + return max(0, int(buy_qty or 0)) + + async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any: """Call an async function retrying on ConnectionError with exponential backoff. @@ -1658,7 +1666,8 @@ async def trading_cycle( buy_trade = get_latest_buy_trade(db_conn, stock_code, market.code) if buy_trade and buy_trade.get("price") is not None: buy_price = float(buy_trade["price"]) - sell_qty = int(quantity or 0) + buy_qty = int(buy_trade.get("quantity") or 0) + sell_qty = _resolve_sell_qty_for_pnl(sell_qty=quantity, buy_qty=buy_qty) trade_pnl = (trade_price - buy_price) * sell_qty decision_logger.update_outcome( decision_id=buy_trade["decision_id"], @@ -2755,7 +2764,11 @@ async def run_daily_session( buy_trade = get_latest_buy_trade(db_conn, stock_code, market.code) if buy_trade and buy_trade.get("price") is not None: buy_price = float(buy_trade["price"]) - sell_qty = int(quantity or 0) + buy_qty = int(buy_trade.get("quantity") or 0) + sell_qty = _resolve_sell_qty_for_pnl( + sell_qty=quantity, + buy_qty=buy_qty, + ) trade_pnl = (trade_price - buy_price) * sell_qty decision_logger.update_outcome( decision_id=buy_trade["decision_id"], diff --git a/tests/test_main.py b/tests/test_main.py index cdb2651..d5ff5c3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -27,6 +27,7 @@ from src.main import ( _extract_held_qty_from_balance, _handle_market_close, _retry_connection, + _resolve_sell_qty_for_pnl, _run_context_scheduler, _run_evolution_loop, _start_dashboard_server, @@ -119,6 +120,18 @@ class TestExtractAvgPriceFromBalance: result = _extract_avg_price_from_balance(balance, "005930", is_domestic=True) assert result == 0.0 + +def test_resolve_sell_qty_for_pnl_prefers_sell_qty() -> None: + assert _resolve_sell_qty_for_pnl(sell_qty=30, buy_qty=100) == 30 + + +def test_resolve_sell_qty_for_pnl_uses_buy_qty_fallback_when_sell_qty_missing() -> None: + assert _resolve_sell_qty_for_pnl(sell_qty=None, buy_qty=12) == 12 + + +def test_resolve_sell_qty_for_pnl_returns_zero_when_both_missing() -> None: + assert _resolve_sell_qty_for_pnl(sell_qty=None, buy_qty=None) == 0 + def test_returns_zero_when_field_empty_string(self) -> None: """Returns 0.0 when pchs_avg_pric is an empty string.""" balance = {"output1": [{"pdno": "005930", "pchs_avg_pric": ""}]}