Merging PR_218 openai_rev package with new streamlit chat app

This commit is contained in:
noptuno
2023-04-27 20:29:30 -04:00
parent 479b8d6d10
commit 355dee533b
8378 changed files with 2931636 additions and 3 deletions

View File

@@ -0,0 +1,40 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Explicitly re-export public symbols from runtime.py and session_manager.py
from streamlit.runtime.runtime import Runtime as Runtime
from streamlit.runtime.runtime import RuntimeConfig as RuntimeConfig
from streamlit.runtime.runtime import RuntimeState as RuntimeState
from streamlit.runtime.session_manager import SessionClient as SessionClient
from streamlit.runtime.session_manager import (
SessionClientDisconnectedError as SessionClientDisconnectedError,
)
def get_instance() -> Runtime:
"""Return the singleton Runtime instance. Raise an Error if the
Runtime hasn't been created yet.
"""
return Runtime.instance()
def exists() -> bool:
"""True if the singleton Runtime instance has been created.
When a Streamlit app is running in "raw mode" - that is, when the
app is run via `python app.py` instead of `streamlit run app.py` -
the Runtime will not exist, and various Streamlit functions need
to adapt.
"""
return Runtime.exists()

View File

@@ -0,0 +1,806 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import sys
import uuid
from enum import Enum
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
import streamlit.elements.exception as exception_utils
from streamlit import config, runtime, source_util
from streamlit.case_converters import to_snake_case
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.proto.ClientState_pb2 import ClientState
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.GitInfo_pb2 import GitInfo
from streamlit.proto.NewSession_pb2 import (
Config,
CustomThemeConfig,
NewSession,
UserInfo,
)
from streamlit.proto.PagesChanged_pb2 import PagesChanged
from streamlit.runtime import caching, legacy_caching
from streamlit.runtime.credentials import Credentials
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
from streamlit.runtime.metrics_util import Installation
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.scriptrunner import RerunData, ScriptRunner, ScriptRunnerEvent
from streamlit.runtime.secrets import secrets_singleton
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.version import STREAMLIT_VERSION_STRING
from streamlit.watcher import LocalSourcesWatcher
LOGGER = get_logger(__name__)
if TYPE_CHECKING:
from streamlit.runtime.state import SessionState
class AppSessionState(Enum):
APP_NOT_RUNNING = "APP_NOT_RUNNING"
APP_IS_RUNNING = "APP_IS_RUNNING"
SHUTDOWN_REQUESTED = "SHUTDOWN_REQUESTED"
def _generate_scriptrun_id() -> str:
"""Randomly generate a unique ID for a script execution."""
return str(uuid.uuid4())
class AppSession:
"""
Contains session data for a single "user" of an active app
(that is, a connected browser tab).
Each AppSession has its own ScriptData, root DeltaGenerator, ScriptRunner,
and widget state.
An AppSession is attached to each thread involved in running its script.
"""
def __init__(
self,
script_data: ScriptData,
uploaded_file_manager: UploadedFileManager,
message_enqueued_callback: Optional[Callable[[], None]],
local_sources_watcher: LocalSourcesWatcher,
user_info: Dict[str, Optional[str]],
) -> None:
"""Initialize the AppSession.
Parameters
----------
script_data : ScriptData
Object storing parameters related to running a script
uploaded_file_manager : UploadedFileManager
Used to manage files uploaded by users via the Streamlit web client.
message_enqueued_callback : Callable[[], None]
After enqueuing a message, this callable notification will be invoked.
local_sources_watcher: LocalSourcesWatcher
The file watcher that lets the session know local files have changed.
user_info: Dict
A dict that contains information about the current user. For now,
it only contains the user's email address.
{
"email": "example@example.com"
}
Information about the current user is optionally provided when a
websocket connection is initialized via the "X-Streamlit-User" header.
"""
# Each AppSession has a unique string ID.
self.id = str(uuid.uuid4())
self._event_loop = asyncio.get_running_loop()
self._script_data = script_data
self._uploaded_file_mgr = uploaded_file_manager
# The browser queue contains messages that haven't yet been
# delivered to the browser. Periodically, the server flushes
# this queue and delivers its contents to the browser.
self._browser_queue = ForwardMsgQueue()
self._message_enqueued_callback = message_enqueued_callback
self._state = AppSessionState.APP_NOT_RUNNING
# Need to remember the client state here because when a script reruns
# due to the source code changing we need to pass in the previous client state.
self._client_state = ClientState()
self._local_sources_watcher: Optional[
LocalSourcesWatcher
] = local_sources_watcher
self._stop_config_listener: Optional[Callable[[], bool]] = None
self._stop_pages_listener: Optional[Callable[[], bool]] = None
self.register_file_watchers()
self._run_on_save = config.get_option("server.runOnSave")
self._scriptrunner: Optional[ScriptRunner] = None
# This needs to be lazily imported to avoid a dependency cycle.
from streamlit.runtime.state import SessionState
self._session_state = SessionState()
self._user_info = user_info
self._debug_last_backmsg_id: Optional[str] = None
LOGGER.debug("AppSession initialized (id=%s)", self.id)
def __del__(self) -> None:
"""Ensure that we call shutdown() when an AppSession is garbage collected."""
self.shutdown()
def register_file_watchers(self) -> None:
"""Register handlers to be called when various files are changed.
Files that we watch include:
* source files that already exist (for edits)
* `.py` files in the the main script's `pages/` directory (for file additions
and deletions)
* project and user-level config.toml files
* the project-level secrets.toml files
This method is called automatically on AppSession construction, but it may be
called again in the case when a session is disconnected and is being reconnect
to.
"""
if self._local_sources_watcher is None:
self._local_sources_watcher = LocalSourcesWatcher(
self._script_data.main_script_path
)
self._local_sources_watcher.register_file_change_callback(
self._on_source_file_changed
)
self._stop_config_listener = config.on_config_parsed(
self._on_source_file_changed, force_connect=True
)
self._stop_pages_listener = source_util.register_pages_changed_callback(
self._on_pages_changed
)
secrets_singleton.file_change_listener.connect(self._on_secrets_file_changed)
def disconnect_file_watchers(self) -> None:
"""Disconnect the file watcher handlers registered by register_file_watchers."""
if self._local_sources_watcher is not None:
self._local_sources_watcher.close()
if self._stop_config_listener is not None:
self._stop_config_listener()
if self._stop_pages_listener is not None:
self._stop_pages_listener()
secrets_singleton.file_change_listener.disconnect(self._on_secrets_file_changed)
self._local_sources_watcher = None
self._stop_config_listener = None
self._stop_pages_listener = None
def flush_browser_queue(self) -> List[ForwardMsg]:
"""Clear the forward message queue and return the messages it contained.
The Server calls this periodically to deliver new messages
to the browser connected to this app.
Returns
-------
list[ForwardMsg]
The messages that were removed from the queue and should
be delivered to the browser.
"""
return self._browser_queue.flush()
def shutdown(self) -> None:
"""Shut down the AppSession.
It's an error to use a AppSession after it's been shut down.
"""
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
LOGGER.debug("Shutting down (id=%s)", self.id)
# Clear any unused session files in upload file manager and media
# file manager
self._uploaded_file_mgr.remove_session_files(self.id)
if runtime.exists():
runtime.get_instance().media_file_mgr.clear_session_refs(self.id)
runtime.get_instance().media_file_mgr.remove_orphaned_files()
# Shut down the ScriptRunner, if one is active.
# self._state must not be set to SHUTDOWN_REQUESTED until
# *after* this is called.
self.request_script_stop()
self._state = AppSessionState.SHUTDOWN_REQUESTED
# Disconnect all file watchers if we haven't already, although we will have
# generally already done so by the time we get here.
self.disconnect_file_watchers()
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
"""Enqueue a new ForwardMsg to our browser queue.
This can be called on both the main thread and a ScriptRunner
run thread.
Parameters
----------
msg : ForwardMsg
The message to enqueue
"""
if not config.get_option("client.displayEnabled"):
return
if self._debug_last_backmsg_id:
msg.debug_last_backmsg_id = self._debug_last_backmsg_id
self._browser_queue.enqueue(msg)
if self._message_enqueued_callback:
self._message_enqueued_callback()
def handle_backmsg(self, msg: BackMsg) -> None:
"""Process a BackMsg."""
try:
msg_type = msg.WhichOneof("type")
if msg_type == "rerun_script":
if msg.debug_last_backmsg_id:
self._debug_last_backmsg_id = msg.debug_last_backmsg_id
self._handle_rerun_script_request(msg.rerun_script)
elif msg_type == "load_git_info":
self._handle_git_information_request()
elif msg_type == "clear_cache":
self._handle_clear_cache_request()
elif msg_type == "set_run_on_save":
self._handle_set_run_on_save_request(msg.set_run_on_save)
elif msg_type == "stop_script":
self._handle_stop_script_request()
else:
LOGGER.warning('No handler for "%s"', msg_type)
except Exception as ex:
LOGGER.error(ex)
self.handle_backmsg_exception(ex)
def handle_backmsg_exception(self, e: BaseException) -> None:
"""Handle an Exception raised while processing a BackMsg from the browser."""
# This does a few things:
# 1) Clears the current app in the browser.
# 2) Marks the current app as "stopped" in the browser.
# 3) HACK: Resets any script params that may have been broken (e.g. the
# command-line when rerunning with wrong argv[0])
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
)
self._on_scriptrunner_event(
self._scriptrunner,
ScriptRunnerEvent.SCRIPT_STARTED,
page_script_hash="",
)
self._on_scriptrunner_event(
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
)
# Send an Exception message to the frontend.
# Because _on_scriptrunner_event does its work in an eventloop callback,
# this exception ForwardMsg *must* also be enqueued in a callback,
# so that it will be enqueued *after* the various ForwardMsgs that
# _on_scriptrunner_event sends.
self._event_loop.call_soon_threadsafe(
lambda: self._enqueue_forward_msg(self._create_exception_message(e))
)
def request_rerun(self, client_state: Optional[ClientState]) -> None:
"""Signal that we're interested in running the script.
If the script is not already running, it will be started immediately.
Otherwise, a rerun will be requested.
Parameters
----------
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ClientState protobuf to run the script with, or None
to use previous client state.
"""
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
LOGGER.warning("Discarding rerun request after shutdown")
return
if client_state:
rerun_data = RerunData(
client_state.query_string,
client_state.widget_states,
client_state.page_script_hash,
client_state.page_name,
)
else:
rerun_data = RerunData()
if self._scriptrunner is not None:
if bool(config.get_option("runner.fastReruns")):
# If fastReruns is enabled, we don't send rerun requests to our
# existing ScriptRunner. Instead, we tell it to shut down. We'll
# then spin up a new ScriptRunner, below, to handle the rerun
# immediately.
self._scriptrunner.request_stop()
self._scriptrunner = None
else:
# fastReruns is not enabled. Send our ScriptRunner a rerun
# request. If the request is accepted, we're done.
success = self._scriptrunner.request_rerun(rerun_data)
if success:
return
# If we are here, then either we have no ScriptRunner, or our
# current ScriptRunner is shutting down and cannot handle a rerun
# request - so we'll create and start a new ScriptRunner.
self._create_scriptrunner(rerun_data)
def request_script_stop(self) -> None:
"""Request that the scriptrunner stop execution.
Does nothing if no scriptrunner exists.
"""
if self._scriptrunner is not None:
self._scriptrunner.request_stop()
def _create_scriptrunner(self, initial_rerun_data: RerunData) -> None:
"""Create and run a new ScriptRunner with the given RerunData."""
self._scriptrunner = ScriptRunner(
session_id=self.id,
main_script_path=self._script_data.main_script_path,
client_state=self._client_state,
session_state=self._session_state,
uploaded_file_mgr=self._uploaded_file_mgr,
initial_rerun_data=initial_rerun_data,
user_info=self._user_info,
)
self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
self._scriptrunner.start()
@property
def session_state(self) -> "SessionState":
return self._session_state
def _should_rerun_on_file_change(self, filepath: str) -> bool:
main_script_path = self._script_data.main_script_path
pages = source_util.get_pages(main_script_path)
changed_page_script_hash = next(
filter(lambda k: pages[k]["script_path"] == filepath, pages),
None,
)
if changed_page_script_hash is not None:
current_page_script_hash = self._client_state.page_script_hash
return changed_page_script_hash == current_page_script_hash
return True
def _on_source_file_changed(self, filepath: Optional[str] = None) -> None:
"""One of our source files changed. Schedule a rerun if appropriate."""
if filepath is not None and not self._should_rerun_on_file_change(filepath):
return
if self._run_on_save:
self.request_rerun(self._client_state)
else:
self._enqueue_forward_msg(self._create_file_change_message())
def _on_secrets_file_changed(self, _) -> None:
"""Called when `secrets.file_change_listener` emits a Signal."""
# NOTE: At the time of writing, this function only calls `_on_source_file_changed`.
# The reason behind creating this function instead of just passing `_on_source_file_changed`
# to `connect` / `disconnect` directly is that every function that is passed to `connect` / `disconnect`
# must have at least one argument for `sender` (in this case we don't really care about it, thus `_`),
# and introducing an unnecessary argument to `_on_source_file_changed` just for this purpose sounded finicky.
self._on_source_file_changed()
def _on_pages_changed(self, _) -> None:
msg = ForwardMsg()
_populate_app_pages(msg.pages_changed, self._script_data.main_script_path)
self._enqueue_forward_msg(msg)
def _clear_queue(self) -> None:
self._browser_queue.clear()
def _on_scriptrunner_event(
self,
sender: Optional[ScriptRunner],
event: ScriptRunnerEvent,
forward_msg: Optional[ForwardMsg] = None,
exception: Optional[BaseException] = None,
client_state: Optional[ClientState] = None,
page_script_hash: Optional[str] = None,
) -> None:
"""Called when our ScriptRunner emits an event.
This is generally called from the sender ScriptRunner's script thread.
We forward the event on to _handle_scriptrunner_event_on_event_loop,
which will be called on the main thread.
"""
self._event_loop.call_soon_threadsafe(
lambda: self._handle_scriptrunner_event_on_event_loop(
sender, event, forward_msg, exception, client_state, page_script_hash
)
)
def _handle_scriptrunner_event_on_event_loop(
self,
sender: Optional[ScriptRunner],
event: ScriptRunnerEvent,
forward_msg: Optional[ForwardMsg] = None,
exception: Optional[BaseException] = None,
client_state: Optional[ClientState] = None,
page_script_hash: Optional[str] = None,
) -> None:
"""Handle a ScriptRunner event.
This function must only be called on our eventloop thread.
Parameters
----------
sender : ScriptRunner | None
The ScriptRunner that emitted the event. (This may be set to
None when called from `handle_backmsg_exception`, if no
ScriptRunner was active when the backmsg exception was raised.)
event : ScriptRunnerEvent
The event type.
forward_msg : ForwardMsg | None
The ForwardMsg to send to the frontend. Set only for the
ENQUEUE_FORWARD_MSG event.
exception : BaseException | None
An exception thrown during compilation. Set only for the
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ScriptRunner's final ClientState. Set only for the
SHUTDOWN event.
page_script_hash : str | None
A hash of the script path corresponding to the page currently being
run. Set only for the SCRIPT_STARTED event.
"""
assert (
self._event_loop == asyncio.get_running_loop()
), "This function must only be called on the eventloop thread the AppSession was created on."
if sender is not self._scriptrunner:
# This event was sent by a non-current ScriptRunner; ignore it.
# This can happen after sppinng up a new ScriptRunner (to handle a
# rerun request, for example) while another ScriptRunner is still
# shutting down. The shutting-down ScriptRunner may still
# emit events.
LOGGER.debug("Ignoring event from non-current ScriptRunner: %s", event)
return
prev_state = self._state
if event == ScriptRunnerEvent.SCRIPT_STARTED:
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
self._state = AppSessionState.APP_IS_RUNNING
assert (
page_script_hash is not None
), "page_script_hash must be set for the SCRIPT_STARTED event"
self._clear_queue()
self._enqueue_forward_msg(
self._create_new_session_message(page_script_hash)
)
elif (
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR
):
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
self._state = AppSessionState.APP_NOT_RUNNING
script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
script_finished_msg = self._create_script_finished_message(
ForwardMsg.FINISHED_SUCCESSFULLY
if script_succeeded
else ForwardMsg.FINISHED_WITH_COMPILE_ERROR
)
self._enqueue_forward_msg(script_finished_msg)
self._debug_last_backmsg_id = None
if script_succeeded:
# The script completed successfully: update our
# LocalSourcesWatcher to account for any source code changes
# that change which modules should be watched.
if self._local_sources_watcher:
self._local_sources_watcher.update_watched_modules()
else:
# The script didn't complete successfully: send the exception
# to the frontend.
assert (
exception is not None
), "exception must be set for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event"
msg = ForwardMsg()
exception_utils.marshall(
msg.session_event.script_compilation_exception, exception
)
self._enqueue_forward_msg(msg)
elif event == ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN:
script_finished_msg = self._create_script_finished_message(
ForwardMsg.FINISHED_EARLY_FOR_RERUN
)
self._enqueue_forward_msg(script_finished_msg)
if self._local_sources_watcher:
self._local_sources_watcher.update_watched_modules()
elif event == ScriptRunnerEvent.SHUTDOWN:
assert (
client_state is not None
), "client_state must be set for the SHUTDOWN event"
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
# Only clear media files if the script is done running AND the
# session is actually shutting down.
runtime.get_instance().media_file_mgr.clear_session_refs(self.id)
self._client_state = client_state
self._scriptrunner = None
elif event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG:
assert (
forward_msg is not None
), "null forward_msg in ENQUEUE_FORWARD_MSG event"
self._enqueue_forward_msg(forward_msg)
# Send a message if our run state changed
app_was_running = prev_state == AppSessionState.APP_IS_RUNNING
app_is_running = self._state == AppSessionState.APP_IS_RUNNING
if app_is_running != app_was_running:
self._enqueue_forward_msg(self._create_session_status_changed_message())
def _create_session_status_changed_message(self) -> ForwardMsg:
"""Create and return a session_status_changed ForwardMsg."""
msg = ForwardMsg()
msg.session_status_changed.run_on_save = self._run_on_save
msg.session_status_changed.script_is_running = (
self._state == AppSessionState.APP_IS_RUNNING
)
return msg
def _create_file_change_message(self) -> ForwardMsg:
"""Create and return a 'script_changed_on_disk' ForwardMsg."""
msg = ForwardMsg()
msg.session_event.script_changed_on_disk = True
return msg
def _create_new_session_message(self, page_script_hash: str) -> ForwardMsg:
"""Create and return a new_session ForwardMsg."""
msg = ForwardMsg()
msg.new_session.script_run_id = _generate_scriptrun_id()
msg.new_session.name = self._script_data.name
msg.new_session.main_script_path = self._script_data.main_script_path
msg.new_session.page_script_hash = page_script_hash
_populate_app_pages(msg.new_session, self._script_data.main_script_path)
_populate_config_msg(msg.new_session.config)
_populate_theme_msg(msg.new_session.custom_theme)
# Immutable session data. We send this every time a new session is
# started, to avoid having to track whether the client has already
# received it. It does not change from run to run; it's up to the
# to perform one-time initialization only once.
imsg = msg.new_session.initialize
_populate_user_info_msg(imsg.user_info)
imsg.environment_info.streamlit_version = STREAMLIT_VERSION_STRING
imsg.environment_info.python_version = ".".join(map(str, sys.version_info))
imsg.session_status.run_on_save = self._run_on_save
imsg.session_status.script_is_running = (
self._state == AppSessionState.APP_IS_RUNNING
)
imsg.command_line = self._script_data.command_line
imsg.session_id = self.id
return msg
def _create_script_finished_message(
self, status: "ForwardMsg.ScriptFinishedStatus.ValueType"
) -> ForwardMsg:
"""Create and return a script_finished ForwardMsg."""
msg = ForwardMsg()
msg.script_finished = status
return msg
def _create_exception_message(self, e: BaseException) -> ForwardMsg:
"""Create and return an Exception ForwardMsg."""
msg = ForwardMsg()
exception_utils.marshall(msg.delta.new_element.exception, e)
return msg
def _handle_git_information_request(self) -> None:
msg = ForwardMsg()
try:
from streamlit.git_util import GitRepo
repo = GitRepo(self._script_data.main_script_path)
repo_info = repo.get_repo_info()
if repo_info is None:
return
repository_name, branch, module = repo_info
msg.git_info_changed.repository = repository_name
msg.git_info_changed.branch = branch
msg.git_info_changed.module = module
msg.git_info_changed.untracked_files[:] = repo.untracked_files
msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files
if repo.is_head_detached:
msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
elif len(repo.ahead_commits) > 0:
msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
else:
msg.git_info_changed.state = GitInfo.GitStates.DEFAULT
self._enqueue_forward_msg(msg)
except Exception as ex:
# Users may never even install Git in the first place, so this
# error requires no action. It can be useful for debugging.
LOGGER.debug("Obtaining Git information produced an error", exc_info=ex)
def _handle_rerun_script_request(
self, client_state: Optional[ClientState] = None
) -> None:
"""Tell the ScriptRunner to re-run its script.
Parameters
----------
client_state : streamlit.proto.ClientState_pb2.ClientState | None
The ClientState protobuf to run the script with, or None
to use previous client state.
"""
self.request_rerun(client_state)
def _handle_stop_script_request(self) -> None:
"""Tell the ScriptRunner to stop running its script."""
self.request_script_stop()
def _handle_clear_cache_request(self) -> None:
"""Clear this app's cache.
Because this cache is global, it will be cleared for all users.
"""
legacy_caching.clear_cache()
caching.cache_data.clear()
caching.cache_resource.clear()
self._session_state.clear()
def _handle_set_run_on_save_request(self, new_value: bool) -> None:
"""Change our run_on_save flag to the given value.
The browser will be notified of the change.
Parameters
----------
new_value : bool
New run_on_save value
"""
self._run_on_save = new_value
self._enqueue_forward_msg(self._create_session_status_changed_message())
def _populate_config_msg(msg: Config) -> None:
msg.gather_usage_stats = config.get_option("browser.gatherUsageStats")
msg.max_cached_message_age = config.get_option("global.maxCachedMessageAge")
msg.mapbox_token = config.get_option("mapbox.token")
msg.allow_run_on_save = config.get_option("server.allowRunOnSave")
msg.hide_top_bar = config.get_option("ui.hideTopBar")
msg.hide_sidebar_nav = config.get_option("ui.hideSidebarNav")
def _populate_theme_msg(msg: CustomThemeConfig) -> None:
enum_encoded_options = {"base", "font"}
theme_opts = config.get_options_for_section("theme")
if not any(theme_opts.values()):
return
for option_name, option_val in theme_opts.items():
if option_name not in enum_encoded_options and option_val is not None:
setattr(msg, to_snake_case(option_name), option_val)
# NOTE: If unset, base and font will default to the protobuf enum zero
# values, which are BaseTheme.LIGHT and FontFamily.SANS_SERIF,
# respectively. This is why we both don't handle the cases explicitly and
# also only log a warning when receiving invalid base/font options.
base_map = {
"light": msg.BaseTheme.LIGHT,
"dark": msg.BaseTheme.DARK,
}
base = theme_opts["base"]
if base is not None:
if base not in base_map:
LOGGER.warning(
f'"{base}" is an invalid value for theme.base.'
f" Allowed values include {list(base_map.keys())}."
' Setting theme.base to "light".'
)
else:
msg.base = base_map[base]
font_map = {
"sans serif": msg.FontFamily.SANS_SERIF,
"serif": msg.FontFamily.SERIF,
"monospace": msg.FontFamily.MONOSPACE,
}
font = theme_opts["font"]
if font is not None:
if font not in font_map:
LOGGER.warning(
f'"{font}" is an invalid value for theme.font.'
f" Allowed values include {list(font_map.keys())}."
' Setting theme.font to "sans serif".'
)
else:
msg.font = font_map[font]
def _populate_user_info_msg(msg: UserInfo) -> None:
msg.installation_id = Installation.instance().installation_id
msg.installation_id_v3 = Installation.instance().installation_id_v3
if Credentials.get_current().activation:
msg.email = Credentials.get_current().activation.email
else:
msg.email = ""
def _populate_app_pages(
msg: Union[NewSession, PagesChanged], main_script_path: str
) -> None:
for page_script_hash, page_info in source_util.get_pages(main_script_path).items():
page_proto = msg.app_pages.add()
page_proto.page_script_hash = page_script_hash
page_proto.page_name = page_info["page_name"]
page_proto.icon = page_info["icon"]

View File

