Merging PR_218 openai_rev package with new streamlit chat app
This commit is contained in:
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.web.server.component_request_handler import ComponentRequestHandler
|
||||
from streamlit.web.server.routes import (
|
||||
allow_cross_origin_requests as allow_cross_origin_requests,
|
||||
)
|
||||
from streamlit.web.server.server import Server as Server
|
||||
from streamlit.web.server.server import (
|
||||
server_address_is_unix_socket as server_address_is_unix_socket,
|
||||
)
|
||||
from streamlit.web.server.stats_request_handler import StatsRequestHandler
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import mimetypes
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.web.server.routes import AssetsFileHandler
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
# We agreed on these limitations for the initial release of static file sharing,
|
||||
# based on security concerns from the SiS and Community Cloud teams
|
||||
# The maximum possible size of single serving static file.
|
||||
MAX_APP_STATIC_FILE_SIZE = 200 * 1024 * 1024 # 200 MB
|
||||
# The list of file extensions that we serve with the corresponding Content-Type header.
|
||||
# All files with other extensions will be served with Content-Type: text/plain
|
||||
SAFE_APP_STATIC_FILE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".gif", ".webp")
|
||||
|
||||
|
||||
class AppStaticFileHandler(AssetsFileHandler):
|
||||
def initialize(self, path: str, default_filename: Optional[str] = None) -> None:
|
||||
super().initialize(path, default_filename)
|
||||
mimetypes.add_type("image/webp", ".webp")
|
||||
|
||||
def validate_absolute_path(self, root: str, absolute_path: str) -> Optional[str]:
|
||||
full_path = os.path.realpath(absolute_path)
|
||||
|
||||
if os.path.isdir(full_path):
|
||||
# we don't want to serve directories, and serve only files
|
||||
raise tornado.web.HTTPError(404)
|
||||
|
||||
if os.path.commonprefix([full_path, root]) != root:
|
||||
# Don't allow misbehaving clients to break out of the static files directory
|
||||
_LOGGER.warning(
|
||||
"Serving files outside of the static directory is not supported"
|
||||
)
|
||||
raise tornado.web.HTTPError(404)
|
||||
|
||||
if (
|
||||
os.path.exists(full_path)
|
||||
and os.path.getsize(full_path) > MAX_APP_STATIC_FILE_SIZE
|
||||
):
|
||||
raise tornado.web.HTTPError(
|
||||
404,
|
||||
"File is too large, its size should not exceed "
|
||||
f"{MAX_APP_STATIC_FILE_SIZE} bytes",
|
||||
reason="File is too large",
|
||||
)
|
||||
|
||||
return super().validate_absolute_path(root, absolute_path)
|
||||
|
||||
def set_extra_headers(self, path: str) -> None:
|
||||
if Path(path).suffix not in SAFE_APP_STATIC_FILE_EXTENSIONS:
|
||||
self.set_header("Content-Type", "text/plain")
|
||||
self.set_header("X-Content-Type-Options", "nosniff")
|
||||
@@ -0,0 +1,189 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
from typing import Any, Awaitable, Dict, List, Optional, Union
|
||||
|
||||
import tornado.concurrent
|
||||
import tornado.locks
|
||||
import tornado.netutil
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
from tornado.websocket import WebSocketHandler
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.BackMsg_pb2 import BackMsg
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime import Runtime, SessionClient, SessionClientDisconnectedError
|
||||
from streamlit.runtime.runtime_util import serialize_forward_msg
|
||||
from streamlit.web.server.server_util import is_url_from_allowed_origins
|
||||
|
||||
_LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class BrowserWebSocketHandler(WebSocketHandler, SessionClient):
|
||||
"""Handles a WebSocket connection from the browser"""
|
||||
|
||||
def initialize(self, runtime: Runtime) -> None:
|
||||
self._runtime = runtime
|
||||
self._session_id: Optional[str] = None
|
||||
# The XSRF cookie is normally set when xsrf_form_html is used, but in a
|
||||
# pure-Javascript application that does not use any regular forms we just
|
||||
# need to read the self.xsrf_token manually to set the cookie as a side
|
||||
# effect. See https://www.tornadoweb.org/en/stable/guide/security.html#cross-site-request-forgery-protection
|
||||
# for more details.
|
||||
if config.get_option("server.enableXsrfProtection"):
|
||||
_ = self.xsrf_token
|
||||
|
||||
def check_origin(self, origin: str) -> bool:
|
||||
"""Set up CORS."""
|
||||
return super().check_origin(origin) or is_url_from_allowed_origins(origin)
|
||||
|
||||
def write_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Send a ForwardMsg to the browser."""
|
||||
try:
|
||||
self.write_message(serialize_forward_msg(msg), binary=True)
|
||||
except tornado.websocket.WebSocketClosedError as e:
|
||||
raise SessionClientDisconnectedError from e
|
||||
|
||||
def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]:
|
||||
"""Return the first subprotocol in the given list.
|
||||
|
||||
This method is used by Tornado to select a protocol when the
|
||||
Sec-WebSocket-Protocol header is set in an HTTP Upgrade request.
|
||||
|
||||
NOTE: We repurpose the Sec-WebSocket-Protocol header here in a slightly
|
||||
unfortunate (but necessary) way. The browser WebSocket API doesn't allow us to
|
||||
set arbitrary HTTP headers, and this header is the only one where we have the
|
||||
ability to set it to arbitrary values, so we use it to pass tokens (in this
|
||||
case, the previous session ID to allow us to reconnect to it) from client to
|
||||
server as the *second* value in the list.
|
||||
|
||||
The reason why the auth token is set as the second value is that, when
|
||||
Sec-WebSocket-Protocol is set, many clients expect the server to respond with a
|
||||
selected subprotocol to use. We don't want that reply to be the token, so we
|
||||
by convention have the client always set the first protocol to "streamlit" and
|
||||
select that.
|
||||
"""
|
||||
if subprotocols:
|
||||
return subprotocols[0]
|
||||
|
||||
return None
|
||||
|
||||
def open(self, *args, **kwargs) -> Optional[Awaitable[None]]:
|
||||
# Extract user info from the X-Streamlit-User header
|
||||
is_public_cloud_app = False
|
||||
|
||||
try:
|
||||
header_content = self.request.headers["X-Streamlit-User"]
|
||||
payload = base64.b64decode(header_content)
|
||||
user_obj = json.loads(payload)
|
||||
email = user_obj["email"]
|
||||
is_public_cloud_app = user_obj["isPublicCloudApp"]
|
||||
except (KeyError, binascii.Error, json.decoder.JSONDecodeError):
|
||||
email = "test@localhost.com"
|
||||
|
||||
user_info: Dict[str, Optional[str]] = dict()
|
||||
if is_public_cloud_app:
|
||||
user_info["email"] = None
|
||||
else:
|
||||
user_info["email"] = email
|
||||
|
||||
existing_session_id = None
|
||||
try:
|
||||
ws_protocols = [
|
||||
p.strip()
|
||||
for p in self.request.headers["Sec-Websocket-Protocol"].split(",")
|
||||
]
|
||||
|
||||
if len(ws_protocols) > 1:
|
||||
# See the NOTE in the docstring of the select_subprotocol method above
|
||||
# for a detailed explanation of why this is done.
|
||||
existing_session_id = ws_protocols[1]
|
||||
except KeyError:
|
||||
# Just let existing_session_id=None if we run into any error while trying to
|
||||
# extract it from the Sec-Websocket-Protocol header.
|
||||
pass
|
||||
|
||||
self._session_id = self._runtime.connect_session(
|
||||
client=self,
|
||||
user_info=user_info,
|
||||
existing_session_id=existing_session_id,
|
||||
)
|
||||
return None
|
||||
|
||||
def on_close(self) -> None:
|
||||
if not self._session_id:
|
||||
return
|
||||
self._runtime.disconnect_session(self._session_id)
|
||||
self._session_id = None
|
||||
|
||||
def get_compression_options(self) -> Optional[Dict[Any, Any]]:
|
||||
"""Enable WebSocket compression.
|
||||
|
||||
Returning an empty dict enables websocket compression. Returning
|
||||
None disables it.
|
||||
|
||||
(See the docstring in the parent class.)
|
||||
"""
|
||||
if config.get_option("server.enableWebsocketCompression"):
|
||||
return {}
|
||||
return None
|
||||
|
||||
def on_message(self, payload: Union[str, bytes]) -> None:
|
||||
if not self._session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
if isinstance(payload, str):
|
||||
# Sanity check. (The frontend should only be sending us bytes;
|
||||
# Protobuf.ParseFromString does not accept str input.)
|
||||
raise RuntimeError(
|
||||
"WebSocket received an unexpected `str` message. "
|
||||
"(We expect `bytes` only.)"
|
||||
)
|
||||
|
||||
msg = BackMsg()
|
||||
msg.ParseFromString(payload)
|
||||
_LOGGER.debug("Received the following back message:\n%s", msg)
|
||||
|
||||
except Exception as ex:
|
||||
_LOGGER.error(ex)
|
||||
self._runtime.handle_backmsg_deserialization_exception(self._session_id, ex)
|
||||
return
|
||||
|
||||
# "debug_disconnect_websocket" and "debug_shutdown_runtime" are special
|
||||
# developmentMode-only messages used in e2e tests to test reconnect handling and
|
||||
# disabling widgets.
|
||||
if msg.WhichOneof("type") == "debug_disconnect_websocket":
|
||||
if config.get_option("global.developmentMode"):
|
||||
self.close()
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Client tried to disconnect websocket when not in development mode."
|
||||
)
|
||||
elif msg.WhichOneof("type") == "debug_shutdown_runtime":
|
||||
if config.get_option("global.developmentMode"):
|
||||
self._runtime.stop()
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Client tried to shut down runtime when not in development mode."
|
||||
)
|
||||
else:
|
||||
# AppSession handles all other BackMsg types.
|
||||
self._runtime.handle_backmsg(self._session_id, msg)
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
import tornado.web
|
||||
|
||||
import streamlit.web.server.routes
|
||||
from streamlit.components.v1.components import ComponentRegistry
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class ComponentRequestHandler(tornado.web.RequestHandler):
|
||||
def initialize(self, registry: ComponentRegistry):
|
||||
self._registry = registry
|
||||
|
||||
def get(self, path: str) -> None:
|
||||
parts = path.split("/")
|
||||
component_name = parts[0]
|
||||
component_root = self._registry.get_component_path(component_name)
|
||||
if component_root is None:
|
||||
self.write("not found")
|
||||
self.set_status(404)
|
||||
return
|
||||
|
||||
# follow symlinks to get an accurate normalized path
|
||||
component_root = os.path.realpath(component_root)
|
||||
filename = "/".join(parts[1:])
|
||||
abspath = os.path.realpath(os.path.join(component_root, filename))
|
||||
|
||||
# Do NOT expose anything outside of the component root.
|
||||
if os.path.commonprefix([component_root, abspath]) != component_root or (
|
||||
not os.path.normpath(abspath).startswith(
|
||||
component_root
|
||||
) # this is a recommendation from CodeQL, probably a bit redundant
|
||||
):
|
||||
self.write("forbidden")
|
||||
self.set_status(403)
|
||||
return
|
||||
try:
|
||||
with open(abspath, "rb") as file:
|
||||
contents = file.read()
|
||||
except (OSError) as e:
|
||||
_LOGGER.error(
|
||||
"ComponentRequestHandler: GET %s read error", abspath, exc_info=e
|
||||
)
|
||||
self.write("read error")
|
||||
self.set_status(404)
|
||||
return
|
||||
|
||||
self.write(contents)
|
||||
self.set_header("Content-Type", self.get_content_type(abspath))
|
||||
|
||||
self.set_extra_headers(path)
|
||||
|
||||
def set_extra_headers(self, path) -> None:
|
||||
"""Disable cache for HTML files.
|
||||
|
||||
Other assets like JS and CSS are suffixed with their hash, so they can
|
||||
be cached indefinitely.
|
||||
"""
|
||||
is_index_url = len(path) == 0
|
||||
|
||||
if is_index_url or path.endswith(".html"):
|
||||
self.set_header("Cache-Control", "no-cache")
|
||||
else:
|
||||
self.set_header("Cache-Control", "public")
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
if streamlit.web.server.routes.allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self) -> None:
|
||||
"""/OPTIONS handler for preflight CORS checks."""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
@staticmethod
|
||||
def get_content_type(abspath) -> str:
|
||||
"""Returns the ``Content-Type`` header to be used for this request.
|
||||
From tornado.web.StaticFileHandler.
|
||||
"""
|
||||
mime_type, encoding = mimetypes.guess_type(abspath)
|
||||
# per RFC 6713, use the appropriate type for a gzip compressed file
|
||||
if encoding == "gzip":
|
||||
return "application/gzip"
|
||||
# As of 2015-07-21 there is no bzip2 encoding defined at
|
||||
# http://www.iana.org/assignments/media-types/media-types.xhtml
|
||||
# So for that (and any other encoding), use octet-stream.
|
||||
elif encoding is not None:
|
||||
return "application/octet-stream"
|
||||
elif mime_type is not None:
|
||||
return mime_type
|
||||
# if mime_type not detected, use application/octet-stream
|
||||
else:
|
||||
return "application/octet-stream"
|
||||
|
||||
@staticmethod
|
||||
def get_url(file_id: str) -> str:
|
||||
"""Return the URL for a component file with the given ID."""
|
||||
return "components/{}".format(file_id)
|
||||
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import quote, unquote_plus
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorageError
|
||||
from streamlit.runtime.memory_media_file_storage import (
|
||||
MemoryMediaFileStorage,
|
||||
get_extension_for_mimetype,
|
||||
)
|
||||
from streamlit.string_util import generate_download_filename_from_title
|
||||
from streamlit.web.server import allow_cross_origin_requests
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class MediaFileHandler(tornado.web.StaticFileHandler):
|
||||
_storage: MemoryMediaFileStorage
|
||||
|
||||
@classmethod
|
||||
def initialize_storage(cls, storage: MemoryMediaFileStorage) -> None:
|
||||
"""Set the MemoryMediaFileStorage object used by instances of this
|
||||
handler. Must be called on server startup.
|
||||
"""
|
||||
# This is a class method, rather than an instance method, because
|
||||
# `get_content()` is a class method and needs to access the storage
|
||||
# instance.
|
||||
cls._storage = storage
|
||||
|
||||
def set_default_headers(self) -> None:
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def set_extra_headers(self, path: str) -> None:
|
||||
"""Add Content-Disposition header for downloadable files.
|
||||
|
||||
Set header value to "attachment" indicating that file should be saved
|
||||
locally instead of displaying inline in browser.
|
||||
|
||||
We also set filename to specify the filename for downloaded files.
|
||||
Used for serving downloadable files, like files stored via the
|
||||
`st.download_button` widget.
|
||||
"""
|
||||
media_file = self._storage.get_file(path)
|
||||
|
||||
if media_file and media_file.kind == MediaFileKind.DOWNLOADABLE:
|
||||
filename = media_file.filename
|
||||
|
||||
if not filename:
|
||||
title = self.get_argument("title", "", True)
|
||||
title = unquote_plus(title)
|
||||
filename = generate_download_filename_from_title(title)
|
||||
filename = (
|
||||
f"{filename}{get_extension_for_mimetype(media_file.mimetype)}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Check that the value can be encoded in latin1. Latin1 is
|
||||
# the default encoding for headers.
|
||||
filename.encode("latin1")
|
||||
file_expr = 'filename="{}"'.format(filename)
|
||||
except UnicodeEncodeError:
|
||||
# RFC5987 syntax.
|
||||
# See: https://datatracker.ietf.org/doc/html/rfc5987
|
||||
file_expr = "filename*=utf-8''{}".format(quote(filename))
|
||||
|
||||
self.set_header("Content-Disposition", f"attachment; {file_expr}")
|
||||
|
||||
# Overriding StaticFileHandler to use the MediaFileManager
|
||||
#
|
||||
# From the Tornado docs:
|
||||
# To replace all interaction with the filesystem (e.g. to serve
|
||||
# static content from a database), override `get_content`,
|
||||
# `get_content_size`, `get_modified_time`, `get_absolute_path`, and
|
||||
# `validate_absolute_path`.
|
||||
def validate_absolute_path(self, root: str, absolute_path: str) -> str:
|
||||
try:
|
||||
self._storage.get_file(absolute_path)
|
||||
except MediaFileStorageError:
|
||||
_LOGGER.error("MediaFileHandler: Missing file %s", absolute_path)
|
||||
raise tornado.web.HTTPError(404, "not found")
|
||||
|
||||
return absolute_path
|
||||
|
||||
def get_content_size(self) -> int:
|
||||
abspath = self.absolute_path
|
||||
if abspath is None:
|
||||
return 0
|
||||
|
||||
media_file = self._storage.get_file(abspath)
|
||||
return media_file.content_size
|
||||
|
||||
def get_modified_time(self) -> None:
|
||||
# We do not track last modified time, but this can be improved to
|
||||
# allow caching among files in the MediaFileManager
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_absolute_path(cls, root: str, path: str) -> str:
|
||||
# All files are stored in memory, so the absolute path is just the
|
||||
# path itself. In the MediaFileHandler, it's just the filename
|
||||
return path
|
||||
|
||||
@classmethod
|
||||
def get_content(
|
||||
cls, abspath: str, start: Optional[int] = None, end: Optional[int] = None
|
||||
):
|
||||
_LOGGER.debug("MediaFileHandler: GET %s", abspath)
|
||||
|
||||
try:
|
||||
# abspath is the hash as used `get_absolute_path`
|
||||
media_file = cls._storage.get_file(abspath)
|
||||
except Exception:
|
||||
_LOGGER.error("MediaFileHandler: Missing file %s", abspath)
|
||||
return None
|
||||
|
||||
_LOGGER.debug(
|
||||
"MediaFileHandler: Sending %s file %s", media_file.mimetype, abspath
|
||||
)
|
||||
|
||||
# If there is no start and end, just return the full content
|
||||
if start is None and end is None:
|
||||
return media_file.content
|
||||
|
||||
if start is None:
|
||||
start = 0
|
||||
if end is None:
|
||||
end = len(media_file.content)
|
||||
|
||||
# content is bytes that work just by slicing supplied by start and end
|
||||
return media_file.content[start:end]
|
||||
259
venv/lib/python3.9/site-packages/streamlit/web/server/routes.py
Normal file
259
venv/lib/python3.9/site-packages/streamlit/web/server/routes.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit import config, file_util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.runtime_util import serialize_forward_msg
|
||||
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def allow_cross_origin_requests():
|
||||
"""True if cross-origin requests are allowed.
|
||||
|
||||
We only allow cross-origin requests when CORS protection has been disabled
|
||||
with server.enableCORS=False or if using the Node server. When using the
|
||||
Node server, we have a dev and prod port, which count as two origins.
|
||||
|
||||
"""
|
||||
return not config.get_option("server.enableCORS") or config.get_option(
|
||||
"global.developmentMode"
|
||||
)
|
||||
|
||||
|
||||
class StaticFileHandler(tornado.web.StaticFileHandler):
|
||||
def initialize(self, path, default_filename, get_pages):
|
||||
self._pages = get_pages()
|
||||
|
||||
super().initialize(path=path, default_filename=default_filename)
|
||||
|
||||
def set_extra_headers(self, path):
|
||||
"""Disable cache for HTML files.
|
||||
|
||||
Other assets like JS and CSS are suffixed with their hash, so they can
|
||||
be cached indefinitely.
|
||||
"""
|
||||
is_index_url = len(path) == 0
|
||||
|
||||
if is_index_url or path.endswith(".html"):
|
||||
self.set_header("Cache-Control", "no-cache")
|
||||
else:
|
||||
self.set_header("Cache-Control", "public")
|
||||
|
||||
def parse_url_path(self, url_path: str) -> str:
|
||||
url_parts = url_path.split("/")
|
||||
|
||||
maybe_page_name = url_parts[0]
|
||||
if maybe_page_name in self._pages:
|
||||
# If we're trying to navigate to a page, we return "index.html"
|
||||
# directly here instead of deferring to the superclass below after
|
||||
# modifying the url_path. The reason why is that tornado handles
|
||||
# requests to "directories" (which is what navigating to a page
|
||||
# looks like) by appending a trailing '/' if there is none and
|
||||
# redirecting.
|
||||
#
|
||||
# This would work, but it
|
||||
# * adds an unnecessary redirect+roundtrip
|
||||
# * adds a trailing '/' to the URL appearing in the browser, which
|
||||
# looks bad
|
||||
if len(url_parts) == 1:
|
||||
return "index.html"
|
||||
|
||||
url_path = "/".join(url_parts[1:])
|
||||
|
||||
return super().parse_url_path(url_path)
|
||||
|
||||
def write_error(self, status_code: int, **kwargs) -> None:
|
||||
if status_code == 404:
|
||||
index_file = os.path.join(file_util.get_static_dir(), "index.html")
|
||||
self.render(index_file)
|
||||
else:
|
||||
super().write_error(status_code, **kwargs)
|
||||
|
||||
|
||||
class AssetsFileHandler(tornado.web.StaticFileHandler):
|
||||
# CORS protection should be disabled as we need access
|
||||
# to this endpoint from the inner iframe.
|
||||
def set_default_headers(self):
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
|
||||
class AddSlashHandler(tornado.web.RequestHandler):
|
||||
@tornado.web.addslash
|
||||
def get(self):
|
||||
pass
|
||||
|
||||
|
||||
class _SpecialRequestHandler(tornado.web.RequestHandler):
|
||||
"""Superclass for "special" endpoints, like /healthz."""
|
||||
|
||||
def set_default_headers(self):
|
||||
self.set_header("Cache-Control", "no-cache")
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self):
|
||||
"""/OPTIONS handler for preflight CORS checks.
|
||||
|
||||
When a browser is making a CORS request, it may sometimes first
|
||||
send an OPTIONS request, to check whether the server understands the
|
||||
CORS protocol. This is optional, and doesn't happen for every request
|
||||
or in every browser. If an OPTIONS request does get sent, and is not
|
||||
then handled by the server, the browser will fail the underlying
|
||||
request.
|
||||
|
||||
The proper way to handle this is to send a 204 response ("no content")
|
||||
with the CORS headers attached. (These headers are automatically added
|
||||
to every outgoing response, including OPTIONS responses,
|
||||
via set_default_headers().)
|
||||
|
||||
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
"""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
|
||||
class HealthHandler(_SpecialRequestHandler):
|
||||
def initialize(self, callback):
|
||||
"""Initialize the handler
|
||||
|
||||
Parameters
|
||||
----------
|
||||
callback : callable
|
||||
A function that returns True if the server is healthy
|
||||
|
||||
"""
|
||||
self._callback = callback
|
||||
|
||||
async def get(self):
|
||||
if self.request.uri and "_stcore/" not in self.request.uri:
|
||||
new_path = (
|
||||
"/_stcore/script-health-check"
|
||||
if "script-health-check" in self.request.uri
|
||||
else "/_stcore/health"
|
||||
)
|
||||
emit_endpoint_deprecation_notice(self, new_path=new_path)
|
||||
|
||||
ok, msg = await self._callback()
|
||||
if ok:
|
||||
self.write(msg)
|
||||
self.set_status(200)
|
||||
|
||||
# Tornado will set the _xsrf cookie automatically for the page on
|
||||
# request for the document. However, if the server is reset and
|
||||
# server.enableXsrfProtection is updated, the browser does not reload the document.
|
||||
# Manually setting the cookie on /healthz since it is pinged when the
|
||||
# browser is disconnected from the server.
|
||||
if config.get_option("server.enableXsrfProtection"):
|
||||
self.set_cookie("_xsrf", self.xsrf_token)
|
||||
|
||||
else:
|
||||
# 503 = SERVICE_UNAVAILABLE
|
||||
self.set_status(503)
|
||||
self.write(msg)
|
||||
|
||||
|
||||
# NOTE: We eventually want to get rid of this hard-coded list entirely as we don't want
|
||||
# to have links to Community Cloud live in the open source library in a way that affects
|
||||
# functionality (links advertising Community Cloud are probably okay 🙂). In the long
|
||||
# run, this list will most likely be replaced by a config option allowing us to more
|
||||
# granularly control what domains a Streamlit app should accept cross-origin iframe
|
||||
# messages from.
|
||||
ALLOWED_MESSAGE_ORIGINS = [
|
||||
"https://devel.streamlit.test",
|
||||
"https://*.streamlit.apptest",
|
||||
"https://*.streamlitapp.test",
|
||||
"https://*.streamlitapp.com",
|
||||
"https://share.streamlit.io",
|
||||
"https://share-demo.streamlit.io",
|
||||
"https://share-head.streamlit.io",
|
||||
"https://share-staging.streamlit.io",
|
||||
"https://*.demo.streamlit.run",
|
||||
"https://*.head.streamlit.run",
|
||||
"https://*.staging.streamlit.run",
|
||||
"https://*.streamlit.run",
|
||||
"https://*.demo.streamlit.app",
|
||||
"https://*.head.streamlit.app",
|
||||
"https://*.staging.streamlit.app",
|
||||
"https://*.streamlit.app",
|
||||
]
|
||||
|
||||
|
||||
class AllowedMessageOriginsHandler(_SpecialRequestHandler):
|
||||
async def get(self) -> None:
|
||||
# ALLOWED_MESSAGE_ORIGINS must be wrapped in a dictionary because Tornado
|
||||
# disallows writing lists directly into responses due to potential XSS
|
||||
# vulnerabilities.
|
||||
# See https://www.tornadoweb.org/en/stable/web.html#tornado.web.RequestHandler.write
|
||||
self.write(
|
||||
{
|
||||
"allowedOrigins": ALLOWED_MESSAGE_ORIGINS,
|
||||
"useExternalAuthToken": False,
|
||||
}
|
||||
)
|
||||
self.set_status(200)
|
||||
|
||||
|
||||
class MessageCacheHandler(tornado.web.RequestHandler):
|
||||
"""Returns ForwardMsgs from our MessageCache"""
|
||||
|
||||
def initialize(self, cache):
|
||||
"""Initializes the handler.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cache : MessageCache
|
||||
|
||||
"""
|
||||
self._cache = cache
|
||||
|
||||
def set_default_headers(self):
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def get(self):
|
||||
msg_hash = self.get_argument("hash", None)
|
||||
if msg_hash is None:
|
||||
# Hash is missing! This is a malformed request.
|
||||
_LOGGER.error(
|
||||
"HTTP request for cached message is missing the hash attribute."
|
||||
)
|
||||
self.set_status(404)
|
||||
raise tornado.web.Finish()
|
||||
|
||||
message = self._cache.get_message(msg_hash)
|
||||
if message is None:
|
||||
# Message not in our cache.
|
||||
_LOGGER.error(
|
||||
"HTTP request for cached message could not be fulfilled. "
|
||||
"No such message"
|
||||
)
|
||||
self.set_status(404)
|
||||
raise tornado.web.Finish()
|
||||
|
||||
_LOGGER.debug("MessageCache HIT")
|
||||
msg_str = serialize_forward_msg(message)
|
||||
self.set_header("Content-Type", "application/octet-stream")
|
||||
self.write(msg_str)
|
||||
self.set_status(200)
|
||||
|
||||
def options(self):
|
||||
"""/OPTIONS handler for preflight CORS checks."""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
410
venv/lib/python3.9/site-packages/streamlit/web/server/server.py
Normal file
410
venv/lib/python3.9/site-packages/streamlit/web/server/server.py
Normal file
@@ -0,0 +1,410 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Awaitable, List, Optional, Union
|
||||
|
||||
import click
|
||||
import tornado.concurrent
|
||||
import tornado.locks
|
||||
import tornado.netutil
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
from tornado.httpserver import HTTPServer
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit import config, file_util, source_util, util
|
||||
from streamlit.components.v1.components import ComponentRegistry
|
||||
from streamlit.config_option import ConfigOption
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime import Runtime, RuntimeConfig, RuntimeState
|
||||
from streamlit.runtime.memory_media_file_storage import MemoryMediaFileStorage
|
||||
from streamlit.runtime.runtime_util import get_max_message_size_bytes
|
||||
from streamlit.web.cache_storage_manager_config import (
|
||||
create_default_cache_storage_manager,
|
||||
)
|
||||
from streamlit.web.server.app_static_file_handler import AppStaticFileHandler
|
||||
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
|
||||
from streamlit.web.server.component_request_handler import ComponentRequestHandler
|
||||
from streamlit.web.server.media_file_handler import MediaFileHandler
|
||||
from streamlit.web.server.routes import (
|
||||
AddSlashHandler,
|
||||
AllowedMessageOriginsHandler,
|
||||
AssetsFileHandler,
|
||||
HealthHandler,
|
||||
MessageCacheHandler,
|
||||
StaticFileHandler,
|
||||
)
|
||||
from streamlit.web.server.server_util import make_url_path_regex
|
||||
from streamlit.web.server.stats_request_handler import StatsRequestHandler
|
||||
from streamlit.web.server.upload_file_request_handler import (
|
||||
UPLOAD_FILE_ROUTE,
|
||||
UploadFileRequestHandler,
|
||||
)
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
TORNADO_SETTINGS = {
|
||||
# Gzip HTTP responses.
|
||||
"compress_response": True,
|
||||
# Ping every 1s to keep WS alive.
|
||||
# 2021.06.22: this value was previously 20s, and was causing
|
||||
# connection instability for a small number of users. This smaller
|
||||
# ping_interval fixes that instability.
|
||||
# https://github.com/streamlit/streamlit/issues/3196
|
||||
"websocket_ping_interval": 1,
|
||||
# If we don't get a ping response within 30s, the connection
|
||||
# is timed out.
|
||||
"websocket_ping_timeout": 30,
|
||||
}
|
||||
|
||||
# When server.port is not available it will look for the next available port
|
||||
# up to MAX_PORT_SEARCH_RETRIES.
|
||||
MAX_PORT_SEARCH_RETRIES = 100
|
||||
|
||||
# When server.address starts with this prefix, the server will bind
|
||||
# to an unix socket.
|
||||
UNIX_SOCKET_PREFIX = "unix://"
|
||||
|
||||
MEDIA_ENDPOINT: Final = "/media"
|
||||
STREAM_ENDPOINT: Final = r"_stcore/stream"
|
||||
METRIC_ENDPOINT: Final = r"(?:st-metrics|_stcore/metrics)"
|
||||
MESSAGE_ENDPOINT: Final = r"_stcore/message"
|
||||
HEALTH_ENDPOINT: Final = r"(?:healthz|_stcore/health)"
|
||||
ALLOWED_MESSAGE_ORIGIN_ENDPOINT: Final = r"_stcore/allowed-message-origins"
|
||||
SCRIPT_HEALTH_CHECK_ENDPOINT: Final = (
|
||||
r"(?:script-health-check|_stcore/script-health-check)"
|
||||
)
|
||||
|
||||
|
||||
class RetriesExceeded(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def server_port_is_manually_set() -> bool:
|
||||
return config.is_manually_set("server.port")
|
||||
|
||||
|
||||
def server_address_is_unix_socket() -> bool:
|
||||
address = config.get_option("server.address")
|
||||
return address is not None and address.startswith(UNIX_SOCKET_PREFIX)
|
||||
|
||||
|
||||
def start_listening(app: tornado.web.Application) -> None:
|
||||
"""Makes the server start listening at the configured port.
|
||||
|
||||
In case the port is already taken it tries listening to the next available
|
||||
port. It will error after MAX_PORT_SEARCH_RETRIES attempts.
|
||||
|
||||
"""
|
||||
cert_file = config.get_option("server.sslCertFile")
|
||||
key_file = config.get_option("server.sslKeyFile")
|
||||
ssl_options = _get_ssl_options(cert_file, key_file)
|
||||
|
||||
http_server = HTTPServer(
|
||||
app,
|
||||
max_buffer_size=config.get_option("server.maxUploadSize") * 1024 * 1024,
|
||||
ssl_options=ssl_options,
|
||||
)
|
||||
|
||||
if server_address_is_unix_socket():
|
||||
start_listening_unix_socket(http_server)
|
||||
else:
|
||||
start_listening_tcp_socket(http_server)
|
||||
|
||||
|
||||
def _get_ssl_options(
|
||||
cert_file: Optional[str], key_file: Optional[str]
|
||||
) -> Union[ssl.SSLContext, None]:
|
||||
if bool(cert_file) != bool(key_file):
|
||||
LOGGER.error(
|
||||
"Options 'server.sslCertFile' and 'server.sslKeyFile' must "
|
||||
"be set together. Set missing options or delete existing options."
|
||||
)
|
||||
sys.exit(1)
|
||||
if cert_file and key_file:
|
||||
# ssl_ctx.load_cert_chain raise exception as below, but it is not
|
||||
# sufficiently user-friendly
|
||||
# FileNotFoundError: [Errno 2] No such file or directory
|
||||
if not Path(cert_file).exists():
|
||||
LOGGER.error("Cert file '%s' does not exist.", cert_file)
|
||||
sys.exit(1)
|
||||
if not Path(key_file).exists():
|
||||
LOGGER.error("Key file '%s' does not exist.", key_file)
|
||||
sys.exit(1)
|
||||
|
||||
ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
# When the SSL certificate fails to load, an exception is raised as below,
|
||||
# but it is not sufficiently user-friendly.
|
||||
# ssl.SSLError: [SSL] PEM lib (_ssl.c:4067)
|
||||
try:
|
||||
ssl_ctx.load_cert_chain(cert_file, key_file)
|
||||
except ssl.SSLError:
|
||||
LOGGER.error(
|
||||
"Failed to load SSL certificate. Make sure "
|
||||
"cert file '%s' and key file '%s' are correct.",
|
||||
cert_file,
|
||||
key_file,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
return ssl_ctx
|
||||
return None
|
||||
|
||||
|
||||
def start_listening_unix_socket(http_server: HTTPServer) -> None:
|
||||
address = config.get_option("server.address")
|
||||
file_name = os.path.expanduser(address[len(UNIX_SOCKET_PREFIX) :])
|
||||
|
||||
unix_socket = tornado.netutil.bind_unix_socket(file_name)
|
||||
http_server.add_socket(unix_socket)
|
||||
|
||||
|
||||
def start_listening_tcp_socket(http_server: HTTPServer) -> None:
|
||||
call_count = 0
|
||||
|
||||
port = None
|
||||
while call_count < MAX_PORT_SEARCH_RETRIES:
|
||||
address = config.get_option("server.address")
|
||||
port = config.get_option("server.port")
|
||||
|
||||
try:
|
||||
http_server.listen(port, address)
|
||||
break # It worked! So let's break out of the loop.
|
||||
|
||||
except (OSError, socket.error) as e:
|
||||
if e.errno == errno.EADDRINUSE:
|
||||
if server_port_is_manually_set():
|
||||
LOGGER.error("Port %s is already in use", port)
|
||||
sys.exit(1)
|
||||
else:
|
||||
LOGGER.debug(
|
||||
"Port %s already in use, trying to use the next one.", port
|
||||
)
|
||||
port += 1
|
||||
# Save port 3000 because it is used for the development
|
||||
# server in the front end.
|
||||
if port == 3000:
|
||||
port += 1
|
||||
|
||||
config.set_option(
|
||||
"server.port", port, ConfigOption.STREAMLIT_DEFINITION
|
||||
)
|
||||
call_count += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
if call_count >= MAX_PORT_SEARCH_RETRIES:
|
||||
raise RetriesExceeded(
|
||||
f"Cannot start Streamlit server. Port {port} is already in use, and "
|
||||
f"Streamlit was unable to find a free port after {MAX_PORT_SEARCH_RETRIES} attempts.",
|
||||
)
|
||||
|
||||
|
||||
class Server:
|
||||
def __init__(self, main_script_path: str, command_line: Optional[str]):
|
||||
"""Create the server. It won't be started yet."""
|
||||
_set_tornado_log_levels()
|
||||
|
||||
self._main_script_path = main_script_path
|
||||
|
||||
# Initialize MediaFileStorage and its associated endpoint
|
||||
media_file_storage = MemoryMediaFileStorage(MEDIA_ENDPOINT)
|
||||
MediaFileHandler.initialize_storage(media_file_storage)
|
||||
|
||||
self._runtime = Runtime(
|
||||
RuntimeConfig(
|
||||
script_path=main_script_path,
|
||||
command_line=command_line,
|
||||
media_file_storage=media_file_storage,
|
||||
cache_storage_manager=create_default_cache_storage_manager(),
|
||||
),
|
||||
)
|
||||
|
||||
self._runtime.stats_mgr.register_provider(media_file_storage)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def main_script_path(self) -> str:
|
||||
return self._main_script_path
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the server.
|
||||
|
||||
When this returns, Streamlit is ready to accept new sessions.
|
||||
"""
|
||||
|
||||
LOGGER.debug("Starting server...")
|
||||
|
||||
app = self._create_app()
|
||||
start_listening(app)
|
||||
|
||||
port = config.get_option("server.port")
|
||||
LOGGER.debug("Server started on port %s", port)
|
||||
|
||||
await self._runtime.start()
|
||||
|
||||
@property
|
||||
def stopped(self) -> Awaitable[None]:
|
||||
"""A Future that completes when the Server's run loop has exited."""
|
||||
return self._runtime.stopped
|
||||
|
||||
def _create_app(self) -> tornado.web.Application:
|
||||
"""Create our tornado web app."""
|
||||
base = config.get_option("server.baseUrlPath")
|
||||
|
||||
routes: List[Any] = [
|
||||
(
|
||||
make_url_path_regex(base, STREAM_ENDPOINT),
|
||||
BrowserWebSocketHandler,
|
||||
dict(runtime=self._runtime),
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, HEALTH_ENDPOINT),
|
||||
HealthHandler,
|
||||
dict(callback=lambda: self._runtime.is_ready_for_browser_connection),
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, MESSAGE_ENDPOINT),
|
||||
MessageCacheHandler,
|
||||
dict(cache=self._runtime.message_cache),
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, METRIC_ENDPOINT),
|
||||
StatsRequestHandler,
|
||||
dict(stats_manager=self._runtime.stats_mgr),
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, ALLOWED_MESSAGE_ORIGIN_ENDPOINT),
|
||||
AllowedMessageOriginsHandler,
|
||||
),
|
||||
(
|
||||
make_url_path_regex(
|
||||
base,
|
||||
UPLOAD_FILE_ROUTE,
|
||||
),
|
||||
UploadFileRequestHandler,
|
||||
dict(
|
||||
file_mgr=self._runtime.uploaded_file_mgr,
|
||||
is_active_session=self._runtime.is_active_session,
|
||||
),
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, "assets/(.*)"),
|
||||
AssetsFileHandler,
|
||||
{"path": "%s/" % file_util.get_assets_dir()},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, f"{MEDIA_ENDPOINT}/(.*)"),
|
||||
MediaFileHandler,
|
||||
{"path": ""},
|
||||
),
|
||||
(
|
||||
make_url_path_regex(base, "component/(.*)"),
|
||||
ComponentRequestHandler,
|
||||
dict(registry=ComponentRegistry.instance()),
|
||||
),
|
||||
]
|
||||
|
||||
if config.get_option("server.scriptHealthCheckEnabled"):
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
make_url_path_regex(base, SCRIPT_HEALTH_CHECK_ENDPOINT),
|
||||
HealthHandler,
|
||||
dict(
|
||||
callback=lambda: self._runtime.does_script_run_without_error()
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
if config.get_option("server.enableStaticServing"):
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
make_url_path_regex(base, "app/static/(.*)"),
|
||||
AppStaticFileHandler,
|
||||
{"path": file_util.get_app_static_dir(self.main_script_path)},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if config.get_option("global.developmentMode"):
|
||||
LOGGER.debug("Serving static content from the Node dev server")
|
||||
else:
|
||||
static_path = file_util.get_static_dir()
|
||||
LOGGER.debug("Serving static content from %s", static_path)
|
||||
|
||||
routes.extend(
|
||||
[
|
||||
(
|
||||
make_url_path_regex(base, "(.*)"),
|
||||
StaticFileHandler,
|
||||
{
|
||||
"path": "%s/" % static_path,
|
||||
"default_filename": "index.html",
|
||||
"get_pages": lambda: set(
|
||||
[
|
||||
page_info["page_name"]
|
||||
for page_info in source_util.get_pages(
|
||||
self.main_script_path
|
||||
).values()
|
||||
]
|
||||
),
|
||||
},
|
||||
),
|
||||
(make_url_path_regex(base, trailing_slash=False), AddSlashHandler),
|
||||
]
|
||||
)
|
||||
|
||||
return tornado.web.Application(
|
||||
routes,
|
||||
cookie_secret=config.get_option("server.cookieSecret"),
|
||||
xsrf_cookies=config.get_option("server.enableXsrfProtection"),
|
||||
# Set the websocket message size. The default value is too low.
|
||||
websocket_max_message_size=get_max_message_size_bytes(),
|
||||
**TORNADO_SETTINGS, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
@property
|
||||
def browser_is_connected(self) -> bool:
|
||||
return self._runtime.state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
|
||||
|
||||
@property
|
||||
def is_running_hello(self) -> bool:
|
||||
from streamlit.hello import Hello
|
||||
|
||||
return self._main_script_path == Hello.__file__
|
||||
|
||||
def stop(self) -> None:
|
||||
click.secho(" Stopping...", fg="blue")
|
||||
self._runtime.stop()
|
||||
|
||||
|
||||
def _set_tornado_log_levels() -> None:
|
||||
if not config.get_option("global.developmentMode"):
|
||||
# Hide logs unless they're super important.
|
||||
# Example of stuff we don't care about: 404 about .js.map files.
|
||||
logging.getLogger("tornado.access").setLevel(logging.ERROR)
|
||||
logging.getLogger("tornado.application").setLevel(logging.ERROR)
|
||||
logging.getLogger("tornado.general").setLevel(logging.ERROR)
|
||||
@@ -0,0 +1,126 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Server related utility functions"""
|
||||
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit import config, net_util, url_util
|
||||
|
||||
|
||||
def is_url_from_allowed_origins(url: str) -> bool:
|
||||
"""Return True if URL is from allowed origins (for CORS purpose).
|
||||
|
||||
Allowed origins:
|
||||
1. localhost
|
||||
2. The internal and external IP addresses of the machine where this
|
||||
function was called from.
|
||||
|
||||
If `server.enableCORS` is False, this allows all origins.
|
||||
"""
|
||||
if not config.get_option("server.enableCORS"):
|
||||
# Allow everything when CORS is disabled.
|
||||
return True
|
||||
|
||||
hostname = url_util.get_hostname(url)
|
||||
|
||||
allowed_domains = [ # List[Union[str, Callable[[], Optional[str]]]]
|
||||
# Check localhost first.
|
||||
"localhost",
|
||||
"0.0.0.0",
|
||||
"127.0.0.1",
|
||||
# Try to avoid making unnecessary HTTP requests by checking if the user
|
||||
# manually specified a server address.
|
||||
_get_server_address_if_manually_set,
|
||||
# Then try the options that depend on HTTP requests or opening sockets.
|
||||
net_util.get_internal_ip,
|
||||
net_util.get_external_ip,
|
||||
]
|
||||
|
||||
for allowed_domain in allowed_domains:
|
||||
if callable(allowed_domain):
|
||||
allowed_domain = allowed_domain()
|
||||
|
||||
if allowed_domain is None:
|
||||
continue
|
||||
|
||||
if hostname == allowed_domain:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_server_address_if_manually_set() -> Optional[str]:
|
||||
if config.is_manually_set("browser.serverAddress"):
|
||||
return url_util.get_hostname(config.get_option("browser.serverAddress"))
|
||||
return None
|
||||
|
||||
|
||||
def make_url_path_regex(*path, **kwargs) -> str:
|
||||
"""Get a regex of the form ^/foo/bar/baz/?$ for a path (foo, bar, baz)."""
|
||||
path = [x.strip("/") for x in path if x] # Filter out falsely components.
|
||||
path_format = r"^/%s/?$" if kwargs.get("trailing_slash", True) else r"^/%s$"
|
||||
return path_format % "/".join(path)
|
||||
|
||||
|
||||
def get_url(host_ip: str) -> str:
|
||||
"""Get the URL for any app served at the given host_ip.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host_ip : str
|
||||
The IP address of the machine that is running the Streamlit Server.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The URL.
|
||||
"""
|
||||
protocol = "https" if config.get_option("server.sslCertFile") else "http"
|
||||
|
||||
port = _get_browser_address_bar_port()
|
||||
base_path = config.get_option("server.baseUrlPath").strip("/")
|
||||
|
||||
if base_path:
|
||||
base_path = "/" + base_path
|
||||
|
||||
host_ip = host_ip.strip("/")
|
||||
return f"{protocol}://{host_ip}:{port}{base_path}"
|
||||
|
||||
|
||||
def _get_browser_address_bar_port() -> int:
|
||||
"""Get the app URL that will be shown in the browser's address bar.
|
||||
|
||||
That is, this is the port where static assets will be served from. In dev,
|
||||
this is different from the URL that will be used to connect to the
|
||||
server-browser websocket.
|
||||
|
||||
"""
|
||||
if config.get_option("global.developmentMode"):
|
||||
return 3000
|
||||
return int(config.get_option("browser.serverPort"))
|
||||
|
||||
|
||||
def emit_endpoint_deprecation_notice(
|
||||
handler: tornado.web.RequestHandler, new_path: str
|
||||
) -> None:
|
||||
"""
|
||||
Emits the warning about deprecation of HTTP endpoint in the HTTP header.
|
||||
"""
|
||||
handler.set_header("Deprecation", True)
|
||||
new_url = urljoin(f"{handler.request.protocol}://{handler.request.host}", new_path)
|
||||
handler.set_header("Link", f'<{new_url}>; rel="alternate"')
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import tornado.web
|
||||
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import GAUGE
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import MetricSet as MetricSetProto
|
||||
from streamlit.runtime.stats import CacheStat, StatsManager
|
||||
from streamlit.web.server.server_util import emit_endpoint_deprecation_notice
|
||||
|
||||
|
||||
class StatsRequestHandler(tornado.web.RequestHandler):
|
||||
def initialize(self, stats_manager: StatsManager) -> None:
|
||||
self._manager = stats_manager
|
||||
|
||||
def set_default_headers(self):
|
||||
# Avoid a circular import
|
||||
from streamlit.web.server import allow_cross_origin_requests
|
||||
|
||||
if allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self):
|
||||
"""/OPTIONS handler for preflight CORS checks."""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
def get(self) -> None:
|
||||
if self.request.uri and "_stcore/" not in self.request.uri:
|
||||
emit_endpoint_deprecation_notice(self, new_path="/_stcore/metrics")
|
||||
|
||||
stats = self._manager.get_stats()
|
||||
|
||||
# If the request asked for protobuf output, we return a serialized
|
||||
# protobuf. Else we return text.
|
||||
if "application/x-protobuf" in self.request.headers.get_list("Accept"):
|
||||
self.write(self._stats_to_proto(stats).SerializeToString())
|
||||
self.set_header("Content-Type", "application/x-protobuf")
|
||||
self.set_status(200)
|
||||
else:
|
||||
self.write(self._stats_to_text(self._manager.get_stats()))
|
||||
self.set_header("Content-Type", "application/openmetrics-text")
|
||||
self.set_status(200)
|
||||
|
||||
@staticmethod
|
||||
def _stats_to_text(stats: List[CacheStat]) -> str:
|
||||
metric_type = "# TYPE cache_memory_bytes gauge"
|
||||
metric_unit = "# UNIT cache_memory_bytes bytes"
|
||||
metric_help = "# HELP Total memory consumed by a cache."
|
||||
openmetrics_eof = "# EOF\n"
|
||||
|
||||
# Format: header, stats, EOF
|
||||
result = [metric_type, metric_unit, metric_help]
|
||||
result.extend(stat.to_metric_str() for stat in stats)
|
||||
result.append(openmetrics_eof)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
@staticmethod
|
||||
def _stats_to_proto(stats: List[CacheStat]) -> MetricSetProto:
|
||||
metric_set = MetricSetProto()
|
||||
|
||||
metric_family = metric_set.metric_families.add()
|
||||
metric_family.name = "cache_memory_bytes"
|
||||
metric_family.type = GAUGE
|
||||
metric_family.unit = "bytes"
|
||||
metric_family.help = "Total memory consumed by a cache."
|
||||
|
||||
for stat in stats:
|
||||
metric_proto = metric_family.metrics.add()
|
||||
stat.marshall_metric_proto(metric_proto)
|
||||
|
||||
metric_set = MetricSetProto()
|
||||
metric_set.metric_families.append(metric_family)
|
||||
return metric_set
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import tornado.httputil
|
||||
import tornado.web
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager, UploadedFileRec
|
||||
from streamlit.web.server import routes, server_util
|
||||
|
||||
# /_stcore/upload_file/(optional session id)/(optional widget id)
|
||||
UPLOAD_FILE_ROUTE = (
|
||||
r"/_stcore/upload_file/?(?P<session_id>[^/]*)?/?(?P<widget_id>[^/]*)?"
|
||||
)
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class UploadFileRequestHandler(tornado.web.RequestHandler):
|
||||
"""Implements the POST /upload_file endpoint."""
|
||||
|
||||
def initialize(
|
||||
self, file_mgr: UploadedFileManager, is_active_session: Callable[[str], bool]
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
file_mgr : UploadedFileManager
|
||||
The server's singleton UploadedFileManager. All file uploads
|
||||
go here.
|
||||
is_active_session:
|
||||
A function that returns true if a session_id belongs to an active
|
||||
session.
|
||||
"""
|
||||
self._file_mgr = file_mgr
|
||||
self._is_active_session = is_active_session
|
||||
|
||||
def set_default_headers(self):
|
||||
self.set_header("Access-Control-Allow-Methods", "POST, OPTIONS")
|
||||
self.set_header("Access-Control-Allow-Headers", "Content-Type")
|
||||
if config.get_option("server.enableXsrfProtection"):
|
||||
self.set_header(
|
||||
"Access-Control-Allow-Origin",
|
||||
server_util.get_url(config.get_option("browser.serverAddress")),
|
||||
)
|
||||
self.set_header("Access-Control-Allow-Headers", "X-Xsrftoken, Content-Type")
|
||||
self.set_header("Vary", "Origin")
|
||||
self.set_header("Access-Control-Allow-Credentials", "true")
|
||||
elif routes.allow_cross_origin_requests():
|
||||
self.set_header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
def options(self, **kwargs):
|
||||
"""/OPTIONS handler for preflight CORS checks.
|
||||
|
||||
When a browser is making a CORS request, it may sometimes first
|
||||
send an OPTIONS request, to check whether the server understands the
|
||||
CORS protocol. This is optional, and doesn't happen for every request
|
||||
or in every browser. If an OPTIONS request does get sent, and is not
|
||||
then handled by the server, the browser will fail the underlying
|
||||
request.
|
||||
|
||||
The proper way to handle this is to send a 204 response ("no content")
|
||||
with the CORS headers attached. (These headers are automatically added
|
||||
to every outgoing response, including OPTIONS responses,
|
||||
via set_default_headers().)
|
||||
|
||||
See https://developer.mozilla.org/en-US/docs/Glossary/Preflight_request
|
||||
"""
|
||||
self.set_status(204)
|
||||
self.finish()
|
||||
|
||||
@staticmethod
|
||||
def _require_arg(args: Dict[str, List[bytes]], name: str) -> str:
|
||||
"""Return the value of the argument with the given name.
|
||||
|
||||
A human-readable exception will be raised if the argument doesn't
|
||||
exist. This will be used as the body for the error response returned
|
||||
from the request.
|
||||
"""
|
||||
try:
|
||||
arg = args[name]
|
||||
except KeyError:
|
||||
raise Exception(f"Missing '{name}'")
|
||||
|
||||
if len(arg) != 1:
|
||||
raise Exception(f"Expected 1 '{name}' arg, but got {len(arg)}")
|
||||
|
||||
# Convert bytes to string
|
||||
return arg[0].decode("utf-8")
|
||||
|
||||
def post(self, **kwargs):
|
||||
"""Receive an uploaded file and add it to our UploadedFileManager.
|
||||
Return the file's ID, so that the client can refer to it.
|
||||
"""
|
||||
args: Dict[str, List[bytes]] = {}
|
||||
files: Dict[str, List[Any]] = {}
|
||||
|
||||
tornado.httputil.parse_body_arguments(
|
||||
content_type=self.request.headers["Content-Type"],
|
||||
body=self.request.body,
|
||||
arguments=args,
|
||||
files=files,
|
||||
)
|
||||
|
||||
try:
|
||||
session_id = self._require_arg(args, "sessionId")
|
||||
widget_id = self._require_arg(args, "widgetId")
|
||||
if not self._is_active_session(session_id):
|
||||
raise Exception(f"Invalid session_id: '{session_id}'")
|
||||
|
||||
except Exception as e:
|
||||
self.send_error(400, reason=str(e))
|
||||
return
|
||||
|
||||
# Create an UploadedFile object for each file.
|
||||
# We assign an initial, invalid file_id to each file in this loop.
|
||||
# The file_mgr will assign unique file IDs and return in `add_file`,
|
||||
# below.
|
||||
uploaded_files: List[UploadedFileRec] = []
|
||||
for _, flist in files.items():
|
||||
for file in flist:
|
||||
uploaded_files.append(
|
||||
UploadedFileRec(
|
||||
id=0,
|
||||
name=file["filename"],
|
||||
type=file["content_type"],
|
||||
data=file["body"],
|
||||
)
|
||||
)
|
||||
|
||||
if len(uploaded_files) != 1:
|
||||
self.send_error(
|
||||
400, reason=f"Expected 1 file, but got {len(uploaded_files)}"
|
||||
)
|
||||
return
|
||||
|
||||
added_file = self._file_mgr.add_file(
|
||||
session_id=session_id, widget_id=widget_id, file=uploaded_files[0]
|
||||
)
|
||||
|
||||
# Return the file_id to the client. (The client will parse
|
||||
# the string back to an int.)
|
||||
self.write(str(added_file.id))
|
||||
self.set_status(200)
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from streamlit import runtime
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
||||
from streamlit.web.server.browser_websocket_handler import BrowserWebSocketHandler
|
||||
|
||||
|
||||
@gather_metrics("_get_websocket_headers")
|
||||
def _get_websocket_headers() -> Optional[Dict[str, str]]:
|
||||
"""Return a copy of the HTTP request headers for the current session's
|
||||
WebSocket connection. If there's no active session, return None instead.
|
||||
|
||||
Raise an error if the server is not running.
|
||||
|
||||
Note to the intrepid: this is an UNSUPPORTED, INTERNAL API. (We don't have plans
|
||||
to remove it without a replacement, but we don't consider this a production-ready
|
||||
function, and its signature may change without a deprecation warning.)
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return None
|
||||
|
||||
session_client = runtime.get_instance().get_client(ctx.session_id)
|
||||
if session_client is None:
|
||||
return None
|
||||
|
||||
if not isinstance(session_client, BrowserWebSocketHandler):
|
||||
raise RuntimeError(
|
||||
f"SessionClient is not a BrowserWebSocketHandler! ({session_client})"
|
||||
)
|
||||
|
||||
return dict(session_client.request.headers)
|
||||
Reference in New Issue
Block a user