#!/usr/bin/env python3
import os
import io
import PIL
import ssl
import json
import time
import flask
import queue
import struct
import pprint
import logging
import datetime
import threading
import traceback
from nano_llm.utils import ArgParser
from websockets.sync.server import serve as websocket_serve
from websockets.exceptions import ConnectionClosed
[docs]
class WebServer():
MESSAGE_JSON = 0 #: JSON websocket message (dict)
MESSAGE_TEXT = 1 #: Text websocket message (str)
MESSAGE_BINARY = 2 #: Binary websocket message (bytes)
MESSAGE_FILE = 3 #: File upload from client (bytes)
MESSAGE_AUDIO = 4 #: Audio samples (bytes, int16)
MESSAGE_IMAGE = 5 #: Image message (PIL.Image)
Instance = None #: Singleton instance
MessageHandlers = [] #: Message handlers
[docs]
def __init__(self, web_host='0.0.0.0', web_port=8050, ws_port=49000,
ssl_cert=None, ssl_key=None, root=None, index='index.html',
mounts={'/tmp/uploads':'/uploads'}, msg_callback=None, web_trace=False,
**kwargs):
"""
Create HTTP/HTTPS Flask webserver with websocket messaging.
Use this by either creating an instance and providing ``msg_callback``,
or inherit from it and implement ``on_message()`` in a subclass.
You can also add Flask routes to Webserver.app before ``start()`` is called.
Args:
web_host (str): network interface to bind to (0.0.0.0 for all)
web_port (int): port to serve HTTP/HTTPS webpages on
ws_port (int): port to use for websocket communication
ssl_cert (str): path to PEM-encoded SSL/TLS cert file for enabling HTTPS
ssl_key (str): path to PEM-encoded SSL/TLS cert key for enabling HTTPS
root (str): the root directory for serving site files (should have static/ and template/)
index (str): the name of the site's index page (should be under web/templates)
upload_dir (str): the path to save files uploaded from client (or None to disable uploads)
msg_callback (callable): websocket message handler (see WebServer.on_message() for signature)
web_trace (bool): if true, additional debug messages will be printed when --log-level=debug
The kwargs are passed as variables to the Jinja render_template() used in the index file.
"""
WebServer.Instance = self
self.host = web_host
self.port = web_port
self.root = root
self.trace = web_trace
self.index = index
self.kwargs = kwargs
self.mounts = mounts
self.upload_dir = None
self.alert_count = 0
self.websocket = None
if not self.root:
self.root = os.path.abspath(os.path.join(os.path.dirname(__file__), '../web'))
self.msg_count_rx = 0
self.msg_count_tx = 0
self.add_message_handler(msg_callback)
# flask server
self.app = flask.Flask(__name__,
static_folder=os.path.join(self.root, 'static'),
template_folder=os.path.join(self.root, 'templates')
)
self.app.use_x_sendfile = True
# setup default index route
self.app.add_url_rule('/', view_func=self.send_index, methods=['GET'])
# setup mounted paths
for path, mount in self.mounts.items():
if path.startswith('/tmp'):
os.makedirs(path, exist_ok=True)
if 'upload' in path or 'upload' in mount:
self.upload_dir = path
logging.info(f"mounting webserver path {path} to {mount}")
self.app.add_url_rule(f"{mount}/<path:path>", view_func=SendFromDirectory(path).send, endpoint=path, methods=['GET'])
logging.debug(f"webserver root directory: {self.root} upload directory: {self.upload_dir}")
# SSL / HTTPS
self.ssl_key = ssl_key
self.ssl_cert = ssl_cert
self.ssl_context = None
self.web_protocol = "http"
if self.ssl_cert and self.ssl_key:
self.web_protocol = "https"
self.ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
self.ssl_context.load_cert_chain(certfile=self.ssl_cert, keyfile=self.ssl_key)
# websocket
self.ws_port = ws_port
self.kwargs['ws_port'] = ws_port
self.ws_server = websocket_serve(self.on_websocket, host=self.host, port=self.ws_port, ssl_context=self.ssl_context, max_size=None)
self.ws_thread = threading.Thread(target=lambda: self.ws_server.serve_forever(), daemon=True)
self.web_thread = threading.Thread(target=lambda: self.app.run(host=self.host, port=self.port, ssl_context=self.ssl_context, debug=True, use_reloader=False), daemon=True)
# https://stackoverflow.com/a/52282788
logging.getLogger('asyncio').setLevel(logging.INFO)
logging.getLogger('asyncio.coroutines').setLevel(logging.INFO)
logging.getLogger('websockets.server').setLevel(logging.INFO)
logging.getLogger('websockets.protocol').setLevel(logging.INFO)
[docs]
def start(self):
"""
Call this to start the webserver listening for new connections.
It will start new worker threads and then return control to the user.
"""
logging.info(f"starting webserver @ {self.web_protocol}://{self.host}:{self.port}")
self.ws_thread.start()
self.web_thread.start()
@property
def connected(self):
"""
Returns true if the server is connected to any clients, otherwise false.
"""
return (self.num_clients > 0)
@property
def num_clients(self):
"""
Returns the number of actively connected clients.
"""
return 0 if self.websocket is None else 1
[docs]
@classmethod
def add_listener(cls, callback):
"""
Register a message handler that will be called when new websocket messages are recieved.
"""
cls.add_message_handler(callback)
[docs]
@classmethod
def add_message_handler(cls, callback):
"""
Register a message handler that will be called when new websocket messages are recieved.
"""
if callback is None:
return
if not isinstance(callback, list):
callback = [callback]
cls.MessageHandlers += callback
[docs]
def on_message(self, payload, payload_size=None, msg_type=MESSAGE_JSON, msg_id=None, metadata=None, timestamp=None, path=None, **kwargs):
"""
Handler for recieved websocket messages. Implement this in a subclass to process messages,
otherwise ``msg_callback`` needs to be provided during initialization.
Args:
payload (dict|str|bytes): If this is a JSON message, will be a dict.
If this is a text message, will be a string.
If this is a binary message, will be a bytes array.
payload_size (int): size of the payload (in bytes)
msg_type (int): MESSAGE_JSON (0), MESSAGE_TEXT (1), MESSAGE_BINARY (2)
msg_id (int): the monotonically-increasing message ID number
metadata (str): message-specific string or other data
timestamp (int): time that the message was sent
path (str): if this is a file or image upload, the file path on the server
"""
if self.MessageHandlers:
for callback in WebServer.MessageHandlers:
try:
callback(payload, payload_size=payload_size, msg_type=msg_type, msg_id=msg_id,
metadata=metadata, timestamp=timestamp, path=path, **kwargs)
except Exception as error:
logging.error(f"Exception occurred handling websocket message:\n\n{pprint.pformat(payload, indent=2) if msg_type==WebServer.MESSAGE_JSON else ''}\n{traceback.format_exc()}")
else:
raise NotImplementedError(f"{type(self)} did not implement on_message or have a msg_callback provided")
[docs]
def send_message(self, payload, type=None, timestamp=None):
"""
Send a websocket message to client.
"""
if timestamp is None:
timestamp = time.time() * 1000
encoding = None
if type is None:
if isinstance(payload, str):
type = WebServer.MESSAGE_TEXT
encoding = 'utf-8'
elif isinstance(payload, bytes):
type = WebServer.MESSAGE_BINARY
else:
type = WebServer.MESSAGE_JSON
encoding = 'ascii'
if self.websocket is None:
logging.debug(f"send_message() - no websocket clients connected, dropping {self.msg_type_str(type)} message")
return
if self.trace and logging.getLogger().isEnabledFor(logging.DEBUG):
msg_text = '\n' + pprint.pformat(payload) if type <= WebServer.MESSAGE_TEXT else ''
logging.debug(f"sending {WebServer.msg_type_str(type)} websocket message (type={type} size={len(payload)}){msg_text}")
if type == WebServer.MESSAGE_JSON and not isinstance(payload, str): # json.dumps() might have already been called
#print('sending JSON', payload)
payload = json.dumps(payload)
if not isinstance(payload, bytes):
if encoding is not None:
payload = bytes(payload, encoding=encoding)
else:
payload = bytes(payload)
# do we even need this queue at all and can the websocket just send straight away?
try:
self.websocket.send(b''.join([
#
# 32-byte message header format:
#
# 0 uint64 message_id (message_count_tx)
# 8 uint64 timestamp (milliseconds since Unix epoch)
# 16 uint16 magic_number (42)
# 18 uint16 message_type (0=json, 1=text, >=2 binary)
# 20 uint32 payload_size (in bytes)
# 24 uint32 unused (padding)
# 28 uint32 unused (padding)
#
struct.pack('!QQHHIII',
self.msg_count_tx,
int(timestamp),
42, type,
len(payload),
0, 0,
),
payload
]))
self.msg_count_tx += 1
except Exception as err:
logging.warning(f"failed to send websocket message to client ({err})")
def send_alert(self, message, level='warning', category='', timeout=3.5):
alert = {
'id': self.alert_count,
'time': datetime.datetime.now().strftime('%-I:%M:%S'),
'message': message,
'level': level,
'category': category,
'timeout': int(timeout*1000),
}
self.send_message({'alert': alert})
self.alert_count = self.alert_count + 1
if level == 'error':
logging.error(message)
elif level == 'warning':
logging.warning(message)
else:
logging.info(message)
return alert
def on_websocket(self, websocket):
self.websocket = websocket # TODO handle multiple clients
remote_address = websocket.remote_address
logging.info(f"new websocket connection from {remote_address}")
'''
# empty the queue from before the connection was made
# (otherwise client will be flooded with old messages)
# TODO implement self.connected so the ws_queue doesn't grow so large without webclient connected...
while True:
try:
self.ws_queue.get(block=False)
except queue.Empty:
break
'''
if self.MessageHandlers:
for callback in WebServer.MessageHandlers:
try:
callback({'client_state': 'connected'}, msg_type=WebServer.MESSAGE_JSON, timestamp=int(time.time()*1000))
except Exception as error:
logging.error(f"Exception occurred handling client_state 'connected' message\n{traceback.format_exc()}")
#listener_thread = threading.Thread(target=self.websocket_listener, args=[websocket], daemon=True)
#listener_thread.start()
try:
self.websocket_listener(websocket)
except ConnectionClosed as closed:
logging.info(f"websocket connection with {remote_address} was closed")
if self.websocket == websocket: # if the client refreshed, the new websocket may already be created
self.websocket = None
'''
while True:
try:
websocket.send(self.ws_queue.get())
except ConnectionClosed as closed:
logging.info(f"websocket connection with {remote_address} was closed")
return
'''
def websocket_listener(self, websocket):
logging.info(f"listening on websocket connection from {websocket.remote_address}")
header_size = 32
while True:
msg = websocket.recv()
if isinstance(msg, str):
logging.warning(f'dropping text-mode websocket message from {websocket.remote_address} "{msg}"')
continue
if len(msg) <= header_size:
logging.warning(f"dropping invalid websocket message from {websocket.remote_address} (size={len(msg)})")
continue
msg_id, timestamp, magic_number, msg_type, payload_size = \
struct.unpack_from('!QQHHI', msg)
metadata = msg[24:32].split(b'\x00')[0].decode()
if magic_number != 42:
logging.warning(f"dropping invalid websocket message from {websocket.remote_address} (magic_number={magic_number} size={len(msg)})")
continue
if msg_id != self.msg_count_rx:
logging.debug(f"recieved websocket message from {websocket.remote_address} with out-of-order ID {msg_id} (last={self.msg_count_rx})")
self.msg_count_rx = msg_id
self.msg_count_rx += 1
msgPayloadSize = len(msg) - header_size
if payload_size != msgPayloadSize:
logging.warning(f"recieved invalid websocket message from {websocket.remote_address} (payload_size={payload_size} actual={msgPayloadSize}");
payload = msg[header_size:]
if msg_type == WebServer.MESSAGE_JSON: # json
payload = json.loads(payload)
elif msg_type == WebServer.MESSAGE_TEXT: # text
payload = payload.decode('utf-8')
if self.trace and msg_type != WebServer.MESSAGE_AUDIO and logging.getLogger().isEnabledFor(logging.DEBUG):
msg_text = '\n' + pprint.pformat(payload) if msg_type <= WebServer.MESSAGE_TEXT else ''
logging.debug(f"recieved {WebServer.msg_type_str(msg_type)} websocket message from {websocket.remote_address} (type={msg_type} size={payload_size}){msg_text}")
# save uploaded files/images to the upload dir
filename = None
if self.upload_dir and metadata and (msg_type == WebServer.MESSAGE_FILE or msg_type == WebServer.MESSAGE_IMAGE):
filename = f"{datetime.datetime.utcfromtimestamp(timestamp/1000).strftime('%Y%m%d_%H%M%S')}.{metadata}"
filename = os.path.join(self.upload_dir, filename)
threading.Thread(target=self.save_upload, args=[payload, filename]).start()
# decode images in-memory
if msg_type == WebServer.MESSAGE_IMAGE:
try:
payload = PIL.Image.open(io.BytesIO(payload))
if filename:
payload.filename = filename
except Exception as err:
print(err)
logging.error(f"failed to load invalid/corrupted {metadata} image uploaded from client")
self.on_message(payload, payload_size=payload_size, msg_type=msg_type, msg_id=msg_id, metadata=metadata, timestamp=timestamp, path=filename)
def save_upload(self, payload, path):
logging.debug(f"saving client upload to {path}")
with open(path, 'wb') as file:
file.write(payload)
def send_index(self):
return flask.render_template(self.index, **self.kwargs)
@staticmethod
def msg_type_str(type):
if type == WebServer.MESSAGE_JSON:
return "json"
elif type == WebServer.MESSAGE_TEXT:
return "text"
elif type == WebServer.MESSAGE_BINARY:
return "binary"
elif type == WebServer.MESSAGE_FILE:
return "file"
elif type == WebServer.MESSAGE_AUDIO:
return "audio"
elif type == WebServer.MESSAGE_IMAGE:
return "image"
else:
raise ValueError(f"unknown message type {type}")
class SendFromDirectory():
def __init__(self, root):
self.root = root
def send(self, path):
return flask.send_from_directory(self.root, path, conditional=False, max_age=120, use_x_sendfile=True)
if __name__ == "__main__":
parser = ArgParser(extras=['web', 'log'])
parser.add_argument("--index", "--page", type=str, default="index.html", help="the filename of the site's index html page (should be under web/templates)")
parser.add_argument("--root", type=str, default=None, help="the root directory for serving site files (should have static/ and template/")
args = parser.parse_args()
webserver = WebServer(**vars(args))
webserver.start()
webserver.web_thread.join()