@@ -0,0 +1,132 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Iterator, Union
from google.protobuf.message import Message
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_data_api import (
CACHE_DATA_MESSAGE_REPLAY_CTX,
CacheDataAPI,
_data_caches,
)
from streamlit.runtime.caching.cache_errors import CACHE_DOCS_URL as CACHE_DOCS_URL
from streamlit.runtime.caching.cache_resource_api import (
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
CacheResourceAPI,
_resource_caches,
)
from streamlit.runtime.state.common import WidgetMetadata
def save_element_message(
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Save the message for an element to a thread-local callstack, so it can
be used later to replay the element when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
def save_block_message(
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Save the message for a block to a thread-local callstack, so it can
be used later to replay the block when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
def save_widget_metadata(metadata: WidgetMetadata[Any]) -> None:
"""Save a widget's metadata to a thread-local callstack, so the widget
can be registered again when that widget is replayed.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
def save_media_data(
image_data: Union[bytes, str], mimetype: str, image_id: str
) -> None:
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
CACHE_DATA_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
dg, st_func_name
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
dg, st_func_name
)
@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
with CACHE_DATA_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning(), CACHE_RESOURCE_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning():
yield
# Explicitly export public symbols
from streamlit.runtime.caching.cache_data_api import (
get_data_cache_stats_provider as get_data_cache_stats_provider,
)
from streamlit.runtime.caching.cache_resource_api import (
get_resource_cache_stats_provider as get_resource_cache_stats_provider,
)
# Create and export public API singletons.
cache_data = CacheDataAPI(decorator_metric_name="cache_data")
cache_resource = CacheResourceAPI(decorator_metric_name="cache_resource")
# Deprecated singletons
_MEMO_WARNING = (
f"`st.experimental_memo` is deprecated. Please use the new command `st.cache_data` instead, "
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
)
experimental_memo = CacheDataAPI(
decorator_metric_name="experimental_memo", deprecation_warning=_MEMO_WARNING
)
_SINGLETON_WARNING = (
f"`st.experimental_singleton` is deprecated. Please use the new command `st.cache_resource` instead, "
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
)
experimental_singleton = CacheResourceAPI(
decorator_metric_name="experimental_singleton",
deprecation_warning=_SINGLETON_WARNING,
)

View File

@@ -0,0 +1,678 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_data: pickle-based caching"""
from __future__ import annotations
import pickle
import threading
import types
from datetime import timedelta
from typing import Any, Callable, TypeVar, Union, cast, overload
from typing_extensions import Literal, TypeAlias
import streamlit as st
from streamlit import runtime
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
ElementMsgData,
MsgData,
MultiCacheResults,
)
from streamlit.runtime.caching.storage import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
InvalidCacheStorageContext,
)
from streamlit.runtime.caching.storage.dummy_cache_storage import (
MemoryCacheStorageManager,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
_LOGGER = get_logger(__name__)
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)
# The cache persistence options we support: "disk" or None
CachePersistType: TypeAlias = Union[Literal["disk"], None]
class CachedDataFuncInfo(CachedFuncInfo):
"""Implements the CachedFuncInfo interface for @st.cache_data"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
persist: CachePersistType,
max_entries: int | None,
ttl: float | timedelta | None,
allow_widgets: bool,
):
super().__init__(
func,
show_spinner=show_spinner,
allow_widgets=allow_widgets,
)
self.persist = persist
self.max_entries = max_entries
self.ttl = ttl
self.validate_params()
@property
def cache_type(self) -> CacheType:
return CacheType.DATA
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_DATA_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _data_caches.get_cache(
key=function_key,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
display_name=self.display_name,
allow_widgets=self.allow_widgets,
)
def validate_params(self) -> None:
"""
Validate the params passed to @st.cache_data are compatible with cache storage
When called, this method could log warnings if cache params are invalid
for current storage.
"""
_data_caches.validate_cache_params(
function_name=self.func.__name__,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
)
class DataCaches(CacheStatsProvider):
"""Manages all DataCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: dict[str, DataCache] = {}
def get_cache(
self,
key: str,
persist: CachePersistType,
max_entries: int | None,
ttl: int | float | timedelta | None,
display_name: str,
allow_widgets: bool,
) -> DataCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and cache.persist == persist
):
return cache
# Close the existing cache's storage, if it exists.
if cache is not None:
_LOGGER.debug(
"Closing existing DataCache storage "
"(key=%s, persist=%s, max_entries=%s, ttl=%s) "
"before creating new one with different params",
key,
persist,
max_entries,
ttl,
)
cache.storage.close()
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
key,
persist,
max_entries,
ttl,
)
cache_context = self.create_cache_storage_context(
function_key=key,
function_name=display_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
cache_storage_manager = self.get_storage_manager()
storage = cache_storage_manager.create(cache_context)
cache = DataCache(
key=key,
storage=storage,
persist=persist,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
display_name=display_name,
allow_widgets=allow_widgets,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all in-memory and on-disk caches."""
with self._caches_lock:
try:
# try to remove in optimal way if such ability provided by
# storage manager clear_all method;
# if not implemented, fallback to remove all
# available storages one by one
self.get_storage_manager().clear_all()
except NotImplementedError:
for data_cache in self._function_caches.values():
data_cache.clear()
data_cache.storage.close()
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
def validate_cache_params(
self,
function_name: str,
persist: CachePersistType,
max_entries: int | None,
ttl: int | float | timedelta | None,
) -> None:
"""Validate that the cache params are valid for given storage.
Raises
------
InvalidCacheStorageContext
Raised if the cache storage manager is not able to work with provided
CacheStorageContext.
"""
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
cache_context = self.create_cache_storage_context(
function_key="DUMMY_KEY",
function_name=function_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
try:
self.get_storage_manager().check_context(cache_context)
except InvalidCacheStorageContext as e:
_LOGGER.error(
"Cache params for function %s are incompatible with current "
"cache storage manager: %s",
function_name,
e,
)
raise
def create_cache_storage_context(
self,
function_key: str,
function_name: str,
persist: CachePersistType,
ttl_seconds: float | None,
max_entries: int | None,
) -> CacheStorageContext:
return CacheStorageContext(
function_key=function_key,
function_display_name=function_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
def get_storage_manager(self) -> CacheStorageManager:
if runtime.exists():
return runtime.get_instance().cache_storage_manager
else:
# When running in "raw mode", we can't access the CacheStorageManager,
# so we're falling back to InMemoryCache.
_LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
return MemoryCacheStorageManager()
# Singleton DataCaches instance
_data_caches = DataCaches()
def get_data_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_data functions."""
return _data_caches
class CacheDataAPI:
"""Implements the public st.cache_data API: the @st.cache_data decorator, and
st.cache_data.clear().
"""
def __init__(
self, decorator_metric_name: str, deprecation_warning: str | None = None
):
"""Create a CacheDataAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage. `@st.experimental_memo` is
deprecated, but we're still supporting it and tracking its usage separately
from `@st.cache_data`.
deprecation_warning
An optional deprecation warning to show when the API is accessed.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics( # type: ignore
decorator_metric_name, self._decorator
)
self._deprecation_warning = deprecation_warning
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
def __call__(self, func: F) -> F:
...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
) -> Callable[[F], F]:
...
def __call__(
self,
func: F | None = None,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
):
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
persist=persist,
show_spinner=show_spinner,
experimental_allow_widgets=experimental_allow_widgets,
)
def _decorator(
self,
func: F | None = None,
*,
ttl: float | timedelta | None,
max_entries: int | None,
show_spinner: bool | str,
persist: CachePersistType | bool,
experimental_allow_widgets: bool,
):
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).
Cached objects are stored in "pickled" form, which means that the return
value of a cached function must be pickleable. Each caller of the cached
function gets its own copy of the cached data.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_data.clear()``.
To cache global resources, use ``st.cache_resource`` instead. Learn more
about caching at https://docs.streamlit.io/library/advanced-features/caching.
Parameters
----------
func : callable
The function to cache. Streamlit hashes the function's source code.
ttl : float or timedelta or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Note that ttl is incompatible with ``persist="disk"`` - ``ttl`` will be
ignored if ``persist`` is specified.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
show_spinner : boolean or string
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached data is being created. If string,
value of show_spinner param will be used for spinner text.
persist : str or boolean or None
Optional location to persist cached data to. Passing "disk" (or True)
will persist the cached data to the local disk. None (or False) will disable
persistence. The default is None.
experimental_allow_widgets : boolean
Allow widgets to be used in the cached function. Defaults to False.
Support for widgets in cached functions is currently experimental.
Setting this parameter to True may lead to excessive memory use since the
widget value is treated as an additional input parameter to the cache.
We may remove support for this option at any time without notice.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
...
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> import streamlit as st
>>>
>>> @st.cache_data(persist="disk")
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
By default, all parameters to a cached function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> connection = make_database_connection()
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> another_connection = make_database_connection()
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _database_connection parameter was different
>>> # in both calls.
A cached function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> fetch_and_clean_data.clear()
>>> # Clear all cached entries for this function.
"""
# Parse our persist value into a string
persist_string: CachePersistType
if persist is True:
persist_string = "disk"
elif persist is False:
persist_string = None
else:
persist_string = persist
if persist_string not in (None, "disk"):
# We'll eventually have more persist options.
raise StreamlitAPIException(
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
)
self._maybe_show_deprecation_warning()
def wrapper(f):
return make_cached_func_wrapper(
CachedDataFuncInfo(
func=f,
persist=persist_string,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
allow_widgets=experimental_allow_widgets,
)
)
if func is None:
return wrapper
return make_cached_func_wrapper(
CachedDataFuncInfo(
func=cast(types.FunctionType, func),
persist=persist_string,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
allow_widgets=experimental_allow_widgets,
)
)
@gather_metrics("clear_data_caches")
def clear(self) -> None:
"""Clear all in-memory and on-disk data caches."""
self._maybe_show_deprecation_warning()
_data_caches.clear_all()
def _maybe_show_deprecation_warning(self):
"""If the API is being accessed with the deprecated `st.experimental_memo` name,
show a deprecation warning.
"""
if self._deprecation_warning is not None:
show_deprecation_warning(self._deprecation_warning)
class DataCache(Cache):
"""Manages cached values for a single st.cache_data function."""
def __init__(
self,
key: str,
storage: CacheStorage,
persist: CachePersistType,
max_entries: int | None,
ttl_seconds: float | None,
display_name: str,
allow_widgets: bool = False,
):
super().__init__()
self.key = key
self.display_name = display_name
self.storage = storage
self.ttl_seconds = ttl_seconds
self.max_entries = max_entries
self.persist = persist
self.allow_widgets = allow_widgets
def get_stats(self) -> list[CacheStat]:
if isinstance(self.storage, CacheStatsProvider):
return self.storage.get_stats()
return []
def read_result(self, key: str) -> CachedResult:
"""Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
if the value doesn't exist, and `CacheError` if the value exists but can't
be unpickled.
"""
try:
pickled_entry = self.storage.get(key)
except CacheStorageKeyNotFoundError as e:
raise CacheKeyNotFoundError(str(e)) from e
except CacheStorageError as e:
raise CacheError(str(e)) from e
try:
entry = pickle.loads(pickled_entry)
if not isinstance(entry, MultiCacheResults):
# Loaded an old cache file format, remove it and let the caller
# rerun the function.
self.storage.delete(key)
raise CacheKeyNotFoundError()
ctx = get_script_run_ctx()
if not ctx:
raise CacheKeyNotFoundError()
widget_key = entry.get_current_widget_key(ctx, CacheType.DATA)
if widget_key in entry.results:
return entry.results[widget_key]
else:
raise CacheKeyNotFoundError()
except pickle.UnpicklingError as exc:
raise CacheError(f"Failed to unpickle {key}") from exc
@gather_metrics("_cache_data_object")
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache.
The value must be pickleable.
"""
ctx = get_script_run_ctx()
if ctx is None:
return
main_id = st._main.id
sidebar_id = st.sidebar.id
if self.allow_widgets:
widgets = {
msg.widget_metadata.widget_id
for msg in messages
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
}
else:
widgets = set()
multi_cache_results: MultiCacheResults | None = None
# Try to find in cache storage, then falling back to a new result instance
try:
multi_cache_results = self._read_multi_results_from_storage(key)
except (CacheKeyNotFoundError, pickle.UnpicklingError):
pass
if multi_cache_results is None:
multi_cache_results = MultiCacheResults(widget_ids=widgets, results={})
multi_cache_results.widget_ids.update(widgets)
widget_key = multi_cache_results.get_current_widget_key(ctx, CacheType.DATA)
result = CachedResult(value, messages, main_id, sidebar_id)
multi_cache_results.results[widget_key] = result
try:
pickled_entry = pickle.dumps(multi_cache_results)
except (pickle.PicklingError, TypeError) as exc:
raise CacheError(f"Failed to pickle {key}") from exc
self.storage.set(key, pickled_entry)
def _clear(self) -> None:
self.storage.clear()
def _read_multi_results_from_storage(self, key: str) -> MultiCacheResults:
"""Look up the results from storage and ensure it has the right type.
Raises a `CacheKeyNotFoundError` if the key has no entry, or if the
entry is malformed.
"""
try:
pickled = self.storage.get(key)
except CacheStorageKeyNotFoundError as e:
raise CacheKeyNotFoundError(str(e)) from e
maybe_results = pickle.loads(pickled)
if isinstance(maybe_results, MultiCacheResults):
return maybe_results
else:
self.storage.delete(key)
raise CacheKeyNotFoundError()

View File

@@ -0,0 +1,176 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from typing import Any, Optional
from streamlit import type_util
from streamlit.errors import (
MarkdownFormattedException,
StreamlitAPIException,
StreamlitAPIWarning,
)
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
CACHE_DOCS_URL = "https://docs.streamlit.io/library/advanced-features/caching"
def get_cached_func_name_md(func: Any) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
elif hasattr(type(func), "__name__"):
return f"`{type(func).__name__}`"
return f"`{type(func)}`"
def get_return_value_type(return_value: Any) -> str:
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
return f"`{return_value.__module__}.{type(return_value).__name__}`"
return get_cached_func_name_md(return_value)
class UnhashableTypeError(Exception):
pass
class UnhashableParamError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
orig_exc: BaseException,
):
msg = self._create_message(cache_type, func, arg_name, arg_value)
super().__init__(msg)
self.with_traceback(orig_exc.__traceback__)
@staticmethod
def _create_message(
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
) -> str:
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
arg_type = type_util.get_fqn_type(arg_value)
func_name = func.__name__
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
return (
f"""
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
To address this, you can tell Streamlit not to hash this argument by adding a
leading underscore to the argument's name in the function signature:
```
@st.{get_decorator_api_name(cache_type)}
def {func_name}({arg_replacement_name}, ...):
...
```
"""
).strip("\n")
class CacheKeyNotFoundError(Exception):
pass
class CacheError(Exception):
pass
class CachedStFunctionWarning(StreamlitAPIWarning):
def __init__(
self,
cache_type: CacheType,
st_func_name: str,
cached_func: types.FunctionType,
):
args = {
"st_func_name": f"`st.{st_func_name}()`",
"func_name": self._get_cached_func_name_md(cached_func),
"decorator_name": get_decorator_api_name(cache_type),
}
msg = (
"""
Your script uses %(st_func_name)s to write to your Streamlit app from within
some cached code at %(func_name)s. This code will only be called when we detect
a cache "miss", which can lead to unexpected results.
How to fix this:
* Move the %(st_func_name)s call outside %(func_name)s.
* Or, if you know what you're doing, use `@st.%(decorator_name)s(experimental_allow_widgets=True)`
to enable widget replay and suppress this warning.
"""
% args
).strip("\n")
super().__init__(msg)
@staticmethod
def _get_cached_func_name_md(func: types.FunctionType) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
else:
return "a cached function"
class CacheReplayClosureError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
cached_func: types.FunctionType,
):
func_name = get_cached_func_name_md(cached_func)
decorator_name = get_decorator_api_name(cache_type)
msg = (
f"""
While running {func_name}, a streamlit element is called on some layout block created outside the function.
This is incompatible with replaying the cached effect of that element, because the
the referenced block might not exist when the replay happens.
How to fix this:
* Move the creation of $THING inside {func_name}.
* Move the call to the streamlit element outside of {func_name}.
* Remove the `@st.{decorator_name}` decorator from {func_name}.
"""
).strip("\n")
super().__init__(msg)
class UnserializableReturnValueError(MarkdownFormattedException):
def __init__(self, func: types.FunctionType, return_value: types.FunctionType):
MarkdownFormattedException.__init__(
self,
f"""
Cannot serialize the return value (of type {get_return_value_type(return_value)}) in {get_cached_func_name_md(func)}.
`st.cache_data` uses [pickle](https://docs.python.org/3/library/pickle.html) to
serialize the functions return value and safely store it in the cache without mutating the original object. Please convert the return value to a pickle-serializable type.
If you want to cache unserializable objects such as database connections or Tensorflow
sessions, use `st.cache_resource` instead (see [our docs]({CACHE_DOCS_URL}) for differences).""",
)
class UnevaluatedDataFrameError(StreamlitAPIException):
"""Used to display a message about uncollected dataframe being used"""
pass

View File

@@ -0,0 +1,520 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_resource implementation"""
from __future__ import annotations
import math
import threading
import types
from datetime import timedelta
from typing import Any, Callable, TypeVar, cast, overload
from cachetools import TTLCache
from pympler import asizeof
from typing_extensions import TypeAlias
import streamlit as st
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
ElementMsgData,
MsgData,
MultiCacheResults,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
_LOGGER = get_logger(__name__)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
ValidateFunc: TypeAlias = Callable[[Any], bool]
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
"""True if the two validate functions are equal for the purposes of
determining whether a given function cache needs to be recreated.
"""
# To "properly" test for function equality here, we'd need to compare function bytecode.
# For performance reasons, We've decided not to do that for now.
return (a is None and b is None) or (a is not None and b is not None)
class ResourceCaches(CacheStatsProvider):
"""Manages all ResourceCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: dict[str, ResourceCache] = {}
def get_cache(
self,
key: str,
display_name: str,
max_entries: int | float | None,
ttl: float | timedelta | None,
validate: ValidateFunc | None,
allow_widgets: bool,
) -> ResourceCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
ttl_seconds = ttl_to_seconds(ttl)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and _equal_validate_funcs(cache.validate, validate)
):
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
cache = ResourceCache(
key=key,
display_name=display_name,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
validate=validate,
allow_widgets=allow_widgets,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all resource caches."""
with self._caches_lock:
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
# Singleton ResourceCaches instance
_resource_caches = ResourceCaches()
def get_resource_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_resource functions."""
return _resource_caches
class CachedResourceFuncInfo(CachedFuncInfo):
"""Implements the CachedFuncInfo interface for @st.cache_resource"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
max_entries: int | None,
ttl: float | timedelta | None,
validate: ValidateFunc | None,
allow_widgets: bool,
):
super().__init__(
func,
show_spinner=show_spinner,
allow_widgets=allow_widgets,
)
self.max_entries = max_entries
self.ttl = ttl
self.validate = validate
@property
def cache_type(self) -> CacheType:
return CacheType.RESOURCE
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _resource_caches.get_cache(
key=function_key,
display_name=self.display_name,
max_entries=self.max_entries,
ttl=self.ttl,
validate=self.validate,
allow_widgets=self.allow_widgets,
)
class CacheResourceAPI:
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
and st.cache_resource.clear().
"""
def __init__(
self, decorator_metric_name: str, deprecation_warning: str | None = None
):
"""Create a CacheResourceAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage. `@st.experimental_singleton` is
deprecated, but we're still supporting it and tracking its usage separately
from `@st.cache_resource`.
deprecation_warning
An optional deprecation warning to show when the API is accessed.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
self._deprecation_warning = deprecation_warning
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
def __call__(self, func: F) -> F:
...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
) -> Callable[[F], F]:
...
def __call__(
self,
func: F | None = None,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
):
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
show_spinner=show_spinner,
validate=validate,
experimental_allow_widgets=experimental_allow_widgets,
)
def _decorator(
self,
func: F | None,
*,
ttl: float | timedelta | None,
max_entries: int | None,
show_spinner: bool | str,
validate: ValidateFunc | None,
experimental_allow_widgets: bool,
):
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
Cached objects are shared across all users, sessions, and reruns. They
must be thread-safe because they can be accessed from multiple threads
concurrently. If thread safety is an issue, consider using ``st.session_state``
to store resources per session instead.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_resource.clear()``.
To cache data, use ``st.cache_data`` instead. Learn more about caching at
https://docs.streamlit.io/library/advanced-features/caching.
Parameters
----------
func : callable
The function that creates the cached resource. Streamlit hashes the
function's source code.
ttl : float or timedelta or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
show_spinner : boolean or string
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached resource is being created. If string,
value of show_spinner param will be used for spinner text.
validate : callable or None
An optional validation function for cached data. ``validate`` is called
each time the cached value is accessed. It receives the cached value as
its only parameter and it must return a boolean. If ``validate`` returns
False, the current cached value is discarded, and the decorated function
is called to compute a new value. This is useful e.g. to check the
health of database connections.
experimental_allow_widgets : boolean
Allow widgets to be used in the cached function. Defaults to False.
Support for widgets in cached functions is currently experimental.
Setting this parameter to True may lead to excessive memory use since the
widget value is treated as an additional input parameter to the cache.
We may remove support for this option at any time without notice.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(url):
... # Create a database session object that points to the URL.
... return session
...
>>> s1 = get_database_session(SESSION_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(SESSION_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the connection object in s1 is the same as in s2.
>>>
>>> s3 = get_database_session(SESSION_URL_2)
>>> # This is a different URL, so the function executes.
By default, all parameters to a cache_resource function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _sessionmaker parameter was different
>>> # in both calls.
A cache_resource function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> get_database_session.clear()
>>> # Clear all cached entries for this function.
"""
self._maybe_show_deprecation_warning()
# Support passing the params via function decorator, e.g.
# @st.cache_resource(show_spinner=False)
if func is None:
return lambda f: make_cached_func_wrapper(
CachedResourceFuncInfo(
func=f,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
validate=validate,
allow_widgets=experimental_allow_widgets,
)
)
return make_cached_func_wrapper(
CachedResourceFuncInfo(
func=cast(types.FunctionType, func),
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
validate=validate,
allow_widgets=experimental_allow_widgets,
)
)
@gather_metrics("clear_resource_caches")
def clear(self) -> None:
"""Clear all cache_resource caches."""
self._maybe_show_deprecation_warning()
_resource_caches.clear_all()
def _maybe_show_deprecation_warning(self):
"""If the API is being accessed with the deprecated `st.experimental_singleton` name,
show a deprecation warning.
"""
if self._deprecation_warning is not None:
show_deprecation_warning(self._deprecation_warning)
class ResourceCache(Cache):
"""Manages cached values for a single st.cache_resource function."""
def __init__(
self,
key: str,
max_entries: float,
ttl_seconds: float,
validate: ValidateFunc | None,
display_name: str,
allow_widgets: bool,
):
super().__init__()
self.key = key
self.display_name = display_name
self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache(
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
)
self._mem_cache_lock = threading.Lock()
self.validate = validate
self.allow_widgets = allow_widgets
@property
def max_entries(self) -> float:
return cast(float, self._mem_cache.maxsize)
@property
def ttl_seconds(self) -> float:
return cast(float, self._mem_cache.ttl)
def read_result(self, key: str) -> CachedResult:
"""Read a value and associated messages from the cache.
Raise `CacheKeyNotFoundError` if the value doesn't exist.
"""
with self._mem_cache_lock:
if key not in self._mem_cache:
# key does not exist in cache.
raise CacheKeyNotFoundError()
multi_results: MultiCacheResults = self._mem_cache[key]
ctx = get_script_run_ctx()
if not ctx:
# ScriptRunCtx does not exist (we're probably running in "raw" mode).
raise CacheKeyNotFoundError()
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
if widget_key not in multi_results.results:
# widget_key does not exist in cache (this combination of widgets hasn't been
# seen for the value_key yet).
raise CacheKeyNotFoundError()
result = multi_results.results[widget_key]
if self.validate is not None and not self.validate(result.value):
# Validate failed: delete the entry and raise an error.
del multi_results.results[widget_key]
raise CacheKeyNotFoundError()
return result
@gather_metrics("_cache_resource_object")
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache."""
ctx = get_script_run_ctx()
if ctx is None:
return
main_id = st._main.id
sidebar_id = st.sidebar.id
if self.allow_widgets:
widgets = {
msg.widget_metadata.widget_id
for msg in messages
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
}
else:
widgets = set()
with self._mem_cache_lock:
try:
multi_results = self._mem_cache[key]
except KeyError:
multi_results = MultiCacheResults(widget_ids=widgets, results={})
multi_results.widget_ids.update(widgets)
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
result = CachedResult(value, messages, main_id, sidebar_id)
multi_results.results[widget_key] = result
self._mem_cache[key] = multi_results
def _clear(self) -> None:
with self._mem_cache_lock:
self._mem_cache.clear()
def get_stats(self) -> list[CacheStat]:
# Shallow clone our cache. Computing item sizes is potentially
# expensive, and we want to minimize the time we spend holding
# the lock.
with self._mem_cache_lock:
cache_entries = list(self._mem_cache.values())
return [
CacheStat(
category_name="st_cache_resource",
cache_name=self.display_name,
byte_length=asizeof.asizeof(entry),
)
for entry in cache_entries
]

View File

@@ -0,0 +1,31 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class CacheType(enum.Enum):
"""The function cache types we implement."""
DATA = "DATA"
RESOURCE = "RESOURCE"
def get_decorator_api_name(cache_type: CacheType) -> str:
"""Return the name of the public decorator API for the given CacheType."""
if cache_type is CacheType.DATA:
return "cache_data"
if cache_type is CacheType.RESOURCE:
return "cache_resource"
raise RuntimeError(f"Unrecognized CacheType '{cache_type}'")

View File

@@ -0,0 +1,449 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common cache logic shared by st.cache_data and st.cache_resource."""
from __future__ import annotations
import functools
import hashlib
import inspect
import math
import threading
import time
import types
from abc import abstractmethod
from collections import defaultdict
from datetime import timedelta
from typing import Any, Callable, overload
from typing_extensions import Literal
from streamlit import type_util
from streamlit.elements.spinner import spinner
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import (
CacheError,
CacheKeyNotFoundError,
UnevaluatedDataFrameError,
UnhashableParamError,
UnhashableTypeError,
UnserializableReturnValueError,
get_cached_func_name_md,
)
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
MsgData,
replay_cached_messages,
)
from streamlit.runtime.caching.hashing import update_hash
_LOGGER = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
TTLCACHE_TIMER = time.monotonic
@overload
def ttl_to_seconds(
ttl: float | timedelta | None, *, coerce_none_to_inf: Literal[False]
) -> float | None:
...
@overload
def ttl_to_seconds(ttl: float | timedelta | None) -> float:
...
def ttl_to_seconds(
ttl: float | timedelta | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
"""
if coerce_none_to_inf and ttl is None:
return math.inf
if isinstance(ttl, timedelta):
return ttl.total_seconds()
return ttl
# We show a special "UnevaluatedDataFrame" warning for cached funcs
# that attempt to return one of these unserializable types:
UNEVALUATED_DATAFRAME_TYPES = (
"snowflake.snowpark.table.Table",
"snowflake.snowpark.dataframe.DataFrame",
"pyspark.sql.dataframe.DataFrame",
)
class Cache:
"""Function cache interface. Caches persist across script runs."""
def __init__(self):
self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
self._value_locks_lock = threading.Lock()
@abstractmethod
def read_result(self, value_key: str) -> CachedResult:
"""Read a value and associated messages from the cache.
Raises
------
CacheKeyNotFoundError
Raised if value_key is not in the cache.
"""
raise NotImplementedError
@abstractmethod
def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache, overwriting any existing
result that uses the value_key.
"""
# We *could* `del self._value_locks[value_key]` here, since nobody will be taking
# a compute_value_lock for this value_key after the result is written.
raise NotImplementedError
def compute_value_lock(self, value_key: str) -> threading.Lock:
"""Return the lock that should be held while computing a new cached value.
In a popular app with a cache that hasn't been pre-warmed, many sessions may try
to access a not-yet-cached value simultaneously. We use a lock to ensure that
only one of those sessions computes the value, and the others block until
the value is computed.
"""
with self._value_locks_lock:
return self._value_locks[value_key]
def clear(self):
"""Clear all values from this cache."""
with self._value_locks_lock:
self._value_locks.clear()
self._clear()
@abstractmethod
def _clear(self) -> None:
"""Subclasses must implement this to perform cache-clearing logic."""
raise NotImplementedError
class CachedFuncInfo:
"""Encapsulates data for a cached function instance.
CachedFuncInfo instances are scoped to a single script run - they're not
persistent.
"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
allow_widgets: bool,
):
self.func = func
self.show_spinner = show_spinner
self.allow_widgets = allow_widgets
@property
def cache_type(self) -> CacheType:
raise NotImplementedError
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
raise NotImplementedError
def get_function_cache(self, function_key: str) -> Cache:
"""Get or create the function cache for the given key."""
raise NotImplementedError
def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]:
"""Create a callable wrapper around a CachedFunctionInfo.
Calling the wrapper will return the cached value if it's already been
computed, and will call the underlying function to compute and cache the
value otherwise.
The wrapper also has a `clear` function that can be called to clear
all of the wrapper's cached values.
"""
cached_func = CachedFunc(info)
# We'd like to simply return `cached_func`, which is already a Callable.
# But using `functools.update_wrapper` on the CachedFunc instance
# itself results in errors when our caching decorators are used to decorate
# member functions. (See https://github.com/streamlit/streamlit/issues/6109)
@functools.wraps(info.func)
def wrapper(*args, **kwargs):
return cached_func(*args, **kwargs)
# Give our wrapper its `clear` function.
# (This results in a spurious mypy error that we suppress.)
wrapper.clear = cached_func.clear # type: ignore
return wrapper
class CachedFunc:
def __init__(self, info: CachedFuncInfo):
self._info = info
self._function_key = _make_function_key(info.cache_type, info.func)
def __call__(self, *args, **kwargs) -> Any:
"""The wrapper. We'll only call our underlying function on a cache miss."""
name = self._info.func.__qualname__
if isinstance(self._info.show_spinner, bool):
if len(args) == 0 and len(kwargs) == 0:
message = f"Running `{name}()`."
else:
message = f"Running `{name}(...)`."
else:
message = self._info.show_spinner
if self._info.show_spinner or isinstance(self._info.show_spinner, str):
with spinner(message):
return self._get_or_create_cached_value(args, kwargs)
else:
return self._get_or_create_cached_value(args, kwargs)
def _get_or_create_cached_value(
self, func_args: tuple[Any, ...], func_kwargs: dict[str, Any]
) -> Any:
# Retrieve the function's cache object. We must do this "just-in-time"
# (as opposed to in the constructor), because caches can be invalidated
# at any time.
cache = self._info.get_function_cache(self._function_key)
# Generate the key for the cached value. This is based on the
# arguments passed to the function.
value_key = _make_value_key(
cache_type=self._info.cache_type,
func=self._info.func,
func_args=func_args,
func_kwargs=func_kwargs,
)
try:
cached_result = cache.read_result(value_key)
return self._handle_cache_hit(cached_result)
except CacheKeyNotFoundError:
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
def _handle_cache_hit(self, result: CachedResult) -> Any:
"""Handle a cache hit: replay the result's cached messages, and return its value."""
replay_cached_messages(
result,
self._info.cache_type,
self._info.func,
)
return result.value
def _handle_cache_miss(
self,
cache: Cache,
value_key: str,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
) -> Any:
"""Handle a cache miss: compute a new cached value, write it back to the cache,
and return that newly-computed value.
"""
# Implementation notes:
# - We take a "compute_value_lock" before computing our value. This ensures that
# multiple sessions don't try to compute the same value simultaneously.
#
# - We use a different lock for each value_key, as opposed to a single lock for
# the entire cache, so that unrelated value computations don't block on each other.
#
# - When retrieving a cache entry that may not yet exist, we use a "double-checked locking"
# strategy: first we try to retrieve the cache entry without taking a value lock. (This
# happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't
# been computed yet, we take the value lock and then immediately try to retrieve cache entry
# *again*, while holding the lock. If the cache entry exists at this point, it means that
# another thread computed the value before us.
#
# This means that the happy path ("cache entry exists") is a wee bit faster because
# no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is
# a wee bit slower, because we do two lookups for the entry.
with cache.compute_value_lock(value_key):
# We've acquired the lock - but another thread may have acquired it first
# and already computed the value. So we need to test for a cache hit again,
# before computing.
try:
cached_result = cache.read_result(value_key)
# Another thread computed the value before us. Early exit!
return self._handle_cache_hit(cached_result)
except CacheKeyNotFoundError:
# We acquired the lock before any other thread. Compute the value!
with self._info.cached_message_replay_ctx.calling_cached_function(
self._info.func, self._info.allow_widgets
):
computed_value = self._info.func(*func_args, **func_kwargs)
# We've computed our value, and now we need to write it back to the cache
# along with any "replay messages" that were generated during value computation.
messages = self._info.cached_message_replay_ctx._most_recent_messages
try:
cache.write_result(value_key, computed_value, messages)
return computed_value
except (CacheError, RuntimeError):
# An exception was thrown while we tried to write to the cache. Report it to the user.
# (We catch `RuntimeError` here because it will be raised by Apache Spark if we do not
# collect dataframe before using `st.cache_data`.)
if True in [
type_util.is_type(computed_value, type_name)
for type_name in UNEVALUATED_DATAFRAME_TYPES
]:
raise UnevaluatedDataFrameError(
f"""
The function {get_cached_func_name_md(self._info.func)} is decorated with `st.cache_data` but it returns an unevaluated dataframe
of type `{type_util.get_fqn_type(computed_value)}`. Please call `collect()` or `to_pandas()` on the dataframe before returning it,
so `st.cache_data` can serialize and cache it."""
)
raise UnserializableReturnValueError(
return_value=computed_value, func=self._info.func
)
def clear(self):
"""Clear the wrapped function's associated cache."""
cache = self._info.get_function_cache(self._function_key)
cache.clear()
def _make_value_key(
cache_type: CacheType,
func: types.FunctionType,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
) -> str:
"""Create the key for a value within a cache.
This key is generated from the function's arguments. All arguments
will be hashed, except for those named with a leading "_".
Raises
------
StreamlitAPIException
Raised (with a nicely-formatted explanation message) if we encounter
an un-hashable arg.
"""
# Create a (name, value) list of all *args and **kwargs passed to the
# function.
arg_pairs: list[tuple[str | None, Any]] = []
for arg_idx in range(len(func_args)):
arg_name = _get_positional_arg_name(func, arg_idx)
arg_pairs.append((arg_name, func_args[arg_idx]))
for kw_name, kw_val in func_kwargs.items():
# **kwargs ordering is preserved, per PEP 468
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
# deterministic.
arg_pairs.append((kw_name, kw_val))
# Create the hash from each arg value, except for those args whose name
# starts with "_". (Underscore-prefixed args are deliberately excluded from
# hashing.)
args_hasher = hashlib.new("md5")
for arg_name, arg_value in arg_pairs:
if arg_name is not None and arg_name.startswith("_"):
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
continue
try:
update_hash(
(arg_name, arg_value),
hasher=args_hasher,
cache_type=cache_type,
)
except UnhashableTypeError as exc:
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
value_key = args_hasher.hexdigest()
_LOGGER.debug("Cache key: %s", value_key)
return value_key
def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
"""Create the unique key for a function's cache.
A function's key is stable across reruns of the app, and changes when
the function's source code changes.
"""
func_hasher = hashlib.new("md5")
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
cache_type=cache_type,
)
# Include the function's source code in its hash. If the source code can't
# be retrieved, fall back to the function's bytecode instead.
source_code: str | bytes
try:
source_code = inspect.getsource(func)
except OSError as e:
_LOGGER.debug(
"Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}",
e,
)
source_code = func.__code__.co_code
update_hash(
source_code,
hasher=func_hasher,
cache_type=cache_type,
)
cache_key = func_hasher.hexdigest()
return cache_key
def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> str | None:
"""Return the name of a function's positional argument.
If arg_index is out of range, or refers to a parameter that is not a
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
return None instead.
"""
if arg_index < 0:
return None
params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values())
if arg_index >= len(params):
return None
if params[arg_index].kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
):
return params[arg_index].name
return None

View File

