Merging PR_218 openai_rev package with new streamlit chat app

This commit is contained in:
noptuno
2023-04-27 20:29:30 -04:00
parent 479b8d6d10
commit 355dee533b
8378 changed files with 2931636 additions and 3 deletions

View File

@@ -0,0 +1,23 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.web.server.component_request_handler import ComponentRequestHandler
from streamlit.web.server.routes import (
allow_cross_origin_requests as allow_cross_origin_requests,
)
from streamlit.web.server.server import Server as Server
from streamlit.web.server.server import (
server_address_is_unix_socket as server_address_is_unix_socket,
)
from streamlit.web.server.stats_request_handler import StatsRequestHandler

View File

@@ -0,0 +1,72 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mimetypes
import os
from pathlib import Path
from typing import Optional
import tornado.web
from streamlit.logger import get_logger
from streamlit.web.server.routes import AssetsFileHandler
_LOGGER = get_logger(__name__)
# We agreed on these limitations for the initial release of static file sharing,
# based on security concerns from the SiS and Community Cloud teams
# The maximum possible size of single serving static file.
MAX_APP_STATIC_FILE_SIZE = 200 * 1024 * 1024 # 200 MB
# The list of file extensions that we serve with the corresponding Content-Type header.
# All files with other extensions will be served with Content-Type: text/plain
SAFE_APP_STATIC_FILE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp")
class AppStaticFileHandler(AssetsFileHandler):
def initialize(self, path: str, default_filename: Optional[str] = None) -> None:
super().initialize(path, default_filename)
mimetypes.add_type("image/webp", ".webp")
def validate_absolute_path(self, root: str, absolute_path: str) -> Optional[str]:
full_path = os.path.realpath(absolute_path)
if os.path.isdir(full_path):
# we don't want to serve directories, and serve only files
raise tornado.web.HTTPError(404)
if os.path.commonprefix([full_path, root]) != root:
# Don't allow misbehaving clients to break out of the static files directory
_LOGGER.warning(
"Serving files outside of the static directory is not supported"
)
raise tornado.web.HTTPError(404)
if (
os.path.exists(full_path)
and os.path.getsize(full_path) > MAX_APP_STATIC_FILE_SIZE
):
raise tornado.web.HTTPError(
404,
"File is too large, its size should not exceed "
f"{MAX_APP_STATIC_FILE_SIZE} bytes",
reason="File is too large",
)
return super().validate_absolute_path(root, absolute_path)
def set_extra_headers(self, path: str) -> None:
if Path(path).suffix not in SAFE_APP_STATIC_FILE_EXTENSIONS:
self.set_header("Content-Type", "text/plain")
self.set_header("X-Content-Type-Options", "nosniff")

View File

