Compare commits
2 Commits
c76b9d5c15
...
ee66ecc305
| Author | SHA1 | Date | |
|---|---|---|---|
| ee66ecc305 | |||
|
|
065c9daaad |
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
@@ -339,3 +340,171 @@ class TelegramClient:
|
||||
await self._send_notification(
|
||||
NotificationMessage(priority=NotificationPriority.HIGH, message=message)
|
||||
)
|
||||
|
||||
|
||||
class TelegramCommandHandler:
|
||||
"""Handles incoming Telegram commands via long polling."""
|
||||
|
||||
def __init__(
|
||||
self, client: TelegramClient, polling_interval: float = 1.0
|
||||
) -> None:
|
||||
"""
|
||||
Initialize command handler.
|
||||
|
||||
Args:
|
||||
client: TelegramClient instance for sending responses
|
||||
polling_interval: Polling interval in seconds
|
||||
"""
|
||||
self._client = client
|
||||
self._polling_interval = polling_interval
|
||||
self._commands: dict[str, Callable[[], Awaitable[None]]] = {}
|
||||
self._last_update_id = 0
|
||||
self._polling_task: asyncio.Task[None] | None = None
|
||||
self._running = False
|
||||
|
||||
def register_command(
|
||||
self, command: str, handler: Callable[[], Awaitable[None]]
|
||||
) -> None:
|
||||
"""
|
||||
Register a command handler.
|
||||
|
||||
Args:
|
||||
command: Command name (without leading slash, e.g., "start")
|
||||
handler: Async function to handle the command
|
||||
"""
|
||||
self._commands[command] = handler
|
||||
logger.debug("Registered command handler: /%s", command)
|
||||
|
||||
async def start_polling(self) -> None:
|
||||
"""Start long polling for commands."""
|
||||
if self._running:
|
||||
logger.warning("Command handler already running")
|
||||
return
|
||||
|
||||
if not self._client._enabled:
|
||||
logger.info("Command handler disabled (TelegramClient disabled)")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._polling_task = asyncio.create_task(self._poll_loop())
|
||||
logger.info("Started Telegram command polling")
|
||||
|
||||
async def stop_polling(self) -> None:
|
||||
"""Stop polling and cancel pending tasks."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
if self._polling_task:
|
||||
self._polling_task.cancel()
|
||||
try:
|
||||
await self._polling_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("Stopped Telegram command polling")
|
||||
|
||||
async def _poll_loop(self) -> None:
|
||||
"""Main polling loop that fetches updates."""
|
||||
while self._running:
|
||||
try:
|
||||
updates = await self._get_updates()
|
||||
for update in updates:
|
||||
await self._handle_update(update)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as exc:
|
||||
logger.error("Error in polling loop: %s", exc)
|
||||
|
||||
await asyncio.sleep(self._polling_interval)
|
||||
|
||||
async def _get_updates(self) -> list[dict]:
|
||||
"""
|
||||
Fetch updates from Telegram API.
|
||||
|
||||
Returns:
|
||||
List of update objects
|
||||
"""
|
||||
try:
|
||||
url = f"{self._client.API_BASE.format(token=self._client._bot_token)}/getUpdates"
|
||||
payload = {
|
||||
"offset": self._last_update_id + 1,
|
||||
"timeout": int(self._polling_interval),
|
||||
"allowed_updates": ["message"],
|
||||
}
|
||||
|
||||
session = self._client._get_session()
|
||||
async with session.post(url, json=payload) as resp:
|
||||
if resp.status != 200:
|
||||
error_text = await resp.text()
|
||||
logger.error(
|
||||
"getUpdates API error (status=%d): %s", resp.status, error_text
|
||||
)
|
||||
return []
|
||||
|
||||
data = await resp.json()
|
||||
if not data.get("ok"):
|
||||
logger.error("getUpdates returned ok=false: %s", data)
|
||||
return []
|
||||
|
||||
updates = data.get("result", [])
|
||||
if updates:
|
||||
self._last_update_id = updates[-1]["update_id"]
|
||||
|
||||
return updates
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug("getUpdates timeout (normal)")
|
||||
return []
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.error("getUpdates failed: %s", exc)
|
||||
return []
|
||||
except Exception as exc:
|
||||
logger.error("Unexpected error in _get_updates: %s", exc)
|
||||
return []
|
||||
|
||||
async def _handle_update(self, update: dict) -> None:
|
||||
"""
|
||||
Parse and handle a single update.
|
||||
|
||||
Args:
|
||||
update: Update object from Telegram API
|
||||
"""
|
||||
try:
|
||||
message = update.get("message")
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Verify chat_id matches configured chat
|
||||
chat_id = str(message.get("chat", {}).get("id", ""))
|
||||
if chat_id != self._client._chat_id:
|
||||
logger.warning(
|
||||
"Ignoring command from unauthorized chat_id: %s", chat_id
|
||||
)
|
||||
return
|
||||
|
||||
# Extract command text
|
||||
text = message.get("text", "").strip()
|
||||
if not text.startswith("/"):
|
||||
return
|
||||
|
||||
# Parse command (remove leading slash and extract command name)
|
||||
command_parts = text[1:].split()
|
||||
if not command_parts:
|
||||
return
|
||||
|
||||
command_name = command_parts[0]
|
||||
|
||||
# Execute handler
|
||||
handler = self._commands.get(command_name)
|
||||
if handler:
|
||||
logger.info("Executing command: /%s", command_name)
|
||||
await handler()
|
||||
else:
|
||||
logger.debug("Unknown command: /%s", command_name)
|
||||
await self._client.send_message(
|
||||
f"Unknown command: /{command_name}\nUse /help to see available commands."
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("Error handling update: %s", exc)
|
||||
# Don't crash the polling loop on handler errors
|
||||
|
||||
319
tests/test_telegram_commands.py
Normal file
319
tests/test_telegram_commands.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""Tests for Telegram command handler."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.notifications.telegram_client import TelegramClient, TelegramCommandHandler
|
||||
|
||||
|
||||
class TestCommandHandlerInit:
|
||||
"""Test command handler initialization."""
|
||||
|
||||
def test_init_with_client(self) -> None:
|
||||
"""Handler initializes with TelegramClient."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
assert handler._client is client
|
||||
assert handler._polling_interval == 1.0
|
||||
assert handler._commands == {}
|
||||
assert handler._running is False
|
||||
|
||||
def test_custom_polling_interval(self) -> None:
|
||||
"""Handler accepts custom polling interval."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client, polling_interval=2.5)
|
||||
|
||||
assert handler._polling_interval == 2.5
|
||||
|
||||
|
||||
class TestCommandRegistration:
|
||||
"""Test command registration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_command(self) -> None:
|
||||
"""Commands can be registered."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
async def test_handler() -> None:
|
||||
pass
|
||||
|
||||
handler.register_command("test", test_handler)
|
||||
|
||||
assert "test" in handler._commands
|
||||
assert handler._commands["test"] is test_handler
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_register_multiple_commands(self) -> None:
|
||||
"""Multiple commands can be registered."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
async def handler1() -> None:
|
||||
pass
|
||||
|
||||
async def handler2() -> None:
|
||||
pass
|
||||
|
||||
handler.register_command("start", handler1)
|
||||
handler.register_command("help", handler2)
|
||||
|
||||
assert len(handler._commands) == 2
|
||||
assert handler._commands["start"] is handler1
|
||||
assert handler._commands["help"] is handler2
|
||||
|
||||
|
||||
class TestPollingLifecycle:
|
||||
"""Test polling start/stop."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_polling(self) -> None:
|
||||
"""Polling can be started."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
with patch.object(handler, "_poll_loop", new_callable=AsyncMock):
|
||||
await handler.start_polling()
|
||||
|
||||
assert handler._running is True
|
||||
assert handler._polling_task is not None
|
||||
|
||||
await handler.stop_polling()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_polling_disabled_client(self) -> None:
|
||||
"""Polling not started when client disabled."""
|
||||
client = TelegramClient(enabled=False)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
await handler.start_polling()
|
||||
|
||||
assert handler._running is False
|
||||
assert handler._polling_task is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_polling(self) -> None:
|
||||
"""Polling can be stopped."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
with patch.object(handler, "_poll_loop", new_callable=AsyncMock):
|
||||
await handler.start_polling()
|
||||
await handler.stop_polling()
|
||||
|
||||
assert handler._running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_double_start_ignored(self) -> None:
|
||||
"""Starting already running handler is ignored."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
with patch.object(handler, "_poll_loop", new_callable=AsyncMock):
|
||||
await handler.start_polling()
|
||||
task1 = handler._polling_task
|
||||
|
||||
await handler.start_polling() # Second start
|
||||
task2 = handler._polling_task
|
||||
|
||||
# Should be the same task
|
||||
assert task1 is task2
|
||||
|
||||
await handler.stop_polling()
|
||||
|
||||
|
||||
class TestUpdateHandling:
|
||||
"""Test update parsing and handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_valid_command(self) -> None:
|
||||
"""Valid commands are executed."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("test", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_unknown_command(self) -> None:
|
||||
"""Unknown commands send help message."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp) as mock_post:
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/unknown",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
|
||||
# Should send error message
|
||||
assert mock_post.call_count == 1
|
||||
payload = mock_post.call_args.kwargs["json"]
|
||||
assert "Unknown command" in payload["text"]
|
||||
assert "/unknown" in payload["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_unauthorized_chat(self) -> None:
|
||||
"""Commands from unauthorized chats are ignored."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("test", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 999}, # Wrong chat_id
|
||||
"text": "/test",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ignore_non_command_text(self) -> None:
|
||||
"""Non-command text is ignored."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
executed = False
|
||||
|
||||
async def test_command() -> None:
|
||||
nonlocal executed
|
||||
executed = True
|
||||
|
||||
handler.register_command("test", test_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "Hello, not a command",
|
||||
},
|
||||
}
|
||||
|
||||
await handler._handle_update(update)
|
||||
assert executed is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_update_error_isolation(self) -> None:
|
||||
"""Errors in handlers don't crash the system."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
async def failing_command() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
handler.register_command("fail", failing_command)
|
||||
|
||||
update = {
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"chat": {"id": 456},
|
||||
"text": "/fail",
|
||||
},
|
||||
}
|
||||
|
||||
# Should not raise exception
|
||||
await handler._handle_update(update)
|
||||
|
||||
|
||||
class TestGetUpdates:
|
||||
"""Test getUpdates API interaction."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_updates_success(self) -> None:
|
||||
"""getUpdates fetches and parses updates."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(
|
||||
return_value={
|
||||
"ok": True,
|
||||
"result": [
|
||||
{"update_id": 1, "message": {"text": "/test"}},
|
||||
{"update_id": 2, "message": {"text": "/help"}},
|
||||
],
|
||||
}
|
||||
)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
updates = await handler._get_updates()
|
||||
|
||||
assert len(updates) == 2
|
||||
assert updates[0]["update_id"] == 1
|
||||
assert updates[1]["update_id"] == 2
|
||||
assert handler._last_update_id == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_updates_api_error(self) -> None:
|
||||
"""getUpdates handles API errors gracefully."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 400
|
||||
mock_resp.text = AsyncMock(return_value="Bad Request")
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
updates = await handler._get_updates()
|
||||
|
||||
assert updates == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_updates_empty_result(self) -> None:
|
||||
"""getUpdates handles empty results."""
|
||||
client = TelegramClient(bot_token="123:abc", chat_id="456", enabled=True)
|
||||
handler = TelegramCommandHandler(client)
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value={"ok": True, "result": []})
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("aiohttp.ClientSession.post", return_value=mock_resp):
|
||||
updates = await handler._get_updates()
|
||||
|
||||
assert updates == []
|
||||
Reference in New Issue
Block a user