@@ -0,0 +1,478 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import hashlib
import threading
import types
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Union
from google.protobuf.message import Message
from typing_extensions import Protocol, runtime_checkable
import streamlit as st
from streamlit import runtime, util
from streamlit.elements import NONWIDGET_ELEMENTS, WIDGETS
from streamlit.logger import get_logger
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_errors import (
CachedStFunctionWarning,
CacheReplayClosureError,
)
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.hashing import update_hash
from streamlit.runtime.scriptrunner.script_run_context import (
ScriptRunContext,
get_script_run_ctx,
)
from streamlit.runtime.state.common import WidgetMetadata
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
_LOGGER = get_logger(__name__)
@runtime_checkable
class Widget(Protocol):
id: str
@dataclass(frozen=True)
class WidgetMsgMetadata:
"""Everything needed for replaying a widget and treating it as an implicit
argument to a cached function, beyond what is stored for all elements.
"""
widget_id: str
widget_value: Any
metadata: WidgetMetadata[Any]
@dataclass(frozen=True)
class MediaMsgData:
media: bytes | str
mimetype: str
media_id: str
@dataclass(frozen=True)
class ElementMsgData:
"""An element's message and related metadata for
replaying that element's function call.
widget_metadata is filled in if and only if this element is a widget.
media_data is filled in iff this is a media element (image, audio, video).
"""
delta_type: str
message: Message
id_of_dg_called_on: str
returned_dgs_id: str
widget_metadata: WidgetMsgMetadata | None = None
media_data: list[MediaMsgData] | None = None
@dataclass(frozen=True)
class BlockMsgData:
message: Block
id_of_dg_called_on: str
returned_dgs_id: str
MsgData = Union[ElementMsgData, BlockMsgData]
"""
Note [Cache result structure]
The cache for a decorated function's results is split into two parts to enable
handling widgets invoked by the function.
Widgets act as implicit additional inputs to the cached function, so they should
be used when deriving the cache key. However, we don't know what widgets are
involved without first calling the function! So, we use the first execution
of the function with a particular cache key to record what widgets are used,
and use the current values of those widgets to derive a second cache key to
look up the function execution's results. The combination of first and second
cache keys act as one true cache key, just split up because the second part depends
on the first.
We need to treat widgets as implicit arguments of the cached function, because
the behavior of the function, inluding what elements are created and what it
returns, can be and usually will be influenced by the values of those widgets.
For example:
> @st.memo
> def example_fn(x):
> y = x + 1
> if st.checkbox("hi"):
> st.write("you checked the thing")
> y = 0
> return y
>
> example_fn(2)
If the checkbox is checked, the function call should return 0 and the checkbox and
message should be rendered. If the checkbox isn't checked, only the checkbox should
render, and the function will return 3.
There is a small catch in this. Since what widgets execute could depend on the values of
any prior widgets, if we replace the `st.write` call in the example with a slider,
the first time it runs, we would miss the slider because it wasn't called,
so when we later execute the function with the checkbox checked, the widget cache key
would not include the state of the slider, and would incorrectly get a cache hit
for a different slider value.
In principle the cache could be function+args key -> 1st widget key -> 2nd widget key
... -> final widget key, with each widget dependent on the exact values of the widgets
seen prior. This would prevent unnecessary cache misses due to differing values of widgets
that wouldn't affect the function's execution because they aren't even created.
But this would add even more complexity and both conceptual and runtime overhead, so it is
unclear if it would be worth doing.
Instead, we can keep the widgets as one cache key, and if we encounter a new widget
while executing the function, we just update the list of widgets to include it.
This will cause existing cached results to be invalidated, which is bad, but to
avoid it we would need to keep around the full list of widgets and values for each
widget cache key so we could compute the updated key, which is probably too expensive
to be worth it.
"""
@dataclass
class CachedResult:
"""The full results of calling a cache-decorated function, enough to
replay the st functions called while executing it.
"""
value: Any
messages: list[MsgData]
main_id: str
sidebar_id: str
@dataclass
class MultiCacheResults:
"""Widgets called by a cache-decorated function, and a mapping of the
widget-derived cache key to the final results of executing the function.
"""
widget_ids: set[str]
results: dict[str, CachedResult]
def get_current_widget_key(
self, ctx: ScriptRunContext, cache_type: CacheType
) -> str:
state = ctx.session_state
# Compute the key using only widgets that have values. A missing widget
# can be ignored because we only care about getting different keys
# for different widget values, and for that purpose doing nothing
# to the running hash is just as good as including the widget with a
# sentinel value. But by excluding it, we might get to reuse a result
# saved before we knew about that widget.
widget_values = [
(wid, state[wid]) for wid in sorted(self.widget_ids) if wid in state
]
widget_key = _make_widget_key(widget_values, cache_type)
return widget_key
"""
Note [DeltaGenerator method invocation]
There are two top level DG instances defined for all apps:
`main`, which is for putting elements in the main part of the app
`sidebar`, for the sidebar
There are 3 different ways an st function can be invoked:
1. Implicitly on the main DG instance (plain `st.foo` calls)
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
To simplify replaying messages from a cached function result, we convert all of these
to explicit invocations. How they get rewritten depends on if the invocation was
implicit vs explicit, and if the target DG has been seen/produced during replay.
Implicit invocation on a known DG -> Explicit invocation on that DG
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
with st.container():
my_cache_decorated_function()
This is situation 2 above, and the DG is a block entirely outside our function call,
so we interpret it as "put this element in the enclosing contextmanager block"
(or main if there isn't one), which is achieved by invoking on main.
Explicit invocation on a known DG -> No change needed
Explicit invocation on an unknown DG -> Raise an error
We have no way to identify the target DG, and it may not even be present in the
current script run, so the least surprising thing to do is raise an error.
"""
class CachedMessageReplayContext(threading.local):
"""A utility for storing messages generated by `st` commands called inside
a cached function.
Data is stored in a thread-local object, so it's safe to use an instance
of this class across multiple threads.
"""
def __init__(self, cache_type: CacheType):
self._cached_func_stack: list[types.FunctionType] = []
self._suppress_st_function_warning = 0
self._cached_message_stack: list[list[MsgData]] = []
self._seen_dg_stack: list[set[str]] = []
self._most_recent_messages: list[MsgData] = []
self._registered_metadata: WidgetMetadata[Any] | None = None
self._media_data: list[MediaMsgData] = []
self._cache_type = cache_type
self._allow_widgets: int = 0
def __repr__(self) -> str:
return util.repr_(self)
@contextlib.contextmanager
def calling_cached_function(
self, func: types.FunctionType, allow_widgets: bool
) -> Iterator[None]:
"""Context manager that should wrap the invocation of a cached function.
It allows us to track any `st.foo` messages that are generated from inside the function
for playback during cache retrieval.
"""
self._cached_func_stack.append(func)
self._cached_message_stack.append([])
self._seen_dg_stack.append(set())
if allow_widgets:
self._allow_widgets += 1
try:
yield
finally:
self._cached_func_stack.pop()
self._most_recent_messages = self._cached_message_stack.pop()
self._seen_dg_stack.pop()
if allow_widgets:
self._allow_widgets -= 1
def save_element_message(
self,
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Record the element protobuf as having been produced during any currently
executing cached functions, so they can be replayed any time the function's
execution is skipped because they're in the cache.
"""
if not runtime.exists():
return
if len(self._cached_message_stack) >= 1:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
# Arrow dataframes have an ID but only set it when used as data editor
# widgets, so we have to check that the ID has been actually set to
# know if an element is a widget.
if isinstance(element_proto, Widget) and element_proto.id:
wid = element_proto.id
# TODO replace `Message` with a more precise type
if not self._registered_metadata:
_LOGGER.error(
"Trying to save widget message that wasn't registered. This should not be possible."
)
raise AttributeError
widget_meta = WidgetMsgMetadata(
wid, None, metadata=self._registered_metadata
)
else:
widget_meta = None
media_data = self._media_data
element_msg_data = ElementMsgData(
delta_type,
element_proto,
id_to_save,
returned_dg_id,
widget_meta,
media_data,
)
for msgs in self._cached_message_stack:
if self._allow_widgets or widget_meta is None:
msgs.append(element_msg_data)
# Reset instance state, now that it has been used for the
# associated element.
self._media_data = []
self._registered_metadata = None
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def save_block_message(
self,
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
for msgs in self._cached_message_stack:
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
"""Select the id of the DG that this message should be invoked on
during message replay.
See Note [DeltaGenerator method invocation]
invoked_id is the DG the st function was called on, usually `st._main`.
acting_on_id is the DG the st function ultimately runs on, which may be different
if the invoked DG delegated to another one because it was in a `with` block.
"""
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
return acting_on_id
else:
return invoked_id
def save_widget_metadata(self, metadata: WidgetMetadata[Any]) -> None:
self._registered_metadata = metadata
def save_image_data(
self, image_data: bytes | str, mimetype: str, image_id: str
) -> None:
self._media_data.append(MediaMsgData(image_data, mimetype, image_id))
@contextlib.contextmanager
def suppress_cached_st_function_warning(self) -> Iterator[None]:
self._suppress_st_function_warning += 1
try:
yield
finally:
self._suppress_st_function_warning -= 1
assert self._suppress_st_function_warning >= 0
def maybe_show_cached_st_function_warning(
self,
dg: "DeltaGenerator",
st_func_name: str,
) -> None:
"""If appropriate, warn about calling st.foo inside @memo.
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
the user when they're calling st.foo() from within a function that is
wrapped in @st.cache.
Parameters
----------
dg : DeltaGenerator
The DeltaGenerator to publish the warning to.
st_func_name : str
The name of the Streamlit function that was called.
"""
# There are some elements not in either list, which we still want to warn about.
# Ideally we will fix this by either updating the lists or creating a better
# way of categorizing elements.
if st_func_name in NONWIDGET_ELEMENTS:
return
if st_func_name in WIDGETS and self._allow_widgets > 0:
return
if len(self._cached_func_stack) > 0 and self._suppress_st_function_warning <= 0:
cached_func = self._cached_func_stack[-1]
self._show_cached_st_function_warning(dg, st_func_name, cached_func)
def _show_cached_st_function_warning(
self,
dg: "DeltaGenerator",
st_func_name: str,
cached_func: types.FunctionType,
) -> None:
# Avoid infinite recursion by suppressing additional cached
# function warnings from within the cached function warning.
with self.suppress_cached_st_function_warning():
e = CachedStFunctionWarning(self._cache_type, st_func_name, cached_func)
dg.exception(e)
def replay_cached_messages(
result: CachedResult, cache_type: CacheType, cached_func: types.FunctionType
) -> None:
"""Replay the st element function calls that happened when executing a
cache-decorated function.
When a cache function is executed, we record the element and block messages
produced, and use those to reproduce the DeltaGenerator calls, so the elements
will appear in the web app even when execution of the function is skipped
because the result was cached.
To make this work, for each st function call we record an identifier for the
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
We also record the identifier for each DG returned by an st function call, if
it returns one. Then, for each recorded message, we get the current DG instance
corresponding to the DG the message was originally called on, and enqueue the
message using that, recording any new DGs produced in case a later st function
call is on one of them.
"""
from streamlit.delta_generator import DeltaGenerator
from streamlit.runtime.state.widgets import register_widget_from_metadata
# Maps originally recorded dg ids to this script run's version of that dg
returned_dgs: dict[str, DeltaGenerator] = {}
returned_dgs[result.main_id] = st._main
returned_dgs[result.sidebar_id] = st.sidebar
ctx = get_script_run_ctx()
try:
for msg in result.messages:
if isinstance(msg, ElementMsgData):
if msg.widget_metadata is not None:
register_widget_from_metadata(
msg.widget_metadata.metadata,
ctx,
None,
msg.delta_type,
)
if msg.media_data is not None:
for data in msg.media_data:
runtime.get_instance().media_file_mgr.add(
data.media, data.mimetype, data.media_id
)
dg = returned_dgs[msg.id_of_dg_called_on]
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
if isinstance(maybe_dg, DeltaGenerator):
returned_dgs[msg.returned_dgs_id] = maybe_dg
elif isinstance(msg, BlockMsgData):
dg = returned_dgs[msg.id_of_dg_called_on]
new_dg = dg._block(msg.message)
returned_dgs[msg.returned_dgs_id] = new_dg
except KeyError:
raise CacheReplayClosureError(cache_type, cached_func)
def _make_widget_key(widgets: list[tuple[str, Any]], cache_type: CacheType) -> str:
"""Generate a key for the given list of widgets used in a cache-decorated function.
Keys are generated by hashing the IDs and values of the widgets in the given list.
"""
func_hasher = hashlib.new("md5")
for widget_id_val in widgets:
update_hash(widget_id_val, func_hasher, cache_type)
return func_hasher.hexdigest()

View File

@@ -0,0 +1,395 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hashing for st.memo and st.singleton."""
import collections
import dataclasses
import functools
import hashlib
import inspect
import io
import os
import pickle
import sys
import tempfile
import threading
import unittest.mock
import weakref
from enum import Enum
from typing import Any, Dict, List, Optional, Pattern
from streamlit import type_util, util
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.uploaded_file_manager import UploadedFile
# If a dataframe has more than this many rows, we consider it large and hash a sample.
_PANDAS_ROWS_LARGE = 100000
_PANDAS_SAMPLE_SIZE = 10000
# Similar to dataframes, we also sample large numpy arrays.
_NP_SIZE_LARGE = 1000000
_NP_SAMPLE_SIZE = 100000
# Arbitrary item to denote where we found a cycle in a hashed object.
# This allows us to hash self-referencing lists, dictionaries, etc.
_CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
def update_hash(val: Any, hasher, cache_type: CacheType) -> None:
"""Updates a hashlib hasher with the hash of val.
This is the main entrypoint to hashing.py.
"""
ch = _CacheFuncHasher(cache_type)
ch.update(hasher, val)
class _HashStack:
"""Stack of what has been hashed, for debug and circular reference detection.
This internally keeps 1 stack per thread.
Internally, this stores the ID of pushed objects rather than the objects
themselves because otherwise the "in" operator inside __contains__ would
fail for objects that don't return a boolean for "==" operator. For
example, arr == 10 where arr is a NumPy array returns another NumPy array.
This causes the "in" to crash since it expects a boolean.
"""
def __init__(self):
self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict()
def __repr__(self) -> str:
return util.repr_(self)
def push(self, val: Any):
self._stack[id(val)] = val
def pop(self):
self._stack.popitem()
def __contains__(self, val: Any):
return id(val) in self._stack
class _HashStacks:
"""Stacks of what has been hashed, with at most 1 stack per thread."""
def __init__(self):
self._stacks: weakref.WeakKeyDictionary[
threading.Thread, _HashStack
] = weakref.WeakKeyDictionary()
def __repr__(self) -> str:
return util.repr_(self)
@property
def current(self) -> _HashStack:
current_thread = threading.current_thread()
stack = self._stacks.get(current_thread, None)
if stack is None:
stack = _HashStack()
self._stacks[current_thread] = stack
return stack
hash_stacks = _HashStacks()
def _int_to_bytes(i: int) -> bytes:
num_bytes = (i.bit_length() + 8) // 8
return i.to_bytes(num_bytes, "little", signed=True)
def _key(obj: Optional[Any]) -> Any:
"""Return key for memoization."""
if obj is None:
return None
def is_simple(obj):
return (
isinstance(obj, bytes)
or isinstance(obj, bytearray)
or isinstance(obj, str)
or isinstance(obj, float)
or isinstance(obj, int)
or isinstance(obj, bool)
or obj is None
)
if is_simple(obj):
return obj
if isinstance(obj, tuple):
if all(map(is_simple, obj)):
return obj
if isinstance(obj, list):
if all(map(is_simple, obj)):
return ("__l", tuple(obj))
if (
type_util.is_type(obj, "pandas.core.frame.DataFrame")
or type_util.is_type(obj, "numpy.ndarray")
or inspect.isbuiltin(obj)
or inspect.isroutine(obj)
or inspect.iscode(obj)
):
return id(obj)
return NoResult
class _CacheFuncHasher:
"""A hasher that can hash objects with cycles."""
def __init__(self, cache_type: CacheType):
self._hashes: Dict[Any, bytes] = {}
# The number of the bytes in the hash.
self.size = 0
self.cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
def to_bytes(self, obj: Any) -> bytes:
"""Add memoization to _to_bytes and protect against cycles in data structures."""
tname = type(obj).__qualname__.encode()
key = (tname, _key(obj))
# Memoize if possible.
if key[1] is not NoResult:
if key in self._hashes:
return self._hashes[key]
# Break recursive cycles.
if obj in hash_stacks.current:
return _CYCLE_PLACEHOLDER
hash_stacks.current.push(obj)
try:
# Hash the input
b = b"%s:%s" % (tname, self._to_bytes(obj))
# Hmmm... It's possible that the size calculation is wrong. When we
# call to_bytes inside _to_bytes things get double-counted.
self.size += sys.getsizeof(b)
if key[1] is not NoResult:
self._hashes[key] = b
finally:
# In case an UnhashableTypeError (or other) error is thrown, clean up the
# stack so we don't get false positives in future hashing calls
hash_stacks.current.pop()
return b
def update(self, hasher, obj: Any) -> None:
"""Update the provided hasher with the hash of an object."""
b = self.to_bytes(obj)
hasher.update(b)
def _to_bytes(self, obj: Any) -> bytes:
"""Hash objects to bytes, including code with dependencies.
Python's built in `hash` does not produce consistent results across
runs.
"""
if isinstance(obj, unittest.mock.Mock):
# Mock objects can appear to be infinitely
# deep, so we don't try to hash them at all.
return self.to_bytes(id(obj))
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
return obj
elif isinstance(obj, str):
return obj.encode()
elif isinstance(obj, float):
return self.to_bytes(hash(obj))
elif isinstance(obj, int):
return _int_to_bytes(obj)
elif isinstance(obj, (list, tuple)):
h = hashlib.new("md5")
for item in obj:
self.update(h, item)
return h.digest()
elif isinstance(obj, dict):
h = hashlib.new("md5")
for item in obj.items():
self.update(h, item)
return h.digest()
elif obj is None:
return b"0"
elif obj is True:
return b"1"
elif obj is False:
return b"0"
elif dataclasses.is_dataclass(obj):
return self.to_bytes(dataclasses.asdict(obj))
elif isinstance(obj, Enum):
return str(obj).encode()
elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
obj, "pandas.core.series.Series"
):
import pandas as pd
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
return b"%s" % pd.util.hash_pandas_object(obj).sum()
except TypeError:
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "numpy.ndarray"):
h = hashlib.new("md5")
self.update(h, obj.shape)
if obj.size >= _NP_SIZE_LARGE:
import numpy as np
state = np.random.RandomState(0)
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
self.update(h, obj.tobytes())
return h.digest()
elif type_util.is_type(obj, "PIL.Image.Image"):
import numpy as np
# we don't just hash the results of obj.tobytes() because we want to use
# the sampling logic for numpy data
np_array = np.frombuffer(obj.tobytes(), dtype="uint8")
return self.to_bytes(np_array)
elif inspect.isbuiltin(obj):
return bytes(obj.__name__.encode())
elif type_util.is_type(obj, "builtins.mappingproxy") or type_util.is_type(
obj, "builtins.dict_items"
):
return self.to_bytes(dict(obj))
elif type_util.is_type(obj, "builtins.getset_descriptor"):
return bytes(obj.__qualname__.encode())
elif isinstance(obj, UploadedFile):
# UploadedFile is a BytesIO (thus IOBase) but has a name.
# It does not have a timestamp so this must come before
# temporary files
h = hashlib.new("md5")
self.update(h, obj.name)
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif hasattr(obj, "name") and (
isinstance(obj, io.IOBase)
# Handle temporary files used during testing
or isinstance(obj, tempfile._TemporaryFileWrapper)
):
# Hash files as name + last modification date + offset.
# NB: we're using hasattr("name") to differentiate between
# on-disk and in-memory StringIO/BytesIO file representations.
# That means that this condition must come *before* the next
# condition, which just checks for StringIO/BytesIO.
h = hashlib.new("md5")
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
self.update(h, obj_name)
self.update(h, os.path.getmtime(obj_name))
self.update(h, obj.tell())
return h.digest()
elif isinstance(obj, Pattern):
return self.to_bytes([obj.pattern, obj.flags])
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
# Hash in-memory StringIO/BytesIO by their full contents
# and seek position.
h = hashlib.new("md5")
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif type_util.is_type(obj, "numpy.ufunc"):
# For numpy.remainder, this returns remainder.
return bytes(obj.__name__.encode())
elif inspect.ismodule(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# so the current warning is quite annoying...
# st.warning(('Streamlit does not support hashing modules. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name for internal modules.
return self.to_bytes(obj.__name__)
elif inspect.isclass(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# (e.g. in every "except" statement) so the current warning is
# quite annoying...
# st.warning(('Streamlit does not support hashing classes. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name of classes.
return self.to_bytes(obj.__name__)
elif isinstance(obj, functools.partial):
# The return value of functools.partial is not a plain function:
# it's a callable object that remembers the original function plus
# the values you pickled into it. So here we need to special-case it.
h = hashlib.new("md5")
self.update(h, obj.args)
self.update(h, obj.func)
self.update(h, obj.keywords)
return h.digest()
else:
# As a last resort, hash the output of the object's __reduce__ method
h = hashlib.new("md5")
try:
reduce_data = obj.__reduce__()
except Exception as ex:
raise UnhashableTypeError() from ex
for item in reduce_data:
self.update(h, item)
return h.digest()
class NoResult:
"""Placeholder class for return values when None is meaningful."""
pass

View File

@@ -0,0 +1,29 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage as CacheStorage,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageContext as CacheStorageContext,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageError as CacheStorageError,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageKeyNotFoundError as CacheStorageKeyNotFoundError,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageManager as CacheStorageManager,
)

View File

@@ -0,0 +1,240 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Declares the CacheStorageContext dataclass, which contains parameter information for
each function decorated by `@st.cache_data` (for example: ttl, max_entries etc.)
Declares the CacheStorageManager protocol, which implementations are used
to create CacheStorage instances and to optionally clear all cache storages,
that were created by this manager, and to check if the context is valid for the storage.
Declares the CacheStorage protocol, which implementations are used to store cached
values for a single `@st.cache_data` decorated function serialized as bytes.
How these classes work together
-------------------------------
- CacheStorageContext : this is a dataclass that contains the parameters from
`@st.cache_data` that are passed to the CacheStorageManager.create() method.
- CacheStorageManager : each instance of this is able to create CacheStorage
instances, and optionally to clear data of all cache storages.
- CacheStorage : each instance of this is able to get, set, delete, and clear
entries for a single `@st.cache_data` decorated function.
┌───────────────────────────────┐
│ │
│ CacheStorageManager │
│ │
│ - clear_all(optional) │
│ - check_context │
│ │
└──┬────────────────────────────┘
│ ┌──────────────────────┐
│ │ CacheStorage │
│ create(context)│ │
└────────────────► - get │
│ - set │
│ - delete │
│ - close (optional)│
│ - clear │
└──────────────────────┘
"""
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from typing_extensions import Literal, Protocol
class CacheStorageError(Exception):
"""Base exception raised by the cache storage"""
class CacheStorageKeyNotFoundError(CacheStorageError):
"""Raised when the key is not found in the cache storage"""
class InvalidCacheStorageContext(CacheStorageError):
"""Raised if the cache storage manager is not able to work with
provided CacheStorageContext.
"""
@dataclass(frozen=True)
class CacheStorageContext:
"""Context passed to the cache storage during initialization
This is the normalized parameters that are passed to CacheStorageManager.create()
method.
Parameters
----------
function_key: str
A hash computed based on function name and source code decorated
by `@st.cache_data`
function_display_name: str
The display name of the function that is decorated by `@st.cache_data`
ttl_seconds : float or None
The time-to-live for the keys in storage, in seconds. If None, the entry
will never expire.
max_entries : int or None
The maximum number of entries to store in the cache storage.
If None, the cache storage will not limit the number of entries.
persist : Literal["disk"] or None
The persistence mode for the cache storage.
Legacy parameter, that used in Streamlit current cache storage implementation.
Could be ignored by cache storage implementation, if storage does not support
persistence or it persistent by default.
"""
function_key: str
function_display_name: str
ttl_seconds: float | None = None
max_entries: int | None = None
persist: Literal["disk"] | None = None
class CacheStorage(Protocol):
"""Cache storage protocol, that should be implemented by the concrete cache storages.
Used to store cached values for a single `@st.cache_data` decorated function
serialized as bytes.
CacheStorage instances should be created by `CacheStorageManager.create()` method.
Notes
-----
Threading: The methods of this protocol could be called from multiple threads.
This is a responsibility of the concrete implementation to ensure thread safety
guarantees.
"""
@abstractmethod
def get(self, key: str) -> bytes:
"""Returns the stored value for the key.
Raises
------
CacheStorageKeyNotFoundError
Raised if the key is not in the storage.
"""
raise NotImplementedError
@abstractmethod
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key"""
raise NotImplementedError
@abstractmethod
def delete(self, key: str) -> None:
"""Delete a given key"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Remove all keys for the storage"""
raise NotImplementedError
def close(self) -> None:
"""Closes the cache storage, it is optional to implement, and should be used
to close open resources, before we delete the storage instance.
e.g. close the database connection etc.
"""
pass
class CacheStorageManager(Protocol):
"""Cache storage manager protocol, that should be implemented by the concrete
cache storage managers.
It is responsible for:
- Creating cache storage instances for the specific
decorated functions,
- Validating the context for the cache storages.
- Optionally clearing all cache storages in optimal way.
It should be created during Runtime initialization.
"""
@abstractmethod
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance
Please note that the ttl, max_entries and other context fields are specific
for whole storage, not for individual key.
Notes
-----
Threading: Should be safe to call from any thread.
"""
raise NotImplementedError
def clear_all(self) -> None:
"""Remove everything what possible from the cache storages in optimal way.
meaningful default behaviour is to raise NotImplementedError, so this is not
abstractmethod.
The method is optional to implement: cache data API will fall back to remove
all available storages one by one via storage.clear() method
if clear_all raises NotImplementedError.
Raises
------
NotImplementedError
Raised if the storage manager does not provide an ability to clear
all storages at once in optimal way.
Notes
-----
Threading: This method could be called from multiple threads.
This is a responsibility of the concrete implementation to ensure
thread safety guarantees.
"""
raise NotImplementedError
def check_context(self, context: CacheStorageContext) -> None:
"""Checks if the context is valid for the storage manager.
This method should not return anything, but log message or raise an exception
if the context is invalid.
In case of raising an exception, we not handle it and let the exception to be
propagated.
check_context is called only once at the moment of creating `@st.cache_data`
decorator for specific function, so it is not called for every cache hit.
Parameters
----------
context: CacheStorageContext
The context to check for the storage manager, dummy function_key in context
will be used, since it is not computed at the point of calling this method.
Raises
------
InvalidCacheStorageContext
Raised if the cache storage manager is not able to work with provided
CacheStorageContext. When possible we should log message instead, since
this exception will be propagated to the user.
Notes
-----
Threading: Should be safe to call from any thread.
"""
pass

View File

@@ -0,0 +1,60 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
InMemoryCacheStorageWrapper,
)
class MemoryCacheStorageManager(CacheStorageManager):
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
persist_storage = DummyCacheStorage()
return InMemoryCacheStorageWrapper(
persist_storage=persist_storage, context=context
)
def clear_all(self) -> None:
raise NotImplementedError
def check_context(self, context: CacheStorageContext) -> None:
pass
class DummyCacheStorage(CacheStorage):
def get(self, key: str) -> bytes:
"""
Dummy gets the value for a given key,
always raises an CacheStorageKeyNotFoundError
"""
raise CacheStorageKeyNotFoundError("Key not found in dummy cache")
def set(self, key: str, value: bytes) -> None:
pass
def delete(self, key: str) -> None:
pass
def clear(self) -> None:
pass
def close(self) -> None:
pass

View File

@@ -0,0 +1,145 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import threading
from cachetools import TTLCache
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageKeyNotFoundError,
)
from streamlit.runtime.stats import CacheStat
_LOGGER = get_logger(__name__)
class InMemoryCacheStorageWrapper(CacheStorage):
"""
In-memory cache storage wrapper.
This class wraps a cache storage and adds an in-memory cache front layer,
which is used to reduce the number of calls to the storage.
The in-memory cache is a TTL cache, which means that the entries are
automatically removed if a given time to live (TTL) has passed.
The in-memory cache is also an LRU cache, which means that the entries
are automatically removed if the cache size exceeds a given maxsize.
If the storage implements its strategy for maxsize, it is recommended
(but not necessary) that the storage implement the same LRU strategy,
otherwise a situation may arise when different items are deleted from
the memory cache and from the storage.
Notes
-----
Threading: in-memory caching layer is thread safe: we hold self._mem_cache_lock for
working with this self._mem_cache object.
However, we do not hold this lock when calling into the underlying storage,
so it is the responsibility of the that storage to ensure that it is safe to use
it from multiple threads.
"""
def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
self.function_key = context.function_key
self.function_display_name = context.function_display_name
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
self._mem_cache: TTLCache[str, bytes] = TTLCache(
maxsize=self.max_entries,
ttl=self.ttl_seconds,
timer=cache_utils.TTLCACHE_TIMER,
)
self._mem_cache_lock = threading.Lock()
self._persist_storage = persist_storage
@property
def ttl_seconds(self) -> float:
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
@property
def max_entries(self) -> float:
return float(self._max_entries) if self._max_entries is not None else math.inf
def get(self, key: str) -> bytes:
"""
Returns the stored value for the key or raise CacheStorageKeyNotFoundError if
the key is not found
"""
try:
entry_bytes = self._read_from_mem_cache(key)
except CacheStorageKeyNotFoundError:
entry_bytes = self._persist_storage.get(key)
self._write_to_mem_cache(key, entry_bytes)
return entry_bytes
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key"""
self._write_to_mem_cache(key, value)
self._persist_storage.set(key, value)
def delete(self, key: str) -> None:
"""Delete a given key"""
self._remove_from_mem_cache(key)
self._persist_storage.delete(key)
def clear(self) -> None:
"""Delete all keys for the in memory cache, and also the persistent storage"""
with self._mem_cache_lock:
self._mem_cache.clear()
self._persist_storage.clear()
def get_stats(self) -> list[CacheStat]:
"""Returns a list of stats in bytes for the cache memory storage per item"""
stats = []
with self._mem_cache_lock:
for item in self._mem_cache.values():
stats.append(
CacheStat(
category_name="st_cache_data",
cache_name=self.function_display_name,
byte_length=len(item),
)
)
return stats
def close(self) -> None:
"""Closes the cache storage"""
self._persist_storage.close()
def _read_from_mem_cache(self, key: str) -> bytes:
with self._mem_cache_lock:
if key in self._mem_cache:
entry = bytes(self._mem_cache[key])
_LOGGER.debug("Memory cache HIT: %s", key)
return entry
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheStorageKeyNotFoundError("Key not found in mem cache")
def _write_to_mem_cache(self, key: str, entry_bytes: bytes) -> None:
with self._mem_cache_lock:
self._mem_cache[key] = entry_bytes
def _remove_from_mem_cache(self, key: str) -> None:
with self._mem_cache_lock:
self._mem_cache.pop(key, None)

View File

@@ -0,0 +1,222 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Declares the LocalDiskCacheStorageManager class, which is used
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
InMemoryCacheStorageWrapper wrapper allows to have first layer of in-memory cache,
before accessing to LocalDiskCacheStorage itself.
Declares the LocalDiskCacheStorage class, which is used to store cached
values on disk.
How these classes work together
-------------------------------
- LocalDiskCacheStorageManager : each instance of this is able
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
and to clear data from cache storage folder. It is also LocalDiskCacheStorageManager
responsibility to check if the context is valid for the storage, and to log warning
if the context is not valid.
- LocalDiskCacheStorage : each instance of this is able to get, set, delete, and clear
entries from disk for a single `@st.cache_data` decorated function if `persist="disk"`
is used in CacheStorageContext.
┌───────────────────────────────┐
│ LocalDiskCacheStorageManager │
│ │
│ - clear_all │
│ - check_context │
│ │
└──┬────────────────────────────┘
│ ┌──────────────────────────────┐
│ │ │
│ create(context)│ InMemoryCacheStorageWrapper │
└────────────────► │
│ ┌─────────────────────┐ │
│ │ │ │
│ │ LocalDiskStorage │ │
│ │ │ │
│ └─────────────────────┘ │
│ │
└──────────────────────────────┘
"""
from __future__ import annotations
import math
import os
import shutil
from streamlit import util
from streamlit.file_util import get_streamlit_file_path, streamlit_read, streamlit_write
from streamlit.logger import get_logger
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
InMemoryCacheStorageWrapper,
)
# Streamlit directory where persisted @st.cache_data objects live.
# (This is the same directory that @st.cache persisted objects live.
# But @st.cache_data uses a different extension, so they don't overlap.)
_CACHE_DIR_NAME = "cache"
# The extension for our persisted @st.cache_data objects.
# (`@st.cache_data` was originally called `@st.memo`)
_CACHED_FILE_EXTENSION = "memo"
_LOGGER = get_logger(__name__)
class LocalDiskCacheStorageManager(CacheStorageManager):
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
persist_storage = LocalDiskCacheStorage(context)
return InMemoryCacheStorageWrapper(
persist_storage=persist_storage, context=context
)
def clear_all(self) -> None:
cache_path = get_cache_folder_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
def check_context(self, context: CacheStorageContext) -> None:
if (
context.persist == "disk"
and context.ttl_seconds is not None
and not math.isinf(context.ttl_seconds)
):
_LOGGER.warning(
f"The cached function '{context.function_display_name}' has a TTL "
"that will be ignored. Persistent cached functions currently don't "
"support TTL."
)
class LocalDiskCacheStorage(CacheStorage):
"""Cache storage that persists data to disk
This is the default cache persistence layer for `@st.cache_data`
"""
def __init__(self, context: CacheStorageContext):
self.function_key = context.function_key
self.persist = context.persist
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
@property
def ttl_seconds(self) -> float:
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
@property
def max_entries(self) -> float:
return float(self._max_entries) if self._max_entries is not None else math.inf
def get(self, key: str) -> bytes:
"""
Returns the stored value for the key if persisted,
raise CacheStorageKeyNotFoundError if not found, or not configured
with persist="disk"
"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
with streamlit_read(path, binary=True) as input:
value = input.read()
_LOGGER.debug("Disk cache HIT: %s", key)
return bytes(value)
except FileNotFoundError:
raise CacheStorageKeyNotFoundError("Key not found in disk cache")
except Exception as ex:
_LOGGER.error(ex)
raise CacheStorageError("Unable to read from cache") from ex
else:
raise CacheStorageKeyNotFoundError(
f"Local disk cache storage is disabled (persist={self.persist})"
)
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
with streamlit_write(path, binary=True) as output:
output.write(value)
except util.Error as e:
_LOGGER.debug(e)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, IOError, OSError):
# If we can't remove the file, it's not a big deal.
pass
raise CacheStorageError("Unable to write to cache") from e
def delete(self, key: str) -> None:
"""Delete a cache file from disk. If the file does not exist on disk,
return silently. If another exception occurs, log it. Does not throw.
"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
os.remove(path)
except FileNotFoundError:
# The file is already removed.
pass
except Exception as ex:
_LOGGER.exception(
"Unable to remove a file from the disk cache", exc_info=ex
)
def clear(self) -> None:
"""Delete all keys for the current storage"""
cache_dir = get_cache_folder_path()
if os.path.isdir(cache_dir):
# We try to remove all files in the cache directory that start with
# the function key, whether `clear` called for `self.persist`
# storage or not, to avoid leaving orphaned files in the cache directory.
for file_name in os.listdir(cache_dir):
if self._is_cache_file(file_name):
os.remove(os.path.join(cache_dir, file_name))
def close(self) -> None:
"""Dummy implementation of close, we don't need to actually "close" anything"""
def _get_cache_file_path(self, value_key: str) -> str:
"""Return the path of the disk cache file for the given value."""
cache_dir = get_cache_folder_path()
return os.path.join(
cache_dir, f"{self.function_key}-{value_key}.{_CACHED_FILE_EXTENSION}"
)
def _is_cache_file(self, fname: str) -> bool:
"""Return true if the given file name is a cache file for this storage."""
return fname.startswith(f"{self.function_key}-") and fname.endswith(
f".{_CACHED_FILE_EXTENSION}"
)
def get_cache_folder_path() -> str:
return get_streamlit_file_path(_CACHE_DIR_NAME)

View File

@@ -0,0 +1,308 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Manage the user's Streamlit credentials."""
import os
import sys
import textwrap
from collections import namedtuple
from typing import Optional
import click
import toml
from streamlit import env_util, file_util, util
from streamlit.logger import get_logger
LOGGER = get_logger(__name__)
if env_util.IS_WINDOWS:
_CONFIG_FILE_PATH = r"%userprofile%/.streamlit/config.toml"
else:
_CONFIG_FILE_PATH = "~/.streamlit/config.toml"
_Activation = namedtuple(
"_Activation",
[
"email", # str : the user's email.
"is_valid", # boolean : whether the email is valid.
],
)
def email_prompt() -> str:
# Emoji can cause encoding errors on non-UTF-8 terminals
# (See https://github.com/streamlit/streamlit/issues/2284.)
# WT_SESSION is a Windows Terminal specific environment variable. If it exists,
# we are on the latest Windows Terminal that supports emojis
show_emoji = sys.stdout.encoding == "utf-8" and (
not env_util.IS_WINDOWS or os.environ.get("WT_SESSION")
)
# IMPORTANT: Break the text below at 80 chars.
return """
{0}%(welcome)s
If youd like to receive helpful onboarding emails, news, offers, promotions,
and the occasional swag, please enter your email address below. Otherwise,
leave this field blank.
%(email)s""".format(
"👋 " if show_emoji else ""
) % {
"welcome": click.style("Welcome to Streamlit!", bold=True),
"email": click.style("Email: ", fg="blue"),
}
# IMPORTANT: Break the text below at 80 chars.
_TELEMETRY_TEXT = """
You can find our privacy policy at %(link)s
Summary:
- This open source library collects usage statistics.
- We cannot see and do not store information contained inside Streamlit apps,
such as text, charts, images, etc.
- Telemetry data is stored in servers in the United States.
- If you'd like to opt out, add the following to %(config)s,
creating that file if necessary:
[browser]
gatherUsageStats = false
""" % {
"link": click.style("https://streamlit.io/privacy-policy", underline=True),
"config": click.style(_CONFIG_FILE_PATH),
}
_TELEMETRY_HEADLESS_TEXT = """
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
"""
# IMPORTANT: Break the text below at 80 chars.
_INSTRUCTIONS_TEXT = """
%(start)s
%(prompt)s %(hello)s
""" % {
"start": click.style("Get started by typing:", fg="blue", bold=True),
"prompt": click.style("$", fg="blue"),
"hello": click.style("streamlit hello", bold=True),
}
class Credentials(object):
"""Credentials class."""
_singleton: Optional["Credentials"] = None
@classmethod
def get_current(cls):
"""Return the singleton instance."""
if cls._singleton is None:
Credentials()
return Credentials._singleton
def __init__(self):
"""Initialize class."""
if Credentials._singleton is not None:
raise RuntimeError(
"Credentials already initialized. Use .get_current() instead"
)
self.activation = None
self._conf_file = _get_credential_file_path()
Credentials._singleton = self
def __repr__(self) -> str:
return util.repr_(self)
def load(self, auto_resolve=False) -> None:
"""Load from toml file."""
if self.activation is not None:
LOGGER.error("Credentials already loaded. Not rereading file.")
return
try:
with open(self._conf_file, "r") as f:
data = toml.load(f).get("general")
if data is None:
raise Exception
self.activation = _verify_email(data.get("email"))
except FileNotFoundError:
if auto_resolve:
self.activate(show_instructions=not auto_resolve)
return
raise RuntimeError(
'Credentials not found. Please run "streamlit activate".'
)
except Exception:
if auto_resolve:
self.reset()
self.activate(show_instructions=not auto_resolve)
return
raise Exception(
textwrap.dedent(
"""
Unable to load credentials from %s.
Run "streamlit reset" and try again.
"""
)
% (self._conf_file)
)
def _check_activated(self, auto_resolve=True):
"""Check if streamlit is activated.
Used by `streamlit run script.py`
"""
try:
self.load(auto_resolve)
except (Exception, RuntimeError) as e:
_exit(str(e))
if self.activation is None or not self.activation.is_valid:
_exit("Activation email not valid.")
@classmethod
def reset(cls):
"""Reset credentials by removing file.
This is used by `streamlit activate reset` in case a user wants
to start over.
"""
c = Credentials.get_current()
c.activation = None
try:
os.remove(c._conf_file)
except OSError as e:
LOGGER.error("Error removing credentials file: %s" % e)
def save(self):
"""Save to toml file."""
if self.activation is None:
return
# Create intermediate directories if necessary
os.makedirs(os.path.dirname(self._conf_file), exist_ok=True)
# Write the file
data = {"email": self.activation.email}
with open(self._conf_file, "w") as f:
toml.dump({"general": data}, f)
def activate(self, show_instructions: bool = True) -> None:
"""Activate Streamlit.
Used by `streamlit activate`.
"""
try:
self.load()
except RuntimeError:
# Runtime Error is raised if credentials file is not found. In that case,
# `self.activation` is None and we will show the activation prompt below.
pass
if self.activation:
if self.activation.is_valid:
_exit("Already activated")
else:
_exit(
"Activation not valid. Please run "
"`streamlit activate reset` then `streamlit activate`"
)
else:
activated = False
while not activated:
email = click.prompt(
text=email_prompt(),
prompt_suffix="",
default="",
show_default=False,
)
self.activation = _verify_email(email)
if self.activation.is_valid:
self.save()
click.secho(_TELEMETRY_TEXT)
if show_instructions:
click.secho(_INSTRUCTIONS_TEXT)
activated = True
else: # pragma: nocover
LOGGER.error("Please try again.")
def _verify_email(email: str) -> _Activation:
"""Verify the user's email address.
The email can either be an empty string (if the user chooses not to enter
it), or a string with a single '@' somewhere in it.
Parameters
----------
email : str
Returns
-------
_Activation
An _Activation object. Its 'is_valid' property will be True only if
the email was validated.
"""
email = email.strip()
# We deliberately use simple email validation here
# since we do not use email address anywhere to send emails.
if len(email) > 0 and email.count("@") != 1:
LOGGER.error("That doesn't look like an email :(")
return _Activation(None, False)
return _Activation(email, True)
def _exit(message): # pragma: nocover
"""Exit program with error."""
LOGGER.error(message)
sys.exit(-1)
def _get_credential_file_path():
return file_util.get_streamlit_file_path("credentials.toml")
def _check_credential_file_exists():
return os.path.exists(_get_credential_file_path())
def check_credentials():
"""Check credentials and potentially activate.
Note
----
If there is no credential file and we are in headless mode, we should not
check, since credential would be automatically set to an empty string.
"""
from streamlit import config
if not _check_credential_file_exists() and config.get_option("server.headless"):
if not config.is_manually_set("browser.gatherUsageStats"):
# If not manually defined, show short message about usage stats gathering.
click.secho(_TELEMETRY_HEADLESS_TEXT)
return
Credentials.get_current()._check_activated()

View File