@@ -0,0 +1,189 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import binascii
import json
from typing import Any, Awaitable, Dict, List, Optional, Union
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.websocket import WebSocketHandler
from typing_extensions import Final
from streamlit import config
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime import Runtime, SessionClient, SessionClientDisconnectedError
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.web.server.server_util import is_url_from_allowed_origins
_LOGGER: Final = get_logger(__name__)
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
"""Handles a WebSocket connection from the browser"""
def initialize(self, runtime: Runtime) -> None:
self._runtime = runtime
self._session_id: Optional[str] = None
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
# pure-Javascript application that does not use any regular forms we just
# need to read the self.xsrf_token manually to set the cookie as a side
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
# for more details.
if config.get_option("server.enableXsrfProtection"):
_ = self.xsrf_token
def check_origin(self, origin: str) -> bool:
"""Set up CORS."""
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Send a ForwardMsg to the browser."""
try:
self.write_message(serialize_forward_msg(msg), binary=True)
except tornado.websocket.WebSocketClosedError as e:
raise SessionClientDisconnectedError from e
def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]:
"""Return the first subprotocol in the given list.
This method is used by Tornado to select a protocol when the
Sec-WebSocket-Protocol header is set in an HTTP Upgrade request.
NOTE: We repurpose the Sec-WebSocket-Protocol header here in a slightly
unfortunate (but necessary) way. The browser WebSocket API doesn't allow us to
set arbitrary HTTP headers, and this header is the only one where we have the
ability to set it to arbitrary values, so we use it to pass tokens (in this
case, the previous session ID to allow us to reconnect to it) from client to
server as the *second* value in the list.
The reason why the auth token is set as the second value is that, when
Sec-WebSocket-Protocol is set, many clients expect the server to respond with a
selected subprotocol to use. We don't want that reply to be the token, so we
by convention have the client always set the first protocol to "streamlit" and
select that.
"""
if subprotocols:
return subprotocols[0]
return None
def open(self, *args, **kwargs) -> Optional[Awaitable[None]]:
# Extract user info from the X-Streamlit-User header
is_public_cloud_app = False
try:
header_content = self.request.headers["X-Streamlit-User"]
payload = base64.b64decode(header_content)
user_obj = json.loads(payload)
email = user_obj["email"]
is_public_cloud_app = user_obj["isPublicCloudApp"]
except (KeyError, binascii.Error, json.decoder.JSONDecodeError):
email = "test@localhost.com"
user_info: Dict[str, Optional[str]] = dict()
if is_public_cloud_app:
user_info["email"] = None
else:
user_info["email"] = email
existing_session_id = None
try:
ws_protocols = [
p.strip()
for p in self.request.headers["Sec-Websocket-Protocol"].split(",")
]
if len(ws_protocols) > 1:
# See the NOTE in the docstring of the select_subprotocol method above
# for a detailed explanation of why this is done.
existing_session_id = ws_protocols[1]
except KeyError:
# Just let existing_session_id=None if we run into any error while trying to
# extract it from the Sec-Websocket-Protocol header.
pass
self._session_id = self._runtime.connect_session(
client=self,
user_info=user_info,
existing_session_id=existing_session_id,
)
return None
def on_close(self) -> None:
if not self._session_id:
return
self._runtime.disconnect_session(self._session_id)
self._session_id = None
def get_compression_options(self) -> Optional[Dict[Any, Any]]:
"""Enable WebSocket compression.
Returning an empty dict enables websocket compression. Returning
None disables it.
(See the docstring in the parent class.)
"""
if config.get_option("server.enableWebsocketCompression"):
return {}
return None
def on_message(self, payload: Union[str, bytes]) -> None:
if not self._session_id:
return
try:
if isinstance(payload, str):
# Sanity check. (The frontend should only be sending us bytes;
# Protobuf.ParseFromString does not accept str input.)
raise RuntimeError(
"WebSocket received an unexpected `str` message. "
"(We expect `bytes` only.)"
)
msg = BackMsg()
msg.ParseFromString(payload)
_LOGGER.debug("Received the following back message:\n%s", msg)
except Exception as ex:
_LOGGER.error(ex)
self._runtime.handle_backmsg_deserialization_exception(self._session_id, ex)
return
# "debug_disconnect_websocket" and "debug_shutdown_runtime" are special
# developmentMode-only messages used in e2e tests to test reconnect handling and
# disabling widgets.
if msg.WhichOneof("type") == "debug_disconnect_websocket":
if config.get_option("global.developmentMode"):
self.close()
else:
_LOGGER.warning(
"Client tried to disconnect websocket when not in development mode."
)
elif msg.WhichOneof("type") == "debug_shutdown_runtime":
if config.get_option("global.developmentMode"):
self._runtime.stop()
else:
_LOGGER.warning(
"Client tried to shut down runtime when not in development mode."
)
else:
# AppSession handles all other BackMsg types.
self._runtime.handle_backmsg(self._session_id, msg)

View File

@@ -0,0 +1,115 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import mimetypes
import os
import tornado.web
import streamlit.web.server.routes
from streamlit.components.v1.components import ComponentRegistry
from streamlit.logger import get_logger
_LOGGER = get_logger(__name__)
class ComponentRequestHandler(tornado.web.RequestHandler):
def initialize(self, registry: ComponentRegistry):
self._registry = registry
def get(self, path: str) -> None:
parts = path.split("/")
component_name = parts[0]
component_root = self._registry.get_component_path(component_name)
if component_root is None:
self.write("not found")
self.set_status(404)
return
# follow symlinks to get an accurate normalized path
component_root = os.path.realpath(component_root)
filename = "/".join(parts[1:])
abspath = os.path.realpath(os.path.join(component_root, filename))
# Do NOT expose anything outside of the component root.
if os.path.commonprefix([component_root, abspath]) != component_root or (
not os.path.normpath(abspath).startswith(
component_root
) # this is a recommendation from CodeQL, probably a bit redundant
):
self.write("forbidden")
self.set_status(403)
return
try:
with open(abspath, "rb") as file:
contents = file.read()
except (OSError) as e:
_LOGGER.error(
"ComponentRequestHandler: GET %s read error", abspath, exc_info=e
)
self.write("read error")
self.set_status(404)
return
self.write(contents)
self.set_header("Content-Type", self.get_content_type(abspath))
self.set_extra_headers(path)
def set_extra_headers(self, path) -> None:
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def set_default_headers(self) -> None:
if streamlit.web.server.routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self) -> None:
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
@staticmethod
def get_content_type(abspath) -> str:
"""Returns the ``Content-Type`` header to be used for this request.
From tornado.web.StaticFileHandler.
"""
mime_type, encoding = mimetypes.guess_type(abspath)
# per RFC 6713, use the appropriate type for a gzip compressed file
if encoding == "gzip":
return "application/gzip"
# As of 2015-07-21 there is no bzip2 encoding defined at
# http://www.iana.org/assignments/media-types/media-types.xhtml
# So for that (and any other encoding), use octet-stream.
elif encoding is not None:
return "application/octet-stream"
elif mime_type is not None:
return mime_type
# if mime_type not detected, use application/octet-stream
else:
return "application/octet-stream"
@staticmethod
def get_url(file_id: str) -> str:
"""Return the URL for a component file with the given ID."""
return "components/{}".format(file_id)

View File

@@ -0,0 +1,146 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from urllib.parse import quote, unquote_plus
import tornado.web
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
from streamlit.runtime.memory_media_file_storage import (
MemoryMediaFileStorage,
get_extension_for_mimetype,
)
from streamlit.string_util import generate_download_filename_from_title
from streamlit.web.server import allow_cross_origin_requests
_LOGGER = get_logger(__name__)
class MediaFileHandler(tornado.web.StaticFileHandler):
_storage: MemoryMediaFileStorage
@classmethod
def initialize_storage(cls, storage: MemoryMediaFileStorage) -> None:
"""Set the MemoryMediaFileStorage object used by instances of this
handler. Must be called on server startup.
"""
# This is a class method, rather than an instance method, because
# `get_content()` is a class method and needs to access the storage
# instance.
cls._storage = storage
def set_default_headers(self) -> None:
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def set_extra_headers(self, path: str) -> None:
"""Add Content-Disposition header for downloadable files.
Set header value to "attachment" indicating that file should be saved
locally instead of displaying inline in browser.
We also set filename to specify the filename for downloaded files.
Used for serving downloadable files, like files stored via the
`st.download_button` widget.
"""
media_file = self._storage.get_file(path)
if media_file and media_file.kind == MediaFileKind.DOWNLOADABLE:
filename = media_file.filename
if not filename:
title = self.get_argument("title", "", True)
title = unquote_plus(title)
filename = generate_download_filename_from_title(title)
filename = (
f"{filename}{get_extension_for_mimetype(media_file.mimetype)}"
)
try:
# Check that the value can be encoded in latin1. Latin1 is
# the default encoding for headers.
filename.encode("latin1")
file_expr = 'filename="{}"'.format(filename)
except UnicodeEncodeError:
# RFC5987 syntax.
# See: https://datatracker.ietf.org/doc/html/rfc5987
file_expr = "filename*=utf-8''{}".format(quote(filename))
self.set_header("Content-Disposition", f"attachment; {file_expr}")
# Overriding StaticFileHandler to use the MediaFileManager
#
# From the Tornado docs:
# To replace all interaction with the filesystem (e.g. to serve
# static content from a database), override `get_content`,
# `get_content_size`, `get_modified_time`, `get_absolute_path`, and
# `validate_absolute_path`.
def validate_absolute_path(self, root: str, absolute_path: str) -> str:
try:
self._storage.get_file(absolute_path)
except MediaFileStorageError:
_LOGGER.error("MediaFileHandler: Missing file %s", absolute_path)
raise tornado.web.HTTPError(404, "not found")
return absolute_path
def get_content_size(self) -> int:
abspath = self.absolute_path
if abspath is None:
return 0
media_file = self._storage.get_file(abspath)
return media_file.content_size
def get_modified_time(self) -> None:
# We do not track last modified time, but this can be improved to
# allow caching among files in the MediaFileManager
return None
@classmethod
def get_absolute_path(cls, root: str, path: str) -> str:
# All files are stored in memory, so the absolute path is just the
# path itself. In the MediaFileHandler, it's just the filename
return path
@classmethod
def get_content(
cls, abspath: str, start: Optional[int] = None, end: Optional[int] = None
):
_LOGGER.debug("MediaFileHandler: GET %s", abspath)
try:
# abspath is the hash as used `get_absolute_path`
media_file = cls._storage.get_file(abspath)
except Exception:
_LOGGER.error("MediaFileHandler: Missing file %s", abspath)
return None
_LOGGER.debug(
"MediaFileHandler: Sending %s file %s", media_file.mimetype, abspath
)
# If there is no start and end, just return the full content
if start is None and end is None:
return media_file.content
if start is None:
start = 0
if end is None:
end = len(media_file.content)
# content is bytes that work just by slicing supplied by start and end
return media_file.content[start:end]

View File

@@ -0,0 +1,259 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tornado.web
from streamlit import config, file_util
from streamlit.logger import get_logger
from streamlit.runtime.runtime_util import serialize_forward_msg
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
_LOGGER = get_logger(__name__)
def allow_cross_origin_requests():
"""True if cross-origin requests are allowed.
We only allow cross-origin requests when CORS protection has been disabled
with server.enableCORS=False or if using the Node server. When using the
Node server, we have a dev and prod port, which count as two origins.
"""
return not config.get_option("server.enableCORS") or config.get_option(
"global.developmentMode"
)
class StaticFileHandler(tornado.web.StaticFileHandler):
def initialize(self, path, default_filename, get_pages):
self._pages = get_pages()
super().initialize(path=path, default_filename=default_filename)
def set_extra_headers(self, path):
"""Disable cache for HTML files.
Other assets like JS and CSS are suffixed with their hash, so they can
be cached indefinitely.
"""
is_index_url = len(path) == 0
if is_index_url or path.endswith(".html"):
self.set_header("Cache-Control", "no-cache")
else:
self.set_header("Cache-Control", "public")
def parse_url_path(self, url_path: str) -> str:
url_parts = url_path.split("/")
maybe_page_name = url_parts[0]
if maybe_page_name in self._pages:
# If we're trying to navigate to a page, we return "index.html"
# directly here instead of deferring to the superclass below after
# modifying the url_path. The reason why is that tornado handles
# requests to "directories" (which is what navigating to a page
# looks like) by appending a trailing '/' if there is none and
# redirecting.
#
# This would work, but it
# * adds an unnecessary redirect+roundtrip
# * adds a trailing '/' to the URL appearing in the browser, which
# looks bad
if len(url_parts) == 1:
return "index.html"
url_path = "/".join(url_parts[1:])
return super().parse_url_path(url_path)
def write_error(self, status_code: int, **kwargs) -> None:
if status_code == 404:
index_file = os.path.join(file_util.get_static_dir(), "index.html")
self.render(index_file)
else:
super().write_error(status_code, **kwargs)
class AssetsFileHandler(tornado.web.StaticFileHandler):
# CORS protection should be disabled as we need access
# to this endpoint from the inner iframe.
def set_default_headers(self):
self.set_header("Access-Control-Allow-Origin", "*")
class AddSlashHandler(tornado.web.RequestHandler):
@tornado.web.addslash
def get(self):
pass
class _SpecialRequestHandler(tornado.web.RequestHandler):
"""Superclass for "special" endpoints, like /healthz."""
def set_default_headers(self):
self.set_header("Cache-Control", "no-cache")
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self):
"""/OPTIONS handler for preflight CORS checks.
When a browser is making a CORS request, it may sometimes first
send an OPTIONS request, to check whether the server understands the
CORS protocol. This is optional, and doesn't happen for every request
or in every browser. If an OPTIONS request does get sent, and is not
then handled by the server, the browser will fail the underlying
request.
The proper way to handle this is to send a 204 response ("no content")
with the CORS headers attached. (These headers are automatically added
to every outgoing response, including OPTIONS responses,
via set_default_headers().)
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
"""
self.set_status(204)
self.finish()
class HealthHandler(_SpecialRequestHandler):
def initialize(self, callback):
"""Initialize the handler
Parameters
----------
callback : callable
A function that returns True if the server is healthy
"""
self._callback = callback
async def get(self):
if self.request.uri and "_stcore/" not in self.request.uri:
new_path = (
"/_stcore/script-health-check"
if "script-health-check" in self.request.uri
else "/_stcore/health"
)
emit_endpoint_deprecation_notice(self, new_path=new_path)
ok, msg = await self._callback()
if ok:
self.write(msg)
self.set_status(200)
# Tornado will set the _xsrf cookie automatically for the page on
# request for the document. However, if the server is reset and
# server.enableXsrfProtection is updated, the browser does not reload the document.
# Manually setting the cookie on /healthz since it is pinged when the
# browser is disconnected from the server.
if config.get_option("server.enableXsrfProtection"):
self.set_cookie("_xsrf", self.xsrf_token)
else:
# 503 = SERVICE_UNAVAILABLE
self.set_status(503)
self.write(msg)
# NOTE: We eventually want to get rid of this hard-coded list entirely as we don't want
# to have links to Community Cloud live in the open source library in a way that affects
# functionality (links advertising Community Cloud are probably okay 🙂). In the long
# run, this list will most likely be replaced by a config option allowing us to more
# granularly control what domains a Streamlit app should accept cross-origin iframe
# messages from.
ALLOWED_MESSAGE_ORIGINS = [
"https://devel.streamlit.test",
"https://*.streamlit.apptest",
"https://*.streamlitapp.test",
"https://*.streamlitapp.com",
"https://share.streamlit.io",
"https://share-demo.streamlit.io",
"https://share-head.streamlit.io",
"https://share-staging.streamlit.io",
"https://*.demo.streamlit.run",
"https://*.head.streamlit.run",
"https://*.staging.streamlit.run",
"https://*.streamlit.run",
"https://*.demo.streamlit.app",
"https://*.head.streamlit.app",
"https://*.staging.streamlit.app",
"https://*.streamlit.app",
]
class AllowedMessageOriginsHandler(_SpecialRequestHandler):
async def get(self) -> None:
# ALLOWED_MESSAGE_ORIGINS must be wrapped in a dictionary because Tornado
# disallows writing lists directly into responses due to potential XSS
# vulnerabilities.
# See https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write
self.write(
{
"allowedOrigins": ALLOWED_MESSAGE_ORIGINS,
"useExternalAuthToken": False,
}
)
self.set_status(200)
class MessageCacheHandler(tornado.web.RequestHandler):
"""Returns ForwardMsgs from our MessageCache"""
def initialize(self, cache):
"""Initializes the handler.
Parameters
----------
cache : MessageCache
"""
self._cache = cache
def set_default_headers(self):
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def get(self):
msg_hash = self.get_argument("hash", None)
if msg_hash is None:
# Hash is missing! This is a malformed request.
_LOGGER.error(
"HTTP request for cached message is missing the hash attribute."
)
self.set_status(404)
raise tornado.web.Finish()
message = self._cache.get_message(msg_hash)
if message is None:
# Message not in our cache.
_LOGGER.error(
"HTTP request for cached message could not be fulfilled. "
"No such message"
)
self.set_status(404)
raise tornado.web.Finish()
_LOGGER.debug("MessageCache HIT")
msg_str = serialize_forward_msg(message)
self.set_header("Content-Type", "application/octet-stream")
self.write(msg_str)
self.set_status(200)
def options(self):
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()

View File

@@ -0,0 +1,410 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import logging
import os
import socket
import ssl
import sys
from pathlib import Path
from typing import Any, Awaitable, List, Optional, Union
import click
import tornado.concurrent
import tornado.locks
import tornado.netutil
import tornado.web
import tornado.websocket
from tornado.httpserver import HTTPServer
from typing_extensions import Final
from streamlit import config, file_util, source_util, util
from streamlit.components.v1.components import ComponentRegistry
from streamlit.config_option import ConfigOption
from streamlit.logger import get_logger
from streamlit.runtime import Runtime, RuntimeConfig, RuntimeState
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
from streamlit.runtime.runtime_util import get_max_message_size_bytes
from streamlit.web.cache_storage_manager_config import (
create_default_cache_storage_manager,
)
from streamlit.web.server.app_static_file_handler import AppStaticFileHandler
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
from streamlit.web.server.component_request_handler import ComponentRequestHandler
from streamlit.web.server.media_file_handler import MediaFileHandler
from streamlit.web.server.routes import (
AddSlashHandler,
AllowedMessageOriginsHandler,
AssetsFileHandler,
HealthHandler,
MessageCacheHandler,
StaticFileHandler,
)
from streamlit.web.server.server_util import make_url_path_regex
from streamlit.web.server.stats_request_handler import StatsRequestHandler
from streamlit.web.server.upload_file_request_handler import (
UPLOAD_FILE_ROUTE,
UploadFileRequestHandler,
)
LOGGER = get_logger(__name__)
TORNADO_SETTINGS = {
# Gzip HTTP responses.
"compress_response": True,
# Ping every 1s to keep WS alive.
# 2021.06.22: this value was previously 20s, and was causing
# connection instability for a small number of users. This smaller
# ping_interval fixes that instability.
# https://github.com/streamlit/streamlit/issues/3196
"websocket_ping_interval": 1,
# If we don't get a ping response within 30s, the connection
# is timed out.
"websocket_ping_timeout": 30,
}
# When server.port is not available it will look for the next available port
# up to MAX_PORT_SEARCH_RETRIES.
MAX_PORT_SEARCH_RETRIES = 100
# When server.address starts with this prefix, the server will bind
# to an unix socket.
UNIX_SOCKET_PREFIX = "unix://"
MEDIA_ENDPOINT: Final = "/media"
STREAM_ENDPOINT: Final = r"_stcore/stream"
METRIC_ENDPOINT: Final = r"(?:st-metrics|_stcore/metrics)"
MESSAGE_ENDPOINT: Final = r"_stcore/message"
HEALTH_ENDPOINT: Final = r"(?:healthz|_stcore/health)"
ALLOWED_MESSAGE_ORIGIN_ENDPOINT: Final = r"_stcore/allowed-message-origins"
SCRIPT_HEALTH_CHECK_ENDPOINT: Final = (
r"(?:script-health-check|_stcore/script-health-check)"
)
class RetriesExceeded(Exception):
pass
def server_port_is_manually_set() -> bool:
return config.is_manually_set("server.port")
def server_address_is_unix_socket() -> bool:
address = config.get_option("server.address")
return address is not None and address.startswith(UNIX_SOCKET_PREFIX)
def start_listening(app: tornado.web.Application) -> None:
"""Makes the server start listening at the configured port.
In case the port is already taken it tries listening to the next available
port. It will error after MAX_PORT_SEARCH_RETRIES attempts.
"""
cert_file = config.get_option("server.sslCertFile")
key_file = config.get_option("server.sslKeyFile")
ssl_options = _get_ssl_options(cert_file, key_file)
http_server = HTTPServer(
app,
max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024,
ssl_options=ssl_options,
)
if server_address_is_unix_socket():
start_listening_unix_socket(http_server)
else:
start_listening_tcp_socket(http_server)
def _get_ssl_options(
cert_file: Optional[str], key_file: Optional[str]
) -> Union[ssl.SSLContext, None]:
if bool(cert_file) != bool(key_file):
LOGGER.error(
"Options 'server.sslCertFile' and 'server.sslKeyFile' must "
"be set together. Set missing options or delete existing options."
)
sys.exit(1)
if cert_file and key_file:
# ssl_ctx.load_cert_chain raise exception as below, but it is not
# sufficiently user-friendly
# FileNotFoundError: [Errno 2] No such file or directory
if not Path(cert_file).exists():
LOGGER.error("Cert file '%s' does not exist.", cert_file)
sys.exit(1)
if not Path(key_file).exists():
LOGGER.error("Key file '%s' does not exist.", key_file)
sys.exit(1)
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
# When the SSL certificate fails to load, an exception is raised as below,
# but it is not sufficiently user-friendly.
# ssl.SSLError: [SSL] PEM lib (_ssl.c:4067)
try:
ssl_ctx.load_cert_chain(cert_file, key_file)
except ssl.SSLError:
LOGGER.error(
"Failed to load SSL certificate. Make sure "
"cert file '%s' and key file '%s' are correct.",
cert_file,
key_file,
)
sys.exit(1)
return ssl_ctx
return None
def start_listening_unix_socket(http_server: HTTPServer) -> None:
address = config.get_option("server.address")
file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :])
unix_socket = tornado.netutil.bind_unix_socket(file_name)
http_server.add_socket(unix_socket)
def start_listening_tcp_socket(http_server: HTTPServer) -> None:
call_count = 0
port = None
while call_count < MAX_PORT_SEARCH_RETRIES:
address = config.get_option("server.address")
port = config.get_option("server.port")
try:
http_server.listen(port, address)
break # It worked! So let's break out of the loop.
except (OSError, socket.error) as e:
if e.errno == errno.EADDRINUSE:
if server_port_is_manually_set():
LOGGER.error("Port %s is already in use", port)
sys.exit(1)
else:
LOGGER.debug(
"Port %s already in use, trying to use the next one.", port
)
port += 1
# Save port 3000 because it is used for the development
# server in the front end.
if port == 3000:
port += 1
config.set_option(
"server.port", port, ConfigOption.STREAMLIT_DEFINITION
)
call_count += 1
else:
raise
if call_count >= MAX_PORT_SEARCH_RETRIES:
raise RetriesExceeded(
f"Cannot start Streamlit server. Port {port} is already in use, and "
f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.",
)
class Server:
def __init__(self, main_script_path: str, command_line: Optional[str]):
"""Create the server. It won't be started yet."""
_set_tornado_log_levels()
self._main_script_path = main_script_path
# Initialize MediaFileStorage and its associated endpoint
media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
MediaFileHandler.initialize_storage(media_file_storage)
self._runtime = Runtime(
RuntimeConfig(
script_path=main_script_path,
command_line=command_line,
media_file_storage=media_file_storage,
cache_storage_manager=create_default_cache_storage_manager(),
),
)
self._runtime.stats_mgr.register_provider(media_file_storage)
def __repr__(self) -> str:
return util.repr_(self)
@property
def main_script_path(self) -> str:
return self._main_script_path
async def start(self) -> None:
"""Start the server.
When this returns, Streamlit is ready to accept new sessions.
"""
LOGGER.debug("Starting server...")
app = self._create_app()
start_listening(app)
port = config.get_option("server.port")
LOGGER.debug("Server started on port %s", port)
await self._runtime.start()
@property
def stopped(self) -> Awaitable[None]:
"""A Future that completes when the Server's run loop has exited."""
return self._runtime.stopped
def _create_app(self) -> tornado.web.Application:
"""Create our tornado web app."""
base = config.get_option("server.baseUrlPath")
routes: List[Any] = [
(
make_url_path_regex(base, STREAM_ENDPOINT),
BrowserWebSocketHandler,
dict(runtime=self._runtime),
),
(
make_url_path_regex(base, HEALTH_ENDPOINT),
HealthHandler,
dict(callback=lambda: self._runtime.is_ready_for_browser_connection),
),
(
make_url_path_regex(base, MESSAGE_ENDPOINT),
MessageCacheHandler,
dict(cache=self._runtime.message_cache),
),
(
make_url_path_regex(base, METRIC_ENDPOINT),
StatsRequestHandler,
dict(stats_manager=self._runtime.stats_mgr),
),
(
make_url_path_regex(base, ALLOWED_MESSAGE_ORIGIN_ENDPOINT),
AllowedMessageOriginsHandler,
),
(
make_url_path_regex(
base,
UPLOAD_FILE_ROUTE,
),
UploadFileRequestHandler,
dict(
file_mgr=self._runtime.uploaded_file_mgr,
is_active_session=self._runtime.is_active_session,
),
),
(
make_url_path_regex(base, "assets/(.*)"),
AssetsFileHandler,
{"path": "%s/" % file_util.get_assets_dir()},
),
(
make_url_path_regex(base, f"{MEDIA_ENDPOINT}/(.*)"),
MediaFileHandler,
{"path": ""},
),
(
make_url_path_regex(base, "component/(.*)"),
ComponentRequestHandler,
dict(registry=ComponentRegistry.instance()),
),
]
if config.get_option("server.scriptHealthCheckEnabled"):
routes.extend(
[
(
make_url_path_regex(base, SCRIPT_HEALTH_CHECK_ENDPOINT),
HealthHandler,
dict(
callback=lambda: self._runtime.does_script_run_without_error()
),
)
]
)
if config.get_option("server.enableStaticServing"):
routes.extend(
[
(
make_url_path_regex(base, "app/static/(.*)"),
AppStaticFileHandler,
{"path": file_util.get_app_static_dir(self.main_script_path)},
),
]
)
if config.get_option("global.developmentMode"):
LOGGER.debug("Serving static content from the Node dev server")
else:
static_path = file_util.get_static_dir()
LOGGER.debug("Serving static content from %s", static_path)
routes.extend(
[
(
make_url_path_regex(base, "(.*)"),
StaticFileHandler,
{
"path": "%s/" % static_path,
"default_filename": "index.html",
"get_pages": lambda: set(
[
page_info["page_name"]
for page_info in source_util.get_pages(
self.main_script_path
).values()
]
),
},
),
(make_url_path_regex(base, trailing_slash=False), AddSlashHandler),
]
)
return tornado.web.Application(
routes,
cookie_secret=config.get_option("server.cookieSecret"),
xsrf_cookies=config.get_option("server.enableXsrfProtection"),
# Set the websocket message size. The default value is too low.
websocket_max_message_size=get_max_message_size_bytes(),
**TORNADO_SETTINGS, # type: ignore[arg-type]
)
@property
def browser_is_connected(self) -> bool:
return self._runtime.state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
@property
def is_running_hello(self) -> bool:
from streamlit.hello import Hello
return self._main_script_path == Hello.__file__
def stop(self) -> None:
click.secho(" Stopping...", fg="blue")
self._runtime.stop()
def _set_tornado_log_levels() -> None:
if not config.get_option("global.developmentMode"):
# Hide logs unless they're super important.
# Example of stuff we don't care about: 404 about .js.map files.
logging.getLogger("tornado.access").setLevel(logging.ERROR)
logging.getLogger("tornado.application").setLevel(logging.ERROR)
logging.getLogger("tornado.general").setLevel(logging.ERROR)

