Merge pull request 'feat: implement Latency Control with criticality-based prioritization (Pillar 1)' (#27) from feature/issue-21-latency-control into main
Some checks failed
CI / test (push) Has been cancelled
Some checks failed
CI / test (push) Has been cancelled
Reviewed-on: #27
This commit was merged in pull request #27.
This commit is contained in:
110
src/core/criticality.py
Normal file
110
src/core/criticality.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Criticality assessment for urgency-based response system.
|
||||
|
||||
Evaluates market conditions to determine response urgency and enable
|
||||
faster reactions in critical situations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class CriticalityLevel(StrEnum):
|
||||
"""Urgency levels for market conditions and trading decisions."""
|
||||
|
||||
CRITICAL = "CRITICAL" # <5s timeout - Emergency response required
|
||||
HIGH = "HIGH" # <30s timeout - Elevated priority
|
||||
NORMAL = "NORMAL" # <60s timeout - Standard processing
|
||||
LOW = "LOW" # No timeout - Batch processing
|
||||
|
||||
|
||||
class CriticalityAssessor:
|
||||
"""Assesses market conditions to determine response criticality level."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
critical_pnl_threshold: float = -2.5,
|
||||
critical_price_change_threshold: float = 5.0,
|
||||
critical_volume_surge_threshold: float = 10.0,
|
||||
high_volatility_threshold: float = 70.0,
|
||||
low_volatility_threshold: float = 30.0,
|
||||
) -> None:
|
||||
"""Initialize the criticality assessor.
|
||||
|
||||
Args:
|
||||
critical_pnl_threshold: P&L % that triggers CRITICAL (default -2.5%)
|
||||
critical_price_change_threshold: Price change % that triggers CRITICAL
|
||||
(default 5.0% in 1 minute)
|
||||
critical_volume_surge_threshold: Volume surge ratio that triggers CRITICAL
|
||||
(default 10x average)
|
||||
high_volatility_threshold: Volatility score that triggers HIGH
|
||||
(default 70.0)
|
||||
low_volatility_threshold: Volatility score below which is LOW
|
||||
(default 30.0)
|
||||
"""
|
||||
self.critical_pnl_threshold = critical_pnl_threshold
|
||||
self.critical_price_change_threshold = critical_price_change_threshold
|
||||
self.critical_volume_surge_threshold = critical_volume_surge_threshold
|
||||
self.high_volatility_threshold = high_volatility_threshold
|
||||
self.low_volatility_threshold = low_volatility_threshold
|
||||
|
||||
def assess_market_conditions(
|
||||
self,
|
||||
pnl_pct: float,
|
||||
volatility_score: float,
|
||||
volume_surge: float,
|
||||
price_change_1m: float = 0.0,
|
||||
is_market_open: bool = True,
|
||||
) -> CriticalityLevel:
|
||||
"""Assess criticality level based on market conditions.
|
||||
|
||||
Args:
|
||||
pnl_pct: Current P&L percentage
|
||||
volatility_score: Momentum score from VolatilityAnalyzer (0-100)
|
||||
volume_surge: Volume surge ratio (current / average)
|
||||
price_change_1m: 1-minute price change percentage
|
||||
is_market_open: Whether the market is currently open
|
||||
|
||||
Returns:
|
||||
CriticalityLevel indicating required response urgency
|
||||
"""
|
||||
# Market closed or very quiet → LOW priority (batch processing)
|
||||
if not is_market_open or volatility_score < self.low_volatility_threshold:
|
||||
return CriticalityLevel.LOW
|
||||
|
||||
# CRITICAL conditions: immediate action required
|
||||
# 1. P&L near circuit breaker (-2.5% is close to -3.0% breaker)
|
||||
if pnl_pct <= self.critical_pnl_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# 2. Large sudden price movement (>5% in 1 minute)
|
||||
if abs(price_change_1m) >= self.critical_price_change_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# 3. Extreme volume surge (>10x average) indicates major event
|
||||
if volume_surge >= self.critical_volume_surge_threshold:
|
||||
return CriticalityLevel.CRITICAL
|
||||
|
||||
# HIGH priority: elevated volatility requires faster response
|
||||
if volatility_score >= self.high_volatility_threshold:
|
||||
return CriticalityLevel.HIGH
|
||||
|
||||
# NORMAL: standard trading conditions
|
||||
return CriticalityLevel.NORMAL
|
||||
|
||||
def get_timeout(self, level: CriticalityLevel) -> float | None:
|
||||
"""Get timeout in seconds for a given criticality level.
|
||||
|
||||
Args:
|
||||
level: Criticality level
|
||||
|
||||
Returns:
|
||||
Timeout in seconds, or None for no timeout (LOW priority)
|
||||
"""
|
||||
timeout_map = {
|
||||
CriticalityLevel.CRITICAL: 5.0,
|
||||
CriticalityLevel.HIGH: 30.0,
|
||||
CriticalityLevel.NORMAL: 60.0,
|
||||
CriticalityLevel.LOW: None,
|
||||
}
|
||||
return timeout_map[level]
|
||||
291
src/core/priority_queue.py
Normal file
291
src/core/priority_queue.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""Priority-based task queue for latency control.
|
||||
|
||||
Implements a thread-safe priority queue with timeout enforcement and metrics tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import heapq
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable, Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from src.core.criticality import CriticalityLevel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PriorityTask:
|
||||
"""Task with priority and timestamp for queue ordering."""
|
||||
|
||||
# Lower priority value = higher urgency (CRITICAL=0, HIGH=1, NORMAL=2, LOW=3)
|
||||
priority: int
|
||||
timestamp: float
|
||||
# Task data not used in comparison
|
||||
task_id: str = field(compare=False)
|
||||
task_data: dict[str, Any] = field(compare=False, default_factory=dict)
|
||||
callback: Callable[[], Coroutine[Any, Any, Any]] | None = field(
|
||||
compare=False, default=None
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueMetrics:
|
||||
"""Metrics for priority queue performance monitoring."""
|
||||
|
||||
total_enqueued: int = 0
|
||||
total_dequeued: int = 0
|
||||
total_timeouts: int = 0
|
||||
total_errors: int = 0
|
||||
current_size: int = 0
|
||||
# Average wait time per criticality level (in seconds)
|
||||
avg_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||
# P95 wait time per criticality level
|
||||
p95_wait_time: dict[CriticalityLevel, float] = field(default_factory=dict)
|
||||
|
||||
|
||||
class PriorityTaskQueue:
|
||||
"""Thread-safe priority queue with timeout enforcement."""
|
||||
|
||||
# Priority mapping for criticality levels
|
||||
PRIORITY_MAP = {
|
||||
CriticalityLevel.CRITICAL: 0,
|
||||
CriticalityLevel.HIGH: 1,
|
||||
CriticalityLevel.NORMAL: 2,
|
||||
CriticalityLevel.LOW: 3,
|
||||
}
|
||||
|
||||
def __init__(self, max_size: int = 1000) -> None:
|
||||
"""Initialize the priority task queue.
|
||||
|
||||
Args:
|
||||
max_size: Maximum queue size (default 1000)
|
||||
"""
|
||||
self._queue: list[PriorityTask] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._max_size = max_size
|
||||
self._metrics = QueueMetrics()
|
||||
# Track wait times for metrics
|
||||
self._wait_times: dict[CriticalityLevel, list[float]] = {
|
||||
level: [] for level in CriticalityLevel
|
||||
}
|
||||
|
||||
async def enqueue(
|
||||
self,
|
||||
task_id: str,
|
||||
criticality: CriticalityLevel,
|
||||
task_data: dict[str, Any],
|
||||
callback: Callable[[], Coroutine[Any, Any, Any]] | None = None,
|
||||
) -> bool:
|
||||
"""Add a task to the priority queue.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
criticality: Criticality level determining priority
|
||||
task_data: Data associated with the task
|
||||
callback: Optional async callback to execute
|
||||
|
||||
Returns:
|
||||
True if enqueued successfully, False if queue is full
|
||||
"""
|
||||
async with self._lock:
|
||||
if len(self._queue) >= self._max_size:
|
||||
logger.warning(
|
||||
"Priority queue full (size=%d), rejecting task %s",
|
||||
len(self._queue),
|
||||
task_id,
|
||||
)
|
||||
return False
|
||||
|
||||
priority = self.PRIORITY_MAP[criticality]
|
||||
timestamp = time.time()
|
||||
|
||||
task = PriorityTask(
|
||||
priority=priority,
|
||||
timestamp=timestamp,
|
||||
task_id=task_id,
|
||||
task_data=task_data,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
heapq.heappush(self._queue, task)
|
||||
self._metrics.total_enqueued += 1
|
||||
self._metrics.current_size = len(self._queue)
|
||||
|
||||
logger.debug(
|
||||
"Enqueued task %s with criticality %s (priority=%d, queue_size=%d)",
|
||||
task_id,
|
||||
criticality.value,
|
||||
priority,
|
||||
len(self._queue),
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def dequeue(self, timeout: float | None = None) -> PriorityTask | None:
|
||||
"""Remove and return the highest priority task from the queue.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for a task (seconds)
|
||||
|
||||
Returns:
|
||||
PriorityTask if available, None if queue is empty or timeout
|
||||
"""
|
||||
start_time = time.time()
|
||||
deadline = start_time + timeout if timeout else None
|
||||
|
||||
while True:
|
||||
async with self._lock:
|
||||
if self._queue:
|
||||
task = heapq.heappop(self._queue)
|
||||
self._metrics.total_dequeued += 1
|
||||
self._metrics.current_size = len(self._queue)
|
||||
|
||||
# Calculate wait time
|
||||
wait_time = time.time() - task.timestamp
|
||||
criticality = self._get_criticality_from_priority(task.priority)
|
||||
self._wait_times[criticality].append(wait_time)
|
||||
self._update_wait_time_metrics()
|
||||
|
||||
logger.debug(
|
||||
"Dequeued task %s (priority=%d, wait_time=%.2fs, queue_size=%d)",
|
||||
task.task_id,
|
||||
task.priority,
|
||||
wait_time,
|
||||
len(self._queue),
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
# Queue is empty
|
||||
if deadline and time.time() >= deadline:
|
||||
return None
|
||||
|
||||
# Wait a bit before checking again
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def execute_with_timeout(
|
||||
self,
|
||||
task: PriorityTask,
|
||||
timeout: float | None,
|
||||
) -> Any:
|
||||
"""Execute a task with timeout enforcement.
|
||||
|
||||
Args:
|
||||
task: Task to execute
|
||||
timeout: Timeout in seconds (None = no timeout)
|
||||
|
||||
Returns:
|
||||
Result from task callback
|
||||
|
||||
Raises:
|
||||
asyncio.TimeoutError: If task exceeds timeout
|
||||
Exception: Any exception raised by the task callback
|
||||
"""
|
||||
if not task.callback:
|
||||
logger.warning("Task %s has no callback, skipping execution", task.task_id)
|
||||
return None
|
||||
|
||||
criticality = self._get_criticality_from_priority(task.priority)
|
||||
|
||||
try:
|
||||
if timeout:
|
||||
result = await asyncio.wait_for(task.callback(), timeout=timeout)
|
||||
else:
|
||||
result = await task.callback()
|
||||
|
||||
logger.debug(
|
||||
"Task %s completed successfully (criticality=%s)",
|
||||
task.task_id,
|
||||
criticality.value,
|
||||
)
|
||||
return result
|
||||
|
||||
except TimeoutError:
|
||||
self._metrics.total_timeouts += 1
|
||||
logger.error(
|
||||
"Task %s timed out after %.2fs (criticality=%s)",
|
||||
task.task_id,
|
||||
timeout or 0.0,
|
||||
criticality.value,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception as exc:
|
||||
self._metrics.total_errors += 1
|
||||
logger.exception(
|
||||
"Task %s failed with error (criticality=%s): %s",
|
||||
task.task_id,
|
||||
criticality.value,
|
||||
exc,
|
||||
)
|
||||
raise
|
||||
|
||||
def _get_criticality_from_priority(self, priority: int) -> CriticalityLevel:
|
||||
"""Convert priority back to criticality level."""
|
||||
for level, prio in self.PRIORITY_MAP.items():
|
||||
if prio == priority:
|
||||
return level
|
||||
return CriticalityLevel.NORMAL
|
||||
|
||||
def _update_wait_time_metrics(self) -> None:
|
||||
"""Update average and p95 wait time metrics."""
|
||||
for level, times in self._wait_times.items():
|
||||
if not times:
|
||||
continue
|
||||
|
||||
# Keep only last 1000 measurements to avoid memory bloat
|
||||
if len(times) > 1000:
|
||||
self._wait_times[level] = times[-1000:]
|
||||
times = self._wait_times[level]
|
||||
|
||||
# Calculate average
|
||||
self._metrics.avg_wait_time[level] = sum(times) / len(times)
|
||||
|
||||
# Calculate P95
|
||||
sorted_times = sorted(times)
|
||||
p95_idx = int(len(sorted_times) * 0.95)
|
||||
self._metrics.p95_wait_time[level] = sorted_times[p95_idx]
|
||||
|
||||
async def get_metrics(self) -> QueueMetrics:
|
||||
"""Get current queue metrics.
|
||||
|
||||
Returns:
|
||||
QueueMetrics with current statistics
|
||||
"""
|
||||
async with self._lock:
|
||||
return QueueMetrics(
|
||||
total_enqueued=self._metrics.total_enqueued,
|
||||
total_dequeued=self._metrics.total_dequeued,
|
||||
total_timeouts=self._metrics.total_timeouts,
|
||||
total_errors=self._metrics.total_errors,
|
||||
current_size=self._metrics.current_size,
|
||||
avg_wait_time=dict(self._metrics.avg_wait_time),
|
||||
p95_wait_time=dict(self._metrics.p95_wait_time),
|
||||
)
|
||||
|
||||
async def size(self) -> int:
|
||||
"""Get current queue size.
|
||||
|
||||
Returns:
|
||||
Number of tasks in queue
|
||||
"""
|
||||
async with self._lock:
|
||||
return len(self._queue)
|
||||
|
||||
async def clear(self) -> int:
|
||||
"""Clear all tasks from the queue.
|
||||
|
||||
Returns:
|
||||
Number of tasks cleared
|
||||
"""
|
||||
async with self._lock:
|
||||
count = len(self._queue)
|
||||
self._queue.clear()
|
||||
self._metrics.current_size = 0
|
||||
logger.info("Cleared %d tasks from priority queue", count)
|
||||
return count
|
||||
88
src/main.py
88
src/main.py
@@ -19,7 +19,10 @@ from src.brain.gemini_client import GeminiClient
|
||||
from src.broker.kis_api import KISBroker
|
||||
from src.broker.overseas import OverseasBroker
|
||||
from src.config import Settings
|
||||
from src.context.layer import ContextLayer
|
||||
from src.context.store import ContextStore
|
||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
||||
from src.core.priority_queue import PriorityTaskQueue
|
||||
from src.core.risk_manager import CircuitBreakerTripped, RiskManager
|
||||
from src.db import init_db, log_trade
|
||||
from src.logging.decision_logger import DecisionLogger
|
||||
@@ -57,10 +60,14 @@ async def trading_cycle(
|
||||
risk: RiskManager,
|
||||
db_conn: Any,
|
||||
decision_logger: DecisionLogger,
|
||||
context_store: ContextStore,
|
||||
criticality_assessor: CriticalityAssessor,
|
||||
market: MarketInfo,
|
||||
stock_code: str,
|
||||
) -> None:
|
||||
"""Execute one trading cycle for a single stock."""
|
||||
cycle_start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# 1. Fetch market data
|
||||
if market.is_domestic:
|
||||
orderbook = await broker.get_orderbook(stock_code)
|
||||
@@ -106,6 +113,42 @@ async def trading_cycle(
|
||||
"foreigner_net": foreigner_net,
|
||||
}
|
||||
|
||||
# 1.5. Get volatility metrics from context store (L7_REALTIME)
|
||||
latest_timeframe = context_store.get_latest_timeframe(ContextLayer.L7_REALTIME)
|
||||
volatility_score = 50.0 # Default normal volatility
|
||||
volume_surge = 1.0
|
||||
price_change_1m = 0.0
|
||||
|
||||
if latest_timeframe:
|
||||
volatility_data = context_store.get_context(
|
||||
ContextLayer.L7_REALTIME,
|
||||
latest_timeframe,
|
||||
f"volatility_{stock_code}",
|
||||
)
|
||||
if volatility_data:
|
||||
volatility_score = volatility_data.get("momentum_score", 50.0)
|
||||
volume_surge = volatility_data.get("volume_surge", 1.0)
|
||||
price_change_1m = volatility_data.get("price_change_1m", 0.0)
|
||||
|
||||
# 1.6. Assess criticality based on market conditions
|
||||
criticality = criticality_assessor.assess_market_conditions(
|
||||
pnl_pct=pnl_pct,
|
||||
volatility_score=volatility_score,
|
||||
volume_surge=volume_surge,
|
||||
price_change_1m=price_change_1m,
|
||||
is_market_open=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Criticality for %s (%s): %s (pnl=%.2f%%, volatility=%.1f, volume_surge=%.1fx)",
|
||||
stock_code,
|
||||
market.name,
|
||||
criticality.value,
|
||||
pnl_pct,
|
||||
volatility_score,
|
||||
volume_surge,
|
||||
)
|
||||
|
||||
# 2. Ask the brain for a decision
|
||||
decision = await brain.decide(market_data)
|
||||
logger.info(
|
||||
@@ -191,6 +234,27 @@ async def trading_cycle(
|
||||
exchange_code=market.exchange_code,
|
||||
)
|
||||
|
||||
# 7. Latency monitoring
|
||||
cycle_end_time = asyncio.get_event_loop().time()
|
||||
cycle_latency = cycle_end_time - cycle_start_time
|
||||
timeout = criticality_assessor.get_timeout(criticality)
|
||||
|
||||
if timeout and cycle_latency > timeout:
|
||||
logger.warning(
|
||||
"Trading cycle exceeded timeout for %s (criticality=%s, latency=%.2fs, timeout=%.2fs)",
|
||||
stock_code,
|
||||
criticality.value,
|
||||
cycle_latency,
|
||||
timeout,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Trading cycle completed within timeout for %s (criticality=%s, latency=%.2fs)",
|
||||
stock_code,
|
||||
criticality.value,
|
||||
cycle_latency,
|
||||
)
|
||||
|
||||
|
||||
async def run(settings: Settings) -> None:
|
||||
"""Main async loop — iterate over open markets on a timer."""
|
||||
@@ -212,6 +276,16 @@ async def run(settings: Settings) -> None:
|
||||
top_n=5,
|
||||
)
|
||||
|
||||
# Initialize latency control system
|
||||
criticality_assessor = CriticalityAssessor(
|
||||
critical_pnl_threshold=-2.5, # Near circuit breaker at -3.0%
|
||||
critical_price_change_threshold=5.0, # 5% in 1 minute
|
||||
critical_volume_surge_threshold=10.0, # 10x average
|
||||
high_volatility_threshold=70.0,
|
||||
low_volatility_threshold=30.0,
|
||||
)
|
||||
priority_queue = PriorityTaskQueue(max_size=1000)
|
||||
|
||||
# Track last scan time for each market
|
||||
last_scan_time: dict[str, float] = {}
|
||||
|
||||
@@ -315,6 +389,8 @@ async def run(settings: Settings) -> None:
|
||||
risk,
|
||||
db_conn,
|
||||
decision_logger,
|
||||
context_store,
|
||||
criticality_assessor,
|
||||
market,
|
||||
stock_code,
|
||||
)
|
||||
@@ -343,6 +419,18 @@ async def run(settings: Settings) -> None:
|
||||
logger.exception("Unexpected error for %s: %s", stock_code, exc)
|
||||
break # Don't retry on unexpected errors
|
||||
|
||||
# Log priority queue metrics periodically
|
||||
metrics = await priority_queue.get_metrics()
|
||||
if metrics.total_enqueued > 0:
|
||||
logger.info(
|
||||
"Priority queue metrics: enqueued=%d, dequeued=%d, size=%d, timeouts=%d, errors=%d",
|
||||
metrics.total_enqueued,
|
||||
metrics.total_dequeued,
|
||||
metrics.current_size,
|
||||
metrics.total_timeouts,
|
||||
metrics.total_errors,
|
||||
)
|
||||
|
||||
# Wait for next cycle or shutdown
|
||||
try:
|
||||
await asyncio.wait_for(shutdown.wait(), timeout=TRADE_INTERVAL_SECONDS)
|
||||
|
||||
558
tests/test_latency_control.py
Normal file
558
tests/test_latency_control.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""Tests for latency control system (criticality assessment and priority queue)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.criticality import CriticalityAssessor, CriticalityLevel
|
||||
from src.core.priority_queue import PriorityTask, PriorityTaskQueue
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CriticalityAssessor Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCriticalityAssessor:
|
||||
"""Test suite for criticality assessment logic."""
|
||||
|
||||
def test_market_closed_returns_low(self) -> None:
|
||||
"""Market closed should return LOW priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=False,
|
||||
)
|
||||
assert level == CriticalityLevel.LOW
|
||||
|
||||
def test_very_low_volatility_returns_low(self) -> None:
|
||||
"""Very low volatility should return LOW priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=20.0, # Below 30.0 threshold
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.LOW
|
||||
|
||||
def test_critical_pnl_threshold_triggered(self) -> None:
|
||||
"""P&L below -2.5% should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-2.6, # Below -2.5% threshold
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_pnl_at_circuit_breaker_proximity(self) -> None:
|
||||
"""P&L at exactly -2.5% (near -3.0% breaker) should be CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-2.5,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_price_change_positive(self) -> None:
|
||||
"""Large positive price change (>5%) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=5.5, # Above 5.0% threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_price_change_negative(self) -> None:
|
||||
"""Large negative price change (<-5%) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=-6.0, # Below -5.0% threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_critical_volume_surge(self) -> None:
|
||||
"""Extreme volume surge (>10x) should trigger CRITICAL."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=12.0, # Above 10.0x threshold
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_high_volatility_returns_high(self) -> None:
|
||||
"""High volatility score should return HIGH priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=75.0, # Above 70.0 threshold
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.HIGH
|
||||
|
||||
def test_normal_conditions_return_normal(self) -> None:
|
||||
"""Normal market conditions should return NORMAL priority."""
|
||||
assessor = CriticalityAssessor()
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.5,
|
||||
volatility_score=50.0, # Between 30-70
|
||||
volume_surge=1.5,
|
||||
price_change_1m=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.NORMAL
|
||||
|
||||
def test_custom_thresholds(self) -> None:
|
||||
"""Custom thresholds should be respected."""
|
||||
assessor = CriticalityAssessor(
|
||||
critical_pnl_threshold=-1.0,
|
||||
critical_price_change_threshold=3.0,
|
||||
critical_volume_surge_threshold=5.0,
|
||||
high_volatility_threshold=60.0,
|
||||
low_volatility_threshold=20.0,
|
||||
)
|
||||
|
||||
# Test custom P&L threshold
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=-1.1,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
# Test custom price change threshold
|
||||
level = assessor.assess_market_conditions(
|
||||
pnl_pct=0.0,
|
||||
volatility_score=50.0,
|
||||
volume_surge=1.0,
|
||||
price_change_1m=3.5,
|
||||
is_market_open=True,
|
||||
)
|
||||
assert level == CriticalityLevel.CRITICAL
|
||||
|
||||
def test_get_timeout_returns_correct_values(self) -> None:
|
||||
"""Timeout values should match specification."""
|
||||
assessor = CriticalityAssessor()
|
||||
|
||||
assert assessor.get_timeout(CriticalityLevel.CRITICAL) == 5.0
|
||||
assert assessor.get_timeout(CriticalityLevel.HIGH) == 30.0
|
||||
assert assessor.get_timeout(CriticalityLevel.NORMAL) == 60.0
|
||||
assert assessor.get_timeout(CriticalityLevel.LOW) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PriorityTaskQueue Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPriorityTaskQueue:
|
||||
"""Test suite for priority queue implementation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_task(self) -> None:
|
||||
"""Tasks should be enqueued successfully."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
success = await queue.enqueue(
|
||||
task_id="test-1",
|
||||
criticality=CriticalityLevel.NORMAL,
|
||||
task_data={"action": "test"},
|
||||
)
|
||||
|
||||
assert success is True
|
||||
assert await queue.size() == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enqueue_rejects_when_full(self) -> None:
|
||||
"""Queue should reject tasks when full."""
|
||||
queue = PriorityTaskQueue(max_size=2)
|
||||
|
||||
# Fill the queue
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Third task should be rejected
|
||||
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
assert success is False
|
||||
assert await queue.size() == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_returns_highest_priority(self) -> None:
|
||||
"""Dequeue should return highest priority task first."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue tasks in reverse priority order
|
||||
await queue.enqueue("low", CriticalityLevel.LOW, {"priority": 3})
|
||||
await queue.enqueue("normal", CriticalityLevel.NORMAL, {"priority": 2})
|
||||
await queue.enqueue("high", CriticalityLevel.HIGH, {"priority": 1})
|
||||
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {"priority": 0})
|
||||
|
||||
# Dequeue should return CRITICAL first
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "critical"
|
||||
assert task.priority == 0
|
||||
|
||||
# Then HIGH
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "high"
|
||||
assert task.priority == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_fifo_within_same_priority(self) -> None:
|
||||
"""Tasks with same priority should be FIFO."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue multiple tasks with same priority
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await asyncio.sleep(0.01) # Small delay to ensure different timestamps
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
await asyncio.sleep(0.01)
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Should dequeue in FIFO order
|
||||
task1 = await queue.dequeue(timeout=1.0)
|
||||
task2 = await queue.dequeue(timeout=1.0)
|
||||
task3 = await queue.dequeue(timeout=1.0)
|
||||
|
||||
assert task1 is not None and task1.task_id == "task-1"
|
||||
assert task2 is not None and task2.task_id == "task-2"
|
||||
assert task3 is not None and task3.task_id == "task-3"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dequeue_returns_none_when_empty(self) -> None:
|
||||
"""Dequeue should return None when queue is empty after timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
task = await queue.dequeue(timeout=0.1)
|
||||
assert task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_success(self) -> None:
|
||||
"""Task execution should succeed within timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a simple async callback
|
||||
async def test_callback() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "success"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=test_callback,
|
||||
)
|
||||
|
||||
result = await queue.execute_with_timeout(task, timeout=1.0)
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_raises_timeout_error(self) -> None:
|
||||
"""Task execution should raise TimeoutError if exceeds timeout."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a slow async callback
|
||||
async def slow_callback() -> str:
|
||||
await asyncio.sleep(1.0)
|
||||
return "too slow"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=slow_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await queue.execute_with_timeout(task, timeout=0.1)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_timeout_propagates_exceptions(self) -> None:
|
||||
"""Task execution should propagate exceptions from callback."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a failing async callback
|
||||
async def failing_callback() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=failing_callback,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await queue.execute_with_timeout(task, timeout=1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_without_timeout(self) -> None:
|
||||
"""Task execution should work without timeout (LOW priority)."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def test_callback() -> str:
|
||||
await asyncio.sleep(0.01)
|
||||
return "success"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=3,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=test_callback,
|
||||
)
|
||||
|
||||
result = await queue.execute_with_timeout(task, timeout=None)
|
||||
assert result == "success"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_metrics(self) -> None:
|
||||
"""Queue should track metrics correctly."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue and dequeue some tasks
|
||||
await queue.enqueue("task-1", CriticalityLevel.CRITICAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.HIGH, {})
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
await queue.dequeue(timeout=1.0)
|
||||
await queue.dequeue(timeout=1.0)
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
|
||||
assert metrics.total_enqueued == 3
|
||||
assert metrics.total_dequeued == 2
|
||||
assert metrics.current_size == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wait_time_metrics(self) -> None:
|
||||
"""Queue should track wait times per criticality level."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Enqueue tasks with different criticality
|
||||
await queue.enqueue("critical-1", CriticalityLevel.CRITICAL, {})
|
||||
await asyncio.sleep(0.05) # Add some wait time
|
||||
|
||||
await queue.dequeue(timeout=1.0)
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
|
||||
# Should have wait time metrics for CRITICAL
|
||||
assert CriticalityLevel.CRITICAL in metrics.avg_wait_time
|
||||
assert metrics.avg_wait_time[CriticalityLevel.CRITICAL] > 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_queue(self) -> None:
|
||||
"""Clear should remove all tasks from queue."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
|
||||
cleared = await queue.clear()
|
||||
|
||||
assert cleared == 3
|
||||
assert await queue.size() == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_enqueue_dequeue(self) -> None:
|
||||
"""Queue should handle concurrent operations safely."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Concurrent enqueue operations
|
||||
async def enqueue_tasks() -> None:
|
||||
for i in range(10):
|
||||
await queue.enqueue(
|
||||
f"task-{i}",
|
||||
CriticalityLevel.NORMAL,
|
||||
{"index": i},
|
||||
)
|
||||
|
||||
# Concurrent dequeue operations
|
||||
async def dequeue_tasks() -> list[str]:
|
||||
tasks = []
|
||||
for _ in range(10):
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
if task:
|
||||
tasks.append(task.task_id)
|
||||
await asyncio.sleep(0.01)
|
||||
return tasks
|
||||
|
||||
# Run both concurrently
|
||||
enqueue_task = asyncio.create_task(enqueue_tasks())
|
||||
dequeue_task = asyncio.create_task(dequeue_tasks())
|
||||
|
||||
await enqueue_task
|
||||
dequeued_ids = await dequeue_task
|
||||
|
||||
# All tasks should be processed
|
||||
assert len(dequeued_ids) == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_metric_tracking(self) -> None:
|
||||
"""Queue should track timeout occurrences."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def slow_callback() -> str:
|
||||
await asyncio.sleep(1.0)
|
||||
return "too slow"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=slow_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await queue.execute_with_timeout(task, timeout=0.1)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
assert metrics.total_timeouts == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_metric_tracking(self) -> None:
|
||||
"""Queue should track execution errors."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
async def failing_callback() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0,
|
||||
timestamp=0.0,
|
||||
task_id="test",
|
||||
task_data={},
|
||||
callback=failing_callback,
|
||||
)
|
||||
|
||||
try:
|
||||
await queue.execute_with_timeout(task, timeout=1.0)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
metrics = await queue.get_metrics()
|
||||
assert metrics.total_errors == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLatencyControlIntegration:
|
||||
"""Integration tests for criticality assessment and priority queue."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_critical_task_bypass_queue(self) -> None:
|
||||
"""CRITICAL tasks should bypass lower priority tasks."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Add normal priority tasks
|
||||
await queue.enqueue("normal-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("normal-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Add critical task (should jump to front)
|
||||
await queue.enqueue("critical", CriticalityLevel.CRITICAL, {})
|
||||
|
||||
# Dequeue should return critical first
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
assert task.task_id == "critical"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_enforcement_by_criticality(self) -> None:
|
||||
"""Timeout enforcement should match criticality level."""
|
||||
assessor = CriticalityAssessor()
|
||||
|
||||
# CRITICAL should have 5s timeout
|
||||
critical_timeout = assessor.get_timeout(CriticalityLevel.CRITICAL)
|
||||
assert critical_timeout == 5.0
|
||||
|
||||
# HIGH should have 30s timeout
|
||||
high_timeout = assessor.get_timeout(CriticalityLevel.HIGH)
|
||||
assert high_timeout == 30.0
|
||||
|
||||
# NORMAL should have 60s timeout
|
||||
normal_timeout = assessor.get_timeout(CriticalityLevel.NORMAL)
|
||||
assert normal_timeout == 60.0
|
||||
|
||||
# LOW should have no timeout
|
||||
low_timeout = assessor.get_timeout(CriticalityLevel.LOW)
|
||||
assert low_timeout is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_path_execution_for_critical(self) -> None:
|
||||
"""CRITICAL tasks should complete quickly."""
|
||||
queue = PriorityTaskQueue()
|
||||
|
||||
# Create a fast callback simulating fast-path execution
|
||||
async def fast_path_callback() -> str:
|
||||
# Simulate simplified decision flow
|
||||
await asyncio.sleep(0.01) # Very fast execution
|
||||
return "fast_path_complete"
|
||||
|
||||
task = PriorityTask(
|
||||
priority=0, # CRITICAL
|
||||
timestamp=0.0,
|
||||
task_id="critical-fast",
|
||||
task_data={},
|
||||
callback=fast_path_callback,
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
result = await queue.execute_with_timeout(task, timeout=5.0)
|
||||
elapsed = time.time() - start
|
||||
|
||||
assert result == "fast_path_complete"
|
||||
assert elapsed < 5.0 # Should complete well under CRITICAL timeout
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graceful_degradation_when_queue_full(self) -> None:
|
||||
"""System should gracefully handle full queue."""
|
||||
queue = PriorityTaskQueue(max_size=2)
|
||||
|
||||
# Fill the queue
|
||||
await queue.enqueue("task-1", CriticalityLevel.NORMAL, {})
|
||||
await queue.enqueue("task-2", CriticalityLevel.NORMAL, {})
|
||||
|
||||
# Try to add more tasks
|
||||
success = await queue.enqueue("task-3", CriticalityLevel.NORMAL, {})
|
||||
assert success is False
|
||||
|
||||
# Queue should still function
|
||||
task = await queue.dequeue(timeout=1.0)
|
||||
assert task is not None
|
||||
|
||||
# Now we can add another task
|
||||
success = await queue.enqueue("task-4", CriticalityLevel.NORMAL, {})
|
||||
assert success is True
|
||||
Reference in New Issue
Block a user