@@ -0,0 +1,270 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
from typing import TYPE_CHECKING, Dict, List, MutableMapping, Optional
from weakref import WeakKeyDictionary
from streamlit import config, util
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
if TYPE_CHECKING:
from streamlit.runtime.app_session import AppSession
LOGGER = get_logger(__name__)
def populate_hash_if_needed(msg: ForwardMsg) -> str:
"""Computes and assigns the unique hash for a ForwardMsg.
If the ForwardMsg already has a hash, this is a no-op.
Parameters
----------
msg : ForwardMsg
Returns
-------
string
The message's hash, returned here for convenience. (The hash
will also be assigned to the ForwardMsg; callers do not need
to do this.)
"""
if msg.hash == "":
# Move the message's metadata aside. It's not part of the
# hash calculation.
metadata = msg.metadata
msg.ClearField("metadata")
# MD5 is good enough for what we need, which is uniqueness.
hasher = hashlib.md5()
hasher.update(msg.SerializeToString())
msg.hash = hasher.hexdigest()
# Restore metadata.
msg.metadata.CopyFrom(metadata)
return msg.hash
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg:
"""Create a ForwardMsg that refers to the given message via its hash.
The reference message will also get a copy of the source message's
metadata.
Parameters
----------
msg : ForwardMsg
The ForwardMsg to create the reference to.
Returns
-------
ForwardMsg
A new ForwardMsg that "points" to the original message via the
ref_hash field.
"""
ref_msg = ForwardMsg()
ref_msg.ref_hash = populate_hash_if_needed(msg)
ref_msg.metadata.CopyFrom(msg.metadata)
return ref_msg
class ForwardMsgCache(CacheStatsProvider):
"""A cache of ForwardMsgs.
Large ForwardMsgs (e.g. those containing big DataFrame payloads) are
stored in this cache. The server can choose to send a ForwardMsg's hash,
rather than the message itself, to a client. Clients can then
request messages from this cache via another endpoint.
This cache is *not* thread safe. It's intended to only be accessed by
the server thread.
"""
class Entry:
"""Cache entry.
Stores the cached message, and the set of AppSessions
that we've sent the cached message to.
"""
def __init__(self, msg: ForwardMsg):
self.msg = msg
self._session_script_run_counts: MutableMapping[
"AppSession", int
] = WeakKeyDictionary()
def __repr__(self) -> str:
return util.repr_(self)
def add_session_ref(self, session: "AppSession", script_run_count: int) -> None:
"""Adds a reference to a AppSession that has referenced
this Entry's message.
Parameters
----------
session : AppSession
script_run_count : int
The session's run count at the time of the call
"""
prev_run_count = self._session_script_run_counts.get(session, 0)
if script_run_count < prev_run_count:
LOGGER.error(
"New script_run_count (%s) is < prev_run_count (%s). "
"This should never happen!" % (script_run_count, prev_run_count)
)
script_run_count = prev_run_count
self._session_script_run_counts[session] = script_run_count
def has_session_ref(self, session: "AppSession") -> bool:
return session in self._session_script_run_counts
def get_session_ref_age(
self, session: "AppSession", script_run_count: int
) -> int:
"""The age of the given session's reference to the Entry,
given a new script_run_count.
"""
return script_run_count - self._session_script_run_counts[session]
def remove_session_ref(self, session: "AppSession") -> None:
del self._session_script_run_counts[session]
def has_refs(self) -> bool:
"""True if this Entry has references from any AppSession.
If not, it can be removed from the cache.
"""
return len(self._session_script_run_counts) > 0
def __init__(self):
self._entries: Dict[str, "ForwardMsgCache.Entry"] = {}
def __repr__(self) -> str:
return util.repr_(self)
def add_message(
self, msg: ForwardMsg, session: "AppSession", script_run_count: int
) -> None:
"""Add a ForwardMsg to the cache.
The cache will also record a reference to the given AppSession,
so that it can track which sessions have already received
each given ForwardMsg.
Parameters
----------
msg : ForwardMsg
session : AppSession
script_run_count : int
The number of times the session's script has run
"""
populate_hash_if_needed(msg)
entry = self._entries.get(msg.hash, None)
if entry is None:
entry = ForwardMsgCache.Entry(msg)
self._entries[msg.hash] = entry
entry.add_session_ref(session, script_run_count)
def get_message(self, hash: str) -> Optional[ForwardMsg]:
"""Return the message with the given ID if it exists in the cache.
Parameters
----------
hash : string
The id of the message to retrieve.
Returns
-------
ForwardMsg | None
"""
entry = self._entries.get(hash, None)
return entry.msg if entry else None
def has_message_reference(
self, msg: ForwardMsg, session: "AppSession", script_run_count: int
) -> bool:
"""Return True if a session has a reference to a message."""
populate_hash_if_needed(msg)
entry = self._entries.get(msg.hash, None)
if entry is None or not entry.has_session_ref(session):
return False
# Ensure we're not expired
age = entry.get_session_ref_age(session, script_run_count)
return age <= int(config.get_option("global.maxCachedMessageAge"))
def remove_expired_session_entries(
self, session: "AppSession", script_run_count: int
) -> None:
"""Remove any cached messages that have expired from the given session.
This should be called each time a AppSession finishes executing.
Parameters
----------
session : AppSession
script_run_count : int
The number of times the session's script has run
"""
max_age = config.get_option("global.maxCachedMessageAge")
# Operate on a copy of our entries dict.
# We may be deleting from it.
for msg_hash, entry in self._entries.copy().items():
if not entry.has_session_ref(session):
continue
age = entry.get_session_ref_age(session, script_run_count)
if age > max_age:
LOGGER.debug(
"Removing expired entry [session=%s, hash=%s, age=%s]",
id(session),
msg_hash,
age,
)
entry.remove_session_ref(session)
if not entry.has_refs():
# The entry has no more references. Remove it from
# the cache completely.
del self._entries[msg_hash]
def clear(self) -> None:
"""Remove all entries from the cache"""
self._entries.clear()
def get_stats(self) -> List[CacheStat]:
stats: List[CacheStat] = []
for entry_hash, entry in self._entries.items():
stats.append(
CacheStat(
category_name="ForwardMessageCache",
cache_name="",
byte_length=entry.msg.ByteSize(),
)
)
return stats

View File

@@ -0,0 +1,143 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from streamlit.logger import get_logger
from streamlit.proto.Delta_pb2 import Delta
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
LOGGER = get_logger(__name__)
class ForwardMsgQueue:
"""Accumulates a session's outgoing ForwardMsgs.
Each AppSession adds messages to its queue, and the Server periodically
flushes all session queues and delivers their messages to the appropriate
clients.
ForwardMsgQueue is not thread-safe - a queue should only be used from
a single thread.
"""
def __init__(self):
self._queue: List[ForwardMsg] = []
# A mapping of (delta_path -> _queue.indexof(msg)) for each
# Delta message in the queue. We use this for coalescing
# redundant outgoing Deltas (where a newer Delta supersedes
# an older Delta, with the same delta_path, that's still in the
# queue).
self._delta_index_map: Dict[Tuple[int, ...], int] = dict()
def get_debug(self) -> Dict[str, Any]:
from google.protobuf.json_format import MessageToDict
return {
"queue": [MessageToDict(m) for m in self._queue],
"ids": list(self._delta_index_map.keys()),
}
def is_empty(self) -> bool:
return len(self._queue) == 0
def enqueue(self, msg: ForwardMsg) -> None:
"""Add message into queue, possibly composing it with another message."""
if not _is_composable_message(msg):
self._queue.append(msg)
return
# If there's a Delta message with the same delta_path already in
# the queue - meaning that it refers to the same location in
# the app - we attempt to combine this new Delta into the old
# one. This is an optimization that prevents redundant Deltas
# from being sent to the frontend.
delta_key = tuple(msg.metadata.delta_path)
if delta_key in self._delta_index_map:
index = self._delta_index_map[delta_key]
old_msg = self._queue[index]
composed_delta = _maybe_compose_deltas(old_msg.delta, msg.delta)
if composed_delta is not None:
new_msg = ForwardMsg()
new_msg.delta.CopyFrom(composed_delta)
new_msg.metadata.CopyFrom(msg.metadata)
self._queue[index] = new_msg
return
# No composition occurred. Append this message to the queue, and
# store its index for potential future composition.
self._delta_index_map[delta_key] = len(self._queue)
self._queue.append(msg)
def clear(self) -> None:
"""Clear the queue."""
self._queue = []
self._delta_index_map = dict()
def flush(self) -> List[ForwardMsg]:
"""Clear the queue and return a list of the messages it contained
before being cleared.
"""
queue = self._queue
self.clear()
return queue
def _is_composable_message(msg: ForwardMsg) -> bool:
"""True if the ForwardMsg is potentially composable with other ForwardMsgs."""
if not msg.HasField("delta"):
# Non-delta messages are never composable.
return False
# We never compose add_rows messages in Python, because the add_rows
# operation can raise errors, and we don't have a good way of handling
# those errors in the message queue.
delta_type = msg.delta.WhichOneof("type")
return delta_type != "add_rows" and delta_type != "arrow_add_rows"
def _maybe_compose_deltas(old_delta: Delta, new_delta: Delta) -> Optional[Delta]:
"""Combines new_delta onto old_delta if possible.
If the combination takes place, the function returns a new Delta that
should replace old_delta in the queue.
If the new_delta is incompatible with old_delta, the function returns None.
In this case, the new_delta should just be appended to the queue as normal.
"""
old_delta_type = old_delta.WhichOneof("type")
if old_delta_type == "add_block":
# We never replace add_block deltas, because blocks can have
# other dependent deltas later in the queue. For example:
#
# placeholder = st.empty()
# placeholder.columns(1)
# placeholder.empty()
#
# The call to "placeholder.columns(1)" creates two blocks, a parent
# container with delta_path (0, 0), and a column child with
# delta_path (0, 0, 0). If the final "placeholder.empty()" Delta
# is composed with the parent container Delta, the frontend will
# throw an error when it tries to add that column child to what is
# now just an element, and not a block.
return None
new_delta_type = new_delta.WhichOneof("type")
if new_delta_type == "new_element":
return new_delta
if new_delta_type == "add_block":
return new_delta
return None

View File

@@ -0,0 +1,20 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.legacy_caching.caching import cache as cache
from streamlit.runtime.legacy_caching.caching import clear_cache as clear_cache
from streamlit.runtime.legacy_caching.caching import get_cache_path as get_cache_path
from streamlit.runtime.legacy_caching.caching import (
maybe_show_cached_st_function_warning as maybe_show_cached_st_function_warning,
)

View File