View File

@@ -0,0 +1,126 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Server related utility functions"""
from typing import Optional
from urllib.parse import urljoin
import tornado.web
from streamlit import config, net_util, url_util
def is_url_from_allowed_origins(url: str) -> bool:
"""Return True if URL is from allowed origins (for CORS purpose).
Allowed origins:
1. localhost
2. The internal and external IP addresses of the machine where this
function was called from.
If `server.enableCORS` is False, this allows all origins.
"""
if not config.get_option("server.enableCORS"):
# Allow everything when CORS is disabled.
return True
hostname = url_util.get_hostname(url)
allowed_domains = [ # List[Union[str, Callable[[], Optional[str]]]]
# Check localhost first.
"localhost",
"0.0.0.0",
"127.0.0.1",
# Try to avoid making unnecessary HTTP requests by checking if the user
# manually specified a server address.
_get_server_address_if_manually_set,
# Then try the options that depend on HTTP requests or opening sockets.
net_util.get_internal_ip,
net_util.get_external_ip,
]
for allowed_domain in allowed_domains:
if callable(allowed_domain):
allowed_domain = allowed_domain()
if allowed_domain is None:
continue
if hostname == allowed_domain:
return True
return False
def _get_server_address_if_manually_set() -> Optional[str]:
if config.is_manually_set("browser.serverAddress"):
return url_util.get_hostname(config.get_option("browser.serverAddress"))
return None
def make_url_path_regex(*path, **kwargs) -> str:
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
path = [x.strip("/") for x in path if x] # Filter out falsely components.
path_format = r"^/%s/?$" if kwargs.get("trailing_slash", True) else r"^/%s$"
return path_format % "/".join(path)
def get_url(host_ip: str) -> str:
"""Get the URL for any app served at the given host_ip.
Parameters
----------
host_ip : str
The IP address of the machine that is running the Streamlit Server.
Returns
-------
str
The URL.
"""
protocol = "https" if config.get_option("server.sslCertFile") else "http"
port = _get_browser_address_bar_port()
base_path = config.get_option("server.baseUrlPath").strip("/")
if base_path:
base_path = "/" + base_path
host_ip = host_ip.strip("/")
return f"{protocol}://{host_ip}:{port}{base_path}"
def _get_browser_address_bar_port() -> int:
"""Get the app URL that will be shown in the browser's address bar.
That is, this is the port where static assets will be served from. In dev,
this is different from the URL that will be used to connect to the
server-browser websocket.
"""
if config.get_option("global.developmentMode"):
return 3000
return int(config.get_option("browser.serverPort"))
def emit_endpoint_deprecation_notice(
handler: tornado.web.RequestHandler, new_path: str
) -> None:
"""
Emits the warning about deprecation of HTTP endpoint in the HTTP header.
"""
handler.set_header("Deprecation", True)
new_url = urljoin(f"{handler.request.protocol}://{handler.request.host}", new_path)
handler.set_header("Link", f'<{new_url}>; rel="alternate"')

View File

@@ -0,0 +1,88 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import tornado.web
from streamlit.proto.openmetrics_data_model_pb2 import GAUGE
from streamlit.proto.openmetrics_data_model_pb2 import MetricSet as MetricSetProto
from streamlit.runtime.stats import CacheStat, StatsManager
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
class StatsRequestHandler(tornado.web.RequestHandler):
def initialize(self, stats_manager: StatsManager) -> None:
self._manager = stats_manager
def set_default_headers(self):
# Avoid a circular import
from streamlit.web.server import allow_cross_origin_requests
if allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self):
"""/OPTIONS handler for preflight CORS checks."""
self.set_status(204)
self.finish()
def get(self) -> None:
if self.request.uri and "_stcore/" not in self.request.uri:
emit_endpoint_deprecation_notice(self, new_path="/_stcore/metrics")
stats = self._manager.get_stats()
# If the request asked for protobuf output, we return a serialized
# protobuf. Else we return text.
if "application/x-protobuf" in self.request.headers.get_list("Accept"):
self.write(self._stats_to_proto(stats).SerializeToString())
self.set_header("Content-Type", "application/x-protobuf")
self.set_status(200)
else:
self.write(self._stats_to_text(self._manager.get_stats()))
self.set_header("Content-Type", "application/openmetrics-text")
self.set_status(200)
@staticmethod
def _stats_to_text(stats: List[CacheStat]) -> str:
metric_type = "# TYPE cache_memory_bytes gauge"
metric_unit = "# UNIT cache_memory_bytes bytes"
metric_help = "# HELP Total memory consumed by a cache."
openmetrics_eof = "# EOF\n"
# Format: header, stats, EOF
result = [metric_type, metric_unit, metric_help]
result.extend(stat.to_metric_str() for stat in stats)
result.append(openmetrics_eof)
return "\n".join(result)
@staticmethod
def _stats_to_proto(stats: List[CacheStat]) -> MetricSetProto:
metric_set = MetricSetProto()
metric_family = metric_set.metric_families.add()
metric_family.name = "cache_memory_bytes"
metric_family.type = GAUGE
metric_family.unit = "bytes"
metric_family.help = "Total memory consumed by a cache."
for stat in stats:
metric_proto = metric_family.metrics.add()
stat.marshall_metric_proto(metric_proto)
metric_set = MetricSetProto()
metric_set.metric_families.append(metric_family)
return metric_set

View File

@@ -0,0 +1,157 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List
import tornado.httputil
import tornado.web
from streamlit import config
from streamlit.logger import get_logger
from streamlit.runtime.uploaded_file_manager import UploadedFileManager, UploadedFileRec
from streamlit.web.server import routes, server_util
# /_stcore/upload_file/(optional session id)/(optional widget id)
UPLOAD_FILE_ROUTE = (
r"/_stcore/upload_file/?(?P<session_id>[^/]*)?/?(?P<widget_id>[^/]*)?"
)
LOGGER = get_logger(__name__)
class UploadFileRequestHandler(tornado.web.RequestHandler):
"""Implements the POST /upload_file endpoint."""
def initialize(
self, file_mgr: UploadedFileManager, is_active_session: Callable[[str], bool]
):
"""
Parameters
----------
file_mgr : UploadedFileManager
The server's singleton UploadedFileManager. All file uploads
go here.
is_active_session:
A function that returns true if a session_id belongs to an active
session.
"""
self._file_mgr = file_mgr
self._is_active_session = is_active_session
def set_default_headers(self):
self.set_header("Access-Control-Allow-Methods", "POST, OPTIONS")
self.set_header("Access-Control-Allow-Headers", "Content-Type")
if config.get_option("server.enableXsrfProtection"):
self.set_header(
"Access-Control-Allow-Origin",
server_util.get_url(config.get_option("browser.serverAddress")),
)
self.set_header("Access-Control-Allow-Headers", "X-Xsrftoken, Content-Type")
self.set_header("Vary", "Origin")
self.set_header("Access-Control-Allow-Credentials", "true")
elif routes.allow_cross_origin_requests():
self.set_header("Access-Control-Allow-Origin", "*")
def options(self, **kwargs):
"""/OPTIONS handler for preflight CORS checks.
When a browser is making a CORS request, it may sometimes first
send an OPTIONS request, to check whether the server understands the
CORS protocol. This is optional, and doesn't happen for every request
or in every browser. If an OPTIONS request does get sent, and is not
then handled by the server, the browser will fail the underlying
request.
The proper way to handle this is to send a 204 response ("no content")
with the CORS headers attached. (These headers are automatically added
to every outgoing response, including OPTIONS responses,
via set_default_headers().)
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
"""
self.set_status(204)
self.finish()
@staticmethod
def _require_arg(args: Dict[str, List[bytes]], name: str) -> str:
"""Return the value of the argument with the given name.
A human-readable exception will be raised if the argument doesn't
exist. This will be used as the body for the error response returned
from the request.
"""
try:
arg = args[name]
except KeyError:
raise Exception(f"Missing '{name}'")
if len(arg) != 1:
raise Exception(f"Expected 1 '{name}' arg, but got {len(arg)}")
# Convert bytes to string
return arg[0].decode("utf-8")
def post(self, **kwargs):
"""Receive an uploaded file and add it to our UploadedFileManager.
Return the file's ID, so that the client can refer to it.
"""
args: Dict[str, List[bytes]] = {}
files: Dict[str, List[Any]] = {}
tornado.httputil.parse_body_arguments(
content_type=self.request.headers["Content-Type"],
body=self.request.body,
arguments=args,
files=files,
)
try:
session_id = self._require_arg(args, "sessionId")
widget_id = self._require_arg(args, "widgetId")
if not self._is_active_session(session_id):
raise Exception(f"Invalid session_id: '{session_id}'")
except Exception as e:
self.send_error(400, reason=str(e))
return
# Create an UploadedFile object for each file.
# We assign an initial, invalid file_id to each file in this loop.
# The file_mgr will assign unique file IDs and return in `add_file`,
# below.
uploaded_files: List[UploadedFileRec] = []
for _, flist in files.items():
for file in flist:
uploaded_files.append(
UploadedFileRec(
id=0,
name=file["filename"],
type=file["content_type"],
data=file["body"],
)
)
if len(uploaded_files) != 1:
self.send_error(
400, reason=f"Expected 1 file, but got {len(uploaded_files)}"
)
return
added_file = self._file_mgr.add_file(
session_id=session_id, widget_id=widget_id, file=uploaded_files[0]
)
# Return the file_id to the client. (The client will parse
# the string back to an int.)
self.write(str(added_file.id))
self.set_status(200)

View File

@@ -0,0 +1,47 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional
from streamlit import runtime
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner import get_script_run_ctx
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
@gather_metrics("_get_websocket_headers")
def _get_websocket_headers() -> Optional[Dict[str, str]]:
"""Return a copy of the HTTP request headers for the current session's
WebSocket connection. If there's no active session, return None instead.
Raise an error if the server is not running.
Note to the intrepid: this is an UNSUPPORTED, INTERNAL API. (We don't have plans
to remove it without a replacement, but we don't consider this a production-ready
function, and its signature may change without a deprecation warning.)
"""
ctx = get_script_run_ctx()
if ctx is None:
return None
session_client = runtime.get_instance().get_client(ctx.session_id)
if session_client is None:
return None
if not isinstance(session_client, BrowserWebSocketHandler):
raise RuntimeError(
f"SessionClient is not a BrowserWebSocketHandler! ({session_client})"
)
return dict(session_client.request.headers)