import re
from collections import namedtuple
from typing import Dict, Tuple, Union
from urllib.parse import quote, urlparse
import aiohttp
import discord
from aiohttp.client_exceptions import ServerDisconnectedError
from . import log
from .enums import ExceptionSeverity, LoadType, PlayerState
from .utils import VoiceChannel
__all__ = ("Track", "RESTClient", "PlaylistInfo")
_PlaylistInfo = namedtuple("PlaylistInfo", "name selectedTrack")
# This exists to preprocess rather than pull in dataclasses for __post_init__
def PlaylistInfo(name=None, selectedTrack=None):
return _PlaylistInfo(
name if name is not None else "Unknown",
selectedTrack if selectedTrack is not None else -1,
)
_re_youtube_timestamp = re.compile(r"[&?]t=(\d+)s?")
_re_soundcloud_timestamp = re.compile(r"#t=(\d+):(\d+)s?")
_re_twitch_timestamp = re.compile(r"\?t=(\d+)h(\d+)m(\d+)s")
def parse_timestamps(data):
if data["loadType"] == LoadType.PLAYLIST_LOADED:
return data["tracks"]
new_tracks = []
query = data["query"]
try:
query_url = urlparse(query)
except:
query_url = None
if not query_url:
return data["tracks"]
for track in data["tracks"]:
start_time = 0
try:
if all([query_url.scheme, query_url.netloc, query_url.path]) or any(
x in query for x in ["ytsearch:", "scsearch:"]
):
url_domain = ".".join(query_url.netloc.split(".")[-2:])
if not query_url.netloc:
url_domain = ".".join(query_url.path.split("/")[0].split(".")[-2:])
if (
(url_domain in ["youtube.com", "youtu.be"] or "ytsearch:" in query)
and any(x in query for x in ["&t=", "?t="])
and not all(k in query for k in ["playlist?", "&list="])
):
match = re.search(_re_youtube_timestamp, query)
if match:
start_time = int(match.group(1))
elif (url_domain == "soundcloud.com" or "scsearch:" in query) and "#t=" in query:
if "/sets/" not in query or ("/sets/" in query and "?in=" in query):
match = re.search(_re_soundcloud_timestamp, query)
if match:
start_time = (int(match.group(1)) * 60) + int(match.group(2))
elif url_domain == "twitch.tv" and "?t=" in query:
match = re.search(_re_twitch_timestamp, query)
if match:
start_time = (
(int(match.group(1)) * 60 * 60)
+ (int(match.group(2)) * 60)
+ int(match.group(3))
)
except Exception:
pass
track["info"]["timestamp"] = start_time * 1000
new_tracks.append(track)
return new_tracks
def reformat_query(query):
try:
query_url = urlparse(query)
if all([query_url.scheme, query_url.netloc, query_url.path]) or any(
x in query for x in ["ytsearch:", "scsearch:"]
):
url_domain = ".".join(query_url.netloc.split(".")[-2:])
if not query_url.netloc:
url_domain = ".".join(query_url.path.split("/")[0].split(".")[-2:])
if (
(url_domain in ["youtube.com", "youtu.be"] or "ytsearch:" in query)
and any(x in query for x in ["&t=", "?t="])
and not all(k in query for k in ["playlist?", "&list="])
):
match = re.search(_re_youtube_timestamp, query)
if match:
query = query.split("&t=")[0].split("?t=")[0]
elif (url_domain == "soundcloud.com" or "scsearch:" in query) and "#t=" in query:
if "/sets/" not in query or ("/sets/" in query and "?in=" in query):
match = re.search(_re_soundcloud_timestamp, query)
if match:
query = query.split("#t=")[0]
elif url_domain == "twitch.tv" and "?t=" in query:
match = re.search(_re_twitch_timestamp, query)
if match:
query = query.split("?t=")[0]
except Exception:
pass
return query
[docs]class Track:
"""
Information about a Lavalink track.
Attributes
----------
requester : discord.User
The user who requested the track.
track_identifier : str
Track identifier used by the Lavalink player to play tracks.
seekable : bool
Boolean determining if seeking can be done on this track.
author : str
The author of this track.
length : int
The length of this track in milliseconds.
is_stream : bool
Determines whether Lavalink will stream this track.
position : int
Current seeked position to begin playback.
title : str
Title of this track.
uri : str
The playback url of this track.
start_timestamp: int
The track start time in milliseconds as provided by the query.
"""
def __init__(self, data):
self.requester = None
self.track_identifier = data.get("track")
self._info = data.get("info", {})
self.seekable = self._info.get("isSeekable", False)
self.author = self._info.get("author")
self.length = self._info.get("length", 0)
self.is_stream = self._info.get("isStream", False)
self.position = self._info.get("position")
self.title = self._info.get("title")
self.uri = self._info.get("uri")
self.start_timestamp = self._info.get("timestamp", 0)
self.extras = data.get("extras", {})
@property
def thumbnail(self):
"""Optional[str]: Returns a thumbnail URL for YouTube tracks."""
if "youtube" in self.uri and "identifier" in self._info:
return "https://img.youtube.com/vi/{}/mqdefault.jpg".format(self._info["identifier"])
def __eq__(self, other):
"""Overrides the default implementation"""
if isinstance(other, Track):
return self.track_identifier == other.track_identifier
return NotImplemented
def __ne__(self, other):
"""Overrides the default implementation"""
x = self.__eq__(other)
if x is not NotImplemented:
return not x
return NotImplemented
def __hash__(self):
"""Overrides the default implementation"""
return hash(tuple(sorted([self.track_identifier, self.title, self.author, self.uri])))
def __repr__(self):
return (
"<Track: "
f"track_identifier={self.track_identifier!r}, "
f"author={self.author!r}, "
f"length={self.length}, "
f"is_stream={self.is_stream}, uri={self.uri!r}, title={self.title!r}>"
)
class LoadResult:
"""
The result of a load_tracks request.
Attributes
----------
load_type : LoadType
The result of the loadtracks request
playlist_info : Optional[PlaylistInfo]
The playlist information detected by Lavalink
tracks : Tuple[Track, ...]
The tracks that were loaded, if any
"""
def __init__(self, data):
self._raw = data
_fallback = {
"loadType": LoadType.LOAD_FAILED,
"exception": {
"message": "Lavalink API returned an unsupported response, Please report it.",
"severity": ExceptionSeverity.SUSPICIOUS,
},
"playlistInfo": {},
"tracks": [],
}
for k, v in _fallback.items():
if k not in data:
if (
k == "exception"
and data.get("loadType", LoadType.LOAD_FAILED) != LoadType.LOAD_FAILED
):
continue
elif k == "exception":
v["message"] = (
f"Timestamp: {self._raw.get('timestamp', 'Unknown')}\n"
f"Status Code: {self._raw.get('status', 'Unknown')}\n"
f"Error: {self._raw.get('error', 'Unknown')}\n"
f"Query: {self._raw.get('query', 'Unknown')}\n"
f"Load Type: {self._raw['loadType']}\n"
f"Message: {self._raw.get('message', v['message'])}"
)
self._raw.update({k: v})
self.load_type = LoadType(self._raw["loadType"])
is_playlist = self._raw.get("isPlaylist") or self.load_type == LoadType.PLAYLIST_LOADED
if is_playlist is True:
self.is_playlist = True
self.playlist_info = PlaylistInfo(**self._raw["playlistInfo"])
elif is_playlist is False:
self.is_playlist = False
self.playlist_info = None
else:
self.is_playlist = None
self.playlist_info = None
_tracks = parse_timestamps(self._raw) if self._raw.get("query") else self._raw["tracks"]
self.tracks = tuple(Track(t) for t in _tracks)
@property
def has_error(self):
return self.load_type == LoadType.LOAD_FAILED
@property
def exception_message(self) -> Union[str, None]:
"""
On Lavalink V3, if there was an exception during a load or get tracks call
this property will be populated with the error message.
If there was no error this property will be ``None``.
"""
if self.has_error:
exception_data = self._raw.get("exception", {})
return exception_data.get("message")
return None
@property
def exception_severity(self) -> Union[ExceptionSeverity, None]:
if self.has_error:
exception_data = self._raw.get("exception", {})
severity = exception_data.get("severity")
if severity is not None:
return ExceptionSeverity(severity)
return None
class RESTClient:
"""
Client class used to access the REST endpoints on a Lavalink node.
"""
def __init__(self, client: discord.Client, channel: VoiceChannel):
from lavalink.node import get_node
self.node = get_node()
self.client = client
self.state: PlayerState = PlayerState.CREATED
self.channel: discord.VoiceChannel = channel
self.guild: discord.Guild = channel.guild
self._last_channel_id = channel.id
self.secured: bool = self.node.secured
self._session: aiohttp.ClientSession = self.node.session
if self.secured:
protocol = "https"
else:
protocol = "http"
self._uri: str = f"{protocol}://{self.node.host}:{self.node.port}/loadtracks?identifier="
self._headers: Dict[str, str] = {"Authorization": self.node.password}
self._warned: bool = False
def __check_node_ready(self):
if self.state != PlayerState.READY:
raise RuntimeError("Cannot execute REST request when node not ready.")
async def _get(self, url):
try:
async with self._session.get(url, headers=self._headers) as resp:
data = await resp.json(content_type=None)
except ServerDisconnectedError:
if self.state == PlayerState.DISCONNECTING:
return {
"loadType": LoadType.LOAD_FAILED,
"exception": {
"message": "Load tracks interrupted by player disconnect.",
"severity": ExceptionSeverity.COMMON,
},
"tracks": [],
}
log.debug("Received server disconnected error when player state = %s", self.state.name)
raise
return data
async def load_tracks(self, query) -> LoadResult:
"""
Executes a loadtracks request. Only works on Lavalink V3.
Parameters
----------
query : str
Returns
-------
LoadResult
"""
self.__check_node_ready()
_raw_url = str(query)
parsed_url = reformat_query(_raw_url)
url = self._uri + quote(parsed_url)
data = await self._get(url)
if isinstance(data, dict):
data["query"] = _raw_url
data["encodedquery"] = url
return LoadResult(data)
elif isinstance(data, list):
modified_data = {
"loadType": LoadType.V2_COMPAT,
"tracks": data,
"query": _raw_url,
"encodedquery": url,
}
return LoadResult(modified_data)
async def get_tracks(self, query) -> Tuple[Track, ...]:
"""
Gets tracks from lavalink.
Parameters
----------
query : str
Returns
-------
Tuple[Track, ...]
"""
if not self._warned:
log.warn("get_tracks() is now deprecated. Please switch to using load_tracks().")
self._warned = True
result = await self.load_tracks(query)
return result.tracks
async def search_yt(self, query) -> LoadResult:
"""
Gets track results from YouTube from Lavalink.
Parameters
----------
query : str
Returns
-------
list of Track
"""
return await self.load_tracks("ytsearch:{}".format(query))
async def search_sc(self, query) -> LoadResult:
"""
Gets track results from SoundCloud from Lavalink.
Parameters
----------
query : str
Returns
-------
list of Track
"""
return await self.load_tracks("scsearch:{}".format(query))