@@ -0,0 +1,895 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A library of caching utilities."""
import contextlib
import functools
import hashlib
import inspect
import math
import os
import pickle
import shutil
import threading
import time
from collections import namedtuple
from dataclasses import dataclass
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Optional,
TypeVar,
Union,
cast,
overload,
)
from cachetools import TTLCache
from pympler.asizeof import asizeof
import streamlit as st
from streamlit import config, file_util, util
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.elements.spinner import spinner
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.errors import StreamlitAPIWarning
from streamlit.logger import get_logger
from streamlit.runtime.caching import CACHE_DOCS_URL
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
from streamlit.runtime.legacy_caching.hashing import (
HashFuncsDict,
HashReason,
update_hash,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
_LOGGER = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
_TTLCACHE_TIMER = time.monotonic
_CacheEntry = namedtuple("_CacheEntry", ["value", "hash"])
_DiskCacheEntry = namedtuple("_DiskCacheEntry", ["value"])
# When we show the "st.cache is deprecated" warning, we make a recommendation about which new
# cache decorator to switch to for the following data types:
NEW_CACHE_FUNC_RECOMMENDATIONS: Dict[str, CacheType] = {
# cache_data recommendations:
"str": CacheType.DATA,
"float": CacheType.DATA,
"int": CacheType.DATA,
"bytes": CacheType.DATA,
"bool": CacheType.DATA,
"datetime.datetime": CacheType.DATA,
"pandas.DataFrame": CacheType.DATA,
"pandas.Series": CacheType.DATA,
"numpy.bool_": CacheType.DATA,
"numpy.bool8": CacheType.DATA,
"numpy.ndarray": CacheType.DATA,
"numpy.float_": CacheType.DATA,
"numpy.float16": CacheType.DATA,
"numpy.float32": CacheType.DATA,
"numpy.float64": CacheType.DATA,
"numpy.float96": CacheType.DATA,
"numpy.float128": CacheType.DATA,
"numpy.int_": CacheType.DATA,
"numpy.int8": CacheType.DATA,
"numpy.int16": CacheType.DATA,
"numpy.int32": CacheType.DATA,
"numpy.int64": CacheType.DATA,
"numpy.intp": CacheType.DATA,
"numpy.uint8": CacheType.DATA,
"numpy.uint16": CacheType.DATA,
"numpy.uint32": CacheType.DATA,
"numpy.uint64": CacheType.DATA,
"numpy.uintp": CacheType.DATA,
"PIL.Image.Image": CacheType.DATA,
"plotly.graph_objects.Figure": CacheType.DATA,
"matplotlib.figure.Figure": CacheType.DATA,
"altair.Chart": CacheType.DATA,
# cache_resource recommendations:
"pyodbc.Connection": CacheType.RESOURCE,
"pymongo.mongo_client.MongoClient": CacheType.RESOURCE,
"mysql.connector.MySQLConnection": CacheType.RESOURCE,
"psycopg2.connection": CacheType.RESOURCE,
"psycopg2.extensions.connection": CacheType.RESOURCE,
"snowflake.connector.connection.SnowflakeConnection": CacheType.RESOURCE,
"snowflake.snowpark.sessions.Session": CacheType.RESOURCE,
"sqlalchemy.engine.base.Engine": CacheType.RESOURCE,
"sqlite3.Connection": CacheType.RESOURCE,
"torch.nn.Module": CacheType.RESOURCE,
"tensorflow.keras.Model": CacheType.RESOURCE,
"tensorflow.Module": CacheType.RESOURCE,
"tensorflow.compat.v1.Session": CacheType.RESOURCE,
"transformers.Pipeline": CacheType.RESOURCE,
"transformers.PreTrainedTokenizer": CacheType.RESOURCE,
"transformers.PreTrainedTokenizerFast": CacheType.RESOURCE,
"transformers.PreTrainedTokenizerBase": CacheType.RESOURCE,
"transformers.PreTrainedModel": CacheType.RESOURCE,
"transformers.TFPreTrainedModel": CacheType.RESOURCE,
"transformers.FlaxPreTrainedModel": CacheType.RESOURCE,
}
def _make_deprecation_warning(cached_value: Any) -> str:
"""Build a deprecation warning string for a cache function that has returned the given
value.
"""
typename = type(cached_value).__qualname__
cache_type_rec = NEW_CACHE_FUNC_RECOMMENDATIONS.get(typename)
if cache_type_rec is not None:
# We have a recommended cache func for the cached value:
return (
f"`st.cache` is deprecated. Please use one of Streamlit's new caching commands,\n"
f"`st.cache_data` or `st.cache_resource`. Based on this function's return value\n"
f"of type `{typename}`, we recommend using `st.{get_decorator_api_name(cache_type_rec)}`.\n\n"
f"More information [in our docs]({CACHE_DOCS_URL})."
)
# We do not have a recommended cache func for the cached value:
return (
f"`st.cache` is deprecated. Please use one of Streamlit's new caching commands,\n"
f"`st.cache_data` or `st.cache_resource`.\n\n"
f"More information [in our docs]({CACHE_DOCS_URL})."
)
@dataclass
class MemCache:
cache: TTLCache
display_name: str
class _MemCaches(CacheStatsProvider):
"""Manages all in-memory st.cache caches"""
def __init__(self):
# Contains a cache object for each st.cache'd function
self._lock = threading.RLock()
self._function_caches: Dict[str, MemCache] = {}
def __repr__(self) -> str:
return util.repr_(self)
def get_cache(
self,
key: str,
max_entries: Optional[float],
ttl: Optional[float],
display_name: str = "",
) -> MemCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
if ttl is None:
ttl = math.inf
if not isinstance(max_entries, (int, float)):
raise RuntimeError("max_entries must be an int")
if not isinstance(ttl, (int, float)):
raise RuntimeError("ttl must be a float")
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._lock:
mem_cache = self._function_caches.get(key)
if (
mem_cache is not None
and mem_cache.cache.ttl == ttl
and mem_cache.cache.maxsize == max_entries
):
return mem_cache
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new mem_cache (key=%s, max_entries=%s, ttl=%s)",
key,
max_entries,
ttl,
)
ttl_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER)
mem_cache = MemCache(ttl_cache, display_name)
self._function_caches[key] = mem_cache
return mem_cache
def clear(self) -> None:
"""Clear all caches"""
with self._lock:
self._function_caches = {}
def get_stats(self) -> List[CacheStat]:
with self._lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats = [
CacheStat("st_cache", cache.display_name, asizeof(c))
for cache in function_caches.values()
for c in cache.cache
]
return stats
# Our singleton _MemCaches instance
_mem_caches = _MemCaches()
# A thread-local counter that's incremented when we enter @st.cache
# and decremented when we exit.
class ThreadLocalCacheInfo(threading.local):
def __init__(self):
self.cached_func_stack: List[Callable[..., Any]] = []
self.suppress_st_function_warning = 0
def __repr__(self) -> str:
return util.repr_(self)
_cache_info = ThreadLocalCacheInfo()
@contextlib.contextmanager
def _calling_cached_function(func: Callable[..., Any]) -> Iterator[None]:
_cache_info.cached_func_stack.append(func)
try:
yield
finally:
_cache_info.cached_func_stack.pop()
@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
_cache_info.suppress_st_function_warning += 1
try:
yield
finally:
_cache_info.suppress_st_function_warning -= 1
assert _cache_info.suppress_st_function_warning >= 0
def _show_cached_st_function_warning(
dg: "st.delta_generator.DeltaGenerator",
st_func_name: str,
cached_func: Callable[..., Any],
) -> None:
# Avoid infinite recursion by suppressing additional cached
# function warnings from within the cached function warning.
with suppress_cached_st_function_warning():
e = CachedStFunctionWarning(st_func_name, cached_func)
dg.exception(e)
def maybe_show_cached_st_function_warning(
dg: "st.delta_generator.DeltaGenerator", st_func_name: str
) -> None:
"""If appropriate, warn about calling st.foo inside @cache.
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
the user when they're calling st.foo() from within a function that is
wrapped in @st.cache.
Parameters
----------
dg : DeltaGenerator
The DeltaGenerator to publish the warning to.
st_func_name : str
The name of the Streamlit function that was called.
"""
if (
len(_cache_info.cached_func_stack) > 0
and _cache_info.suppress_st_function_warning <= 0
):
cached_func = _cache_info.cached_func_stack[-1]
_show_cached_st_function_warning(dg, st_func_name, cached_func)
def _read_from_mem_cache(
mem_cache: MemCache,
key: str,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict],
) -> Any:
cache = mem_cache.cache
if key in cache:
entry = cache[key]
if not allow_output_mutation:
computed_output_hash = _get_output_hash(
entry.value, func_or_code, hash_funcs
)
stored_output_hash = entry.hash
if computed_output_hash != stored_output_hash:
_LOGGER.debug("Cached object was mutated: %s", key)
raise CachedObjectMutationError(entry.value, func_or_code)
_LOGGER.debug("Memory cache HIT: %s", type(entry.value))
return entry.value
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheKeyNotFoundError("Key not found in mem cache")
def _write_to_mem_cache(
mem_cache: MemCache,
key: str,
value: Any,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict],
) -> None:
if allow_output_mutation:
hash = None
else:
hash = _get_output_hash(value, func_or_code, hash_funcs)
mem_cache.display_name = f"{func_or_code.__module__}.{func_or_code.__qualname__}"
mem_cache.cache[key] = _CacheEntry(value=value, hash=hash)
def _get_output_hash(
value: Any, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict]
) -> bytes:
hasher = hashlib.new("md5")
update_hash(
value,
hasher=hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_OUTPUT,
hash_source=func_or_code,
)
return hasher.digest()
def _read_from_disk_cache(key: str) -> Any:
path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key)
try:
with file_util.streamlit_read(path, binary=True) as input:
entry = pickle.load(input)
value = entry.value
_LOGGER.debug("Disk cache HIT: %s", type(value))
except util.Error as e:
_LOGGER.error(e)
raise CacheError("Unable to read from cache: %s" % e)
except FileNotFoundError:
raise CacheKeyNotFoundError("Key not found in disk cache")
return value
def _write_to_disk_cache(key: str, value: Any) -> None:
path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key)
try:
with file_util.streamlit_write(path, binary=True) as output:
entry = _DiskCacheEntry(value=value)
pickle.dump(entry, output, pickle.HIGHEST_PROTOCOL)
except util.Error as e:
_LOGGER.debug(e)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, IOError, OSError):
# If we can't remove the file, it's not a big deal.
pass
raise CacheError("Unable to write to cache: %s" % e)
def _read_from_cache(
mem_cache: MemCache,
key: str,
persist: bool,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict] = None,
) -> Any:
"""Read a value from the cache.
Our goal is to read from memory if possible. If the data was mutated (hash
changed), we show a warning. If reading from memory fails, we either read
from disk or rerun the code.
"""
try:
return _read_from_mem_cache(
mem_cache, key, allow_output_mutation, func_or_code, hash_funcs
)
except CachedObjectMutationError as e:
handle_uncaught_app_exception(CachedObjectMutationWarning(e))
return e.cached_value
except CacheKeyNotFoundError as e:
if persist:
value = _read_from_disk_cache(key)
_write_to_mem_cache(
mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs
)
return value
raise e
@gather_metrics("_cache_object")
def _write_to_cache(
mem_cache: MemCache,
key: str,
value: Any,
persist: bool,
allow_output_mutation: bool,
func_or_code: Callable[..., Any],
hash_funcs: Optional[HashFuncsDict] = None,
):
_write_to_mem_cache(
mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs
)
if persist:
_write_to_disk_cache(key, value)
F = TypeVar("F", bound=Callable[..., Any])
@overload
def cache(
func: F,
persist: bool = False,
allow_output_mutation: bool = False,
show_spinner: bool = True,
suppress_st_warning: bool = False,
hash_funcs: Optional[HashFuncsDict] = None,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
) -> F:
...
@overload
def cache(
func: None = None,
persist: bool = False,
allow_output_mutation: bool = False,
show_spinner: bool = True,
suppress_st_warning: bool = False,
hash_funcs: Optional[HashFuncsDict] = None,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
) -> Callable[[F], F]:
...
def cache(
func: Optional[F] = None,
persist: bool = False,
allow_output_mutation: bool = False,
show_spinner: bool = True,
suppress_st_warning: bool = False,
hash_funcs: Optional[HashFuncsDict] = None,
max_entries: Optional[int] = None,
ttl: Optional[float] = None,
) -> Union[Callable[[F], F], F]:
"""Function decorator to memoize function executions.
Parameters
----------
func : callable
The function to cache. Streamlit hashes the function and dependent code.
persist : boolean
Whether to persist the cache on disk.
allow_output_mutation : boolean
Streamlit shows a warning when return values are mutated, as that
can have unintended consequences. This is done by hashing the return value internally.
If you know what you're doing and would like to override this warning, set this to True.
show_spinner : boolean
Enable the spinner. Default is True to show a spinner when there is
a cache miss.
suppress_st_warning : boolean
Suppress warnings about calling Streamlit commands from within
the cached function.
hash_funcs : dict or None
Mapping of types or fully qualified names to hash functions. This is used to override
the behavior of the hasher inside Streamlit's caching mechanism: when the hasher
encounters an object, it will first check to see if its type matches a key in this
dict and, if so, will use the provided function to generate a hash for it. See below
for an example of how this can be used.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
ttl : float or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
...
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> @st.cache(persist=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
To disable hashing return values, set the ``allow_output_mutation`` parameter to ``True``:
>>> @st.cache(allow_output_mutation=True)
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
To override the default hashing behavior, pass a custom hash function.
You can do that by mapping a type (e.g. ``MongoClient``) to a hash function (``id``) like this:
>>> @st.cache(hash_funcs={MongoClient: id})
... def connect_to_database(url):
... return MongoClient(url)
Alternatively, you can map the type's fully-qualified name
(e.g. ``"pymongo.mongo_client.MongoClient"``) to the hash function instead:
>>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id})
... def connect_to_database(url):
... return MongoClient(url)
"""
_LOGGER.debug("Entering st.cache: %s", func)
# Support passing the params via function decorator, e.g.
# @st.cache(persist=True, allow_output_mutation=True)
if func is None:
def wrapper(f: F) -> F:
return cache(
func=f,
persist=persist,
allow_output_mutation=allow_output_mutation,
show_spinner=show_spinner,
suppress_st_warning=suppress_st_warning,
hash_funcs=hash_funcs,
max_entries=max_entries,
ttl=ttl,
)
return wrapper
else:
# To make mypy type narrow Optional[F] -> F
non_optional_func = func
cache_key = None
@functools.wraps(non_optional_func)
def wrapped_func(*args, **kwargs):
"""Wrapper function that only calls the underlying function on a cache miss.
Cached objects are stored in the cache/ directory.
"""
if not config.get_option("client.caching"):
_LOGGER.debug("Purposefully skipping cache")
return non_optional_func(*args, **kwargs)
name = non_optional_func.__qualname__
if len(args) == 0 and len(kwargs) == 0:
message = "Running `%s()`." % name
else:
message = "Running `%s(...)`." % name
def get_or_create_cached_value():
nonlocal cache_key
if cache_key is None:
# Delay generating the cache key until the first call.
# This way we can see values of globals, including functions
# defined after this one.
# If we generated the key earlier we would only hash those
# globals by name, and miss changes in their code or value.
cache_key = _hash_func(non_optional_func, hash_funcs)
# First, get the cache that's attached to this function.
# This cache's key is generated (above) from the function's code.
mem_cache = _mem_caches.get_cache(cache_key, max_entries, ttl)
# Next, calculate the key for the value we'll be searching for
# within that cache. This key is generated from both the function's
# code and the arguments that are passed into it. (Even though this
# key is used to index into a per-function cache, it must be
# globally unique, because it is *also* used for a global on-disk
# cache that is *not* per-function.)
value_hasher = hashlib.new("md5")
if args:
update_hash(
args,
hasher=value_hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_ARGS,
hash_source=non_optional_func,
)
if kwargs:
update_hash(
kwargs,
hasher=value_hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_ARGS,
hash_source=non_optional_func,
)
value_key = value_hasher.hexdigest()
# Avoid recomputing the body's hash by just appending the
# previously-computed hash to the arg hash.
value_key = "%s-%s" % (value_key, cache_key)
_LOGGER.debug("Cache key: %s", value_key)
try:
return_value = _read_from_cache(
mem_cache=mem_cache,
key=value_key,
persist=persist,
allow_output_mutation=allow_output_mutation,
func_or_code=non_optional_func,
hash_funcs=hash_funcs,
)
_LOGGER.debug("Cache hit: %s", non_optional_func)
except CacheKeyNotFoundError:
_LOGGER.debug("Cache miss: %s", non_optional_func)
with _calling_cached_function(non_optional_func):
if suppress_st_warning:
with suppress_cached_st_function_warning():
return_value = non_optional_func(*args, **kwargs)
else:
return_value = non_optional_func(*args, **kwargs)
_write_to_cache(
mem_cache=mem_cache,
key=value_key,
value=return_value,
persist=persist,
allow_output_mutation=allow_output_mutation,
func_or_code=non_optional_func,
hash_funcs=hash_funcs,
)
# st.cache is deprecated. We show a warning every time it's used.
show_deprecation_warning(_make_deprecation_warning(return_value))
return return_value
if show_spinner:
with spinner(message):
return get_or_create_cached_value()
else:
return get_or_create_cached_value()
# Make this a well-behaved decorator by preserving important function
# attributes.
try:
wrapped_func.__dict__.update(non_optional_func.__dict__)
except AttributeError:
# For normal functions this should never happen, but if so it's not problematic.
pass
return cast(F, wrapped_func)
def _hash_func(func: Callable[..., Any], hash_funcs: Optional[HashFuncsDict]) -> str:
# Create the unique key for a function's cache. The cache will be retrieved
# from inside the wrapped function.
#
# A naive implementation would involve simply creating the cache object
# right in the wrapper, which in a normal Python script would be executed
# only once. But in Streamlit, we reload all modules related to a user's
# app when the app is re-run, which means that - among other things - all
# function decorators in the app will be re-run, and so any decorator-local
# objects will be recreated.
#
# Furthermore, our caches can be destroyed and recreated (in response to
# cache clearing, for example), which means that retrieving the function's
# cache in the decorator (so that the wrapped function can save a lookup)
# is incorrect: the cache itself may be recreated between
# decorator-evaluation time and decorated-function-execution time. So we
# must retrieve the cache object *and* perform the cached-value lookup
# inside the decorated function.
func_hasher = hashlib.new("md5")
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
# We do not pass `hash_funcs` here, because we don't want our function's
# name to get an unexpected hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
hash_funcs=None,
hash_reason=HashReason.CACHING_FUNC_BODY,
hash_source=func,
)
# Include the function's body in the hash. We *do* pass hash_funcs here,
# because this step will be hashing any objects referenced in the function
# body.
update_hash(
func,
hasher=func_hasher,
hash_funcs=hash_funcs,
hash_reason=HashReason.CACHING_FUNC_BODY,
hash_source=func,
)
cache_key = func_hasher.hexdigest()
_LOGGER.debug(
"mem_cache key for %s.%s: %s", func.__module__, func.__qualname__, cache_key
)
return cache_key
def clear_cache() -> bool:
"""Clear the memoization cache.
Returns
-------
boolean
True if the disk cache was cleared. False otherwise (e.g. cache file
doesn't exist on disk).
"""
_clear_mem_cache()
return _clear_disk_cache()
def get_cache_path() -> str:
return file_util.get_streamlit_file_path("cache")
def _clear_disk_cache() -> bool:
# TODO: Only delete disk cache for functions related to the user's current
# script.
cache_path = get_cache_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
return True
return False
def _clear_mem_cache() -> None:
_mem_caches.clear()
class CacheError(Exception):
pass
class CacheKeyNotFoundError(Exception):
pass
class CachedObjectMutationError(ValueError):
# This is used internally, but never shown to the user.
# Users see CachedObjectMutationWarning instead.
def __init__(self, cached_value, func_or_code):
self.cached_value = cached_value
if inspect.iscode(func_or_code):
self.cached_func_name = "a code block"
else:
self.cached_func_name = _get_cached_func_name_md(func_or_code)
def __repr__(self) -> str:
return util.repr_(self)
class CachedStFunctionWarning(StreamlitAPIWarning):
def __init__(self, st_func_name, cached_func):
msg = self._get_message(st_func_name, cached_func)
super(CachedStFunctionWarning, self).__init__(msg)
def _get_message(self, st_func_name, cached_func):
args = {
"st_func_name": "`st.%s()` or `st.write()`" % st_func_name,
"func_name": _get_cached_func_name_md(cached_func),
}
return (
"""
Your script uses %(st_func_name)s to write to your Streamlit app from within
some cached code at %(func_name)s. This code will only be called when we detect
a cache "miss", which can lead to unexpected results.
How to fix this:
* Move the %(st_func_name)s call outside %(func_name)s.
* Or, if you know what you're doing, use `@st.cache(suppress_st_warning=True)`
to suppress the warning.
"""
% args
).strip("\n")
class CachedObjectMutationWarning(StreamlitAPIWarning):
def __init__(self, orig_exc):
msg = self._get_message(orig_exc)
super(CachedObjectMutationWarning, self).__init__(msg)
def _get_message(self, orig_exc):
return (
"""
Return value of %(func_name)s was mutated between runs.
By default, Streamlit's cache should be treated as immutable, or it may behave
in unexpected ways. You received this warning because Streamlit detected
that an object returned by %(func_name)s was mutated outside of %(func_name)s.
How to fix this:
* If you did not mean to mutate that return value:
- If possible, inspect your code to find and remove that mutation.
- Otherwise, you could also clone the returned value so you can freely
mutate it.
* If you actually meant to mutate the return value and know the consequences of
doing so, annotate the function with `@st.cache(allow_output_mutation=True)`.
For more information and detailed solutions check out [our documentation.]
(https://docs.streamlit.io/library/advanced-features/caching)
"""
% {"func_name": orig_exc.cached_func_name}
).strip("\n")
def _get_cached_func_name_md(func: Callable[..., Any]) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
else:
return "a cached function"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,230 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides global MediaFileManager object as `media_file_manager`."""
import collections
import threading
from typing import Dict, Optional, Set, Union
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorage
LOGGER = get_logger(__name__)
def _get_session_id() -> str:
"""Get the active AppSession's session_id."""
from streamlit.runtime.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx()
if ctx is None:
# This is only None when running "python myscript.py" rather than
# "streamlit run myscript.py". In which case the session ID doesn't
# matter and can just be a constant, as there's only ever "session".
return "dontcare"
else:
return ctx.session_id
class MediaFileMetadata:
"""Metadata that the MediaFileManager needs for each file it manages."""
def __init__(self, kind: MediaFileKind = MediaFileKind.MEDIA):
self._kind = kind
self._is_marked_for_delete = False
@property
def kind(self) -> MediaFileKind:
return self._kind
@property
def is_marked_for_delete(self) -> bool:
return self._is_marked_for_delete
def mark_for_delete(self) -> None:
self._is_marked_for_delete = True
class MediaFileManager:
"""In-memory file manager for MediaFile objects.
This keeps track of:
- Which files exist, and what their IDs are. This is important so we can
serve files by ID -- that's the whole point of this class!
- Which files are being used by which AppSession (by ID). This is
important so we can remove files from memory when no more sessions need
them.
- The exact location in the app where each file is being used (i.e. the
file's "coordinates"). This is is important so we can mark a file as "not
being used by a certain session" if it gets replaced by another file at
the same coordinates. For example, when doing an animation where the same
image is constantly replace with new frames. (This doesn't solve the case
where the file's coordinates keep changing for some reason, though! e.g.
if new elements keep being prepended to the app. Unlikely to happen, but
we should address it at some point.)
"""
def __init__(self, storage: MediaFileStorage):
self._storage = storage
# Dict of [file_id -> MediaFileMetadata]
self._file_metadata: Dict[str, MediaFileMetadata] = dict()
# Dict[session ID][coordinates] -> file_id.
self._files_by_session_and_coord: Dict[
str, Dict[str, str]
] = collections.defaultdict(dict)
# MediaFileManager is used from multiple threads, so all operations
# need to be protected with a Lock. (This is not an RLock, which
# means taking it multiple times from the same thread will deadlock.)
self._lock = threading.Lock()
def _get_inactive_file_ids(self) -> Set[str]:
"""Compute the set of files that are stored in the manager, but are
not referenced by any active session. These are files that can be
safely deleted.
Thread safety: callers must hold `self._lock`.
"""
# Get the set of all our file IDs.
file_ids = set(self._file_metadata.keys())
# Subtract all IDs that are in use by each session
for session_file_ids_by_coord in self._files_by_session_and_coord.values():
file_ids.difference_update(session_file_ids_by_coord.values())
return file_ids
def remove_orphaned_files(self) -> None:
"""Remove all files that are no longer referenced by any active session.
Safe to call from any thread.
"""
LOGGER.debug("Removing orphaned files...")
with self._lock:
for file_id in self._get_inactive_file_ids():
file = self._file_metadata[file_id]
if file.kind == MediaFileKind.MEDIA:
self._delete_file(file_id)
elif file.kind == MediaFileKind.DOWNLOADABLE:
if file.is_marked_for_delete:
self._delete_file(file_id)
else:
file.mark_for_delete()
def _delete_file(self, file_id: str) -> None:
"""Delete the given file from storage, and remove its metadata from
self._files_by_id.
Thread safety: callers must hold `self._lock`.
"""
LOGGER.debug("Deleting File: %s", file_id)
self._storage.delete_file(file_id)
del self._file_metadata[file_id]
def clear_session_refs(self, session_id: Optional[str] = None) -> None:
"""Remove the given session's file references.
(This does not remove any files from the manager - you must call
`remove_orphaned_files` for that.)
Should be called whenever ScriptRunner completes and when a session ends.
Safe to call from any thread.
"""
if session_id is None:
session_id = _get_session_id()
LOGGER.debug("Disconnecting files for session with ID %s", session_id)
with self._lock:
if session_id in self._files_by_session_and_coord:
del self._files_by_session_and_coord[session_id]
LOGGER.debug(
"Sessions still active: %r", self._files_by_session_and_coord.keys()
)
LOGGER.debug(
"Files: %s; Sessions with files: %s",
len(self._file_metadata),
len(self._files_by_session_and_coord),
)
def add(
self,
path_or_data: Union[bytes, str],
mimetype: str,
coordinates: str,
file_name: Optional[str] = None,
is_for_static_download: bool = False,
) -> str:
"""Add a new MediaFile with the given parameters and return its URL.
If an identical file already exists, return the existing URL
and registers the current session as a user.
Safe to call from any thread.
Parameters
----------
path_or_data : bytes or str
If bytes: the media file's raw data. If str: the name of a file
to load from disk.
mimetype : str
The mime type for the file. E.g. "audio/mpeg".
This string will be used in the "Content-Type" header when the file
is served over HTTP.
coordinates : str
Unique string identifying an element's location.
Prevents memory leak of "forgotten" file IDs when element media
is being replaced-in-place (e.g. an st.image stream).
coordinates should be of the form: "1.(3.-14).5"
file_name : str or None
Optional file_name. Used to set the filename in the response header.
is_for_static_download: bool
Indicate that data stored for downloading as a file,
not as a media for rendering at page. [default: False]
Returns
-------
str
The url that the frontend can use to fetch the media.
Raises
------
If a filename is passed, any Exception raised when trying to read the
file will be re-raised.
"""
session_id = _get_session_id()
with self._lock:
kind = (
MediaFileKind.DOWNLOADABLE
if is_for_static_download
else MediaFileKind.MEDIA
)
file_id = self._storage.load_and_get_id(
path_or_data, mimetype, kind, file_name
)
metadata = MediaFileMetadata(kind=kind)
self._file_metadata[file_id] = metadata
self._files_by_session_and_coord[session_id][coordinates] = file_id
return self._storage.get_url(file_id)

View File

@@ -0,0 +1,143 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from enum import Enum
from typing import Optional, Union
from typing_extensions import Protocol
class MediaFileKind(Enum):
# st.image, st.video, st.audio files
MEDIA = "media"
# st.download_button files
DOWNLOADABLE = "downloadable"
class MediaFileStorageError(Exception):
"""Exception class for errors raised by MediaFileStorage.
When running in "development mode", the full text of these errors
is displayed in the frontend, so errors should be human-readable
(and actionable).
When running in "release mode", errors are redacted on the
frontend; we instead show a generic "Something went wrong!" message.
"""
class MediaFileStorage(Protocol):
@abstractmethod
def load_and_get_id(
self,
path_or_data: Union[str, bytes],
mimetype: str,
kind: MediaFileKind,
filename: Optional[str] = None,
) -> str:
"""Load the given file path or bytes into the manager and return
an ID that uniquely identifies it.
Its an error to pass a URL to this function. (Media stored at
external URLs can be served directly to the Streamlit frontend;
theres no need to store this data in MediaFileStorage.)
Parameters
----------
path_or_data
A path to a file, or the file's raw data as bytes.
mimetype
The medias mimetype. Used to set the Content-Type header when
serving the media over HTTP.
kind
The kind of file this is: either MEDIA, or DOWNLOADABLE.
filename : str or None
Optional filename. Used to set the filename in the response header.
Returns
-------
str
The unique ID of the media file.
Raises
------
MediaFileStorageError
Raised if the media can't be loaded (for example, if a file
path is invalid).
"""
raise NotImplementedError
@abstractmethod
def get_url(self, file_id: str) -> str:
"""Return a URL for a file in the manager.
Parameters
----------
file_id
The file's ID, returned from load_media_and_get_id().
Returns
-------
str
A URL that the frontend can load the file from. Because this
URL may expire, it should not be cached!
Raises
------
MediaFileStorageError
Raised if the manager doesn't contain an object with the given ID.
"""
raise NotImplementedError
@abstractmethod
def delete_file(self, file_id: str) -> None:
"""Delete a file from the manager.
This should be called when a given file is no longer referenced
by any connected client, so that the MediaFileStorage can free its
resources.
Calling `delete_file` on a file_id that doesn't exist is allowed,
and is a no-op. (This means that multiple `delete_file` calls with
the same file_id is not an error.)
Note: implementations can choose to ignore `delete_file` calls -
this function is a *suggestion*, not a *command*. Callers should
not rely on file deletion happening immediately (or at all).
Parameters
----------
file_id
The file's ID, returned from load_media_and_get_id().
Returns
-------
None
Raises
------
MediaFileStorageError
Raised if file deletion fails for any reason. Note that these
failures will generally not be shown on the frontend (file
deletion usually occurs on session disconnect).
"""
raise NotImplementedError

View File

@@ -0,0 +1,183 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MediaFileStorage implementation that stores files in memory."""
import contextlib
import hashlib
import mimetypes
import os.path
from typing import Dict, List, NamedTuple, Optional, Union
from typing_extensions import Final
from streamlit.logger import get_logger
from streamlit.runtime.media_file_storage import (
MediaFileKind,
MediaFileStorage,
MediaFileStorageError,
)
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
LOGGER = get_logger(__name__)
# Mimetype -> filename extension map for the `get_extension_for_mimetype`
# function. We use Python's `mimetypes.guess_extension` for most mimetypes,
# but (as of Python 3.9) `mimetypes.guess_extension("audio/wav")` returns None,
# so we handle it ourselves.
PREFERRED_MIMETYPE_EXTENSION_MAP: Final = {
"audio/wav": ".wav",
}
def _calculate_file_id(
data: bytes, mimetype: str, filename: Optional[str] = None
) -> str:
"""Hash data, mimetype, and an optional filename to generate a stable file ID.
Parameters
----------
data
Content of in-memory file in bytes. Other types will throw TypeError.
mimetype
Any string. Will be converted to bytes and used to compute a hash.
filename
Any string. Will be converted to bytes and used to compute a hash.
"""
filehash = hashlib.new("sha224")
filehash.update(data)
filehash.update(bytes(mimetype.encode()))
if filename is not None:
filehash.update(bytes(filename.encode()))
return filehash.hexdigest()
def get_extension_for_mimetype(mimetype: str) -> str:
if mimetype in PREFERRED_MIMETYPE_EXTENSION_MAP:
return PREFERRED_MIMETYPE_EXTENSION_MAP[mimetype]
extension = mimetypes.guess_extension(mimetype, strict=False)
if extension is None:
return ""
return extension
class MemoryFile(NamedTuple):
"""A MediaFile stored in memory."""
content: bytes
mimetype: str
kind: MediaFileKind
filename: Optional[str]
@property
def content_size(self) -> int:
return len(self.content)
class MemoryMediaFileStorage(MediaFileStorage, CacheStatsProvider):
def __init__(self, media_endpoint: str):
"""Create a new MemoryMediaFileStorage instance
Parameters
----------
media_endpoint
The name of the local endpoint that media is served from.
This endpoint should start with a forward-slash (e.g. "/media").
"""
self._files_by_id: Dict[str, MemoryFile] = {}
self._media_endpoint = media_endpoint
def load_and_get_id(
self,
path_or_data: Union[str, bytes],
mimetype: str,
kind: MediaFileKind,
filename: Optional[str] = None,
) -> str:
"""Add a file to the manager and return its ID."""
file_data: bytes
if isinstance(path_or_data, str):
file_data = self._read_file(path_or_data)
else:
file_data = path_or_data
# Because our file_ids are stable, if we already have a file with the
# given ID, we don't need to create a new one.
file_id = _calculate_file_id(file_data, mimetype, filename)
if file_id not in self._files_by_id:
LOGGER.debug("Adding media file %s", file_id)
media_file = MemoryFile(
content=file_data, mimetype=mimetype, kind=kind, filename=filename
)
self._files_by_id[file_id] = media_file
return file_id
def get_file(self, filename: str) -> MemoryFile:
"""Return the MemoryFile with the given filename. Filenames are of the
form "file_id.extension". (Note that this is *not* the optional
user-specified filename for download files.)
Raises a MediaFileStorageError if no such file exists.
"""
file_id = os.path.splitext(filename)[0]
try:
return self._files_by_id[file_id]
except KeyError as e:
raise MediaFileStorageError(
f"Bad filename '{filename}'. (No media file with id '{file_id}')"
) from e
def get_url(self, file_id: str) -> str:
"""Get a URL for a given media file. Raise a MediaFileStorageError if
no such file exists.
"""
media_file = self.get_file(file_id)
extension = get_extension_for_mimetype(media_file.mimetype)
return f"{self._media_endpoint}/{file_id}{extension}"
def delete_file(self, file_id: str) -> None:
"""Delete the file with the given ID."""
# We swallow KeyErrors here - it's not an error to delete a file
# that doesn't exist.
with contextlib.suppress(KeyError):
del self._files_by_id[file_id]
def _read_file(self, filename: str) -> bytes:
"""Read a file into memory. Raise MediaFileStorageError if we can't."""
try:
with open(filename, "rb") as f:
return f.read()
except Exception as ex:
raise MediaFileStorageError(f"Error opening '{filename}'") from ex
def get_stats(self) -> List[CacheStat]:
# We operate on a copy of our dict, to avoid race conditions
# with other threads that may be manipulating the cache.
files_by_id = self._files_by_id.copy()
stats: List[CacheStat] = []
for file_id, file in files_by_id.items():
stats.append(
CacheStat(
category_name="st_memory_media_file_storage",
cache_name="",
byte_length=len(file.content),
)
)
return stats

View File

@@ -0,0 +1,72 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, MutableMapping, Optional
from cachetools import TTLCache
from streamlit.runtime.session_manager import SessionInfo, SessionStorage
class MemorySessionStorage(SessionStorage):
"""A SessionStorage that stores sessions in memory.
At most maxsize sessions are stored with a TTL of ttl seconds. This class is really
just a thin wrapper around cachetools.TTLCache that complies with the SessionStorage
protocol.
"""
# NOTE: The defaults for maxsize and ttl are chosen arbitrarily for now. These
# numbers are reasonable as the main problems we're trying to solve at the moment are
# caused by transient disconnects that are usually just short network blips. In the
# future, we may want to increase both to support use cases such as saving state for
# much longer periods of time. For example, we may want session state to persist if
# a user closes their laptop lid and comes back to an app hours later.
def __init__(
self,
maxsize: int = 128,
ttl_seconds: int = 2 * 60, # 2 minutes
) -> None:
"""Instantiate a new MemorySessionStorage.
Parameters
----------
maxsize
The maximum number of sessions we allow to be stored in this
MemorySessionStorage. If an entry needs to be removed because we have
exceeded this number, either
* an expired entry is removed, or
* the least recently used entry is removed (if no entries have expired).
ttl_seconds
The time in seconds for an entry added to a MemorySessionStorage to live.
After this amount of time has passed for a given entry, it becomes
inaccessible and will be removed eventually.
"""
self._cache: MutableMapping[str, SessionInfo] = TTLCache(
maxsize=maxsize, ttl=ttl_seconds
)
def get(self, session_id: str) -> Optional[SessionInfo]:
return self._cache.get(session_id, None)
def save(self, session_info: SessionInfo) -> None:
self._cache[session_info.session.id] = session_info
def delete(self, session_id: str) -> None:
del self._cache[session_id]
def list(self) -> List[SessionInfo]:
return list(self._cache.values())

View File

@@ -0,0 +1,373 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import os
import sys
import threading
import time
import uuid
from collections.abc import Sized
from functools import wraps
from timeit import default_timer as timer
from typing import Any, Callable, List, Optional, Set, TypeVar, Union, cast, overload
from typing_extensions import Final
from streamlit import config, util
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.PageProfile_pb2 import Argument, Command
_LOGGER = get_logger(__name__)
# Limit the number of commands to keep the page profile message small
# since Segment allows only a maximum of 32kb per event.
_MAX_TRACKED_COMMANDS: Final = 200
# Only track a maximum of 25 uses per unique command since some apps use
# commands excessively (e.g. calling add_rows thousands of times in one rerun)
# making the page profile useless.
_MAX_TRACKED_PER_COMMAND: Final = 25
# A mapping to convert from the actual name to preferred/shorter representations
_OBJECT_NAME_MAPPING: Final = {
"streamlit.delta_generator.DeltaGenerator": "DG",
"pandas.core.frame.DataFrame": "DataFrame",
"plotly.graph_objs._figure.Figure": "PlotlyFigure",
"bokeh.plotting.figure.Figure": "BokehFigure",
"matplotlib.figure.Figure": "MatplotlibFigure",
"pandas.io.formats.style.Styler": "PandasStyler",
"pandas.core.indexes.base.Index": "PandasIndex",
"pandas.core.series.Series": "PandasSeries",
}
# A list of dependencies to check for attribution
_ATTRIBUTIONS_TO_CHECK: Final = [
"snowflake",
"torch",
"tensorflow",
"streamlit_extras",
"streamlit_pydantic",
"plost",
]
_ETC_MACHINE_ID_PATH = "/etc/machine-id"
_DBUS_MACHINE_ID_PATH = "/var/lib/dbus/machine-id"
def _get_machine_id_v3() -> str:
"""Get the machine ID
This is a unique identifier for a user for tracking metrics in Segment,
that is broken in different ways in some Linux distros and Docker images.
- at times just a hash of '', which means many machines map to the same ID
- at times a hash of the same string, when running in a Docker container
"""
machine_id = str(uuid.getnode())
if os.path.isfile(_ETC_MACHINE_ID_PATH):
with open(_ETC_MACHINE_ID_PATH, "r") as f:
machine_id = f.read()
elif os.path.isfile(_DBUS_MACHINE_ID_PATH):
with open(_DBUS_MACHINE_ID_PATH, "r") as f:
machine_id = f.read()
return machine_id
class Installation:
_instance_lock = threading.Lock()
_instance: Optional["Installation"] = None
@classmethod
def instance(cls) -> "Installation":
"""Returns the singleton Installation"""
# We use a double-checked locking optimization to avoid the overhead
# of acquiring the lock in the common case:
# https://en.wikipedia.org/wiki/Double-checked_locking
if cls._instance is None:
with cls._instance_lock:
if cls._instance is None:
cls._instance = Installation()
return cls._instance
def __init__(self):
self.installation_id_v3 = str(
uuid.uuid5(uuid.NAMESPACE_DNS, _get_machine_id_v3())
)
def __repr__(self) -> str:
return util.repr_(self)
@property
def installation_id(self):
return self.installation_id_v3
def _get_type_name(obj: object) -> str:
"""Get a simplified name for the type of the given object."""
with contextlib.suppress(Exception):
obj_type = type(obj)
type_name = "unknown"
if hasattr(obj_type, "__qualname__"):
type_name = obj_type.__qualname__
elif hasattr(obj_type, "__name__"):
type_name = obj_type.__name__
if obj_type.__module__ != "builtins":
# Add the full module path
type_name = f"{obj_type.__module__}.{type_name}"
if type_name in _OBJECT_NAME_MAPPING:
type_name = _OBJECT_NAME_MAPPING[type_name]
return type_name
return "failed"
def _get_top_level_module(func: Callable[..., Any]) -> str:
"""Get the top level module for the given function."""
module = inspect.getmodule(func)
if module is None or not module.__name__:
return "unknown"
return module.__name__.split(".")[0]
def _get_arg_metadata(arg: object) -> Optional[str]:
"""Get metadata information related to the value of the given object."""
with contextlib.suppress(Exception):
if isinstance(arg, (bool)):
return f"val:{arg}"
if isinstance(arg, Sized):
return f"len:{len(arg)}"
return None
def _get_command_telemetry(
_command_func: Callable[..., Any], _command_name: str, *args, **kwargs
) -> Command:
"""Get telemetry information for the given callable and its arguments."""
arg_keywords = inspect.getfullargspec(_command_func).args
self_arg: Optional[Any] = None
arguments: List[Argument] = []
is_method = inspect.ismethod(_command_func)
name = _command_name
for i, arg in enumerate(args):
pos = i
if is_method:
# If func is a method, ignore the first argument (self)
i = i + 1
keyword = arg_keywords[i] if len(arg_keywords) > i else f"{i}"
if keyword == "self":
self_arg = arg
continue
argument = Argument(k=keyword, t=_get_type_name(arg), p=pos)
arg_metadata = _get_arg_metadata(arg)
if arg_metadata:
argument.m = arg_metadata
arguments.append(argument)
for kwarg, kwarg_value in kwargs.items():
argument = Argument(k=kwarg, t=_get_type_name(kwarg_value))
arg_metadata = _get_arg_metadata(kwarg_value)
if arg_metadata:
argument.m = arg_metadata
arguments.append(argument)
top_level_module = _get_top_level_module(_command_func)
if top_level_module != "streamlit":
# If the gather_metrics decorator is used outside of streamlit library
# we enforce a prefix to be added to the tracked command:
name = f"external:{top_level_module}:{name}"
if (
name == "create_instance"
and self_arg
and hasattr(self_arg, "name")
and self_arg.name
):
name = f"component:{self_arg.name}"
return Command(name=name, args=arguments)
def to_microseconds(seconds: float) -> int:
"""Convert seconds into microseconds."""
return int(seconds * 1_000_000)
F = TypeVar("F", bound=Callable[..., Any])
@overload
def gather_metrics(
name: str,
func: F,
) -> F:
...
@overload
def gather_metrics(
name: str,
func: None = None,
) -> Callable[[F], F]:
...
def gather_metrics(name: str, func: Optional[F] = None) -> Union[Callable[[F], F], F]:
"""Function decorator to add telemetry tracking to commands.
Parameters
----------
func : callable
The function to track for telemetry.
name : str or None
Overwrite the function name with a custom name that is used for telemetry tracking.
Example
-------
>>> @st.gather_metrics
... def my_command(url):
... return url
>>> @st.gather_metrics(name="custom_name")
... def my_command(url):
... return url
"""
if not name:
_LOGGER.warning("gather_metrics: name is empty")
name = "undefined"
if func is None:
# Support passing the params via function decorator
def wrapper(f: F) -> F:
return gather_metrics(
name=name,
func=f,
)
return wrapper
else:
# To make mypy type narrow Optional[F] -> F
non_optional_func = func
@wraps(non_optional_func)
def wrapped_func(*args, **kwargs):
exec_start = timer()
# get_script_run_ctx gets imported here to prevent circular dependencies
from streamlit.runtime.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx(suppress_warning=True)
tracking_activated = (
ctx is not None
and ctx.gather_usage_stats
and not ctx.command_tracking_deactivated
and len(ctx.tracked_commands)
< _MAX_TRACKED_COMMANDS # Prevent too much memory usage
)
command_telemetry: Optional[Command] = None
if ctx and tracking_activated:
try:
command_telemetry = _get_command_telemetry(
non_optional_func, name, *args, **kwargs
)
if (
command_telemetry.name not in ctx.tracked_commands_counter
or ctx.tracked_commands_counter[command_telemetry.name]
< _MAX_TRACKED_PER_COMMAND
):
ctx.tracked_commands.append(command_telemetry)
ctx.tracked_commands_counter.update([command_telemetry.name])
# Deactivate tracking to prevent calls inside already tracked commands
ctx.command_tracking_deactivated = True
except Exception as ex:
# Always capture all exceptions since we want to make sure that
# the telemetry never causes any issues.
_LOGGER.debug("Failed to collect command telemetry", exc_info=ex)
try:
result = non_optional_func(*args, **kwargs)
finally:
# Activate tracking again if command executes without any exceptions
if ctx:
ctx.command_tracking_deactivated = False
if tracking_activated and command_telemetry:
# Set the execution time to the measured value
command_telemetry.time = to_microseconds(timer() - exec_start)
return result
with contextlib.suppress(AttributeError):
# Make this a well-behaved decorator by preserving important function
# attributes.
wrapped_func.__dict__.update(non_optional_func.__dict__)
wrapped_func.__signature__ = inspect.signature(non_optional_func) # type: ignore
return cast(F, wrapped_func)
def create_page_profile_message(
commands: List[Command],
exec_time: int,
prep_time: int,
uncaught_exception: Optional[str] = None,
) -> ForwardMsg:
"""Create and return the full PageProfile ForwardMsg."""
msg = ForwardMsg()
msg.page_profile.commands.extend(commands)
msg.page_profile.exec_time = exec_time
msg.page_profile.prep_time = prep_time
msg.page_profile.headless = config.get_option("server.headless")
# Collect all config options that have been manually set
config_options: Set[str] = set()
if config._config_options:
for option_name in config._config_options.keys():
if not config.is_manually_set(option_name):
# We only care about manually defined options
continue
config_option = config._config_options[option_name]
if config_option.is_default:
option_name = f"{option_name}:default"
config_options.add(option_name)
msg.page_profile.config.extend(config_options)
# Check the predefined set of modules for attribution
attributions: Set[str] = {
attribution
for attribution in _ATTRIBUTIONS_TO_CHECK
if attribution in sys.modules
}
msg.page_profile.os = str(sys.platform)
msg.page_profile.timezone = str(time.tzname)
msg.page_profile.attributions.extend(attributions)
if uncaught_exception:
msg.page_profile.uncaught_exception = uncaught_exception
return msg

View File

@@ -0,0 +1,726 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import asyncio
import sys
import time
import traceback
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Awaitable, Dict, NamedTuple, Optional, Tuple, Type
from typing_extensions import Final
from streamlit import config
from streamlit.logger import get_logger
from streamlit.proto.BackMsg_pb2 import BackMsg
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.caching import (
get_data_cache_stats_provider,
get_resource_cache_stats_provider,
)
from streamlit.runtime.caching.storage.local_disk_cache_storage import (
LocalDiskCacheStorageManager,
)
from streamlit.runtime.forward_msg_cache import (
ForwardMsgCache,
create_reference_msg,
populate_hash_if_needed,
)
from streamlit.runtime.legacy_caching.caching import _mem_caches
from streamlit.runtime.media_file_manager import MediaFileManager
from streamlit.runtime.media_file_storage import MediaFileStorage
from streamlit.runtime.memory_session_storage import MemorySessionStorage
from streamlit.runtime.runtime_util import is_cacheable_msg
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.session_manager import (
ActiveSessionInfo,
SessionClient,
SessionClientDisconnectedError,
SessionManager,
SessionStorage,
)
from streamlit.runtime.state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SessionStateStatProvider,
)
from streamlit.runtime.stats import StatsManager
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.runtime.websocket_session_manager import WebsocketSessionManager
from streamlit.watcher import LocalSourcesWatcher
if TYPE_CHECKING:
from streamlit.runtime.caching.storage import CacheStorageManager
# Wait for the script run result for 60s and if no result is available give up
SCRIPT_RUN_CHECK_TIMEOUT: Final = 60
LOGGER: Final = get_logger(__name__)
class RuntimeStoppedError(Exception):
"""Raised by operations on a Runtime instance that is stopped."""
@dataclass(frozen=True)
class RuntimeConfig:
"""Config options for StreamlitRuntime."""
# The filesystem path of the Streamlit script to run.
script_path: str
# The (optional) command line that Streamlit was started with
# (e.g. "streamlit run app.py")
command_line: Optional[str]
# The storage backend for Streamlit's MediaFileManager.
media_file_storage: MediaFileStorage
# The cache storage backend for Streamlit's st.cache_data.
cache_storage_manager: CacheStorageManager = field(
default_factory=LocalDiskCacheStorageManager
)
# The SessionManager class to be used.
session_manager_class: Type[SessionManager] = WebsocketSessionManager
# The SessionStorage instance for the SessionManager to use.
session_storage: SessionStorage = field(default_factory=MemorySessionStorage)
class RuntimeState(Enum):
INITIAL = "INITIAL"
NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED"
ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED"
STOPPING = "STOPPING"
STOPPED = "STOPPED"
class AsyncObjects(NamedTuple):
"""Container for all asyncio objects that Runtime manages.
These cannot be initialized until the Runtime's eventloop is assigned.
"""
# The eventloop that Runtime is running on.
eventloop: asyncio.AbstractEventLoop
# Set after Runtime.stop() is called. Never cleared.
must_stop: asyncio.Event
# Set when a client connects; cleared when we have no connected clients.
has_connection: asyncio.Event
# Set after a ForwardMsg is enqueued; cleared when we flush ForwardMsgs.
need_send_data: asyncio.Event
# Completed when the Runtime has started.
started: asyncio.Future[None]
# Completed when the Runtime has stopped.
stopped: asyncio.Future[None]
class Runtime:
_instance: Optional[Runtime] = None
@classmethod
def instance(cls) -> Runtime:
"""Return the singleton Runtime instance. Raise an Error if the
Runtime hasn't been created yet.
"""
if cls._instance is None:
raise RuntimeError("Runtime hasn't been created!")
return cls._instance
@classmethod
def exists(cls) -> bool:
"""True if the singleton Runtime instance has been created.
When a Streamlit app is running in "raw mode" - that is, when the
app is run via `python app.py` instead of `streamlit run app.py` -
the Runtime will not exist, and various Streamlit functions need
to adapt.
"""
return cls._instance is not None
def __init__(self, config: RuntimeConfig):
"""Create a Runtime instance. It won't be started yet.
Runtime is *not* thread-safe. Its public methods are generally
safe to call only on the same thread that its event loop runs on.
Parameters
----------
config
Config options.
"""
if Runtime._instance is not None:
raise RuntimeError("Runtime instance already exists!")
Runtime._instance = self
# Will be created when we start.
self._async_objs: Optional[AsyncObjects] = None
# The task that runs our main loop. We need to save a reference
# to it so that it doesn't get garbage collected while running.
self._loop_coroutine_task: Optional[asyncio.Task[None]] = None
self._main_script_path = config.script_path
self._command_line = config.command_line or ""
self._state = RuntimeState.INITIAL
# Initialize managers
self._message_cache = ForwardMsgCache()
self._uploaded_file_mgr = UploadedFileManager()
self._uploaded_file_mgr.on_files_updated.connect(self._on_files_updated)
self._media_file_mgr = MediaFileManager(storage=config.media_file_storage)
self._cache_storage_manager = config.cache_storage_manager
self._session_mgr = config.session_manager_class(
session_storage=config.session_storage,
uploaded_file_manager=self._uploaded_file_mgr,
message_enqueued_callback=self._enqueued_some_message,
)
self._stats_mgr = StatsManager()
self._stats_mgr.register_provider(get_data_cache_stats_provider())
self._stats_mgr.register_provider(get_resource_cache_stats_provider())
self._stats_mgr.register_provider(_mem_caches)
self._stats_mgr.register_provider(self._message_cache)
self._stats_mgr.register_provider(self._uploaded_file_mgr)
self._stats_mgr.register_provider(SessionStateStatProvider(self._session_mgr))
@property
def state(self) -> RuntimeState:
return self._state
@property
def message_cache(self) -> ForwardMsgCache:
return self._message_cache
@property
def uploaded_file_mgr(self) -> UploadedFileManager:
return self._uploaded_file_mgr
@property
def cache_storage_manager(self) -> CacheStorageManager:
return self._cache_storage_manager
@property
def media_file_mgr(self) -> MediaFileManager:
return self._media_file_mgr
@property
def stats_mgr(self) -> StatsManager:
return self._stats_mgr
@property
def stopped(self) -> Awaitable[None]:
"""A Future that completes when the Runtime's run loop has exited."""
return self._get_async_objs().stopped
# NOTE: A few Runtime methods listed as threadsafe (get_client, _on_files_updated,
# and is_active_session) currently rely on the implementation detail that
# WebsocketSessionManager's get_active_session_info and is_active_session methods
# happen to be threadsafe. This may change with future SessionManager implementations,
# at which point we'll need to formalize our thread safety rules for each
# SessionManager method.
def get_client(self, session_id: str) -> Optional[SessionClient]:
"""Get the SessionClient for the given session_id, or None
if no such session exists.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info is None:
return None
return session_info.client
def _on_files_updated(self, session_id: str) -> None:
"""Event handler for UploadedFileManager.on_file_added.
Ensures that uploaded files from stale sessions get deleted.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
if not self._session_mgr.is_active_session(session_id):
# If an uploaded file doesn't belong to an active session,
# remove it so it doesn't stick around forever.
self._uploaded_file_mgr.remove_session_files(session_id)
async def start(self) -> None:
"""Start the runtime. This must be called only once, before
any other functions are called.
When this coroutine returns, Streamlit is ready to accept new sessions.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
# Create our AsyncObjects. We need to have a running eventloop to
# instantiate our various synchronization primitives.
async_objs = AsyncObjects(
eventloop=asyncio.get_running_loop(),
must_stop=asyncio.Event(),
has_connection=asyncio.Event(),
need_send_data=asyncio.Event(),
started=asyncio.Future(),
stopped=asyncio.Future(),
)
self._async_objs = async_objs
if sys.version_info >= (3, 8, 0):
# Python 3.8+ supports a create_task `name` parameter, which can
# make debugging a bit easier.
self._loop_coroutine_task = asyncio.create_task(
self._loop_coroutine(), name="Runtime.loop_coroutine"
)
else:
self._loop_coroutine_task = asyncio.create_task(self._loop_coroutine())
await async_objs.started
def stop(self) -> None:
"""Request that Streamlit close all sessions and stop running.
Note that Streamlit won't stop running immediately.
Notes
-----
Threading: SAFE. May be called from any thread.
"""
async_objs = self._get_async_objs()
def stop_on_eventloop():
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
return
LOGGER.debug("Runtime stopping...")
self._set_state(RuntimeState.STOPPING)
async_objs.must_stop.set()
async_objs.eventloop.call_soon_threadsafe(stop_on_eventloop)
def is_active_session(self, session_id: str) -> bool:
"""True if the session_id belongs to an active session.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
return self._session_mgr.is_active_session(session_id)
def connect_session(
self,
client: SessionClient,
user_info: Dict[str, Optional[str]],
existing_session_id: Optional[str] = None,
) -> str:
"""Create a new session (or connect to an existing one) and return its unique ID.
Parameters
----------
client
A concrete SessionClient implementation for communicating with
the session's client.
user_info
A dict that contains information about the session's user. For now,
it only (optionally) contains the user's email address.
{
"email": "example@example.com"
}
Returns
-------
str
The session's unique string ID.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
raise RuntimeStoppedError(f"Can't connect_session (state={self._state})")
session_id = self._session_mgr.connect_session(
client=client,
script_data=ScriptData(self._main_script_path, self._command_line or ""),
user_info=user_info,
existing_session_id=existing_session_id,
)
self._set_state(RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED)
self._get_async_objs().has_connection.set()
return session_id
def create_session(
self,
client: SessionClient,
user_info: Dict[str, Optional[str]],
existing_session_id: Optional[str] = None,
) -> str:
"""Create a new session (or connect to an existing one) and return its unique ID.
Notes
-----
This method is simply an alias for connect_session added for backwards
compatibility.
"""
LOGGER.warning("create_session is deprecated! Use connect_session instead.")
return self.connect_session(
client=client, user_info=user_info, existing_session_id=existing_session_id
)
def close_session(self, session_id: str) -> None:
"""Close and completely shut down a session.
This differs from disconnect_session in that it always completely shuts down a
session, permanently losing any associated state (session state, uploaded files,
etc.).
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id
The session's unique ID.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
self._session_mgr.close_session(session_id)
self._on_session_disconnected()
def disconnect_session(self, session_id: str) -> None:
"""Disconnect a session. It will stop producing ForwardMsgs.
Differs from close_session because disconnected sessions can be reconnected to
for a brief window (depending on the SessionManager/SessionStorage
implementations used by the runtime).
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id
The session's unique ID.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
self._session_mgr.disconnect_session(session_id)
self._on_session_disconnected()
def handle_backmsg(self, session_id: str, msg: BackMsg) -> None:
"""Send a BackMsg to an active session.
Parameters
----------
session_id
The session's unique ID.
msg
The BackMsg to deliver to the session.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
raise RuntimeStoppedError(f"Can't handle_backmsg (state={self._state})")
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info is None:
LOGGER.debug(
"Discarding BackMsg for disconnected session (id=%s)", session_id
)
return
session_info.session.handle_backmsg(msg)
def handle_backmsg_deserialization_exception(
self, session_id: str, exc: BaseException
) -> None:
"""Handle an Exception raised during deserialization of a BackMsg.
Parameters
----------
session_id
The session's unique ID.
exc
The Exception.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
raise RuntimeStoppedError(
f"Can't handle_backmsg_deserialization_exception (state={self._state})"
)
session_info = self._session_mgr.get_active_session_info(session_id)
if session_info is None:
LOGGER.debug(
"Discarding BackMsg Exception for disconnected session (id=%s)",
session_id,
)
return
session_info.session.handle_backmsg_exception(exc)
@property
async def is_ready_for_browser_connection(self) -> Tuple[bool, str]:
if self._state not in (
RuntimeState.INITIAL,
RuntimeState.STOPPING,
RuntimeState.STOPPED,
):
return True, "ok"
return False, "unavailable"
async def does_script_run_without_error(self) -> Tuple[bool, str]:
"""Load and execute the app's script to verify it runs without an error.
Returns
-------
(True, "ok") if the script completes without error, or (False, err_msg)
if the script raises an exception.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
# NOTE: We create an AppSession directly here instead of using the
# SessionManager intentionally. This isn't a "real" session and is only being
# used to test that the script runs without error.
session = AppSession(
script_data=ScriptData(self._main_script_path, self._command_line),
uploaded_file_manager=self._uploaded_file_mgr,
message_enqueued_callback=self._enqueued_some_message,
local_sources_watcher=LocalSourcesWatcher(self._main_script_path),
user_info={"email": "test@test.com"},
)
try:
session.request_rerun(None)
now = time.perf_counter()
while (
SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state
and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT
):
await asyncio.sleep(0.1)
if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state:
return False, "timeout"
ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY]
msg = "ok" if ok else "error"
return ok, msg
finally:
session.shutdown()
def _set_state(self, new_state: RuntimeState) -> None:
LOGGER.debug("Runtime state: %s -> %s", self._state, new_state)
self._state = new_state
async def _loop_coroutine(self) -> None:
"""The main Runtime loop.
This function won't exit until `stop` is called.
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
async_objs = self._get_async_objs()
try:
if self._state == RuntimeState.INITIAL:
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
pass
else:
raise RuntimeError(f"Bad Runtime state at start: {self._state}")
# Signal that we're started and ready to accept sessions
async_objs.started.set_result(None)
while not async_objs.must_stop.is_set():
if self._state == RuntimeState.NO_SESSIONS_CONNECTED:
await asyncio.wait(
(
asyncio.create_task(async_objs.must_stop.wait()),
asyncio.create_task(async_objs.has_connection.wait()),
),
return_when=asyncio.FIRST_COMPLETED,
)
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
async_objs.need_send_data.clear()
for active_session_info in self._session_mgr.list_active_sessions():
msg_list = active_session_info.session.flush_browser_queue()
for msg in msg_list:
try:
self._send_message(active_session_info, msg)
except SessionClientDisconnectedError:
self._session_mgr.disconnect_session(
active_session_info.session.id
)
# Yield for a tick after sending a message.
await asyncio.sleep(0)
# Yield for a few milliseconds between session message
# flushing.
await asyncio.sleep(0.01)
else:
# Break out of the thread loop if we encounter any other state.
break
await asyncio.wait(
(
asyncio.create_task(async_objs.must_stop.wait()),
asyncio.create_task(async_objs.need_send_data.wait()),
),
return_when=asyncio.FIRST_COMPLETED,
)
# Shut down all AppSessions.
for session_info in self._session_mgr.list_sessions():
# NOTE: We want to fully shut down sessions when the runtime stops for
# now, but this may change in the future if/when our notion of a session
# is no longer so tightly coupled to a browser tab.
self._session_mgr.close_session(session_info.session.id)
self._set_state(RuntimeState.STOPPED)
async_objs.stopped.set_result(None)
except Exception as e:
async_objs.stopped.set_exception(e)
traceback.print_exc()
LOGGER.info(
"""
Please report this bug at https://github.com/streamlit/streamlit/issues.
"""
)
def _send_message(self, session_info: ActiveSessionInfo, msg: ForwardMsg) -> None:
"""Send a message to a client.
If the client is likely to have already cached the message, we may
instead send a "reference" message that contains only the hash of the
message.
Parameters
----------
session_info : ActiveSessionInfo
The ActiveSessionInfo associated with websocket
msg : ForwardMsg
The message to send to the client
Notes
-----
Threading: UNSAFE. Must be called on the eventloop thread.
"""
msg.metadata.cacheable = is_cacheable_msg(msg)
msg_to_send = msg
if msg.metadata.cacheable:
populate_hash_if_needed(msg)
if self._message_cache.has_message_reference(
msg, session_info.session, session_info.script_run_count
):
# This session has probably cached this message. Send
# a reference instead.
LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash)
msg_to_send = create_reference_msg(msg)
# Cache the message so it can be referenced in the future.
# If the message is already cached, this will reset its
# age.
LOGGER.debug("Caching message (hash=%s)", msg.hash)
self._message_cache.add_message(
msg, session_info.session, session_info.script_run_count
)
# If this was a `script_finished` message, we increment the
# script_run_count for this session, and update the cache
if (
msg.WhichOneof("type") == "script_finished"
and msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY
):
LOGGER.debug(
"Script run finished successfully; "
"removing expired entries from MessageCache "
"(max_age=%s)",
config.get_option("global.maxCachedMessageAge"),
)
session_info.script_run_count += 1
self._message_cache.remove_expired_session_entries(
session_info.session, session_info.script_run_count
)
# Ship it off!
session_info.client.write_forward_msg(msg_to_send)
def _enqueued_some_message(self) -> None:
"""Callback called by AppSession after the AppSession has enqueued a
message. Sets the "needs_send_data" event, which causes our core
loop to wake up and flush client message queues.
Notes
-----
Threading: SAFE. May be called on any thread.
"""
async_objs = self._get_async_objs()
async_objs.eventloop.call_soon_threadsafe(async_objs.need_send_data.set)
def _get_async_objs(self) -> AsyncObjects:
"""Return our AsyncObjects instance. If the Runtime hasn't been
started, this will raise an error.
"""
if self._async_objs is None:
raise RuntimeError("Runtime hasn't started yet!")
return self._async_objs
def _on_session_disconnected(self) -> None:
"""Set the runtime state to NO_SESSIONS_CONNECTED if the last active
session was disconnected.
"""
if (
self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
and self._session_mgr.num_active_sessions() == 0
):
self._get_async_objs().has_connection.clear()
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)

View File

@@ -0,0 +1,98 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Runtime-related utility functions"""
from typing import Any, Optional
from streamlit import config
from streamlit.errors import MarkdownFormattedException
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed
class MessageSizeError(MarkdownFormattedException):
"""Exception raised when a websocket message is larger than the configured limit."""
def __init__(self, failed_msg_str: Any):
msg = self._get_message(failed_msg_str)
super(MessageSizeError, self).__init__(msg)
def _get_message(self, failed_msg_str: Any) -> str:
# This needs to have zero indentation otherwise the markdown will render incorrectly.
return (
(
"""
**Data of size {message_size_mb:.1f} MB exceeds the message size limit of {message_size_limit_mb} MB.**
This is often caused by a large chart or dataframe. Please decrease the amount of data sent
to the browser, or increase the limit by setting the config option `server.maxMessageSize`.
[Click here to learn more about config options](https://docs.streamlit.io/library/advanced-features/configuration#set-configuration-options).
_Note that increasing the limit may lead to long loading times and large memory consumption
of the client's browser and the Streamlit server._
"""
)
.format(
message_size_mb=len(failed_msg_str) / 1e6,
message_size_limit_mb=(get_max_message_size_bytes() / 1e6),
)
.strip("\n")
)
def is_cacheable_msg(msg: ForwardMsg) -> bool:
"""True if the given message qualifies for caching."""
if msg.WhichOneof("type") in {"ref_hash", "initialize"}:
# Some message types never get cached
return False
return msg.ByteSize() >= int(config.get_option("global.minCachedMessageSize"))
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
"""Serialize a ForwardMsg to send to a client.
If the message is too large, it will be converted to an exception message
instead.
"""
populate_hash_if_needed(msg)
msg_str = msg.SerializeToString()
if len(msg_str) > get_max_message_size_bytes():
import streamlit.elements.exception as exception
# Overwrite the offending ForwardMsg.delta with an error to display.
# This assumes that the size limit wasn't exceeded due to metadata.
exception.marshall(msg.delta.new_element.exception, MessageSizeError(msg_str))
msg_str = msg.SerializeToString()
return msg_str
# This needs to be initialized lazily to avoid calling config.get_option() and
# thus initializing config options when this file is first imported.
_max_message_size_bytes: Optional[int] = None
def get_max_message_size_bytes() -> int:
"""Returns the max websocket message size in bytes.
This will lazyload the value from the config and store it in the global symbol table.
"""
global _max_message_size_bytes
if _max_message_size_bytes is None:
_max_message_size_bytes = config.get_option("server.maxMessageSize") * int(1e6)
return _max_message_size_bytes

