99 lines
3.5 KiB
Python
99 lines
3.5 KiB
Python
|
import asyncio
|
||
|
import re
|
||
|
import typing
|
||
|
|
||
|
|
||
|
server_message_regex = re.compile(r"^(?P<nickname>[\w\s]+):\s*(?P<content>.*)$")
|
||
|
|
||
|
|
||
|
class MoonchatMessage(typing.NamedTuple):
|
||
|
nickname: str
|
||
|
content: str
|
||
|
|
||
|
|
||
|
class MessageDecodeError(ValueError):
|
||
|
pass
|
||
|
|
||
|
|
||
|
class Moonchat:
|
||
|
def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, encoding: str):
|
||
|
self.reader = reader
|
||
|
self.writer = writer
|
||
|
self.encoding = encoding
|
||
|
self.closed = False
|
||
|
|
||
|
def close(self):
|
||
|
if self.closed:
|
||
|
return
|
||
|
self.closed = True
|
||
|
if not self.writer.is_closing():
|
||
|
if self.writer.can_write_eof():
|
||
|
self.writer.write_eof()
|
||
|
self.writer.close()
|
||
|
|
||
|
@staticmethod
|
||
|
async def connect(ip: str, port: int, encoding='ascii', **kwargs):
|
||
|
"""Provide the hostname, port and optional arguments to open_connection."""
|
||
|
streams = await asyncio.open_connection(ip, port, **kwargs)
|
||
|
return Moonchat(*streams, encoding=encoding) if encoding else Moonchat(*streams)
|
||
|
|
||
|
def encode_message(self, message: str) -> bytes:
|
||
|
"""Return encoded raw data with trailing newline if required."""
|
||
|
return (message.removesuffix('\n')+'\n').encode(self.encoding)
|
||
|
|
||
|
def decode_message(self, data: bytes) -> MoonchatMessage:
|
||
|
"""Return decoded raw data without trailing newlines."""
|
||
|
unparsed = (data.decode(self.encoding)).strip()
|
||
|
regex_match = server_message_regex.match(unparsed)
|
||
|
if not regex_match:
|
||
|
raise ValueError("cannot decode malformed message: " + unparsed)
|
||
|
return MoonchatMessage(**regex_match.groupdict())
|
||
|
|
||
|
async def send_message(self, message: str) -> bool:
|
||
|
"""Sends string to chat. Return whether successful."""
|
||
|
encoded_message = self.encode_message(message)
|
||
|
return await self.send_message_raw(encoded_message)
|
||
|
|
||
|
async def send_message_raw(self, message: bytes | bytearray | memoryview) -> bool:
|
||
|
"""Send raw data straight to the server if you feel like it. Return True if successful."""
|
||
|
if self.closed:
|
||
|
return False
|
||
|
if self.writer.is_closing():
|
||
|
self.close()
|
||
|
return False
|
||
|
self.writer.write(message)
|
||
|
await self.writer.drain()
|
||
|
return True
|
||
|
|
||
|
async def recieve_message_raw(self) -> bytes | None:
|
||
|
"""Retrieve the next line from the server, or None if there are no more messages."""
|
||
|
if self.closed:
|
||
|
return None
|
||
|
line = await self.reader.readline()
|
||
|
if b'\n' not in line: # partial reads mean we're out of data
|
||
|
self.close()
|
||
|
return None
|
||
|
return line
|
||
|
|
||
|
async def recieve_message(self) -> MoonchatMessage | None:
|
||
|
"""Retrieve the next message from the server."""
|
||
|
raw_message = await self.recieve_message_raw()
|
||
|
return self.decode_message(raw_message) if raw_message else None
|
||
|
|
||
|
async def raw_messages(self):
|
||
|
"""Yield raw unencoded messages until connection is closed."""
|
||
|
while not self.closed:
|
||
|
if message := await self.recieve_message_raw():
|
||
|
yield message
|
||
|
|
||
|
async def messages(self, ignore_invalid=False):
|
||
|
"""Yield messages until the connection is closed"""
|
||
|
while not self.closed:
|
||
|
try:
|
||
|
message = await self.recieve_message()
|
||
|
except MessageDecodeError as err:
|
||
|
if not ignore_invalid:
|
||
|
raise err
|
||
|
if message:
|
||
|
yield message
|