Source code for lavalink.node

from __future__ import annotations

import asyncio
import secrets
import string
import typing
from collections import namedtuple
from typing import KeysView, List, Optional, ValuesView

import aiohttp
from discord.backoff import ExponentialBackoff
from discord.ext.commands import Bot

from . import log, ws_ll_log, ws_rll_log, __version__
from .enums import LavalinkEvents, LavalinkIncomingOp, LavalinkOutgoingOp, NodeState, PlayerState
from .player import Player
from .rest_api import Track
from .utils import VoiceChannel, is_loop_closed
from .errors import AbortingNodeConnection, NodeNotReady, NodeNotFound, PlayerNotFound

__all__ = [
    "Stats",
    "Node",
    "NodeStats",
    "get_node",
    "get_nodes_stats",
    "get_all_nodes",
]

_nodes: List[Node] = []

PositionTime = namedtuple("PositionTime", "position time connected")
MemoryInfo = namedtuple("MemoryInfo", "reservable used free allocated")
CPUInfo = namedtuple("CPUInfo", "cores systemLoad lavalinkLoad")


# Originally Added in: https://github.com/PythonistaGuild/Wavelink/pull/66
class _Key:
    def __init__(self, length: int = 32):
        self.length: int = length
        self.persistent: str = ""
        self.__repr__()

    def __repr__(self):
        """Generate a new key, return it and make it persistent"""
        alphabet = string.ascii_letters + string.digits + "#$%&()*+,-./:;<=>?@[]^_~!"
        key = "".join(secrets.choice(alphabet) for _ in range(self.length))
        self.persistent = key
        return key

    def __str__(self):
        """Return the persistent key."""
        # Ensure output is not a non-string
        # Since input could be Any object.
        if not self.persistent:
            return self.__repr__()
        return str(self.persistent)