View File

@@ -0,0 +1,44 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
@dataclass(frozen=True)
class ScriptData:
"""Contains parameters related to running a script."""
main_script_path: str
command_line: str
script_folder: str = field(init=False)
name: str = field(init=False)
def __post_init__(self) -> None:
"""Set some computed values derived from main_script_path.
The usage of object.__setattr__ is necessary because trying to set
self.script_folder or self.name normally, even within the __init__ method, will
explode since we declared this dataclass to be frozen.
We do this in __post_init__ so that we can use the auto-generated __init__
method that most dataclasses use.
"""
main_script_path = os.path.abspath(self.main_script_path)
script_folder = os.path.dirname(main_script_path)
object.__setattr__(self, "script_folder", script_folder)
basename = os.path.basename(main_script_path)
name = str(os.path.splitext(basename)[0])
object.__setattr__(self, "name", name)

View File

@@ -0,0 +1,33 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Explicitly export public symbols
from streamlit.runtime.scriptrunner.script_requests import RerunData as RerunData
from streamlit.runtime.scriptrunner.script_run_context import (
ScriptRunContext as ScriptRunContext,
)
from streamlit.runtime.scriptrunner.script_run_context import (
add_script_run_ctx as add_script_run_ctx,
)
from streamlit.runtime.scriptrunner.script_run_context import (
get_script_run_ctx as get_script_run_ctx,
)
from streamlit.runtime.scriptrunner.script_runner import (
RerunException as RerunException,
)
from streamlit.runtime.scriptrunner.script_runner import ScriptRunner as ScriptRunner
from streamlit.runtime.scriptrunner.script_runner import (
ScriptRunnerEvent as ScriptRunnerEvent,
)
from streamlit.runtime.scriptrunner.script_runner import StopException as StopException

View File

@@ -0,0 +1,197 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import sys
from typing_extensions import Final
# When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of
# its module's AST:
# import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__
MAGIC_MODULE_NAME: Final = "__streamlitmagic__"
def add_magic(code, script_path):
"""Modifies the code to support magic Streamlit commands.
Parameters
----------
code : str
The Python code.
script_path : str
The path to the script file.
Returns
-------
ast.Module
The syntax tree for the code.
"""
# Pass script_path so we get pretty exceptions.
tree = ast.parse(code, script_path, "exec")
return _modify_ast_subtree(tree, is_root=True)
def _modify_ast_subtree(tree, body_attr="body", is_root=False):
"""Parses magic commands and modifies the given AST (sub)tree."""
body = getattr(tree, body_attr)
for i, node in enumerate(body):
node_type = type(node)
# Parse the contents of functions, With statements, and for statements
if (
node_type is ast.FunctionDef
or node_type is ast.With
or node_type is ast.For
or node_type is ast.While
or node_type is ast.AsyncFunctionDef
or node_type is ast.AsyncWith
or node_type is ast.AsyncFor
):
_modify_ast_subtree(node)
# Parse the contents of try statements
elif node_type is ast.Try:
for j, inner_node in enumerate(node.handlers):
node.handlers[j] = _modify_ast_subtree(inner_node)
finally_node = _modify_ast_subtree(node, body_attr="finalbody")
node.finalbody = finally_node.finalbody
_modify_ast_subtree(node)
# Convert if expressions to st.write
elif node_type is ast.If:
_modify_ast_subtree(node)
_modify_ast_subtree(node, "orelse")
# Convert standalone expression nodes to st.write
elif node_type is ast.Expr:
value = _get_st_write_from_expr(node, i, parent_type=type(tree))
if value is not None:
node.value = value
if is_root:
# Import Streamlit so we can use it in the new_value above.
_insert_import_statement(tree)
ast.fix_missing_locations(tree)
return tree
def _insert_import_statement(tree):
"""Insert Streamlit import statement at the top(ish) of the tree."""
st_import = _build_st_import_statement()
# If the 0th node is already an import statement, put the Streamlit
# import below that, so we don't break "from __future__ import".
if tree.body and type(tree.body[0]) in (ast.ImportFrom, ast.Import):
tree.body.insert(1, st_import)
# If the 0th node is a docstring and the 1st is an import statement,
# put the Streamlit import below those, so we don't break "from
# __future__ import".
elif (
len(tree.body) > 1
and (type(tree.body[0]) is ast.Expr and _is_docstring_node(tree.body[0].value))
and type(tree.body[1]) in (ast.ImportFrom, ast.Import)
):
tree.body.insert(2, st_import)
else:
tree.body.insert(0, st_import)
def _build_st_import_statement():
"""Build AST node for `import magic_funcs as __streamlitmagic__`."""
return ast.Import(
names=[
ast.alias(
name="streamlit.runtime.scriptrunner.magic_funcs",
asname=MAGIC_MODULE_NAME,
)
]
)
def _build_st_write_call(nodes):
"""Build AST node for `__streamlitmagic__.transparent_write(*nodes)`."""
return ast.Call(
func=ast.Attribute(
attr="transparent_write",
value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()),
ctx=ast.Load(),
),
args=nodes,
keywords=[],
kwargs=None,
starargs=None,
)
def _get_st_write_from_expr(node, i, parent_type):
# Don't change function calls
if type(node.value) is ast.Call:
return None
# Don't change Docstring nodes
if (
i == 0
and _is_docstring_node(node.value)
and parent_type in (ast.FunctionDef, ast.Module)
):
return None
# Don't change yield nodes
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
return None
# Don't change await nodes
if type(node.value) is ast.Await:
return None
# If tuple, call st.write on the 0th element (rather than the
# whole tuple). This allows us to add a comma at the end of a statement
# to turn it into an expression that should be st-written. Ex:
# "np.random.randn(1000, 2),"
if type(node.value) is ast.Tuple:
args = node.value.elts
st_write = _build_st_write_call(args)
# st.write all strings.
elif type(node.value) is ast.Str:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write all variables.
elif type(node.value) is ast.Name:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write everything else
else:
args = [node.value]
st_write = _build_st_write_call(args)
return st_write
def _is_docstring_node(node):
if sys.version_info >= (3, 8, 0):
return type(node) is ast.Constant and type(node.value) is str
else:
return type(node) is ast.Str

View File

@@ -0,0 +1,30 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from streamlit.runtime.metrics_util import gather_metrics
@gather_metrics("magic")
def transparent_write(*args: Any) -> Any:
"""The function that gets magic-ified into Streamlit apps.
This is just st.write, but returns the arguments you passed to it.
"""
import streamlit as st
st.write(*args)
if len(args) == 1:
return args[0]
return args

View File

@@ -0,0 +1,180 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from dataclasses import dataclass
from enum import Enum
from typing import Optional, cast
from streamlit.proto.WidgetStates_pb2 import WidgetStates
from streamlit.runtime.state import coalesce_widget_states
class ScriptRequestType(Enum):
# The ScriptRunner should continue running its script.
CONTINUE = "CONTINUE"
# If the script is running, it should be stopped as soon
# as the ScriptRunner reaches an interrupt point.
# This is a terminal state.
STOP = "STOP"
# A script rerun has been requested. The ScriptRunner should
# handle this request as soon as it reaches an interrupt point.
RERUN = "RERUN"
@dataclass(frozen=True)
class RerunData:
"""Data attached to RERUN requests. Immutable."""
query_string: str = ""
widget_states: Optional[WidgetStates] = None
page_script_hash: str = ""
page_name: str = ""
@dataclass(frozen=True)
class ScriptRequest:
"""A STOP or RERUN request and associated data."""
type: ScriptRequestType
_rerun_data: Optional[RerunData] = None
@property
def rerun_data(self) -> RerunData:
if self.type is not ScriptRequestType.RERUN:
raise RuntimeError("RerunData is only set for RERUN requests.")
return cast(RerunData, self._rerun_data)
class ScriptRequests:
"""An interface for communicating with a ScriptRunner. Thread-safe.
AppSession makes requests of a ScriptRunner through this class, and
ScriptRunner handles those requests.
"""
def __init__(self):
self._lock = threading.Lock()
self._state = ScriptRequestType.CONTINUE
self._rerun_data = RerunData()
def request_stop(self) -> None:
"""Request that the ScriptRunner stop running. A stopped ScriptRunner
can't be used anymore. STOP requests succeed unconditionally.
"""
with self._lock:
self._state = ScriptRequestType.STOP
def request_rerun(self, new_data: RerunData) -> bool:
"""Request that the ScriptRunner rerun its script.
If the ScriptRunner has been stopped, this request can't be honored:
return False.
Otherwise, record the request and return True. The ScriptRunner will
handle the rerun request as soon as it reaches an interrupt point.
"""
with self._lock:
if self._state == ScriptRequestType.STOP:
# We can't rerun after being stopped.
return False
if self._state == ScriptRequestType.CONTINUE:
# If we're running, we can handle a rerun request
# unconditionally.
self._state = ScriptRequestType.RERUN
self._rerun_data = new_data
return True
if self._state == ScriptRequestType.RERUN:
# If we have an existing Rerun request, we coalesce this
# new request into it.
if self._rerun_data.widget_states is None:
# The existing request's widget_states is None, which
# means it wants to rerun with whatever the most
# recent script execution's widget state was.
# We have no meaningful state to merge with, and
# so we simply overwrite the existing request.
self._rerun_data = new_data
return True
if new_data.widget_states is not None:
# Both the existing and the new request have
# non-null widget_states. Merge them together.
coalesced_states = coalesce_widget_states(
self._rerun_data.widget_states, new_data.widget_states
)
self._rerun_data = RerunData(
query_string=new_data.query_string,
widget_states=coalesced_states,
page_script_hash=new_data.page_script_hash,
page_name=new_data.page_name,
)
return True
# If old widget_states is NOT None, and new widget_states IS
# None, then this new request is entirely redundant. Leave
# our existing rerun_data as is.
return True
# We'll never get here
raise RuntimeError(f"Unrecognized ScriptRunnerState: {self._state}")
def on_scriptrunner_yield(self) -> Optional[ScriptRequest]:
"""Called by the ScriptRunner when it's at a yield point.
If we have no request, return None.
If we have a RERUN request, return the request and set our internal
state to CONTINUE.
If we have a STOP request, return the request and remain stopped.
"""
if self._state == ScriptRequestType.CONTINUE:
# We avoid taking a lock in the common case. If a STOP or RERUN
# request is received between the `if` and `return`, it will be
# handled at the next `on_scriptrunner_yield`, or when
# `on_scriptrunner_ready` is called.
return None
with self._lock:
if self._state == ScriptRequestType.RERUN:
self._state = ScriptRequestType.CONTINUE
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
assert self._state == ScriptRequestType.STOP
return ScriptRequest(ScriptRequestType.STOP)
def on_scriptrunner_ready(self) -> ScriptRequest:
"""Called by the ScriptRunner when it's about to run its script for
the first time, and also after its script has successfully completed.
If we have a RERUN request, return the request and set
our internal state to CONTINUE.
If we have a STOP request or no request, set our internal state
to STOP.
"""
with self._lock:
if self._state == ScriptRequestType.RERUN:
self._state = ScriptRequestType.CONTINUE
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
# If we don't have a rerun request, unconditionally change our
# state to STOP.
self._state = ScriptRequestType.STOP
return ScriptRequest(ScriptRequestType.STOP)

View File

@@ -0,0 +1,170 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import threading
from dataclasses import dataclass, field
from typing import Callable, Counter, Dict, List, Optional, Set
from typing_extensions import Final, TypeAlias
from streamlit import runtime
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.PageProfile_pb2 import Command
from streamlit.runtime.state import SafeSessionState
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
LOGGER: Final = get_logger(__name__)
UserInfo: TypeAlias = Dict[str, Optional[str]]
@dataclass
class ScriptRunContext:
"""A context object that contains data for a "script run" - that is,
data that's scoped to a single ScriptRunner execution (and therefore also
scoped to a single connected "session").
ScriptRunContext is used internally by virtually every `st.foo()` function.
It is accessed only from the script thread that's created by ScriptRunner.
Streamlit code typically retrieves the active ScriptRunContext via the
`get_script_run_ctx` function.
"""
session_id: str
_enqueue: Callable[[ForwardMsg], None]
query_string: str
session_state: SafeSessionState
uploaded_file_mgr: UploadedFileManager
page_script_hash: str
user_info: UserInfo
gather_usage_stats: bool = False
command_tracking_deactivated: bool = False
tracked_commands: List[Command] = field(default_factory=list)
tracked_commands_counter: Counter[str] = field(default_factory=collections.Counter)
_set_page_config_allowed: bool = True
_has_script_started: bool = False
widget_ids_this_run: Set[str] = field(default_factory=set)
widget_user_keys_this_run: Set[str] = field(default_factory=set)
form_ids_this_run: Set[str] = field(default_factory=set)
cursors: Dict[int, "streamlit.cursor.RunningCursor"] = field(default_factory=dict)
dg_stack: List["streamlit.delta_generator.DeltaGenerator"] = field(
default_factory=list
)
def reset(self, query_string: str = "", page_script_hash: str = "") -> None:
self.cursors = {}
self.widget_ids_this_run = set()
self.widget_user_keys_this_run = set()
self.form_ids_this_run = set()
self.query_string = query_string
self.page_script_hash = page_script_hash
# Permit set_page_config when the ScriptRunContext is reused on a rerun
self._set_page_config_allowed = True
self._has_script_started = False
self.command_tracking_deactivated: bool = False
self.tracked_commands = []
self.tracked_commands_counter = collections.Counter()
def on_script_start(self) -> None:
self._has_script_started = True
def enqueue(self, msg: ForwardMsg) -> None:
"""Enqueue a ForwardMsg for this context's session."""
if msg.HasField("page_config_changed") and not self._set_page_config_allowed:
raise StreamlitAPIException(
"`set_page_config()` can only be called once per app, "
+ "and must be called as the first Streamlit command in your script.\n\n"
+ "For more information refer to the [docs]"
+ "(https://docs.streamlit.io/library/api-reference/utilities/st.set_page_config)."
)
# We want to disallow set_page config if one of the following occurs:
# - set_page_config was called on this message
# - The script has already started and a different st call occurs (a delta)
if msg.HasField("page_config_changed") or (
msg.HasField("delta") and self._has_script_started
):
self._set_page_config_allowed = False
# Pass the message up to our associated ScriptRunner.
self._enqueue(msg)
SCRIPT_RUN_CONTEXT_ATTR_NAME: Final = "streamlit_script_run_ctx"
def add_script_run_ctx(
thread: Optional[threading.Thread] = None, ctx: Optional[ScriptRunContext] = None
):
"""Adds the current ScriptRunContext to a newly-created thread.
This should be called from this thread's parent thread,
before the new thread starts.
Parameters
----------
thread : threading.Thread
The thread to attach the current ScriptRunContext to.
ctx : ScriptRunContext or None
The ScriptRunContext to add, or None to use the current thread's
ScriptRunContext.
Returns
-------
threading.Thread
The same thread that was passed in, for chaining.
"""
if thread is None:
thread = threading.current_thread()
if ctx is None:
ctx = get_script_run_ctx()
if ctx is not None:
setattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, ctx)
return thread
def get_script_run_ctx(suppress_warning: bool = False) -> Optional[ScriptRunContext]:
"""
Parameters
----------
suppress_warning : bool
If True, don't log a warning if there's no ScriptRunContext.
Returns
-------
ScriptRunContext | None
The current thread's ScriptRunContext, or None if it doesn't have one.
"""
thread = threading.current_thread()
ctx: Optional[ScriptRunContext] = getattr(
thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, None
)
if ctx is None and runtime.exists() and not suppress_warning:
# Only warn about a missing ScriptRunContext if suppress_warning is False, and
# we were started via `streamlit run`. Otherwise, the user is likely running a
# script "bare", and doesn't need to be warned about streamlit
# bits that are irrelevant when not connected to a session.
LOGGER.warning("Thread '%s': missing ScriptRunContext", thread.name)
return ctx
# Needed to avoid circular dependencies while running tests.
import streamlit

View File

@@ -0,0 +1,704 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import sys
import threading
import types
from contextlib import contextmanager
from enum import Enum
from timeit import default_timer as timer
from typing import Callable, Dict, Optional
from blinker import Signal
from streamlit import config, runtime, source_util, util
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.logger import get_logger
from streamlit.proto.ClientState_pb2 import ClientState
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.scriptrunner import magic
from streamlit.runtime.scriptrunner.script_requests import (
RerunData,
ScriptRequests,
ScriptRequestType,
)
from streamlit.runtime.scriptrunner.script_run_context import (
ScriptRunContext,
add_script_run_ctx,
get_script_run_ctx,
)
from streamlit.runtime.state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SafeSessionState,
SessionState,
)
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.vendor.ipython.modified_sys_path import modified_sys_path
_LOGGER = get_logger(__name__)
class ScriptRunnerEvent(Enum):
## "Control" events. These are emitted when the ScriptRunner's state changes.
# The script started running.
SCRIPT_STARTED = "SCRIPT_STARTED"
# The script run stopped because of a compile error.
SCRIPT_STOPPED_WITH_COMPILE_ERROR = "SCRIPT_STOPPED_WITH_COMPILE_ERROR"
# The script run stopped because it ran to completion, or was
# interrupted by the user.
SCRIPT_STOPPED_WITH_SUCCESS = "SCRIPT_STOPPED_WITH_SUCCESS"
# The script run stopped in order to start a script run with newer widget state.
SCRIPT_STOPPED_FOR_RERUN = "SCRIPT_STOPPED_FOR_RERUN"
# The ScriptRunner is done processing the ScriptEventQueue and
# is shut down.
SHUTDOWN = "SHUTDOWN"
## "Data" events. These are emitted when the ScriptRunner's script has
## data to send to the frontend.
# The script has a ForwardMsg to send to the frontend.
ENQUEUE_FORWARD_MSG = "ENQUEUE_FORWARD_MSG"
"""
Note [Threading]
There are two kinds of threads in Streamlit, the main thread and script threads.
The main thread is started by invoking the Streamlit CLI, and bootstraps the
framework and runs the Tornado webserver.
A script thread is created by a ScriptRunner when it starts. The script thread
is where the ScriptRunner executes, including running the user script itself,
processing messages to/from the frontend, and all the Streamlit library function
calls in the user script.
It is possible for the user script to spawn its own threads, which could call
Streamlit functions. We restrict the ScriptRunner's execution control to the
script thread. Calling Streamlit functions from other threads is unlikely to
work correctly due to lack of ScriptRunContext, so we may add a guard against
it in the future.
"""
class ScriptRunner:
def __init__(
self,
session_id: str,
main_script_path: str,
client_state: ClientState,
session_state: SessionState,
uploaded_file_mgr: UploadedFileManager,
initial_rerun_data: RerunData,
user_info: Dict[str, Optional[str]],
):
"""Initialize the ScriptRunner.
(The ScriptRunner won't start executing until start() is called.)
Parameters
----------
session_id : str
The AppSession's id.
main_script_path : str
Path to our main app script.
client_state : ClientState
The current state from the client (widgets and query params).
uploaded_file_mgr : UploadedFileManager
The File manager to store the data uploaded by the file_uploader widget.
user_info: Dict
A dict that contains information about the current user. For now,
it only contains the user's email address.
{
"email": "example@example.com"
}
Information about the current user is optionally provided when a
websocket connection is initialized via the "X-Streamlit-User" header.
"""
self._session_id = session_id
self._main_script_path = main_script_path
self._uploaded_file_mgr = uploaded_file_mgr
self._user_info = user_info
# Initialize SessionState with the latest widget states
session_state.set_widgets_from_proto(client_state.widget_states)
self._client_state = client_state
self._session_state = SafeSessionState(session_state)
self._requests = ScriptRequests()
self._requests.request_rerun(initial_rerun_data)
self.on_event = Signal(
doc="""Emitted when a ScriptRunnerEvent occurs.
This signal is generally emitted on the ScriptRunner's script
thread (which is *not* the same thread that the ScriptRunner was
created on).
Parameters
----------
sender: ScriptRunner
The sender of the event (this ScriptRunner).
event : ScriptRunnerEvent
forward_msg : ForwardMsg | None
The ForwardMsg to send to the frontend. Set only for the
ENQUEUE_FORWARD_MSG event.
exception : BaseException | None
Our compile error. Set only for the
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
The ScriptRunner's final WidgetStates. Set only for the
SHUTDOWN event.
"""
)
# Set to true while we're executing. Used by
# _maybe_handle_execution_control_request.
self._execing = False
# This is initialized in start()
self._script_thread: Optional[threading.Thread] = None
def __repr__(self) -> str:
return util.repr_(self)
def request_stop(self) -> None:
"""Request that the ScriptRunner stop running its script and
shut down. The ScriptRunner will handle this request when it reaches
an interrupt point.
Safe to call from any thread.
"""
self._requests.request_stop()
# "Disconnect" our SafeSessionState wrapper from its underlying
# SessionState instance. This will cause all further session_state
# operations in this ScriptRunner to no-op.
#
# After `request_stop` is called, our script will continue executing
# until it reaches a yield point. AppSession may also *immediately*
# spin up a new ScriptRunner after this call, which means we'll
# potentially have two active ScriptRunners for a brief period while
# this one is shutting down. Disconnecting our SessionState ensures
# that this ScriptRunner's thread won't introduce SessionState-
# related race conditions during this script overlap.
self._session_state.disconnect()
def request_rerun(self, rerun_data: RerunData) -> bool:
"""Request that the ScriptRunner interrupt its currently-running
script and restart it.
If the ScriptRunner has been stopped, this request can't be honored:
return False.
Otherwise, record the request and return True. The ScriptRunner will
handle the rerun request as soon as it reaches an interrupt point.
Safe to call from any thread.
"""
return self._requests.request_rerun(rerun_data)
def start(self) -> None:
"""Start a new thread to process the ScriptEventQueue.
This must be called only once.
"""
if self._script_thread is not None:
raise Exception("ScriptRunner was already started")
self._script_thread = threading.Thread(
target=self._run_script_thread,
name="ScriptRunner.scriptThread",
)
self._script_thread.start()
def _get_script_run_ctx(self) -> ScriptRunContext:
"""Get the ScriptRunContext for the current thread.
Returns
-------
ScriptRunContext
The ScriptRunContext for the current thread.
Raises
------
AssertionError
If called outside of a ScriptRunner thread.
RuntimeError
If there is no ScriptRunContext for the current thread.
"""
assert self._is_in_script_thread()
ctx = get_script_run_ctx()
if ctx is None:
# This should never be possible on the script_runner thread.
raise RuntimeError(
"ScriptRunner thread has a null ScriptRunContext. Something has gone very wrong!"
)
return ctx
def _run_script_thread(self) -> None:
"""The entry point for the script thread.
Processes the ScriptRequestQueue, which will at least contain the RERUN
request that will trigger the first script-run.
When the ScriptRequestQueue is empty, or when a SHUTDOWN request is
dequeued, this function will exit and its thread will terminate.
"""
assert self._is_in_script_thread()
_LOGGER.debug("Beginning script thread")
# Create and attach the thread's ScriptRunContext
ctx = ScriptRunContext(
session_id=self._session_id,
_enqueue=self._enqueue_forward_msg,
query_string=self._client_state.query_string,
session_state=self._session_state,
uploaded_file_mgr=self._uploaded_file_mgr,
page_script_hash=self._client_state.page_script_hash,
user_info=self._user_info,
gather_usage_stats=bool(config.get_option("browser.gatherUsageStats")),
)
add_script_run_ctx(threading.current_thread(), ctx)
request = self._requests.on_scriptrunner_ready()
while request.type == ScriptRequestType.RERUN:
# When the script thread starts, we'll have a pending rerun
# request that we'll handle immediately. When the script finishes,
# it's possible that another request has come in that we need to
# handle, which is why we call _run_script in a loop.
self._run_script(request.rerun_data)
request = self._requests.on_scriptrunner_ready()
assert request.type == ScriptRequestType.STOP
# Send a SHUTDOWN event before exiting. This includes the widget values
# as they existed after our last successful script run, which the
# AppSession will pass on to the next ScriptRunner that gets
# created.
client_state = ClientState()
client_state.query_string = ctx.query_string
client_state.page_script_hash = ctx.page_script_hash
widget_states = self._session_state.get_widget_states()
client_state.widget_states.widgets.extend(widget_states)
self.on_event.send(
self, event=ScriptRunnerEvent.SHUTDOWN, client_state=client_state
)
def _is_in_script_thread(self) -> bool:
"""True if the calling function is running in the script thread"""
return self._script_thread == threading.current_thread()
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
"""Enqueue a ForwardMsg to our browser queue.
This private function is called by ScriptRunContext only.
It may be called from the script thread OR the main thread.
"""
# Whenever we enqueue a ForwardMsg, we also handle any pending
# execution control request. This means that a script can be
# cleanly interrupted and stopped inside most `st.foo` calls.
#
# (If "runner.installTracer" is true, then we'll actually be
# handling these requests in a callback called after every Python
# instruction instead.)
if not config.get_option("runner.installTracer"):
self._maybe_handle_execution_control_request()
# Pass the message to our associated AppSession.
self.on_event.send(
self, event=ScriptRunnerEvent.ENQUEUE_FORWARD_MSG, forward_msg=msg
)
def _maybe_handle_execution_control_request(self) -> None:
"""Check our current ScriptRequestState to see if we have a
pending STOP or RERUN request.
This function is called every time the app script enqueues a
ForwardMsg, which means that most `st.foo` commands - which generally
involve sending a ForwardMsg to the frontend - act as implicit
yield points in the script's execution.
"""
if not self._is_in_script_thread():
# We can only handle execution_control_request if we're on the
# script execution thread. However, it's possible for deltas to
# be enqueued (and, therefore, for this function to be called)
# in separate threads, so we check for that here.
return
if not self._execing:
# If the _execing flag is not set, we're not actually inside
# an exec() call. This happens when our script exec() completes,
# we change our state to STOPPED, and a statechange-listener
# enqueues a new ForwardEvent
return
request = self._requests.on_scriptrunner_yield()
if request is None:
# No RERUN or STOP request.
return
if request.type == ScriptRequestType.RERUN:
raise RerunException(request.rerun_data)
assert request.type == ScriptRequestType.STOP
raise StopException()
def _install_tracer(self) -> None:
"""Install function that runs before each line of the script."""
def trace_calls(frame, event, arg):
self._maybe_handle_execution_control_request()
return trace_calls
# Python interpreters are not required to implement sys.settrace.
if hasattr(sys, "settrace"):
sys.settrace(trace_calls)
@contextmanager
def _set_execing_flag(self):
"""A context for setting the ScriptRunner._execing flag.
Used by _maybe_handle_execution_control_request to ensure that
we only handle requests while we're inside an exec() call
"""
if self._execing:
raise RuntimeError("Nested set_execing_flag call")
self._execing = True
try:
yield
finally:
self._execing = False
def _run_script(self, rerun_data: RerunData) -> None:
"""Run our script.
Parameters
----------
rerun_data: RerunData
The RerunData to use.
"""
assert self._is_in_script_thread()
_LOGGER.debug("Running script %s", rerun_data)
start_time: float = timer()
prep_time: float = 0 # This will be overwritten once preparations are done.
# Reset DeltaGenerators, widgets, media files.
runtime.get_instance().media_file_mgr.clear_session_refs()
main_script_path = self._main_script_path
pages = source_util.get_pages(main_script_path)
# Safe because pages will at least contain the app's main page.
main_page_info = list(pages.values())[0]
current_page_info = None
uncaught_exception = None
if rerun_data.page_script_hash:
current_page_info = pages.get(rerun_data.page_script_hash, None)
elif not rerun_data.page_script_hash and rerun_data.page_name:
# If a user navigates directly to a non-main page of an app, we get
# the first script run request before the list of pages has been
# sent to the frontend. In this case, we choose the first script
# with a name matching the requested page name.
current_page_info = next(
filter(
# There seems to be this weird bug with mypy where it
# thinks that p can be None (which is impossible given the
# types of pages), so we add `p and` at the beginning of
# the predicate to circumvent this.
lambda p: p and (p["page_name"] == rerun_data.page_name),
pages.values(),
),
None,
)
else:
# If no information about what page to run is given, default to
# running the main page.
current_page_info = main_page_info
page_script_hash = (
current_page_info["page_script_hash"]
if current_page_info is not None
else main_page_info["page_script_hash"]
)
ctx = self._get_script_run_ctx()
ctx.reset(
query_string=rerun_data.query_string,
page_script_hash=page_script_hash,
)
self.on_event.send(
self,
event=ScriptRunnerEvent.SCRIPT_STARTED,
page_script_hash=page_script_hash,
)
# Compile the script. Any errors thrown here will be surfaced
# to the user via a modal dialog in the frontend, and won't result
# in their previous script elements disappearing.
try:
if current_page_info:
script_path = current_page_info["script_path"]
else:
script_path = main_script_path
# At this point, we know that either
# * the script corresponding to the hash requested no longer
# exists, or
# * we were not able to find a script with the requested page
# name.
# In both of these cases, we want to send a page_not_found
# message to the frontend.
msg = ForwardMsg()
msg.page_not_found.page_name = rerun_data.page_name
ctx.enqueue(msg)
with source_util.open_python_file(script_path) as f:
filebody = f.read()
if config.get_option("runner.magicEnabled"):
filebody = magic.add_magic(filebody, script_path)
code = compile( # type: ignore
filebody,
# Pass in the file path so it can show up in exceptions.
script_path,
# We're compiling entire blocks of Python, so we need "exec"
# mode (as opposed to "eval" or "single").
mode="exec",
# Don't inherit any flags or "future" statements.
flags=0,
dont_inherit=1,
# Use the default optimization options.
optimize=-1,
)
except Exception as ex:
# We got a compile error. Send an error event and bail immediately.
_LOGGER.debug("Fatal script error: %s", ex)
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
self.on_event.send(
self,
event=ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
exception=ex,
)
return
# If we get here, we've successfully compiled our script. The next step
# is to run it. Errors thrown during execution will be shown to the
# user as ExceptionElements.
if config.get_option("runner.installTracer"):
self._install_tracer()
# This will be set to a RerunData instance if our execution
# is interrupted by a RerunException.
rerun_exception_data: Optional[RerunData] = None
try:
# Create fake module. This gives us a name global namespace to
# execute the code in.
# TODO(vdonato): Double-check that we're okay with naming the
# module for every page `__main__`. I'm pretty sure this is
# necessary given that people will likely often write
# ```
# if __name__ == "__main__":
# ...
# ```
# in their scripts.
module = _new_module("__main__")
# Install the fake module as the __main__ module. This allows
# the pickle module to work inside the user's code, since it now
# can know the module where the pickled objects stem from.
# IMPORTANT: This means we can't use "if __name__ == '__main__'" in
# our code, as it will point to the wrong module!!!
sys.modules["__main__"] = module
# Add special variables to the module's globals dict.
# Note: The following is a requirement for the CodeHasher to
# work correctly. The CodeHasher is scoped to
# files contained in the directory of __main__.__file__, which we
# assume is the main script directory.
module.__dict__["__file__"] = script_path
with modified_sys_path(self._main_script_path), self._set_execing_flag():
# Run callbacks for widgets whose values have changed.
if rerun_data.widget_states is not None:
self._session_state.on_script_will_rerun(rerun_data.widget_states)
ctx.on_script_start()
prep_time = timer() - start_time
exec(code, module.__dict__)
self._session_state.maybe_check_serializable()
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = True
except RerunException as e:
rerun_exception_data = e.rerun_data
except StopException:
# This is thrown when the script executes `st.stop()`.
# We don't have to do anything here.
pass
except Exception as ex:
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
uncaught_exception = ex
handle_uncaught_app_exception(uncaught_exception)
finally:
if rerun_exception_data:
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN
else:
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
if ctx.gather_usage_stats:
try:
# Prevent issues with circular import
from streamlit.runtime.metrics_util import (
create_page_profile_message,
to_microseconds,
)
# Create and send page profile information
ctx.enqueue(
create_page_profile_message(
ctx.tracked_commands,
exec_time=to_microseconds(timer() - start_time),
prep_time=to_microseconds(prep_time),
uncaught_exception=type(uncaught_exception).__name__
if uncaught_exception
else None,
)
)
except Exception as ex:
# Always capture all exceptions since we want to make sure that
# the telemetry never causes any issues.
_LOGGER.debug("Failed to create page profile", exc_info=ex)
self._on_script_finished(ctx, finished_event)
# Use _log_if_error() to make sure we never ever ever stop running the
# script without meaning to.
_log_if_error(_clean_problem_modules)
if rerun_exception_data is not None:
self._run_script(rerun_exception_data)
def _on_script_finished(
self, ctx: ScriptRunContext, event: ScriptRunnerEvent
) -> None:
"""Called when our script finishes executing, even if it finished
early with an exception. We perform post-run cleanup here.
"""
# Tell session_state to update itself in response
self._session_state.on_script_finished(ctx.widget_ids_this_run)
# Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS
# even if we were stopped with an exception.)
self.on_event.send(self, event=event)
# Remove orphaned files now that the script has run and files in use
# are marked as active.
runtime.get_instance().media_file_mgr.remove_orphaned_files()
# Force garbage collection to run, to help avoid memory use building up
# This is usually not an issue, but sometimes GC takes time to kick in and
# causes apps to go over resource limits, and forcing it to run between
# script runs is low cost, since we aren't doing much work anyway.
if config.get_option("runner.postScriptGC"):
gc.collect(2)
class ScriptControlException(BaseException):
"""Base exception for ScriptRunner."""
pass
class StopException(ScriptControlException):
"""Silently stop the execution of the user's script."""
pass
class RerunException(ScriptControlException):
"""Silently stop and rerun the user's script."""
def __init__(self, rerun_data: RerunData):
"""Construct a RerunException
Parameters
----------
rerun_data : RerunData
The RerunData that should be used to rerun the script
"""
self.rerun_data = rerun_data
def __repr__(self) -> str:
return util.repr_(self)
def _clean_problem_modules() -> None:
"""Some modules are stateful, so we have to clear their state."""
if "keras" in sys.modules:
try:
keras = sys.modules["keras"]
keras.backend.clear_session()
except Exception:
# We don't want to crash the app if we can't clear the Keras session.
pass
if "matplotlib.pyplot" in sys.modules:
try:
plt = sys.modules["matplotlib.pyplot"]
plt.close("all")
except Exception:
# We don't want to crash the app if we can't close matplotlib
pass
def _new_module(name: str) -> types.ModuleType:
"""Create a new module with the given name."""
return types.ModuleType(name)
# The reason this is not a decorator is because we want to make it clear at the
# calling location that this function is being used.
def _log_if_error(fn: Callable[[], None]) -> None:
try:
fn()
except Exception as e:
_LOGGER.warning(e)

