diff --git a/src/main.py b/src/main.py index 4bb7b60..34f850f 100644 --- a/src/main.py +++ b/src/main.py @@ -110,6 +110,14 @@ DAILY_TRADE_SESSIONS = 4 # Number of trading sessions per day TRADE_SESSION_INTERVAL_HOURS = 6 # Hours between sessions +def _resolve_sell_qty_for_pnl(*, sell_qty: int | None, buy_qty: int | None) -> int: + """Choose quantity basis for SELL outcome PnL with safe fallback.""" + resolved_sell = int(sell_qty or 0) + if resolved_sell > 0: + return resolved_sell + return max(0, int(buy_qty or 0)) + + async def _retry_connection(coro_factory: Any, *args: Any, label: str = "", **kwargs: Any) -> Any: """Call an async function retrying on ConnectionError with exponential backoff. @@ -1667,8 +1675,9 @@ async def trading_cycle( ) if buy_trade and buy_trade.get("price") is not None: buy_price = float(buy_trade["price"]) - buy_qty = int(buy_trade.get("quantity") or 1) - trade_pnl = (trade_price - buy_price) * buy_qty + buy_qty = int(buy_trade.get("quantity") or 0) + sell_qty = _resolve_sell_qty_for_pnl(sell_qty=quantity, buy_qty=buy_qty) + trade_pnl = (trade_price - buy_price) * sell_qty decision_logger.update_outcome( decision_id=buy_trade["decision_id"], pnl=trade_pnl, @@ -2772,8 +2781,12 @@ async def run_daily_session( ) if buy_trade and buy_trade.get("price") is not None: buy_price = float(buy_trade["price"]) - buy_qty = int(buy_trade.get("quantity") or 1) - trade_pnl = (trade_price - buy_price) * buy_qty + buy_qty = int(buy_trade.get("quantity") or 0) + sell_qty = _resolve_sell_qty_for_pnl( + sell_qty=quantity, + buy_qty=buy_qty, + ) + trade_pnl = (trade_price - buy_price) * sell_qty decision_logger.update_outcome( decision_id=buy_trade["decision_id"], pnl=trade_pnl, diff --git a/tests/test_main.py b/tests/test_main.py index 63ee0da..d5ff5c3 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -27,6 +27,7 @@ from src.main import ( _extract_held_qty_from_balance, _handle_market_close, _retry_connection, + _resolve_sell_qty_for_pnl, _run_context_scheduler, _run_evolution_loop, _start_dashboard_server, @@ -119,6 +120,18 @@ class TestExtractAvgPriceFromBalance: result = _extract_avg_price_from_balance(balance, "005930", is_domestic=True) assert result == 0.0 + +def test_resolve_sell_qty_for_pnl_prefers_sell_qty() -> None: + assert _resolve_sell_qty_for_pnl(sell_qty=30, buy_qty=100) == 30 + + +def test_resolve_sell_qty_for_pnl_uses_buy_qty_fallback_when_sell_qty_missing() -> None: + assert _resolve_sell_qty_for_pnl(sell_qty=None, buy_qty=12) == 12 + + +def test_resolve_sell_qty_for_pnl_returns_zero_when_both_missing() -> None: + assert _resolve_sell_qty_for_pnl(sell_qty=None, buy_qty=None) == 0 + def test_returns_zero_when_field_empty_string(self) -> None: """Returns 0.0 when pchs_avg_pric is an empty string.""" balance = {"output1": [{"pdno": "005930", "pchs_avg_pric": ""}]} @@ -2750,6 +2763,9 @@ async def test_sell_order_uses_broker_balance_qty_not_db() -> None: assert call_kwargs["order_type"] == "SELL" # Must use broker-confirmed qty (5), NOT DB-recorded ordered qty (10) assert call_kwargs["quantity"] == 5 + updated_buy = decision_logger.get_decision_by_id(buy_decision_id) + assert updated_buy is not None + assert updated_buy.outcome_pnl == -25.0 @pytest.mark.asyncio