[docs]class Stats: def __init__(self, memory, players, active_players, cpu, uptime): self.memory = MemoryInfo(**memory) self.players = players self.active_players = active_players self.cpu_info = CPUInfo(**cpu) self.uptime = uptime
# Node stats related class below and how it is called is originally from: # https://github.com/PythonistaGuild/Wavelink/blob/abba49e9806af3c50886f82054ea603129ad08b9/wavelink/stats.py#L41 # https://github.com/PythonistaGuild/Wavelink/blob/abba49e9806af3c50886f82054ea603129ad08b9/wavelink/websocket.py#L132 class NodeStats: def __init__(self, data: dict): self.uptime = data["uptime"] self.players = data["players"] self.playing_players = data["playingPlayers"] memory = data["memory"] self.memory_free = memory["free"] self.memory_used = memory["used"] self.memory_allocated = memory["allocated"] self.memory_reservable = memory["reservable"] cpu = data["cpu"] self.cpu_cores = cpu["cores"] self.system_load = cpu["systemLoad"] self.lavalink_load = cpu["lavalinkLoad"] frame_stats = data.get("frameStats", {}) self.frames_sent = frame_stats.get("sent", -1) self.frames_nulled = frame_stats.get("nulled", -1) self.frames_deficit = frame_stats.get("deficit", -1) def __repr__(self): return ( "<NodeStats: " f"uptime={self.uptime}, " f"players={self.players}, " f"playing_players={self.playing_players}, " f"memory_free={self.memory_free}, memory_used={self.memory_used}, " f"cpu_cores={self.cpu_cores}, system_load={self.system_load}, " f"lavalink_load={self.lavalink_load}>" )
[docs]class Node: _is_shutdown: bool = False def __init__( self, *, event_handler: typing.Callable, host: str, password: str, user_id: int, num_shards: int, port: Optional[int] = None, resume_key: Optional[str] = None, resume_timeout: float = 60, bot: Bot = None, secured: bool = False, ): """ Represents a Lavalink node. Parameters ---------- event_handler Function to dispatch events to. host : str Lavalink player host. password : str Password for the Lavalink player. port : Optional[int] Port of the Lavalink player event websocket. user_id : int User ID of the bot. num_shards : int Number of shards to which the bot is currently connected. resume_key : Optional[str] A resume key used for resuming a session upon re-establishing a WebSocket connection to Lavalink. resume_timeout : float How long the node should wait for a connection while disconnected before clearing all players. bot: AutoShardedBot The Bot object that connects to discord. """ self.bot = bot self.event_handler = event_handler self.host = host self.secured = secured if port is None: if self.secured: self.port = 443 else: self.port = 80 else: self.port = port self.password = password self._resume_key = resume_key if self._resume_key is None: self._resume_key = self._gen_key() self._resume_timeout = resume_timeout self._resuming_configured = False self.num_shards = num_shards self.user_id = user_id self._ready_event = asyncio.Event() self._ws = None self._listener_task = None self.session = aiohttp.ClientSession() self.reconnect_task = None self.try_connect_task = None self._queue: List = [] self._players_dict = {} self.state = NodeState.CONNECTING self._state_handlers: List = [] self._retries = 0 self.stats = None if self not in _nodes: _nodes.append(self) self._closers = ( aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSED, ) self.register_state_handler(self.node_state_handler) def __repr__(self): return ( "<Node: " f"state={self.state.name}, " f"host={self.host}, " f"port={self.port}, " f"password={'*' * len(self.password)}, resume_key={self._resume_key}, " f"shards={self.num_shards}, user={self.user_id}, stats={self.stats}>" ) @property def headers(self) -> dict: return self._get_connect_headers() @property def players(self) -> ValuesView[Player]: return self._players_dict.values() @property def guild_ids(self) -> KeysView[int]: return self._players_dict.keys() def _gen_key(self): if self._resume_key is None: return _Key() else: # if this is a class then it will generate a persistent key # We should not't check the instance since # we would still make 1 extra call to check, which is useless. self._resume_key.__repr__() return self._resume_key
[docs] async def connect(self, timeout: float = None, *, shutdown: bool = False): """ Connects to the Lavalink player event websocket. Parameters ---------- timeout : float Time after which to timeout on attempting to connect to the Lavalink websocket, ``None`` is considered never, but the underlying code may stop trying past a certain point. shutdown : bool Whether the node was told to shut down Raises ------ asyncio.TimeoutError If the websocket failed to connect after the given time. AbortingNodeConnection If the connection attempt must be aborted during a reconnect attempt """ self._is_shutdown = shutdown if self.secured: uri = f"wss://{self.host}:{self.port}" else: uri = f"ws://{self.host}:{self.port}" ws_ll_log.info("Lavalink WS connecting to %s with headers %s", uri, self.headers) if self.try_connect_task is not None: self.try_connect_task.cancel() self.try_connect_task = asyncio.create_task(self._multi_try_connect(uri)) try: await asyncio.wait_for(self.try_connect_task, timeout=timeout) except asyncio.CancelledError: raise AbortingNodeConnection
async def _configure_resume(self): if self._resuming_configured: return if self._resume_key and self._resume_timeout and self._resume_timeout > 0: await self.send( dict( op="configureResuming", key=str(self._resume_key), timeout=self._resume_timeout, ) ) self._resuming_configured = True ws_ll_log.debug("Node Resuming has been configured.") async def wait_until_ready(self, *, timeout: Optional[float] = None): await asyncio.wait_for(self._ready_event.wait(), timeout=timeout) def _get_connect_headers(self) -> dict: # Num-Shards is not used on Lavalink jar files >= v3.4 # but kept for compatibility to avoid NPEs on older builds headers = { "Authorization": self.password, "User-Id": str(self.user_id), "Num-Shards": str(self.num_shards), "Client-Name": f"Red-Lavalink/{__version__}", } if self._resume_key: headers["Resume-Key"] = str(self._resume_key) return headers @property def lavalink_major_version(self): if not self.ready: raise NodeNotReady("Node not ready!") return self._ws.response_headers.get("Lavalink-Major-Version") @property def ready(self) -> bool: """ Whether the underlying node is ready for requests. """ return self.state == NodeState.READY async def _multi_try_connect(self, uri): backoff = ExponentialBackoff() attempt = 1 if self._listener_task is not None: self._listener_task.cancel() if self._ws is not None: await self._ws.close(code=4006, message=b"Reconnecting") while self._is_shutdown is False and (self._ws is None or self._ws.closed): self._retries += 1 if self._is_shutdown is True: ws_ll_log.error("Lavalink node was shutdown during a connect attempt.") raise asyncio.CancelledError try: ws = await self.session.ws_connect(url=uri, headers=self.headers, heartbeat=60) except (OSError, aiohttp.ClientConnectionError): if attempt > 5: raise asyncio.TimeoutError delay = backoff.delay() ws_ll_log.warning("Failed connect attempt %s, retrying in %s", attempt, delay) await asyncio.sleep(delay) attempt += 1 except aiohttp.WSServerHandshakeError: ws_ll_log.error("Failed connect WSServerHandshakeError") raise asyncio.TimeoutError else: if self._is_shutdown is True: ws_ll_log.error("Lavalink node was shutdown during a connect attempt.") raise asyncio.CancelledError self.session_resumed = ws._response.headers.get("Session-Resumed", False) if self._ws is not None and self.session_resumed: ws_ll_log.info("WEBSOCKET Resumed Session with key: %s", self._resume_key) self._ws = ws break if self._is_shutdown is True: raise asyncio.CancelledError ws_ll_log.info("Lavalink WS connected to %s", uri) ws_ll_log.debug("Creating Lavalink WS listener.") if self._is_shutdown is False: self._listener_task = asyncio.create_task(self.listener()) asyncio.create_task(self._configure_resume()) if self._queue: temp = self._queue.copy() self._queue.clear() for data in temp: await self.send(data) self._ready_event.set() self.update_state(NodeState.READY)
[docs] async def listener(self): """ Listener task for receiving ops from Lavalink. """ while self._is_shutdown is False: msg = await self._ws.receive() if msg.type in self._closers: if self._resuming_configured: if self.state != NodeState.RECONNECTING: if self.reconnect_task is not None: self.reconnect_task.cancel() ws_ll_log.info("[NODE] | NODE Resuming: %s", msg.extra) self.update_state(NodeState.RECONNECTING) self.reconnect_task = asyncio.create_task( self._reconnect(shutdown=self._is_shutdown) ) return else: ws_ll_log.info("[NODE] | Listener closing: %s", msg.extra) break elif msg.type == aiohttp.WSMsgType.TEXT: data = msg.json() try: op = LavalinkIncomingOp(data.get("op")) except ValueError: ws_ll_log.verbose("[NODE] | Received unknown op: %s", data) else: ws_ll_log.trace("[NODE] | Received known op: %s", data) asyncio.create_task(self._handle_op(op, data)) elif msg.type == aiohttp.WSMsgType.ERROR: exc = self._ws.exception() ws_ll_log.warning( "[NODE] | An exception occurred on the websocket - Attempting to reconnect" ) ws_ll_log.debug("[NODE] | Exception in WebSocket!", exc_info=exc) break else: ws_ll_log.debug( "[NODE] | WebSocket connection received unexpected message: %s:%s", msg.type, msg.data, ) if self.state != NodeState.RECONNECTING and not self._is_shutdown: ws_ll_log.warning( "[NODE] | %s - WS %s SHUTDOWN %s.", self, not self._ws.closed, self._is_shutdown ) if self.reconnect_task is not None: self.reconnect_task.cancel() self.update_state(NodeState.RECONNECTING) self.reconnect_task = asyncio.create_task(self._reconnect(shutdown=self._is_shutdown))
async def _handle_op(self, op: LavalinkIncomingOp, data): if op == LavalinkIncomingOp.EVENT: try: event = LavalinkEvents(data.get("type")) except ValueError: ws_ll_log.verbose("Unknown event type: %s", data) else: self.event_handler(op, event, data) elif op == LavalinkIncomingOp.PLAYER_UPDATE: state = data.get("state", {}) position = PositionTime( position=state.get("position", 0), time=state.get("time", 0), connected=state.get("connected", False), ) self.event_handler(op, position, data) elif op == LavalinkIncomingOp.STATS: stats = Stats( memory=data.get("memory"), players=data.get("players"), active_players=data.get("playingPlayers"), cpu=data.get("cpu"), uptime=data.get("uptime"), ) self.stats = NodeStats(data) self.event_handler(op, stats, data) else: ws_ll_log.verbose("Unknown op type: %r", data) async def _reconnect(self, *, shutdown: bool = False): self._ready_event.clear() if self._is_shutdown is True or shutdown: ws_ll_log.info("[NODE] | Shutting down Lavalink WS.") return if self.state != NodeState.CONNECTING: self.update_state(NodeState.RECONNECTING) if self.state != NodeState.RECONNECTING: return backoff = ExponentialBackoff(base=1) attempt = 1 while self.state == NodeState.RECONNECTING: attempt += 1 if attempt > 10: ws_ll_log.info("[NODE] | Failed reconnection attempt too many times, aborting ...") asyncio.create_task(self.disconnect()) return try: await self.connect(shutdown=shutdown) except AbortingNodeConnection: return except asyncio.TimeoutError: delay = backoff.delay() ws_ll_log.warning( "[NODE] | Lavalink WS reconnect attempt %s, retrying in %s", attempt, delay, ) await asyncio.sleep(delay) else: ws_ll_log.info("[NODE] | Reconnect successful.") self.dispatch_reconnect() self._retries = 0 def dispatch_reconnect(self): for guild_id in self.guild_ids: self.event_handler( LavalinkIncomingOp.EVENT, LavalinkEvents.WEBSOCKET_CLOSED, { "guildId": guild_id, "code": 42069, "reason": "Lavalink WS reconnected", "byRemote": True, "retries": self._retries, }, ) def update_state(self, next_state: NodeState): if next_state == self.state: return ws_ll_log.verbose("Changing node state: %s -> %s", self.state.name, next_state.name) old_state = self.state self.state = next_state if is_loop_closed(): ws_ll_log.debug("Event loop closed, not notifying state handlers.") return for handler in self._state_handlers: asyncio.create_task(handler(next_state, old_state)) def register_state_handler(self, func): if not asyncio.iscoroutinefunction(func): raise ValueError("Argument must be a coroutine object.") if func not in self._state_handlers: self._state_handlers.append(func) def unregister_state_handler(self, func): self._state_handlers.remove(func)
[docs] async def create_player(self, channel: VoiceChannel, *, self_deaf: bool = False) -> Player: """ Connects to a discord voice channel. This function is safe to repeatedly call as it will return an existing player if there is one. Parameters ---------- channel: VoiceChannel self_deaf: bool Returns ------- Player The created Player object. """ if self._already_in_guild(channel): player = self.get_player(channel.guild.id) await player.move_to(channel, self_deaf=self_deaf) else: player: Player = await channel.connect(cls=Player, self_deaf=self_deaf) # type: ignore return player
def _already_in_guild(self, channel: VoiceChannel) -> bool: return channel.guild.id in self._players_dict
[docs] def get_player(self, guild_id: int) -> Player: """ Gets a Player object from a guild ID. Parameters ---------- guild_id : int Discord guild ID. Returns ------- Player Raises ------ KeyError If that guild does not have a Player, e.g. is not connected to any voice channel. """ if guild_id in self._players_dict: return self._players_dict[guild_id] raise PlayerNotFound("No such player for that guild.")
async def node_state_handler(self, next_state: NodeState, old_state: NodeState): ws_rll_log.debug("Received node state update: %s -> %s", old_state.name, next_state.name) if next_state == NodeState.READY: await self.update_player_states(PlayerState.READY) elif next_state == NodeState.DISCONNECTING: await self.update_player_states(PlayerState.DISCONNECTING) elif next_state in (NodeState.CONNECTING, NodeState.RECONNECTING): await self.update_player_states(PlayerState.NODE_BUSY) async def update_player_states(self, state: PlayerState): for player in self.players: await player.update_state(state) async def refresh_player_state(self, player: Player): if self.ready: await player.update_state(PlayerState.READY) elif self.state == NodeState.DISCONNECTING: await player.update_state(PlayerState.DISCONNECTING) else: await player.update_state(PlayerState.NODE_BUSY) def remove_player(self, player: Player): if player.state != PlayerState.DISCONNECTING: log.error( "Attempting to remove a player (%r) from player list with state: %s", player, player.state.name, ) return guild_id = player.channel.guild.id if guild_id in self._players_dict: del self._players_dict[guild_id]
[docs] async def disconnect(self): """ Shuts down and disconnects the websocket. """ global _nodes self._is_shutdown = True self._ready_event.clear() self._queue.clear() if ( self.try_connect_task is not None and not self.try_connect_task.cancelled() and not is_loop_closed() ): self.try_connect_task.cancel() if ( self.reconnect_task is not None and not self.reconnect_task.cancelled() and not is_loop_closed() ): self.reconnect_task.cancel() self.update_state(NodeState.DISCONNECTING) if self._resuming_configured and not (self._ws is None or self._ws.closed): await self.send(dict(op="configureResuming", key=None)) self._resuming_configured = False for p in tuple(self.players): await p.disconnect(force=True) log.debug("Disconnected all players.") if self._ws is not None and not self._ws.closed: await self._ws.close() if ( self._listener_task is not None and not self._listener_task.cancelled() and not is_loop_closed() ): self._listener_task.cancel() await self.session.close() self._state_handlers = [] if len(_nodes) == 1: _nodes = [] elif len(_nodes) > 1: _nodes.remove(self) ws_ll_log.info("Shutdown Lavalink WS.")
async def send(self, data): if self._ws is None or self._ws.closed: self._queue.append(data) else: ws_ll_log.trace("Sending data to Lavalink node: %s", data) await self._ws.send_json(data) async def send_lavalink_voice_update(self, guild_id, session_id, event): await self.send( { "op": LavalinkOutgoingOp.VOICE_UPDATE.value, "guildId": str(guild_id), "sessionId": session_id, "event": event, } ) async def destroy_guild(self, guild_id: int): await self.send({"op": LavalinkOutgoingOp.DESTROY.value, "guildId": str(guild_id)}) async def no_event_stop(self, guild_id: int): await self.send({"op": LavalinkOutgoingOp.STOP.value, "guildId": str(guild_id)}) # Player commands async def stop(self, guild_id: int): await self.no_event_stop(guild_id=guild_id) self.event_handler( LavalinkIncomingOp.EVENT, LavalinkEvents.QUEUE_END, {"guildId": str(guild_id)} ) async def no_stop_play( self, guild_id: int, track: Track, replace: bool = True, start: int = 0, pause: bool = False, ): await self.send( { "op": LavalinkOutgoingOp.PLAY.value, "guildId": str(guild_id), "track": track.track_identifier, "noReplace": not replace, "startTime": str(start), "pause": pause, } ) async def play( self, guild_id: int, track: Track, replace: bool = True, start: int = 0, pause: bool = False, ): # await self.send({"op": LavalinkOutgoingOp.STOP.value, "guildId": str(guild_id)}) await self.no_stop_play( guild_id=guild_id, track=track, replace=replace, start=start, pause=pause ) async def pause(self, guild_id, paused): await self.send( {"op": LavalinkOutgoingOp.PAUSE.value, "guildId": str(guild_id), "pause": paused} ) async def volume(self, guild_id: int, _volume: int): await self.send( {"op": LavalinkOutgoingOp.VOLUME.value, "guildId": str(guild_id), "volume": _volume} ) async def seek(self, guild_id: int, position: int): await self.send( {"op": LavalinkOutgoingOp.SEEK.value, "guildId": str(guild_id), "position": position} )
def get_node(guild_id: int = None, *, ignore_ready_status: bool = False) -> Node: """ Gets a node based on a guild ID, useful for noding separation. If the guild ID does not already have a node association, the least used node is returned. Skips over nodes that are not yet ready. Parameters ---------- guild_id : int ignore_ready_status : bool Returns ------- Node """ guild_count = 1e10 least_used = None for node in _nodes: guild_ids = node.guild_ids if ignore_ready_status is False and not node.ready: continue elif len(guild_ids) < guild_count: guild_count = len(guild_ids) least_used = node if guild_id in guild_ids: return node if least_used is None: raise NodeNotFound("No Lavalink nodes found.") return least_used def get_nodes_stats(): return [node.stats for node in _nodes] def get_all_nodes() -> List[Node]: return [node for node in _nodes] async def disconnect(): for node in _nodes.copy(): await node.disconnect()