View File

@@ -0,0 +1,352 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import threading
from copy import deepcopy
from typing import (
Any,
Dict,
ItemsView,
Iterator,
KeysView,
List,
Mapping,
NoReturn,
Optional,
ValuesView,
)
import toml
from blinker import Signal
from typing_extensions import Final
import streamlit as st
import streamlit.watcher.path_watcher
from streamlit import file_util, runtime
from streamlit.logger import get_logger
_LOGGER = get_logger(__name__)
SECRETS_FILE_LOCS: Final[List[str]] = [
file_util.get_streamlit_file_path("secrets.toml"),
# NOTE: The order here is important! Project-level secrets should overwrite global
# secrets.
file_util.get_project_streamlit_file_path("secrets.toml"),
]
def _missing_attr_error_message(attr_name: str) -> str:
return (
f'st.secrets has no attribute "{attr_name}". '
f"Did you forget to add it to secrets.toml or the app settings on Streamlit Cloud? "
f"More info: https://docs.streamlit.io/streamlit-cloud/get-started/deploy-an-app/connect-to-data-sources/secrets-management"
)
def _missing_key_error_message(key: str) -> str:
return (
f'st.secrets has no key "{key}". '
f"Did you forget to add it to secrets.toml or the app settings on Streamlit Cloud? "
f"More info: https://docs.streamlit.io/streamlit-cloud/get-started/deploy-an-app/connect-to-data-sources/secrets-management"
)
class AttrDict(Mapping[str, Any]):
"""
We use AttrDict to wrap up dictionary values from secrets
to provide dot access to nested secrets
"""
def __init__(self, value):
self.__dict__["__nested_secrets__"] = dict(value)
@staticmethod
def _maybe_wrap_in_attr_dict(value) -> Any:
if not isinstance(value, Mapping):
return value
else:
return AttrDict(value)
def __len__(self) -> int:
return len(self.__nested_secrets__)
def __iter__(self) -> Iterator[str]:
return iter(self.__nested_secrets__)
def __getitem__(self, key: str) -> Any:
try:
value = self.__nested_secrets__[key]
return self._maybe_wrap_in_attr_dict(value)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def __getattr__(self, attr_name: str) -> Any:
try:
value = self.__nested_secrets__[attr_name]
return self._maybe_wrap_in_attr_dict(value)
except KeyError:
raise AttributeError(_missing_attr_error_message(attr_name))
def __repr__(self):
return repr(self.__nested_secrets__)
def __setitem__(self, key, value) -> NoReturn:
raise TypeError("Secrets does not support item assignment.")
def __setattr__(self, key, value) -> NoReturn:
raise TypeError("Secrets does not support attribute assignment.")
def to_dict(self) -> Dict[str, Any]:
return deepcopy(self.__nested_secrets__)
class Secrets(Mapping[str, Any]):
"""A dict-like class that stores secrets.
Parses secrets.toml on-demand. Cannot be externally mutated.
Safe to use from multiple threads.
"""
def __init__(self, file_paths: List[str]):
# Our secrets dict.
self._secrets: Optional[Mapping[str, Any]] = None
self._lock = threading.RLock()
self._file_watchers_installed = False
self._file_paths = file_paths
self.file_change_listener = Signal(
doc="Emitted when a `secrets.toml` file has been changed."
)
def load_if_toml_exists(self) -> bool:
"""Load secrets.toml files from disk if they exists. If none exist,
no exception will be raised. (If a file exists but is malformed,
an exception *will* be raised.)
Returns True if a secrets.toml file was successfully parsed, False otherwise.
Thread-safe.
"""
try:
self._parse(print_exceptions=False)
return True
except FileNotFoundError:
# No secrets.toml files exist. That's fine.
return False
def _reset(self) -> None:
"""Clear the secrets dictionary and remove any secrets that were
added to os.environ.
Thread-safe.
"""
with self._lock:
if self._secrets is None:
return
for k, v in self._secrets.items():
self._maybe_delete_environment_variable(k, v)
self._secrets = None
def _parse(self, print_exceptions: bool) -> Mapping[str, Any]:
"""Parse our secrets.toml files if they're not already parsed.
This function is safe to call from multiple threads.
Parameters
----------
print_exceptions : bool
If True, then exceptions will be printed with `st.error` before
being re-raised.
Raises
------
FileNotFoundError
Raised if secrets.toml doesn't exist.
"""
# Avoid taking a lock for the common case where secrets are already
# loaded.
secrets = self._secrets
if secrets is not None:
return secrets
with self._lock:
if self._secrets is not None:
return self._secrets
# It's fine for a user to only have one secrets.toml file defined, so
# we ignore individual FileNotFoundErrors when attempting to read files
# below and only raise an exception if we weren't able read *any* secrets
# file.
found_secrets_file = False
secrets = {}
for path in self._file_paths:
try:
with open(path, encoding="utf-8") as f:
secrets_file_str = f.read()
found_secrets_file = True
except FileNotFoundError:
continue
try:
secrets.update(toml.loads(secrets_file_str))
except:
if print_exceptions:
st.error(f"Error parsing secrets file at {path}")
raise
if not found_secrets_file:
err_msg = f"No secrets files found. Valid paths for a secrets.toml file are: {', '.join(self._file_paths)}"
if print_exceptions:
st.error(err_msg)
raise FileNotFoundError(err_msg)
if len([p for p in self._file_paths if os.path.exists(p)]) > 1:
_LOGGER.info(
f"Secrets found in multiple locations: {', '.join(self._file_paths)}. "
"When multiple secret.toml files exist, local secrets will take precedence over global secrets."
)
for k, v in secrets.items():
self._maybe_set_environment_variable(k, v)
self._secrets = secrets
self._maybe_install_file_watchers()
return self._secrets
@staticmethod
def _maybe_set_environment_variable(k: Any, v: Any) -> None:
"""Add the given key/value pair to os.environ if the value
is a string, int, or float.
"""
value_type = type(v)
if value_type in (str, int, float):
os.environ[k] = str(v)
@staticmethod
def _maybe_delete_environment_variable(k: Any, v: Any) -> None:
"""Remove the given key/value pair from os.environ if the value
is a string, int, or float.
"""
value_type = type(v)
if value_type in (str, int, float) and os.environ.get(k) == v:
del os.environ[k]
def _maybe_install_file_watchers(self) -> None:
with self._lock:
if self._file_watchers_installed:
return
for path in self._file_paths:
try:
streamlit.watcher.path_watcher.watch_file(
path,
self._on_secrets_file_changed,
watcher_type="poll",
)
except FileNotFoundError:
# A user may only have one secrets.toml file defined, so we'd expect
# FileNotFoundErrors to be raised when attempting to install a
# watcher on the nonexistent ones.
pass
# We set file_watchers_installed to True even if the installation attempt
# failed to avoid repeatedly trying to install it.
self._file_watchers_installed = True
def _on_secrets_file_changed(self, changed_file_path) -> None:
with self._lock:
_LOGGER.debug("Secrets file %s changed, reloading", changed_file_path)
self._reset()
self._parse(print_exceptions=True)
# Emit a signal to notify receivers that the `secrets.toml` file
# has been changed.
self.file_change_listener.send()
def __getattr__(self, key: str) -> Any:
"""Return the value with the given key. If no such key
exists, raise an AttributeError.
Thread-safe.
"""
try:
value = self._parse(True)[key]
if not isinstance(value, Mapping):
return value
else:
return AttrDict(value)
# We add FileNotFoundError since __getattr__ is expected to only raise
# AttributeError. Without handling FileNotFoundError, unittests.mocks
# fails during mock creation on Python3.9
except (KeyError, FileNotFoundError):
raise AttributeError(_missing_attr_error_message(key))
def __getitem__(self, key: str) -> Any:
"""Return the value with the given key. If no such key
exists, raise a KeyError.
Thread-safe.
"""
try:
value = self._parse(True)[key]
if not isinstance(value, Mapping):
return value
else:
return AttrDict(value)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def __repr__(self) -> str:
# If the runtime is NOT initialized, it is a method call outside
# the streamlit app, so we avoid reading the secrets file as it may not exist.
# If the runtime is initialized, display the contents of the file and
# the file must already exist.
"""A string representation of the contents of the dict. Thread-safe."""
if not runtime.exists():
return f"{self.__class__.__name__}(file_paths={self._file_paths!r})"
return repr(self._parse(True))
def __len__(self) -> int:
"""The number of entries in the dict. Thread-safe."""
return len(self._parse(True))
def has_key(self, k: str) -> bool:
"""True if the given key is in the dict. Thread-safe."""
return k in self._parse(True)
def keys(self) -> KeysView[str]:
"""A view of the keys in the dict. Thread-safe."""
return self._parse(True).keys()
def values(self) -> ValuesView[Any]:
"""A view of the values in the dict. Thread-safe."""
return self._parse(True).values()
def items(self) -> ItemsView[str, Any]:
"""A view of the key-value items in the dict. Thread-safe."""
return self._parse(True).items()
def __contains__(self, key: Any) -> bool:
"""True if the given key is in the dict. Thread-safe."""
return key in self._parse(True)
def __iter__(self) -> Iterator[str]:
"""An iterator over the keys in the dict. Thread-safe."""
return iter(self._parse(True))
secrets_singleton: Final = Secrets(SECRETS_FILE_LOCS)

View File

@@ -0,0 +1,381 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, cast
from typing_extensions import Protocol
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
class SessionClientDisconnectedError(Exception):
"""Raised by operations on a disconnected SessionClient."""
class SessionClient(Protocol):
"""Interface for sending data to a session's client."""
@abstractmethod
def write_forward_msg(self, msg: ForwardMsg) -> None:
"""Deliver a ForwardMsg to the client.
If the SessionClient has been disconnected, it should raise a
SessionClientDisconnectedError.
"""
raise NotImplementedError
@dataclass
class ActiveSessionInfo:
"""Type containing data related to an active session.
This type is nearly identical to SessionInfo. The difference is that when using it,
we are guaranteed that SessionClient is not None.
"""
client: SessionClient
session: AppSession
script_run_count: int = 0
@dataclass
class SessionInfo:
"""Type containing data related to an AppSession.
For each AppSession, the Runtime tracks that session's
script_run_count. This is used to track the age of messages in
the ForwardMsgCache.
"""
client: Optional[SessionClient]
session: AppSession
script_run_count: int = 0
def is_active(self) -> bool:
return self.client is not None
def to_active(self) -> ActiveSessionInfo:
assert self.is_active(), "A SessionInfo with no client cannot be active!"
# NOTE: The cast here (rather than copying this SessionInfo's fields into a new
# ActiveSessionInfo) is important as the Runtime expects to be able to mutate
# what's returned from get_active_session_info to increment script_run_count.
return cast(ActiveSessionInfo, self)
class SessionStorageError(Exception):
"""Exception class for errors raised by SessionStorage.
The original error that causes a SessionStorageError to be (re)raised will generally
be an I/O error specific to the concrete SessionStorage implementation.
"""
class SessionStorage(Protocol):
@abstractmethod
def get(self, session_id: str) -> Optional[SessionInfo]:
"""Return the SessionInfo corresponding to session_id, or None if one does not
exist.
Parameters
----------
session_id
The unique ID of the session being fetched.
Returns
-------
SessionInfo or None
Raises
------
SessionStorageError
Raised if an error occurs while attempting to fetch the session. This will
generally happen if there is an error with the underlying storage backend
(e.g. if we lose our connection to it).
"""
raise NotImplementedError
@abstractmethod
def save(self, session_info: SessionInfo) -> None:
"""Save the given session.
Parameters
----------
session_info
The SessionInfo being saved.
Raises
------
SessionStorageError
Raised if an error occurs while saving the given session.
"""
raise NotImplementedError
@abstractmethod
def delete(self, session_id: str) -> None:
"""Mark the session corresponding to session_id for deletion and stop tracking
it.
Note that:
* Calling delete on an ID corresponding to a nonexistent session is a no-op.
* Calling delete on an ID should cause the given session to no longer be
tracked by this SessionStorage, but exactly when and how the session's data
is eventually cleaned up is a detail left up to the implementation.
Parameters
----------
session_id
The unique ID of the session to delete.
Raises
------
SessionStorageError
Raised if an error occurs while attempting to delete the session.
"""
raise NotImplementedError
@abstractmethod
def list(self) -> List[SessionInfo]:
"""List all sessions tracked by this SessionStorage.
Returns
-------
List[SessionInfo]
Raises
------
SessionStorageError
Raised if an error occurs while attempting to list sessions.
"""
raise NotImplementedError
class SessionManager(Protocol):
"""SessionManagers are responsible for encapsulating all session lifecycle behavior
that the Streamlit Runtime may care about.
A SessionManager must define the following required methods:
* __init__
* connect_session
* close_session
* get_session_info
* list_sessions
SessionManager implementations may also choose to define the notions of active and
inactive sessions. The precise definitions of active/inactive are left to the
concrete implementation. SessionManagers that wish to differentiate between active
and inactive sessions should have the required methods listed above operate on *all*
sessions. Additionally, they should define the following methods for working with
active sessions:
* disconnect_session
* get_active_session_info
* is_active_session
* list_active_sessions
When active session-related methods are left undefined, their default
implementations are the naturally corresponding required methods.
The Runtime, unless there's a good reason to do otherwise, should generally work
with the active-session versions of a SessionManager's methods. There isn't currently
a need for us to be able to operate on inactive sessions stored in SessionStorage
outside of the SessionManager itself. However, it's highly likely that we'll
eventually have to do so, which is why the abstractions allow for this now.
Notes
-----
Threading: All SessionManager methods are *not* threadsafe -- they must be called
from the runtime's eventloop thread.
"""
@abstractmethod
def __init__(
self,
session_storage: SessionStorage,
uploaded_file_manager: UploadedFileManager,
message_enqueued_callback: Optional[Callable[[], None]],
) -> None:
"""Initialize a SessionManager with the given SessionStorage.
Parameters
----------
session_storage
The SessionStorage instance backing this SessionManager.
uploaded_file_manager
Used to manage files uploaded by users via the Streamlit web client.
message_enqueued_callback
A callback invoked after a message is enqueued to be sent to a web client.
"""
raise NotImplementedError
@abstractmethod
def connect_session(
self,
client: SessionClient,
script_data: ScriptData,
user_info: Dict[str, Optional[str]],
existing_session_id: Optional[str] = None,
) -> str:
"""Create a new session or connect to an existing one.
Parameters
----------
client
A concrete SessionClient implementation for communicating with
the session's client.
script_data
Contains parameters related to running a script.
user_info
A dict that contains information about the session's user. For now,
it only (optionally) contains the user's email address.
{
"email": "example@example.com"
}
existing_session_id
The ID of an existing session to reconnect to. If one is not provided, a new
session is created. Note that whether a SessionManager supports reconnection
to an existing session is left up to the concrete SessionManager
implementation. Those that do not support reconnection should simply ignore
this argument.
Returns
-------
str
The session's unique string ID.
"""
raise NotImplementedError
@abstractmethod
def close_session(self, session_id: str) -> None:
"""Close and completely delete the session with the given id.
This function may be called multiple times for the same session,
which is not an error. (Subsequent calls just no-op.)
Parameters
----------
session_id
The session's unique ID.
"""
raise NotImplementedError
@abstractmethod
def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
"""Return the SessionInfo for the given id, or None if no such session
exists.
Parameters
----------
session_id
The session's unique ID.
Returns
-------
SessionInfo or None
"""
raise NotImplementedError
@abstractmethod
def list_sessions(self) -> List[SessionInfo]:
"""Return the SessionInfo for all sessions managed by this SessionManager.
Returns
-------
List[SessionInfo]
"""
raise NotImplementedError
def num_sessions(self) -> int:
"""Return the number of sessions tracked by this SessionManager.
Subclasses of SessionManager shouldn't provide their own implementation of this
method without a *very* good reason.
Returns
-------
int
"""
return len(self.list_sessions())
# NOTE: The following methods only need to be overwritten when a concrete
# SessionManager implementation has a notion of active vs inactive sessions.
# If left unimplemented in a subclass, the default implementations of these methods
# call corresponding SessionManager methods in a natural way.
def disconnect_session(self, session_id: str) -> None:
"""Disconnect the given session.
This method should be idempotent.
Parameters
----------
session_id
The session's unique ID.
"""
self.close_session(session_id)
def get_active_session_info(self, session_id: str) -> Optional[ActiveSessionInfo]:
"""Return the ActiveSessionInfo for the given id, or None if either no such
session exists or the session is not active.
Parameters
----------
session_id
The active session's unique ID.
Returns
-------
ActiveSessionInfo or None
"""
session = self.get_session_info(session_id)
if session is None or not session.is_active():
return None
return session.to_active()
def is_active_session(self, session_id: str) -> bool:
"""Return True if the given session exists and is active, False otherwise.
Returns
-------
bool
"""
return self.get_active_session_info(session_id) is not None
def list_active_sessions(self) -> List[ActiveSessionInfo]:
"""Return the session info for all active sessions tracked by this SessionManager.
Returns
-------
List[ActiveSessionInfo]
"""
return [s.to_active() for s in self.list_sessions()]
def num_active_sessions(self) -> int:
"""Return the number of active sessions tracked by this SessionManager.
Subclasses of SessionManager shouldn't provide their own implementation of this
method without a *very* good reason.
Returns
-------
int
"""
return len(self.list_active_sessions())

View File

@@ -0,0 +1,40 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.state.common import WidgetArgs as WidgetArgs
from streamlit.runtime.state.common import WidgetCallback as WidgetCallback
from streamlit.runtime.state.common import WidgetKwargs as WidgetKwargs
# Explicitly re-export public symbols
from streamlit.runtime.state.safe_session_state import (
SafeSessionState as SafeSessionState,
)
from streamlit.runtime.state.session_state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY as SCRIPT_RUN_WITHOUT_ERRORS_KEY,
)
from streamlit.runtime.state.session_state import SessionState as SessionState
from streamlit.runtime.state.session_state import (
SessionStateStatProvider as SessionStateStatProvider,
)
from streamlit.runtime.state.session_state_proxy import (
SessionStateProxy as SessionStateProxy,
)
from streamlit.runtime.state.session_state_proxy import (
get_session_state as get_session_state,
)
from streamlit.runtime.state.widgets import NoValue as NoValue
from streamlit.runtime.state.widgets import (
coalesce_widget_states as coalesce_widget_states,
)
from streamlit.runtime.state.widgets import register_widget as register_widget

View File

@@ -0,0 +1,183 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions and data structures shared by session_state.py and widgets.py"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, Union
from typing_extensions import Final, TypeAlias
from streamlit.errors import StreamlitAPIException
from streamlit.proto.Arrow_pb2 import Arrow
from streamlit.proto.Button_pb2 import Button
from streamlit.proto.CameraInput_pb2 import CameraInput
from streamlit.proto.Checkbox_pb2 import Checkbox
from streamlit.proto.ColorPicker_pb2 import ColorPicker
from streamlit.proto.Components_pb2 import ComponentInstance
from streamlit.proto.DateInput_pb2 import DateInput
from streamlit.proto.DownloadButton_pb2 import DownloadButton
from streamlit.proto.FileUploader_pb2 import FileUploader
from streamlit.proto.MultiSelect_pb2 import MultiSelect
from streamlit.proto.NumberInput_pb2 import NumberInput
from streamlit.proto.Radio_pb2 import Radio
from streamlit.proto.Selectbox_pb2 import Selectbox
from streamlit.proto.Slider_pb2 import Slider
from streamlit.proto.TextArea_pb2 import TextArea
from streamlit.proto.TextInput_pb2 import TextInput
from streamlit.proto.TimeInput_pb2 import TimeInput
from streamlit.type_util import ValueFieldName
# Protobuf types for all widgets.
WidgetProto: TypeAlias = Union[
Arrow,
Button,
CameraInput,
Checkbox,
ColorPicker,
ComponentInstance,
DateInput,
DownloadButton,
FileUploader,
MultiSelect,
NumberInput,
Radio,
Selectbox,
Slider,
TextArea,
TextInput,
TimeInput,
]
GENERATED_WIDGET_ID_PREFIX: Final = "$$GENERATED_WIDGET_ID"
T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)
WidgetArgs: TypeAlias = Tuple[Any, ...]
WidgetKwargs: TypeAlias = Dict[str, Any]
WidgetCallback: TypeAlias = Callable[..., None]
# A deserializer receives the value from whatever field is set on the
# WidgetState proto, and returns a regular python value. A serializer
# receives a regular python value, and returns something suitable for
# a value field on WidgetState proto. They should be inverses.
WidgetDeserializer: TypeAlias = Callable[[Any, str], T]
WidgetSerializer: TypeAlias = Callable[[T], Any]
@dataclass(frozen=True)
class WidgetMetadata(Generic[T]):
"""Metadata associated with a single widget. Immutable."""
id: str
deserializer: WidgetDeserializer[T] = field(repr=False)
serializer: WidgetSerializer[T] = field(repr=False)
value_type: ValueFieldName
# An optional user-code callback invoked when the widget's value changes.
# Widget callbacks are called at the start of a script run, before the
# body of the script is executed.
callback: WidgetCallback | None = None
callback_args: WidgetArgs | None = None
callback_kwargs: WidgetKwargs | None = None
@dataclass(frozen=True)
class RegisterWidgetResult(Generic[T_co]):
"""Result returned by the `register_widget` family of functions/methods.
Should be usable by widget code to determine what value to return, and
whether to update the UI.
Parameters
----------
value : T_co
The widget's current value, or, in cases where the true widget value
could not be determined, an appropriate fallback value.
This value should be returned by the widget call.
value_changed : bool
True if the widget's value is different from the value most recently
returned from the frontend.
Implies an update to the frontend is needed.
"""
value: T_co
value_changed: bool
@classmethod
def failure(
cls, deserializer: WidgetDeserializer[T_co]
) -> "RegisterWidgetResult[T_co]":
"""The canonical way to construct a RegisterWidgetResult in cases
where the true widget value could not be determined.
"""
return cls(value=deserializer(None, ""), value_changed=False)
def compute_widget_id(
element_type: str, element_proto: WidgetProto, user_key: Optional[str] = None
) -> str:
"""Compute the widget id for the given widget. This id is stable: a given
set of inputs to this function will always produce the same widget id output.
The widget id includes the user_key so widgets with identical arguments can
use it to be distinct.
The widget id includes an easily identified prefix, and the user_key as a
suffix, to make it easy to identify it and know if a key maps to it.
Does not mutate the element_proto object.
"""
h = hashlib.new("md5")
h.update(element_type.encode("utf-8"))
h.update(element_proto.SerializeToString())
return f"{GENERATED_WIDGET_ID_PREFIX}-{h.hexdigest()}-{user_key}"
def user_key_from_widget_id(widget_id: str) -> Optional[str]:
"""Return the user key portion of a widget id, or None if the id does not
have a user key.
TODO This will incorrectly indicate no user key if the user actually provides
"None" as a key, but we can't avoid this kind of problem while storing the
string representation of the no-user-key sentinel as part of the widget id.
"""
user_key = widget_id.split("-", maxsplit=2)[-1]
user_key = None if user_key == "None" else user_key
return user_key
def is_widget_id(key: str) -> bool:
"""True if the given session_state key has the structure of a widget ID."""
return key.startswith(GENERATED_WIDGET_ID_PREFIX)
def is_keyed_widget_id(key: str) -> bool:
"""True if the given session_state key has the structure of a widget ID with a user_key."""
return is_widget_id(key) and not key.endswith("-None")
def require_valid_user_key(key: str) -> None:
"""Raise an Exception if the given user_key is invalid."""
if is_widget_id(key):
raise StreamlitAPIException(
f"Keys beginning with {GENERATED_WIDGET_ID_PREFIX} are reserved."
)

View File

@@ -0,0 +1,134 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import Any, Dict, List, Optional, Set
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.state.common import RegisterWidgetResult, T, WidgetMetadata
from streamlit.runtime.state.session_state import SessionState
class SafeSessionState:
"""Thread-safe wrapper around SessionState.
When AppSession gets a re-run request, it can interrupt its existing
ScriptRunner and spin up a new ScriptRunner to handle the request.
When this happens, the existing ScriptRunner will continue executing
its script until it reaches a yield point - but during this time, it
must not mutate its SessionState. An interrupted ScriptRunner assigns
a dummy SessionState instance to its wrapper to prevent further mutation.
"""
def __init__(self, state: SessionState):
self._state = state
# TODO: we'd prefer this be a threading.Lock instead of RLock -
# but `call_callbacks` first needs to be rewritten.
self._lock = threading.RLock()
self._disconnected = False
def disconnect(self) -> None:
"""Disconnect the wrapper from its underlying SessionState.
ScriptRunner calls this when it gets a stop request. After this
function is called, all future SessionState interactions are no-ops.
"""
with self._lock:
self._disconnected = True
def register_widget(
self, metadata: WidgetMetadata[T], user_key: Optional[str]
) -> RegisterWidgetResult[T]:
with self._lock:
if self._disconnected:
return RegisterWidgetResult.failure(metadata.deserializer)
return self._state.register_widget(metadata, user_key)
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
with self._lock:
if self._disconnected:
return
# TODO: rewrite this to copy the callbacks list into a local
# variable so that we don't need to hold our lock for the
# duration. (This will also allow us to downgrade our RLock
# to a Lock.)
self._state.on_script_will_rerun(latest_widget_states)
def on_script_finished(self, widget_ids_this_run: Set[str]) -> None:
with self._lock:
if self._disconnected:
return
self._state.on_script_finished(widget_ids_this_run)
def maybe_check_serializable(self) -> None:
with self._lock:
if self._disconnected:
return
self._state.maybe_check_serializable()
def get_widget_states(self) -> List[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
with self._lock:
if self._disconnected:
return []
return self._state.get_widget_states()
def is_new_state_value(self, user_key: str) -> bool:
with self._lock:
if self._disconnected:
return False
return self._state.is_new_state_value(user_key)
@property
def filtered_state(self) -> Dict[str, Any]:
"""The combined session and widget state, excluding keyless widgets."""
with self._lock:
if self._disconnected:
return {}
return self._state.filtered_state
def __getitem__(self, key: str) -> Any:
with self._lock:
if self._disconnected:
raise KeyError(key)
return self._state[key]
def __setitem__(self, key: str, value: Any) -> None:
with self._lock:
if self._disconnected:
return
self._state[key] = value
def __delitem__(self, key: str) -> None:
with self._lock:
if self._disconnected:
raise KeyError(key)
del self._state[key]
def __contains__(self, key: str) -> bool:
with self._lock:
if self._disconnected:
return False
return key in self._state

View File

@@ -0,0 +1,648 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import json
import pickle
from copy import deepcopy
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Any,
Iterator,
KeysView,
List,
MutableMapping,
Union,
cast,
)
from pympler.asizeof import asizeof
from typing_extensions import Final, TypeAlias
import streamlit as st
from streamlit import config
from streamlit.errors import StreamlitAPIException, UnserializableSessionStateError
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
from streamlit.runtime.state.common import (
RegisterWidgetResult,
T,
WidgetMetadata,
is_keyed_widget_id,
is_widget_id,
)
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
from streamlit.type_util import ValueFieldName, is_array_value_field_name
if TYPE_CHECKING:
from streamlit.runtime.session_manager import SessionManager
STREAMLIT_INTERNAL_KEY_PREFIX: Final = "$$STREAMLIT_INTERNAL_KEY"
SCRIPT_RUN_WITHOUT_ERRORS_KEY: Final = (
f"{STREAMLIT_INTERNAL_KEY_PREFIX}_SCRIPT_RUN_WITHOUT_ERRORS"
)
@dataclass(frozen=True)
class Serialized:
"""A widget value that's serialized to a protobuf. Immutable."""
value: WidgetStateProto
@dataclass(frozen=True)
class Value:
"""A widget value that's not serialized. Immutable."""
value: Any
WState: TypeAlias = Union[Value, Serialized]
@dataclass
class WStates(MutableMapping[str, Any]):
"""A mapping of widget IDs to values. Widget values can be stored in
serialized or deserialized form, but when values are retrieved from the
mapping, they'll always be deserialized.
"""
states: dict[str, WState] = field(default_factory=dict)
widget_metadata: dict[str, WidgetMetadata[Any]] = field(default_factory=dict)
def __getitem__(self, k: str) -> Any:
"""Return the value of the widget with the given key.
If the widget's value is currently stored in serialized form, it
will be deserialized first.
"""
wstate = self.states.get(k)
if wstate is None:
raise KeyError(k)
if isinstance(wstate, Value):
# The widget's value is already deserialized - return it directly.
return wstate.value
# The widget's value is serialized. We deserialize it, and return
# the deserialized value.
metadata = self.widget_metadata.get(k)
if metadata is None:
# No deserializer, which should only happen if state is
# gotten from a reconnecting browser and the script is
# trying to access it. Pretend it doesn't exist.
raise KeyError(k)
value_field_name = cast(
ValueFieldName,
wstate.value.WhichOneof("value"),
)
value = wstate.value.__getattribute__(value_field_name)
if is_array_value_field_name(value_field_name):
# Array types are messages with data in a `data` field
value = value.data
elif value_field_name == "json_value":
value = json.loads(value)
deserialized = metadata.deserializer(value, metadata.id)
# Update metadata to reflect information from WidgetState proto
self.set_widget_metadata(
replace(
metadata,
value_type=value_field_name,
)
)
self.states[k] = Value(deserialized)
return deserialized
def __setitem__(self, k: str, v: WState) -> None:
self.states[k] = v
def __delitem__(self, k: str) -> None:
del self.states[k]
def __len__(self) -> int:
return len(self.states)
def __iter__(self):
# For this and many other methods, we can't simply delegate to the
# states field, because we need to invoke `__getitem__` for any
# values, to handle deserialization and unwrapping of values.
for key in self.states:
yield key
def keys(self) -> KeysView[str]:
return KeysView(self.states)
def items(self) -> set[tuple[str, Any]]: # type: ignore[override]
return {(k, self[k]) for k in self}
def values(self) -> set[Any]: # type: ignore[override]
return {self[wid] for wid in self}
def update(self, other: "WStates") -> None: # type: ignore[override]
"""Copy all widget values and metadata from 'other' into this mapping,
overwriting any data in this mapping that's also present in 'other'.
"""
self.states.update(other.states)
self.widget_metadata.update(other.widget_metadata)
def set_widget_from_proto(self, widget_state: WidgetStateProto) -> None:
"""Set a widget's serialized value, overwriting any existing value it has."""
self[widget_state.id] = Serialized(widget_state)
def set_from_value(self, k: str, v: Any) -> None:
"""Set a widget's deserialized value, overwriting any existing value it has."""
self[k] = Value(v)
def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
"""Set a widget's metadata, overwriting any existing metadata it has."""
self.widget_metadata[widget_meta.id] = widget_meta
def remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
self.states = {k: v for k, v in self.states.items() if k in active_widget_ids}
def get_serialized(self, k: str) -> WidgetStateProto | None:
"""Get the serialized value of the widget with the given id.
If the widget doesn't exist, return None. If the widget exists but
is not in serialized form, it will be serialized first.
"""
item = self.states.get(k)
if item is None:
# No such widget: return None.
return None
if isinstance(item, Serialized):
# Widget value is serialized: return it directly.
return item.value
# Widget value is not serialized: serialize it first!
metadata = self.widget_metadata.get(k)
if metadata is None:
# We're missing the widget's metadata. (Can this happen?)
return None
widget = WidgetStateProto()
widget.id = k
field = metadata.value_type
serialized = metadata.serializer(item.value)
if is_array_value_field_name(field):
arr = getattr(widget, field)
arr.data.extend(serialized)
elif field == "json_value":
setattr(widget, field, json.dumps(serialized))
elif field == "file_uploader_state_value":
widget.file_uploader_state_value.CopyFrom(serialized)
else:
setattr(widget, field, serialized)
return widget
def as_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
states = [
self.get_serialized(widget_id)
for widget_id in self.states.keys()
if self.get_serialized(widget_id)
]
states = cast(List[WidgetStateProto], states)
return states
def call_callback(self, widget_id: str) -> None:
"""Call the given widget's callback and return the callback's
return value. If the widget has no callback, return None.
If the widget doesn't exist, raise an Exception.
"""
metadata = self.widget_metadata.get(widget_id)
assert metadata is not None
callback = metadata.callback
if callback is None:
return
args = metadata.callback_args or ()
kwargs = metadata.callback_kwargs or {}
callback(*args, **kwargs)
def _missing_key_error_message(key: str) -> str:
return (
f'st.session_state has no key "{key}". Did you forget to initialize it? '
f"More info: https://docs.streamlit.io/library/advanced-features/session-state#initialization"
)
@dataclass
class SessionState:
"""SessionState allows users to store values that persist between app
reruns.
Example
-------
>>> if "num_script_runs" not in st.session_state:
... st.session_state.num_script_runs = 0
>>> st.session_state.num_script_runs += 1
>>> st.write(st.session_state.num_script_runs) # writes 1
The next time your script runs, the value of
st.session_state.num_script_runs will be preserved.
>>> st.session_state.num_script_runs += 1
>>> st.write(st.session_state.num_script_runs) # writes 2
"""
# All the values from previous script runs, squished together to save memory
_old_state: dict[str, Any] = field(default_factory=dict)
# Values set in session state during the current script run, possibly for
# setting a widget's value. Keyed by a user provided string.
_new_session_state: dict[str, Any] = field(default_factory=dict)
# Widget values from the frontend, usually one changing prompted the script rerun
_new_widget_state: WStates = field(default_factory=WStates)
# Keys used for widgets will be eagerly converted to the matching widget id
_key_id_mapping: dict[str, str] = field(default_factory=dict)
# is it possible for a value to get through this without being deserialized?
def _compact_state(self) -> None:
"""Copy all current session_state and widget_state values into our
_old_state dict, and then clear our current session_state and
widget_state.
"""
for key_or_wid in self:
self._old_state[key_or_wid] = self[key_or_wid]
self._new_session_state.clear()
self._new_widget_state.clear()
def clear(self) -> None:
"""Reset self completely, clearing all current and old values."""
self._old_state.clear()
self._new_session_state.clear()
self._new_widget_state.clear()
self._key_id_mapping.clear()
@property
def filtered_state(self) -> dict[str, Any]:
"""The combined session and widget state, excluding keyless widgets."""
wid_key_map = self._reverse_key_wid_map
state: dict[str, Any] = {}
# We can't write `for k, v in self.items()` here because doing so will
# run into a `KeyError` if widget metadata has been cleared (which
# happens when the streamlit server restarted or the cache was cleared),
# then we receive a widget's state from a browser.
for k in self._keys():
if not is_widget_id(k) and not _is_internal_key(k):
state[k] = self[k]
elif is_keyed_widget_id(k):
try:
key = wid_key_map[k]
state[key] = self[k]
except KeyError:
# Widget id no longer maps to a key, it is a not yet
# cleared value in old state for a reset widget
pass
return state
@property
def _reverse_key_wid_map(self) -> dict[str, str]:
"""Return a mapping of widget_id : widget_key."""
wid_key_map = {v: k for k, v in self._key_id_mapping.items()}
return wid_key_map
def _keys(self) -> set[str]:
"""All keys active in Session State, with widget keys converted
to widget ids when one is known. (This includes autogenerated keys
for widgets that don't have user_keys defined, and which aren't
exposed to user code.)
"""
old_keys = {self._get_widget_id(k) for k in self._old_state.keys()}
new_widget_keys = set(self._new_widget_state.keys())
new_session_state_keys = {
self._get_widget_id(k) for k in self._new_session_state.keys()
}
return old_keys | new_widget_keys | new_session_state_keys
def is_new_state_value(self, user_key: str) -> bool:
"""True if a value with the given key is in the current session state."""
return user_key in self._new_session_state
def __iter__(self) -> Iterator[Any]:
"""Return an iterator over the keys of the SessionState.
This is a shortcut for `iter(self.keys())`
"""
return iter(self._keys())
def __len__(self) -> int:
"""Return the number of items in SessionState."""
return len(self._keys())
def __getitem__(self, key: str) -> Any:
wid_key_map = self._reverse_key_wid_map
widget_id = self._get_widget_id(key)
if widget_id in wid_key_map and widget_id == key:
# the "key" is a raw widget id, so get its associated user key for lookup
key = wid_key_map[widget_id]
try:
return self._getitem(widget_id, key)
except KeyError:
raise KeyError(_missing_key_error_message(key))
def _getitem(self, widget_id: str | None, user_key: str | None) -> Any:
"""Get the value of an entry in Session State, using either the
user-provided key or a widget id as appropriate for the internal dict
being accessed.
At least one of the arguments must have a value.
"""
assert user_key is not None or widget_id is not None
if user_key is not None:
try:
return self._new_session_state[user_key]
except KeyError:
pass
if widget_id is not None:
try:
return self._new_widget_state[widget_id]
except KeyError:
pass
# Typically, there won't be both a widget id and an associated state key in
# old state at the same time, so the order we check is arbitrary.
# The exception is if session state is set and then a later run has
# a widget created, so the widget id entry should be newer.
# The opposite case shouldn't happen, because setting the value of a widget
# through session state will result in the next widget state reflecting that
# value.
if widget_id is not None:
try:
return self._old_state[widget_id]
except KeyError:
pass
if user_key is not None:
try:
return self._old_state[user_key]
except KeyError:
pass
# We'll never get here
raise KeyError
def __setitem__(self, user_key: str, value: Any) -> None:
"""Set the value of the session_state entry with the given user_key.
If the key corresponds to a widget or form that's been instantiated
during the current script run, raise a StreamlitAPIException instead.
"""
from streamlit.runtime.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx()
if ctx is not None:
widget_id = self._key_id_mapping.get(user_key, None)
widget_ids = ctx.widget_ids_this_run
form_ids = ctx.form_ids_this_run
if widget_id in widget_ids or user_key in form_ids:
raise StreamlitAPIException(
f"`st.session_state.{user_key}` cannot be modified after the widget"
f" with key `{user_key}` is instantiated."
)
self._new_session_state[user_key] = value
def __delitem__(self, key: str) -> None:
widget_id = self._get_widget_id(key)
if not (key in self or widget_id in self):
raise KeyError(_missing_key_error_message(key))
if key in self._new_session_state:
del self._new_session_state[key]
if key in self._old_state:
del self._old_state[key]
if key in self._key_id_mapping:
del self._key_id_mapping[key]
if widget_id in self._new_widget_state:
del self._new_widget_state[widget_id]
if widget_id in self._old_state:
del self._old_state[widget_id]
def set_widgets_from_proto(self, widget_states: WidgetStatesProto) -> None:
"""Set the value of all widgets represented in the given WidgetStatesProto."""
for state in widget_states.widgets:
self._new_widget_state.set_widget_from_proto(state)
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
"""Called by ScriptRunner before its script re-runs.
Update widget data and call callbacks on widgets whose value changed
between the previous and current script runs.
"""
# Update ourselves with the new widget_states. The old widget states,
# used to skip callbacks if values haven't changed, are also preserved.
self._compact_state()
self.set_widgets_from_proto(latest_widget_states)
self._call_callbacks()
def _call_callbacks(self) -> None:
"""Call any callback associated with each widget whose value
changed between the previous and current script runs.
"""
from streamlit.runtime.scriptrunner import RerunException
changed_widget_ids = [
wid for wid in self._new_widget_state if self._widget_changed(wid)
]
for wid in changed_widget_ids:
try:
self._new_widget_state.call_callback(wid)
except RerunException:
st.warning(
"Calling st.experimental_rerun() within a callback is a no-op."
)
def _widget_changed(self, widget_id: str) -> bool:
"""True if the given widget's value changed between the previous
script run and the current script run.
"""
new_value = self._new_widget_state.get(widget_id)
old_value = self._old_state.get(widget_id)
changed: bool = new_value != old_value
return changed
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
"""Called by ScriptRunner after its script finishes running.
Updates widgets to prepare for the next script run.
Parameters
----------
widget_ids_this_run: set[str]
The IDs of the widgets that were accessed during the script
run. Any widget state whose ID does *not* appear in this set
is considered "stale" and will be removed.
"""
self._reset_triggers()
self._remove_stale_widgets(widget_ids_this_run)
def _reset_triggers(self) -> None:
"""Set all trigger values in our state dictionary to False."""
for state_id in self._new_widget_state:
metadata = self._new_widget_state.widget_metadata.get(state_id)
if metadata is not None and metadata.value_type == "trigger_value":
self._new_widget_state[state_id] = Value(False)
for state_id in self._old_state:
metadata = self._new_widget_state.widget_metadata.get(state_id)
if metadata is not None and metadata.value_type == "trigger_value":
self._old_state[state_id] = False
def _remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
self._new_widget_state.remove_stale_widgets(active_widget_ids)
# Remove entries from _old_state corresponding to
# widgets not in widget_ids.
self._old_state = {
k: v
for k, v in self._old_state.items()
if (k in active_widget_ids or not is_widget_id(k))
}
def _set_widget_metadata(self, widget_metadata: WidgetMetadata[Any]) -> None:
"""Set a widget's metadata."""
widget_id = widget_metadata.id
self._new_widget_state.widget_metadata[widget_id] = widget_metadata
def get_widget_states(self) -> list[WidgetStateProto]:
"""Return a list of serialized widget values for each widget with a value."""
return self._new_widget_state.as_widget_states()
def _get_widget_id(self, k: str) -> str:
"""Turns a value that might be a widget id or a user provided key into
an appropriate widget id.
"""
return self._key_id_mapping.get(k, k)
def _set_key_widget_mapping(self, widget_id: str, user_key: str) -> None:
self._key_id_mapping[user_key] = widget_id
def register_widget(
self, metadata: WidgetMetadata[T], user_key: str | None
) -> RegisterWidgetResult[T]:
"""Register a widget with the SessionState.
Returns
-------
RegisterWidgetResult[T]
Contains the widget's current value, and a bool that will be True
if the frontend needs to be updated with the current value.
"""
widget_id = metadata.id
self._set_widget_metadata(metadata)
if user_key is not None:
# If the widget has a user_key, update its user_key:widget_id mapping
self._set_key_widget_mapping(widget_id, user_key)
if widget_id not in self and (user_key is None or user_key not in self):
# This is the first time the widget is registered, so we save its
# value in widget state.
deserializer = metadata.deserializer
initial_widget_value = deepcopy(deserializer(None, metadata.id))
self._new_widget_state.set_from_value(widget_id, initial_widget_value)
# Get the current value of the widget for use as its return value.
# We return a copy, so that reference types can't be accidentally
# mutated by user code.
widget_value = cast(T, self[widget_id])
widget_value = deepcopy(widget_value)
# widget_value_changed indicates to the caller that the widget's
# current value is different from what is in the frontend.
widget_value_changed = user_key is not None and self.is_new_state_value(
user_key
)
return RegisterWidgetResult(widget_value, widget_value_changed)
def __contains__(self, key: str) -> bool:
try:
self[key]
except KeyError:
return False
else:
return True
def get_stats(self) -> list[CacheStat]:
stat = CacheStat("st_session_state", "", asizeof(self))
return [stat]
def _check_serializable(self) -> None:
"""Verify that everything added to session state can be serialized.
We use pickleability as the metric for serializability, and test for
pickleability by just trying it.
"""
for k in self:
try:
pickle.dumps(self[k])
except Exception as e:
err_msg = f"""Cannot serialize the value (of type `{type(self[k])}`) of '{k}' in st.session_state.
Streamlit has been configured to use [pickle](https://docs.python.org/3/library/pickle.html) to
serialize session_state values. Please convert the value to a pickle-serializable type. To learn
more about this behavior, see [our docs](https://docs.streamlit.io/knowledge-base/using-streamlit/serializable-session-state). """
raise UnserializableSessionStateError(err_msg) from e
def maybe_check_serializable(self) -> None:
"""Verify that session state can be serialized, if the relevant config
option is set.
See `_check_serializable` for details."""
if config.get_option("runner.enforceSerializableSessionState"):
self._check_serializable()
def _is_internal_key(key: str) -> bool:
return key.startswith(STREAMLIT_INTERNAL_KEY_PREFIX)
@dataclass
class SessionStateStatProvider(CacheStatsProvider):
_session_mgr: "SessionManager"
def get_stats(self) -> list[CacheStat]:
stats: list[CacheStat] = []
for session_info in self._session_mgr.list_active_sessions():
session_state = session_info.session.session_state
stats.extend(session_state.get_stats())
return stats

View File

@@ -0,0 +1,142 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Iterator, MutableMapping
from typing_extensions import Final
from streamlit import logger as _logger
from streamlit import runtime
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.state.common import require_valid_user_key
from streamlit.runtime.state.safe_session_state import SafeSessionState
from streamlit.runtime.state.session_state import SessionState
from streamlit.type_util import Key
LOGGER: Final = _logger.get_logger(__name__)
_state_use_warning_already_displayed: bool = False
def get_session_state() -> SafeSessionState:
"""Get the SessionState object for the current session.
Note that in streamlit scripts, this function should not be called
directly. Instead, SessionState objects should be accessed via
st.session_state.
"""
global _state_use_warning_already_displayed
from streamlit.runtime.scriptrunner import get_script_run_ctx
ctx = get_script_run_ctx()
# If there is no script run context because the script is run bare, have
# session state act as an always empty dictionary, and print a warning.
if ctx is None:
if not _state_use_warning_already_displayed:
_state_use_warning_already_displayed = True
if not runtime.exists():
LOGGER.warning(
"Session state does not function when running a script without `streamlit run`"
)
return SafeSessionState(SessionState())
return ctx.session_state
class SessionStateProxy(MutableMapping[Key, Any]):
"""A stateless singleton that proxies `st.session_state` interactions
to the current script thread's SessionState instance.
The proxy API differs slightly from SessionState: it does not allow
callers to get, set, or iterate over "keyless" widgets (that is, widgets
that were created without a user_key, and have autogenerated keys).
"""
def __iter__(self) -> Iterator[Any]:
"""Iterator over user state and keyed widget values."""
# TODO: this is unsafe if fastReruns is true! Let's deprecate/remove.
return iter(get_session_state().filtered_state)
def __len__(self) -> int:
"""Number of user state and keyed widget values in session_state."""
return len(get_session_state().filtered_state)
def __str__(self) -> str:
"""String representation of user state and keyed widget values."""
return str(get_session_state().filtered_state)
def __getitem__(self, key: Key) -> Any:
"""Return the state or widget value with the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
return get_session_state()[key]
@gather_metrics("session_state.set_item")
def __setitem__(self, key: Key, value: Any) -> None:
"""Set the value of the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
get_session_state()[key] = value
def __delitem__(self, key: Key) -> None:
"""Delete the value with the given key.
Raises
------
StreamlitAPIException
If the key is not a valid SessionState user key.
"""
key = str(key)
require_valid_user_key(key)
del get_session_state()[key]
def __getattr__(self, key: str) -> Any:
try:
return self[key]
except KeyError:
raise AttributeError(_missing_attr_error_message(key))
@gather_metrics("session_state.set_attr")
def __setattr__(self, key: str, value: Any) -> None:
self[key] = value
def __delattr__(self, key: str) -> None:
try:
del self[key]
except KeyError:
raise AttributeError(_missing_attr_error_message(key))
def to_dict(self) -> Dict[str, Any]:
"""Return a dict containing all session_state and keyed widget values."""
return get_session_state().filtered_state
def _missing_attr_error_message(attr_name: str) -> str:
return (
f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? '
f"More info: https://docs.streamlit.io/library/advanced-features/session-state#initialization"
)

View File

@@ -0,0 +1,281 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap
from types import MappingProxyType
from typing import TYPE_CHECKING, Dict, Mapping, Optional
from typing_extensions import Final, TypeAlias
from streamlit.errors import DuplicateWidgetID
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
from streamlit.runtime.state.common import (
RegisterWidgetResult,
T,
WidgetArgs,
WidgetCallback,
WidgetDeserializer,
WidgetKwargs,
WidgetMetadata,
WidgetProto,
WidgetSerializer,
compute_widget_id,
user_key_from_widget_id,
)
from streamlit.type_util import ValueFieldName
if TYPE_CHECKING:
from streamlit.runtime.scriptrunner import ScriptRunContext
ElementType: TypeAlias = str
# NOTE: We use this table to start with a best-effort guess for the value_type
# of each widget. Once we actually receive a proto for a widget from the
# frontend, the guess is updated to be the correct type. Unfortunately, we're
# not able to always rely on the proto as the type may be needed earlier.
# Thankfully, in these cases (when value_type == "trigger_value"), the static
# table here being slightly inaccurate should never pose a problem.
ELEMENT_TYPE_TO_VALUE_TYPE: Final[
Mapping[ElementType, ValueFieldName]
] = MappingProxyType(
{
"button": "trigger_value",
"download_button": "trigger_value",
"checkbox": "bool_value",
"camera_input": "file_uploader_state_value",
"color_picker": "string_value",
"date_input": "string_array_value",
"file_uploader": "file_uploader_state_value",
"multiselect": "int_array_value",
"number_input": "double_value",
"radio": "int_value",
"selectbox": "int_value",
"slider": "double_array_value",
"text_area": "string_value",
"text_input": "string_value",
"time_input": "string_value",
"component_instance": "json_value",
"data_editor": "string_value",
}
)
class NoValue:
"""Return this from DeltaGenerator.foo_widget() when you want the st.foo_widget()
call to return None. This is needed because `DeltaGenerator._enqueue`
replaces `None` with a `DeltaGenerator` (for use in non-widget elements).
"""
pass
def register_widget(
element_type: ElementType,
element_proto: WidgetProto,
deserializer: WidgetDeserializer[T],
serializer: WidgetSerializer[T],
ctx: Optional["ScriptRunContext"],
user_key: Optional[str] = None,
widget_func_name: Optional[str] = None,
on_change_handler: Optional[WidgetCallback] = None,
args: Optional[WidgetArgs] = None,
kwargs: Optional[WidgetKwargs] = None,
) -> RegisterWidgetResult[T]:
"""Register a widget with Streamlit, and return its current value.
NOTE: This function should be called after the proto has been filled.
Parameters
----------
element_type : ElementType
The type of the element as stored in proto.
element_proto : WidgetProto
The proto of the specified type (e.g. Button/Multiselect/Slider proto)
deserializer : WidgetDeserializer[T]
Called to convert a widget's protobuf value to the value returned by
its st.<widget_name> function.
serializer : WidgetSerializer[T]
Called to convert a widget's value to its protobuf representation.
ctx : Optional[ScriptRunContext]
Used to ensure uniqueness of widget IDs, and to look up widget values.
user_key : Optional[str]
Optional user-specified string to use as the widget ID.
If this is None, we'll generate an ID by hashing the element.
widget_func_name : Optional[str]
The widget's DeltaGenerator function name, if it's different from
its element_type. Custom components are a special case: they all have
the element_type "component_instance", but are instantiated with
dynamically-named functions.
on_change_handler : Optional[WidgetCallback]
An optional callback invoked when the widget's value changes.
args : Optional[WidgetArgs]
args to pass to on_change_handler when invoked
kwargs : Optional[WidgetKwargs]
kwargs to pass to on_change_handler when invoked
Returns
-------
register_widget_result : RegisterWidgetResult[T]
Provides information on which value to return to the widget caller,
and whether the UI needs updating.
- Unhappy path:
- Our ScriptRunContext doesn't exist (meaning that we're running
as a "bare script" outside streamlit).
- We are disconnected from the SessionState instance.
In both cases we'll return a fallback RegisterWidgetResult[T].
- Happy path:
- The widget has already been registered on a previous run but the
user hasn't interacted with it on the client. The widget will have
the default value it was first created with. We then return a
RegisterWidgetResult[T], containing this value.
- The widget has already been registered and the user *has*
interacted with it. The widget will have that most recent
user-specified value. We then return a RegisterWidgetResult[T],
containing this value.
For both paths a widget return value is provided, allowing the widgets
to be used in a non-streamlit setting.
"""
widget_id = compute_widget_id(element_type, element_proto, user_key)
element_proto.id = widget_id
# Create the widget's updated metadata, and register it with session_state.
metadata = WidgetMetadata(
widget_id,
deserializer,
serializer,
value_type=ELEMENT_TYPE_TO_VALUE_TYPE[element_type],
callback=on_change_handler,
callback_args=args,
callback_kwargs=kwargs,
)
return register_widget_from_metadata(metadata, ctx, widget_func_name, element_type)
def register_widget_from_metadata(
metadata: WidgetMetadata[T],
ctx: Optional["ScriptRunContext"],
widget_func_name: Optional[str],
element_type: ElementType,
) -> RegisterWidgetResult[T]:
"""Register a widget and return its value, using an already constructed
`WidgetMetadata`.
This is split out from `register_widget` to allow caching code to replay
widgets by saving and reusing the completed metadata.
See `register_widget` for details on what this returns.
"""
# Local import to avoid import cycle
import streamlit.runtime.caching as caching
if ctx is None:
# Early-out if we don't have a script run context (which probably means
# we're running as a "bare" Python script, and not via `streamlit run`).
return RegisterWidgetResult.failure(deserializer=metadata.deserializer)
widget_id = metadata.id
user_key = user_key_from_widget_id(widget_id)
# Ensure another widget with the same user key hasn't already been registered.
if user_key is not None:
if user_key not in ctx.widget_user_keys_this_run:
ctx.widget_user_keys_this_run.add(user_key)
else:
raise DuplicateWidgetID(
_build_duplicate_widget_message(
widget_func_name if widget_func_name is not None else element_type,
user_key,
)
)
# Ensure another widget with the same id hasn't already been registered.
new_widget = widget_id not in ctx.widget_ids_this_run
if new_widget:
ctx.widget_ids_this_run.add(widget_id)
else:
raise DuplicateWidgetID(
_build_duplicate_widget_message(
widget_func_name if widget_func_name is not None else element_type,
user_key,
)
)
# Save the widget metadata for cached result replay
caching.save_widget_metadata(metadata)
return ctx.session_state.register_widget(metadata, user_key)
def coalesce_widget_states(
old_states: WidgetStates, new_states: WidgetStates
) -> WidgetStates:
"""Coalesce an older WidgetStates into a newer one, and return a new
WidgetStates containing the result.
For most widget values, we just take the latest version.
However, any trigger_values (which are set by buttons) that are True in
`old_states` will be set to True in the coalesced result, so that button
presses don't go missing.
"""
states_by_id: Dict[str, WidgetState] = {
wstate.id: wstate for wstate in new_states.widgets
}
for old_state in old_states.widgets:
if old_state.WhichOneof("value") == "trigger_value" and old_state.trigger_value:
# Ensure the corresponding new_state is also a trigger;
# otherwise, a widget that was previously a button but no longer is
# could get a bad value.
new_trigger_val = states_by_id.get(old_state.id)
if (
new_trigger_val
and new_trigger_val.WhichOneof("value") == "trigger_value"
):
states_by_id[old_state.id] = old_state
coalesced = WidgetStates()
coalesced.widgets.extend(states_by_id.values())
return coalesced
def _build_duplicate_widget_message(
widget_func_name: str, user_key: Optional[str] = None
) -> str:
if user_key is not None:
message = textwrap.dedent(
"""
There are multiple widgets with the same `key='{user_key}'`.
To fix this, please make sure that the `key` argument is unique for each
widget you create.
"""
)
else:
message = textwrap.dedent(
"""
There are multiple identical `st.{widget_type}` widgets with the
same generated key.
When a widget is created, it's assigned an internal key based on
its structure. Multiple widgets with an identical structure will
result in the same internal key, which causes this error.
To fix this error, please pass a unique `key` argument to
`st.{widget_type}`.
"""
)
return message.strip("\n").format(widget_type=widget_func_name, user_key=user_key)

View File

@@ -0,0 +1,88 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import List, NamedTuple
from typing_extensions import Protocol, runtime_checkable
from streamlit.proto.openmetrics_data_model_pb2 import Metric as MetricProto
class CacheStat(NamedTuple):
"""Describes a single cache entry.
Properties
----------
category_name : str
A human-readable name for the cache "category" that the entry belongs
to - e.g. "st.memo", "session_state", etc.
cache_name : str
A human-readable name for cache instance that the entry belongs to.
For "st.memo" and other function decorator caches, this might be the
name of the cached function. If the cache category doesn't have
multiple separate cache instances, this can just be the empty string.
byte_length : int
The entry's memory footprint in bytes.
"""
category_name: str
cache_name: str
byte_length: int
def to_metric_str(self) -> str:
return 'cache_memory_bytes{cache_type="%s",cache="%s"} %s' % (
self.category_name,
self.cache_name,
self.byte_length,
)
def marshall_metric_proto(self, metric: MetricProto) -> None:
"""Fill an OpenMetrics `Metric` protobuf object."""
label = metric.labels.add()
label.name = "cache_type"
label.value = self.category_name
label = metric.labels.add()
label.name = "cache"
label.value = self.cache_name
metric_point = metric.metric_points.add()
metric_point.gauge_value.int_value = self.byte_length
@runtime_checkable
class CacheStatsProvider(Protocol):
@abstractmethod
def get_stats(self) -> List[CacheStat]:
raise NotImplementedError
class StatsManager:
def __init__(self):
self._cache_stats_providers: List[CacheStatsProvider] = []
def register_provider(self, provider: CacheStatsProvider) -> None:
"""Register a CacheStatsProvider with the manager.
This function is not thread-safe. Call it immediately after
creation.
"""
self._cache_stats_providers.append(provider)
def get_stats(self) -> List[CacheStat]:
"""Return a list containing all stats from each registered provider."""
all_stats: List[CacheStat] = []
for provider in self._cache_stats_providers:
all_stats.extend(provider.get_stats())
return all_stats

View File

@@ -0,0 +1,334 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import threading
from typing import Dict, List, NamedTuple, Tuple
from blinker import Signal
from streamlit import util
from streamlit.logger import get_logger
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
LOGGER = get_logger(__name__)
class UploadedFileRec(NamedTuple):
"""Metadata and raw bytes for an uploaded file. Immutable."""
id: int
name: str
type: str
data: bytes
class UploadedFile(io.BytesIO):
"""A mutable uploaded file.
This class extends BytesIO, which has copy-on-write semantics when
initialized with `bytes`.
"""
def __init__(self, record: UploadedFileRec):
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
# the Python docs - possibly because it's a CPython-only optimization
# and not guaranteed to be in other Python runtimes. But it's detailed
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
super(UploadedFile, self).__init__(record.data)
self.id = record.id
self.name = record.name
self.type = record.type
self.size = len(record.data)
def __eq__(self, other: object) -> bool:
if not isinstance(other, UploadedFile):
return NotImplemented
return self.id == other.id
def __repr__(self) -> str:
return util.repr_(self)
class UploadedFileManager(CacheStatsProvider):
"""Holds files uploaded by users of the running Streamlit app,
and emits an event signal when a file is added.
This class can be used safely from multiple threads simultaneously.
"""
def __init__(self):
# List of files for a given widget in a given session.
self._files_by_id: Dict[Tuple[str, str], List[UploadedFileRec]] = {}
# A counter that generates unique file IDs. Each file ID is greater
# than the previous ID, which means we can use IDs to compare files
# by age.
self._file_id_counter = 1
self._file_id_lock = threading.Lock()
# Prevents concurrent access to the _files_by_id dict.
# In remove_session_files(), we iterate over the dict's keys. It's
# an error to mutate a dict while iterating; this lock prevents that.
self._files_lock = threading.Lock()
self.on_files_updated = Signal(
doc="""Emitted when a file list is added to the manager or updated.
Parameters
----------
session_id : str
The session_id for the session whose files were updated.
"""
)
def __repr__(self) -> str:
return util.repr_(self)
def add_file(
self,
session_id: str,
widget_id: str,
file: UploadedFileRec,
) -> UploadedFileRec:
"""Add a file to the FileManager, and return a new UploadedFileRec
with its ID assigned.
The "on_files_updated" Signal will be emitted.
Safe to call from any thread.
Parameters
----------
session_id
The ID of the session that owns the file.
widget_id
The widget ID of the FileUploader that created the file.
file
The file to add.
Returns
-------
UploadedFileRec
The added file, which has its unique ID assigned.
"""
files_by_widget = session_id, widget_id
# Assign the file a unique ID
file_id = self._get_next_file_id()
file = UploadedFileRec(
id=file_id, name=file.name, type=file.type, data=file.data
)
with self._files_lock:
file_list = self._files_by_id.get(files_by_widget, None)
if file_list is not None:
file_list.append(file)
else:
self._files_by_id[files_by_widget] = [file]
self.on_files_updated.send(session_id)
return file
def get_all_files(self, session_id: str, widget_id: str) -> List[UploadedFileRec]:
"""Return all the files stored for the given widget.
Safe to call from any thread.
Parameters
----------
session_id
The ID of the session that owns the files.
widget_id
The widget ID of the FileUploader that created the files.
"""
file_list_id = (session_id, widget_id)
with self._files_lock:
return self._files_by_id.get(file_list_id, []).copy()
def get_files(
self, session_id: str, widget_id: str, file_ids: List[int]
) -> List[UploadedFileRec]:
"""Return the files with the given widget_id and file_ids.
Safe to call from any thread.
Parameters
----------
session_id
The ID of the session that owns the files.
widget_id
The widget ID of the FileUploader that created the files.
file_ids
List of file IDs. Only files whose IDs are in this list will be
returned.
"""
return [
f for f in self.get_all_files(session_id, widget_id) if f.id in file_ids
]
def remove_orphaned_files(
self,
session_id: str,
widget_id: str,
newest_file_id: int,
active_file_ids: List[int],
) -> None:
"""Remove 'orphaned' files: files that have been uploaded and
subsequently deleted, but haven't yet been removed from memory.
Because FileUploader can live inside forms, file deletion is made a
bit tricky: a file deletion should only happen after the form is
submitted.
FileUploader's widget value is an array of numbers that has two parts:
- The first number is always 'this.state.newestServerFileId'.
- The remaining 0 or more numbers are the file IDs of all the
uploader's uploaded files.
When the server receives the widget value, it deletes "orphaned"
uploaded files. An orphaned file is any file associated with a given
FileUploader whose file ID is not in the active_file_ids, and whose
ID is <= `newestServerFileId`.
This logic ensures that a FileUploader within a form doesn't have any
of its "unsubmitted" uploads prematurely deleted when the script is
re-run.
Safe to call from any thread.
"""
file_list_id = (session_id, widget_id)
with self._files_lock:
file_list = self._files_by_id.get(file_list_id)
if file_list is None:
return
# Remove orphaned files from the list:
# - `f.id in active_file_ids`:
# File is currently tracked by the widget. DON'T remove.
# - `f.id > newest_file_id`:
# file was uploaded *after* the widget was most recently
# updated. (It's probably in a form.) DON'T remove.
# - `f.id < newest_file_id and f.id not in active_file_ids`:
# File is not currently tracked by the widget, and was uploaded
# *before* this most recent update. This means it's been deleted
# by the user on the frontend, and is now "orphaned". Remove!
new_list = [
f for f in file_list if f.id > newest_file_id or f.id in active_file_ids
]
self._files_by_id[file_list_id] = new_list
num_removed = len(file_list) - len(new_list)
if num_removed > 0:
LOGGER.debug("Removed %s orphaned files" % num_removed)
def remove_file(self, session_id: str, widget_id: str, file_id: int) -> bool:
"""Remove the file list with the given ID, if it exists.
The "on_files_updated" Signal will be emitted.
Safe to call from any thread.
Returns
-------
bool
True if the file was removed, or False if no such file exists.
"""
file_list_id = (session_id, widget_id)
with self._files_lock:
file_list = self._files_by_id.get(file_list_id, None)
if file_list is None:
return False
# Remove the file from its list.
new_file_list = [file for file in file_list if file.id != file_id]
self._files_by_id[file_list_id] = new_file_list
self.on_files_updated.send(session_id)
return True
def _remove_files(self, session_id: str, widget_id: str) -> None:
"""Remove the file list for the provided widget in the
provided session, if it exists.
Does not emit any signals.
Safe to call from any thread.
"""
files_by_widget = session_id, widget_id
with self._files_lock:
self._files_by_id.pop(files_by_widget, None)
def remove_files(self, session_id: str, widget_id: str) -> None:
"""Remove the file list for the provided widget in the
provided session, if it exists.
The "on_files_updated" Signal will be emitted.
Safe to call from any thread.
Parameters
----------
session_id : str
The ID of the session that owns the files.
widget_id : str
The widget ID of the FileUploader that created the files.
"""
self._remove_files(session_id, widget_id)
self.on_files_updated.send(session_id)
def remove_session_files(self, session_id: str) -> None:
"""Remove all files that belong to the given session.
Safe to call from any thread.
Parameters
----------
session_id : str
The ID of the session whose files we're removing.
"""
# Copy the keys into a list, because we'll be mutating the dictionary.
with self._files_lock:
all_ids = list(self._files_by_id.keys())
for files_id in all_ids:
if files_id[0] == session_id:
self.remove_files(*files_id)
def _get_next_file_id(self) -> int:
"""Return the next file ID and increment our ID counter."""
with self._file_id_lock:
file_id = self._file_id_counter
self._file_id_counter += 1
return file_id
def get_stats(self) -> List[CacheStat]:
"""Return the manager's CacheStats.
Safe to call from any thread.
"""
with self._files_lock:
# Flatten all files into a single list
all_files: List[UploadedFileRec] = []
for file_list in self._files_by_id.values():
all_files.extend(file_list)
return [
CacheStat(
category_name="UploadedFileManager",
cache_name="",
byte_length=len(file.data),
)
for file in all_files
]

View File

@@ -0,0 +1,157 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, cast
from typing_extensions import Final
from streamlit.logger import get_logger
from streamlit.runtime.app_session import AppSession
from streamlit.runtime.script_data import ScriptData
from streamlit.runtime.session_manager import (
ActiveSessionInfo,
SessionClient,
SessionInfo,
SessionManager,
SessionStorage,
)
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.watcher import LocalSourcesWatcher
LOGGER: Final = get_logger(__name__)
class WebsocketSessionManager(SessionManager):
"""A SessionManager used to manage sessions with lifecycles tied to those of a
browser tab's websocket connection.
WebsocketSessionManagers differentiate between "active" and "inactive" sessions.
Active sessions are those with a currently active websocket connection. Inactive
sessions are sessions without. Eventual cleanup of inactive sessions is a detail left
to the specific SessionStorage that a WebsocketSessionManager is instantiated with.
"""
def __init__(
self,
session_storage: SessionStorage,
uploaded_file_manager: UploadedFileManager,
message_enqueued_callback: Optional[Callable[[], None]],
) -> None:
self._session_storage = session_storage
self._uploaded_file_mgr = uploaded_file_manager
self._message_enqueued_callback = message_enqueued_callback
# Mapping of AppSession.id -> ActiveSessionInfo.
self._active_session_info_by_id: Dict[str, ActiveSessionInfo] = {}
def connect_session(
self,
client: SessionClient,
script_data: ScriptData,
user_info: Dict[str, Optional[str]],
existing_session_id: Optional[str] = None,
) -> str:
if existing_session_id in self._active_session_info_by_id:
LOGGER.warning(
"Session with id %s is already connected! Connecting to a new session.",
existing_session_id,
)
session_info = (
existing_session_id
and existing_session_id not in self._active_session_info_by_id
and self._session_storage.get(existing_session_id)
)
if session_info:
existing_session = session_info.session
existing_session.register_file_watchers()
self._active_session_info_by_id[existing_session.id] = ActiveSessionInfo(
client,
existing_session,
session_info.script_run_count,
)
self._session_storage.delete(existing_session.id)
return existing_session.id
session = AppSession(
script_data=script_data,
uploaded_file_manager=self._uploaded_file_mgr,
message_enqueued_callback=self._message_enqueued_callback,
local_sources_watcher=LocalSourcesWatcher(script_data.main_script_path),
user_info=user_info,
)
LOGGER.debug(
"Created new session for client %s. Session ID: %s", id(client), session.id
)
assert (
session.id not in self._active_session_info_by_id
), f"session.id '{session.id}' registered multiple times!"
self._active_session_info_by_id[session.id] = ActiveSessionInfo(client, session)
return session.id
def disconnect_session(self, session_id: str) -> None:
if session_id in self._active_session_info_by_id:
active_session_info = self._active_session_info_by_id[session_id]
session = active_session_info.session
session.request_script_stop()
session.disconnect_file_watchers()
self._session_storage.save(
SessionInfo(
client=None,
session=session,
script_run_count=active_session_info.script_run_count,
)
)
del self._active_session_info_by_id[session_id]
def get_active_session_info(self, session_id: str) -> Optional[ActiveSessionInfo]:
return self._active_session_info_by_id.get(session_id)
def is_active_session(self, session_id: str) -> bool:
return session_id in self._active_session_info_by_id
def list_active_sessions(self) -> List[ActiveSessionInfo]:
return list(self._active_session_info_by_id.values())
def close_session(self, session_id: str) -> None:
if session_id in self._active_session_info_by_id:
active_session_info = self._active_session_info_by_id[session_id]
del self._active_session_info_by_id[session_id]
active_session_info.session.shutdown()
return
session_info = self._session_storage.get(session_id)
if session_info:
self._session_storage.delete(session_id)
session_info.session.shutdown()
def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
session_info = self.get_active_session_info(session_id)
if session_info:
return cast(SessionInfo, session_info)
return self._session_storage.get(session_id)
def list_sessions(self) -> List[SessionInfo]:
return (
cast(List[SessionInfo], self.list_active_sessions())
+ self._session_storage.list()
)