Merging PR_218 openai_rev package with new streamlit chat app
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Explicitly re-export public symbols from runtime.py and session_manager.py
|
||||
from streamlit.runtime.runtime import Runtime as Runtime
|
||||
from streamlit.runtime.runtime import RuntimeConfig as RuntimeConfig
|
||||
from streamlit.runtime.runtime import RuntimeState as RuntimeState
|
||||
from streamlit.runtime.session_manager import SessionClient as SessionClient
|
||||
from streamlit.runtime.session_manager import (
|
||||
SessionClientDisconnectedError as SessionClientDisconnectedError,
|
||||
)
|
||||
|
||||
|
||||
def get_instance() -> Runtime:
|
||||
"""Return the singleton Runtime instance. Raise an Error if the
|
||||
Runtime hasn't been created yet.
|
||||
"""
|
||||
return Runtime.instance()
|
||||
|
||||
|
||||
def exists() -> bool:
|
||||
"""True if the singleton Runtime instance has been created.
|
||||
|
||||
When a Streamlit app is running in "raw mode" - that is, when the
|
||||
app is run via `python app.py` instead of `streamlit run app.py` -
|
||||
the Runtime will not exist, and various Streamlit functions need
|
||||
to adapt.
|
||||
"""
|
||||
return Runtime.exists()
|
||||
@@ -0,0 +1,806 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import sys
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
|
||||
|
||||
import streamlit.elements.exception as exception_utils
|
||||
from streamlit import config, runtime, source_util
|
||||
from streamlit.case_converters import to_snake_case
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.BackMsg_pb2 import BackMsg
|
||||
from streamlit.proto.ClientState_pb2 import ClientState
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.proto.GitInfo_pb2 import GitInfo
|
||||
from streamlit.proto.NewSession_pb2 import (
|
||||
Config,
|
||||
CustomThemeConfig,
|
||||
NewSession,
|
||||
UserInfo,
|
||||
)
|
||||
from streamlit.proto.PagesChanged_pb2 import PagesChanged
|
||||
from streamlit.runtime import caching, legacy_caching
|
||||
from streamlit.runtime.credentials import Credentials
|
||||
from streamlit.runtime.forward_msg_queue import ForwardMsgQueue
|
||||
from streamlit.runtime.metrics_util import Installation
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.scriptrunner import RerunData, ScriptRunner, ScriptRunnerEvent
|
||||
from streamlit.runtime.secrets import secrets_singleton
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
from streamlit.version import STREAMLIT_VERSION_STRING
|
||||
from streamlit.watcher import LocalSourcesWatcher
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.state import SessionState
|
||||
|
||||
|
||||
class AppSessionState(Enum):
|
||||
APP_NOT_RUNNING = "APP_NOT_RUNNING"
|
||||
APP_IS_RUNNING = "APP_IS_RUNNING"
|
||||
SHUTDOWN_REQUESTED = "SHUTDOWN_REQUESTED"
|
||||
|
||||
|
||||
def _generate_scriptrun_id() -> str:
|
||||
"""Randomly generate a unique ID for a script execution."""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class AppSession:
|
||||
"""
|
||||
Contains session data for a single "user" of an active app
|
||||
(that is, a connected browser tab).
|
||||
|
||||
Each AppSession has its own ScriptData, root DeltaGenerator, ScriptRunner,
|
||||
and widget state.
|
||||
|
||||
An AppSession is attached to each thread involved in running its script.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
script_data: ScriptData,
|
||||
uploaded_file_manager: UploadedFileManager,
|
||||
message_enqueued_callback: Optional[Callable[[], None]],
|
||||
local_sources_watcher: LocalSourcesWatcher,
|
||||
user_info: Dict[str, Optional[str]],
|
||||
) -> None:
|
||||
"""Initialize the AppSession.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
script_data : ScriptData
|
||||
Object storing parameters related to running a script
|
||||
|
||||
uploaded_file_manager : UploadedFileManager
|
||||
Used to manage files uploaded by users via the Streamlit web client.
|
||||
|
||||
message_enqueued_callback : Callable[[], None]
|
||||
After enqueuing a message, this callable notification will be invoked.
|
||||
|
||||
local_sources_watcher: LocalSourcesWatcher
|
||||
The file watcher that lets the session know local files have changed.
|
||||
|
||||
user_info: Dict
|
||||
A dict that contains information about the current user. For now,
|
||||
it only contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
|
||||
Information about the current user is optionally provided when a
|
||||
websocket connection is initialized via the "X-Streamlit-User" header.
|
||||
|
||||
"""
|
||||
# Each AppSession has a unique string ID.
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
self._event_loop = asyncio.get_running_loop()
|
||||
self._script_data = script_data
|
||||
self._uploaded_file_mgr = uploaded_file_manager
|
||||
|
||||
# The browser queue contains messages that haven't yet been
|
||||
# delivered to the browser. Periodically, the server flushes
|
||||
# this queue and delivers its contents to the browser.
|
||||
self._browser_queue = ForwardMsgQueue()
|
||||
self._message_enqueued_callback = message_enqueued_callback
|
||||
|
||||
self._state = AppSessionState.APP_NOT_RUNNING
|
||||
|
||||
# Need to remember the client state here because when a script reruns
|
||||
# due to the source code changing we need to pass in the previous client state.
|
||||
self._client_state = ClientState()
|
||||
|
||||
self._local_sources_watcher: Optional[
|
||||
LocalSourcesWatcher
|
||||
] = local_sources_watcher
|
||||
self._stop_config_listener: Optional[Callable[[], bool]] = None
|
||||
self._stop_pages_listener: Optional[Callable[[], bool]] = None
|
||||
|
||||
self.register_file_watchers()
|
||||
|
||||
self._run_on_save = config.get_option("server.runOnSave")
|
||||
|
||||
self._scriptrunner: Optional[ScriptRunner] = None
|
||||
|
||||
# This needs to be lazily imported to avoid a dependency cycle.
|
||||
from streamlit.runtime.state import SessionState
|
||||
|
||||
self._session_state = SessionState()
|
||||
self._user_info = user_info
|
||||
|
||||
self._debug_last_backmsg_id: Optional[str] = None
|
||||
|
||||
LOGGER.debug("AppSession initialized (id=%s)", self.id)
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Ensure that we call shutdown() when an AppSession is garbage collected."""
|
||||
self.shutdown()
|
||||
|
||||
def register_file_watchers(self) -> None:
|
||||
"""Register handlers to be called when various files are changed.
|
||||
|
||||
Files that we watch include:
|
||||
* source files that already exist (for edits)
|
||||
* `.py` files in the the main script's `pages/` directory (for file additions
|
||||
and deletions)
|
||||
* project and user-level config.toml files
|
||||
* the project-level secrets.toml files
|
||||
|
||||
This method is called automatically on AppSession construction, but it may be
|
||||
called again in the case when a session is disconnected and is being reconnect
|
||||
to.
|
||||
"""
|
||||
if self._local_sources_watcher is None:
|
||||
self._local_sources_watcher = LocalSourcesWatcher(
|
||||
self._script_data.main_script_path
|
||||
)
|
||||
|
||||
self._local_sources_watcher.register_file_change_callback(
|
||||
self._on_source_file_changed
|
||||
)
|
||||
self._stop_config_listener = config.on_config_parsed(
|
||||
self._on_source_file_changed, force_connect=True
|
||||
)
|
||||
self._stop_pages_listener = source_util.register_pages_changed_callback(
|
||||
self._on_pages_changed
|
||||
)
|
||||
secrets_singleton.file_change_listener.connect(self._on_secrets_file_changed)
|
||||
|
||||
def disconnect_file_watchers(self) -> None:
|
||||
"""Disconnect the file watcher handlers registered by register_file_watchers."""
|
||||
if self._local_sources_watcher is not None:
|
||||
self._local_sources_watcher.close()
|
||||
if self._stop_config_listener is not None:
|
||||
self._stop_config_listener()
|
||||
if self._stop_pages_listener is not None:
|
||||
self._stop_pages_listener()
|
||||
|
||||
secrets_singleton.file_change_listener.disconnect(self._on_secrets_file_changed)
|
||||
|
||||
self._local_sources_watcher = None
|
||||
self._stop_config_listener = None
|
||||
self._stop_pages_listener = None
|
||||
|
||||
def flush_browser_queue(self) -> List[ForwardMsg]:
|
||||
"""Clear the forward message queue and return the messages it contained.
|
||||
|
||||
The Server calls this periodically to deliver new messages
|
||||
to the browser connected to this app.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[ForwardMsg]
|
||||
The messages that were removed from the queue and should
|
||||
be delivered to the browser.
|
||||
|
||||
"""
|
||||
return self._browser_queue.flush()
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shut down the AppSession.
|
||||
|
||||
It's an error to use a AppSession after it's been shut down.
|
||||
|
||||
"""
|
||||
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
|
||||
LOGGER.debug("Shutting down (id=%s)", self.id)
|
||||
# Clear any unused session files in upload file manager and media
|
||||
# file manager
|
||||
self._uploaded_file_mgr.remove_session_files(self.id)
|
||||
|
||||
if runtime.exists():
|
||||
runtime.get_instance().media_file_mgr.clear_session_refs(self.id)
|
||||
runtime.get_instance().media_file_mgr.remove_orphaned_files()
|
||||
|
||||
# Shut down the ScriptRunner, if one is active.
|
||||
# self._state must not be set to SHUTDOWN_REQUESTED until
|
||||
# *after* this is called.
|
||||
self.request_script_stop()
|
||||
|
||||
self._state = AppSessionState.SHUTDOWN_REQUESTED
|
||||
|
||||
# Disconnect all file watchers if we haven't already, although we will have
|
||||
# generally already done so by the time we get here.
|
||||
self.disconnect_file_watchers()
|
||||
|
||||
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Enqueue a new ForwardMsg to our browser queue.
|
||||
|
||||
This can be called on both the main thread and a ScriptRunner
|
||||
run thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
The message to enqueue
|
||||
|
||||
"""
|
||||
if not config.get_option("client.displayEnabled"):
|
||||
return
|
||||
|
||||
if self._debug_last_backmsg_id:
|
||||
msg.debug_last_backmsg_id = self._debug_last_backmsg_id
|
||||
|
||||
self._browser_queue.enqueue(msg)
|
||||
if self._message_enqueued_callback:
|
||||
self._message_enqueued_callback()
|
||||
|
||||
def handle_backmsg(self, msg: BackMsg) -> None:
|
||||
"""Process a BackMsg."""
|
||||
try:
|
||||
msg_type = msg.WhichOneof("type")
|
||||
|
||||
if msg_type == "rerun_script":
|
||||
if msg.debug_last_backmsg_id:
|
||||
self._debug_last_backmsg_id = msg.debug_last_backmsg_id
|
||||
|
||||
self._handle_rerun_script_request(msg.rerun_script)
|
||||
elif msg_type == "load_git_info":
|
||||
self._handle_git_information_request()
|
||||
elif msg_type == "clear_cache":
|
||||
self._handle_clear_cache_request()
|
||||
elif msg_type == "set_run_on_save":
|
||||
self._handle_set_run_on_save_request(msg.set_run_on_save)
|
||||
elif msg_type == "stop_script":
|
||||
self._handle_stop_script_request()
|
||||
else:
|
||||
LOGGER.warning('No handler for "%s"', msg_type)
|
||||
|
||||
except Exception as ex:
|
||||
LOGGER.error(ex)
|
||||
self.handle_backmsg_exception(ex)
|
||||
|
||||
def handle_backmsg_exception(self, e: BaseException) -> None:
|
||||
"""Handle an Exception raised while processing a BackMsg from the browser."""
|
||||
# This does a few things:
|
||||
# 1) Clears the current app in the browser.
|
||||
# 2) Marks the current app as "stopped" in the browser.
|
||||
# 3) HACK: Resets any script params that may have been broken (e.g. the
|
||||
# command-line when rerunning with wrong argv[0])
|
||||
|
||||
self._on_scriptrunner_event(
|
||||
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
)
|
||||
self._on_scriptrunner_event(
|
||||
self._scriptrunner,
|
||||
ScriptRunnerEvent.SCRIPT_STARTED,
|
||||
page_script_hash="",
|
||||
)
|
||||
self._on_scriptrunner_event(
|
||||
self._scriptrunner, ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
)
|
||||
|
||||
# Send an Exception message to the frontend.
|
||||
# Because _on_scriptrunner_event does its work in an eventloop callback,
|
||||
# this exception ForwardMsg *must* also be enqueued in a callback,
|
||||
# so that it will be enqueued *after* the various ForwardMsgs that
|
||||
# _on_scriptrunner_event sends.
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
lambda: self._enqueue_forward_msg(self._create_exception_message(e))
|
||||
)
|
||||
|
||||
def request_rerun(self, client_state: Optional[ClientState]) -> None:
|
||||
"""Signal that we're interested in running the script.
|
||||
|
||||
If the script is not already running, it will be started immediately.
|
||||
Otherwise, a rerun will be requested.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client_state : streamlit.proto.ClientState_pb2.ClientState | None
|
||||
The ClientState protobuf to run the script with, or None
|
||||
to use previous client state.
|
||||
|
||||
"""
|
||||
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
|
||||
LOGGER.warning("Discarding rerun request after shutdown")
|
||||
return
|
||||
|
||||
if client_state:
|
||||
rerun_data = RerunData(
|
||||
client_state.query_string,
|
||||
client_state.widget_states,
|
||||
client_state.page_script_hash,
|
||||
client_state.page_name,
|
||||
)
|
||||
else:
|
||||
rerun_data = RerunData()
|
||||
|
||||
if self._scriptrunner is not None:
|
||||
if bool(config.get_option("runner.fastReruns")):
|
||||
# If fastReruns is enabled, we don't send rerun requests to our
|
||||
# existing ScriptRunner. Instead, we tell it to shut down. We'll
|
||||
# then spin up a new ScriptRunner, below, to handle the rerun
|
||||
# immediately.
|
||||
self._scriptrunner.request_stop()
|
||||
self._scriptrunner = None
|
||||
else:
|
||||
# fastReruns is not enabled. Send our ScriptRunner a rerun
|
||||
# request. If the request is accepted, we're done.
|
||||
success = self._scriptrunner.request_rerun(rerun_data)
|
||||
if success:
|
||||
return
|
||||
|
||||
# If we are here, then either we have no ScriptRunner, or our
|
||||
# current ScriptRunner is shutting down and cannot handle a rerun
|
||||
# request - so we'll create and start a new ScriptRunner.
|
||||
self._create_scriptrunner(rerun_data)
|
||||
|
||||
def request_script_stop(self) -> None:
|
||||
"""Request that the scriptrunner stop execution.
|
||||
|
||||
Does nothing if no scriptrunner exists.
|
||||
"""
|
||||
if self._scriptrunner is not None:
|
||||
self._scriptrunner.request_stop()
|
||||
|
||||
def _create_scriptrunner(self, initial_rerun_data: RerunData) -> None:
|
||||
"""Create and run a new ScriptRunner with the given RerunData."""
|
||||
self._scriptrunner = ScriptRunner(
|
||||
session_id=self.id,
|
||||
main_script_path=self._script_data.main_script_path,
|
||||
client_state=self._client_state,
|
||||
session_state=self._session_state,
|
||||
uploaded_file_mgr=self._uploaded_file_mgr,
|
||||
initial_rerun_data=initial_rerun_data,
|
||||
user_info=self._user_info,
|
||||
)
|
||||
self._scriptrunner.on_event.connect(self._on_scriptrunner_event)
|
||||
self._scriptrunner.start()
|
||||
|
||||
@property
|
||||
def session_state(self) -> "SessionState":
|
||||
return self._session_state
|
||||
|
||||
def _should_rerun_on_file_change(self, filepath: str) -> bool:
|
||||
main_script_path = self._script_data.main_script_path
|
||||
pages = source_util.get_pages(main_script_path)
|
||||
|
||||
changed_page_script_hash = next(
|
||||
filter(lambda k: pages[k]["script_path"] == filepath, pages),
|
||||
None,
|
||||
)
|
||||
|
||||
if changed_page_script_hash is not None:
|
||||
current_page_script_hash = self._client_state.page_script_hash
|
||||
return changed_page_script_hash == current_page_script_hash
|
||||
|
||||
return True
|
||||
|
||||
def _on_source_file_changed(self, filepath: Optional[str] = None) -> None:
|
||||
"""One of our source files changed. Schedule a rerun if appropriate."""
|
||||
if filepath is not None and not self._should_rerun_on_file_change(filepath):
|
||||
return
|
||||
|
||||
if self._run_on_save:
|
||||
self.request_rerun(self._client_state)
|
||||
else:
|
||||
self._enqueue_forward_msg(self._create_file_change_message())
|
||||
|
||||
def _on_secrets_file_changed(self, _) -> None:
|
||||
"""Called when `secrets.file_change_listener` emits a Signal."""
|
||||
|
||||
# NOTE: At the time of writing, this function only calls `_on_source_file_changed`.
|
||||
# The reason behind creating this function instead of just passing `_on_source_file_changed`
|
||||
# to `connect` / `disconnect` directly is that every function that is passed to `connect` / `disconnect`
|
||||
# must have at least one argument for `sender` (in this case we don't really care about it, thus `_`),
|
||||
# and introducing an unnecessary argument to `_on_source_file_changed` just for this purpose sounded finicky.
|
||||
self._on_source_file_changed()
|
||||
|
||||
def _on_pages_changed(self, _) -> None:
|
||||
msg = ForwardMsg()
|
||||
_populate_app_pages(msg.pages_changed, self._script_data.main_script_path)
|
||||
self._enqueue_forward_msg(msg)
|
||||
|
||||
def _clear_queue(self) -> None:
|
||||
self._browser_queue.clear()
|
||||
|
||||
def _on_scriptrunner_event(
|
||||
self,
|
||||
sender: Optional[ScriptRunner],
|
||||
event: ScriptRunnerEvent,
|
||||
forward_msg: Optional[ForwardMsg] = None,
|
||||
exception: Optional[BaseException] = None,
|
||||
client_state: Optional[ClientState] = None,
|
||||
page_script_hash: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Called when our ScriptRunner emits an event.
|
||||
|
||||
This is generally called from the sender ScriptRunner's script thread.
|
||||
We forward the event on to _handle_scriptrunner_event_on_event_loop,
|
||||
which will be called on the main thread.
|
||||
"""
|
||||
self._event_loop.call_soon_threadsafe(
|
||||
lambda: self._handle_scriptrunner_event_on_event_loop(
|
||||
sender, event, forward_msg, exception, client_state, page_script_hash
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_scriptrunner_event_on_event_loop(
|
||||
self,
|
||||
sender: Optional[ScriptRunner],
|
||||
event: ScriptRunnerEvent,
|
||||
forward_msg: Optional[ForwardMsg] = None,
|
||||
exception: Optional[BaseException] = None,
|
||||
client_state: Optional[ClientState] = None,
|
||||
page_script_hash: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Handle a ScriptRunner event.
|
||||
|
||||
This function must only be called on our eventloop thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sender : ScriptRunner | None
|
||||
The ScriptRunner that emitted the event. (This may be set to
|
||||
None when called from `handle_backmsg_exception`, if no
|
||||
ScriptRunner was active when the backmsg exception was raised.)
|
||||
|
||||
event : ScriptRunnerEvent
|
||||
The event type.
|
||||
|
||||
forward_msg : ForwardMsg | None
|
||||
The ForwardMsg to send to the frontend. Set only for the
|
||||
ENQUEUE_FORWARD_MSG event.
|
||||
|
||||
exception : BaseException | None
|
||||
An exception thrown during compilation. Set only for the
|
||||
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
|
||||
|
||||
client_state : streamlit.proto.ClientState_pb2.ClientState | None
|
||||
The ScriptRunner's final ClientState. Set only for the
|
||||
SHUTDOWN event.
|
||||
|
||||
page_script_hash : str | None
|
||||
A hash of the script path corresponding to the page currently being
|
||||
run. Set only for the SCRIPT_STARTED event.
|
||||
"""
|
||||
|
||||
assert (
|
||||
self._event_loop == asyncio.get_running_loop()
|
||||
), "This function must only be called on the eventloop thread the AppSession was created on."
|
||||
|
||||
if sender is not self._scriptrunner:
|
||||
# This event was sent by a non-current ScriptRunner; ignore it.
|
||||
# This can happen after sppinng up a new ScriptRunner (to handle a
|
||||
# rerun request, for example) while another ScriptRunner is still
|
||||
# shutting down. The shutting-down ScriptRunner may still
|
||||
# emit events.
|
||||
LOGGER.debug("Ignoring event from non-current ScriptRunner: %s", event)
|
||||
return
|
||||
|
||||
prev_state = self._state
|
||||
|
||||
if event == ScriptRunnerEvent.SCRIPT_STARTED:
|
||||
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
|
||||
self._state = AppSessionState.APP_IS_RUNNING
|
||||
|
||||
assert (
|
||||
page_script_hash is not None
|
||||
), "page_script_hash must be set for the SCRIPT_STARTED event"
|
||||
|
||||
self._clear_queue()
|
||||
self._enqueue_forward_msg(
|
||||
self._create_new_session_message(page_script_hash)
|
||||
)
|
||||
|
||||
elif (
|
||||
event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
or event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR
|
||||
):
|
||||
if self._state != AppSessionState.SHUTDOWN_REQUESTED:
|
||||
self._state = AppSessionState.APP_NOT_RUNNING
|
||||
|
||||
script_succeeded = event == ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
|
||||
script_finished_msg = self._create_script_finished_message(
|
||||
ForwardMsg.FINISHED_SUCCESSFULLY
|
||||
if script_succeeded
|
||||
else ForwardMsg.FINISHED_WITH_COMPILE_ERROR
|
||||
)
|
||||
self._enqueue_forward_msg(script_finished_msg)
|
||||
|
||||
self._debug_last_backmsg_id = None
|
||||
|
||||
if script_succeeded:
|
||||
# The script completed successfully: update our
|
||||
# LocalSourcesWatcher to account for any source code changes
|
||||
# that change which modules should be watched.
|
||||
if self._local_sources_watcher:
|
||||
self._local_sources_watcher.update_watched_modules()
|
||||
else:
|
||||
# The script didn't complete successfully: send the exception
|
||||
# to the frontend.
|
||||
assert (
|
||||
exception is not None
|
||||
), "exception must be set for the SCRIPT_STOPPED_WITH_COMPILE_ERROR event"
|
||||
msg = ForwardMsg()
|
||||
exception_utils.marshall(
|
||||
msg.session_event.script_compilation_exception, exception
|
||||
)
|
||||
self._enqueue_forward_msg(msg)
|
||||
|
||||
elif event == ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN:
|
||||
script_finished_msg = self._create_script_finished_message(
|
||||
ForwardMsg.FINISHED_EARLY_FOR_RERUN
|
||||
)
|
||||
self._enqueue_forward_msg(script_finished_msg)
|
||||
if self._local_sources_watcher:
|
||||
self._local_sources_watcher.update_watched_modules()
|
||||
|
||||
elif event == ScriptRunnerEvent.SHUTDOWN:
|
||||
assert (
|
||||
client_state is not None
|
||||
), "client_state must be set for the SHUTDOWN event"
|
||||
|
||||
if self._state == AppSessionState.SHUTDOWN_REQUESTED:
|
||||
# Only clear media files if the script is done running AND the
|
||||
# session is actually shutting down.
|
||||
runtime.get_instance().media_file_mgr.clear_session_refs(self.id)
|
||||
|
||||
self._client_state = client_state
|
||||
self._scriptrunner = None
|
||||
|
||||
elif event == ScriptRunnerEvent.ENQUEUE_FORWARD_MSG:
|
||||
assert (
|
||||
forward_msg is not None
|
||||
), "null forward_msg in ENQUEUE_FORWARD_MSG event"
|
||||
self._enqueue_forward_msg(forward_msg)
|
||||
|
||||
# Send a message if our run state changed
|
||||
app_was_running = prev_state == AppSessionState.APP_IS_RUNNING
|
||||
app_is_running = self._state == AppSessionState.APP_IS_RUNNING
|
||||
if app_is_running != app_was_running:
|
||||
self._enqueue_forward_msg(self._create_session_status_changed_message())
|
||||
|
||||
def _create_session_status_changed_message(self) -> ForwardMsg:
|
||||
"""Create and return a session_status_changed ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.session_status_changed.run_on_save = self._run_on_save
|
||||
msg.session_status_changed.script_is_running = (
|
||||
self._state == AppSessionState.APP_IS_RUNNING
|
||||
)
|
||||
return msg
|
||||
|
||||
def _create_file_change_message(self) -> ForwardMsg:
|
||||
"""Create and return a 'script_changed_on_disk' ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.session_event.script_changed_on_disk = True
|
||||
return msg
|
||||
|
||||
def _create_new_session_message(self, page_script_hash: str) -> ForwardMsg:
|
||||
"""Create and return a new_session ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
|
||||
msg.new_session.script_run_id = _generate_scriptrun_id()
|
||||
msg.new_session.name = self._script_data.name
|
||||
msg.new_session.main_script_path = self._script_data.main_script_path
|
||||
msg.new_session.page_script_hash = page_script_hash
|
||||
|
||||
_populate_app_pages(msg.new_session, self._script_data.main_script_path)
|
||||
_populate_config_msg(msg.new_session.config)
|
||||
_populate_theme_msg(msg.new_session.custom_theme)
|
||||
|
||||
# Immutable session data. We send this every time a new session is
|
||||
# started, to avoid having to track whether the client has already
|
||||
# received it. It does not change from run to run; it's up to the
|
||||
# to perform one-time initialization only once.
|
||||
imsg = msg.new_session.initialize
|
||||
|
||||
_populate_user_info_msg(imsg.user_info)
|
||||
|
||||
imsg.environment_info.streamlit_version = STREAMLIT_VERSION_STRING
|
||||
imsg.environment_info.python_version = ".".join(map(str, sys.version_info))
|
||||
|
||||
imsg.session_status.run_on_save = self._run_on_save
|
||||
imsg.session_status.script_is_running = (
|
||||
self._state == AppSessionState.APP_IS_RUNNING
|
||||
)
|
||||
|
||||
imsg.command_line = self._script_data.command_line
|
||||
imsg.session_id = self.id
|
||||
|
||||
return msg
|
||||
|
||||
def _create_script_finished_message(
|
||||
self, status: "ForwardMsg.ScriptFinishedStatus.ValueType"
|
||||
) -> ForwardMsg:
|
||||
"""Create and return a script_finished ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.script_finished = status
|
||||
return msg
|
||||
|
||||
def _create_exception_message(self, e: BaseException) -> ForwardMsg:
|
||||
"""Create and return an Exception ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
exception_utils.marshall(msg.delta.new_element.exception, e)
|
||||
return msg
|
||||
|
||||
def _handle_git_information_request(self) -> None:
|
||||
msg = ForwardMsg()
|
||||
|
||||
try:
|
||||
from streamlit.git_util import GitRepo
|
||||
|
||||
repo = GitRepo(self._script_data.main_script_path)
|
||||
|
||||
repo_info = repo.get_repo_info()
|
||||
if repo_info is None:
|
||||
return
|
||||
|
||||
repository_name, branch, module = repo_info
|
||||
|
||||
msg.git_info_changed.repository = repository_name
|
||||
msg.git_info_changed.branch = branch
|
||||
msg.git_info_changed.module = module
|
||||
|
||||
msg.git_info_changed.untracked_files[:] = repo.untracked_files
|
||||
msg.git_info_changed.uncommitted_files[:] = repo.uncommitted_files
|
||||
|
||||
if repo.is_head_detached:
|
||||
msg.git_info_changed.state = GitInfo.GitStates.HEAD_DETACHED
|
||||
elif len(repo.ahead_commits) > 0:
|
||||
msg.git_info_changed.state = GitInfo.GitStates.AHEAD_OF_REMOTE
|
||||
else:
|
||||
msg.git_info_changed.state = GitInfo.GitStates.DEFAULT
|
||||
|
||||
self._enqueue_forward_msg(msg)
|
||||
except Exception as ex:
|
||||
# Users may never even install Git in the first place, so this
|
||||
# error requires no action. It can be useful for debugging.
|
||||
LOGGER.debug("Obtaining Git information produced an error", exc_info=ex)
|
||||
|
||||
def _handle_rerun_script_request(
|
||||
self, client_state: Optional[ClientState] = None
|
||||
) -> None:
|
||||
"""Tell the ScriptRunner to re-run its script.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client_state : streamlit.proto.ClientState_pb2.ClientState | None
|
||||
The ClientState protobuf to run the script with, or None
|
||||
to use previous client state.
|
||||
|
||||
"""
|
||||
self.request_rerun(client_state)
|
||||
|
||||
def _handle_stop_script_request(self) -> None:
|
||||
"""Tell the ScriptRunner to stop running its script."""
|
||||
self.request_script_stop()
|
||||
|
||||
def _handle_clear_cache_request(self) -> None:
|
||||
"""Clear this app's cache.
|
||||
|
||||
Because this cache is global, it will be cleared for all users.
|
||||
|
||||
"""
|
||||
legacy_caching.clear_cache()
|
||||
caching.cache_data.clear()
|
||||
caching.cache_resource.clear()
|
||||
self._session_state.clear()
|
||||
|
||||
def _handle_set_run_on_save_request(self, new_value: bool) -> None:
|
||||
"""Change our run_on_save flag to the given value.
|
||||
|
||||
The browser will be notified of the change.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_value : bool
|
||||
New run_on_save value
|
||||
|
||||
"""
|
||||
self._run_on_save = new_value
|
||||
self._enqueue_forward_msg(self._create_session_status_changed_message())
|
||||
|
||||
|
||||
def _populate_config_msg(msg: Config) -> None:
|
||||
msg.gather_usage_stats = config.get_option("browser.gatherUsageStats")
|
||||
msg.max_cached_message_age = config.get_option("global.maxCachedMessageAge")
|
||||
msg.mapbox_token = config.get_option("mapbox.token")
|
||||
msg.allow_run_on_save = config.get_option("server.allowRunOnSave")
|
||||
msg.hide_top_bar = config.get_option("ui.hideTopBar")
|
||||
msg.hide_sidebar_nav = config.get_option("ui.hideSidebarNav")
|
||||
|
||||
|
||||
def _populate_theme_msg(msg: CustomThemeConfig) -> None:
|
||||
enum_encoded_options = {"base", "font"}
|
||||
theme_opts = config.get_options_for_section("theme")
|
||||
|
||||
if not any(theme_opts.values()):
|
||||
return
|
||||
|
||||
for option_name, option_val in theme_opts.items():
|
||||
if option_name not in enum_encoded_options and option_val is not None:
|
||||
setattr(msg, to_snake_case(option_name), option_val)
|
||||
|
||||
# NOTE: If unset, base and font will default to the protobuf enum zero
|
||||
# values, which are BaseTheme.LIGHT and FontFamily.SANS_SERIF,
|
||||
# respectively. This is why we both don't handle the cases explicitly and
|
||||
# also only log a warning when receiving invalid base/font options.
|
||||
base_map = {
|
||||
"light": msg.BaseTheme.LIGHT,
|
||||
"dark": msg.BaseTheme.DARK,
|
||||
}
|
||||
base = theme_opts["base"]
|
||||
if base is not None:
|
||||
if base not in base_map:
|
||||
LOGGER.warning(
|
||||
f'"{base}" is an invalid value for theme.base.'
|
||||
f" Allowed values include {list(base_map.keys())}."
|
||||
' Setting theme.base to "light".'
|
||||
)
|
||||
else:
|
||||
msg.base = base_map[base]
|
||||
|
||||
font_map = {
|
||||
"sans serif": msg.FontFamily.SANS_SERIF,
|
||||
"serif": msg.FontFamily.SERIF,
|
||||
"monospace": msg.FontFamily.MONOSPACE,
|
||||
}
|
||||
font = theme_opts["font"]
|
||||
if font is not None:
|
||||
if font not in font_map:
|
||||
LOGGER.warning(
|
||||
f'"{font}" is an invalid value for theme.font.'
|
||||
f" Allowed values include {list(font_map.keys())}."
|
||||
' Setting theme.font to "sans serif".'
|
||||
)
|
||||
else:
|
||||
msg.font = font_map[font]
|
||||
|
||||
|
||||
def _populate_user_info_msg(msg: UserInfo) -> None:
|
||||
msg.installation_id = Installation.instance().installation_id
|
||||
msg.installation_id_v3 = Installation.instance().installation_id_v3
|
||||
if Credentials.get_current().activation:
|
||||
msg.email = Credentials.get_current().activation.email
|
||||
else:
|
||||
msg.email = ""
|
||||
|
||||
|
||||
def _populate_app_pages(
|
||||
msg: Union[NewSession, PagesChanged], main_script_path: str
|
||||
) -> None:
|
||||
for page_script_hash, page_info in source_util.get_pages(main_script_path).items():
|
||||
page_proto = msg.app_pages.add()
|
||||
|
||||
page_proto.page_script_hash = page_script_hash
|
||||
page_proto.page_name = page_info["page_name"]
|
||||
page_proto.icon = page_info["icon"]
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
from typing import Any, Iterator, Union
|
||||
|
||||
from google.protobuf.message import Message
|
||||
|
||||
from streamlit.proto.Block_pb2 import Block
|
||||
from streamlit.runtime.caching.cache_data_api import (
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX,
|
||||
CacheDataAPI,
|
||||
_data_caches,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_errors import CACHE_DOCS_URL as CACHE_DOCS_URL
|
||||
from streamlit.runtime.caching.cache_resource_api import (
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
|
||||
CacheResourceAPI,
|
||||
_resource_caches,
|
||||
)
|
||||
from streamlit.runtime.state.common import WidgetMetadata
|
||||
|
||||
|
||||
def save_element_message(
|
||||
delta_type: str,
|
||||
element_proto: Message,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Save the message for an element to a thread-local callstack, so it can
|
||||
be used later to replay the element when a cache-decorated function's
|
||||
execution is skipped.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
|
||||
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
|
||||
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
|
||||
|
||||
def save_block_message(
|
||||
block_proto: Block,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Save the message for a block to a thread-local callstack, so it can
|
||||
be used later to replay the block when a cache-decorated function's
|
||||
execution is skipped.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
|
||||
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
|
||||
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
|
||||
|
||||
def save_widget_metadata(metadata: WidgetMetadata[Any]) -> None:
|
||||
"""Save a widget's metadata to a thread-local callstack, so the widget
|
||||
can be registered again when that widget is replayed.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
|
||||
|
||||
|
||||
def save_media_data(
|
||||
image_data: Union[bytes, str], mimetype: str, image_id: str
|
||||
) -> None:
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
|
||||
|
||||
|
||||
def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
|
||||
dg, st_func_name
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
|
||||
dg, st_func_name
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cached_st_function_warning() -> Iterator[None]:
|
||||
with CACHE_DATA_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning(), CACHE_RESOURCE_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning():
|
||||
yield
|
||||
|
||||
|
||||
# Explicitly export public symbols
|
||||
from streamlit.runtime.caching.cache_data_api import (
|
||||
get_data_cache_stats_provider as get_data_cache_stats_provider,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_resource_api import (
|
||||
get_resource_cache_stats_provider as get_resource_cache_stats_provider,
|
||||
)
|
||||
|
||||
# Create and export public API singletons.
|
||||
cache_data = CacheDataAPI(decorator_metric_name="cache_data")
|
||||
cache_resource = CacheResourceAPI(decorator_metric_name="cache_resource")
|
||||
|
||||
# Deprecated singletons
|
||||
_MEMO_WARNING = (
|
||||
f"`st.experimental_memo` is deprecated. Please use the new command `st.cache_data` instead, "
|
||||
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
|
||||
)
|
||||
|
||||
experimental_memo = CacheDataAPI(
|
||||
decorator_metric_name="experimental_memo", deprecation_warning=_MEMO_WARNING
|
||||
)
|
||||
|
||||
_SINGLETON_WARNING = (
|
||||
f"`st.experimental_singleton` is deprecated. Please use the new command `st.cache_resource` instead, "
|
||||
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
|
||||
)
|
||||
|
||||
experimental_singleton = CacheResourceAPI(
|
||||
decorator_metric_name="experimental_singleton",
|
||||
deprecation_warning=_SINGLETON_WARNING,
|
||||
)
|
||||
@@ -0,0 +1,678 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""@st.cache_data: pickle-based caching"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
import types
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import runtime
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cache_utils import (
|
||||
Cache,
|
||||
CachedFuncInfo,
|
||||
make_cached_func_wrapper,
|
||||
ttl_to_seconds,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
ElementMsgData,
|
||||
MsgData,
|
||||
MultiCacheResults,
|
||||
)
|
||||
from streamlit.runtime.caching.storage import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
InvalidCacheStorageContext,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.dummy_cache_storage import (
|
||||
MemoryCacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)
|
||||
|
||||
# The cache persistence options we support: "disk" or None
|
||||
CachePersistType: TypeAlias = Union[Literal["disk"], None]
|
||||
|
||||
|
||||
class CachedDataFuncInfo(CachedFuncInfo):
|
||||
"""Implements the CachedFuncInfo interface for @st.cache_data"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: float | timedelta | None,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
super().__init__(
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self.persist = persist
|
||||
self.max_entries = max_entries
|
||||
self.ttl = ttl
|
||||
|
||||
self.validate_params()
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
return CacheType.DATA
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
return CACHE_DATA_MESSAGE_REPLAY_CTX
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""A human-readable name for the cached function"""
|
||||
return f"{self.func.__module__}.{self.func.__qualname__}"
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
return _data_caches.get_cache(
|
||||
key=function_key,
|
||||
persist=self.persist,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
display_name=self.display_name,
|
||||
allow_widgets=self.allow_widgets,
|
||||
)
|
||||
|
||||
def validate_params(self) -> None:
|
||||
"""
|
||||
Validate the params passed to @st.cache_data are compatible with cache storage
|
||||
|
||||
When called, this method could log warnings if cache params are invalid
|
||||
for current storage.
|
||||
"""
|
||||
_data_caches.validate_cache_params(
|
||||
function_name=self.func.__name__,
|
||||
persist=self.persist,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
|
||||
|
||||
class DataCaches(CacheStatsProvider):
|
||||
"""Manages all DataCache instances"""
|
||||
|
||||
def __init__(self):
|
||||
self._caches_lock = threading.Lock()
|
||||
self._function_caches: dict[str, DataCache] = {}
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: int | float | timedelta | None,
|
||||
display_name: str,
|
||||
allow_widgets: bool,
|
||||
) -> DataCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
|
||||
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._caches_lock:
|
||||
cache = self._function_caches.get(key)
|
||||
if (
|
||||
cache is not None
|
||||
and cache.ttl_seconds == ttl_seconds
|
||||
and cache.max_entries == max_entries
|
||||
and cache.persist == persist
|
||||
):
|
||||
return cache
|
||||
|
||||
# Close the existing cache's storage, if it exists.
|
||||
if cache is not None:
|
||||
_LOGGER.debug(
|
||||
"Closing existing DataCache storage "
|
||||
"(key=%s, persist=%s, max_entries=%s, ttl=%s) "
|
||||
"before creating new one with different params",
|
||||
key,
|
||||
persist,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
cache.storage.close()
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug(
|
||||
"Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
|
||||
key,
|
||||
persist,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
|
||||
cache_context = self.create_cache_storage_context(
|
||||
function_key=key,
|
||||
function_name=display_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
cache_storage_manager = self.get_storage_manager()
|
||||
storage = cache_storage_manager.create(cache_context)
|
||||
|
||||
cache = DataCache(
|
||||
key=key,
|
||||
storage=storage,
|
||||
persist=persist,
|
||||
max_entries=max_entries,
|
||||
ttl_seconds=ttl_seconds,
|
||||
display_name=display_name,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self._function_caches[key] = cache
|
||||
return cache
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all in-memory and on-disk caches."""
|
||||
with self._caches_lock:
|
||||
try:
|
||||
# try to remove in optimal way if such ability provided by
|
||||
# storage manager clear_all method;
|
||||
# if not implemented, fallback to remove all
|
||||
# available storages one by one
|
||||
self.get_storage_manager().clear_all()
|
||||
except NotImplementedError:
|
||||
for data_cache in self._function_caches.values():
|
||||
data_cache.clear()
|
||||
data_cache.storage.close()
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
with self._caches_lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats: list[CacheStat] = []
|
||||
for cache in function_caches.values():
|
||||
stats.extend(cache.get_stats())
|
||||
return stats
|
||||
|
||||
def validate_cache_params(
|
||||
self,
|
||||
function_name: str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: int | float | timedelta | None,
|
||||
) -> None:
|
||||
"""Validate that the cache params are valid for given storage.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidCacheStorageContext
|
||||
Raised if the cache storage manager is not able to work with provided
|
||||
CacheStorageContext.
|
||||
"""
|
||||
|
||||
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
|
||||
|
||||
cache_context = self.create_cache_storage_context(
|
||||
function_key="DUMMY_KEY",
|
||||
function_name=function_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
try:
|
||||
self.get_storage_manager().check_context(cache_context)
|
||||
except InvalidCacheStorageContext as e:
|
||||
_LOGGER.error(
|
||||
"Cache params for function %s are incompatible with current "
|
||||
"cache storage manager: %s",
|
||||
function_name,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
def create_cache_storage_context(
|
||||
self,
|
||||
function_key: str,
|
||||
function_name: str,
|
||||
persist: CachePersistType,
|
||||
ttl_seconds: float | None,
|
||||
max_entries: int | None,
|
||||
) -> CacheStorageContext:
|
||||
return CacheStorageContext(
|
||||
function_key=function_key,
|
||||
function_display_name=function_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
|
||||
def get_storage_manager(self) -> CacheStorageManager:
|
||||
if runtime.exists():
|
||||
return runtime.get_instance().cache_storage_manager
|
||||
else:
|
||||
# When running in "raw mode", we can't access the CacheStorageManager,
|
||||
# so we're falling back to InMemoryCache.
|
||||
_LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
|
||||
return MemoryCacheStorageManager()
|
||||
|
||||
|
||||
# Singleton DataCaches instance
|
||||
_data_caches = DataCaches()
|
||||
|
||||
|
||||
def get_data_cache_stats_provider() -> CacheStatsProvider:
|
||||
"""Return the StatsProvider for all @st.cache_data functions."""
|
||||
return _data_caches
|
||||
|
||||
|
||||
class CacheDataAPI:
|
||||
"""Implements the public st.cache_data API: the @st.cache_data decorator, and
|
||||
st.cache_data.clear().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, decorator_metric_name: str, deprecation_warning: str | None = None
|
||||
):
|
||||
"""Create a CacheDataAPI instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decorator_metric_name
|
||||
The metric name to record for decorator usage. `@st.experimental_memo` is
|
||||
deprecated, but we're still supporting it and tracking its usage separately
|
||||
from `@st.cache_data`.
|
||||
|
||||
deprecation_warning
|
||||
An optional deprecation warning to show when the API is accessed.
|
||||
"""
|
||||
|
||||
# Parameterize the decorator metric name.
|
||||
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
|
||||
self._decorator = gather_metrics( # type: ignore
|
||||
decorator_metric_name, self._decorator
|
||||
)
|
||||
self._deprecation_warning = deprecation_warning
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Bare decorator usage
|
||||
@overload
|
||||
def __call__(self, func: F) -> F:
|
||||
...
|
||||
|
||||
# Decorator with arguments
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
persist: CachePersistType | bool = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
persist: CachePersistType | bool = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
):
|
||||
return self._decorator(
|
||||
func,
|
||||
ttl=ttl,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
show_spinner=show_spinner,
|
||||
experimental_allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
|
||||
def _decorator(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | None,
|
||||
max_entries: int | None,
|
||||
show_spinner: bool | str,
|
||||
persist: CachePersistType | bool,
|
||||
experimental_allow_widgets: bool,
|
||||
):
|
||||
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).
|
||||
|
||||
Cached objects are stored in "pickled" form, which means that the return
|
||||
value of a cached function must be pickleable. Each caller of the cached
|
||||
function gets its own copy of the cached data.
|
||||
|
||||
You can clear a function's cache with ``func.clear()`` or clear the entire
|
||||
cache with ``st.cache_data.clear()``.
|
||||
|
||||
To cache global resources, use ``st.cache_resource`` instead. Learn more
|
||||
about caching at https://docs.streamlit.io/library/advanced-features/caching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to cache. Streamlit hashes the function's source code.
|
||||
|
||||
ttl : float or timedelta or None
|
||||
The maximum number of seconds to keep an entry in the cache, or
|
||||
None if cache entries should not expire. The default is None.
|
||||
Note that ttl is incompatible with ``persist="disk"`` - ``ttl`` will be
|
||||
ignored if ``persist`` is specified.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. (When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed.) The default is None.
|
||||
|
||||
show_spinner : boolean or string
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a "cache miss" and the cached data is being created. If string,
|
||||
value of show_spinner param will be used for spinner text.
|
||||
|
||||
persist : str or boolean or None
|
||||
Optional location to persist cached data to. Passing "disk" (or True)
|
||||
will persist the cached data to the local disk. None (or False) will disable
|
||||
persistence. The default is None.
|
||||
|
||||
experimental_allow_widgets : boolean
|
||||
Allow widgets to be used in the cached function. Defaults to False.
|
||||
Support for widgets in cached functions is currently experimental.
|
||||
Setting this parameter to True may lead to excessive memory use since the
|
||||
widget value is treated as an additional input parameter to the cache.
|
||||
We may remove support for this option at any time without notice.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> d1 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> d2 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the data in d1 is the same as in d2.
|
||||
>>>
|
||||
>>> d3 = fetch_and_clean_data(DATA_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
To set the ``persist`` parameter, use this command as follows:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data(persist="disk")
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
By default, all parameters to a cached function must be hashable.
|
||||
Any parameter whose name begins with ``_`` will not be hashed. You can use
|
||||
this as an "escape hatch" for parameters that are not hashable:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(_db_connection, num_rows):
|
||||
... # Fetch data from _db_connection here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> connection = make_database_connection()
|
||||
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> another_connection = make_database_connection()
|
||||
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value - even though the _database_connection parameter was different
|
||||
>>> # in both calls.
|
||||
|
||||
A cached function's cache can be procedurally cleared:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(_db_connection, num_rows):
|
||||
... # Fetch data from _db_connection here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> fetch_and_clean_data.clear()
|
||||
>>> # Clear all cached entries for this function.
|
||||
|
||||
"""
|
||||
|
||||
# Parse our persist value into a string
|
||||
persist_string: CachePersistType
|
||||
if persist is True:
|
||||
persist_string = "disk"
|
||||
elif persist is False:
|
||||
persist_string = None
|
||||
else:
|
||||
persist_string = persist
|
||||
|
||||
if persist_string not in (None, "disk"):
|
||||
# We'll eventually have more persist options.
|
||||
raise StreamlitAPIException(
|
||||
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
|
||||
)
|
||||
|
||||
self._maybe_show_deprecation_warning()
|
||||
|
||||
def wrapper(f):
|
||||
return make_cached_func_wrapper(
|
||||
CachedDataFuncInfo(
|
||||
func=f,
|
||||
persist=persist_string,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
if func is None:
|
||||
return wrapper
|
||||
|
||||
return make_cached_func_wrapper(
|
||||
CachedDataFuncInfo(
|
||||
func=cast(types.FunctionType, func),
|
||||
persist=persist_string,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
@gather_metrics("clear_data_caches")
|
||||
def clear(self) -> None:
|
||||
"""Clear all in-memory and on-disk data caches."""
|
||||
self._maybe_show_deprecation_warning()
|
||||
_data_caches.clear_all()
|
||||
|
||||
def _maybe_show_deprecation_warning(self):
|
||||
"""If the API is being accessed with the deprecated `st.experimental_memo` name,
|
||||
show a deprecation warning.
|
||||
"""
|
||||
if self._deprecation_warning is not None:
|
||||
show_deprecation_warning(self._deprecation_warning)
|
||||
|
||||
|
||||
class DataCache(Cache):
|
||||
"""Manages cached values for a single st.cache_data function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
storage: CacheStorage,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl_seconds: float | None,
|
||||
display_name: str,
|
||||
allow_widgets: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.display_name = display_name
|
||||
self.storage = storage
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_entries = max_entries
|
||||
self.persist = persist
|
||||
self.allow_widgets = allow_widgets
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
if isinstance(self.storage, CacheStatsProvider):
|
||||
return self.storage.get_stats()
|
||||
return []
|
||||
|
||||
def read_result(self, key: str) -> CachedResult:
|
||||
"""Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
|
||||
if the value doesn't exist, and `CacheError` if the value exists but can't
|
||||
be unpickled.
|
||||
"""
|
||||
try:
|
||||
pickled_entry = self.storage.get(key)
|
||||
except CacheStorageKeyNotFoundError as e:
|
||||
raise CacheKeyNotFoundError(str(e)) from e
|
||||
except CacheStorageError as e:
|
||||
raise CacheError(str(e)) from e
|
||||
|
||||
try:
|
||||
entry = pickle.loads(pickled_entry)
|
||||
if not isinstance(entry, MultiCacheResults):
|
||||
# Loaded an old cache file format, remove it and let the caller
|
||||
# rerun the function.
|
||||
self.storage.delete(key)
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if not ctx:
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
widget_key = entry.get_current_widget_key(ctx, CacheType.DATA)
|
||||
if widget_key in entry.results:
|
||||
return entry.results[widget_key]
|
||||
else:
|
||||
raise CacheKeyNotFoundError()
|
||||
except pickle.UnpicklingError as exc:
|
||||
raise CacheError(f"Failed to unpickle {key}") from exc
|
||||
|
||||
@gather_metrics("_cache_data_object")
|
||||
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache.
|
||||
The value must be pickleable.
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
|
||||
main_id = st._main.id
|
||||
sidebar_id = st.sidebar.id
|
||||
|
||||
if self.allow_widgets:
|
||||
widgets = {
|
||||
msg.widget_metadata.widget_id
|
||||
for msg in messages
|
||||
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
|
||||
}
|
||||
else:
|
||||
widgets = set()
|
||||
|
||||
multi_cache_results: MultiCacheResults | None = None
|
||||
|
||||
# Try to find in cache storage, then falling back to a new result instance
|
||||
try:
|
||||
multi_cache_results = self._read_multi_results_from_storage(key)
|
||||
except (CacheKeyNotFoundError, pickle.UnpicklingError):
|
||||
pass
|
||||
|
||||
if multi_cache_results is None:
|
||||
multi_cache_results = MultiCacheResults(widget_ids=widgets, results={})
|
||||
multi_cache_results.widget_ids.update(widgets)
|
||||
widget_key = multi_cache_results.get_current_widget_key(ctx, CacheType.DATA)
|
||||
|
||||
result = CachedResult(value, messages, main_id, sidebar_id)
|
||||
multi_cache_results.results[widget_key] = result
|
||||
|
||||
try:
|
||||
pickled_entry = pickle.dumps(multi_cache_results)
|
||||
except (pickle.PicklingError, TypeError) as exc:
|
||||
raise CacheError(f"Failed to pickle {key}") from exc
|
||||
|
||||
self.storage.set(key, pickled_entry)
|
||||
|
||||
def _clear(self) -> None:
|
||||
self.storage.clear()
|
||||
|
||||
def _read_multi_results_from_storage(self, key: str) -> MultiCacheResults:
|
||||
"""Look up the results from storage and ensure it has the right type.
|
||||
|
||||
Raises a `CacheKeyNotFoundError` if the key has no entry, or if the
|
||||
entry is malformed.
|
||||
"""
|
||||
try:
|
||||
pickled = self.storage.get(key)
|
||||
except CacheStorageKeyNotFoundError as e:
|
||||
raise CacheKeyNotFoundError(str(e)) from e
|
||||
|
||||
maybe_results = pickle.loads(pickled)
|
||||
|
||||
if isinstance(maybe_results, MultiCacheResults):
|
||||
return maybe_results
|
||||
else:
|
||||
self.storage.delete(key)
|
||||
raise CacheKeyNotFoundError()
|
||||
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
from streamlit import type_util
|
||||
from streamlit.errors import (
|
||||
MarkdownFormattedException,
|
||||
StreamlitAPIException,
|
||||
StreamlitAPIWarning,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
|
||||
|
||||
CACHE_DOCS_URL = "https://docs.streamlit.io/library/advanced-features/caching"
|
||||
|
||||
|
||||
def get_cached_func_name_md(func: Any) -> str:
|
||||
"""Get markdown representation of the function name."""
|
||||
if hasattr(func, "__name__"):
|
||||
return "`%s()`" % func.__name__
|
||||
elif hasattr(type(func), "__name__"):
|
||||
return f"`{type(func).__name__}`"
|
||||
return f"`{type(func)}`"
|
||||
|
||||
|
||||
def get_return_value_type(return_value: Any) -> str:
|
||||
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
|
||||
return f"`{return_value.__module__}.{type(return_value).__name__}`"
|
||||
return get_cached_func_name_md(return_value)
|
||||
|
||||
|
||||
class UnhashableTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnhashableParamError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
func: types.FunctionType,
|
||||
arg_name: Optional[str],
|
||||
arg_value: Any,
|
||||
orig_exc: BaseException,
|
||||
):
|
||||
msg = self._create_message(cache_type, func, arg_name, arg_value)
|
||||
super().__init__(msg)
|
||||
self.with_traceback(orig_exc.__traceback__)
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
cache_type: CacheType,
|
||||
func: types.FunctionType,
|
||||
arg_name: Optional[str],
|
||||
arg_value: Any,
|
||||
) -> str:
|
||||
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
|
||||
arg_type = type_util.get_fqn_type(arg_value)
|
||||
func_name = func.__name__
|
||||
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
|
||||
|
||||
return (
|
||||
f"""
|
||||
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
|
||||
|
||||
To address this, you can tell Streamlit not to hash this argument by adding a
|
||||
leading underscore to the argument's name in the function signature:
|
||||
|
||||
```
|
||||
@st.{get_decorator_api_name(cache_type)}
|
||||
def {func_name}({arg_replacement_name}, ...):
|
||||
...
|
||||
```
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
|
||||
class CacheKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CachedStFunctionWarning(StreamlitAPIWarning):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
st_func_name: str,
|
||||
cached_func: types.FunctionType,
|
||||
):
|
||||
args = {
|
||||
"st_func_name": f"`st.{st_func_name}()`",
|
||||
"func_name": self._get_cached_func_name_md(cached_func),
|
||||
"decorator_name": get_decorator_api_name(cache_type),
|
||||
}
|
||||
|
||||
msg = (
|
||||
"""
|
||||
Your script uses %(st_func_name)s to write to your Streamlit app from within
|
||||
some cached code at %(func_name)s. This code will only be called when we detect
|
||||
a cache "miss", which can lead to unexpected results.
|
||||
|
||||
How to fix this:
|
||||
* Move the %(st_func_name)s call outside %(func_name)s.
|
||||
* Or, if you know what you're doing, use `@st.%(decorator_name)s(experimental_allow_widgets=True)`
|
||||
to enable widget replay and suppress this warning.
|
||||
"""
|
||||
% args
|
||||
).strip("\n")
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
@staticmethod
|
||||
def _get_cached_func_name_md(func: types.FunctionType) -> str:
|
||||
"""Get markdown representation of the function name."""
|
||||
if hasattr(func, "__name__"):
|
||||
return "`%s()`" % func.__name__
|
||||
else:
|
||||
return "a cached function"
|
||||
|
||||
|
||||
class CacheReplayClosureError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
cached_func: types.FunctionType,
|
||||
):
|
||||
func_name = get_cached_func_name_md(cached_func)
|
||||
decorator_name = get_decorator_api_name(cache_type)
|
||||
|
||||
msg = (
|
||||
f"""
|
||||
While running {func_name}, a streamlit element is called on some layout block created outside the function.
|
||||
This is incompatible with replaying the cached effect of that element, because the
|
||||
the referenced block might not exist when the replay happens.
|
||||
|
||||
How to fix this:
|
||||
* Move the creation of $THING inside {func_name}.
|
||||
* Move the call to the streamlit element outside of {func_name}.
|
||||
* Remove the `@st.{decorator_name}` decorator from {func_name}.
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class UnserializableReturnValueError(MarkdownFormattedException):
|
||||
def __init__(self, func: types.FunctionType, return_value: types.FunctionType):
|
||||
MarkdownFormattedException.__init__(
|
||||
self,
|
||||
f"""
|
||||
Cannot serialize the return value (of type {get_return_value_type(return_value)}) in {get_cached_func_name_md(func)}.
|
||||
`st.cache_data` uses [pickle](https://docs.python.org/3/library/pickle.html) to
|
||||
serialize the function’s return value and safely store it in the cache without mutating the original object. Please convert the return value to a pickle-serializable type.
|
||||
If you want to cache unserializable objects such as database connections or Tensorflow
|
||||
sessions, use `st.cache_resource` instead (see [our docs]({CACHE_DOCS_URL}) for differences).""",
|
||||
)
|
||||
|
||||
|
||||
class UnevaluatedDataFrameError(StreamlitAPIException):
|
||||
"""Used to display a message about uncollected dataframe being used"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,520 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""@st.cache_resource implementation"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
import types
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, TypeVar, cast, overload
|
||||
|
||||
from cachetools import TTLCache
|
||||
from pympler import asizeof
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import cache_utils
|
||||
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cache_utils import (
|
||||
Cache,
|
||||
CachedFuncInfo,
|
||||
make_cached_func_wrapper,
|
||||
ttl_to_seconds,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
ElementMsgData,
|
||||
MsgData,
|
||||
MultiCacheResults,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
|
||||
|
||||
ValidateFunc: TypeAlias = Callable[[Any], bool]
|
||||
|
||||
|
||||
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
|
||||
"""True if the two validate functions are equal for the purposes of
|
||||
determining whether a given function cache needs to be recreated.
|
||||
"""
|
||||
# To "properly" test for function equality here, we'd need to compare function bytecode.
|
||||
# For performance reasons, We've decided not to do that for now.
|
||||
return (a is None and b is None) or (a is not None and b is not None)
|
||||
|
||||
|
||||
class ResourceCaches(CacheStatsProvider):
|
||||
"""Manages all ResourceCache instances"""
|
||||
|
||||
def __init__(self):
|
||||
self._caches_lock = threading.Lock()
|
||||
self._function_caches: dict[str, ResourceCache] = {}
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
display_name: str,
|
||||
max_entries: int | float | None,
|
||||
ttl: float | timedelta | None,
|
||||
validate: ValidateFunc | None,
|
||||
allow_widgets: bool,
|
||||
) -> ResourceCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
if max_entries is None:
|
||||
max_entries = math.inf
|
||||
|
||||
ttl_seconds = ttl_to_seconds(ttl)
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._caches_lock:
|
||||
cache = self._function_caches.get(key)
|
||||
if (
|
||||
cache is not None
|
||||
and cache.ttl_seconds == ttl_seconds
|
||||
and cache.max_entries == max_entries
|
||||
and _equal_validate_funcs(cache.validate, validate)
|
||||
):
|
||||
return cache
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
|
||||
cache = ResourceCache(
|
||||
key=key,
|
||||
display_name=display_name,
|
||||
max_entries=max_entries,
|
||||
ttl_seconds=ttl_seconds,
|
||||
validate=validate,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self._function_caches[key] = cache
|
||||
return cache
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all resource caches."""
|
||||
with self._caches_lock:
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
with self._caches_lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats: list[CacheStat] = []
|
||||
for cache in function_caches.values():
|
||||
stats.extend(cache.get_stats())
|
||||
return stats
|
||||
|
||||
|
||||
# Singleton ResourceCaches instance
|
||||
_resource_caches = ResourceCaches()
|
||||
|
||||
|
||||
def get_resource_cache_stats_provider() -> CacheStatsProvider:
|
||||
"""Return the StatsProvider for all @st.cache_resource functions."""
|
||||
return _resource_caches
|
||||
|
||||
|
||||
class CachedResourceFuncInfo(CachedFuncInfo):
|
||||
"""Implements the CachedFuncInfo interface for @st.cache_resource"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
max_entries: int | None,
|
||||
ttl: float | timedelta | None,
|
||||
validate: ValidateFunc | None,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
super().__init__(
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self.max_entries = max_entries
|
||||
self.ttl = ttl
|
||||
self.validate = validate
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
return CacheType.RESOURCE
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""A human-readable name for the cached function"""
|
||||
return f"{self.func.__module__}.{self.func.__qualname__}"
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
return _resource_caches.get_cache(
|
||||
key=function_key,
|
||||
display_name=self.display_name,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
validate=self.validate,
|
||||
allow_widgets=self.allow_widgets,
|
||||
)
|
||||
|
||||
|
||||
class CacheResourceAPI:
|
||||
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
|
||||
and st.cache_resource.clear().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, decorator_metric_name: str, deprecation_warning: str | None = None
|
||||
):
|
||||
"""Create a CacheResourceAPI instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decorator_metric_name
|
||||
The metric name to record for decorator usage. `@st.experimental_singleton` is
|
||||
deprecated, but we're still supporting it and tracking its usage separately
|
||||
from `@st.cache_resource`.
|
||||
|
||||
deprecation_warning
|
||||
An optional deprecation warning to show when the API is accessed.
|
||||
"""
|
||||
|
||||
# Parameterize the decorator metric name.
|
||||
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
|
||||
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
|
||||
self._deprecation_warning = deprecation_warning
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Bare decorator usage
|
||||
@overload
|
||||
def __call__(self, func: F) -> F:
|
||||
...
|
||||
|
||||
# Decorator with arguments
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
validate: ValidateFunc | None = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
validate: ValidateFunc | None = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
):
|
||||
return self._decorator(
|
||||
func,
|
||||
ttl=ttl,
|
||||
max_entries=max_entries,
|
||||
show_spinner=show_spinner,
|
||||
validate=validate,
|
||||
experimental_allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
|
||||
def _decorator(
|
||||
self,
|
||||
func: F | None,
|
||||
*,
|
||||
ttl: float | timedelta | None,
|
||||
max_entries: int | None,
|
||||
show_spinner: bool | str,
|
||||
validate: ValidateFunc | None,
|
||||
experimental_allow_widgets: bool,
|
||||
):
|
||||
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
|
||||
|
||||
Cached objects are shared across all users, sessions, and reruns. They
|
||||
must be thread-safe because they can be accessed from multiple threads
|
||||
concurrently. If thread safety is an issue, consider using ``st.session_state``
|
||||
to store resources per session instead.
|
||||
|
||||
You can clear a function's cache with ``func.clear()`` or clear the entire
|
||||
cache with ``st.cache_resource.clear()``.
|
||||
|
||||
To cache data, use ``st.cache_data`` instead. Learn more about caching at
|
||||
https://docs.streamlit.io/library/advanced-features/caching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function that creates the cached resource. Streamlit hashes the
|
||||
function's source code.
|
||||
|
||||
ttl : float or timedelta or None
|
||||
The maximum number of seconds to keep an entry in the cache, or
|
||||
None if cache entries should not expire. The default is None.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. (When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed.) The default is None.
|
||||
|
||||
show_spinner : boolean or string
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a "cache miss" and the cached resource is being created. If string,
|
||||
value of show_spinner param will be used for spinner text.
|
||||
|
||||
validate : callable or None
|
||||
An optional validation function for cached data. ``validate`` is called
|
||||
each time the cached value is accessed. It receives the cached value as
|
||||
its only parameter and it must return a boolean. If ``validate`` returns
|
||||
False, the current cached value is discarded, and the decorated function
|
||||
is called to compute a new value. This is useful e.g. to check the
|
||||
health of database connections.
|
||||
|
||||
experimental_allow_widgets : boolean
|
||||
Allow widgets to be used in the cached function. Defaults to False.
|
||||
Support for widgets in cached functions is currently experimental.
|
||||
Setting this parameter to True may lead to excessive memory use since the
|
||||
widget value is treated as an additional input parameter to the cache.
|
||||
We may remove support for this option at any time without notice.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(url):
|
||||
... # Create a database session object that points to the URL.
|
||||
... return session
|
||||
...
|
||||
>>> s1 = get_database_session(SESSION_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> s2 = get_database_session(SESSION_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the connection object in s1 is the same as in s2.
|
||||
>>>
|
||||
>>> s3 = get_database_session(SESSION_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
By default, all parameters to a cache_resource function must be hashable.
|
||||
Any parameter whose name begins with ``_`` will not be hashed. You can use
|
||||
this as an "escape hatch" for parameters that are not hashable:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(_sessionmaker, url):
|
||||
... # Create a database connection object that points to the URL.
|
||||
... return connection
|
||||
...
|
||||
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value - even though the _sessionmaker parameter was different
|
||||
>>> # in both calls.
|
||||
|
||||
A cache_resource function's cache can be procedurally cleared:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(_sessionmaker, url):
|
||||
... # Create a database connection object that points to the URL.
|
||||
... return connection
|
||||
...
|
||||
>>> get_database_session.clear()
|
||||
>>> # Clear all cached entries for this function.
|
||||
|
||||
"""
|
||||
self._maybe_show_deprecation_warning()
|
||||
|
||||
# Support passing the params via function decorator, e.g.
|
||||
# @st.cache_resource(show_spinner=False)
|
||||
if func is None:
|
||||
return lambda f: make_cached_func_wrapper(
|
||||
CachedResourceFuncInfo(
|
||||
func=f,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
validate=validate,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
return make_cached_func_wrapper(
|
||||
CachedResourceFuncInfo(
|
||||
func=cast(types.FunctionType, func),
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
validate=validate,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
@gather_metrics("clear_resource_caches")
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache_resource caches."""
|
||||
self._maybe_show_deprecation_warning()
|
||||
_resource_caches.clear_all()
|
||||
|
||||
def _maybe_show_deprecation_warning(self):
|
||||
"""If the API is being accessed with the deprecated `st.experimental_singleton` name,
|
||||
show a deprecation warning.
|
||||
"""
|
||||
if self._deprecation_warning is not None:
|
||||
show_deprecation_warning(self._deprecation_warning)
|
||||
|
||||
|
||||
class ResourceCache(Cache):
|
||||
"""Manages cached values for a single st.cache_resource function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
max_entries: float,
|
||||
ttl_seconds: float,
|
||||
validate: ValidateFunc | None,
|
||||
display_name: str,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.display_name = display_name
|
||||
self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache(
|
||||
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
|
||||
)
|
||||
self._mem_cache_lock = threading.Lock()
|
||||
self.validate = validate
|
||||
self.allow_widgets = allow_widgets
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return cast(float, self._mem_cache.maxsize)
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return cast(float, self._mem_cache.ttl)
|
||||
|
||||
def read_result(self, key: str) -> CachedResult:
|
||||
"""Read a value and associated messages from the cache.
|
||||
Raise `CacheKeyNotFoundError` if the value doesn't exist.
|
||||
"""
|
||||
with self._mem_cache_lock:
|
||||
if key not in self._mem_cache:
|
||||
# key does not exist in cache.
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
multi_results: MultiCacheResults = self._mem_cache[key]
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if not ctx:
|
||||
# ScriptRunCtx does not exist (we're probably running in "raw" mode).
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
|
||||
if widget_key not in multi_results.results:
|
||||
# widget_key does not exist in cache (this combination of widgets hasn't been
|
||||
# seen for the value_key yet).
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
result = multi_results.results[widget_key]
|
||||
|
||||
if self.validate is not None and not self.validate(result.value):
|
||||
# Validate failed: delete the entry and raise an error.
|
||||
del multi_results.results[widget_key]
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
return result
|
||||
|
||||
@gather_metrics("_cache_resource_object")
|
||||
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache."""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
|
||||
main_id = st._main.id
|
||||
sidebar_id = st.sidebar.id
|
||||
if self.allow_widgets:
|
||||
widgets = {
|
||||
msg.widget_metadata.widget_id
|
||||
for msg in messages
|
||||
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
|
||||
}
|
||||
else:
|
||||
widgets = set()
|
||||
|
||||
with self._mem_cache_lock:
|
||||
try:
|
||||
multi_results = self._mem_cache[key]
|
||||
except KeyError:
|
||||
multi_results = MultiCacheResults(widget_ids=widgets, results={})
|
||||
|
||||
multi_results.widget_ids.update(widgets)
|
||||
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
|
||||
|
||||
result = CachedResult(value, messages, main_id, sidebar_id)
|
||||
multi_results.results[widget_key] = result
|
||||
self._mem_cache[key] = multi_results
|
||||
|
||||
def _clear(self) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.clear()
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
# Shallow clone our cache. Computing item sizes is potentially
|
||||
# expensive, and we want to minimize the time we spend holding
|
||||
# the lock.
|
||||
with self._mem_cache_lock:
|
||||
cache_entries = list(self._mem_cache.values())
|
||||
|
||||
return [
|
||||
CacheStat(
|
||||
category_name="st_cache_resource",
|
||||
cache_name=self.display_name,
|
||||
byte_length=asizeof.asizeof(entry),
|
||||
)
|
||||
for entry in cache_entries
|
||||
]
|
||||
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class CacheType(enum.Enum):
|
||||
"""The function cache types we implement."""
|
||||
|
||||
DATA = "DATA"
|
||||
RESOURCE = "RESOURCE"
|
||||
|
||||
|
||||
def get_decorator_api_name(cache_type: CacheType) -> str:
|
||||
"""Return the name of the public decorator API for the given CacheType."""
|
||||
if cache_type is CacheType.DATA:
|
||||
return "cache_data"
|
||||
if cache_type is CacheType.RESOURCE:
|
||||
return "cache_resource"
|
||||
raise RuntimeError(f"Unrecognized CacheType '{cache_type}'")
|
||||
@@ -0,0 +1,449 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Common cache logic shared by st.cache_data and st.cache_resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, overload
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from streamlit import type_util
|
||||
from streamlit.elements.spinner import spinner
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.cache_errors import (
|
||||
CacheError,
|
||||
CacheKeyNotFoundError,
|
||||
UnevaluatedDataFrameError,
|
||||
UnhashableParamError,
|
||||
UnhashableTypeError,
|
||||
UnserializableReturnValueError,
|
||||
get_cached_func_name_md,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
MsgData,
|
||||
replay_cached_messages,
|
||||
)
|
||||
from streamlit.runtime.caching.hashing import update_hash
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
# The timer function we use with TTLCache. This is the default timer func, but
|
||||
# is exposed here as a constant so that it can be patched in unit tests.
|
||||
TTLCACHE_TIMER = time.monotonic
|
||||
|
||||
|
||||
@overload
|
||||
def ttl_to_seconds(
|
||||
ttl: float | timedelta | None, *, coerce_none_to_inf: Literal[False]
|
||||
) -> float | None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def ttl_to_seconds(ttl: float | timedelta | None) -> float:
|
||||
...
|
||||
|
||||
|
||||
def ttl_to_seconds(
|
||||
ttl: float | timedelta | None, *, coerce_none_to_inf: bool = True
|
||||
) -> float | None:
|
||||
"""
|
||||
Convert a ttl value to a float representing "number of seconds".
|
||||
"""
|
||||
if coerce_none_to_inf and ttl is None:
|
||||
return math.inf
|
||||
if isinstance(ttl, timedelta):
|
||||
return ttl.total_seconds()
|
||||
return ttl
|
||||
|
||||
|
||||
# We show a special "UnevaluatedDataFrame" warning for cached funcs
|
||||
# that attempt to return one of these unserializable types:
|
||||
UNEVALUATED_DATAFRAME_TYPES = (
|
||||
"snowflake.snowpark.table.Table",
|
||||
"snowflake.snowpark.dataframe.DataFrame",
|
||||
"pyspark.sql.dataframe.DataFrame",
|
||||
)
|
||||
|
||||
|
||||
class Cache:
|
||||
"""Function cache interface. Caches persist across script runs."""
|
||||
|
||||
def __init__(self):
|
||||
self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
self._value_locks_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def read_result(self, value_key: str) -> CachedResult:
|
||||
"""Read a value and associated messages from the cache.
|
||||
|
||||
Raises
|
||||
------
|
||||
CacheKeyNotFoundError
|
||||
Raised if value_key is not in the cache.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache, overwriting any existing
|
||||
result that uses the value_key.
|
||||
"""
|
||||
# We *could* `del self._value_locks[value_key]` here, since nobody will be taking
|
||||
# a compute_value_lock for this value_key after the result is written.
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_value_lock(self, value_key: str) -> threading.Lock:
|
||||
"""Return the lock that should be held while computing a new cached value.
|
||||
In a popular app with a cache that hasn't been pre-warmed, many sessions may try
|
||||
to access a not-yet-cached value simultaneously. We use a lock to ensure that
|
||||
only one of those sessions computes the value, and the others block until
|
||||
the value is computed.
|
||||
"""
|
||||
with self._value_locks_lock:
|
||||
return self._value_locks[value_key]
|
||||
|
||||
def clear(self):
|
||||
"""Clear all values from this cache."""
|
||||
with self._value_locks_lock:
|
||||
self._value_locks.clear()
|
||||
self._clear()
|
||||
|
||||
@abstractmethod
|
||||
def _clear(self) -> None:
|
||||
"""Subclasses must implement this to perform cache-clearing logic."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CachedFuncInfo:
|
||||
"""Encapsulates data for a cached function instance.
|
||||
|
||||
CachedFuncInfo instances are scoped to a single script run - they're not
|
||||
persistent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
self.func = func
|
||||
self.show_spinner = show_spinner
|
||||
self.allow_widgets = allow_widgets
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
"""Get or create the function cache for the given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]:
|
||||
"""Create a callable wrapper around a CachedFunctionInfo.
|
||||
|
||||
Calling the wrapper will return the cached value if it's already been
|
||||
computed, and will call the underlying function to compute and cache the
|
||||
value otherwise.
|
||||
|
||||
The wrapper also has a `clear` function that can be called to clear
|
||||
all of the wrapper's cached values.
|
||||
"""
|
||||
cached_func = CachedFunc(info)
|
||||
|
||||
# We'd like to simply return `cached_func`, which is already a Callable.
|
||||
# But using `functools.update_wrapper` on the CachedFunc instance
|
||||
# itself results in errors when our caching decorators are used to decorate
|
||||
# member functions. (See https://github.com/streamlit/streamlit/issues/6109)
|
||||
|
||||
@functools.wraps(info.func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return cached_func(*args, **kwargs)
|
||||
|
||||
# Give our wrapper its `clear` function.
|
||||
# (This results in a spurious mypy error that we suppress.)
|
||||
wrapper.clear = cached_func.clear # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CachedFunc:
|
||||
def __init__(self, info: CachedFuncInfo):
|
||||
self._info = info
|
||||
self._function_key = _make_function_key(info.cache_type, info.func)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""The wrapper. We'll only call our underlying function on a cache miss."""
|
||||
|
||||
name = self._info.func.__qualname__
|
||||
|
||||
if isinstance(self._info.show_spinner, bool):
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
message = f"Running `{name}()`."
|
||||
else:
|
||||
message = f"Running `{name}(...)`."
|
||||
else:
|
||||
message = self._info.show_spinner
|
||||
|
||||
if self._info.show_spinner or isinstance(self._info.show_spinner, str):
|
||||
with spinner(message):
|
||||
return self._get_or_create_cached_value(args, kwargs)
|
||||
else:
|
||||
return self._get_or_create_cached_value(args, kwargs)
|
||||
|
||||
def _get_or_create_cached_value(
|
||||
self, func_args: tuple[Any, ...], func_kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
# Retrieve the function's cache object. We must do this "just-in-time"
|
||||
# (as opposed to in the constructor), because caches can be invalidated
|
||||
# at any time.
|
||||
cache = self._info.get_function_cache(self._function_key)
|
||||
|
||||
# Generate the key for the cached value. This is based on the
|
||||
# arguments passed to the function.
|
||||
value_key = _make_value_key(
|
||||
cache_type=self._info.cache_type,
|
||||
func=self._info.func,
|
||||
func_args=func_args,
|
||||
func_kwargs=func_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
cached_result = cache.read_result(value_key)
|
||||
return self._handle_cache_hit(cached_result)
|
||||
except CacheKeyNotFoundError:
|
||||
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
|
||||
|
||||
def _handle_cache_hit(self, result: CachedResult) -> Any:
|
||||
"""Handle a cache hit: replay the result's cached messages, and return its value."""
|
||||
replay_cached_messages(
|
||||
result,
|
||||
self._info.cache_type,
|
||||
self._info.func,
|
||||
)
|
||||
return result.value
|
||||
|
||||
def _handle_cache_miss(
|
||||
self,
|
||||
cache: Cache,
|
||||
value_key: str,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Handle a cache miss: compute a new cached value, write it back to the cache,
|
||||
and return that newly-computed value.
|
||||
"""
|
||||
|
||||
# Implementation notes:
|
||||
# - We take a "compute_value_lock" before computing our value. This ensures that
|
||||
# multiple sessions don't try to compute the same value simultaneously.
|
||||
#
|
||||
# - We use a different lock for each value_key, as opposed to a single lock for
|
||||
# the entire cache, so that unrelated value computations don't block on each other.
|
||||
#
|
||||
# - When retrieving a cache entry that may not yet exist, we use a "double-checked locking"
|
||||
# strategy: first we try to retrieve the cache entry without taking a value lock. (This
|
||||
# happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't
|
||||
# been computed yet, we take the value lock and then immediately try to retrieve cache entry
|
||||
# *again*, while holding the lock. If the cache entry exists at this point, it means that
|
||||
# another thread computed the value before us.
|
||||
#
|
||||
# This means that the happy path ("cache entry exists") is a wee bit faster because
|
||||
# no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is
|
||||
# a wee bit slower, because we do two lookups for the entry.
|
||||
|
||||
with cache.compute_value_lock(value_key):
|
||||
# We've acquired the lock - but another thread may have acquired it first
|
||||
# and already computed the value. So we need to test for a cache hit again,
|
||||
# before computing.
|
||||
try:
|
||||
cached_result = cache.read_result(value_key)
|
||||
# Another thread computed the value before us. Early exit!
|
||||
return self._handle_cache_hit(cached_result)
|
||||
|
||||
except CacheKeyNotFoundError:
|
||||
# We acquired the lock before any other thread. Compute the value!
|
||||
with self._info.cached_message_replay_ctx.calling_cached_function(
|
||||
self._info.func, self._info.allow_widgets
|
||||
):
|
||||
computed_value = self._info.func(*func_args, **func_kwargs)
|
||||
|
||||
# We've computed our value, and now we need to write it back to the cache
|
||||
# along with any "replay messages" that were generated during value computation.
|
||||
messages = self._info.cached_message_replay_ctx._most_recent_messages
|
||||
try:
|
||||
cache.write_result(value_key, computed_value, messages)
|
||||
return computed_value
|
||||
except (CacheError, RuntimeError):
|
||||
# An exception was thrown while we tried to write to the cache. Report it to the user.
|
||||
# (We catch `RuntimeError` here because it will be raised by Apache Spark if we do not
|
||||
# collect dataframe before using `st.cache_data`.)
|
||||
if True in [
|
||||
type_util.is_type(computed_value, type_name)
|
||||
for type_name in UNEVALUATED_DATAFRAME_TYPES
|
||||
]:
|
||||
raise UnevaluatedDataFrameError(
|
||||
f"""
|
||||
The function {get_cached_func_name_md(self._info.func)} is decorated with `st.cache_data` but it returns an unevaluated dataframe
|
||||
of type `{type_util.get_fqn_type(computed_value)}`. Please call `collect()` or `to_pandas()` on the dataframe before returning it,
|
||||
so `st.cache_data` can serialize and cache it."""
|
||||
)
|
||||
raise UnserializableReturnValueError(
|
||||
return_value=computed_value, func=self._info.func
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the wrapped function's associated cache."""
|
||||
cache = self._info.get_function_cache(self._function_key)
|
||||
cache.clear()
|
||||
|
||||
|
||||
def _make_value_key(
|
||||
cache_type: CacheType,
|
||||
func: types.FunctionType,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
) -> str:
|
||||
"""Create the key for a value within a cache.
|
||||
|
||||
This key is generated from the function's arguments. All arguments
|
||||
will be hashed, except for those named with a leading "_".
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
Raised (with a nicely-formatted explanation message) if we encounter
|
||||
an un-hashable arg.
|
||||
"""
|
||||
|
||||
# Create a (name, value) list of all *args and **kwargs passed to the
|
||||
# function.
|
||||
arg_pairs: list[tuple[str | None, Any]] = []
|
||||
for arg_idx in range(len(func_args)):
|
||||
arg_name = _get_positional_arg_name(func, arg_idx)
|
||||
arg_pairs.append((arg_name, func_args[arg_idx]))
|
||||
|
||||
for kw_name, kw_val in func_kwargs.items():
|
||||
# **kwargs ordering is preserved, per PEP 468
|
||||
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
|
||||
# deterministic.
|
||||
arg_pairs.append((kw_name, kw_val))
|
||||
|
||||
# Create the hash from each arg value, except for those args whose name
|
||||
# starts with "_". (Underscore-prefixed args are deliberately excluded from
|
||||
# hashing.)
|
||||
args_hasher = hashlib.new("md5")
|
||||
for arg_name, arg_value in arg_pairs:
|
||||
if arg_name is not None and arg_name.startswith("_"):
|
||||
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
|
||||
continue
|
||||
|
||||
try:
|
||||
update_hash(
|
||||
(arg_name, arg_value),
|
||||
hasher=args_hasher,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
except UnhashableTypeError as exc:
|
||||
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
|
||||
|
||||
value_key = args_hasher.hexdigest()
|
||||
_LOGGER.debug("Cache key: %s", value_key)
|
||||
|
||||
return value_key
|
||||
|
||||
|
||||
def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
|
||||
"""Create the unique key for a function's cache.
|
||||
|
||||
A function's key is stable across reruns of the app, and changes when
|
||||
the function's source code changes.
|
||||
"""
|
||||
func_hasher = hashlib.new("md5")
|
||||
|
||||
# Include the function's __module__ and __qualname__ strings in the hash.
|
||||
# This means that two identical functions in different modules
|
||||
# will not share a hash; it also means that two identical *nested*
|
||||
# functions in the same module will not share a hash.
|
||||
update_hash(
|
||||
(func.__module__, func.__qualname__),
|
||||
hasher=func_hasher,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
# Include the function's source code in its hash. If the source code can't
|
||||
# be retrieved, fall back to the function's bytecode instead.
|
||||
source_code: str | bytes
|
||||
try:
|
||||
source_code = inspect.getsource(func)
|
||||
except OSError as e:
|
||||
_LOGGER.debug(
|
||||
"Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}",
|
||||
e,
|
||||
)
|
||||
source_code = func.__code__.co_code
|
||||
|
||||
update_hash(
|
||||
source_code,
|
||||
hasher=func_hasher,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
cache_key = func_hasher.hexdigest()
|
||||
return cache_key
|
||||
|
||||
|
||||
def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> str | None:
|
||||
"""Return the name of a function's positional argument.
|
||||
|
||||
If arg_index is out of range, or refers to a parameter that is not a
|
||||
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
|
||||
return None instead.
|
||||
"""
|
||||
if arg_index < 0:
|
||||
return None
|
||||
|
||||
params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values())
|
||||
if arg_index >= len(params):
|
||||
return None
|
||||
|
||||
if params[arg_index].kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
):
|
||||
return params[arg_index].name
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,478 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import threading
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Iterator, Union
|
||||
|
||||
from google.protobuf.message import Message
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import runtime, util
|
||||
from streamlit.elements import NONWIDGET_ELEMENTS, WIDGETS
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.Block_pb2 import Block
|
||||
from streamlit.runtime.caching.cache_errors import (
|
||||
CachedStFunctionWarning,
|
||||
CacheReplayClosureError,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.hashing import update_hash
|
||||
from streamlit.runtime.scriptrunner.script_run_context import (
|
||||
ScriptRunContext,
|
||||
get_script_run_ctx,
|
||||
)
|
||||
from streamlit.runtime.state.common import WidgetMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Widget(Protocol):
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WidgetMsgMetadata:
|
||||
"""Everything needed for replaying a widget and treating it as an implicit
|
||||
argument to a cached function, beyond what is stored for all elements.
|
||||
"""
|
||||
|
||||
widget_id: str
|
||||
widget_value: Any
|
||||
metadata: WidgetMetadata[Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MediaMsgData:
|
||||
media: bytes | str
|
||||
mimetype: str
|
||||
media_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ElementMsgData:
|
||||
"""An element's message and related metadata for
|
||||
replaying that element's function call.
|
||||
|
||||
widget_metadata is filled in if and only if this element is a widget.
|
||||
media_data is filled in iff this is a media element (image, audio, video).
|
||||
"""
|
||||
|
||||
delta_type: str
|
||||
message: Message
|
||||
id_of_dg_called_on: str
|
||||
returned_dgs_id: str
|
||||
widget_metadata: WidgetMsgMetadata | None = None
|
||||
media_data: list[MediaMsgData] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockMsgData:
|
||||
message: Block
|
||||
id_of_dg_called_on: str
|
||||
returned_dgs_id: str
|
||||
|
||||
|
||||
MsgData = Union[ElementMsgData, BlockMsgData]
|
||||
|
||||
"""
|
||||
Note [Cache result structure]
|
||||
|
||||
The cache for a decorated function's results is split into two parts to enable
|
||||
handling widgets invoked by the function.
|
||||
|
||||
Widgets act as implicit additional inputs to the cached function, so they should
|
||||
be used when deriving the cache key. However, we don't know what widgets are
|
||||
involved without first calling the function! So, we use the first execution
|
||||
of the function with a particular cache key to record what widgets are used,
|
||||
and use the current values of those widgets to derive a second cache key to
|
||||
look up the function execution's results. The combination of first and second
|
||||
cache keys act as one true cache key, just split up because the second part depends
|
||||
on the first.
|
||||
|
||||
We need to treat widgets as implicit arguments of the cached function, because
|
||||
the behavior of the function, inluding what elements are created and what it
|
||||
returns, can be and usually will be influenced by the values of those widgets.
|
||||
For example:
|
||||
> @st.memo
|
||||
> def example_fn(x):
|
||||
> y = x + 1
|
||||
> if st.checkbox("hi"):
|
||||
> st.write("you checked the thing")
|
||||
> y = 0
|
||||
> return y
|
||||
>
|
||||
> example_fn(2)
|
||||
|
||||
If the checkbox is checked, the function call should return 0 and the checkbox and
|
||||
message should be rendered. If the checkbox isn't checked, only the checkbox should
|
||||
render, and the function will return 3.
|
||||
|
||||
|
||||
There is a small catch in this. Since what widgets execute could depend on the values of
|
||||
any prior widgets, if we replace the `st.write` call in the example with a slider,
|
||||
the first time it runs, we would miss the slider because it wasn't called,
|
||||
so when we later execute the function with the checkbox checked, the widget cache key
|
||||
would not include the state of the slider, and would incorrectly get a cache hit
|
||||
for a different slider value.
|
||||
|
||||
In principle the cache could be function+args key -> 1st widget key -> 2nd widget key
|
||||
... -> final widget key, with each widget dependent on the exact values of the widgets
|
||||
seen prior. This would prevent unnecessary cache misses due to differing values of widgets
|
||||
that wouldn't affect the function's execution because they aren't even created.
|
||||
But this would add even more complexity and both conceptual and runtime overhead, so it is
|
||||
unclear if it would be worth doing.
|
||||
|
||||
Instead, we can keep the widgets as one cache key, and if we encounter a new widget
|
||||
while executing the function, we just update the list of widgets to include it.
|
||||
This will cause existing cached results to be invalidated, which is bad, but to
|
||||
avoid it we would need to keep around the full list of widgets and values for each
|
||||
widget cache key so we could compute the updated key, which is probably too expensive
|
||||
to be worth it.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedResult:
|
||||
"""The full results of calling a cache-decorated function, enough to
|
||||
replay the st functions called while executing it.
|
||||
"""
|
||||
|
||||
value: Any
|
||||
messages: list[MsgData]
|
||||
main_id: str
|
||||
sidebar_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiCacheResults:
|
||||
"""Widgets called by a cache-decorated function, and a mapping of the
|
||||
widget-derived cache key to the final results of executing the function.
|
||||
"""
|
||||
|
||||
widget_ids: set[str]
|
||||
results: dict[str, CachedResult]
|
||||
|
||||
def get_current_widget_key(
|
||||
self, ctx: ScriptRunContext, cache_type: CacheType
|
||||
) -> str:
|
||||
state = ctx.session_state
|
||||
# Compute the key using only widgets that have values. A missing widget
|
||||
# can be ignored because we only care about getting different keys
|
||||
# for different widget values, and for that purpose doing nothing
|
||||
# to the running hash is just as good as including the widget with a
|
||||
# sentinel value. But by excluding it, we might get to reuse a result
|
||||
# saved before we knew about that widget.
|
||||
widget_values = [
|
||||
(wid, state[wid]) for wid in sorted(self.widget_ids) if wid in state
|
||||
]
|
||||
widget_key = _make_widget_key(widget_values, cache_type)
|
||||
return widget_key
|
||||
|
||||
|
||||
"""
|
||||
Note [DeltaGenerator method invocation]
|
||||
There are two top level DG instances defined for all apps:
|
||||
`main`, which is for putting elements in the main part of the app
|
||||
`sidebar`, for the sidebar
|
||||
|
||||
There are 3 different ways an st function can be invoked:
|
||||
1. Implicitly on the main DG instance (plain `st.foo` calls)
|
||||
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
|
||||
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
|
||||
|
||||
To simplify replaying messages from a cached function result, we convert all of these
|
||||
to explicit invocations. How they get rewritten depends on if the invocation was
|
||||
implicit vs explicit, and if the target DG has been seen/produced during replay.
|
||||
|
||||
Implicit invocation on a known DG -> Explicit invocation on that DG
|
||||
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
|
||||
with st.container():
|
||||
my_cache_decorated_function()
|
||||
|
||||
This is situation 2 above, and the DG is a block entirely outside our function call,
|
||||
so we interpret it as "put this element in the enclosing contextmanager block"
|
||||
(or main if there isn't one), which is achieved by invoking on main.
|
||||
Explicit invocation on a known DG -> No change needed
|
||||
Explicit invocation on an unknown DG -> Raise an error
|
||||
We have no way to identify the target DG, and it may not even be present in the
|
||||
current script run, so the least surprising thing to do is raise an error.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CachedMessageReplayContext(threading.local):
|
||||
"""A utility for storing messages generated by `st` commands called inside
|
||||
a cached function.
|
||||
|
||||
Data is stored in a thread-local object, so it's safe to use an instance
|
||||
of this class across multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_type: CacheType):
|
||||
self._cached_func_stack: list[types.FunctionType] = []
|
||||
self._suppress_st_function_warning = 0
|
||||
self._cached_message_stack: list[list[MsgData]] = []
|
||||
self._seen_dg_stack: list[set[str]] = []
|
||||
self._most_recent_messages: list[MsgData] = []
|
||||
self._registered_metadata: WidgetMetadata[Any] | None = None
|
||||
self._media_data: list[MediaMsgData] = []
|
||||
self._cache_type = cache_type
|
||||
self._allow_widgets: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def calling_cached_function(
|
||||
self, func: types.FunctionType, allow_widgets: bool
|
||||
) -> Iterator[None]:
|
||||
"""Context manager that should wrap the invocation of a cached function.
|
||||
It allows us to track any `st.foo` messages that are generated from inside the function
|
||||
for playback during cache retrieval.
|
||||
"""
|
||||
self._cached_func_stack.append(func)
|
||||
self._cached_message_stack.append([])
|
||||
self._seen_dg_stack.append(set())
|
||||
if allow_widgets:
|
||||
self._allow_widgets += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._cached_func_stack.pop()
|
||||
self._most_recent_messages = self._cached_message_stack.pop()
|
||||
self._seen_dg_stack.pop()
|
||||
if allow_widgets:
|
||||
self._allow_widgets -= 1
|
||||
|
||||
def save_element_message(
|
||||
self,
|
||||
delta_type: str,
|
||||
element_proto: Message,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Record the element protobuf as having been produced during any currently
|
||||
executing cached functions, so they can be replayed any time the function's
|
||||
execution is skipped because they're in the cache.
|
||||
"""
|
||||
if not runtime.exists():
|
||||
return
|
||||
if len(self._cached_message_stack) >= 1:
|
||||
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
||||
# Arrow dataframes have an ID but only set it when used as data editor
|
||||
# widgets, so we have to check that the ID has been actually set to
|
||||
# know if an element is a widget.
|
||||
if isinstance(element_proto, Widget) and element_proto.id:
|
||||
wid = element_proto.id
|
||||
# TODO replace `Message` with a more precise type
|
||||
if not self._registered_metadata:
|
||||
_LOGGER.error(
|
||||
"Trying to save widget message that wasn't registered. This should not be possible."
|
||||
)
|
||||
raise AttributeError
|
||||
widget_meta = WidgetMsgMetadata(
|
||||
wid, None, metadata=self._registered_metadata
|
||||
)
|
||||
else:
|
||||
widget_meta = None
|
||||
|
||||
media_data = self._media_data
|
||||
|
||||
element_msg_data = ElementMsgData(
|
||||
delta_type,
|
||||
element_proto,
|
||||
id_to_save,
|
||||
returned_dg_id,
|
||||
widget_meta,
|
||||
media_data,
|
||||
)
|
||||
for msgs in self._cached_message_stack:
|
||||
if self._allow_widgets or widget_meta is None:
|
||||
msgs.append(element_msg_data)
|
||||
|
||||
# Reset instance state, now that it has been used for the
|
||||
# associated element.
|
||||
self._media_data = []
|
||||
self._registered_metadata = None
|
||||
|
||||
for s in self._seen_dg_stack:
|
||||
s.add(returned_dg_id)
|
||||
|
||||
def save_block_message(
|
||||
self,
|
||||
block_proto: Block,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
||||
for msgs in self._cached_message_stack:
|
||||
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
|
||||
for s in self._seen_dg_stack:
|
||||
s.add(returned_dg_id)
|
||||
|
||||
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
|
||||
"""Select the id of the DG that this message should be invoked on
|
||||
during message replay.
|
||||
|
||||
See Note [DeltaGenerator method invocation]
|
||||
|
||||
invoked_id is the DG the st function was called on, usually `st._main`.
|
||||
acting_on_id is the DG the st function ultimately runs on, which may be different
|
||||
if the invoked DG delegated to another one because it was in a `with` block.
|
||||
"""
|
||||
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
|
||||
return acting_on_id
|
||||
else:
|
||||
return invoked_id
|
||||
|
||||
def save_widget_metadata(self, metadata: WidgetMetadata[Any]) -> None:
|
||||
self._registered_metadata = metadata
|
||||
|
||||
def save_image_data(
|
||||
self, image_data: bytes | str, mimetype: str, image_id: str
|
||||
) -> None:
|
||||
self._media_data.append(MediaMsgData(image_data, mimetype, image_id))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cached_st_function_warning(self) -> Iterator[None]:
|
||||
self._suppress_st_function_warning += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._suppress_st_function_warning -= 1
|
||||
assert self._suppress_st_function_warning >= 0
|
||||
|
||||
def maybe_show_cached_st_function_warning(
|
||||
self,
|
||||
dg: "DeltaGenerator",
|
||||
st_func_name: str,
|
||||
) -> None:
|
||||
"""If appropriate, warn about calling st.foo inside @memo.
|
||||
|
||||
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
|
||||
the user when they're calling st.foo() from within a function that is
|
||||
wrapped in @st.cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dg : DeltaGenerator
|
||||
The DeltaGenerator to publish the warning to.
|
||||
|
||||
st_func_name : str
|
||||
The name of the Streamlit function that was called.
|
||||
|
||||
"""
|
||||
# There are some elements not in either list, which we still want to warn about.
|
||||
# Ideally we will fix this by either updating the lists or creating a better
|
||||
# way of categorizing elements.
|
||||
if st_func_name in NONWIDGET_ELEMENTS:
|
||||
return
|
||||
if st_func_name in WIDGETS and self._allow_widgets > 0:
|
||||
return
|
||||
|
||||
if len(self._cached_func_stack) > 0 and self._suppress_st_function_warning <= 0:
|
||||
cached_func = self._cached_func_stack[-1]
|
||||
self._show_cached_st_function_warning(dg, st_func_name, cached_func)
|
||||
|
||||
def _show_cached_st_function_warning(
|
||||
self,
|
||||
dg: "DeltaGenerator",
|
||||
st_func_name: str,
|
||||
cached_func: types.FunctionType,
|
||||
) -> None:
|
||||
# Avoid infinite recursion by suppressing additional cached
|
||||
# function warnings from within the cached function warning.
|
||||
with self.suppress_cached_st_function_warning():
|
||||
e = CachedStFunctionWarning(self._cache_type, st_func_name, cached_func)
|
||||
dg.exception(e)
|
||||
|
||||
|
||||
def replay_cached_messages(
|
||||
result: CachedResult, cache_type: CacheType, cached_func: types.FunctionType
|
||||
) -> None:
|
||||
"""Replay the st element function calls that happened when executing a
|
||||
cache-decorated function.
|
||||
|
||||
When a cache function is executed, we record the element and block messages
|
||||
produced, and use those to reproduce the DeltaGenerator calls, so the elements
|
||||
will appear in the web app even when execution of the function is skipped
|
||||
because the result was cached.
|
||||
|
||||
To make this work, for each st function call we record an identifier for the
|
||||
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
|
||||
We also record the identifier for each DG returned by an st function call, if
|
||||
it returns one. Then, for each recorded message, we get the current DG instance
|
||||
corresponding to the DG the message was originally called on, and enqueue the
|
||||
message using that, recording any new DGs produced in case a later st function
|
||||
call is on one of them.
|
||||
"""
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
from streamlit.runtime.state.widgets import register_widget_from_metadata
|
||||
|
||||
# Maps originally recorded dg ids to this script run's version of that dg
|
||||
returned_dgs: dict[str, DeltaGenerator] = {}
|
||||
returned_dgs[result.main_id] = st._main
|
||||
returned_dgs[result.sidebar_id] = st.sidebar
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
try:
|
||||
for msg in result.messages:
|
||||
if isinstance(msg, ElementMsgData):
|
||||
if msg.widget_metadata is not None:
|
||||
register_widget_from_metadata(
|
||||
msg.widget_metadata.metadata,
|
||||
ctx,
|
||||
None,
|
||||
msg.delta_type,
|
||||
)
|
||||
if msg.media_data is not None:
|
||||
for data in msg.media_data:
|
||||
runtime.get_instance().media_file_mgr.add(
|
||||
data.media, data.mimetype, data.media_id
|
||||
)
|
||||
dg = returned_dgs[msg.id_of_dg_called_on]
|
||||
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
|
||||
if isinstance(maybe_dg, DeltaGenerator):
|
||||
returned_dgs[msg.returned_dgs_id] = maybe_dg
|
||||
elif isinstance(msg, BlockMsgData):
|
||||
dg = returned_dgs[msg.id_of_dg_called_on]
|
||||
new_dg = dg._block(msg.message)
|
||||
returned_dgs[msg.returned_dgs_id] = new_dg
|
||||
except KeyError:
|
||||
raise CacheReplayClosureError(cache_type, cached_func)
|
||||
|
||||
|
||||
def _make_widget_key(widgets: list[tuple[str, Any]], cache_type: CacheType) -> str:
|
||||
"""Generate a key for the given list of widgets used in a cache-decorated function.
|
||||
|
||||
Keys are generated by hashing the IDs and values of the widgets in the given list.
|
||||
"""
|
||||
func_hasher = hashlib.new("md5")
|
||||
for widget_id_val in widgets:
|
||||
update_hash(widget_id_val, func_hasher, cache_type)
|
||||
|
||||
return func_hasher.hexdigest()
|
||||
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Hashing for st.memo and st.singleton."""
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest.mock
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Pattern
|
||||
|
||||
from streamlit import type_util, util
|
||||
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
||||
|
||||
# If a dataframe has more than this many rows, we consider it large and hash a sample.
|
||||
_PANDAS_ROWS_LARGE = 100000
|
||||
_PANDAS_SAMPLE_SIZE = 10000
|
||||
|
||||
# Similar to dataframes, we also sample large numpy arrays.
|
||||
_NP_SIZE_LARGE = 1000000
|
||||
_NP_SAMPLE_SIZE = 100000
|
||||
|
||||
# Arbitrary item to denote where we found a cycle in a hashed object.
|
||||
# This allows us to hash self-referencing lists, dictionaries, etc.
|
||||
_CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
|
||||
|
||||
|
||||
def update_hash(val: Any, hasher, cache_type: CacheType) -> None:
|
||||
"""Updates a hashlib hasher with the hash of val.
|
||||
|
||||
This is the main entrypoint to hashing.py.
|
||||
"""
|
||||
ch = _CacheFuncHasher(cache_type)
|
||||
ch.update(hasher, val)
|
||||
|
||||
|
||||
class _HashStack:
|
||||
"""Stack of what has been hashed, for debug and circular reference detection.
|
||||
|
||||
This internally keeps 1 stack per thread.
|
||||
|
||||
Internally, this stores the ID of pushed objects rather than the objects
|
||||
themselves because otherwise the "in" operator inside __contains__ would
|
||||
fail for objects that don't return a boolean for "==" operator. For
|
||||
example, arr == 10 where arr is a NumPy array returns another NumPy array.
|
||||
This causes the "in" to crash since it expects a boolean.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def push(self, val: Any):
|
||||
self._stack[id(val)] = val
|
||||
|
||||
def pop(self):
|
||||
self._stack.popitem()
|
||||
|
||||
def __contains__(self, val: Any):
|
||||
return id(val) in self._stack
|
||||
|
||||
|
||||
class _HashStacks:
|
||||
"""Stacks of what has been hashed, with at most 1 stack per thread."""
|
||||
|
||||
def __init__(self):
|
||||
self._stacks: weakref.WeakKeyDictionary[
|
||||
threading.Thread, _HashStack
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def current(self) -> _HashStack:
|
||||
current_thread = threading.current_thread()
|
||||
|
||||
stack = self._stacks.get(current_thread, None)
|
||||
|
||||
if stack is None:
|
||||
stack = _HashStack()
|
||||
self._stacks[current_thread] = stack
|
||||
|
||||
return stack
|
||||
|
||||
|
||||
hash_stacks = _HashStacks()
|
||||
|
||||
|
||||
def _int_to_bytes(i: int) -> bytes:
|
||||
num_bytes = (i.bit_length() + 8) // 8
|
||||
return i.to_bytes(num_bytes, "little", signed=True)
|
||||
|
||||
|
||||
def _key(obj: Optional[Any]) -> Any:
|
||||
"""Return key for memoization."""
|
||||
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
def is_simple(obj):
|
||||
return (
|
||||
isinstance(obj, bytes)
|
||||
or isinstance(obj, bytearray)
|
||||
or isinstance(obj, str)
|
||||
or isinstance(obj, float)
|
||||
or isinstance(obj, int)
|
||||
or isinstance(obj, bool)
|
||||
or obj is None
|
||||
)
|
||||
|
||||
if is_simple(obj):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, tuple):
|
||||
if all(map(is_simple, obj)):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list):
|
||||
if all(map(is_simple, obj)):
|
||||
return ("__l", tuple(obj))
|
||||
|
||||
if (
|
||||
type_util.is_type(obj, "pandas.core.frame.DataFrame")
|
||||
or type_util.is_type(obj, "numpy.ndarray")
|
||||
or inspect.isbuiltin(obj)
|
||||
or inspect.isroutine(obj)
|
||||
or inspect.iscode(obj)
|
||||
):
|
||||
return id(obj)
|
||||
|
||||
return NoResult
|
||||
|
||||
|
||||
class _CacheFuncHasher:
|
||||
"""A hasher that can hash objects with cycles."""
|
||||
|
||||
def __init__(self, cache_type: CacheType):
|
||||
self._hashes: Dict[Any, bytes] = {}
|
||||
|
||||
# The number of the bytes in the hash.
|
||||
self.size = 0
|
||||
|
||||
self.cache_type = cache_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def to_bytes(self, obj: Any) -> bytes:
|
||||
"""Add memoization to _to_bytes and protect against cycles in data structures."""
|
||||
tname = type(obj).__qualname__.encode()
|
||||
key = (tname, _key(obj))
|
||||
|
||||
# Memoize if possible.
|
||||
if key[1] is not NoResult:
|
||||
if key in self._hashes:
|
||||
return self._hashes[key]
|
||||
|
||||
# Break recursive cycles.
|
||||
if obj in hash_stacks.current:
|
||||
return _CYCLE_PLACEHOLDER
|
||||
|
||||
hash_stacks.current.push(obj)
|
||||
|
||||
try:
|
||||
# Hash the input
|
||||
b = b"%s:%s" % (tname, self._to_bytes(obj))
|
||||
|
||||
# Hmmm... It's possible that the size calculation is wrong. When we
|
||||
# call to_bytes inside _to_bytes things get double-counted.
|
||||
self.size += sys.getsizeof(b)
|
||||
|
||||
if key[1] is not NoResult:
|
||||
self._hashes[key] = b
|
||||
|
||||
finally:
|
||||
# In case an UnhashableTypeError (or other) error is thrown, clean up the
|
||||
# stack so we don't get false positives in future hashing calls
|
||||
hash_stacks.current.pop()
|
||||
|
||||
return b
|
||||
|
||||
def update(self, hasher, obj: Any) -> None:
|
||||
"""Update the provided hasher with the hash of an object."""
|
||||
b = self.to_bytes(obj)
|
||||
hasher.update(b)
|
||||
|
||||
def _to_bytes(self, obj: Any) -> bytes:
|
||||
"""Hash objects to bytes, including code with dependencies.
|
||||
|
||||
Python's built in `hash` does not produce consistent results across
|
||||
runs.
|
||||
"""
|
||||
|
||||
if isinstance(obj, unittest.mock.Mock):
|
||||
# Mock objects can appear to be infinitely
|
||||
# deep, so we don't try to hash them at all.
|
||||
return self.to_bytes(id(obj))
|
||||
|
||||
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
|
||||
return obj
|
||||
|
||||
elif isinstance(obj, str):
|
||||
return obj.encode()
|
||||
|
||||
elif isinstance(obj, float):
|
||||
return self.to_bytes(hash(obj))
|
||||
|
||||
elif isinstance(obj, int):
|
||||
return _int_to_bytes(obj)
|
||||
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
h = hashlib.new("md5")
|
||||
for item in obj:
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
h = hashlib.new("md5")
|
||||
for item in obj.items():
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
elif obj is None:
|
||||
return b"0"
|
||||
|
||||
elif obj is True:
|
||||
return b"1"
|
||||
|
||||
elif obj is False:
|
||||
return b"0"
|
||||
|
||||
elif dataclasses.is_dataclass(obj):
|
||||
return self.to_bytes(dataclasses.asdict(obj))
|
||||
|
||||
elif isinstance(obj, Enum):
|
||||
return str(obj).encode()
|
||||
|
||||
elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
|
||||
obj, "pandas.core.series.Series"
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
if len(obj) >= _PANDAS_ROWS_LARGE:
|
||||
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
|
||||
try:
|
||||
return b"%s" % pd.util.hash_pandas_object(obj).sum()
|
||||
except TypeError:
|
||||
# Use pickle if pandas cannot hash the object for example if
|
||||
# it contains unhashable objects.
|
||||
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
elif type_util.is_type(obj, "numpy.ndarray"):
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.shape)
|
||||
|
||||
if obj.size >= _NP_SIZE_LARGE:
|
||||
import numpy as np
|
||||
|
||||
state = np.random.RandomState(0)
|
||||
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
|
||||
|
||||
self.update(h, obj.tobytes())
|
||||
return h.digest()
|
||||
elif type_util.is_type(obj, "PIL.Image.Image"):
|
||||
import numpy as np
|
||||
|
||||
# we don't just hash the results of obj.tobytes() because we want to use
|
||||
# the sampling logic for numpy data
|
||||
np_array = np.frombuffer(obj.tobytes(), dtype="uint8")
|
||||
return self.to_bytes(np_array)
|
||||
|
||||
elif inspect.isbuiltin(obj):
|
||||
return bytes(obj.__name__.encode())
|
||||
|
||||
elif type_util.is_type(obj, "builtins.mappingproxy") or type_util.is_type(
|
||||
obj, "builtins.dict_items"
|
||||
):
|
||||
return self.to_bytes(dict(obj))
|
||||
|
||||
elif type_util.is_type(obj, "builtins.getset_descriptor"):
|
||||
return bytes(obj.__qualname__.encode())
|
||||
|
||||
elif isinstance(obj, UploadedFile):
|
||||
# UploadedFile is a BytesIO (thus IOBase) but has a name.
|
||||
# It does not have a timestamp so this must come before
|
||||
# temporary files
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.name)
|
||||
self.update(h, obj.tell())
|
||||
self.update(h, obj.getvalue())
|
||||
return h.digest()
|
||||
|
||||
elif hasattr(obj, "name") and (
|
||||
isinstance(obj, io.IOBase)
|
||||
# Handle temporary files used during testing
|
||||
or isinstance(obj, tempfile._TemporaryFileWrapper)
|
||||
):
|
||||
# Hash files as name + last modification date + offset.
|
||||
# NB: we're using hasattr("name") to differentiate between
|
||||
# on-disk and in-memory StringIO/BytesIO file representations.
|
||||
# That means that this condition must come *before* the next
|
||||
# condition, which just checks for StringIO/BytesIO.
|
||||
h = hashlib.new("md5")
|
||||
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
|
||||
self.update(h, obj_name)
|
||||
self.update(h, os.path.getmtime(obj_name))
|
||||
self.update(h, obj.tell())
|
||||
return h.digest()
|
||||
|
||||
elif isinstance(obj, Pattern):
|
||||
return self.to_bytes([obj.pattern, obj.flags])
|
||||
|
||||
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
|
||||
# Hash in-memory StringIO/BytesIO by their full contents
|
||||
# and seek position.
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.tell())
|
||||
self.update(h, obj.getvalue())
|
||||
return h.digest()
|
||||
|
||||
elif type_util.is_type(obj, "numpy.ufunc"):
|
||||
# For numpy.remainder, this returns remainder.
|
||||
return bytes(obj.__name__.encode())
|
||||
|
||||
elif inspect.ismodule(obj):
|
||||
# TODO: Figure out how to best show this kind of warning to the
|
||||
# user. In the meantime, show nothing. This scenario is too common,
|
||||
# so the current warning is quite annoying...
|
||||
# st.warning(('Streamlit does not support hashing modules. '
|
||||
# 'We did not hash `%s`.') % obj.__name__)
|
||||
# TODO: Hash more than just the name for internal modules.
|
||||
return self.to_bytes(obj.__name__)
|
||||
|
||||
elif inspect.isclass(obj):
|
||||
# TODO: Figure out how to best show this kind of warning to the
|
||||
# user. In the meantime, show nothing. This scenario is too common,
|
||||
# (e.g. in every "except" statement) so the current warning is
|
||||
# quite annoying...
|
||||
# st.warning(('Streamlit does not support hashing classes. '
|
||||
# 'We did not hash `%s`.') % obj.__name__)
|
||||
# TODO: Hash more than just the name of classes.
|
||||
return self.to_bytes(obj.__name__)
|
||||
|
||||
elif isinstance(obj, functools.partial):
|
||||
# The return value of functools.partial is not a plain function:
|
||||
# it's a callable object that remembers the original function plus
|
||||
# the values you pickled into it. So here we need to special-case it.
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.args)
|
||||
self.update(h, obj.func)
|
||||
self.update(h, obj.keywords)
|
||||
return h.digest()
|
||||
|
||||
else:
|
||||
# As a last resort, hash the output of the object's __reduce__ method
|
||||
h = hashlib.new("md5")
|
||||
try:
|
||||
reduce_data = obj.__reduce__()
|
||||
except Exception as ex:
|
||||
raise UnhashableTypeError() from ex
|
||||
|
||||
for item in reduce_data:
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
|
||||
class NoResult:
|
||||
"""Placeholder class for return values when None is meaningful."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage as CacheStorage,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageContext as CacheStorageContext,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageError as CacheStorageError,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageKeyNotFoundError as CacheStorageKeyNotFoundError,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageManager as CacheStorageManager,
|
||||
)
|
||||
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Declares the CacheStorageContext dataclass, which contains parameter information for
|
||||
each function decorated by `@st.cache_data` (for example: ttl, max_entries etc.)
|
||||
|
||||
Declares the CacheStorageManager protocol, which implementations are used
|
||||
to create CacheStorage instances and to optionally clear all cache storages,
|
||||
that were created by this manager, and to check if the context is valid for the storage.
|
||||
|
||||
Declares the CacheStorage protocol, which implementations are used to store cached
|
||||
values for a single `@st.cache_data` decorated function serialized as bytes.
|
||||
|
||||
How these classes work together
|
||||
-------------------------------
|
||||
- CacheStorageContext : this is a dataclass that contains the parameters from
|
||||
`@st.cache_data` that are passed to the CacheStorageManager.create() method.
|
||||
|
||||
- CacheStorageManager : each instance of this is able to create CacheStorage
|
||||
instances, and optionally to clear data of all cache storages.
|
||||
|
||||
- CacheStorage : each instance of this is able to get, set, delete, and clear
|
||||
entries for a single `@st.cache_data` decorated function.
|
||||
|
||||
┌───────────────────────────────┐
|
||||
│ │
|
||||
│ CacheStorageManager │
|
||||
│ │
|
||||
│ - clear_all(optional) │
|
||||
│ - check_context │
|
||||
│ │
|
||||
└──┬────────────────────────────┘
|
||||
│
|
||||
│ ┌──────────────────────┐
|
||||
│ │ CacheStorage │
|
||||
│ create(context)│ │
|
||||
└────────────────► - get │
|
||||
│ - set │
|
||||
│ - delete │
|
||||
│ - close (optional)│
|
||||
│ - clear │
|
||||
└──────────────────────┘
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
|
||||
class CacheStorageError(Exception):
|
||||
"""Base exception raised by the cache storage"""
|
||||
|
||||
|
||||
class CacheStorageKeyNotFoundError(CacheStorageError):
|
||||
"""Raised when the key is not found in the cache storage"""
|
||||
|
||||
|
||||
class InvalidCacheStorageContext(CacheStorageError):
|
||||
"""Raised if the cache storage manager is not able to work with
|
||||
provided CacheStorageContext.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheStorageContext:
|
||||
"""Context passed to the cache storage during initialization
|
||||
This is the normalized parameters that are passed to CacheStorageManager.create()
|
||||
method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
function_key: str
|
||||
A hash computed based on function name and source code decorated
|
||||
by `@st.cache_data`
|
||||
|
||||
function_display_name: str
|
||||
The display name of the function that is decorated by `@st.cache_data`
|
||||
|
||||
ttl_seconds : float or None
|
||||
The time-to-live for the keys in storage, in seconds. If None, the entry
|
||||
will never expire.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to store in the cache storage.
|
||||
If None, the cache storage will not limit the number of entries.
|
||||
|
||||
persist : Literal["disk"] or None
|
||||
The persistence mode for the cache storage.
|
||||
Legacy parameter, that used in Streamlit current cache storage implementation.
|
||||
Could be ignored by cache storage implementation, if storage does not support
|
||||
persistence or it persistent by default.
|
||||
"""
|
||||
|
||||
function_key: str
|
||||
function_display_name: str
|
||||
ttl_seconds: float | None = None
|
||||
max_entries: int | None = None
|
||||
persist: Literal["disk"] | None = None
|
||||
|
||||
|
||||
class CacheStorage(Protocol):
|
||||
"""Cache storage protocol, that should be implemented by the concrete cache storages.
|
||||
Used to store cached values for a single `@st.cache_data` decorated function
|
||||
serialized as bytes.
|
||||
|
||||
CacheStorage instances should be created by `CacheStorageManager.create()` method.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: The methods of this protocol could be called from multiple threads.
|
||||
This is a responsibility of the concrete implementation to ensure thread safety
|
||||
guarantees.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> bytes:
|
||||
"""Returns the stored value for the key.
|
||||
|
||||
Raises
|
||||
------
|
||||
CacheStorageKeyNotFoundError
|
||||
Raised if the key is not in the storage.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a given key"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all keys for the storage"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the cache storage, it is optional to implement, and should be used
|
||||
to close open resources, before we delete the storage instance.
|
||||
e.g. close the database connection etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CacheStorageManager(Protocol):
|
||||
"""Cache storage manager protocol, that should be implemented by the concrete
|
||||
cache storage managers.
|
||||
|
||||
It is responsible for:
|
||||
- Creating cache storage instances for the specific
|
||||
decorated functions,
|
||||
- Validating the context for the cache storages.
|
||||
- Optionally clearing all cache storages in optimal way.
|
||||
|
||||
It should be created during Runtime initialization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance
|
||||
Please note that the ttl, max_entries and other context fields are specific
|
||||
for whole storage, not for individual key.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: Should be safe to call from any thread.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Remove everything what possible from the cache storages in optimal way.
|
||||
meaningful default behaviour is to raise NotImplementedError, so this is not
|
||||
abstractmethod.
|
||||
|
||||
The method is optional to implement: cache data API will fall back to remove
|
||||
all available storages one by one via storage.clear() method
|
||||
if clear_all raises NotImplementedError.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
Raised if the storage manager does not provide an ability to clear
|
||||
all storages at once in optimal way.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: This method could be called from multiple threads.
|
||||
This is a responsibility of the concrete implementation to ensure
|
||||
thread safety guarantees.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
"""Checks if the context is valid for the storage manager.
|
||||
This method should not return anything, but log message or raise an exception
|
||||
if the context is invalid.
|
||||
|
||||
In case of raising an exception, we not handle it and let the exception to be
|
||||
propagated.
|
||||
|
||||
check_context is called only once at the moment of creating `@st.cache_data`
|
||||
decorator for specific function, so it is not called for every cache hit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context: CacheStorageContext
|
||||
The context to check for the storage manager, dummy function_key in context
|
||||
will be used, since it is not computed at the point of calling this method.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidCacheStorageContext
|
||||
Raised if the cache storage manager is not able to work with provided
|
||||
CacheStorageContext. When possible we should log message instead, since
|
||||
this exception will be propagated to the user.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: Should be safe to call from any thread.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
|
||||
InMemoryCacheStorageWrapper,
|
||||
)
|
||||
|
||||
|
||||
class MemoryCacheStorageManager(CacheStorageManager):
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
|
||||
persist_storage = DummyCacheStorage()
|
||||
return InMemoryCacheStorageWrapper(
|
||||
persist_storage=persist_storage, context=context
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyCacheStorage(CacheStorage):
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Dummy gets the value for a given key,
|
||||
always raises an CacheStorageKeyNotFoundError
|
||||
"""
|
||||
raise CacheStorageKeyNotFoundError("Key not found in dummy cache")
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
pass
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import cache_utils
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageKeyNotFoundError,
|
||||
)
|
||||
from streamlit.runtime.stats import CacheStat
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class InMemoryCacheStorageWrapper(CacheStorage):
|
||||
"""
|
||||
In-memory cache storage wrapper.
|
||||
|
||||
This class wraps a cache storage and adds an in-memory cache front layer,
|
||||
which is used to reduce the number of calls to the storage.
|
||||
|
||||
The in-memory cache is a TTL cache, which means that the entries are
|
||||
automatically removed if a given time to live (TTL) has passed.
|
||||
|
||||
The in-memory cache is also an LRU cache, which means that the entries
|
||||
are automatically removed if the cache size exceeds a given maxsize.
|
||||
|
||||
If the storage implements its strategy for maxsize, it is recommended
|
||||
(but not necessary) that the storage implement the same LRU strategy,
|
||||
otherwise a situation may arise when different items are deleted from
|
||||
the memory cache and from the storage.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: in-memory caching layer is thread safe: we hold self._mem_cache_lock for
|
||||
working with this self._mem_cache object.
|
||||
However, we do not hold this lock when calling into the underlying storage,
|
||||
so it is the responsibility of the that storage to ensure that it is safe to use
|
||||
it from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
|
||||
self.function_key = context.function_key
|
||||
self.function_display_name = context.function_display_name
|
||||
self._ttl_seconds = context.ttl_seconds
|
||||
self._max_entries = context.max_entries
|
||||
self._mem_cache: TTLCache[str, bytes] = TTLCache(
|
||||
maxsize=self.max_entries,
|
||||
ttl=self.ttl_seconds,
|
||||
timer=cache_utils.TTLCACHE_TIMER,
|
||||
)
|
||||
self._mem_cache_lock = threading.Lock()
|
||||
self._persist_storage = persist_storage
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return float(self._max_entries) if self._max_entries is not None else math.inf
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Returns the stored value for the key or raise CacheStorageKeyNotFoundError if
|
||||
the key is not found
|
||||
"""
|
||||
try:
|
||||
entry_bytes = self._read_from_mem_cache(key)
|
||||
except CacheStorageKeyNotFoundError:
|
||||
entry_bytes = self._persist_storage.get(key)
|
||||
self._write_to_mem_cache(key, entry_bytes)
|
||||
return entry_bytes
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key"""
|
||||
self._write_to_mem_cache(key, value)
|
||||
self._persist_storage.set(key, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a given key"""
|
||||
self._remove_from_mem_cache(key)
|
||||
self._persist_storage.delete(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all keys for the in memory cache, and also the persistent storage"""
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.clear()
|
||||
self._persist_storage.clear()
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
"""Returns a list of stats in bytes for the cache memory storage per item"""
|
||||
stats = []
|
||||
|
||||
with self._mem_cache_lock:
|
||||
for item in self._mem_cache.values():
|
||||
stats.append(
|
||||
CacheStat(
|
||||
category_name="st_cache_data",
|
||||
cache_name=self.function_display_name,
|
||||
byte_length=len(item),
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the cache storage"""
|
||||
self._persist_storage.close()
|
||||
|
||||
def _read_from_mem_cache(self, key: str) -> bytes:
|
||||
with self._mem_cache_lock:
|
||||
if key in self._mem_cache:
|
||||
entry = bytes(self._mem_cache[key])
|
||||
_LOGGER.debug("Memory cache HIT: %s", key)
|
||||
return entry
|
||||
|
||||
else:
|
||||
_LOGGER.debug("Memory cache MISS: %s", key)
|
||||
raise CacheStorageKeyNotFoundError("Key not found in mem cache")
|
||||
|
||||
def _write_to_mem_cache(self, key: str, entry_bytes: bytes) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache[key] = entry_bytes
|
||||
|
||||
def _remove_from_mem_cache(self, key: str) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.pop(key, None)
|
||||
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Declares the LocalDiskCacheStorageManager class, which is used
|
||||
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
|
||||
InMemoryCacheStorageWrapper wrapper allows to have first layer of in-memory cache,
|
||||
before accessing to LocalDiskCacheStorage itself.
|
||||
|
||||
Declares the LocalDiskCacheStorage class, which is used to store cached
|
||||
values on disk.
|
||||
|
||||
How these classes work together
|
||||
-------------------------------
|
||||
|
||||
- LocalDiskCacheStorageManager : each instance of this is able
|
||||
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
|
||||
and to clear data from cache storage folder. It is also LocalDiskCacheStorageManager
|
||||
responsibility to check if the context is valid for the storage, and to log warning
|
||||
if the context is not valid.
|
||||
|
||||
- LocalDiskCacheStorage : each instance of this is able to get, set, delete, and clear
|
||||
entries from disk for a single `@st.cache_data` decorated function if `persist="disk"`
|
||||
is used in CacheStorageContext.
|
||||
|
||||
|
||||
┌───────────────────────────────┐
|
||||
│ LocalDiskCacheStorageManager │
|
||||
│ │
|
||||
│ - clear_all │
|
||||
│ - check_context │
|
||||
│ │
|
||||
└──┬────────────────────────────┘
|
||||
│
|
||||
│ ┌──────────────────────────────┐
|
||||
│ │ │
|
||||
│ create(context)│ InMemoryCacheStorageWrapper │
|
||||
└────────────────► │
|
||||
│ ┌─────────────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ LocalDiskStorage │ │
|
||||
│ │ │ │
|
||||
│ └─────────────────────┘ │
|
||||
│ │
|
||||
└──────────────────────────────┘
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.file_util import get_streamlit_file_path, streamlit_read, streamlit_write
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
|
||||
InMemoryCacheStorageWrapper,
|
||||
)
|
||||
|
||||
# Streamlit directory where persisted @st.cache_data objects live.
|
||||
# (This is the same directory that @st.cache persisted objects live.
|
||||
# But @st.cache_data uses a different extension, so they don't overlap.)
|
||||
_CACHE_DIR_NAME = "cache"
|
||||
|
||||
# The extension for our persisted @st.cache_data objects.
|
||||
# (`@st.cache_data` was originally called `@st.memo`)
|
||||
_CACHED_FILE_EXTENSION = "memo"
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class LocalDiskCacheStorageManager(CacheStorageManager):
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
|
||||
persist_storage = LocalDiskCacheStorage(context)
|
||||
return InMemoryCacheStorageWrapper(
|
||||
persist_storage=persist_storage, context=context
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
cache_path = get_cache_folder_path()
|
||||
if os.path.isdir(cache_path):
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
if (
|
||||
context.persist == "disk"
|
||||
and context.ttl_seconds is not None
|
||||
and not math.isinf(context.ttl_seconds)
|
||||
):
|
||||
_LOGGER.warning(
|
||||
f"The cached function '{context.function_display_name}' has a TTL "
|
||||
"that will be ignored. Persistent cached functions currently don't "
|
||||
"support TTL."
|
||||
)
|
||||
|
||||
|
||||
class LocalDiskCacheStorage(CacheStorage):
|
||||
"""Cache storage that persists data to disk
|
||||
This is the default cache persistence layer for `@st.cache_data`
|
||||
"""
|
||||
|
||||
def __init__(self, context: CacheStorageContext):
|
||||
self.function_key = context.function_key
|
||||
self.persist = context.persist
|
||||
self._ttl_seconds = context.ttl_seconds
|
||||
self._max_entries = context.max_entries
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return float(self._max_entries) if self._max_entries is not None else math.inf
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Returns the stored value for the key if persisted,
|
||||
raise CacheStorageKeyNotFoundError if not found, or not configured
|
||||
with persist="disk"
|
||||
"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
with streamlit_read(path, binary=True) as input:
|
||||
value = input.read()
|
||||
_LOGGER.debug("Disk cache HIT: %s", key)
|
||||
return bytes(value)
|
||||
except FileNotFoundError:
|
||||
raise CacheStorageKeyNotFoundError("Key not found in disk cache")
|
||||
except Exception as ex:
|
||||
_LOGGER.error(ex)
|
||||
raise CacheStorageError("Unable to read from cache") from ex
|
||||
else:
|
||||
raise CacheStorageKeyNotFoundError(
|
||||
f"Local disk cache storage is disabled (persist={self.persist})"
|
||||
)
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
with streamlit_write(path, binary=True) as output:
|
||||
output.write(value)
|
||||
except util.Error as e:
|
||||
_LOGGER.debug(e)
|
||||
# Clean up file so we don't leave zero byte files.
|
||||
try:
|
||||
os.remove(path)
|
||||
except (FileNotFoundError, IOError, OSError):
|
||||
# If we can't remove the file, it's not a big deal.
|
||||
pass
|
||||
raise CacheStorageError("Unable to write to cache") from e
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a cache file from disk. If the file does not exist on disk,
|
||||
return silently. If another exception occurs, log it. Does not throw.
|
||||
"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
# The file is already removed.
|
||||
pass
|
||||
except Exception as ex:
|
||||
_LOGGER.exception(
|
||||
"Unable to remove a file from the disk cache", exc_info=ex
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all keys for the current storage"""
|
||||
cache_dir = get_cache_folder_path()
|
||||
|
||||
if os.path.isdir(cache_dir):
|
||||
# We try to remove all files in the cache directory that start with
|
||||
# the function key, whether `clear` called for `self.persist`
|
||||
# storage or not, to avoid leaving orphaned files in the cache directory.
|
||||
for file_name in os.listdir(cache_dir):
|
||||
if self._is_cache_file(file_name):
|
||||
os.remove(os.path.join(cache_dir, file_name))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Dummy implementation of close, we don't need to actually "close" anything"""
|
||||
|
||||
def _get_cache_file_path(self, value_key: str) -> str:
|
||||
"""Return the path of the disk cache file for the given value."""
|
||||
cache_dir = get_cache_folder_path()
|
||||
return os.path.join(
|
||||
cache_dir, f"{self.function_key}-{value_key}.{_CACHED_FILE_EXTENSION}"
|
||||
)
|
||||
|
||||
def _is_cache_file(self, fname: str) -> bool:
|
||||
"""Return true if the given file name is a cache file for this storage."""
|
||||
return fname.startswith(f"{self.function_key}-") and fname.endswith(
|
||||
f".{_CACHED_FILE_EXTENSION}"
|
||||
)
|
||||
|
||||
|
||||
def get_cache_folder_path() -> str:
|
||||
return get_streamlit_file_path(_CACHE_DIR_NAME)
|
||||
@@ -0,0 +1,308 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Manage the user's Streamlit credentials."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import toml
|
||||
|
||||
from streamlit import env_util, file_util, util
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
if env_util.IS_WINDOWS:
|
||||
_CONFIG_FILE_PATH = r"%userprofile%/.streamlit/config.toml"
|
||||
else:
|
||||
_CONFIG_FILE_PATH = "~/.streamlit/config.toml"
|
||||
|
||||
_Activation = namedtuple(
|
||||
"_Activation",
|
||||
[
|
||||
"email", # str : the user's email.
|
||||
"is_valid", # boolean : whether the email is valid.
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def email_prompt() -> str:
|
||||
# Emoji can cause encoding errors on non-UTF-8 terminals
|
||||
# (See https://github.com/streamlit/streamlit/issues/2284.)
|
||||
# WT_SESSION is a Windows Terminal specific environment variable. If it exists,
|
||||
# we are on the latest Windows Terminal that supports emojis
|
||||
show_emoji = sys.stdout.encoding == "utf-8" and (
|
||||
not env_util.IS_WINDOWS or os.environ.get("WT_SESSION")
|
||||
)
|
||||
|
||||
# IMPORTANT: Break the text below at 80 chars.
|
||||
return """
|
||||
{0}%(welcome)s
|
||||
|
||||
If you’d like to receive helpful onboarding emails, news, offers, promotions,
|
||||
and the occasional swag, please enter your email address below. Otherwise,
|
||||
leave this field blank.
|
||||
|
||||
%(email)s""".format(
|
||||
"👋 " if show_emoji else ""
|
||||
) % {
|
||||
"welcome": click.style("Welcome to Streamlit!", bold=True),
|
||||
"email": click.style("Email: ", fg="blue"),
|
||||
}
|
||||
|
||||
|
||||
# IMPORTANT: Break the text below at 80 chars.
|
||||
_TELEMETRY_TEXT = """
|
||||
You can find our privacy policy at %(link)s
|
||||
|
||||
Summary:
|
||||
- This open source library collects usage statistics.
|
||||
- We cannot see and do not store information contained inside Streamlit apps,
|
||||
such as text, charts, images, etc.
|
||||
- Telemetry data is stored in servers in the United States.
|
||||
- If you'd like to opt out, add the following to %(config)s,
|
||||
creating that file if necessary:
|
||||
|
||||
[browser]
|
||||
gatherUsageStats = false
|
||||
""" % {
|
||||
"link": click.style("https://streamlit.io/privacy-policy", underline=True),
|
||||
"config": click.style(_CONFIG_FILE_PATH),
|
||||
}
|
||||
|
||||
_TELEMETRY_HEADLESS_TEXT = """
|
||||
Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
|
||||
"""
|
||||
|
||||
# IMPORTANT: Break the text below at 80 chars.
|
||||
_INSTRUCTIONS_TEXT = """
|
||||
%(start)s
|
||||
%(prompt)s %(hello)s
|
||||
""" % {
|
||||
"start": click.style("Get started by typing:", fg="blue", bold=True),
|
||||
"prompt": click.style("$", fg="blue"),
|
||||
"hello": click.style("streamlit hello", bold=True),
|
||||
}
|
||||
|
||||
|
||||
class Credentials(object):
|
||||
"""Credentials class."""
|
||||
|
||||
_singleton: Optional["Credentials"] = None
|
||||
|
||||
@classmethod
|
||||
def get_current(cls):
|
||||
"""Return the singleton instance."""
|
||||
if cls._singleton is None:
|
||||
Credentials()
|
||||
|
||||
return Credentials._singleton
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize class."""
|
||||
if Credentials._singleton is not None:
|
||||
raise RuntimeError(
|
||||
"Credentials already initialized. Use .get_current() instead"
|
||||
)
|
||||
|
||||
self.activation = None
|
||||
self._conf_file = _get_credential_file_path()
|
||||
|
||||
Credentials._singleton = self
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def load(self, auto_resolve=False) -> None:
|
||||
"""Load from toml file."""
|
||||
if self.activation is not None:
|
||||
LOGGER.error("Credentials already loaded. Not rereading file.")
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self._conf_file, "r") as f:
|
||||
data = toml.load(f).get("general")
|
||||
if data is None:
|
||||
raise Exception
|
||||
self.activation = _verify_email(data.get("email"))
|
||||
except FileNotFoundError:
|
||||
if auto_resolve:
|
||||
self.activate(show_instructions=not auto_resolve)
|
||||
return
|
||||
raise RuntimeError(
|
||||
'Credentials not found. Please run "streamlit activate".'
|
||||
)
|
||||
except Exception:
|
||||
if auto_resolve:
|
||||
self.reset()
|
||||
self.activate(show_instructions=not auto_resolve)
|
||||
return
|
||||
raise Exception(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
Unable to load credentials from %s.
|
||||
Run "streamlit reset" and try again.
|
||||
"""
|
||||
)
|
||||
% (self._conf_file)
|
||||
)
|
||||
|
||||
def _check_activated(self, auto_resolve=True):
|
||||
"""Check if streamlit is activated.
|
||||
|
||||
Used by `streamlit run script.py`
|
||||
"""
|
||||
try:
|
||||
self.load(auto_resolve)
|
||||
except (Exception, RuntimeError) as e:
|
||||
_exit(str(e))
|
||||
|
||||
if self.activation is None or not self.activation.is_valid:
|
||||
_exit("Activation email not valid.")
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
"""Reset credentials by removing file.
|
||||
|
||||
This is used by `streamlit activate reset` in case a user wants
|
||||
to start over.
|
||||
"""
|
||||
c = Credentials.get_current()
|
||||
c.activation = None
|
||||
|
||||
try:
|
||||
os.remove(c._conf_file)
|
||||
except OSError as e:
|
||||
LOGGER.error("Error removing credentials file: %s" % e)
|
||||
|
||||
def save(self):
|
||||
"""Save to toml file."""
|
||||
if self.activation is None:
|
||||
return
|
||||
|
||||
# Create intermediate directories if necessary
|
||||
os.makedirs(os.path.dirname(self._conf_file), exist_ok=True)
|
||||
|
||||
# Write the file
|
||||
data = {"email": self.activation.email}
|
||||
with open(self._conf_file, "w") as f:
|
||||
toml.dump({"general": data}, f)
|
||||
|
||||
def activate(self, show_instructions: bool = True) -> None:
|
||||
"""Activate Streamlit.
|
||||
|
||||
Used by `streamlit activate`.
|
||||
"""
|
||||
try:
|
||||
self.load()
|
||||
except RuntimeError:
|
||||
# Runtime Error is raised if credentials file is not found. In that case,
|
||||
# `self.activation` is None and we will show the activation prompt below.
|
||||
pass
|
||||
|
||||
if self.activation:
|
||||
if self.activation.is_valid:
|
||||
_exit("Already activated")
|
||||
else:
|
||||
_exit(
|
||||
"Activation not valid. Please run "
|
||||
"`streamlit activate reset` then `streamlit activate`"
|
||||
)
|
||||
else:
|
||||
activated = False
|
||||
|
||||
while not activated:
|
||||
email = click.prompt(
|
||||
text=email_prompt(),
|
||||
prompt_suffix="",
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
|
||||
self.activation = _verify_email(email)
|
||||
if self.activation.is_valid:
|
||||
self.save()
|
||||
click.secho(_TELEMETRY_TEXT)
|
||||
if show_instructions:
|
||||
click.secho(_INSTRUCTIONS_TEXT)
|
||||
activated = True
|
||||
else: # pragma: nocover
|
||||
LOGGER.error("Please try again.")
|
||||
|
||||
|
||||
def _verify_email(email: str) -> _Activation:
|
||||
"""Verify the user's email address.
|
||||
|
||||
The email can either be an empty string (if the user chooses not to enter
|
||||
it), or a string with a single '@' somewhere in it.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
email : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
_Activation
|
||||
An _Activation object. Its 'is_valid' property will be True only if
|
||||
the email was validated.
|
||||
|
||||
"""
|
||||
email = email.strip()
|
||||
|
||||
# We deliberately use simple email validation here
|
||||
# since we do not use email address anywhere to send emails.
|
||||
if len(email) > 0 and email.count("@") != 1:
|
||||
LOGGER.error("That doesn't look like an email :(")
|
||||
return _Activation(None, False)
|
||||
|
||||
return _Activation(email, True)
|
||||
|
||||
|
||||
def _exit(message): # pragma: nocover
|
||||
"""Exit program with error."""
|
||||
LOGGER.error(message)
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def _get_credential_file_path():
|
||||
return file_util.get_streamlit_file_path("credentials.toml")
|
||||
|
||||
|
||||
def _check_credential_file_exists():
|
||||
return os.path.exists(_get_credential_file_path())
|
||||
|
||||
|
||||
def check_credentials():
|
||||
"""Check credentials and potentially activate.
|
||||
|
||||
Note
|
||||
----
|
||||
If there is no credential file and we are in headless mode, we should not
|
||||
check, since credential would be automatically set to an empty string.
|
||||
|
||||
"""
|
||||
from streamlit import config
|
||||
|
||||
if not _check_credential_file_exists() and config.get_option("server.headless"):
|
||||
if not config.is_manually_set("browser.gatherUsageStats"):
|
||||
# If not manually defined, show short message about usage stats gathering.
|
||||
click.secho(_TELEMETRY_HEADLESS_TEXT)
|
||||
return
|
||||
Credentials.get_current()._check_activated()
|
||||
@@ -0,0 +1,270 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Dict, List, MutableMapping, Optional
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
from streamlit import config, util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def populate_hash_if_needed(msg: ForwardMsg) -> str:
|
||||
"""Computes and assigns the unique hash for a ForwardMsg.
|
||||
|
||||
If the ForwardMsg already has a hash, this is a no-op.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
|
||||
Returns
|
||||
-------
|
||||
string
|
||||
The message's hash, returned here for convenience. (The hash
|
||||
will also be assigned to the ForwardMsg; callers do not need
|
||||
to do this.)
|
||||
|
||||
"""
|
||||
if msg.hash == "":
|
||||
# Move the message's metadata aside. It's not part of the
|
||||
# hash calculation.
|
||||
metadata = msg.metadata
|
||||
msg.ClearField("metadata")
|
||||
|
||||
# MD5 is good enough for what we need, which is uniqueness.
|
||||
hasher = hashlib.md5()
|
||||
hasher.update(msg.SerializeToString())
|
||||
msg.hash = hasher.hexdigest()
|
||||
|
||||
# Restore metadata.
|
||||
msg.metadata.CopyFrom(metadata)
|
||||
|
||||
return msg.hash
|
||||
|
||||
|
||||
def create_reference_msg(msg: ForwardMsg) -> ForwardMsg:
|
||||
"""Create a ForwardMsg that refers to the given message via its hash.
|
||||
|
||||
The reference message will also get a copy of the source message's
|
||||
metadata.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
The ForwardMsg to create the reference to.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ForwardMsg
|
||||
A new ForwardMsg that "points" to the original message via the
|
||||
ref_hash field.
|
||||
|
||||
"""
|
||||
ref_msg = ForwardMsg()
|
||||
ref_msg.ref_hash = populate_hash_if_needed(msg)
|
||||
ref_msg.metadata.CopyFrom(msg.metadata)
|
||||
return ref_msg
|
||||
|
||||
|
||||
class ForwardMsgCache(CacheStatsProvider):
|
||||
"""A cache of ForwardMsgs.
|
||||
|
||||
Large ForwardMsgs (e.g. those containing big DataFrame payloads) are
|
||||
stored in this cache. The server can choose to send a ForwardMsg's hash,
|
||||
rather than the message itself, to a client. Clients can then
|
||||
request messages from this cache via another endpoint.
|
||||
|
||||
This cache is *not* thread safe. It's intended to only be accessed by
|
||||
the server thread.
|
||||
|
||||
"""
|
||||
|
||||
class Entry:
|
||||
"""Cache entry.
|
||||
|
||||
Stores the cached message, and the set of AppSessions
|
||||
that we've sent the cached message to.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, msg: ForwardMsg):
|
||||
self.msg = msg
|
||||
self._session_script_run_counts: MutableMapping[
|
||||
"AppSession", int
|
||||
] = WeakKeyDictionary()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def add_session_ref(self, session: "AppSession", script_run_count: int) -> None:
|
||||
"""Adds a reference to a AppSession that has referenced
|
||||
this Entry's message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : AppSession
|
||||
script_run_count : int
|
||||
The session's run count at the time of the call
|
||||
|
||||
"""
|
||||
prev_run_count = self._session_script_run_counts.get(session, 0)
|
||||
if script_run_count < prev_run_count:
|
||||
LOGGER.error(
|
||||
"New script_run_count (%s) is < prev_run_count (%s). "
|
||||
"This should never happen!" % (script_run_count, prev_run_count)
|
||||
)
|
||||
script_run_count = prev_run_count
|
||||
self._session_script_run_counts[session] = script_run_count
|
||||
|
||||
def has_session_ref(self, session: "AppSession") -> bool:
|
||||
return session in self._session_script_run_counts
|
||||
|
||||
def get_session_ref_age(
|
||||
self, session: "AppSession", script_run_count: int
|
||||
) -> int:
|
||||
"""The age of the given session's reference to the Entry,
|
||||
given a new script_run_count.
|
||||
|
||||
"""
|
||||
return script_run_count - self._session_script_run_counts[session]
|
||||
|
||||
def remove_session_ref(self, session: "AppSession") -> None:
|
||||
del self._session_script_run_counts[session]
|
||||
|
||||
def has_refs(self) -> bool:
|
||||
"""True if this Entry has references from any AppSession.
|
||||
|
||||
If not, it can be removed from the cache.
|
||||
"""
|
||||
return len(self._session_script_run_counts) > 0
|
||||
|
||||
def __init__(self):
|
||||
self._entries: Dict[str, "ForwardMsgCache.Entry"] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def add_message(
|
||||
self, msg: ForwardMsg, session: "AppSession", script_run_count: int
|
||||
) -> None:
|
||||
"""Add a ForwardMsg to the cache.
|
||||
|
||||
The cache will also record a reference to the given AppSession,
|
||||
so that it can track which sessions have already received
|
||||
each given ForwardMsg.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
msg : ForwardMsg
|
||||
session : AppSession
|
||||
script_run_count : int
|
||||
The number of times the session's script has run
|
||||
|
||||
"""
|
||||
populate_hash_if_needed(msg)
|
||||
entry = self._entries.get(msg.hash, None)
|
||||
if entry is None:
|
||||
entry = ForwardMsgCache.Entry(msg)
|
||||
self._entries[msg.hash] = entry
|
||||
entry.add_session_ref(session, script_run_count)
|
||||
|
||||
def get_message(self, hash: str) -> Optional[ForwardMsg]:
|
||||
"""Return the message with the given ID if it exists in the cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hash : string
|
||||
The id of the message to retrieve.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ForwardMsg | None
|
||||
|
||||
"""
|
||||
entry = self._entries.get(hash, None)
|
||||
return entry.msg if entry else None
|
||||
|
||||
def has_message_reference(
|
||||
self, msg: ForwardMsg, session: "AppSession", script_run_count: int
|
||||
) -> bool:
|
||||
"""Return True if a session has a reference to a message."""
|
||||
populate_hash_if_needed(msg)
|
||||
|
||||
entry = self._entries.get(msg.hash, None)
|
||||
if entry is None or not entry.has_session_ref(session):
|
||||
return False
|
||||
|
||||
# Ensure we're not expired
|
||||
age = entry.get_session_ref_age(session, script_run_count)
|
||||
return age <= int(config.get_option("global.maxCachedMessageAge"))
|
||||
|
||||
def remove_expired_session_entries(
|
||||
self, session: "AppSession", script_run_count: int
|
||||
) -> None:
|
||||
"""Remove any cached messages that have expired from the given session.
|
||||
|
||||
This should be called each time a AppSession finishes executing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session : AppSession
|
||||
script_run_count : int
|
||||
The number of times the session's script has run
|
||||
|
||||
"""
|
||||
max_age = config.get_option("global.maxCachedMessageAge")
|
||||
|
||||
# Operate on a copy of our entries dict.
|
||||
# We may be deleting from it.
|
||||
for msg_hash, entry in self._entries.copy().items():
|
||||
if not entry.has_session_ref(session):
|
||||
continue
|
||||
|
||||
age = entry.get_session_ref_age(session, script_run_count)
|
||||
if age > max_age:
|
||||
LOGGER.debug(
|
||||
"Removing expired entry [session=%s, hash=%s, age=%s]",
|
||||
id(session),
|
||||
msg_hash,
|
||||
age,
|
||||
)
|
||||
entry.remove_session_ref(session)
|
||||
if not entry.has_refs():
|
||||
# The entry has no more references. Remove it from
|
||||
# the cache completely.
|
||||
del self._entries[msg_hash]
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries from the cache"""
|
||||
self._entries.clear()
|
||||
|
||||
def get_stats(self) -> List[CacheStat]:
|
||||
stats: List[CacheStat] = []
|
||||
for entry_hash, entry in self._entries.items():
|
||||
stats.append(
|
||||
CacheStat(
|
||||
category_name="ForwardMessageCache",
|
||||
cache_name="",
|
||||
byte_length=entry.msg.ByteSize(),
|
||||
)
|
||||
)
|
||||
return stats
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.Delta_pb2 import Delta
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class ForwardMsgQueue:
|
||||
"""Accumulates a session's outgoing ForwardMsgs.
|
||||
|
||||
Each AppSession adds messages to its queue, and the Server periodically
|
||||
flushes all session queues and delivers their messages to the appropriate
|
||||
clients.
|
||||
|
||||
ForwardMsgQueue is not thread-safe - a queue should only be used from
|
||||
a single thread.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._queue: List[ForwardMsg] = []
|
||||
# A mapping of (delta_path -> _queue.indexof(msg)) for each
|
||||
# Delta message in the queue. We use this for coalescing
|
||||
# redundant outgoing Deltas (where a newer Delta supersedes
|
||||
# an older Delta, with the same delta_path, that's still in the
|
||||
# queue).
|
||||
self._delta_index_map: Dict[Tuple[int, ...], int] = dict()
|
||||
|
||||
def get_debug(self) -> Dict[str, Any]:
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
return {
|
||||
"queue": [MessageToDict(m) for m in self._queue],
|
||||
"ids": list(self._delta_index_map.keys()),
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return len(self._queue) == 0
|
||||
|
||||
def enqueue(self, msg: ForwardMsg) -> None:
|
||||
"""Add message into queue, possibly composing it with another message."""
|
||||
if not _is_composable_message(msg):
|
||||
self._queue.append(msg)
|
||||
return
|
||||
|
||||
# If there's a Delta message with the same delta_path already in
|
||||
# the queue - meaning that it refers to the same location in
|
||||
# the app - we attempt to combine this new Delta into the old
|
||||
# one. This is an optimization that prevents redundant Deltas
|
||||
# from being sent to the frontend.
|
||||
delta_key = tuple(msg.metadata.delta_path)
|
||||
if delta_key in self._delta_index_map:
|
||||
index = self._delta_index_map[delta_key]
|
||||
old_msg = self._queue[index]
|
||||
composed_delta = _maybe_compose_deltas(old_msg.delta, msg.delta)
|
||||
if composed_delta is not None:
|
||||
new_msg = ForwardMsg()
|
||||
new_msg.delta.CopyFrom(composed_delta)
|
||||
new_msg.metadata.CopyFrom(msg.metadata)
|
||||
self._queue[index] = new_msg
|
||||
return
|
||||
|
||||
# No composition occurred. Append this message to the queue, and
|
||||
# store its index for potential future composition.
|
||||
self._delta_index_map[delta_key] = len(self._queue)
|
||||
self._queue.append(msg)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the queue."""
|
||||
self._queue = []
|
||||
self._delta_index_map = dict()
|
||||
|
||||
def flush(self) -> List[ForwardMsg]:
|
||||
"""Clear the queue and return a list of the messages it contained
|
||||
before being cleared.
|
||||
"""
|
||||
queue = self._queue
|
||||
self.clear()
|
||||
return queue
|
||||
|
||||
|
||||
def _is_composable_message(msg: ForwardMsg) -> bool:
|
||||
"""True if the ForwardMsg is potentially composable with other ForwardMsgs."""
|
||||
if not msg.HasField("delta"):
|
||||
# Non-delta messages are never composable.
|
||||
return False
|
||||
|
||||
# We never compose add_rows messages in Python, because the add_rows
|
||||
# operation can raise errors, and we don't have a good way of handling
|
||||
# those errors in the message queue.
|
||||
delta_type = msg.delta.WhichOneof("type")
|
||||
return delta_type != "add_rows" and delta_type != "arrow_add_rows"
|
||||
|
||||
|
||||
def _maybe_compose_deltas(old_delta: Delta, new_delta: Delta) -> Optional[Delta]:
|
||||
"""Combines new_delta onto old_delta if possible.
|
||||
|
||||
If the combination takes place, the function returns a new Delta that
|
||||
should replace old_delta in the queue.
|
||||
|
||||
If the new_delta is incompatible with old_delta, the function returns None.
|
||||
In this case, the new_delta should just be appended to the queue as normal.
|
||||
"""
|
||||
old_delta_type = old_delta.WhichOneof("type")
|
||||
if old_delta_type == "add_block":
|
||||
# We never replace add_block deltas, because blocks can have
|
||||
# other dependent deltas later in the queue. For example:
|
||||
#
|
||||
# placeholder = st.empty()
|
||||
# placeholder.columns(1)
|
||||
# placeholder.empty()
|
||||
#
|
||||
# The call to "placeholder.columns(1)" creates two blocks, a parent
|
||||
# container with delta_path (0, 0), and a column child with
|
||||
# delta_path (0, 0, 0). If the final "placeholder.empty()" Delta
|
||||
# is composed with the parent container Delta, the frontend will
|
||||
# throw an error when it tries to add that column child to what is
|
||||
# now just an element, and not a block.
|
||||
return None
|
||||
|
||||
new_delta_type = new_delta.WhichOneof("type")
|
||||
if new_delta_type == "new_element":
|
||||
return new_delta
|
||||
|
||||
if new_delta_type == "add_block":
|
||||
return new_delta
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.legacy_caching.caching import cache as cache
|
||||
from streamlit.runtime.legacy_caching.caching import clear_cache as clear_cache
|
||||
from streamlit.runtime.legacy_caching.caching import get_cache_path as get_cache_path
|
||||
from streamlit.runtime.legacy_caching.caching import (
|
||||
maybe_show_cached_st_function_warning as maybe_show_cached_st_function_warning,
|
||||
)
|
||||
@@ -0,0 +1,895 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A library of caching utilities."""
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from cachetools import TTLCache
|
||||
from pympler.asizeof import asizeof
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import config, file_util, util
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.elements.spinner import spinner
|
||||
from streamlit.error_util import handle_uncaught_app_exception
|
||||
from streamlit.errors import StreamlitAPIWarning
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import CACHE_DOCS_URL
|
||||
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
|
||||
from streamlit.runtime.legacy_caching.hashing import (
|
||||
HashFuncsDict,
|
||||
HashReason,
|
||||
update_hash,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
# The timer function we use with TTLCache. This is the default timer func, but
|
||||
# is exposed here as a constant so that it can be patched in unit tests.
|
||||
_TTLCACHE_TIMER = time.monotonic
|
||||
|
||||
|
||||
_CacheEntry = namedtuple("_CacheEntry", ["value", "hash"])
|
||||
_DiskCacheEntry = namedtuple("_DiskCacheEntry", ["value"])
|
||||
|
||||
# When we show the "st.cache is deprecated" warning, we make a recommendation about which new
|
||||
# cache decorator to switch to for the following data types:
|
||||
NEW_CACHE_FUNC_RECOMMENDATIONS: Dict[str, CacheType] = {
|
||||
# cache_data recommendations:
|
||||
"str": CacheType.DATA,
|
||||
"float": CacheType.DATA,
|
||||
"int": CacheType.DATA,
|
||||
"bytes": CacheType.DATA,
|
||||
"bool": CacheType.DATA,
|
||||
"datetime.datetime": CacheType.DATA,
|
||||
"pandas.DataFrame": CacheType.DATA,
|
||||
"pandas.Series": CacheType.DATA,
|
||||
"numpy.bool_": CacheType.DATA,
|
||||
"numpy.bool8": CacheType.DATA,
|
||||
"numpy.ndarray": CacheType.DATA,
|
||||
"numpy.float_": CacheType.DATA,
|
||||
"numpy.float16": CacheType.DATA,
|
||||
"numpy.float32": CacheType.DATA,
|
||||
"numpy.float64": CacheType.DATA,
|
||||
"numpy.float96": CacheType.DATA,
|
||||
"numpy.float128": CacheType.DATA,
|
||||
"numpy.int_": CacheType.DATA,
|
||||
"numpy.int8": CacheType.DATA,
|
||||
"numpy.int16": CacheType.DATA,
|
||||
"numpy.int32": CacheType.DATA,
|
||||
"numpy.int64": CacheType.DATA,
|
||||
"numpy.intp": CacheType.DATA,
|
||||
"numpy.uint8": CacheType.DATA,
|
||||
"numpy.uint16": CacheType.DATA,
|
||||
"numpy.uint32": CacheType.DATA,
|
||||
"numpy.uint64": CacheType.DATA,
|
||||
"numpy.uintp": CacheType.DATA,
|
||||
"PIL.Image.Image": CacheType.DATA,
|
||||
"plotly.graph_objects.Figure": CacheType.DATA,
|
||||
"matplotlib.figure.Figure": CacheType.DATA,
|
||||
"altair.Chart": CacheType.DATA,
|
||||
# cache_resource recommendations:
|
||||
"pyodbc.Connection": CacheType.RESOURCE,
|
||||
"pymongo.mongo_client.MongoClient": CacheType.RESOURCE,
|
||||
"mysql.connector.MySQLConnection": CacheType.RESOURCE,
|
||||
"psycopg2.connection": CacheType.RESOURCE,
|
||||
"psycopg2.extensions.connection": CacheType.RESOURCE,
|
||||
"snowflake.connector.connection.SnowflakeConnection": CacheType.RESOURCE,
|
||||
"snowflake.snowpark.sessions.Session": CacheType.RESOURCE,
|
||||
"sqlalchemy.engine.base.Engine": CacheType.RESOURCE,
|
||||
"sqlite3.Connection": CacheType.RESOURCE,
|
||||
"torch.nn.Module": CacheType.RESOURCE,
|
||||
"tensorflow.keras.Model": CacheType.RESOURCE,
|
||||
"tensorflow.Module": CacheType.RESOURCE,
|
||||
"tensorflow.compat.v1.Session": CacheType.RESOURCE,
|
||||
"transformers.Pipeline": CacheType.RESOURCE,
|
||||
"transformers.PreTrainedTokenizer": CacheType.RESOURCE,
|
||||
"transformers.PreTrainedTokenizerFast": CacheType.RESOURCE,
|
||||
"transformers.PreTrainedTokenizerBase": CacheType.RESOURCE,
|
||||
"transformers.PreTrainedModel": CacheType.RESOURCE,
|
||||
"transformers.TFPreTrainedModel": CacheType.RESOURCE,
|
||||
"transformers.FlaxPreTrainedModel": CacheType.RESOURCE,
|
||||
}
|
||||
|
||||
|
||||
def _make_deprecation_warning(cached_value: Any) -> str:
|
||||
"""Build a deprecation warning string for a cache function that has returned the given
|
||||
value.
|
||||
"""
|
||||
typename = type(cached_value).__qualname__
|
||||
cache_type_rec = NEW_CACHE_FUNC_RECOMMENDATIONS.get(typename)
|
||||
if cache_type_rec is not None:
|
||||
# We have a recommended cache func for the cached value:
|
||||
return (
|
||||
f"`st.cache` is deprecated. Please use one of Streamlit's new caching commands,\n"
|
||||
f"`st.cache_data` or `st.cache_resource`. Based on this function's return value\n"
|
||||
f"of type `{typename}`, we recommend using `st.{get_decorator_api_name(cache_type_rec)}`.\n\n"
|
||||
f"More information [in our docs]({CACHE_DOCS_URL})."
|
||||
)
|
||||
|
||||
# We do not have a recommended cache func for the cached value:
|
||||
return (
|
||||
f"`st.cache` is deprecated. Please use one of Streamlit's new caching commands,\n"
|
||||
f"`st.cache_data` or `st.cache_resource`.\n\n"
|
||||
f"More information [in our docs]({CACHE_DOCS_URL})."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemCache:
|
||||
cache: TTLCache
|
||||
display_name: str
|
||||
|
||||
|
||||
class _MemCaches(CacheStatsProvider):
|
||||
"""Manages all in-memory st.cache caches"""
|
||||
|
||||
def __init__(self):
|
||||
# Contains a cache object for each st.cache'd function
|
||||
self._lock = threading.RLock()
|
||||
self._function_caches: Dict[str, MemCache] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
max_entries: Optional[float],
|
||||
ttl: Optional[float],
|
||||
display_name: str = "",
|
||||
) -> MemCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
|
||||
if max_entries is None:
|
||||
max_entries = math.inf
|
||||
if ttl is None:
|
||||
ttl = math.inf
|
||||
|
||||
if not isinstance(max_entries, (int, float)):
|
||||
raise RuntimeError("max_entries must be an int")
|
||||
if not isinstance(ttl, (int, float)):
|
||||
raise RuntimeError("ttl must be a float")
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._lock:
|
||||
mem_cache = self._function_caches.get(key)
|
||||
if (
|
||||
mem_cache is not None
|
||||
and mem_cache.cache.ttl == ttl
|
||||
and mem_cache.cache.maxsize == max_entries
|
||||
):
|
||||
return mem_cache
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug(
|
||||
"Creating new mem_cache (key=%s, max_entries=%s, ttl=%s)",
|
||||
key,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
ttl_cache = TTLCache(maxsize=max_entries, ttl=ttl, timer=_TTLCACHE_TIMER)
|
||||
mem_cache = MemCache(ttl_cache, display_name)
|
||||
self._function_caches[key] = mem_cache
|
||||
return mem_cache
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all caches"""
|
||||
with self._lock:
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> List[CacheStat]:
|
||||
with self._lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats = [
|
||||
CacheStat("st_cache", cache.display_name, asizeof(c))
|
||||
for cache in function_caches.values()
|
||||
for c in cache.cache
|
||||
]
|
||||
return stats
|
||||
|
||||
|
||||
# Our singleton _MemCaches instance
|
||||
_mem_caches = _MemCaches()
|
||||
|
||||
|
||||
# A thread-local counter that's incremented when we enter @st.cache
|
||||
# and decremented when we exit.
|
||||
class ThreadLocalCacheInfo(threading.local):
|
||||
def __init__(self):
|
||||
self.cached_func_stack: List[Callable[..., Any]] = []
|
||||
self.suppress_st_function_warning = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
_cache_info = ThreadLocalCacheInfo()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _calling_cached_function(func: Callable[..., Any]) -> Iterator[None]:
|
||||
_cache_info.cached_func_stack.append(func)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_cache_info.cached_func_stack.pop()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cached_st_function_warning() -> Iterator[None]:
|
||||
_cache_info.suppress_st_function_warning += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_cache_info.suppress_st_function_warning -= 1
|
||||
assert _cache_info.suppress_st_function_warning >= 0
|
||||
|
||||
|
||||
def _show_cached_st_function_warning(
|
||||
dg: "st.delta_generator.DeltaGenerator",
|
||||
st_func_name: str,
|
||||
cached_func: Callable[..., Any],
|
||||
) -> None:
|
||||
# Avoid infinite recursion by suppressing additional cached
|
||||
# function warnings from within the cached function warning.
|
||||
with suppress_cached_st_function_warning():
|
||||
e = CachedStFunctionWarning(st_func_name, cached_func)
|
||||
dg.exception(e)
|
||||
|
||||
|
||||
def maybe_show_cached_st_function_warning(
|
||||
dg: "st.delta_generator.DeltaGenerator", st_func_name: str
|
||||
) -> None:
|
||||
"""If appropriate, warn about calling st.foo inside @cache.
|
||||
|
||||
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
|
||||
the user when they're calling st.foo() from within a function that is
|
||||
wrapped in @st.cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dg : DeltaGenerator
|
||||
The DeltaGenerator to publish the warning to.
|
||||
|
||||
st_func_name : str
|
||||
The name of the Streamlit function that was called.
|
||||
|
||||
"""
|
||||
if (
|
||||
len(_cache_info.cached_func_stack) > 0
|
||||
and _cache_info.suppress_st_function_warning <= 0
|
||||
):
|
||||
cached_func = _cache_info.cached_func_stack[-1]
|
||||
_show_cached_st_function_warning(dg, st_func_name, cached_func)
|
||||
|
||||
|
||||
def _read_from_mem_cache(
|
||||
mem_cache: MemCache,
|
||||
key: str,
|
||||
allow_output_mutation: bool,
|
||||
func_or_code: Callable[..., Any],
|
||||
hash_funcs: Optional[HashFuncsDict],
|
||||
) -> Any:
|
||||
cache = mem_cache.cache
|
||||
if key in cache:
|
||||
entry = cache[key]
|
||||
|
||||
if not allow_output_mutation:
|
||||
computed_output_hash = _get_output_hash(
|
||||
entry.value, func_or_code, hash_funcs
|
||||
)
|
||||
stored_output_hash = entry.hash
|
||||
|
||||
if computed_output_hash != stored_output_hash:
|
||||
_LOGGER.debug("Cached object was mutated: %s", key)
|
||||
raise CachedObjectMutationError(entry.value, func_or_code)
|
||||
|
||||
_LOGGER.debug("Memory cache HIT: %s", type(entry.value))
|
||||
return entry.value
|
||||
|
||||
else:
|
||||
_LOGGER.debug("Memory cache MISS: %s", key)
|
||||
raise CacheKeyNotFoundError("Key not found in mem cache")
|
||||
|
||||
|
||||
def _write_to_mem_cache(
|
||||
mem_cache: MemCache,
|
||||
key: str,
|
||||
value: Any,
|
||||
allow_output_mutation: bool,
|
||||
func_or_code: Callable[..., Any],
|
||||
hash_funcs: Optional[HashFuncsDict],
|
||||
) -> None:
|
||||
if allow_output_mutation:
|
||||
hash = None
|
||||
else:
|
||||
hash = _get_output_hash(value, func_or_code, hash_funcs)
|
||||
|
||||
mem_cache.display_name = f"{func_or_code.__module__}.{func_or_code.__qualname__}"
|
||||
mem_cache.cache[key] = _CacheEntry(value=value, hash=hash)
|
||||
|
||||
|
||||
def _get_output_hash(
|
||||
value: Any, func_or_code: Callable[..., Any], hash_funcs: Optional[HashFuncsDict]
|
||||
) -> bytes:
|
||||
hasher = hashlib.new("md5")
|
||||
update_hash(
|
||||
value,
|
||||
hasher=hasher,
|
||||
hash_funcs=hash_funcs,
|
||||
hash_reason=HashReason.CACHING_FUNC_OUTPUT,
|
||||
hash_source=func_or_code,
|
||||
)
|
||||
return hasher.digest()
|
||||
|
||||
|
||||
def _read_from_disk_cache(key: str) -> Any:
|
||||
path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key)
|
||||
try:
|
||||
with file_util.streamlit_read(path, binary=True) as input:
|
||||
entry = pickle.load(input)
|
||||
value = entry.value
|
||||
_LOGGER.debug("Disk cache HIT: %s", type(value))
|
||||
except util.Error as e:
|
||||
_LOGGER.error(e)
|
||||
raise CacheError("Unable to read from cache: %s" % e)
|
||||
|
||||
except FileNotFoundError:
|
||||
raise CacheKeyNotFoundError("Key not found in disk cache")
|
||||
return value
|
||||
|
||||
|
||||
def _write_to_disk_cache(key: str, value: Any) -> None:
|
||||
path = file_util.get_streamlit_file_path("cache", "%s.pickle" % key)
|
||||
|
||||
try:
|
||||
with file_util.streamlit_write(path, binary=True) as output:
|
||||
entry = _DiskCacheEntry(value=value)
|
||||
pickle.dump(entry, output, pickle.HIGHEST_PROTOCOL)
|
||||
except util.Error as e:
|
||||
_LOGGER.debug(e)
|
||||
# Clean up file so we don't leave zero byte files.
|
||||
try:
|
||||
os.remove(path)
|
||||
except (FileNotFoundError, IOError, OSError):
|
||||
# If we can't remove the file, it's not a big deal.
|
||||
pass
|
||||
raise CacheError("Unable to write to cache: %s" % e)
|
||||
|
||||
|
||||
def _read_from_cache(
|
||||
mem_cache: MemCache,
|
||||
key: str,
|
||||
persist: bool,
|
||||
allow_output_mutation: bool,
|
||||
func_or_code: Callable[..., Any],
|
||||
hash_funcs: Optional[HashFuncsDict] = None,
|
||||
) -> Any:
|
||||
"""Read a value from the cache.
|
||||
|
||||
Our goal is to read from memory if possible. If the data was mutated (hash
|
||||
changed), we show a warning. If reading from memory fails, we either read
|
||||
from disk or rerun the code.
|
||||
"""
|
||||
try:
|
||||
return _read_from_mem_cache(
|
||||
mem_cache, key, allow_output_mutation, func_or_code, hash_funcs
|
||||
)
|
||||
|
||||
except CachedObjectMutationError as e:
|
||||
handle_uncaught_app_exception(CachedObjectMutationWarning(e))
|
||||
return e.cached_value
|
||||
|
||||
except CacheKeyNotFoundError as e:
|
||||
if persist:
|
||||
value = _read_from_disk_cache(key)
|
||||
_write_to_mem_cache(
|
||||
mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs
|
||||
)
|
||||
return value
|
||||
raise e
|
||||
|
||||
|
||||
@gather_metrics("_cache_object")
|
||||
def _write_to_cache(
|
||||
mem_cache: MemCache,
|
||||
key: str,
|
||||
value: Any,
|
||||
persist: bool,
|
||||
allow_output_mutation: bool,
|
||||
func_or_code: Callable[..., Any],
|
||||
hash_funcs: Optional[HashFuncsDict] = None,
|
||||
):
|
||||
_write_to_mem_cache(
|
||||
mem_cache, key, value, allow_output_mutation, func_or_code, hash_funcs
|
||||
)
|
||||
if persist:
|
||||
_write_to_disk_cache(key, value)
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def cache(
|
||||
func: F,
|
||||
persist: bool = False,
|
||||
allow_output_mutation: bool = False,
|
||||
show_spinner: bool = True,
|
||||
suppress_st_warning: bool = False,
|
||||
hash_funcs: Optional[HashFuncsDict] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
) -> F:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def cache(
|
||||
func: None = None,
|
||||
persist: bool = False,
|
||||
allow_output_mutation: bool = False,
|
||||
show_spinner: bool = True,
|
||||
suppress_st_warning: bool = False,
|
||||
hash_funcs: Optional[HashFuncsDict] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
) -> Callable[[F], F]:
|
||||
...
|
||||
|
||||
|
||||
def cache(
|
||||
func: Optional[F] = None,
|
||||
persist: bool = False,
|
||||
allow_output_mutation: bool = False,
|
||||
show_spinner: bool = True,
|
||||
suppress_st_warning: bool = False,
|
||||
hash_funcs: Optional[HashFuncsDict] = None,
|
||||
max_entries: Optional[int] = None,
|
||||
ttl: Optional[float] = None,
|
||||
) -> Union[Callable[[F], F], F]:
|
||||
"""Function decorator to memoize function executions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to cache. Streamlit hashes the function and dependent code.
|
||||
|
||||
persist : boolean
|
||||
Whether to persist the cache on disk.
|
||||
|
||||
allow_output_mutation : boolean
|
||||
Streamlit shows a warning when return values are mutated, as that
|
||||
can have unintended consequences. This is done by hashing the return value internally.
|
||||
|
||||
If you know what you're doing and would like to override this warning, set this to True.
|
||||
|
||||
show_spinner : boolean
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a cache miss.
|
||||
|
||||
suppress_st_warning : boolean
|
||||
Suppress warnings about calling Streamlit commands from within
|
||||
the cached function.
|
||||
|
||||
hash_funcs : dict or None
|
||||
Mapping of types or fully qualified names to hash functions. This is used to override
|
||||
the behavior of the hasher inside Streamlit's caching mechanism: when the hasher
|
||||
encounters an object, it will first check to see if its type matches a key in this
|
||||
dict and, if so, will use the provided function to generate a hash for it. See below
|
||||
for an example of how this can be used.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. (When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed.) The default is None.
|
||||
|
||||
ttl : float or None
|
||||
The maximum number of seconds to keep an entry in the cache, or
|
||||
None if cache entries should not expire. The default is None.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> d1 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> d2 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the data in d1 is the same as in d2.
|
||||
>>>
|
||||
>>> d3 = fetch_and_clean_data(DATA_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
To set the ``persist`` parameter, use this command as follows:
|
||||
|
||||
>>> @st.cache(persist=True)
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
To disable hashing return values, set the ``allow_output_mutation`` parameter to ``True``:
|
||||
|
||||
>>> @st.cache(allow_output_mutation=True)
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
|
||||
To override the default hashing behavior, pass a custom hash function.
|
||||
You can do that by mapping a type (e.g. ``MongoClient``) to a hash function (``id``) like this:
|
||||
|
||||
>>> @st.cache(hash_funcs={MongoClient: id})
|
||||
... def connect_to_database(url):
|
||||
... return MongoClient(url)
|
||||
|
||||
Alternatively, you can map the type's fully-qualified name
|
||||
(e.g. ``"pymongo.mongo_client.MongoClient"``) to the hash function instead:
|
||||
|
||||
>>> @st.cache(hash_funcs={"pymongo.mongo_client.MongoClient": id})
|
||||
... def connect_to_database(url):
|
||||
... return MongoClient(url)
|
||||
|
||||
"""
|
||||
_LOGGER.debug("Entering st.cache: %s", func)
|
||||
|
||||
# Support passing the params via function decorator, e.g.
|
||||
# @st.cache(persist=True, allow_output_mutation=True)
|
||||
if func is None:
|
||||
|
||||
def wrapper(f: F) -> F:
|
||||
return cache(
|
||||
func=f,
|
||||
persist=persist,
|
||||
allow_output_mutation=allow_output_mutation,
|
||||
show_spinner=show_spinner,
|
||||
suppress_st_warning=suppress_st_warning,
|
||||
hash_funcs=hash_funcs,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
# To make mypy type narrow Optional[F] -> F
|
||||
non_optional_func = func
|
||||
|
||||
cache_key = None
|
||||
|
||||
@functools.wraps(non_optional_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""Wrapper function that only calls the underlying function on a cache miss.
|
||||
|
||||
Cached objects are stored in the cache/ directory.
|
||||
"""
|
||||
|
||||
if not config.get_option("client.caching"):
|
||||
_LOGGER.debug("Purposefully skipping cache")
|
||||
return non_optional_func(*args, **kwargs)
|
||||
|
||||
name = non_optional_func.__qualname__
|
||||
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
message = "Running `%s()`." % name
|
||||
else:
|
||||
message = "Running `%s(...)`." % name
|
||||
|
||||
def get_or_create_cached_value():
|
||||
nonlocal cache_key
|
||||
if cache_key is None:
|
||||
# Delay generating the cache key until the first call.
|
||||
# This way we can see values of globals, including functions
|
||||
# defined after this one.
|
||||
# If we generated the key earlier we would only hash those
|
||||
# globals by name, and miss changes in their code or value.
|
||||
cache_key = _hash_func(non_optional_func, hash_funcs)
|
||||
|
||||
# First, get the cache that's attached to this function.
|
||||
# This cache's key is generated (above) from the function's code.
|
||||
mem_cache = _mem_caches.get_cache(cache_key, max_entries, ttl)
|
||||
|
||||
# Next, calculate the key for the value we'll be searching for
|
||||
# within that cache. This key is generated from both the function's
|
||||
# code and the arguments that are passed into it. (Even though this
|
||||
# key is used to index into a per-function cache, it must be
|
||||
# globally unique, because it is *also* used for a global on-disk
|
||||
# cache that is *not* per-function.)
|
||||
value_hasher = hashlib.new("md5")
|
||||
|
||||
if args:
|
||||
update_hash(
|
||||
args,
|
||||
hasher=value_hasher,
|
||||
hash_funcs=hash_funcs,
|
||||
hash_reason=HashReason.CACHING_FUNC_ARGS,
|
||||
hash_source=non_optional_func,
|
||||
)
|
||||
|
||||
if kwargs:
|
||||
update_hash(
|
||||
kwargs,
|
||||
hasher=value_hasher,
|
||||
hash_funcs=hash_funcs,
|
||||
hash_reason=HashReason.CACHING_FUNC_ARGS,
|
||||
hash_source=non_optional_func,
|
||||
)
|
||||
|
||||
value_key = value_hasher.hexdigest()
|
||||
|
||||
# Avoid recomputing the body's hash by just appending the
|
||||
# previously-computed hash to the arg hash.
|
||||
value_key = "%s-%s" % (value_key, cache_key)
|
||||
|
||||
_LOGGER.debug("Cache key: %s", value_key)
|
||||
|
||||
try:
|
||||
return_value = _read_from_cache(
|
||||
mem_cache=mem_cache,
|
||||
key=value_key,
|
||||
persist=persist,
|
||||
allow_output_mutation=allow_output_mutation,
|
||||
func_or_code=non_optional_func,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
_LOGGER.debug("Cache hit: %s", non_optional_func)
|
||||
|
||||
except CacheKeyNotFoundError:
|
||||
_LOGGER.debug("Cache miss: %s", non_optional_func)
|
||||
|
||||
with _calling_cached_function(non_optional_func):
|
||||
if suppress_st_warning:
|
||||
with suppress_cached_st_function_warning():
|
||||
return_value = non_optional_func(*args, **kwargs)
|
||||
else:
|
||||
return_value = non_optional_func(*args, **kwargs)
|
||||
|
||||
_write_to_cache(
|
||||
mem_cache=mem_cache,
|
||||
key=value_key,
|
||||
value=return_value,
|
||||
persist=persist,
|
||||
allow_output_mutation=allow_output_mutation,
|
||||
func_or_code=non_optional_func,
|
||||
hash_funcs=hash_funcs,
|
||||
)
|
||||
|
||||
# st.cache is deprecated. We show a warning every time it's used.
|
||||
show_deprecation_warning(_make_deprecation_warning(return_value))
|
||||
|
||||
return return_value
|
||||
|
||||
if show_spinner:
|
||||
with spinner(message):
|
||||
return get_or_create_cached_value()
|
||||
else:
|
||||
return get_or_create_cached_value()
|
||||
|
||||
# Make this a well-behaved decorator by preserving important function
|
||||
# attributes.
|
||||
try:
|
||||
wrapped_func.__dict__.update(non_optional_func.__dict__)
|
||||
except AttributeError:
|
||||
# For normal functions this should never happen, but if so it's not problematic.
|
||||
pass
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
|
||||
def _hash_func(func: Callable[..., Any], hash_funcs: Optional[HashFuncsDict]) -> str:
|
||||
# Create the unique key for a function's cache. The cache will be retrieved
|
||||
# from inside the wrapped function.
|
||||
#
|
||||
# A naive implementation would involve simply creating the cache object
|
||||
# right in the wrapper, which in a normal Python script would be executed
|
||||
# only once. But in Streamlit, we reload all modules related to a user's
|
||||
# app when the app is re-run, which means that - among other things - all
|
||||
# function decorators in the app will be re-run, and so any decorator-local
|
||||
# objects will be recreated.
|
||||
#
|
||||
# Furthermore, our caches can be destroyed and recreated (in response to
|
||||
# cache clearing, for example), which means that retrieving the function's
|
||||
# cache in the decorator (so that the wrapped function can save a lookup)
|
||||
# is incorrect: the cache itself may be recreated between
|
||||
# decorator-evaluation time and decorated-function-execution time. So we
|
||||
# must retrieve the cache object *and* perform the cached-value lookup
|
||||
# inside the decorated function.
|
||||
func_hasher = hashlib.new("md5")
|
||||
|
||||
# Include the function's __module__ and __qualname__ strings in the hash.
|
||||
# This means that two identical functions in different modules
|
||||
# will not share a hash; it also means that two identical *nested*
|
||||
# functions in the same module will not share a hash.
|
||||
# We do not pass `hash_funcs` here, because we don't want our function's
|
||||
# name to get an unexpected hash.
|
||||
update_hash(
|
||||
(func.__module__, func.__qualname__),
|
||||
hasher=func_hasher,
|
||||
hash_funcs=None,
|
||||
hash_reason=HashReason.CACHING_FUNC_BODY,
|
||||
hash_source=func,
|
||||
)
|
||||
|
||||
# Include the function's body in the hash. We *do* pass hash_funcs here,
|
||||
# because this step will be hashing any objects referenced in the function
|
||||
# body.
|
||||
update_hash(
|
||||
func,
|
||||
hasher=func_hasher,
|
||||
hash_funcs=hash_funcs,
|
||||
hash_reason=HashReason.CACHING_FUNC_BODY,
|
||||
hash_source=func,
|
||||
)
|
||||
cache_key = func_hasher.hexdigest()
|
||||
_LOGGER.debug(
|
||||
"mem_cache key for %s.%s: %s", func.__module__, func.__qualname__, cache_key
|
||||
)
|
||||
return cache_key
|
||||
|
||||
|
||||
def clear_cache() -> bool:
|
||||
"""Clear the memoization cache.
|
||||
|
||||
Returns
|
||||
-------
|
||||
boolean
|
||||
True if the disk cache was cleared. False otherwise (e.g. cache file
|
||||
doesn't exist on disk).
|
||||
"""
|
||||
_clear_mem_cache()
|
||||
return _clear_disk_cache()
|
||||
|
||||
|
||||
def get_cache_path() -> str:
|
||||
return file_util.get_streamlit_file_path("cache")
|
||||
|
||||
|
||||
def _clear_disk_cache() -> bool:
|
||||
# TODO: Only delete disk cache for functions related to the user's current
|
||||
# script.
|
||||
cache_path = get_cache_path()
|
||||
if os.path.isdir(cache_path):
|
||||
shutil.rmtree(cache_path)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _clear_mem_cache() -> None:
|
||||
_mem_caches.clear()
|
||||
|
||||
|
||||
class CacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CacheKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CachedObjectMutationError(ValueError):
|
||||
# This is used internally, but never shown to the user.
|
||||
# Users see CachedObjectMutationWarning instead.
|
||||
|
||||
def __init__(self, cached_value, func_or_code):
|
||||
self.cached_value = cached_value
|
||||
if inspect.iscode(func_or_code):
|
||||
self.cached_func_name = "a code block"
|
||||
else:
|
||||
self.cached_func_name = _get_cached_func_name_md(func_or_code)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
class CachedStFunctionWarning(StreamlitAPIWarning):
|
||||
def __init__(self, st_func_name, cached_func):
|
||||
msg = self._get_message(st_func_name, cached_func)
|
||||
super(CachedStFunctionWarning, self).__init__(msg)
|
||||
|
||||
def _get_message(self, st_func_name, cached_func):
|
||||
args = {
|
||||
"st_func_name": "`st.%s()` or `st.write()`" % st_func_name,
|
||||
"func_name": _get_cached_func_name_md(cached_func),
|
||||
}
|
||||
|
||||
return (
|
||||
"""
|
||||
Your script uses %(st_func_name)s to write to your Streamlit app from within
|
||||
some cached code at %(func_name)s. This code will only be called when we detect
|
||||
a cache "miss", which can lead to unexpected results.
|
||||
|
||||
How to fix this:
|
||||
* Move the %(st_func_name)s call outside %(func_name)s.
|
||||
* Or, if you know what you're doing, use `@st.cache(suppress_st_warning=True)`
|
||||
to suppress the warning.
|
||||
"""
|
||||
% args
|
||||
).strip("\n")
|
||||
|
||||
|
||||
class CachedObjectMutationWarning(StreamlitAPIWarning):
|
||||
def __init__(self, orig_exc):
|
||||
msg = self._get_message(orig_exc)
|
||||
super(CachedObjectMutationWarning, self).__init__(msg)
|
||||
|
||||
def _get_message(self, orig_exc):
|
||||
return (
|
||||
"""
|
||||
Return value of %(func_name)s was mutated between runs.
|
||||
|
||||
By default, Streamlit's cache should be treated as immutable, or it may behave
|
||||
in unexpected ways. You received this warning because Streamlit detected
|
||||
that an object returned by %(func_name)s was mutated outside of %(func_name)s.
|
||||
|
||||
How to fix this:
|
||||
* If you did not mean to mutate that return value:
|
||||
- If possible, inspect your code to find and remove that mutation.
|
||||
- Otherwise, you could also clone the returned value so you can freely
|
||||
mutate it.
|
||||
* If you actually meant to mutate the return value and know the consequences of
|
||||
doing so, annotate the function with `@st.cache(allow_output_mutation=True)`.
|
||||
|
||||
For more information and detailed solutions check out [our documentation.]
|
||||
(https://docs.streamlit.io/library/advanced-features/caching)
|
||||
"""
|
||||
% {"func_name": orig_exc.cached_func_name}
|
||||
).strip("\n")
|
||||
|
||||
|
||||
def _get_cached_func_name_md(func: Callable[..., Any]) -> str:
|
||||
"""Get markdown representation of the function name."""
|
||||
if hasattr(func, "__name__"):
|
||||
return "`%s()`" % func.__name__
|
||||
else:
|
||||
return "a cached function"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,230 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Provides global MediaFileManager object as `media_file_manager`."""
|
||||
|
||||
import collections
|
||||
import threading
|
||||
from typing import Dict, Optional, Set, Union
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.media_file_storage import MediaFileKind, MediaFileStorage
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
def _get_session_id() -> str:
|
||||
"""Get the active AppSession's session_id."""
|
||||
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
# This is only None when running "python myscript.py" rather than
|
||||
# "streamlit run myscript.py". In which case the session ID doesn't
|
||||
# matter and can just be a constant, as there's only ever "session".
|
||||
return "dontcare"
|
||||
else:
|
||||
return ctx.session_id
|
||||
|
||||
|
||||
class MediaFileMetadata:
|
||||
"""Metadata that the MediaFileManager needs for each file it manages."""
|
||||
|
||||
def __init__(self, kind: MediaFileKind = MediaFileKind.MEDIA):
|
||||
self._kind = kind
|
||||
self._is_marked_for_delete = False
|
||||
|
||||
@property
|
||||
def kind(self) -> MediaFileKind:
|
||||
return self._kind
|
||||
|
||||
@property
|
||||
def is_marked_for_delete(self) -> bool:
|
||||
return self._is_marked_for_delete
|
||||
|
||||
def mark_for_delete(self) -> None:
|
||||
self._is_marked_for_delete = True
|
||||
|
||||
|
||||
class MediaFileManager:
|
||||
"""In-memory file manager for MediaFile objects.
|
||||
|
||||
This keeps track of:
|
||||
- Which files exist, and what their IDs are. This is important so we can
|
||||
serve files by ID -- that's the whole point of this class!
|
||||
- Which files are being used by which AppSession (by ID). This is
|
||||
important so we can remove files from memory when no more sessions need
|
||||
them.
|
||||
- The exact location in the app where each file is being used (i.e. the
|
||||
file's "coordinates"). This is is important so we can mark a file as "not
|
||||
being used by a certain session" if it gets replaced by another file at
|
||||
the same coordinates. For example, when doing an animation where the same
|
||||
image is constantly replace with new frames. (This doesn't solve the case
|
||||
where the file's coordinates keep changing for some reason, though! e.g.
|
||||
if new elements keep being prepended to the app. Unlikely to happen, but
|
||||
we should address it at some point.)
|
||||
"""
|
||||
|
||||
def __init__(self, storage: MediaFileStorage):
|
||||
self._storage = storage
|
||||
|
||||
# Dict of [file_id -> MediaFileMetadata]
|
||||
self._file_metadata: Dict[str, MediaFileMetadata] = dict()
|
||||
|
||||
# Dict[session ID][coordinates] -> file_id.
|
||||
self._files_by_session_and_coord: Dict[
|
||||
str, Dict[str, str]
|
||||
] = collections.defaultdict(dict)
|
||||
|
||||
# MediaFileManager is used from multiple threads, so all operations
|
||||
# need to be protected with a Lock. (This is not an RLock, which
|
||||
# means taking it multiple times from the same thread will deadlock.)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _get_inactive_file_ids(self) -> Set[str]:
|
||||
"""Compute the set of files that are stored in the manager, but are
|
||||
not referenced by any active session. These are files that can be
|
||||
safely deleted.
|
||||
|
||||
Thread safety: callers must hold `self._lock`.
|
||||
"""
|
||||
# Get the set of all our file IDs.
|
||||
file_ids = set(self._file_metadata.keys())
|
||||
|
||||
# Subtract all IDs that are in use by each session
|
||||
for session_file_ids_by_coord in self._files_by_session_and_coord.values():
|
||||
file_ids.difference_update(session_file_ids_by_coord.values())
|
||||
|
||||
return file_ids
|
||||
|
||||
def remove_orphaned_files(self) -> None:
|
||||
"""Remove all files that are no longer referenced by any active session.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
LOGGER.debug("Removing orphaned files...")
|
||||
|
||||
with self._lock:
|
||||
for file_id in self._get_inactive_file_ids():
|
||||
file = self._file_metadata[file_id]
|
||||
if file.kind == MediaFileKind.MEDIA:
|
||||
self._delete_file(file_id)
|
||||
elif file.kind == MediaFileKind.DOWNLOADABLE:
|
||||
if file.is_marked_for_delete:
|
||||
self._delete_file(file_id)
|
||||
else:
|
||||
file.mark_for_delete()
|
||||
|
||||
def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete the given file from storage, and remove its metadata from
|
||||
self._files_by_id.
|
||||
|
||||
Thread safety: callers must hold `self._lock`.
|
||||
"""
|
||||
LOGGER.debug("Deleting File: %s", file_id)
|
||||
self._storage.delete_file(file_id)
|
||||
del self._file_metadata[file_id]
|
||||
|
||||
def clear_session_refs(self, session_id: Optional[str] = None) -> None:
|
||||
"""Remove the given session's file references.
|
||||
|
||||
(This does not remove any files from the manager - you must call
|
||||
`remove_orphaned_files` for that.)
|
||||
|
||||
Should be called whenever ScriptRunner completes and when a session ends.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
if session_id is None:
|
||||
session_id = _get_session_id()
|
||||
|
||||
LOGGER.debug("Disconnecting files for session with ID %s", session_id)
|
||||
|
||||
with self._lock:
|
||||
if session_id in self._files_by_session_and_coord:
|
||||
del self._files_by_session_and_coord[session_id]
|
||||
|
||||
LOGGER.debug(
|
||||
"Sessions still active: %r", self._files_by_session_and_coord.keys()
|
||||
)
|
||||
|
||||
LOGGER.debug(
|
||||
"Files: %s; Sessions with files: %s",
|
||||
len(self._file_metadata),
|
||||
len(self._files_by_session_and_coord),
|
||||
)
|
||||
|
||||
def add(
|
||||
self,
|
||||
path_or_data: Union[bytes, str],
|
||||
mimetype: str,
|
||||
coordinates: str,
|
||||
file_name: Optional[str] = None,
|
||||
is_for_static_download: bool = False,
|
||||
) -> str:
|
||||
"""Add a new MediaFile with the given parameters and return its URL.
|
||||
|
||||
If an identical file already exists, return the existing URL
|
||||
and registers the current session as a user.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_or_data : bytes or str
|
||||
If bytes: the media file's raw data. If str: the name of a file
|
||||
to load from disk.
|
||||
mimetype : str
|
||||
The mime type for the file. E.g. "audio/mpeg".
|
||||
This string will be used in the "Content-Type" header when the file
|
||||
is served over HTTP.
|
||||
coordinates : str
|
||||
Unique string identifying an element's location.
|
||||
Prevents memory leak of "forgotten" file IDs when element media
|
||||
is being replaced-in-place (e.g. an st.image stream).
|
||||
coordinates should be of the form: "1.(3.-14).5"
|
||||
file_name : str or None
|
||||
Optional file_name. Used to set the filename in the response header.
|
||||
is_for_static_download: bool
|
||||
Indicate that data stored for downloading as a file,
|
||||
not as a media for rendering at page. [default: False]
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The url that the frontend can use to fetch the media.
|
||||
|
||||
Raises
|
||||
------
|
||||
If a filename is passed, any Exception raised when trying to read the
|
||||
file will be re-raised.
|
||||
"""
|
||||
|
||||
session_id = _get_session_id()
|
||||
|
||||
with self._lock:
|
||||
kind = (
|
||||
MediaFileKind.DOWNLOADABLE
|
||||
if is_for_static_download
|
||||
else MediaFileKind.MEDIA
|
||||
)
|
||||
file_id = self._storage.load_and_get_id(
|
||||
path_or_data, mimetype, kind, file_name
|
||||
)
|
||||
metadata = MediaFileMetadata(kind=kind)
|
||||
|
||||
self._file_metadata[file_id] = metadata
|
||||
self._files_by_session_and_coord[session_id][coordinates] = file_id
|
||||
|
||||
return self._storage.get_url(file_id)
|
||||
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
class MediaFileKind(Enum):
|
||||
# st.image, st.video, st.audio files
|
||||
MEDIA = "media"
|
||||
|
||||
# st.download_button files
|
||||
DOWNLOADABLE = "downloadable"
|
||||
|
||||
|
||||
class MediaFileStorageError(Exception):
|
||||
"""Exception class for errors raised by MediaFileStorage.
|
||||
|
||||
When running in "development mode", the full text of these errors
|
||||
is displayed in the frontend, so errors should be human-readable
|
||||
(and actionable).
|
||||
|
||||
When running in "release mode", errors are redacted on the
|
||||
frontend; we instead show a generic "Something went wrong!" message.
|
||||
"""
|
||||
|
||||
|
||||
class MediaFileStorage(Protocol):
|
||||
@abstractmethod
|
||||
def load_and_get_id(
|
||||
self,
|
||||
path_or_data: Union[str, bytes],
|
||||
mimetype: str,
|
||||
kind: MediaFileKind,
|
||||
filename: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Load the given file path or bytes into the manager and return
|
||||
an ID that uniquely identifies it.
|
||||
|
||||
It’s an error to pass a URL to this function. (Media stored at
|
||||
external URLs can be served directly to the Streamlit frontend;
|
||||
there’s no need to store this data in MediaFileStorage.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path_or_data
|
||||
A path to a file, or the file's raw data as bytes.
|
||||
|
||||
mimetype
|
||||
The media’s mimetype. Used to set the Content-Type header when
|
||||
serving the media over HTTP.
|
||||
|
||||
kind
|
||||
The kind of file this is: either MEDIA, or DOWNLOADABLE.
|
||||
|
||||
filename : str or None
|
||||
Optional filename. Used to set the filename in the response header.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The unique ID of the media file.
|
||||
|
||||
Raises
|
||||
------
|
||||
MediaFileStorageError
|
||||
Raised if the media can't be loaded (for example, if a file
|
||||
path is invalid).
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self, file_id: str) -> str:
|
||||
"""Return a URL for a file in the manager.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_id
|
||||
The file's ID, returned from load_media_and_get_id().
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
A URL that the frontend can load the file from. Because this
|
||||
URL may expire, it should not be cached!
|
||||
|
||||
Raises
|
||||
------
|
||||
MediaFileStorageError
|
||||
Raised if the manager doesn't contain an object with the given ID.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from the manager.
|
||||
|
||||
This should be called when a given file is no longer referenced
|
||||
by any connected client, so that the MediaFileStorage can free its
|
||||
resources.
|
||||
|
||||
Calling `delete_file` on a file_id that doesn't exist is allowed,
|
||||
and is a no-op. (This means that multiple `delete_file` calls with
|
||||
the same file_id is not an error.)
|
||||
|
||||
Note: implementations can choose to ignore `delete_file` calls -
|
||||
this function is a *suggestion*, not a *command*. Callers should
|
||||
not rely on file deletion happening immediately (or at all).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_id
|
||||
The file's ID, returned from load_media_and_get_id().
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
|
||||
Raises
|
||||
------
|
||||
MediaFileStorageError
|
||||
Raised if file deletion fails for any reason. Note that these
|
||||
failures will generally not be shown on the frontend (file
|
||||
deletion usually occurs on session disconnect).
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MediaFileStorage implementation that stores files in memory."""
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import mimetypes
|
||||
import os.path
|
||||
from typing import Dict, List, NamedTuple, Optional, Union
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.media_file_storage import (
|
||||
MediaFileKind,
|
||||
MediaFileStorage,
|
||||
MediaFileStorageError,
|
||||
)
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
# Mimetype -> filename extension map for the `get_extension_for_mimetype`
|
||||
# function. We use Python's `mimetypes.guess_extension` for most mimetypes,
|
||||
# but (as of Python 3.9) `mimetypes.guess_extension("audio/wav")` returns None,
|
||||
# so we handle it ourselves.
|
||||
PREFERRED_MIMETYPE_EXTENSION_MAP: Final = {
|
||||
"audio/wav": ".wav",
|
||||
}
|
||||
|
||||
|
||||
def _calculate_file_id(
|
||||
data: bytes, mimetype: str, filename: Optional[str] = None
|
||||
) -> str:
|
||||
"""Hash data, mimetype, and an optional filename to generate a stable file ID.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data
|
||||
Content of in-memory file in bytes. Other types will throw TypeError.
|
||||
mimetype
|
||||
Any string. Will be converted to bytes and used to compute a hash.
|
||||
filename
|
||||
Any string. Will be converted to bytes and used to compute a hash.
|
||||
"""
|
||||
filehash = hashlib.new("sha224")
|
||||
filehash.update(data)
|
||||
filehash.update(bytes(mimetype.encode()))
|
||||
|
||||
if filename is not None:
|
||||
filehash.update(bytes(filename.encode()))
|
||||
|
||||
return filehash.hexdigest()
|
||||
|
||||
|
||||
def get_extension_for_mimetype(mimetype: str) -> str:
|
||||
if mimetype in PREFERRED_MIMETYPE_EXTENSION_MAP:
|
||||
return PREFERRED_MIMETYPE_EXTENSION_MAP[mimetype]
|
||||
|
||||
extension = mimetypes.guess_extension(mimetype, strict=False)
|
||||
if extension is None:
|
||||
return ""
|
||||
|
||||
return extension
|
||||
|
||||
|
||||
class MemoryFile(NamedTuple):
|
||||
"""A MediaFile stored in memory."""
|
||||
|
||||
content: bytes
|
||||
mimetype: str
|
||||
kind: MediaFileKind
|
||||
filename: Optional[str]
|
||||
|
||||
@property
|
||||
def content_size(self) -> int:
|
||||
return len(self.content)
|
||||
|
||||
|
||||
class MemoryMediaFileStorage(MediaFileStorage, CacheStatsProvider):
|
||||
def __init__(self, media_endpoint: str):
|
||||
"""Create a new MemoryMediaFileStorage instance
|
||||
|
||||
Parameters
|
||||
----------
|
||||
media_endpoint
|
||||
The name of the local endpoint that media is served from.
|
||||
This endpoint should start with a forward-slash (e.g. "/media").
|
||||
"""
|
||||
self._files_by_id: Dict[str, MemoryFile] = {}
|
||||
self._media_endpoint = media_endpoint
|
||||
|
||||
def load_and_get_id(
|
||||
self,
|
||||
path_or_data: Union[str, bytes],
|
||||
mimetype: str,
|
||||
kind: MediaFileKind,
|
||||
filename: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Add a file to the manager and return its ID."""
|
||||
file_data: bytes
|
||||
if isinstance(path_or_data, str):
|
||||
file_data = self._read_file(path_or_data)
|
||||
else:
|
||||
file_data = path_or_data
|
||||
|
||||
# Because our file_ids are stable, if we already have a file with the
|
||||
# given ID, we don't need to create a new one.
|
||||
file_id = _calculate_file_id(file_data, mimetype, filename)
|
||||
if file_id not in self._files_by_id:
|
||||
LOGGER.debug("Adding media file %s", file_id)
|
||||
media_file = MemoryFile(
|
||||
content=file_data, mimetype=mimetype, kind=kind, filename=filename
|
||||
)
|
||||
self._files_by_id[file_id] = media_file
|
||||
|
||||
return file_id
|
||||
|
||||
def get_file(self, filename: str) -> MemoryFile:
|
||||
"""Return the MemoryFile with the given filename. Filenames are of the
|
||||
form "file_id.extension". (Note that this is *not* the optional
|
||||
user-specified filename for download files.)
|
||||
|
||||
Raises a MediaFileStorageError if no such file exists.
|
||||
"""
|
||||
file_id = os.path.splitext(filename)[0]
|
||||
try:
|
||||
return self._files_by_id[file_id]
|
||||
except KeyError as e:
|
||||
raise MediaFileStorageError(
|
||||
f"Bad filename '{filename}'. (No media file with id '{file_id}')"
|
||||
) from e
|
||||
|
||||
def get_url(self, file_id: str) -> str:
|
||||
"""Get a URL for a given media file. Raise a MediaFileStorageError if
|
||||
no such file exists.
|
||||
"""
|
||||
media_file = self.get_file(file_id)
|
||||
extension = get_extension_for_mimetype(media_file.mimetype)
|
||||
return f"{self._media_endpoint}/{file_id}{extension}"
|
||||
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""Delete the file with the given ID."""
|
||||
# We swallow KeyErrors here - it's not an error to delete a file
|
||||
# that doesn't exist.
|
||||
with contextlib.suppress(KeyError):
|
||||
del self._files_by_id[file_id]
|
||||
|
||||
def _read_file(self, filename: str) -> bytes:
|
||||
"""Read a file into memory. Raise MediaFileStorageError if we can't."""
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
return f.read()
|
||||
except Exception as ex:
|
||||
raise MediaFileStorageError(f"Error opening '{filename}'") from ex
|
||||
|
||||
def get_stats(self) -> List[CacheStat]:
|
||||
# We operate on a copy of our dict, to avoid race conditions
|
||||
# with other threads that may be manipulating the cache.
|
||||
files_by_id = self._files_by_id.copy()
|
||||
|
||||
stats: List[CacheStat] = []
|
||||
for file_id, file in files_by_id.items():
|
||||
stats.append(
|
||||
CacheStat(
|
||||
category_name="st_memory_media_file_storage",
|
||||
cache_name="",
|
||||
byte_length=len(file.content),
|
||||
)
|
||||
)
|
||||
return stats
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, MutableMapping, Optional
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from streamlit.runtime.session_manager import SessionInfo, SessionStorage
|
||||
|
||||
|
||||
class MemorySessionStorage(SessionStorage):
|
||||
"""A SessionStorage that stores sessions in memory.
|
||||
|
||||
At most maxsize sessions are stored with a TTL of ttl seconds. This class is really
|
||||
just a thin wrapper around cachetools.TTLCache that complies with the SessionStorage
|
||||
protocol.
|
||||
"""
|
||||
|
||||
# NOTE: The defaults for maxsize and ttl are chosen arbitrarily for now. These
|
||||
# numbers are reasonable as the main problems we're trying to solve at the moment are
|
||||
# caused by transient disconnects that are usually just short network blips. In the
|
||||
# future, we may want to increase both to support use cases such as saving state for
|
||||
# much longer periods of time. For example, we may want session state to persist if
|
||||
# a user closes their laptop lid and comes back to an app hours later.
|
||||
def __init__(
|
||||
self,
|
||||
maxsize: int = 128,
|
||||
ttl_seconds: int = 2 * 60, # 2 minutes
|
||||
) -> None:
|
||||
"""Instantiate a new MemorySessionStorage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maxsize
|
||||
The maximum number of sessions we allow to be stored in this
|
||||
MemorySessionStorage. If an entry needs to be removed because we have
|
||||
exceeded this number, either
|
||||
* an expired entry is removed, or
|
||||
* the least recently used entry is removed (if no entries have expired).
|
||||
|
||||
ttl_seconds
|
||||
The time in seconds for an entry added to a MemorySessionStorage to live.
|
||||
After this amount of time has passed for a given entry, it becomes
|
||||
inaccessible and will be removed eventually.
|
||||
"""
|
||||
|
||||
self._cache: MutableMapping[str, SessionInfo] = TTLCache(
|
||||
maxsize=maxsize, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
def get(self, session_id: str) -> Optional[SessionInfo]:
|
||||
return self._cache.get(session_id, None)
|
||||
|
||||
def save(self, session_info: SessionInfo) -> None:
|
||||
self._cache[session_info.session.id] = session_info
|
||||
|
||||
def delete(self, session_id: str) -> None:
|
||||
del self._cache[session_id]
|
||||
|
||||
def list(self) -> List[SessionInfo]:
|
||||
return list(self._cache.values())
|
||||
@@ -0,0 +1,373 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Sized
|
||||
from functools import wraps
|
||||
from timeit import default_timer as timer
|
||||
from typing import Any, Callable, List, Optional, Set, TypeVar, Union, cast, overload
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit import config, util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.proto.PageProfile_pb2 import Argument, Command
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
# Limit the number of commands to keep the page profile message small
|
||||
# since Segment allows only a maximum of 32kb per event.
|
||||
_MAX_TRACKED_COMMANDS: Final = 200
|
||||
# Only track a maximum of 25 uses per unique command since some apps use
|
||||
# commands excessively (e.g. calling add_rows thousands of times in one rerun)
|
||||
# making the page profile useless.
|
||||
_MAX_TRACKED_PER_COMMAND: Final = 25
|
||||
|
||||
# A mapping to convert from the actual name to preferred/shorter representations
|
||||
_OBJECT_NAME_MAPPING: Final = {
|
||||
"streamlit.delta_generator.DeltaGenerator": "DG",
|
||||
"pandas.core.frame.DataFrame": "DataFrame",
|
||||
"plotly.graph_objs._figure.Figure": "PlotlyFigure",
|
||||
"bokeh.plotting.figure.Figure": "BokehFigure",
|
||||
"matplotlib.figure.Figure": "MatplotlibFigure",
|
||||
"pandas.io.formats.style.Styler": "PandasStyler",
|
||||
"pandas.core.indexes.base.Index": "PandasIndex",
|
||||
"pandas.core.series.Series": "PandasSeries",
|
||||
}
|
||||
|
||||
# A list of dependencies to check for attribution
|
||||
_ATTRIBUTIONS_TO_CHECK: Final = [
|
||||
"snowflake",
|
||||
"torch",
|
||||
"tensorflow",
|
||||
"streamlit_extras",
|
||||
"streamlit_pydantic",
|
||||
"plost",
|
||||
]
|
||||
|
||||
_ETC_MACHINE_ID_PATH = "/etc/machine-id"
|
||||
_DBUS_MACHINE_ID_PATH = "/var/lib/dbus/machine-id"
|
||||
|
||||
|
||||
def _get_machine_id_v3() -> str:
|
||||
"""Get the machine ID
|
||||
|
||||
This is a unique identifier for a user for tracking metrics in Segment,
|
||||
that is broken in different ways in some Linux distros and Docker images.
|
||||
- at times just a hash of '', which means many machines map to the same ID
|
||||
- at times a hash of the same string, when running in a Docker container
|
||||
"""
|
||||
|
||||
machine_id = str(uuid.getnode())
|
||||
if os.path.isfile(_ETC_MACHINE_ID_PATH):
|
||||
with open(_ETC_MACHINE_ID_PATH, "r") as f:
|
||||
machine_id = f.read()
|
||||
|
||||
elif os.path.isfile(_DBUS_MACHINE_ID_PATH):
|
||||
with open(_DBUS_MACHINE_ID_PATH, "r") as f:
|
||||
machine_id = f.read()
|
||||
|
||||
return machine_id
|
||||
|
||||
|
||||
class Installation:
|
||||
_instance_lock = threading.Lock()
|
||||
_instance: Optional["Installation"] = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> "Installation":
|
||||
"""Returns the singleton Installation"""
|
||||
# We use a double-checked locking optimization to avoid the overhead
|
||||
# of acquiring the lock in the common case:
|
||||
# https://en.wikipedia.org/wiki/Double-checked_locking
|
||||
if cls._instance is None:
|
||||
with cls._instance_lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = Installation()
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
self.installation_id_v3 = str(
|
||||
uuid.uuid5(uuid.NAMESPACE_DNS, _get_machine_id_v3())
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def installation_id(self):
|
||||
return self.installation_id_v3
|
||||
|
||||
|
||||
def _get_type_name(obj: object) -> str:
|
||||
"""Get a simplified name for the type of the given object."""
|
||||
with contextlib.suppress(Exception):
|
||||
obj_type = type(obj)
|
||||
type_name = "unknown"
|
||||
if hasattr(obj_type, "__qualname__"):
|
||||
type_name = obj_type.__qualname__
|
||||
elif hasattr(obj_type, "__name__"):
|
||||
type_name = obj_type.__name__
|
||||
|
||||
if obj_type.__module__ != "builtins":
|
||||
# Add the full module path
|
||||
type_name = f"{obj_type.__module__}.{type_name}"
|
||||
|
||||
if type_name in _OBJECT_NAME_MAPPING:
|
||||
type_name = _OBJECT_NAME_MAPPING[type_name]
|
||||
return type_name
|
||||
return "failed"
|
||||
|
||||
|
||||
def _get_top_level_module(func: Callable[..., Any]) -> str:
|
||||
"""Get the top level module for the given function."""
|
||||
module = inspect.getmodule(func)
|
||||
if module is None or not module.__name__:
|
||||
return "unknown"
|
||||
return module.__name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_arg_metadata(arg: object) -> Optional[str]:
|
||||
"""Get metadata information related to the value of the given object."""
|
||||
with contextlib.suppress(Exception):
|
||||
if isinstance(arg, (bool)):
|
||||
return f"val:{arg}"
|
||||
|
||||
if isinstance(arg, Sized):
|
||||
return f"len:{len(arg)}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_command_telemetry(
|
||||
_command_func: Callable[..., Any], _command_name: str, *args, **kwargs
|
||||
) -> Command:
|
||||
"""Get telemetry information for the given callable and its arguments."""
|
||||
arg_keywords = inspect.getfullargspec(_command_func).args
|
||||
self_arg: Optional[Any] = None
|
||||
arguments: List[Argument] = []
|
||||
is_method = inspect.ismethod(_command_func)
|
||||
name = _command_name
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
pos = i
|
||||
if is_method:
|
||||
# If func is a method, ignore the first argument (self)
|
||||
i = i + 1
|
||||
|
||||
keyword = arg_keywords[i] if len(arg_keywords) > i else f"{i}"
|
||||
if keyword == "self":
|
||||
self_arg = arg
|
||||
continue
|
||||
argument = Argument(k=keyword, t=_get_type_name(arg), p=pos)
|
||||
|
||||
arg_metadata = _get_arg_metadata(arg)
|
||||
if arg_metadata:
|
||||
argument.m = arg_metadata
|
||||
arguments.append(argument)
|
||||
for kwarg, kwarg_value in kwargs.items():
|
||||
argument = Argument(k=kwarg, t=_get_type_name(kwarg_value))
|
||||
|
||||
arg_metadata = _get_arg_metadata(kwarg_value)
|
||||
if arg_metadata:
|
||||
argument.m = arg_metadata
|
||||
arguments.append(argument)
|
||||
|
||||
top_level_module = _get_top_level_module(_command_func)
|
||||
if top_level_module != "streamlit":
|
||||
# If the gather_metrics decorator is used outside of streamlit library
|
||||
# we enforce a prefix to be added to the tracked command:
|
||||
name = f"external:{top_level_module}:{name}"
|
||||
|
||||
if (
|
||||
name == "create_instance"
|
||||
and self_arg
|
||||
and hasattr(self_arg, "name")
|
||||
and self_arg.name
|
||||
):
|
||||
name = f"component:{self_arg.name}"
|
||||
|
||||
return Command(name=name, args=arguments)
|
||||
|
||||
|
||||
def to_microseconds(seconds: float) -> int:
|
||||
"""Convert seconds into microseconds."""
|
||||
return int(seconds * 1_000_000)
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@overload
|
||||
def gather_metrics(
|
||||
name: str,
|
||||
func: F,
|
||||
) -> F:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def gather_metrics(
|
||||
name: str,
|
||||
func: None = None,
|
||||
) -> Callable[[F], F]:
|
||||
...
|
||||
|
||||
|
||||
def gather_metrics(name: str, func: Optional[F] = None) -> Union[Callable[[F], F], F]:
|
||||
"""Function decorator to add telemetry tracking to commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to track for telemetry.
|
||||
|
||||
name : str or None
|
||||
Overwrite the function name with a custom name that is used for telemetry tracking.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> @st.gather_metrics
|
||||
... def my_command(url):
|
||||
... return url
|
||||
|
||||
>>> @st.gather_metrics(name="custom_name")
|
||||
... def my_command(url):
|
||||
... return url
|
||||
"""
|
||||
|
||||
if not name:
|
||||
_LOGGER.warning("gather_metrics: name is empty")
|
||||
name = "undefined"
|
||||
|
||||
if func is None:
|
||||
# Support passing the params via function decorator
|
||||
def wrapper(f: F) -> F:
|
||||
return gather_metrics(
|
||||
name=name,
|
||||
func=f,
|
||||
)
|
||||
|
||||
return wrapper
|
||||
else:
|
||||
# To make mypy type narrow Optional[F] -> F
|
||||
non_optional_func = func
|
||||
|
||||
@wraps(non_optional_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
exec_start = timer()
|
||||
# get_script_run_ctx gets imported here to prevent circular dependencies
|
||||
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
||||
|
||||
ctx = get_script_run_ctx(suppress_warning=True)
|
||||
|
||||
tracking_activated = (
|
||||
ctx is not None
|
||||
and ctx.gather_usage_stats
|
||||
and not ctx.command_tracking_deactivated
|
||||
and len(ctx.tracked_commands)
|
||||
< _MAX_TRACKED_COMMANDS # Prevent too much memory usage
|
||||
)
|
||||
command_telemetry: Optional[Command] = None
|
||||
|
||||
if ctx and tracking_activated:
|
||||
try:
|
||||
command_telemetry = _get_command_telemetry(
|
||||
non_optional_func, name, *args, **kwargs
|
||||
)
|
||||
|
||||
if (
|
||||
command_telemetry.name not in ctx.tracked_commands_counter
|
||||
or ctx.tracked_commands_counter[command_telemetry.name]
|
||||
< _MAX_TRACKED_PER_COMMAND
|
||||
):
|
||||
ctx.tracked_commands.append(command_telemetry)
|
||||
ctx.tracked_commands_counter.update([command_telemetry.name])
|
||||
# Deactivate tracking to prevent calls inside already tracked commands
|
||||
ctx.command_tracking_deactivated = True
|
||||
except Exception as ex:
|
||||
# Always capture all exceptions since we want to make sure that
|
||||
# the telemetry never causes any issues.
|
||||
_LOGGER.debug("Failed to collect command telemetry", exc_info=ex)
|
||||
try:
|
||||
result = non_optional_func(*args, **kwargs)
|
||||
finally:
|
||||
# Activate tracking again if command executes without any exceptions
|
||||
if ctx:
|
||||
ctx.command_tracking_deactivated = False
|
||||
|
||||
if tracking_activated and command_telemetry:
|
||||
# Set the execution time to the measured value
|
||||
command_telemetry.time = to_microseconds(timer() - exec_start)
|
||||
return result
|
||||
|
||||
with contextlib.suppress(AttributeError):
|
||||
# Make this a well-behaved decorator by preserving important function
|
||||
# attributes.
|
||||
wrapped_func.__dict__.update(non_optional_func.__dict__)
|
||||
wrapped_func.__signature__ = inspect.signature(non_optional_func) # type: ignore
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
|
||||
def create_page_profile_message(
|
||||
commands: List[Command],
|
||||
exec_time: int,
|
||||
prep_time: int,
|
||||
uncaught_exception: Optional[str] = None,
|
||||
) -> ForwardMsg:
|
||||
"""Create and return the full PageProfile ForwardMsg."""
|
||||
msg = ForwardMsg()
|
||||
msg.page_profile.commands.extend(commands)
|
||||
msg.page_profile.exec_time = exec_time
|
||||
msg.page_profile.prep_time = prep_time
|
||||
|
||||
msg.page_profile.headless = config.get_option("server.headless")
|
||||
|
||||
# Collect all config options that have been manually set
|
||||
config_options: Set[str] = set()
|
||||
if config._config_options:
|
||||
for option_name in config._config_options.keys():
|
||||
if not config.is_manually_set(option_name):
|
||||
# We only care about manually defined options
|
||||
continue
|
||||
|
||||
config_option = config._config_options[option_name]
|
||||
if config_option.is_default:
|
||||
option_name = f"{option_name}:default"
|
||||
config_options.add(option_name)
|
||||
|
||||
msg.page_profile.config.extend(config_options)
|
||||
|
||||
# Check the predefined set of modules for attribution
|
||||
attributions: Set[str] = {
|
||||
attribution
|
||||
for attribution in _ATTRIBUTIONS_TO_CHECK
|
||||
if attribution in sys.modules
|
||||
}
|
||||
|
||||
msg.page_profile.os = str(sys.platform)
|
||||
msg.page_profile.timezone = str(time.tzname)
|
||||
msg.page_profile.attributions.extend(attributions)
|
||||
|
||||
if uncaught_exception:
|
||||
msg.page_profile.uncaught_exception = uncaught_exception
|
||||
|
||||
return msg
|
||||
726
venv/lib/python3.9/site-packages/streamlit/runtime/runtime.py
Normal file
726
venv/lib/python3.9/site-packages/streamlit/runtime/runtime.py
Normal file
@@ -0,0 +1,726 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Awaitable, Dict, NamedTuple, Optional, Tuple, Type
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.BackMsg_pb2 import BackMsg
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
from streamlit.runtime.caching import (
|
||||
get_data_cache_stats_provider,
|
||||
get_resource_cache_stats_provider,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.local_disk_cache_storage import (
|
||||
LocalDiskCacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.forward_msg_cache import (
|
||||
ForwardMsgCache,
|
||||
create_reference_msg,
|
||||
populate_hash_if_needed,
|
||||
)
|
||||
from streamlit.runtime.legacy_caching.caching import _mem_caches
|
||||
from streamlit.runtime.media_file_manager import MediaFileManager
|
||||
from streamlit.runtime.media_file_storage import MediaFileStorage
|
||||
from streamlit.runtime.memory_session_storage import MemorySessionStorage
|
||||
from streamlit.runtime.runtime_util import is_cacheable_msg
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.session_manager import (
|
||||
ActiveSessionInfo,
|
||||
SessionClient,
|
||||
SessionClientDisconnectedError,
|
||||
SessionManager,
|
||||
SessionStorage,
|
||||
)
|
||||
from streamlit.runtime.state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
SessionStateStatProvider,
|
||||
)
|
||||
from streamlit.runtime.stats import StatsManager
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
from streamlit.runtime.websocket_session_manager import WebsocketSessionManager
|
||||
from streamlit.watcher import LocalSourcesWatcher
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.caching.storage import CacheStorageManager
|
||||
|
||||
# Wait for the script run result for 60s and if no result is available give up
|
||||
SCRIPT_RUN_CHECK_TIMEOUT: Final = 60
|
||||
|
||||
LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class RuntimeStoppedError(Exception):
|
||||
"""Raised by operations on a Runtime instance that is stopped."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeConfig:
|
||||
"""Config options for StreamlitRuntime."""
|
||||
|
||||
# The filesystem path of the Streamlit script to run.
|
||||
script_path: str
|
||||
|
||||
# The (optional) command line that Streamlit was started with
|
||||
# (e.g. "streamlit run app.py")
|
||||
command_line: Optional[str]
|
||||
|
||||
# The storage backend for Streamlit's MediaFileManager.
|
||||
media_file_storage: MediaFileStorage
|
||||
|
||||
# The cache storage backend for Streamlit's st.cache_data.
|
||||
cache_storage_manager: CacheStorageManager = field(
|
||||
default_factory=LocalDiskCacheStorageManager
|
||||
)
|
||||
|
||||
# The SessionManager class to be used.
|
||||
session_manager_class: Type[SessionManager] = WebsocketSessionManager
|
||||
|
||||
# The SessionStorage instance for the SessionManager to use.
|
||||
session_storage: SessionStorage = field(default_factory=MemorySessionStorage)
|
||||
|
||||
|
||||
class RuntimeState(Enum):
|
||||
INITIAL = "INITIAL"
|
||||
NO_SESSIONS_CONNECTED = "NO_SESSIONS_CONNECTED"
|
||||
ONE_OR_MORE_SESSIONS_CONNECTED = "ONE_OR_MORE_SESSIONS_CONNECTED"
|
||||
STOPPING = "STOPPING"
|
||||
STOPPED = "STOPPED"
|
||||
|
||||
|
||||
class AsyncObjects(NamedTuple):
|
||||
"""Container for all asyncio objects that Runtime manages.
|
||||
These cannot be initialized until the Runtime's eventloop is assigned.
|
||||
"""
|
||||
|
||||
# The eventloop that Runtime is running on.
|
||||
eventloop: asyncio.AbstractEventLoop
|
||||
|
||||
# Set after Runtime.stop() is called. Never cleared.
|
||||
must_stop: asyncio.Event
|
||||
|
||||
# Set when a client connects; cleared when we have no connected clients.
|
||||
has_connection: asyncio.Event
|
||||
|
||||
# Set after a ForwardMsg is enqueued; cleared when we flush ForwardMsgs.
|
||||
need_send_data: asyncio.Event
|
||||
|
||||
# Completed when the Runtime has started.
|
||||
started: asyncio.Future[None]
|
||||
|
||||
# Completed when the Runtime has stopped.
|
||||
stopped: asyncio.Future[None]
|
||||
|
||||
|
||||
class Runtime:
|
||||
_instance: Optional[Runtime] = None
|
||||
|
||||
@classmethod
|
||||
def instance(cls) -> Runtime:
|
||||
"""Return the singleton Runtime instance. Raise an Error if the
|
||||
Runtime hasn't been created yet.
|
||||
"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Runtime hasn't been created!")
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def exists(cls) -> bool:
|
||||
"""True if the singleton Runtime instance has been created.
|
||||
|
||||
When a Streamlit app is running in "raw mode" - that is, when the
|
||||
app is run via `python app.py` instead of `streamlit run app.py` -
|
||||
the Runtime will not exist, and various Streamlit functions need
|
||||
to adapt.
|
||||
"""
|
||||
return cls._instance is not None
|
||||
|
||||
def __init__(self, config: RuntimeConfig):
|
||||
"""Create a Runtime instance. It won't be started yet.
|
||||
|
||||
Runtime is *not* thread-safe. Its public methods are generally
|
||||
safe to call only on the same thread that its event loop runs on.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config
|
||||
Config options.
|
||||
"""
|
||||
if Runtime._instance is not None:
|
||||
raise RuntimeError("Runtime instance already exists!")
|
||||
Runtime._instance = self
|
||||
|
||||
# Will be created when we start.
|
||||
self._async_objs: Optional[AsyncObjects] = None
|
||||
|
||||
# The task that runs our main loop. We need to save a reference
|
||||
# to it so that it doesn't get garbage collected while running.
|
||||
self._loop_coroutine_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
self._main_script_path = config.script_path
|
||||
self._command_line = config.command_line or ""
|
||||
|
||||
self._state = RuntimeState.INITIAL
|
||||
|
||||
# Initialize managers
|
||||
self._message_cache = ForwardMsgCache()
|
||||
self._uploaded_file_mgr = UploadedFileManager()
|
||||
self._uploaded_file_mgr.on_files_updated.connect(self._on_files_updated)
|
||||
self._media_file_mgr = MediaFileManager(storage=config.media_file_storage)
|
||||
self._cache_storage_manager = config.cache_storage_manager
|
||||
|
||||
self._session_mgr = config.session_manager_class(
|
||||
session_storage=config.session_storage,
|
||||
uploaded_file_manager=self._uploaded_file_mgr,
|
||||
message_enqueued_callback=self._enqueued_some_message,
|
||||
)
|
||||
|
||||
self._stats_mgr = StatsManager()
|
||||
self._stats_mgr.register_provider(get_data_cache_stats_provider())
|
||||
self._stats_mgr.register_provider(get_resource_cache_stats_provider())
|
||||
self._stats_mgr.register_provider(_mem_caches)
|
||||
self._stats_mgr.register_provider(self._message_cache)
|
||||
self._stats_mgr.register_provider(self._uploaded_file_mgr)
|
||||
self._stats_mgr.register_provider(SessionStateStatProvider(self._session_mgr))
|
||||
|
||||
@property
|
||||
def state(self) -> RuntimeState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def message_cache(self) -> ForwardMsgCache:
|
||||
return self._message_cache
|
||||
|
||||
@property
|
||||
def uploaded_file_mgr(self) -> UploadedFileManager:
|
||||
return self._uploaded_file_mgr
|
||||
|
||||
@property
|
||||
def cache_storage_manager(self) -> CacheStorageManager:
|
||||
return self._cache_storage_manager
|
||||
|
||||
@property
|
||||
def media_file_mgr(self) -> MediaFileManager:
|
||||
return self._media_file_mgr
|
||||
|
||||
@property
|
||||
def stats_mgr(self) -> StatsManager:
|
||||
return self._stats_mgr
|
||||
|
||||
@property
|
||||
def stopped(self) -> Awaitable[None]:
|
||||
"""A Future that completes when the Runtime's run loop has exited."""
|
||||
return self._get_async_objs().stopped
|
||||
|
||||
# NOTE: A few Runtime methods listed as threadsafe (get_client, _on_files_updated,
|
||||
# and is_active_session) currently rely on the implementation detail that
|
||||
# WebsocketSessionManager's get_active_session_info and is_active_session methods
|
||||
# happen to be threadsafe. This may change with future SessionManager implementations,
|
||||
# at which point we'll need to formalize our thread safety rules for each
|
||||
# SessionManager method.
|
||||
def get_client(self, session_id: str) -> Optional[SessionClient]:
|
||||
"""Get the SessionClient for the given session_id, or None
|
||||
if no such session exists.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info is None:
|
||||
return None
|
||||
return session_info.client
|
||||
|
||||
def _on_files_updated(self, session_id: str) -> None:
|
||||
"""Event handler for UploadedFileManager.on_file_added.
|
||||
Ensures that uploaded files from stale sessions get deleted.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
if not self._session_mgr.is_active_session(session_id):
|
||||
# If an uploaded file doesn't belong to an active session,
|
||||
# remove it so it doesn't stick around forever.
|
||||
self._uploaded_file_mgr.remove_session_files(session_id)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the runtime. This must be called only once, before
|
||||
any other functions are called.
|
||||
|
||||
When this coroutine returns, Streamlit is ready to accept new sessions.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
|
||||
# Create our AsyncObjects. We need to have a running eventloop to
|
||||
# instantiate our various synchronization primitives.
|
||||
async_objs = AsyncObjects(
|
||||
eventloop=asyncio.get_running_loop(),
|
||||
must_stop=asyncio.Event(),
|
||||
has_connection=asyncio.Event(),
|
||||
need_send_data=asyncio.Event(),
|
||||
started=asyncio.Future(),
|
||||
stopped=asyncio.Future(),
|
||||
)
|
||||
self._async_objs = async_objs
|
||||
|
||||
if sys.version_info >= (3, 8, 0):
|
||||
# Python 3.8+ supports a create_task `name` parameter, which can
|
||||
# make debugging a bit easier.
|
||||
self._loop_coroutine_task = asyncio.create_task(
|
||||
self._loop_coroutine(), name="Runtime.loop_coroutine"
|
||||
)
|
||||
else:
|
||||
self._loop_coroutine_task = asyncio.create_task(self._loop_coroutine())
|
||||
|
||||
await async_objs.started
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Request that Streamlit close all sessions and stop running.
|
||||
Note that Streamlit won't stop running immediately.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called from any thread.
|
||||
"""
|
||||
|
||||
async_objs = self._get_async_objs()
|
||||
|
||||
def stop_on_eventloop():
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
return
|
||||
|
||||
LOGGER.debug("Runtime stopping...")
|
||||
self._set_state(RuntimeState.STOPPING)
|
||||
async_objs.must_stop.set()
|
||||
|
||||
async_objs.eventloop.call_soon_threadsafe(stop_on_eventloop)
|
||||
|
||||
def is_active_session(self, session_id: str) -> bool:
|
||||
"""True if the session_id belongs to an active session.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
return self._session_mgr.is_active_session(session_id)
|
||||
|
||||
def connect_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
user_info: Dict[str, Optional[str]],
|
||||
existing_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Create a new session (or connect to an existing one) and return its unique ID.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client
|
||||
A concrete SessionClient implementation for communicating with
|
||||
the session's client.
|
||||
user_info
|
||||
A dict that contains information about the session's user. For now,
|
||||
it only (optionally) contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The session's unique string ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
raise RuntimeStoppedError(f"Can't connect_session (state={self._state})")
|
||||
|
||||
session_id = self._session_mgr.connect_session(
|
||||
client=client,
|
||||
script_data=ScriptData(self._main_script_path, self._command_line or ""),
|
||||
user_info=user_info,
|
||||
existing_session_id=existing_session_id,
|
||||
)
|
||||
self._set_state(RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED)
|
||||
self._get_async_objs().has_connection.set()
|
||||
|
||||
return session_id
|
||||
|
||||
def create_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
user_info: Dict[str, Optional[str]],
|
||||
existing_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Create a new session (or connect to an existing one) and return its unique ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This method is simply an alias for connect_session added for backwards
|
||||
compatibility.
|
||||
"""
|
||||
LOGGER.warning("create_session is deprecated! Use connect_session instead.")
|
||||
return self.connect_session(
|
||||
client=client, user_info=user_info, existing_session_id=existing_session_id
|
||||
)
|
||||
|
||||
def close_session(self, session_id: str) -> None:
|
||||
"""Close and completely shut down a session.
|
||||
|
||||
This differs from disconnect_session in that it always completely shuts down a
|
||||
session, permanently losing any associated state (session state, uploaded files,
|
||||
etc.).
|
||||
|
||||
This function may be called multiple times for the same session,
|
||||
which is not an error. (Subsequent calls just no-op.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
self._session_mgr.close_session(session_id)
|
||||
self._on_session_disconnected()
|
||||
|
||||
def disconnect_session(self, session_id: str) -> None:
|
||||
"""Disconnect a session. It will stop producing ForwardMsgs.
|
||||
|
||||
Differs from close_session because disconnected sessions can be reconnected to
|
||||
for a brief window (depending on the SessionManager/SessionStorage
|
||||
implementations used by the runtime).
|
||||
|
||||
This function may be called multiple times for the same session,
|
||||
which is not an error. (Subsequent calls just no-op.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
self._session_mgr.disconnect_session(session_id)
|
||||
self._on_session_disconnected()
|
||||
|
||||
def handle_backmsg(self, session_id: str, msg: BackMsg) -> None:
|
||||
"""Send a BackMsg to an active session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
msg
|
||||
The BackMsg to deliver to the session.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
raise RuntimeStoppedError(f"Can't handle_backmsg (state={self._state})")
|
||||
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info is None:
|
||||
LOGGER.debug(
|
||||
"Discarding BackMsg for disconnected session (id=%s)", session_id
|
||||
)
|
||||
return
|
||||
|
||||
session_info.session.handle_backmsg(msg)
|
||||
|
||||
def handle_backmsg_deserialization_exception(
|
||||
self, session_id: str, exc: BaseException
|
||||
) -> None:
|
||||
"""Handle an Exception raised during deserialization of a BackMsg.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
exc
|
||||
The Exception.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
if self._state in (RuntimeState.STOPPING, RuntimeState.STOPPED):
|
||||
raise RuntimeStoppedError(
|
||||
f"Can't handle_backmsg_deserialization_exception (state={self._state})"
|
||||
)
|
||||
|
||||
session_info = self._session_mgr.get_active_session_info(session_id)
|
||||
if session_info is None:
|
||||
LOGGER.debug(
|
||||
"Discarding BackMsg Exception for disconnected session (id=%s)",
|
||||
session_id,
|
||||
)
|
||||
return
|
||||
|
||||
session_info.session.handle_backmsg_exception(exc)
|
||||
|
||||
@property
|
||||
async def is_ready_for_browser_connection(self) -> Tuple[bool, str]:
|
||||
if self._state not in (
|
||||
RuntimeState.INITIAL,
|
||||
RuntimeState.STOPPING,
|
||||
RuntimeState.STOPPED,
|
||||
):
|
||||
return True, "ok"
|
||||
|
||||
return False, "unavailable"
|
||||
|
||||
async def does_script_run_without_error(self) -> Tuple[bool, str]:
|
||||
"""Load and execute the app's script to verify it runs without an error.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(True, "ok") if the script completes without error, or (False, err_msg)
|
||||
if the script raises an exception.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
# NOTE: We create an AppSession directly here instead of using the
|
||||
# SessionManager intentionally. This isn't a "real" session and is only being
|
||||
# used to test that the script runs without error.
|
||||
session = AppSession(
|
||||
script_data=ScriptData(self._main_script_path, self._command_line),
|
||||
uploaded_file_manager=self._uploaded_file_mgr,
|
||||
message_enqueued_callback=self._enqueued_some_message,
|
||||
local_sources_watcher=LocalSourcesWatcher(self._main_script_path),
|
||||
user_info={"email": "test@test.com"},
|
||||
)
|
||||
|
||||
try:
|
||||
session.request_rerun(None)
|
||||
|
||||
now = time.perf_counter()
|
||||
while (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state
|
||||
and (time.perf_counter() - now) < SCRIPT_RUN_CHECK_TIMEOUT
|
||||
):
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
if SCRIPT_RUN_WITHOUT_ERRORS_KEY not in session.session_state:
|
||||
return False, "timeout"
|
||||
|
||||
ok = session.session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY]
|
||||
msg = "ok" if ok else "error"
|
||||
|
||||
return ok, msg
|
||||
finally:
|
||||
session.shutdown()
|
||||
|
||||
def _set_state(self, new_state: RuntimeState) -> None:
|
||||
LOGGER.debug("Runtime state: %s -> %s", self._state, new_state)
|
||||
self._state = new_state
|
||||
|
||||
async def _loop_coroutine(self) -> None:
|
||||
"""The main Runtime loop.
|
||||
|
||||
This function won't exit until `stop` is called.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
|
||||
async_objs = self._get_async_objs()
|
||||
|
||||
try:
|
||||
if self._state == RuntimeState.INITIAL:
|
||||
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)
|
||||
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError(f"Bad Runtime state at start: {self._state}")
|
||||
|
||||
# Signal that we're started and ready to accept sessions
|
||||
async_objs.started.set_result(None)
|
||||
|
||||
while not async_objs.must_stop.is_set():
|
||||
if self._state == RuntimeState.NO_SESSIONS_CONNECTED:
|
||||
await asyncio.wait(
|
||||
(
|
||||
asyncio.create_task(async_objs.must_stop.wait()),
|
||||
asyncio.create_task(async_objs.has_connection.wait()),
|
||||
),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
elif self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED:
|
||||
async_objs.need_send_data.clear()
|
||||
|
||||
for active_session_info in self._session_mgr.list_active_sessions():
|
||||
msg_list = active_session_info.session.flush_browser_queue()
|
||||
for msg in msg_list:
|
||||
try:
|
||||
self._send_message(active_session_info, msg)
|
||||
except SessionClientDisconnectedError:
|
||||
self._session_mgr.disconnect_session(
|
||||
active_session_info.session.id
|
||||
)
|
||||
|
||||
# Yield for a tick after sending a message.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Yield for a few milliseconds between session message
|
||||
# flushing.
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
else:
|
||||
# Break out of the thread loop if we encounter any other state.
|
||||
break
|
||||
|
||||
await asyncio.wait(
|
||||
(
|
||||
asyncio.create_task(async_objs.must_stop.wait()),
|
||||
asyncio.create_task(async_objs.need_send_data.wait()),
|
||||
),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
# Shut down all AppSessions.
|
||||
for session_info in self._session_mgr.list_sessions():
|
||||
# NOTE: We want to fully shut down sessions when the runtime stops for
|
||||
# now, but this may change in the future if/when our notion of a session
|
||||
# is no longer so tightly coupled to a browser tab.
|
||||
self._session_mgr.close_session(session_info.session.id)
|
||||
|
||||
self._set_state(RuntimeState.STOPPED)
|
||||
async_objs.stopped.set_result(None)
|
||||
|
||||
except Exception as e:
|
||||
async_objs.stopped.set_exception(e)
|
||||
traceback.print_exc()
|
||||
LOGGER.info(
|
||||
"""
|
||||
Please report this bug at https://github.com/streamlit/streamlit/issues.
|
||||
"""
|
||||
)
|
||||
|
||||
def _send_message(self, session_info: ActiveSessionInfo, msg: ForwardMsg) -> None:
|
||||
"""Send a message to a client.
|
||||
|
||||
If the client is likely to have already cached the message, we may
|
||||
instead send a "reference" message that contains only the hash of the
|
||||
message.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_info : ActiveSessionInfo
|
||||
The ActiveSessionInfo associated with websocket
|
||||
msg : ForwardMsg
|
||||
The message to send to the client
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: UNSAFE. Must be called on the eventloop thread.
|
||||
"""
|
||||
msg.metadata.cacheable = is_cacheable_msg(msg)
|
||||
msg_to_send = msg
|
||||
if msg.metadata.cacheable:
|
||||
populate_hash_if_needed(msg)
|
||||
|
||||
if self._message_cache.has_message_reference(
|
||||
msg, session_info.session, session_info.script_run_count
|
||||
):
|
||||
# This session has probably cached this message. Send
|
||||
# a reference instead.
|
||||
LOGGER.debug("Sending cached message ref (hash=%s)", msg.hash)
|
||||
msg_to_send = create_reference_msg(msg)
|
||||
|
||||
# Cache the message so it can be referenced in the future.
|
||||
# If the message is already cached, this will reset its
|
||||
# age.
|
||||
LOGGER.debug("Caching message (hash=%s)", msg.hash)
|
||||
self._message_cache.add_message(
|
||||
msg, session_info.session, session_info.script_run_count
|
||||
)
|
||||
|
||||
# If this was a `script_finished` message, we increment the
|
||||
# script_run_count for this session, and update the cache
|
||||
if (
|
||||
msg.WhichOneof("type") == "script_finished"
|
||||
and msg.script_finished == ForwardMsg.FINISHED_SUCCESSFULLY
|
||||
):
|
||||
LOGGER.debug(
|
||||
"Script run finished successfully; "
|
||||
"removing expired entries from MessageCache "
|
||||
"(max_age=%s)",
|
||||
config.get_option("global.maxCachedMessageAge"),
|
||||
)
|
||||
session_info.script_run_count += 1
|
||||
self._message_cache.remove_expired_session_entries(
|
||||
session_info.session, session_info.script_run_count
|
||||
)
|
||||
|
||||
# Ship it off!
|
||||
session_info.client.write_forward_msg(msg_to_send)
|
||||
|
||||
def _enqueued_some_message(self) -> None:
|
||||
"""Callback called by AppSession after the AppSession has enqueued a
|
||||
message. Sets the "needs_send_data" event, which causes our core
|
||||
loop to wake up and flush client message queues.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: SAFE. May be called on any thread.
|
||||
"""
|
||||
async_objs = self._get_async_objs()
|
||||
async_objs.eventloop.call_soon_threadsafe(async_objs.need_send_data.set)
|
||||
|
||||
def _get_async_objs(self) -> AsyncObjects:
|
||||
"""Return our AsyncObjects instance. If the Runtime hasn't been
|
||||
started, this will raise an error.
|
||||
"""
|
||||
if self._async_objs is None:
|
||||
raise RuntimeError("Runtime hasn't started yet!")
|
||||
return self._async_objs
|
||||
|
||||
def _on_session_disconnected(self) -> None:
|
||||
"""Set the runtime state to NO_SESSIONS_CONNECTED if the last active
|
||||
session was disconnected.
|
||||
"""
|
||||
if (
|
||||
self._state == RuntimeState.ONE_OR_MORE_SESSIONS_CONNECTED
|
||||
and self._session_mgr.num_active_sessions() == 0
|
||||
):
|
||||
self._get_async_objs().has_connection.clear()
|
||||
self._set_state(RuntimeState.NO_SESSIONS_CONNECTED)
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Runtime-related utility functions"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from streamlit import config
|
||||
from streamlit.errors import MarkdownFormattedException
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.forward_msg_cache import populate_hash_if_needed
|
||||
|
||||
|
||||
class MessageSizeError(MarkdownFormattedException):
|
||||
"""Exception raised when a websocket message is larger than the configured limit."""
|
||||
|
||||
def __init__(self, failed_msg_str: Any):
|
||||
msg = self._get_message(failed_msg_str)
|
||||
super(MessageSizeError, self).__init__(msg)
|
||||
|
||||
def _get_message(self, failed_msg_str: Any) -> str:
|
||||
# This needs to have zero indentation otherwise the markdown will render incorrectly.
|
||||
return (
|
||||
(
|
||||
"""
|
||||
**Data of size {message_size_mb:.1f} MB exceeds the message size limit of {message_size_limit_mb} MB.**
|
||||
|
||||
This is often caused by a large chart or dataframe. Please decrease the amount of data sent
|
||||
to the browser, or increase the limit by setting the config option `server.maxMessageSize`.
|
||||
[Click here to learn more about config options](https://docs.streamlit.io/library/advanced-features/configuration#set-configuration-options).
|
||||
|
||||
_Note that increasing the limit may lead to long loading times and large memory consumption
|
||||
of the client's browser and the Streamlit server._
|
||||
"""
|
||||
)
|
||||
.format(
|
||||
message_size_mb=len(failed_msg_str) / 1e6,
|
||||
message_size_limit_mb=(get_max_message_size_bytes() / 1e6),
|
||||
)
|
||||
.strip("\n")
|
||||
)
|
||||
|
||||
|
||||
def is_cacheable_msg(msg: ForwardMsg) -> bool:
|
||||
"""True if the given message qualifies for caching."""
|
||||
if msg.WhichOneof("type") in {"ref_hash", "initialize"}:
|
||||
# Some message types never get cached
|
||||
return False
|
||||
return msg.ByteSize() >= int(config.get_option("global.minCachedMessageSize"))
|
||||
|
||||
|
||||
def serialize_forward_msg(msg: ForwardMsg) -> bytes:
|
||||
"""Serialize a ForwardMsg to send to a client.
|
||||
|
||||
If the message is too large, it will be converted to an exception message
|
||||
instead.
|
||||
"""
|
||||
populate_hash_if_needed(msg)
|
||||
msg_str = msg.SerializeToString()
|
||||
|
||||
if len(msg_str) > get_max_message_size_bytes():
|
||||
import streamlit.elements.exception as exception
|
||||
|
||||
# Overwrite the offending ForwardMsg.delta with an error to display.
|
||||
# This assumes that the size limit wasn't exceeded due to metadata.
|
||||
exception.marshall(msg.delta.new_element.exception, MessageSizeError(msg_str))
|
||||
msg_str = msg.SerializeToString()
|
||||
|
||||
return msg_str
|
||||
|
||||
|
||||
# This needs to be initialized lazily to avoid calling config.get_option() and
|
||||
# thus initializing config options when this file is first imported.
|
||||
_max_message_size_bytes: Optional[int] = None
|
||||
|
||||
|
||||
def get_max_message_size_bytes() -> int:
|
||||
"""Returns the max websocket message size in bytes.
|
||||
|
||||
This will lazyload the value from the config and store it in the global symbol table.
|
||||
"""
|
||||
global _max_message_size_bytes
|
||||
|
||||
if _max_message_size_bytes is None:
|
||||
_max_message_size_bytes = config.get_option("server.maxMessageSize") * int(1e6)
|
||||
|
||||
return _max_message_size_bytes
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScriptData:
|
||||
"""Contains parameters related to running a script."""
|
||||
|
||||
main_script_path: str
|
||||
command_line: str
|
||||
script_folder: str = field(init=False)
|
||||
name: str = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Set some computed values derived from main_script_path.
|
||||
|
||||
The usage of object.__setattr__ is necessary because trying to set
|
||||
self.script_folder or self.name normally, even within the __init__ method, will
|
||||
explode since we declared this dataclass to be frozen.
|
||||
|
||||
We do this in __post_init__ so that we can use the auto-generated __init__
|
||||
method that most dataclasses use.
|
||||
"""
|
||||
main_script_path = os.path.abspath(self.main_script_path)
|
||||
script_folder = os.path.dirname(main_script_path)
|
||||
object.__setattr__(self, "script_folder", script_folder)
|
||||
|
||||
basename = os.path.basename(main_script_path)
|
||||
name = str(os.path.splitext(basename)[0])
|
||||
object.__setattr__(self, "name", name)
|
||||
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Explicitly export public symbols
|
||||
from streamlit.runtime.scriptrunner.script_requests import RerunData as RerunData
|
||||
from streamlit.runtime.scriptrunner.script_run_context import (
|
||||
ScriptRunContext as ScriptRunContext,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_run_context import (
|
||||
add_script_run_ctx as add_script_run_ctx,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_run_context import (
|
||||
get_script_run_ctx as get_script_run_ctx,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_runner import (
|
||||
RerunException as RerunException,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_runner import ScriptRunner as ScriptRunner
|
||||
from streamlit.runtime.scriptrunner.script_runner import (
|
||||
ScriptRunnerEvent as ScriptRunnerEvent,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_runner import StopException as StopException
|
||||
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ast
|
||||
import sys
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
# When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of
|
||||
# its module's AST:
|
||||
# import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__
|
||||
MAGIC_MODULE_NAME: Final = "__streamlitmagic__"
|
||||
|
||||
|
||||
def add_magic(code, script_path):
|
||||
"""Modifies the code to support magic Streamlit commands.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code : str
|
||||
The Python code.
|
||||
script_path : str
|
||||
The path to the script file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ast.Module
|
||||
The syntax tree for the code.
|
||||
|
||||
"""
|
||||
# Pass script_path so we get pretty exceptions.
|
||||
tree = ast.parse(code, script_path, "exec")
|
||||
return _modify_ast_subtree(tree, is_root=True)
|
||||
|
||||
|
||||
def _modify_ast_subtree(tree, body_attr="body", is_root=False):
|
||||
"""Parses magic commands and modifies the given AST (sub)tree."""
|
||||
|
||||
body = getattr(tree, body_attr)
|
||||
|
||||
for i, node in enumerate(body):
|
||||
node_type = type(node)
|
||||
|
||||
# Parse the contents of functions, With statements, and for statements
|
||||
if (
|
||||
node_type is ast.FunctionDef
|
||||
or node_type is ast.With
|
||||
or node_type is ast.For
|
||||
or node_type is ast.While
|
||||
or node_type is ast.AsyncFunctionDef
|
||||
or node_type is ast.AsyncWith
|
||||
or node_type is ast.AsyncFor
|
||||
):
|
||||
_modify_ast_subtree(node)
|
||||
|
||||
# Parse the contents of try statements
|
||||
elif node_type is ast.Try:
|
||||
for j, inner_node in enumerate(node.handlers):
|
||||
node.handlers[j] = _modify_ast_subtree(inner_node)
|
||||
finally_node = _modify_ast_subtree(node, body_attr="finalbody")
|
||||
node.finalbody = finally_node.finalbody
|
||||
_modify_ast_subtree(node)
|
||||
|
||||
# Convert if expressions to st.write
|
||||
elif node_type is ast.If:
|
||||
_modify_ast_subtree(node)
|
||||
_modify_ast_subtree(node, "orelse")
|
||||
|
||||
# Convert standalone expression nodes to st.write
|
||||
elif node_type is ast.Expr:
|
||||
value = _get_st_write_from_expr(node, i, parent_type=type(tree))
|
||||
if value is not None:
|
||||
node.value = value
|
||||
|
||||
if is_root:
|
||||
# Import Streamlit so we can use it in the new_value above.
|
||||
_insert_import_statement(tree)
|
||||
|
||||
ast.fix_missing_locations(tree)
|
||||
|
||||
return tree
|
||||
|
||||
|
||||
def _insert_import_statement(tree):
|
||||
"""Insert Streamlit import statement at the top(ish) of the tree."""
|
||||
|
||||
st_import = _build_st_import_statement()
|
||||
|
||||
# If the 0th node is already an import statement, put the Streamlit
|
||||
# import below that, so we don't break "from __future__ import".
|
||||
if tree.body and type(tree.body[0]) in (ast.ImportFrom, ast.Import):
|
||||
tree.body.insert(1, st_import)
|
||||
|
||||
# If the 0th node is a docstring and the 1st is an import statement,
|
||||
# put the Streamlit import below those, so we don't break "from
|
||||
# __future__ import".
|
||||
elif (
|
||||
len(tree.body) > 1
|
||||
and (type(tree.body[0]) is ast.Expr and _is_docstring_node(tree.body[0].value))
|
||||
and type(tree.body[1]) in (ast.ImportFrom, ast.Import)
|
||||
):
|
||||
tree.body.insert(2, st_import)
|
||||
|
||||
else:
|
||||
tree.body.insert(0, st_import)
|
||||
|
||||
|
||||
def _build_st_import_statement():
|
||||
"""Build AST node for `import magic_funcs as __streamlitmagic__`."""
|
||||
return ast.Import(
|
||||
names=[
|
||||
ast.alias(
|
||||
name="streamlit.runtime.scriptrunner.magic_funcs",
|
||||
asname=MAGIC_MODULE_NAME,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _build_st_write_call(nodes):
|
||||
"""Build AST node for `__streamlitmagic__.transparent_write(*nodes)`."""
|
||||
return ast.Call(
|
||||
func=ast.Attribute(
|
||||
attr="transparent_write",
|
||||
value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()),
|
||||
ctx=ast.Load(),
|
||||
),
|
||||
args=nodes,
|
||||
keywords=[],
|
||||
kwargs=None,
|
||||
starargs=None,
|
||||
)
|
||||
|
||||
|
||||
def _get_st_write_from_expr(node, i, parent_type):
|
||||
# Don't change function calls
|
||||
if type(node.value) is ast.Call:
|
||||
return None
|
||||
|
||||
# Don't change Docstring nodes
|
||||
if (
|
||||
i == 0
|
||||
and _is_docstring_node(node.value)
|
||||
and parent_type in (ast.FunctionDef, ast.Module)
|
||||
):
|
||||
return None
|
||||
|
||||
# Don't change yield nodes
|
||||
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
|
||||
return None
|
||||
|
||||
# Don't change await nodes
|
||||
if type(node.value) is ast.Await:
|
||||
return None
|
||||
|
||||
# If tuple, call st.write on the 0th element (rather than the
|
||||
# whole tuple). This allows us to add a comma at the end of a statement
|
||||
# to turn it into an expression that should be st-written. Ex:
|
||||
# "np.random.randn(1000, 2),"
|
||||
if type(node.value) is ast.Tuple:
|
||||
args = node.value.elts
|
||||
st_write = _build_st_write_call(args)
|
||||
|
||||
# st.write all strings.
|
||||
elif type(node.value) is ast.Str:
|
||||
args = [node.value]
|
||||
st_write = _build_st_write_call(args)
|
||||
|
||||
# st.write all variables.
|
||||
elif type(node.value) is ast.Name:
|
||||
args = [node.value]
|
||||
st_write = _build_st_write_call(args)
|
||||
|
||||
# st.write everything else
|
||||
else:
|
||||
args = [node.value]
|
||||
st_write = _build_st_write_call(args)
|
||||
|
||||
return st_write
|
||||
|
||||
|
||||
def _is_docstring_node(node):
|
||||
if sys.version_info >= (3, 8, 0):
|
||||
return type(node) is ast.Constant and type(node.value) is str
|
||||
else:
|
||||
return type(node) is ast.Str
|
||||
@@ -0,0 +1,30 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
|
||||
|
||||
@gather_metrics("magic")
|
||||
def transparent_write(*args: Any) -> Any:
|
||||
"""The function that gets magic-ified into Streamlit apps.
|
||||
This is just st.write, but returns the arguments you passed to it.
|
||||
"""
|
||||
import streamlit as st
|
||||
|
||||
st.write(*args)
|
||||
if len(args) == 1:
|
||||
return args[0]
|
||||
return args
|
||||
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, cast
|
||||
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates
|
||||
from streamlit.runtime.state import coalesce_widget_states
|
||||
|
||||
|
||||
class ScriptRequestType(Enum):
|
||||
# The ScriptRunner should continue running its script.
|
||||
CONTINUE = "CONTINUE"
|
||||
|
||||
# If the script is running, it should be stopped as soon
|
||||
# as the ScriptRunner reaches an interrupt point.
|
||||
# This is a terminal state.
|
||||
STOP = "STOP"
|
||||
|
||||
# A script rerun has been requested. The ScriptRunner should
|
||||
# handle this request as soon as it reaches an interrupt point.
|
||||
RERUN = "RERUN"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RerunData:
|
||||
"""Data attached to RERUN requests. Immutable."""
|
||||
|
||||
query_string: str = ""
|
||||
widget_states: Optional[WidgetStates] = None
|
||||
page_script_hash: str = ""
|
||||
page_name: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScriptRequest:
|
||||
"""A STOP or RERUN request and associated data."""
|
||||
|
||||
type: ScriptRequestType
|
||||
_rerun_data: Optional[RerunData] = None
|
||||
|
||||
@property
|
||||
def rerun_data(self) -> RerunData:
|
||||
if self.type is not ScriptRequestType.RERUN:
|
||||
raise RuntimeError("RerunData is only set for RERUN requests.")
|
||||
return cast(RerunData, self._rerun_data)
|
||||
|
||||
|
||||
class ScriptRequests:
|
||||
"""An interface for communicating with a ScriptRunner. Thread-safe.
|
||||
|
||||
AppSession makes requests of a ScriptRunner through this class, and
|
||||
ScriptRunner handles those requests.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self._state = ScriptRequestType.CONTINUE
|
||||
self._rerun_data = RerunData()
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""Request that the ScriptRunner stop running. A stopped ScriptRunner
|
||||
can't be used anymore. STOP requests succeed unconditionally.
|
||||
"""
|
||||
with self._lock:
|
||||
self._state = ScriptRequestType.STOP
|
||||
|
||||
def request_rerun(self, new_data: RerunData) -> bool:
|
||||
"""Request that the ScriptRunner rerun its script.
|
||||
|
||||
If the ScriptRunner has been stopped, this request can't be honored:
|
||||
return False.
|
||||
|
||||
Otherwise, record the request and return True. The ScriptRunner will
|
||||
handle the rerun request as soon as it reaches an interrupt point.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
if self._state == ScriptRequestType.STOP:
|
||||
# We can't rerun after being stopped.
|
||||
return False
|
||||
|
||||
if self._state == ScriptRequestType.CONTINUE:
|
||||
# If we're running, we can handle a rerun request
|
||||
# unconditionally.
|
||||
self._state = ScriptRequestType.RERUN
|
||||
self._rerun_data = new_data
|
||||
return True
|
||||
|
||||
if self._state == ScriptRequestType.RERUN:
|
||||
# If we have an existing Rerun request, we coalesce this
|
||||
# new request into it.
|
||||
if self._rerun_data.widget_states is None:
|
||||
# The existing request's widget_states is None, which
|
||||
# means it wants to rerun with whatever the most
|
||||
# recent script execution's widget state was.
|
||||
# We have no meaningful state to merge with, and
|
||||
# so we simply overwrite the existing request.
|
||||
self._rerun_data = new_data
|
||||
return True
|
||||
|
||||
if new_data.widget_states is not None:
|
||||
# Both the existing and the new request have
|
||||
# non-null widget_states. Merge them together.
|
||||
coalesced_states = coalesce_widget_states(
|
||||
self._rerun_data.widget_states, new_data.widget_states
|
||||
)
|
||||
self._rerun_data = RerunData(
|
||||
query_string=new_data.query_string,
|
||||
widget_states=coalesced_states,
|
||||
page_script_hash=new_data.page_script_hash,
|
||||
page_name=new_data.page_name,
|
||||
)
|
||||
return True
|
||||
|
||||
# If old widget_states is NOT None, and new widget_states IS
|
||||
# None, then this new request is entirely redundant. Leave
|
||||
# our existing rerun_data as is.
|
||||
return True
|
||||
|
||||
# We'll never get here
|
||||
raise RuntimeError(f"Unrecognized ScriptRunnerState: {self._state}")
|
||||
|
||||
def on_scriptrunner_yield(self) -> Optional[ScriptRequest]:
|
||||
"""Called by the ScriptRunner when it's at a yield point.
|
||||
|
||||
If we have no request, return None.
|
||||
|
||||
If we have a RERUN request, return the request and set our internal
|
||||
state to CONTINUE.
|
||||
|
||||
If we have a STOP request, return the request and remain stopped.
|
||||
"""
|
||||
if self._state == ScriptRequestType.CONTINUE:
|
||||
# We avoid taking a lock in the common case. If a STOP or RERUN
|
||||
# request is received between the `if` and `return`, it will be
|
||||
# handled at the next `on_scriptrunner_yield`, or when
|
||||
# `on_scriptrunner_ready` is called.
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if self._state == ScriptRequestType.RERUN:
|
||||
self._state = ScriptRequestType.CONTINUE
|
||||
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
|
||||
|
||||
assert self._state == ScriptRequestType.STOP
|
||||
return ScriptRequest(ScriptRequestType.STOP)
|
||||
|
||||
def on_scriptrunner_ready(self) -> ScriptRequest:
|
||||
"""Called by the ScriptRunner when it's about to run its script for
|
||||
the first time, and also after its script has successfully completed.
|
||||
|
||||
If we have a RERUN request, return the request and set
|
||||
our internal state to CONTINUE.
|
||||
|
||||
If we have a STOP request or no request, set our internal state
|
||||
to STOP.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state == ScriptRequestType.RERUN:
|
||||
self._state = ScriptRequestType.CONTINUE
|
||||
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
|
||||
|
||||
# If we don't have a rerun request, unconditionally change our
|
||||
# state to STOP.
|
||||
self._state = ScriptRequestType.STOP
|
||||
return ScriptRequest(ScriptRequestType.STOP)
|
||||
@@ -0,0 +1,170 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Counter, Dict, List, Optional, Set
|
||||
|
||||
from typing_extensions import Final, TypeAlias
|
||||
|
||||
from streamlit import runtime
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.proto.PageProfile_pb2 import Command
|
||||
from streamlit.runtime.state import SafeSessionState
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
|
||||
LOGGER: Final = get_logger(__name__)
|
||||
|
||||
UserInfo: TypeAlias = Dict[str, Optional[str]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptRunContext:
|
||||
"""A context object that contains data for a "script run" - that is,
|
||||
data that's scoped to a single ScriptRunner execution (and therefore also
|
||||
scoped to a single connected "session").
|
||||
|
||||
ScriptRunContext is used internally by virtually every `st.foo()` function.
|
||||
It is accessed only from the script thread that's created by ScriptRunner.
|
||||
|
||||
Streamlit code typically retrieves the active ScriptRunContext via the
|
||||
`get_script_run_ctx` function.
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
_enqueue: Callable[[ForwardMsg], None]
|
||||
query_string: str
|
||||
session_state: SafeSessionState
|
||||
uploaded_file_mgr: UploadedFileManager
|
||||
page_script_hash: str
|
||||
user_info: UserInfo
|
||||
|
||||
gather_usage_stats: bool = False
|
||||
command_tracking_deactivated: bool = False
|
||||
tracked_commands: List[Command] = field(default_factory=list)
|
||||
tracked_commands_counter: Counter[str] = field(default_factory=collections.Counter)
|
||||
_set_page_config_allowed: bool = True
|
||||
_has_script_started: bool = False
|
||||
widget_ids_this_run: Set[str] = field(default_factory=set)
|
||||
widget_user_keys_this_run: Set[str] = field(default_factory=set)
|
||||
form_ids_this_run: Set[str] = field(default_factory=set)
|
||||
cursors: Dict[int, "streamlit.cursor.RunningCursor"] = field(default_factory=dict)
|
||||
dg_stack: List["streamlit.delta_generator.DeltaGenerator"] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def reset(self, query_string: str = "", page_script_hash: str = "") -> None:
|
||||
self.cursors = {}
|
||||
self.widget_ids_this_run = set()
|
||||
self.widget_user_keys_this_run = set()
|
||||
self.form_ids_this_run = set()
|
||||
self.query_string = query_string
|
||||
self.page_script_hash = page_script_hash
|
||||
# Permit set_page_config when the ScriptRunContext is reused on a rerun
|
||||
self._set_page_config_allowed = True
|
||||
self._has_script_started = False
|
||||
self.command_tracking_deactivated: bool = False
|
||||
self.tracked_commands = []
|
||||
self.tracked_commands_counter = collections.Counter()
|
||||
|
||||
def on_script_start(self) -> None:
|
||||
self._has_script_started = True
|
||||
|
||||
def enqueue(self, msg: ForwardMsg) -> None:
|
||||
"""Enqueue a ForwardMsg for this context's session."""
|
||||
if msg.HasField("page_config_changed") and not self._set_page_config_allowed:
|
||||
raise StreamlitAPIException(
|
||||
"`set_page_config()` can only be called once per app, "
|
||||
+ "and must be called as the first Streamlit command in your script.\n\n"
|
||||
+ "For more information refer to the [docs]"
|
||||
+ "(https://docs.streamlit.io/library/api-reference/utilities/st.set_page_config)."
|
||||
)
|
||||
|
||||
# We want to disallow set_page config if one of the following occurs:
|
||||
# - set_page_config was called on this message
|
||||
# - The script has already started and a different st call occurs (a delta)
|
||||
if msg.HasField("page_config_changed") or (
|
||||
msg.HasField("delta") and self._has_script_started
|
||||
):
|
||||
self._set_page_config_allowed = False
|
||||
|
||||
# Pass the message up to our associated ScriptRunner.
|
||||
self._enqueue(msg)
|
||||
|
||||
|
||||
SCRIPT_RUN_CONTEXT_ATTR_NAME: Final = "streamlit_script_run_ctx"
|
||||
|
||||
|
||||
def add_script_run_ctx(
|
||||
thread: Optional[threading.Thread] = None, ctx: Optional[ScriptRunContext] = None
|
||||
):
|
||||
"""Adds the current ScriptRunContext to a newly-created thread.
|
||||
|
||||
This should be called from this thread's parent thread,
|
||||
before the new thread starts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
thread : threading.Thread
|
||||
The thread to attach the current ScriptRunContext to.
|
||||
ctx : ScriptRunContext or None
|
||||
The ScriptRunContext to add, or None to use the current thread's
|
||||
ScriptRunContext.
|
||||
|
||||
Returns
|
||||
-------
|
||||
threading.Thread
|
||||
The same thread that was passed in, for chaining.
|
||||
|
||||
"""
|
||||
if thread is None:
|
||||
thread = threading.current_thread()
|
||||
if ctx is None:
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is not None:
|
||||
setattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, ctx)
|
||||
return thread
|
||||
|
||||
|
||||
def get_script_run_ctx(suppress_warning: bool = False) -> Optional[ScriptRunContext]:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
suppress_warning : bool
|
||||
If True, don't log a warning if there's no ScriptRunContext.
|
||||
Returns
|
||||
-------
|
||||
ScriptRunContext | None
|
||||
The current thread's ScriptRunContext, or None if it doesn't have one.
|
||||
|
||||
"""
|
||||
thread = threading.current_thread()
|
||||
ctx: Optional[ScriptRunContext] = getattr(
|
||||
thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, None
|
||||
)
|
||||
if ctx is None and runtime.exists() and not suppress_warning:
|
||||
# Only warn about a missing ScriptRunContext if suppress_warning is False, and
|
||||
# we were started via `streamlit run`. Otherwise, the user is likely running a
|
||||
# script "bare", and doesn't need to be warned about streamlit
|
||||
# bits that are irrelevant when not connected to a session.
|
||||
LOGGER.warning("Thread '%s': missing ScriptRunContext", thread.name)
|
||||
|
||||
return ctx
|
||||
|
||||
|
||||
# Needed to avoid circular dependencies while running tests.
|
||||
import streamlit
|
||||
@@ -0,0 +1,704 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from timeit import default_timer as timer
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
from blinker import Signal
|
||||
|
||||
from streamlit import config, runtime, source_util, util
|
||||
from streamlit.error_util import handle_uncaught_app_exception
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.ClientState_pb2 import ClientState
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.scriptrunner import magic
|
||||
from streamlit.runtime.scriptrunner.script_requests import (
|
||||
RerunData,
|
||||
ScriptRequests,
|
||||
ScriptRequestType,
|
||||
)
|
||||
from streamlit.runtime.scriptrunner.script_run_context import (
|
||||
ScriptRunContext,
|
||||
add_script_run_ctx,
|
||||
get_script_run_ctx,
|
||||
)
|
||||
from streamlit.runtime.state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
SafeSessionState,
|
||||
SessionState,
|
||||
)
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
from streamlit.vendor.ipython.modified_sys_path import modified_sys_path
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class ScriptRunnerEvent(Enum):
|
||||
## "Control" events. These are emitted when the ScriptRunner's state changes.
|
||||
|
||||
# The script started running.
|
||||
SCRIPT_STARTED = "SCRIPT_STARTED"
|
||||
|
||||
# The script run stopped because of a compile error.
|
||||
SCRIPT_STOPPED_WITH_COMPILE_ERROR = "SCRIPT_STOPPED_WITH_COMPILE_ERROR"
|
||||
|
||||
# The script run stopped because it ran to completion, or was
|
||||
# interrupted by the user.
|
||||
SCRIPT_STOPPED_WITH_SUCCESS = "SCRIPT_STOPPED_WITH_SUCCESS"
|
||||
|
||||
# The script run stopped in order to start a script run with newer widget state.
|
||||
SCRIPT_STOPPED_FOR_RERUN = "SCRIPT_STOPPED_FOR_RERUN"
|
||||
|
||||
# The ScriptRunner is done processing the ScriptEventQueue and
|
||||
# is shut down.
|
||||
SHUTDOWN = "SHUTDOWN"
|
||||
|
||||
## "Data" events. These are emitted when the ScriptRunner's script has
|
||||
## data to send to the frontend.
|
||||
|
||||
# The script has a ForwardMsg to send to the frontend.
|
||||
ENQUEUE_FORWARD_MSG = "ENQUEUE_FORWARD_MSG"
|
||||
|
||||
|
||||
"""
|
||||
Note [Threading]
|
||||
There are two kinds of threads in Streamlit, the main thread and script threads.
|
||||
The main thread is started by invoking the Streamlit CLI, and bootstraps the
|
||||
framework and runs the Tornado webserver.
|
||||
A script thread is created by a ScriptRunner when it starts. The script thread
|
||||
is where the ScriptRunner executes, including running the user script itself,
|
||||
processing messages to/from the frontend, and all the Streamlit library function
|
||||
calls in the user script.
|
||||
It is possible for the user script to spawn its own threads, which could call
|
||||
Streamlit functions. We restrict the ScriptRunner's execution control to the
|
||||
script thread. Calling Streamlit functions from other threads is unlikely to
|
||||
work correctly due to lack of ScriptRunContext, so we may add a guard against
|
||||
it in the future.
|
||||
"""
|
||||
|
||||
|
||||
class ScriptRunner:
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
main_script_path: str,
|
||||
client_state: ClientState,
|
||||
session_state: SessionState,
|
||||
uploaded_file_mgr: UploadedFileManager,
|
||||
initial_rerun_data: RerunData,
|
||||
user_info: Dict[str, Optional[str]],
|
||||
):
|
||||
"""Initialize the ScriptRunner.
|
||||
|
||||
(The ScriptRunner won't start executing until start() is called.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id : str
|
||||
The AppSession's id.
|
||||
|
||||
main_script_path : str
|
||||
Path to our main app script.
|
||||
|
||||
client_state : ClientState
|
||||
The current state from the client (widgets and query params).
|
||||
|
||||
uploaded_file_mgr : UploadedFileManager
|
||||
The File manager to store the data uploaded by the file_uploader widget.
|
||||
|
||||
user_info: Dict
|
||||
A dict that contains information about the current user. For now,
|
||||
it only contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
|
||||
Information about the current user is optionally provided when a
|
||||
websocket connection is initialized via the "X-Streamlit-User" header.
|
||||
|
||||
"""
|
||||
self._session_id = session_id
|
||||
self._main_script_path = main_script_path
|
||||
self._uploaded_file_mgr = uploaded_file_mgr
|
||||
self._user_info = user_info
|
||||
|
||||
# Initialize SessionState with the latest widget states
|
||||
session_state.set_widgets_from_proto(client_state.widget_states)
|
||||
|
||||
self._client_state = client_state
|
||||
self._session_state = SafeSessionState(session_state)
|
||||
|
||||
self._requests = ScriptRequests()
|
||||
self._requests.request_rerun(initial_rerun_data)
|
||||
|
||||
self.on_event = Signal(
|
||||
doc="""Emitted when a ScriptRunnerEvent occurs.
|
||||
|
||||
This signal is generally emitted on the ScriptRunner's script
|
||||
thread (which is *not* the same thread that the ScriptRunner was
|
||||
created on).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sender: ScriptRunner
|
||||
The sender of the event (this ScriptRunner).
|
||||
|
||||
event : ScriptRunnerEvent
|
||||
|
||||
forward_msg : ForwardMsg | None
|
||||
The ForwardMsg to send to the frontend. Set only for the
|
||||
ENQUEUE_FORWARD_MSG event.
|
||||
|
||||
exception : BaseException | None
|
||||
Our compile error. Set only for the
|
||||
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
|
||||
|
||||
widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
|
||||
The ScriptRunner's final WidgetStates. Set only for the
|
||||
SHUTDOWN event.
|
||||
"""
|
||||
)
|
||||
|
||||
# Set to true while we're executing. Used by
|
||||
# _maybe_handle_execution_control_request.
|
||||
self._execing = False
|
||||
|
||||
# This is initialized in start()
|
||||
self._script_thread: Optional[threading.Thread] = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""Request that the ScriptRunner stop running its script and
|
||||
shut down. The ScriptRunner will handle this request when it reaches
|
||||
an interrupt point.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
self._requests.request_stop()
|
||||
|
||||
# "Disconnect" our SafeSessionState wrapper from its underlying
|
||||
# SessionState instance. This will cause all further session_state
|
||||
# operations in this ScriptRunner to no-op.
|
||||
#
|
||||
# After `request_stop` is called, our script will continue executing
|
||||
# until it reaches a yield point. AppSession may also *immediately*
|
||||
# spin up a new ScriptRunner after this call, which means we'll
|
||||
# potentially have two active ScriptRunners for a brief period while
|
||||
# this one is shutting down. Disconnecting our SessionState ensures
|
||||
# that this ScriptRunner's thread won't introduce SessionState-
|
||||
# related race conditions during this script overlap.
|
||||
self._session_state.disconnect()
|
||||
|
||||
def request_rerun(self, rerun_data: RerunData) -> bool:
|
||||
"""Request that the ScriptRunner interrupt its currently-running
|
||||
script and restart it.
|
||||
|
||||
If the ScriptRunner has been stopped, this request can't be honored:
|
||||
return False.
|
||||
|
||||
Otherwise, record the request and return True. The ScriptRunner will
|
||||
handle the rerun request as soon as it reaches an interrupt point.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
return self._requests.request_rerun(rerun_data)
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start a new thread to process the ScriptEventQueue.
|
||||
|
||||
This must be called only once.
|
||||
|
||||
"""
|
||||
if self._script_thread is not None:
|
||||
raise Exception("ScriptRunner was already started")
|
||||
|
||||
self._script_thread = threading.Thread(
|
||||
target=self._run_script_thread,
|
||||
name="ScriptRunner.scriptThread",
|
||||
)
|
||||
self._script_thread.start()
|
||||
|
||||
def _get_script_run_ctx(self) -> ScriptRunContext:
|
||||
"""Get the ScriptRunContext for the current thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ScriptRunContext
|
||||
The ScriptRunContext for the current thread.
|
||||
|
||||
Raises
|
||||
------
|
||||
AssertionError
|
||||
If called outside of a ScriptRunner thread.
|
||||
RuntimeError
|
||||
If there is no ScriptRunContext for the current thread.
|
||||
|
||||
"""
|
||||
assert self._is_in_script_thread()
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
# This should never be possible on the script_runner thread.
|
||||
raise RuntimeError(
|
||||
"ScriptRunner thread has a null ScriptRunContext. Something has gone very wrong!"
|
||||
)
|
||||
return ctx
|
||||
|
||||
def _run_script_thread(self) -> None:
|
||||
"""The entry point for the script thread.
|
||||
|
||||
Processes the ScriptRequestQueue, which will at least contain the RERUN
|
||||
request that will trigger the first script-run.
|
||||
|
||||
When the ScriptRequestQueue is empty, or when a SHUTDOWN request is
|
||||
dequeued, this function will exit and its thread will terminate.
|
||||
"""
|
||||
assert self._is_in_script_thread()
|
||||
|
||||
_LOGGER.debug("Beginning script thread")
|
||||
|
||||
# Create and attach the thread's ScriptRunContext
|
||||
ctx = ScriptRunContext(
|
||||
session_id=self._session_id,
|
||||
_enqueue=self._enqueue_forward_msg,
|
||||
query_string=self._client_state.query_string,
|
||||
session_state=self._session_state,
|
||||
uploaded_file_mgr=self._uploaded_file_mgr,
|
||||
page_script_hash=self._client_state.page_script_hash,
|
||||
user_info=self._user_info,
|
||||
gather_usage_stats=bool(config.get_option("browser.gatherUsageStats")),
|
||||
)
|
||||
add_script_run_ctx(threading.current_thread(), ctx)
|
||||
|
||||
request = self._requests.on_scriptrunner_ready()
|
||||
while request.type == ScriptRequestType.RERUN:
|
||||
# When the script thread starts, we'll have a pending rerun
|
||||
# request that we'll handle immediately. When the script finishes,
|
||||
# it's possible that another request has come in that we need to
|
||||
# handle, which is why we call _run_script in a loop.
|
||||
self._run_script(request.rerun_data)
|
||||
request = self._requests.on_scriptrunner_ready()
|
||||
|
||||
assert request.type == ScriptRequestType.STOP
|
||||
|
||||
# Send a SHUTDOWN event before exiting. This includes the widget values
|
||||
# as they existed after our last successful script run, which the
|
||||
# AppSession will pass on to the next ScriptRunner that gets
|
||||
# created.
|
||||
client_state = ClientState()
|
||||
client_state.query_string = ctx.query_string
|
||||
client_state.page_script_hash = ctx.page_script_hash
|
||||
widget_states = self._session_state.get_widget_states()
|
||||
client_state.widget_states.widgets.extend(widget_states)
|
||||
self.on_event.send(
|
||||
self, event=ScriptRunnerEvent.SHUTDOWN, client_state=client_state
|
||||
)
|
||||
|
||||
def _is_in_script_thread(self) -> bool:
|
||||
"""True if the calling function is running in the script thread"""
|
||||
return self._script_thread == threading.current_thread()
|
||||
|
||||
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Enqueue a ForwardMsg to our browser queue.
|
||||
This private function is called by ScriptRunContext only.
|
||||
|
||||
It may be called from the script thread OR the main thread.
|
||||
"""
|
||||
# Whenever we enqueue a ForwardMsg, we also handle any pending
|
||||
# execution control request. This means that a script can be
|
||||
# cleanly interrupted and stopped inside most `st.foo` calls.
|
||||
#
|
||||
# (If "runner.installTracer" is true, then we'll actually be
|
||||
# handling these requests in a callback called after every Python
|
||||
# instruction instead.)
|
||||
if not config.get_option("runner.installTracer"):
|
||||
self._maybe_handle_execution_control_request()
|
||||
|
||||
# Pass the message to our associated AppSession.
|
||||
self.on_event.send(
|
||||
self, event=ScriptRunnerEvent.ENQUEUE_FORWARD_MSG, forward_msg=msg
|
||||
)
|
||||
|
||||
def _maybe_handle_execution_control_request(self) -> None:
|
||||
"""Check our current ScriptRequestState to see if we have a
|
||||
pending STOP or RERUN request.
|
||||
|
||||
This function is called every time the app script enqueues a
|
||||
ForwardMsg, which means that most `st.foo` commands - which generally
|
||||
involve sending a ForwardMsg to the frontend - act as implicit
|
||||
yield points in the script's execution.
|
||||
"""
|
||||
if not self._is_in_script_thread():
|
||||
# We can only handle execution_control_request if we're on the
|
||||
# script execution thread. However, it's possible for deltas to
|
||||
# be enqueued (and, therefore, for this function to be called)
|
||||
# in separate threads, so we check for that here.
|
||||
return
|
||||
|
||||
if not self._execing:
|
||||
# If the _execing flag is not set, we're not actually inside
|
||||
# an exec() call. This happens when our script exec() completes,
|
||||
# we change our state to STOPPED, and a statechange-listener
|
||||
# enqueues a new ForwardEvent
|
||||
return
|
||||
|
||||
request = self._requests.on_scriptrunner_yield()
|
||||
if request is None:
|
||||
# No RERUN or STOP request.
|
||||
return
|
||||
|
||||
if request.type == ScriptRequestType.RERUN:
|
||||
raise RerunException(request.rerun_data)
|
||||
|
||||
assert request.type == ScriptRequestType.STOP
|
||||
raise StopException()
|
||||
|
||||
def _install_tracer(self) -> None:
|
||||
"""Install function that runs before each line of the script."""
|
||||
|
||||
def trace_calls(frame, event, arg):
|
||||
self._maybe_handle_execution_control_request()
|
||||
return trace_calls
|
||||
|
||||
# Python interpreters are not required to implement sys.settrace.
|
||||
if hasattr(sys, "settrace"):
|
||||
sys.settrace(trace_calls)
|
||||
|
||||
@contextmanager
|
||||
def _set_execing_flag(self):
|
||||
"""A context for setting the ScriptRunner._execing flag.
|
||||
|
||||
Used by _maybe_handle_execution_control_request to ensure that
|
||||
we only handle requests while we're inside an exec() call
|
||||
"""
|
||||
if self._execing:
|
||||
raise RuntimeError("Nested set_execing_flag call")
|
||||
self._execing = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._execing = False
|
||||
|
||||
def _run_script(self, rerun_data: RerunData) -> None:
|
||||
"""Run our script.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rerun_data: RerunData
|
||||
The RerunData to use.
|
||||
|
||||
"""
|
||||
assert self._is_in_script_thread()
|
||||
|
||||
_LOGGER.debug("Running script %s", rerun_data)
|
||||
|
||||
start_time: float = timer()
|
||||
prep_time: float = 0 # This will be overwritten once preparations are done.
|
||||
|
||||
# Reset DeltaGenerators, widgets, media files.
|
||||
runtime.get_instance().media_file_mgr.clear_session_refs()
|
||||
|
||||
main_script_path = self._main_script_path
|
||||
pages = source_util.get_pages(main_script_path)
|
||||
# Safe because pages will at least contain the app's main page.
|
||||
main_page_info = list(pages.values())[0]
|
||||
current_page_info = None
|
||||
uncaught_exception = None
|
||||
|
||||
if rerun_data.page_script_hash:
|
||||
current_page_info = pages.get(rerun_data.page_script_hash, None)
|
||||
elif not rerun_data.page_script_hash and rerun_data.page_name:
|
||||
# If a user navigates directly to a non-main page of an app, we get
|
||||
# the first script run request before the list of pages has been
|
||||
# sent to the frontend. In this case, we choose the first script
|
||||
# with a name matching the requested page name.
|
||||
current_page_info = next(
|
||||
filter(
|
||||
# There seems to be this weird bug with mypy where it
|
||||
# thinks that p can be None (which is impossible given the
|
||||
# types of pages), so we add `p and` at the beginning of
|
||||
# the predicate to circumvent this.
|
||||
lambda p: p and (p["page_name"] == rerun_data.page_name),
|
||||
pages.values(),
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
# If no information about what page to run is given, default to
|
||||
# running the main page.
|
||||
current_page_info = main_page_info
|
||||
|
||||
page_script_hash = (
|
||||
current_page_info["page_script_hash"]
|
||||
if current_page_info is not None
|
||||
else main_page_info["page_script_hash"]
|
||||
)
|
||||
|
||||
ctx = self._get_script_run_ctx()
|
||||
ctx.reset(
|
||||
query_string=rerun_data.query_string,
|
||||
page_script_hash=page_script_hash,
|
||||
)
|
||||
|
||||
self.on_event.send(
|
||||
self,
|
||||
event=ScriptRunnerEvent.SCRIPT_STARTED,
|
||||
page_script_hash=page_script_hash,
|
||||
)
|
||||
|
||||
# Compile the script. Any errors thrown here will be surfaced
|
||||
# to the user via a modal dialog in the frontend, and won't result
|
||||
# in their previous script elements disappearing.
|
||||
try:
|
||||
if current_page_info:
|
||||
script_path = current_page_info["script_path"]
|
||||
else:
|
||||
script_path = main_script_path
|
||||
|
||||
# At this point, we know that either
|
||||
# * the script corresponding to the hash requested no longer
|
||||
# exists, or
|
||||
# * we were not able to find a script with the requested page
|
||||
# name.
|
||||
# In both of these cases, we want to send a page_not_found
|
||||
# message to the frontend.
|
||||
msg = ForwardMsg()
|
||||
msg.page_not_found.page_name = rerun_data.page_name
|
||||
ctx.enqueue(msg)
|
||||
|
||||
with source_util.open_python_file(script_path) as f:
|
||||
filebody = f.read()
|
||||
|
||||
if config.get_option("runner.magicEnabled"):
|
||||
filebody = magic.add_magic(filebody, script_path)
|
||||
|
||||
code = compile( # type: ignore
|
||||
filebody,
|
||||
# Pass in the file path so it can show up in exceptions.
|
||||
script_path,
|
||||
# We're compiling entire blocks of Python, so we need "exec"
|
||||
# mode (as opposed to "eval" or "single").
|
||||
mode="exec",
|
||||
# Don't inherit any flags or "future" statements.
|
||||
flags=0,
|
||||
dont_inherit=1,
|
||||
# Use the default optimization options.
|
||||
optimize=-1,
|
||||
)
|
||||
|
||||
except Exception as ex:
|
||||
# We got a compile error. Send an error event and bail immediately.
|
||||
_LOGGER.debug("Fatal script error: %s", ex)
|
||||
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
|
||||
self.on_event.send(
|
||||
self,
|
||||
event=ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
|
||||
exception=ex,
|
||||
)
|
||||
return
|
||||
|
||||
# If we get here, we've successfully compiled our script. The next step
|
||||
# is to run it. Errors thrown during execution will be shown to the
|
||||
# user as ExceptionElements.
|
||||
|
||||
if config.get_option("runner.installTracer"):
|
||||
self._install_tracer()
|
||||
|
||||
# This will be set to a RerunData instance if our execution
|
||||
# is interrupted by a RerunException.
|
||||
rerun_exception_data: Optional[RerunData] = None
|
||||
|
||||
try:
|
||||
# Create fake module. This gives us a name global namespace to
|
||||
# execute the code in.
|
||||
# TODO(vdonato): Double-check that we're okay with naming the
|
||||
# module for every page `__main__`. I'm pretty sure this is
|
||||
# necessary given that people will likely often write
|
||||
# ```
|
||||
# if __name__ == "__main__":
|
||||
# ...
|
||||
# ```
|
||||
# in their scripts.
|
||||
module = _new_module("__main__")
|
||||
|
||||
# Install the fake module as the __main__ module. This allows
|
||||
# the pickle module to work inside the user's code, since it now
|
||||
# can know the module where the pickled objects stem from.
|
||||
# IMPORTANT: This means we can't use "if __name__ == '__main__'" in
|
||||
# our code, as it will point to the wrong module!!!
|
||||
sys.modules["__main__"] = module
|
||||
|
||||
# Add special variables to the module's globals dict.
|
||||
# Note: The following is a requirement for the CodeHasher to
|
||||
# work correctly. The CodeHasher is scoped to
|
||||
# files contained in the directory of __main__.__file__, which we
|
||||
# assume is the main script directory.
|
||||
module.__dict__["__file__"] = script_path
|
||||
|
||||
with modified_sys_path(self._main_script_path), self._set_execing_flag():
|
||||
# Run callbacks for widgets whose values have changed.
|
||||
if rerun_data.widget_states is not None:
|
||||
self._session_state.on_script_will_rerun(rerun_data.widget_states)
|
||||
|
||||
ctx.on_script_start()
|
||||
prep_time = timer() - start_time
|
||||
exec(code, module.__dict__)
|
||||
self._session_state.maybe_check_serializable()
|
||||
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = True
|
||||
except RerunException as e:
|
||||
rerun_exception_data = e.rerun_data
|
||||
|
||||
except StopException:
|
||||
# This is thrown when the script executes `st.stop()`.
|
||||
# We don't have to do anything here.
|
||||
pass
|
||||
|
||||
except Exception as ex:
|
||||
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
|
||||
uncaught_exception = ex
|
||||
handle_uncaught_app_exception(uncaught_exception)
|
||||
|
||||
finally:
|
||||
if rerun_exception_data:
|
||||
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN
|
||||
else:
|
||||
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
|
||||
|
||||
if ctx.gather_usage_stats:
|
||||
try:
|
||||
# Prevent issues with circular import
|
||||
from streamlit.runtime.metrics_util import (
|
||||
create_page_profile_message,
|
||||
to_microseconds,
|
||||
)
|
||||
|
||||
# Create and send page profile information
|
||||
ctx.enqueue(
|
||||
create_page_profile_message(
|
||||
ctx.tracked_commands,
|
||||
exec_time=to_microseconds(timer() - start_time),
|
||||
prep_time=to_microseconds(prep_time),
|
||||
uncaught_exception=type(uncaught_exception).__name__
|
||||
if uncaught_exception
|
||||
else None,
|
||||
)
|
||||
)
|
||||
except Exception as ex:
|
||||
# Always capture all exceptions since we want to make sure that
|
||||
# the telemetry never causes any issues.
|
||||
_LOGGER.debug("Failed to create page profile", exc_info=ex)
|
||||
self._on_script_finished(ctx, finished_event)
|
||||
|
||||
# Use _log_if_error() to make sure we never ever ever stop running the
|
||||
# script without meaning to.
|
||||
_log_if_error(_clean_problem_modules)
|
||||
|
||||
if rerun_exception_data is not None:
|
||||
self._run_script(rerun_exception_data)
|
||||
|
||||
def _on_script_finished(
|
||||
self, ctx: ScriptRunContext, event: ScriptRunnerEvent
|
||||
) -> None:
|
||||
"""Called when our script finishes executing, even if it finished
|
||||
early with an exception. We perform post-run cleanup here.
|
||||
"""
|
||||
# Tell session_state to update itself in response
|
||||
self._session_state.on_script_finished(ctx.widget_ids_this_run)
|
||||
|
||||
# Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS
|
||||
# even if we were stopped with an exception.)
|
||||
self.on_event.send(self, event=event)
|
||||
|
||||
# Remove orphaned files now that the script has run and files in use
|
||||
# are marked as active.
|
||||
runtime.get_instance().media_file_mgr.remove_orphaned_files()
|
||||
|
||||
# Force garbage collection to run, to help avoid memory use building up
|
||||
# This is usually not an issue, but sometimes GC takes time to kick in and
|
||||
# causes apps to go over resource limits, and forcing it to run between
|
||||
# script runs is low cost, since we aren't doing much work anyway.
|
||||
if config.get_option("runner.postScriptGC"):
|
||||
gc.collect(2)
|
||||
|
||||
|
||||
class ScriptControlException(BaseException):
|
||||
"""Base exception for ScriptRunner."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class StopException(ScriptControlException):
|
||||
"""Silently stop the execution of the user's script."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RerunException(ScriptControlException):
|
||||
"""Silently stop and rerun the user's script."""
|
||||
|
||||
def __init__(self, rerun_data: RerunData):
|
||||
"""Construct a RerunException
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rerun_data : RerunData
|
||||
The RerunData that should be used to rerun the script
|
||||
"""
|
||||
self.rerun_data = rerun_data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
def _clean_problem_modules() -> None:
|
||||
"""Some modules are stateful, so we have to clear their state."""
|
||||
|
||||
if "keras" in sys.modules:
|
||||
try:
|
||||
keras = sys.modules["keras"]
|
||||
keras.backend.clear_session()
|
||||
except Exception:
|
||||
# We don't want to crash the app if we can't clear the Keras session.
|
||||
pass
|
||||
|
||||
if "matplotlib.pyplot" in sys.modules:
|
||||
try:
|
||||
plt = sys.modules["matplotlib.pyplot"]
|
||||
plt.close("all")
|
||||
except Exception:
|
||||
# We don't want to crash the app if we can't close matplotlib
|
||||
pass
|
||||
|
||||
|
||||
def _new_module(name: str) -> types.ModuleType:
|
||||
"""Create a new module with the given name."""
|
||||
return types.ModuleType(name)
|
||||
|
||||
|
||||
# The reason this is not a decorator is because we want to make it clear at the
|
||||
# calling location that this function is being used.
|
||||
def _log_if_error(fn: Callable[[], None]) -> None:
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
_LOGGER.warning(e)
|
||||
352
venv/lib/python3.9/site-packages/streamlit/runtime/secrets.py
Normal file
352
venv/lib/python3.9/site-packages/streamlit/runtime/secrets.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import threading
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
ItemsView,
|
||||
Iterator,
|
||||
KeysView,
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
ValuesView,
|
||||
)
|
||||
|
||||
import toml
|
||||
from blinker import Signal
|
||||
from typing_extensions import Final
|
||||
|
||||
import streamlit as st
|
||||
import streamlit.watcher.path_watcher
|
||||
from streamlit import file_util, runtime
|
||||
from streamlit.logger import get_logger
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
SECRETS_FILE_LOCS: Final[List[str]] = [
|
||||
file_util.get_streamlit_file_path("secrets.toml"),
|
||||
# NOTE: The order here is important! Project-level secrets should overwrite global
|
||||
# secrets.
|
||||
file_util.get_project_streamlit_file_path("secrets.toml"),
|
||||
]
|
||||
|
||||
|
||||
def _missing_attr_error_message(attr_name: str) -> str:
|
||||
return (
|
||||
f'st.secrets has no attribute "{attr_name}". '
|
||||
f"Did you forget to add it to secrets.toml or the app settings on Streamlit Cloud? "
|
||||
f"More info: https://docs.streamlit.io/streamlit-cloud/get-started/deploy-an-app/connect-to-data-sources/secrets-management"
|
||||
)
|
||||
|
||||
|
||||
def _missing_key_error_message(key: str) -> str:
|
||||
return (
|
||||
f'st.secrets has no key "{key}". '
|
||||
f"Did you forget to add it to secrets.toml or the app settings on Streamlit Cloud? "
|
||||
f"More info: https://docs.streamlit.io/streamlit-cloud/get-started/deploy-an-app/connect-to-data-sources/secrets-management"
|
||||
)
|
||||
|
||||
|
||||
class AttrDict(Mapping[str, Any]):
|
||||
"""
|
||||
We use AttrDict to wrap up dictionary values from secrets
|
||||
to provide dot access to nested secrets
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
self.__dict__["__nested_secrets__"] = dict(value)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_wrap_in_attr_dict(value) -> Any:
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
else:
|
||||
return AttrDict(value)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.__nested_secrets__)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self.__nested_secrets__)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
try:
|
||||
value = self.__nested_secrets__[key]
|
||||
return self._maybe_wrap_in_attr_dict(value)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def __getattr__(self, attr_name: str) -> Any:
|
||||
try:
|
||||
value = self.__nested_secrets__[attr_name]
|
||||
return self._maybe_wrap_in_attr_dict(value)
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(attr_name))
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.__nested_secrets__)
|
||||
|
||||
def __setitem__(self, key, value) -> NoReturn:
|
||||
raise TypeError("Secrets does not support item assignment.")
|
||||
|
||||
def __setattr__(self, key, value) -> NoReturn:
|
||||
raise TypeError("Secrets does not support attribute assignment.")
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return deepcopy(self.__nested_secrets__)
|
||||
|
||||
|
||||
class Secrets(Mapping[str, Any]):
|
||||
"""A dict-like class that stores secrets.
|
||||
Parses secrets.toml on-demand. Cannot be externally mutated.
|
||||
|
||||
Safe to use from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, file_paths: List[str]):
|
||||
# Our secrets dict.
|
||||
self._secrets: Optional[Mapping[str, Any]] = None
|
||||
self._lock = threading.RLock()
|
||||
self._file_watchers_installed = False
|
||||
self._file_paths = file_paths
|
||||
|
||||
self.file_change_listener = Signal(
|
||||
doc="Emitted when a `secrets.toml` file has been changed."
|
||||
)
|
||||
|
||||
def load_if_toml_exists(self) -> bool:
|
||||
"""Load secrets.toml files from disk if they exists. If none exist,
|
||||
no exception will be raised. (If a file exists but is malformed,
|
||||
an exception *will* be raised.)
|
||||
|
||||
Returns True if a secrets.toml file was successfully parsed, False otherwise.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
try:
|
||||
self._parse(print_exceptions=False)
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
# No secrets.toml files exist. That's fine.
|
||||
return False
|
||||
|
||||
def _reset(self) -> None:
|
||||
"""Clear the secrets dictionary and remove any secrets that were
|
||||
added to os.environ.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._secrets is None:
|
||||
return
|
||||
|
||||
for k, v in self._secrets.items():
|
||||
self._maybe_delete_environment_variable(k, v)
|
||||
self._secrets = None
|
||||
|
||||
def _parse(self, print_exceptions: bool) -> Mapping[str, Any]:
|
||||
"""Parse our secrets.toml files if they're not already parsed.
|
||||
This function is safe to call from multiple threads.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
print_exceptions : bool
|
||||
If True, then exceptions will be printed with `st.error` before
|
||||
being re-raised.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
Raised if secrets.toml doesn't exist.
|
||||
|
||||
"""
|
||||
# Avoid taking a lock for the common case where secrets are already
|
||||
# loaded.
|
||||
secrets = self._secrets
|
||||
if secrets is not None:
|
||||
return secrets
|
||||
|
||||
with self._lock:
|
||||
if self._secrets is not None:
|
||||
return self._secrets
|
||||
|
||||
# It's fine for a user to only have one secrets.toml file defined, so
|
||||
# we ignore individual FileNotFoundErrors when attempting to read files
|
||||
# below and only raise an exception if we weren't able read *any* secrets
|
||||
# file.
|
||||
found_secrets_file = False
|
||||
secrets = {}
|
||||
|
||||
for path in self._file_paths:
|
||||
try:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
secrets_file_str = f.read()
|
||||
found_secrets_file = True
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
secrets.update(toml.loads(secrets_file_str))
|
||||
except:
|
||||
if print_exceptions:
|
||||
st.error(f"Error parsing secrets file at {path}")
|
||||
raise
|
||||
|
||||
if not found_secrets_file:
|
||||
err_msg = f"No secrets files found. Valid paths for a secrets.toml file are: {', '.join(self._file_paths)}"
|
||||
if print_exceptions:
|
||||
st.error(err_msg)
|
||||
raise FileNotFoundError(err_msg)
|
||||
|
||||
if len([p for p in self._file_paths if os.path.exists(p)]) > 1:
|
||||
_LOGGER.info(
|
||||
f"Secrets found in multiple locations: {', '.join(self._file_paths)}. "
|
||||
"When multiple secret.toml files exist, local secrets will take precedence over global secrets."
|
||||
)
|
||||
|
||||
for k, v in secrets.items():
|
||||
self._maybe_set_environment_variable(k, v)
|
||||
|
||||
self._secrets = secrets
|
||||
self._maybe_install_file_watchers()
|
||||
|
||||
return self._secrets
|
||||
|
||||
@staticmethod
|
||||
def _maybe_set_environment_variable(k: Any, v: Any) -> None:
|
||||
"""Add the given key/value pair to os.environ if the value
|
||||
is a string, int, or float.
|
||||
"""
|
||||
value_type = type(v)
|
||||
if value_type in (str, int, float):
|
||||
os.environ[k] = str(v)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_delete_environment_variable(k: Any, v: Any) -> None:
|
||||
"""Remove the given key/value pair from os.environ if the value
|
||||
is a string, int, or float.
|
||||
"""
|
||||
value_type = type(v)
|
||||
if value_type in (str, int, float) and os.environ.get(k) == v:
|
||||
del os.environ[k]
|
||||
|
||||
def _maybe_install_file_watchers(self) -> None:
|
||||
with self._lock:
|
||||
if self._file_watchers_installed:
|
||||
return
|
||||
|
||||
for path in self._file_paths:
|
||||
try:
|
||||
streamlit.watcher.path_watcher.watch_file(
|
||||
path,
|
||||
self._on_secrets_file_changed,
|
||||
watcher_type="poll",
|
||||
)
|
||||
except FileNotFoundError:
|
||||
# A user may only have one secrets.toml file defined, so we'd expect
|
||||
# FileNotFoundErrors to be raised when attempting to install a
|
||||
# watcher on the nonexistent ones.
|
||||
pass
|
||||
|
||||
# We set file_watchers_installed to True even if the installation attempt
|
||||
# failed to avoid repeatedly trying to install it.
|
||||
self._file_watchers_installed = True
|
||||
|
||||
def _on_secrets_file_changed(self, changed_file_path) -> None:
|
||||
with self._lock:
|
||||
_LOGGER.debug("Secrets file %s changed, reloading", changed_file_path)
|
||||
self._reset()
|
||||
self._parse(print_exceptions=True)
|
||||
|
||||
# Emit a signal to notify receivers that the `secrets.toml` file
|
||||
# has been changed.
|
||||
self.file_change_listener.send()
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
"""Return the value with the given key. If no such key
|
||||
exists, raise an AttributeError.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
try:
|
||||
value = self._parse(True)[key]
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
else:
|
||||
return AttrDict(value)
|
||||
# We add FileNotFoundError since __getattr__ is expected to only raise
|
||||
# AttributeError. Without handling FileNotFoundError, unittests.mocks
|
||||
# fails during mock creation on Python3.9
|
||||
except (KeyError, FileNotFoundError):
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
"""Return the value with the given key. If no such key
|
||||
exists, raise a KeyError.
|
||||
|
||||
Thread-safe.
|
||||
"""
|
||||
try:
|
||||
value = self._parse(True)[key]
|
||||
if not isinstance(value, Mapping):
|
||||
return value
|
||||
else:
|
||||
return AttrDict(value)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# If the runtime is NOT initialized, it is a method call outside
|
||||
# the streamlit app, so we avoid reading the secrets file as it may not exist.
|
||||
# If the runtime is initialized, display the contents of the file and
|
||||
# the file must already exist.
|
||||
"""A string representation of the contents of the dict. Thread-safe."""
|
||||
if not runtime.exists():
|
||||
return f"{self.__class__.__name__}(file_paths={self._file_paths!r})"
|
||||
return repr(self._parse(True))
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of entries in the dict. Thread-safe."""
|
||||
return len(self._parse(True))
|
||||
|
||||
def has_key(self, k: str) -> bool:
|
||||
"""True if the given key is in the dict. Thread-safe."""
|
||||
return k in self._parse(True)
|
||||
|
||||
def keys(self) -> KeysView[str]:
|
||||
"""A view of the keys in the dict. Thread-safe."""
|
||||
return self._parse(True).keys()
|
||||
|
||||
def values(self) -> ValuesView[Any]:
|
||||
"""A view of the values in the dict. Thread-safe."""
|
||||
return self._parse(True).values()
|
||||
|
||||
def items(self) -> ItemsView[str, Any]:
|
||||
"""A view of the key-value items in the dict. Thread-safe."""
|
||||
return self._parse(True).items()
|
||||
|
||||
def __contains__(self, key: Any) -> bool:
|
||||
"""True if the given key is in the dict. Thread-safe."""
|
||||
return key in self._parse(True)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
"""An iterator over the keys in the dict. Thread-safe."""
|
||||
return iter(self._parse(True))
|
||||
|
||||
|
||||
secrets_singleton: Final = Secrets(SECRETS_FILE_LOCS)
|
||||
@@ -0,0 +1,381 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional, cast
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
|
||||
|
||||
class SessionClientDisconnectedError(Exception):
|
||||
"""Raised by operations on a disconnected SessionClient."""
|
||||
|
||||
|
||||
class SessionClient(Protocol):
|
||||
"""Interface for sending data to a session's client."""
|
||||
|
||||
@abstractmethod
|
||||
def write_forward_msg(self, msg: ForwardMsg) -> None:
|
||||
"""Deliver a ForwardMsg to the client.
|
||||
|
||||
If the SessionClient has been disconnected, it should raise a
|
||||
SessionClientDisconnectedError.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveSessionInfo:
|
||||
"""Type containing data related to an active session.
|
||||
|
||||
This type is nearly identical to SessionInfo. The difference is that when using it,
|
||||
we are guaranteed that SessionClient is not None.
|
||||
"""
|
||||
|
||||
client: SessionClient
|
||||
session: AppSession
|
||||
script_run_count: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionInfo:
|
||||
"""Type containing data related to an AppSession.
|
||||
|
||||
For each AppSession, the Runtime tracks that session's
|
||||
script_run_count. This is used to track the age of messages in
|
||||
the ForwardMsgCache.
|
||||
"""
|
||||
|
||||
client: Optional[SessionClient]
|
||||
session: AppSession
|
||||
script_run_count: int = 0
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return self.client is not None
|
||||
|
||||
def to_active(self) -> ActiveSessionInfo:
|
||||
assert self.is_active(), "A SessionInfo with no client cannot be active!"
|
||||
|
||||
# NOTE: The cast here (rather than copying this SessionInfo's fields into a new
|
||||
# ActiveSessionInfo) is important as the Runtime expects to be able to mutate
|
||||
# what's returned from get_active_session_info to increment script_run_count.
|
||||
return cast(ActiveSessionInfo, self)
|
||||
|
||||
|
||||
class SessionStorageError(Exception):
|
||||
"""Exception class for errors raised by SessionStorage.
|
||||
|
||||
The original error that causes a SessionStorageError to be (re)raised will generally
|
||||
be an I/O error specific to the concrete SessionStorage implementation.
|
||||
"""
|
||||
|
||||
|
||||
class SessionStorage(Protocol):
|
||||
@abstractmethod
|
||||
def get(self, session_id: str) -> Optional[SessionInfo]:
|
||||
"""Return the SessionInfo corresponding to session_id, or None if one does not
|
||||
exist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The unique ID of the session being fetched.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SessionInfo or None
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while attempting to fetch the session. This will
|
||||
generally happen if there is an error with the underlying storage backend
|
||||
(e.g. if we lose our connection to it).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def save(self, session_info: SessionInfo) -> None:
|
||||
"""Save the given session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_info
|
||||
The SessionInfo being saved.
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while saving the given session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, session_id: str) -> None:
|
||||
"""Mark the session corresponding to session_id for deletion and stop tracking
|
||||
it.
|
||||
|
||||
Note that:
|
||||
* Calling delete on an ID corresponding to a nonexistent session is a no-op.
|
||||
* Calling delete on an ID should cause the given session to no longer be
|
||||
tracked by this SessionStorage, but exactly when and how the session's data
|
||||
is eventually cleaned up is a detail left up to the implementation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The unique ID of the session to delete.
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while attempting to delete the session.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list(self) -> List[SessionInfo]:
|
||||
"""List all sessions tracked by this SessionStorage.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[SessionInfo]
|
||||
|
||||
Raises
|
||||
------
|
||||
SessionStorageError
|
||||
Raised if an error occurs while attempting to list sessions.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SessionManager(Protocol):
|
||||
"""SessionManagers are responsible for encapsulating all session lifecycle behavior
|
||||
that the Streamlit Runtime may care about.
|
||||
|
||||
A SessionManager must define the following required methods:
|
||||
* __init__
|
||||
* connect_session
|
||||
* close_session
|
||||
* get_session_info
|
||||
* list_sessions
|
||||
|
||||
SessionManager implementations may also choose to define the notions of active and
|
||||
inactive sessions. The precise definitions of active/inactive are left to the
|
||||
concrete implementation. SessionManagers that wish to differentiate between active
|
||||
and inactive sessions should have the required methods listed above operate on *all*
|
||||
sessions. Additionally, they should define the following methods for working with
|
||||
active sessions:
|
||||
* disconnect_session
|
||||
* get_active_session_info
|
||||
* is_active_session
|
||||
* list_active_sessions
|
||||
|
||||
When active session-related methods are left undefined, their default
|
||||
implementations are the naturally corresponding required methods.
|
||||
|
||||
The Runtime, unless there's a good reason to do otherwise, should generally work
|
||||
with the active-session versions of a SessionManager's methods. There isn't currently
|
||||
a need for us to be able to operate on inactive sessions stored in SessionStorage
|
||||
outside of the SessionManager itself. However, it's highly likely that we'll
|
||||
eventually have to do so, which is why the abstractions allow for this now.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: All SessionManager methods are *not* threadsafe -- they must be called
|
||||
from the runtime's eventloop thread.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
session_storage: SessionStorage,
|
||||
uploaded_file_manager: UploadedFileManager,
|
||||
message_enqueued_callback: Optional[Callable[[], None]],
|
||||
) -> None:
|
||||
"""Initialize a SessionManager with the given SessionStorage.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_storage
|
||||
The SessionStorage instance backing this SessionManager.
|
||||
|
||||
uploaded_file_manager
|
||||
Used to manage files uploaded by users via the Streamlit web client.
|
||||
|
||||
message_enqueued_callback
|
||||
A callback invoked after a message is enqueued to be sent to a web client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def connect_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
script_data: ScriptData,
|
||||
user_info: Dict[str, Optional[str]],
|
||||
existing_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Create a new session or connect to an existing one.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client
|
||||
A concrete SessionClient implementation for communicating with
|
||||
the session's client.
|
||||
script_data
|
||||
Contains parameters related to running a script.
|
||||
user_info
|
||||
A dict that contains information about the session's user. For now,
|
||||
it only (optionally) contains the user's email address.
|
||||
|
||||
{
|
||||
"email": "example@example.com"
|
||||
}
|
||||
existing_session_id
|
||||
The ID of an existing session to reconnect to. If one is not provided, a new
|
||||
session is created. Note that whether a SessionManager supports reconnection
|
||||
to an existing session is left up to the concrete SessionManager
|
||||
implementation. Those that do not support reconnection should simply ignore
|
||||
this argument.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
The session's unique string ID.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def close_session(self, session_id: str) -> None:
|
||||
"""Close and completely delete the session with the given id.
|
||||
|
||||
This function may be called multiple times for the same session,
|
||||
which is not an error. (Subsequent calls just no-op.)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
|
||||
"""Return the SessionInfo for the given id, or None if no such session
|
||||
exists.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
|
||||
Returns
|
||||
-------
|
||||
SessionInfo or None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def list_sessions(self) -> List[SessionInfo]:
|
||||
"""Return the SessionInfo for all sessions managed by this SessionManager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[SessionInfo]
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def num_sessions(self) -> int:
|
||||
"""Return the number of sessions tracked by this SessionManager.
|
||||
|
||||
Subclasses of SessionManager shouldn't provide their own implementation of this
|
||||
method without a *very* good reason.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
return len(self.list_sessions())
|
||||
|
||||
# NOTE: The following methods only need to be overwritten when a concrete
|
||||
# SessionManager implementation has a notion of active vs inactive sessions.
|
||||
# If left unimplemented in a subclass, the default implementations of these methods
|
||||
# call corresponding SessionManager methods in a natural way.
|
||||
|
||||
def disconnect_session(self, session_id: str) -> None:
|
||||
"""Disconnect the given session.
|
||||
|
||||
This method should be idempotent.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The session's unique ID.
|
||||
"""
|
||||
self.close_session(session_id)
|
||||
|
||||
def get_active_session_info(self, session_id: str) -> Optional[ActiveSessionInfo]:
|
||||
"""Return the ActiveSessionInfo for the given id, or None if either no such
|
||||
session exists or the session is not active.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The active session's unique ID.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ActiveSessionInfo or None
|
||||
"""
|
||||
session = self.get_session_info(session_id)
|
||||
if session is None or not session.is_active():
|
||||
return None
|
||||
return session.to_active()
|
||||
|
||||
def is_active_session(self, session_id: str) -> bool:
|
||||
"""Return True if the given session exists and is active, False otherwise.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
"""
|
||||
return self.get_active_session_info(session_id) is not None
|
||||
|
||||
def list_active_sessions(self) -> List[ActiveSessionInfo]:
|
||||
"""Return the session info for all active sessions tracked by this SessionManager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[ActiveSessionInfo]
|
||||
"""
|
||||
return [s.to_active() for s in self.list_sessions()]
|
||||
|
||||
def num_active_sessions(self) -> int:
|
||||
"""Return the number of active sessions tracked by this SessionManager.
|
||||
|
||||
Subclasses of SessionManager shouldn't provide their own implementation of this
|
||||
method without a *very* good reason.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
"""
|
||||
return len(self.list_active_sessions())
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.state.common import WidgetArgs as WidgetArgs
|
||||
from streamlit.runtime.state.common import WidgetCallback as WidgetCallback
|
||||
from streamlit.runtime.state.common import WidgetKwargs as WidgetKwargs
|
||||
|
||||
# Explicitly re-export public symbols
|
||||
from streamlit.runtime.state.safe_session_state import (
|
||||
SafeSessionState as SafeSessionState,
|
||||
)
|
||||
from streamlit.runtime.state.session_state import (
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY as SCRIPT_RUN_WITHOUT_ERRORS_KEY,
|
||||
)
|
||||
from streamlit.runtime.state.session_state import SessionState as SessionState
|
||||
from streamlit.runtime.state.session_state import (
|
||||
SessionStateStatProvider as SessionStateStatProvider,
|
||||
)
|
||||
from streamlit.runtime.state.session_state_proxy import (
|
||||
SessionStateProxy as SessionStateProxy,
|
||||
)
|
||||
from streamlit.runtime.state.session_state_proxy import (
|
||||
get_session_state as get_session_state,
|
||||
)
|
||||
from streamlit.runtime.state.widgets import NoValue as NoValue
|
||||
from streamlit.runtime.state.widgets import (
|
||||
coalesce_widget_states as coalesce_widget_states,
|
||||
)
|
||||
from streamlit.runtime.state.widgets import register_widget as register_widget
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Functions and data structures shared by session_state.py and widgets.py"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar, Union
|
||||
|
||||
from typing_extensions import Final, TypeAlias
|
||||
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.proto.Arrow_pb2 import Arrow
|
||||
from streamlit.proto.Button_pb2 import Button
|
||||
from streamlit.proto.CameraInput_pb2 import CameraInput
|
||||
from streamlit.proto.Checkbox_pb2 import Checkbox
|
||||
from streamlit.proto.ColorPicker_pb2 import ColorPicker
|
||||
from streamlit.proto.Components_pb2 import ComponentInstance
|
||||
from streamlit.proto.DateInput_pb2 import DateInput
|
||||
from streamlit.proto.DownloadButton_pb2 import DownloadButton
|
||||
from streamlit.proto.FileUploader_pb2 import FileUploader
|
||||
from streamlit.proto.MultiSelect_pb2 import MultiSelect
|
||||
from streamlit.proto.NumberInput_pb2 import NumberInput
|
||||
from streamlit.proto.Radio_pb2 import Radio
|
||||
from streamlit.proto.Selectbox_pb2 import Selectbox
|
||||
from streamlit.proto.Slider_pb2 import Slider
|
||||
from streamlit.proto.TextArea_pb2 import TextArea
|
||||
from streamlit.proto.TextInput_pb2 import TextInput
|
||||
from streamlit.proto.TimeInput_pb2 import TimeInput
|
||||
from streamlit.type_util import ValueFieldName
|
||||
|
||||
# Protobuf types for all widgets.
|
||||
WidgetProto: TypeAlias = Union[
|
||||
Arrow,
|
||||
Button,
|
||||
CameraInput,
|
||||
Checkbox,
|
||||
ColorPicker,
|
||||
ComponentInstance,
|
||||
DateInput,
|
||||
DownloadButton,
|
||||
FileUploader,
|
||||
MultiSelect,
|
||||
NumberInput,
|
||||
Radio,
|
||||
Selectbox,
|
||||
Slider,
|
||||
TextArea,
|
||||
TextInput,
|
||||
TimeInput,
|
||||
]
|
||||
|
||||
GENERATED_WIDGET_ID_PREFIX: Final = "$$GENERATED_WIDGET_ID"
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
T_co = TypeVar("T_co", covariant=True)
|
||||
|
||||
|
||||
WidgetArgs: TypeAlias = Tuple[Any, ...]
|
||||
WidgetKwargs: TypeAlias = Dict[str, Any]
|
||||
WidgetCallback: TypeAlias = Callable[..., None]
|
||||
|
||||
# A deserializer receives the value from whatever field is set on the
|
||||
# WidgetState proto, and returns a regular python value. A serializer
|
||||
# receives a regular python value, and returns something suitable for
|
||||
# a value field on WidgetState proto. They should be inverses.
|
||||
WidgetDeserializer: TypeAlias = Callable[[Any, str], T]
|
||||
WidgetSerializer: TypeAlias = Callable[[T], Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WidgetMetadata(Generic[T]):
|
||||
"""Metadata associated with a single widget. Immutable."""
|
||||
|
||||
id: str
|
||||
deserializer: WidgetDeserializer[T] = field(repr=False)
|
||||
serializer: WidgetSerializer[T] = field(repr=False)
|
||||
value_type: ValueFieldName
|
||||
|
||||
# An optional user-code callback invoked when the widget's value changes.
|
||||
# Widget callbacks are called at the start of a script run, before the
|
||||
# body of the script is executed.
|
||||
callback: WidgetCallback | None = None
|
||||
callback_args: WidgetArgs | None = None
|
||||
callback_kwargs: WidgetKwargs | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterWidgetResult(Generic[T_co]):
|
||||
"""Result returned by the `register_widget` family of functions/methods.
|
||||
|
||||
Should be usable by widget code to determine what value to return, and
|
||||
whether to update the UI.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : T_co
|
||||
The widget's current value, or, in cases where the true widget value
|
||||
could not be determined, an appropriate fallback value.
|
||||
|
||||
This value should be returned by the widget call.
|
||||
value_changed : bool
|
||||
True if the widget's value is different from the value most recently
|
||||
returned from the frontend.
|
||||
|
||||
Implies an update to the frontend is needed.
|
||||
"""
|
||||
|
||||
value: T_co
|
||||
value_changed: bool
|
||||
|
||||
@classmethod
|
||||
def failure(
|
||||
cls, deserializer: WidgetDeserializer[T_co]
|
||||
) -> "RegisterWidgetResult[T_co]":
|
||||
"""The canonical way to construct a RegisterWidgetResult in cases
|
||||
where the true widget value could not be determined.
|
||||
"""
|
||||
return cls(value=deserializer(None, ""), value_changed=False)
|
||||
|
||||
|
||||
def compute_widget_id(
|
||||
element_type: str, element_proto: WidgetProto, user_key: Optional[str] = None
|
||||
) -> str:
|
||||
"""Compute the widget id for the given widget. This id is stable: a given
|
||||
set of inputs to this function will always produce the same widget id output.
|
||||
|
||||
The widget id includes the user_key so widgets with identical arguments can
|
||||
use it to be distinct.
|
||||
|
||||
The widget id includes an easily identified prefix, and the user_key as a
|
||||
suffix, to make it easy to identify it and know if a key maps to it.
|
||||
|
||||
Does not mutate the element_proto object.
|
||||
"""
|
||||
h = hashlib.new("md5")
|
||||
h.update(element_type.encode("utf-8"))
|
||||
h.update(element_proto.SerializeToString())
|
||||
return f"{GENERATED_WIDGET_ID_PREFIX}-{h.hexdigest()}-{user_key}"
|
||||
|
||||
|
||||
def user_key_from_widget_id(widget_id: str) -> Optional[str]:
|
||||
"""Return the user key portion of a widget id, or None if the id does not
|
||||
have a user key.
|
||||
|
||||
TODO This will incorrectly indicate no user key if the user actually provides
|
||||
"None" as a key, but we can't avoid this kind of problem while storing the
|
||||
string representation of the no-user-key sentinel as part of the widget id.
|
||||
"""
|
||||
user_key = widget_id.split("-", maxsplit=2)[-1]
|
||||
user_key = None if user_key == "None" else user_key
|
||||
return user_key
|
||||
|
||||
|
||||
def is_widget_id(key: str) -> bool:
|
||||
"""True if the given session_state key has the structure of a widget ID."""
|
||||
return key.startswith(GENERATED_WIDGET_ID_PREFIX)
|
||||
|
||||
|
||||
def is_keyed_widget_id(key: str) -> bool:
|
||||
"""True if the given session_state key has the structure of a widget ID with a user_key."""
|
||||
return is_widget_id(key) and not key.endswith("-None")
|
||||
|
||||
|
||||
def require_valid_user_key(key: str) -> None:
|
||||
"""Raise an Exception if the given user_key is invalid."""
|
||||
if is_widget_id(key):
|
||||
raise StreamlitAPIException(
|
||||
f"Keys beginning with {GENERATED_WIDGET_ID_PREFIX} are reserved."
|
||||
)
|
||||
@@ -0,0 +1,134 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
|
||||
from streamlit.runtime.state.common import RegisterWidgetResult, T, WidgetMetadata
|
||||
from streamlit.runtime.state.session_state import SessionState
|
||||
|
||||
|
||||
class SafeSessionState:
|
||||
"""Thread-safe wrapper around SessionState.
|
||||
|
||||
When AppSession gets a re-run request, it can interrupt its existing
|
||||
ScriptRunner and spin up a new ScriptRunner to handle the request.
|
||||
When this happens, the existing ScriptRunner will continue executing
|
||||
its script until it reaches a yield point - but during this time, it
|
||||
must not mutate its SessionState. An interrupted ScriptRunner assigns
|
||||
a dummy SessionState instance to its wrapper to prevent further mutation.
|
||||
"""
|
||||
|
||||
def __init__(self, state: SessionState):
|
||||
self._state = state
|
||||
# TODO: we'd prefer this be a threading.Lock instead of RLock -
|
||||
# but `call_callbacks` first needs to be rewritten.
|
||||
self._lock = threading.RLock()
|
||||
self._disconnected = False
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect the wrapper from its underlying SessionState.
|
||||
ScriptRunner calls this when it gets a stop request. After this
|
||||
function is called, all future SessionState interactions are no-ops.
|
||||
"""
|
||||
with self._lock:
|
||||
self._disconnected = True
|
||||
|
||||
def register_widget(
|
||||
self, metadata: WidgetMetadata[T], user_key: Optional[str]
|
||||
) -> RegisterWidgetResult[T]:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return RegisterWidgetResult.failure(metadata.deserializer)
|
||||
|
||||
return self._state.register_widget(metadata, user_key)
|
||||
|
||||
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return
|
||||
|
||||
# TODO: rewrite this to copy the callbacks list into a local
|
||||
# variable so that we don't need to hold our lock for the
|
||||
# duration. (This will also allow us to downgrade our RLock
|
||||
# to a Lock.)
|
||||
self._state.on_script_will_rerun(latest_widget_states)
|
||||
|
||||
def on_script_finished(self, widget_ids_this_run: Set[str]) -> None:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return
|
||||
|
||||
self._state.on_script_finished(widget_ids_this_run)
|
||||
|
||||
def maybe_check_serializable(self) -> None:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return
|
||||
|
||||
self._state.maybe_check_serializable()
|
||||
|
||||
def get_widget_states(self) -> List[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return []
|
||||
|
||||
return self._state.get_widget_states()
|
||||
|
||||
def is_new_state_value(self, user_key: str) -> bool:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return False
|
||||
|
||||
return self._state.is_new_state_value(user_key)
|
||||
|
||||
@property
|
||||
def filtered_state(self) -> Dict[str, Any]:
|
||||
"""The combined session and widget state, excluding keyless widgets."""
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return {}
|
||||
|
||||
return self._state.filtered_state
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._state[key]
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return
|
||||
|
||||
self._state[key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
raise KeyError(key)
|
||||
|
||||
del self._state[key]
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
with self._lock:
|
||||
if self._disconnected:
|
||||
return False
|
||||
|
||||
return key in self._state
|
||||
@@ -0,0 +1,648 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pickle
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterator,
|
||||
KeysView,
|
||||
List,
|
||||
MutableMapping,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from pympler.asizeof import asizeof
|
||||
from typing_extensions import Final, TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import config
|
||||
from streamlit.errors import StreamlitAPIException, UnserializableSessionStateError
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState as WidgetStateProto
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetStates as WidgetStatesProto
|
||||
from streamlit.runtime.state.common import (
|
||||
RegisterWidgetResult,
|
||||
T,
|
||||
WidgetMetadata,
|
||||
is_keyed_widget_id,
|
||||
is_widget_id,
|
||||
)
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
from streamlit.type_util import ValueFieldName, is_array_value_field_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.session_manager import SessionManager
|
||||
|
||||
|
||||
STREAMLIT_INTERNAL_KEY_PREFIX: Final = "$$STREAMLIT_INTERNAL_KEY"
|
||||
SCRIPT_RUN_WITHOUT_ERRORS_KEY: Final = (
|
||||
f"{STREAMLIT_INTERNAL_KEY_PREFIX}_SCRIPT_RUN_WITHOUT_ERRORS"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Serialized:
|
||||
"""A widget value that's serialized to a protobuf. Immutable."""
|
||||
|
||||
value: WidgetStateProto
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Value:
|
||||
"""A widget value that's not serialized. Immutable."""
|
||||
|
||||
value: Any
|
||||
|
||||
|
||||
WState: TypeAlias = Union[Value, Serialized]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WStates(MutableMapping[str, Any]):
|
||||
"""A mapping of widget IDs to values. Widget values can be stored in
|
||||
serialized or deserialized form, but when values are retrieved from the
|
||||
mapping, they'll always be deserialized.
|
||||
"""
|
||||
|
||||
states: dict[str, WState] = field(default_factory=dict)
|
||||
widget_metadata: dict[str, WidgetMetadata[Any]] = field(default_factory=dict)
|
||||
|
||||
def __getitem__(self, k: str) -> Any:
|
||||
"""Return the value of the widget with the given key.
|
||||
If the widget's value is currently stored in serialized form, it
|
||||
will be deserialized first.
|
||||
"""
|
||||
wstate = self.states.get(k)
|
||||
if wstate is None:
|
||||
raise KeyError(k)
|
||||
|
||||
if isinstance(wstate, Value):
|
||||
# The widget's value is already deserialized - return it directly.
|
||||
return wstate.value
|
||||
|
||||
# The widget's value is serialized. We deserialize it, and return
|
||||
# the deserialized value.
|
||||
|
||||
metadata = self.widget_metadata.get(k)
|
||||
if metadata is None:
|
||||
# No deserializer, which should only happen if state is
|
||||
# gotten from a reconnecting browser and the script is
|
||||
# trying to access it. Pretend it doesn't exist.
|
||||
raise KeyError(k)
|
||||
value_field_name = cast(
|
||||
ValueFieldName,
|
||||
wstate.value.WhichOneof("value"),
|
||||
)
|
||||
value = wstate.value.__getattribute__(value_field_name)
|
||||
|
||||
if is_array_value_field_name(value_field_name):
|
||||
# Array types are messages with data in a `data` field
|
||||
value = value.data
|
||||
elif value_field_name == "json_value":
|
||||
value = json.loads(value)
|
||||
|
||||
deserialized = metadata.deserializer(value, metadata.id)
|
||||
|
||||
# Update metadata to reflect information from WidgetState proto
|
||||
self.set_widget_metadata(
|
||||
replace(
|
||||
metadata,
|
||||
value_type=value_field_name,
|
||||
)
|
||||
)
|
||||
|
||||
self.states[k] = Value(deserialized)
|
||||
return deserialized
|
||||
|
||||
def __setitem__(self, k: str, v: WState) -> None:
|
||||
self.states[k] = v
|
||||
|
||||
def __delitem__(self, k: str) -> None:
|
||||
del self.states[k]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.states)
|
||||
|
||||
def __iter__(self):
|
||||
# For this and many other methods, we can't simply delegate to the
|
||||
# states field, because we need to invoke `__getitem__` for any
|
||||
# values, to handle deserialization and unwrapping of values.
|
||||
for key in self.states:
|
||||
yield key
|
||||
|
||||
def keys(self) -> KeysView[str]:
|
||||
return KeysView(self.states)
|
||||
|
||||
def items(self) -> set[tuple[str, Any]]: # type: ignore[override]
|
||||
return {(k, self[k]) for k in self}
|
||||
|
||||
def values(self) -> set[Any]: # type: ignore[override]
|
||||
return {self[wid] for wid in self}
|
||||
|
||||
def update(self, other: "WStates") -> None: # type: ignore[override]
|
||||
"""Copy all widget values and metadata from 'other' into this mapping,
|
||||
overwriting any data in this mapping that's also present in 'other'.
|
||||
"""
|
||||
self.states.update(other.states)
|
||||
self.widget_metadata.update(other.widget_metadata)
|
||||
|
||||
def set_widget_from_proto(self, widget_state: WidgetStateProto) -> None:
|
||||
"""Set a widget's serialized value, overwriting any existing value it has."""
|
||||
self[widget_state.id] = Serialized(widget_state)
|
||||
|
||||
def set_from_value(self, k: str, v: Any) -> None:
|
||||
"""Set a widget's deserialized value, overwriting any existing value it has."""
|
||||
self[k] = Value(v)
|
||||
|
||||
def set_widget_metadata(self, widget_meta: WidgetMetadata[Any]) -> None:
|
||||
"""Set a widget's metadata, overwriting any existing metadata it has."""
|
||||
self.widget_metadata[widget_meta.id] = widget_meta
|
||||
|
||||
def remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
|
||||
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
|
||||
self.states = {k: v for k, v in self.states.items() if k in active_widget_ids}
|
||||
|
||||
def get_serialized(self, k: str) -> WidgetStateProto | None:
|
||||
"""Get the serialized value of the widget with the given id.
|
||||
|
||||
If the widget doesn't exist, return None. If the widget exists but
|
||||
is not in serialized form, it will be serialized first.
|
||||
"""
|
||||
|
||||
item = self.states.get(k)
|
||||
if item is None:
|
||||
# No such widget: return None.
|
||||
return None
|
||||
|
||||
if isinstance(item, Serialized):
|
||||
# Widget value is serialized: return it directly.
|
||||
return item.value
|
||||
|
||||
# Widget value is not serialized: serialize it first!
|
||||
metadata = self.widget_metadata.get(k)
|
||||
if metadata is None:
|
||||
# We're missing the widget's metadata. (Can this happen?)
|
||||
return None
|
||||
|
||||
widget = WidgetStateProto()
|
||||
widget.id = k
|
||||
|
||||
field = metadata.value_type
|
||||
serialized = metadata.serializer(item.value)
|
||||
if is_array_value_field_name(field):
|
||||
arr = getattr(widget, field)
|
||||
arr.data.extend(serialized)
|
||||
elif field == "json_value":
|
||||
setattr(widget, field, json.dumps(serialized))
|
||||
elif field == "file_uploader_state_value":
|
||||
widget.file_uploader_state_value.CopyFrom(serialized)
|
||||
else:
|
||||
setattr(widget, field, serialized)
|
||||
|
||||
return widget
|
||||
|
||||
def as_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
states = [
|
||||
self.get_serialized(widget_id)
|
||||
for widget_id in self.states.keys()
|
||||
if self.get_serialized(widget_id)
|
||||
]
|
||||
states = cast(List[WidgetStateProto], states)
|
||||
return states
|
||||
|
||||
def call_callback(self, widget_id: str) -> None:
|
||||
"""Call the given widget's callback and return the callback's
|
||||
return value. If the widget has no callback, return None.
|
||||
|
||||
If the widget doesn't exist, raise an Exception.
|
||||
"""
|
||||
metadata = self.widget_metadata.get(widget_id)
|
||||
assert metadata is not None
|
||||
callback = metadata.callback
|
||||
if callback is None:
|
||||
return
|
||||
|
||||
args = metadata.callback_args or ()
|
||||
kwargs = metadata.callback_kwargs or {}
|
||||
callback(*args, **kwargs)
|
||||
|
||||
|
||||
def _missing_key_error_message(key: str) -> str:
|
||||
return (
|
||||
f'st.session_state has no key "{key}". Did you forget to initialize it? '
|
||||
f"More info: https://docs.streamlit.io/library/advanced-features/session-state#initialization"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionState:
|
||||
"""SessionState allows users to store values that persist between app
|
||||
reruns.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> if "num_script_runs" not in st.session_state:
|
||||
... st.session_state.num_script_runs = 0
|
||||
>>> st.session_state.num_script_runs += 1
|
||||
>>> st.write(st.session_state.num_script_runs) # writes 1
|
||||
|
||||
The next time your script runs, the value of
|
||||
st.session_state.num_script_runs will be preserved.
|
||||
>>> st.session_state.num_script_runs += 1
|
||||
>>> st.write(st.session_state.num_script_runs) # writes 2
|
||||
"""
|
||||
|
||||
# All the values from previous script runs, squished together to save memory
|
||||
_old_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Values set in session state during the current script run, possibly for
|
||||
# setting a widget's value. Keyed by a user provided string.
|
||||
_new_session_state: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Widget values from the frontend, usually one changing prompted the script rerun
|
||||
_new_widget_state: WStates = field(default_factory=WStates)
|
||||
|
||||
# Keys used for widgets will be eagerly converted to the matching widget id
|
||||
_key_id_mapping: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# is it possible for a value to get through this without being deserialized?
|
||||
def _compact_state(self) -> None:
|
||||
"""Copy all current session_state and widget_state values into our
|
||||
_old_state dict, and then clear our current session_state and
|
||||
widget_state.
|
||||
"""
|
||||
for key_or_wid in self:
|
||||
self._old_state[key_or_wid] = self[key_or_wid]
|
||||
self._new_session_state.clear()
|
||||
self._new_widget_state.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Reset self completely, clearing all current and old values."""
|
||||
self._old_state.clear()
|
||||
self._new_session_state.clear()
|
||||
self._new_widget_state.clear()
|
||||
self._key_id_mapping.clear()
|
||||
|
||||
@property
|
||||
def filtered_state(self) -> dict[str, Any]:
|
||||
"""The combined session and widget state, excluding keyless widgets."""
|
||||
|
||||
wid_key_map = self._reverse_key_wid_map
|
||||
|
||||
state: dict[str, Any] = {}
|
||||
|
||||
# We can't write `for k, v in self.items()` here because doing so will
|
||||
# run into a `KeyError` if widget metadata has been cleared (which
|
||||
# happens when the streamlit server restarted or the cache was cleared),
|
||||
# then we receive a widget's state from a browser.
|
||||
for k in self._keys():
|
||||
if not is_widget_id(k) and not _is_internal_key(k):
|
||||
state[k] = self[k]
|
||||
elif is_keyed_widget_id(k):
|
||||
try:
|
||||
key = wid_key_map[k]
|
||||
state[key] = self[k]
|
||||
except KeyError:
|
||||
# Widget id no longer maps to a key, it is a not yet
|
||||
# cleared value in old state for a reset widget
|
||||
pass
|
||||
|
||||
return state
|
||||
|
||||
@property
|
||||
def _reverse_key_wid_map(self) -> dict[str, str]:
|
||||
"""Return a mapping of widget_id : widget_key."""
|
||||
wid_key_map = {v: k for k, v in self._key_id_mapping.items()}
|
||||
return wid_key_map
|
||||
|
||||
def _keys(self) -> set[str]:
|
||||
"""All keys active in Session State, with widget keys converted
|
||||
to widget ids when one is known. (This includes autogenerated keys
|
||||
for widgets that don't have user_keys defined, and which aren't
|
||||
exposed to user code.)
|
||||
"""
|
||||
old_keys = {self._get_widget_id(k) for k in self._old_state.keys()}
|
||||
new_widget_keys = set(self._new_widget_state.keys())
|
||||
new_session_state_keys = {
|
||||
self._get_widget_id(k) for k in self._new_session_state.keys()
|
||||
}
|
||||
return old_keys | new_widget_keys | new_session_state_keys
|
||||
|
||||
def is_new_state_value(self, user_key: str) -> bool:
|
||||
"""True if a value with the given key is in the current session state."""
|
||||
return user_key in self._new_session_state
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Return an iterator over the keys of the SessionState.
|
||||
This is a shortcut for `iter(self.keys())`
|
||||
"""
|
||||
return iter(self._keys())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of items in SessionState."""
|
||||
return len(self._keys())
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
wid_key_map = self._reverse_key_wid_map
|
||||
widget_id = self._get_widget_id(key)
|
||||
|
||||
if widget_id in wid_key_map and widget_id == key:
|
||||
# the "key" is a raw widget id, so get its associated user key for lookup
|
||||
key = wid_key_map[widget_id]
|
||||
try:
|
||||
return self._getitem(widget_id, key)
|
||||
except KeyError:
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
def _getitem(self, widget_id: str | None, user_key: str | None) -> Any:
|
||||
"""Get the value of an entry in Session State, using either the
|
||||
user-provided key or a widget id as appropriate for the internal dict
|
||||
being accessed.
|
||||
|
||||
At least one of the arguments must have a value.
|
||||
"""
|
||||
assert user_key is not None or widget_id is not None
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
return self._new_session_state[user_key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if widget_id is not None:
|
||||
try:
|
||||
return self._new_widget_state[widget_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# Typically, there won't be both a widget id and an associated state key in
|
||||
# old state at the same time, so the order we check is arbitrary.
|
||||
# The exception is if session state is set and then a later run has
|
||||
# a widget created, so the widget id entry should be newer.
|
||||
# The opposite case shouldn't happen, because setting the value of a widget
|
||||
# through session state will result in the next widget state reflecting that
|
||||
# value.
|
||||
if widget_id is not None:
|
||||
try:
|
||||
return self._old_state[widget_id]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
if user_key is not None:
|
||||
try:
|
||||
return self._old_state[user_key]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We'll never get here
|
||||
raise KeyError
|
||||
|
||||
def __setitem__(self, user_key: str, value: Any) -> None:
|
||||
"""Set the value of the session_state entry with the given user_key.
|
||||
|
||||
If the key corresponds to a widget or form that's been instantiated
|
||||
during the current script run, raise a StreamlitAPIException instead.
|
||||
"""
|
||||
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
if ctx is not None:
|
||||
widget_id = self._key_id_mapping.get(user_key, None)
|
||||
widget_ids = ctx.widget_ids_this_run
|
||||
form_ids = ctx.form_ids_this_run
|
||||
|
||||
if widget_id in widget_ids or user_key in form_ids:
|
||||
raise StreamlitAPIException(
|
||||
f"`st.session_state.{user_key}` cannot be modified after the widget"
|
||||
f" with key `{user_key}` is instantiated."
|
||||
)
|
||||
|
||||
self._new_session_state[user_key] = value
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
widget_id = self._get_widget_id(key)
|
||||
|
||||
if not (key in self or widget_id in self):
|
||||
raise KeyError(_missing_key_error_message(key))
|
||||
|
||||
if key in self._new_session_state:
|
||||
del self._new_session_state[key]
|
||||
|
||||
if key in self._old_state:
|
||||
del self._old_state[key]
|
||||
|
||||
if key in self._key_id_mapping:
|
||||
del self._key_id_mapping[key]
|
||||
|
||||
if widget_id in self._new_widget_state:
|
||||
del self._new_widget_state[widget_id]
|
||||
|
||||
if widget_id in self._old_state:
|
||||
del self._old_state[widget_id]
|
||||
|
||||
def set_widgets_from_proto(self, widget_states: WidgetStatesProto) -> None:
|
||||
"""Set the value of all widgets represented in the given WidgetStatesProto."""
|
||||
for state in widget_states.widgets:
|
||||
self._new_widget_state.set_widget_from_proto(state)
|
||||
|
||||
def on_script_will_rerun(self, latest_widget_states: WidgetStatesProto) -> None:
|
||||
"""Called by ScriptRunner before its script re-runs.
|
||||
|
||||
Update widget data and call callbacks on widgets whose value changed
|
||||
between the previous and current script runs.
|
||||
"""
|
||||
# Update ourselves with the new widget_states. The old widget states,
|
||||
# used to skip callbacks if values haven't changed, are also preserved.
|
||||
self._compact_state()
|
||||
self.set_widgets_from_proto(latest_widget_states)
|
||||
self._call_callbacks()
|
||||
|
||||
def _call_callbacks(self) -> None:
|
||||
"""Call any callback associated with each widget whose value
|
||||
changed between the previous and current script runs.
|
||||
"""
|
||||
from streamlit.runtime.scriptrunner import RerunException
|
||||
|
||||
changed_widget_ids = [
|
||||
wid for wid in self._new_widget_state if self._widget_changed(wid)
|
||||
]
|
||||
for wid in changed_widget_ids:
|
||||
try:
|
||||
self._new_widget_state.call_callback(wid)
|
||||
except RerunException:
|
||||
st.warning(
|
||||
"Calling st.experimental_rerun() within a callback is a no-op."
|
||||
)
|
||||
|
||||
def _widget_changed(self, widget_id: str) -> bool:
|
||||
"""True if the given widget's value changed between the previous
|
||||
script run and the current script run.
|
||||
"""
|
||||
new_value = self._new_widget_state.get(widget_id)
|
||||
old_value = self._old_state.get(widget_id)
|
||||
changed: bool = new_value != old_value
|
||||
return changed
|
||||
|
||||
def on_script_finished(self, widget_ids_this_run: set[str]) -> None:
|
||||
"""Called by ScriptRunner after its script finishes running.
|
||||
Updates widgets to prepare for the next script run.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
widget_ids_this_run: set[str]
|
||||
The IDs of the widgets that were accessed during the script
|
||||
run. Any widget state whose ID does *not* appear in this set
|
||||
is considered "stale" and will be removed.
|
||||
"""
|
||||
self._reset_triggers()
|
||||
self._remove_stale_widgets(widget_ids_this_run)
|
||||
|
||||
def _reset_triggers(self) -> None:
|
||||
"""Set all trigger values in our state dictionary to False."""
|
||||
for state_id in self._new_widget_state:
|
||||
metadata = self._new_widget_state.widget_metadata.get(state_id)
|
||||
if metadata is not None and metadata.value_type == "trigger_value":
|
||||
self._new_widget_state[state_id] = Value(False)
|
||||
|
||||
for state_id in self._old_state:
|
||||
metadata = self._new_widget_state.widget_metadata.get(state_id)
|
||||
if metadata is not None and metadata.value_type == "trigger_value":
|
||||
self._old_state[state_id] = False
|
||||
|
||||
def _remove_stale_widgets(self, active_widget_ids: set[str]) -> None:
|
||||
"""Remove widget state for widgets whose ids aren't in `active_widget_ids`."""
|
||||
self._new_widget_state.remove_stale_widgets(active_widget_ids)
|
||||
|
||||
# Remove entries from _old_state corresponding to
|
||||
# widgets not in widget_ids.
|
||||
self._old_state = {
|
||||
k: v
|
||||
for k, v in self._old_state.items()
|
||||
if (k in active_widget_ids or not is_widget_id(k))
|
||||
}
|
||||
|
||||
def _set_widget_metadata(self, widget_metadata: WidgetMetadata[Any]) -> None:
|
||||
"""Set a widget's metadata."""
|
||||
widget_id = widget_metadata.id
|
||||
self._new_widget_state.widget_metadata[widget_id] = widget_metadata
|
||||
|
||||
def get_widget_states(self) -> list[WidgetStateProto]:
|
||||
"""Return a list of serialized widget values for each widget with a value."""
|
||||
return self._new_widget_state.as_widget_states()
|
||||
|
||||
def _get_widget_id(self, k: str) -> str:
|
||||
"""Turns a value that might be a widget id or a user provided key into
|
||||
an appropriate widget id.
|
||||
"""
|
||||
return self._key_id_mapping.get(k, k)
|
||||
|
||||
def _set_key_widget_mapping(self, widget_id: str, user_key: str) -> None:
|
||||
self._key_id_mapping[user_key] = widget_id
|
||||
|
||||
def register_widget(
|
||||
self, metadata: WidgetMetadata[T], user_key: str | None
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget with the SessionState.
|
||||
|
||||
Returns
|
||||
-------
|
||||
RegisterWidgetResult[T]
|
||||
Contains the widget's current value, and a bool that will be True
|
||||
if the frontend needs to be updated with the current value.
|
||||
"""
|
||||
widget_id = metadata.id
|
||||
|
||||
self._set_widget_metadata(metadata)
|
||||
if user_key is not None:
|
||||
# If the widget has a user_key, update its user_key:widget_id mapping
|
||||
self._set_key_widget_mapping(widget_id, user_key)
|
||||
|
||||
if widget_id not in self and (user_key is None or user_key not in self):
|
||||
# This is the first time the widget is registered, so we save its
|
||||
# value in widget state.
|
||||
deserializer = metadata.deserializer
|
||||
initial_widget_value = deepcopy(deserializer(None, metadata.id))
|
||||
self._new_widget_state.set_from_value(widget_id, initial_widget_value)
|
||||
|
||||
# Get the current value of the widget for use as its return value.
|
||||
# We return a copy, so that reference types can't be accidentally
|
||||
# mutated by user code.
|
||||
widget_value = cast(T, self[widget_id])
|
||||
widget_value = deepcopy(widget_value)
|
||||
|
||||
# widget_value_changed indicates to the caller that the widget's
|
||||
# current value is different from what is in the frontend.
|
||||
widget_value_changed = user_key is not None and self.is_new_state_value(
|
||||
user_key
|
||||
)
|
||||
|
||||
return RegisterWidgetResult(widget_value, widget_value_changed)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
try:
|
||||
self[key]
|
||||
except KeyError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
stat = CacheStat("st_session_state", "", asizeof(self))
|
||||
return [stat]
|
||||
|
||||
def _check_serializable(self) -> None:
|
||||
"""Verify that everything added to session state can be serialized.
|
||||
We use pickleability as the metric for serializability, and test for
|
||||
pickleability by just trying it.
|
||||
"""
|
||||
for k in self:
|
||||
try:
|
||||
pickle.dumps(self[k])
|
||||
except Exception as e:
|
||||
err_msg = f"""Cannot serialize the value (of type `{type(self[k])}`) of '{k}' in st.session_state.
|
||||
Streamlit has been configured to use [pickle](https://docs.python.org/3/library/pickle.html) to
|
||||
serialize session_state values. Please convert the value to a pickle-serializable type. To learn
|
||||
more about this behavior, see [our docs](https://docs.streamlit.io/knowledge-base/using-streamlit/serializable-session-state). """
|
||||
raise UnserializableSessionStateError(err_msg) from e
|
||||
|
||||
def maybe_check_serializable(self) -> None:
|
||||
"""Verify that session state can be serialized, if the relevant config
|
||||
option is set.
|
||||
|
||||
See `_check_serializable` for details."""
|
||||
if config.get_option("runner.enforceSerializableSessionState"):
|
||||
self._check_serializable()
|
||||
|
||||
|
||||
def _is_internal_key(key: str) -> bool:
|
||||
return key.startswith(STREAMLIT_INTERNAL_KEY_PREFIX)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionStateStatProvider(CacheStatsProvider):
|
||||
_session_mgr: "SessionManager"
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
stats: list[CacheStat] = []
|
||||
for session_info in self._session_mgr.list_active_sessions():
|
||||
session_state = session_info.session.session_state
|
||||
stats.extend(session_state.get_stats())
|
||||
return stats
|
||||
@@ -0,0 +1,142 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Iterator, MutableMapping
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit import logger as _logger
|
||||
from streamlit import runtime
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.state.common import require_valid_user_key
|
||||
from streamlit.runtime.state.safe_session_state import SafeSessionState
|
||||
from streamlit.runtime.state.session_state import SessionState
|
||||
from streamlit.type_util import Key
|
||||
|
||||
LOGGER: Final = _logger.get_logger(__name__)
|
||||
|
||||
|
||||
_state_use_warning_already_displayed: bool = False
|
||||
|
||||
|
||||
def get_session_state() -> SafeSessionState:
|
||||
"""Get the SessionState object for the current session.
|
||||
|
||||
Note that in streamlit scripts, this function should not be called
|
||||
directly. Instead, SessionState objects should be accessed via
|
||||
st.session_state.
|
||||
"""
|
||||
global _state_use_warning_already_displayed
|
||||
from streamlit.runtime.scriptrunner import get_script_run_ctx
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
# If there is no script run context because the script is run bare, have
|
||||
# session state act as an always empty dictionary, and print a warning.
|
||||
if ctx is None:
|
||||
if not _state_use_warning_already_displayed:
|
||||
_state_use_warning_already_displayed = True
|
||||
if not runtime.exists():
|
||||
LOGGER.warning(
|
||||
"Session state does not function when running a script without `streamlit run`"
|
||||
)
|
||||
return SafeSessionState(SessionState())
|
||||
return ctx.session_state
|
||||
|
||||
|
||||
class SessionStateProxy(MutableMapping[Key, Any]):
|
||||
"""A stateless singleton that proxies `st.session_state` interactions
|
||||
to the current script thread's SessionState instance.
|
||||
|
||||
The proxy API differs slightly from SessionState: it does not allow
|
||||
callers to get, set, or iterate over "keyless" widgets (that is, widgets
|
||||
that were created without a user_key, and have autogenerated keys).
|
||||
"""
|
||||
|
||||
def __iter__(self) -> Iterator[Any]:
|
||||
"""Iterator over user state and keyed widget values."""
|
||||
# TODO: this is unsafe if fastReruns is true! Let's deprecate/remove.
|
||||
return iter(get_session_state().filtered_state)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Number of user state and keyed widget values in session_state."""
|
||||
return len(get_session_state().filtered_state)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of user state and keyed widget values."""
|
||||
return str(get_session_state().filtered_state)
|
||||
|
||||
def __getitem__(self, key: Key) -> Any:
|
||||
"""Return the state or widget value with the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
return get_session_state()[key]
|
||||
|
||||
@gather_metrics("session_state.set_item")
|
||||
def __setitem__(self, key: Key, value: Any) -> None:
|
||||
"""Set the value of the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
get_session_state()[key] = value
|
||||
|
||||
def __delitem__(self, key: Key) -> None:
|
||||
"""Delete the value with the given key.
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
If the key is not a valid SessionState user key.
|
||||
"""
|
||||
key = str(key)
|
||||
require_valid_user_key(key)
|
||||
del get_session_state()[key]
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
@gather_metrics("session_state.set_attr")
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key: str) -> None:
|
||||
try:
|
||||
del self[key]
|
||||
except KeyError:
|
||||
raise AttributeError(_missing_attr_error_message(key))
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Return a dict containing all session_state and keyed widget values."""
|
||||
return get_session_state().filtered_state
|
||||
|
||||
|
||||
def _missing_attr_error_message(attr_name: str) -> str:
|
||||
return (
|
||||
f'st.session_state has no attribute "{attr_name}". Did you forget to initialize it? '
|
||||
f"More info: https://docs.streamlit.io/library/advanced-features/session-state#initialization"
|
||||
)
|
||||
@@ -0,0 +1,281 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import textwrap
|
||||
from types import MappingProxyType
|
||||
from typing import TYPE_CHECKING, Dict, Mapping, Optional
|
||||
|
||||
from typing_extensions import Final, TypeAlias
|
||||
|
||||
from streamlit.errors import DuplicateWidgetID
|
||||
from streamlit.proto.WidgetStates_pb2 import WidgetState, WidgetStates
|
||||
from streamlit.runtime.state.common import (
|
||||
RegisterWidgetResult,
|
||||
T,
|
||||
WidgetArgs,
|
||||
WidgetCallback,
|
||||
WidgetDeserializer,
|
||||
WidgetKwargs,
|
||||
WidgetMetadata,
|
||||
WidgetProto,
|
||||
WidgetSerializer,
|
||||
compute_widget_id,
|
||||
user_key_from_widget_id,
|
||||
)
|
||||
from streamlit.type_util import ValueFieldName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.runtime.scriptrunner import ScriptRunContext
|
||||
|
||||
ElementType: TypeAlias = str
|
||||
|
||||
# NOTE: We use this table to start with a best-effort guess for the value_type
|
||||
# of each widget. Once we actually receive a proto for a widget from the
|
||||
# frontend, the guess is updated to be the correct type. Unfortunately, we're
|
||||
# not able to always rely on the proto as the type may be needed earlier.
|
||||
# Thankfully, in these cases (when value_type == "trigger_value"), the static
|
||||
# table here being slightly inaccurate should never pose a problem.
|
||||
ELEMENT_TYPE_TO_VALUE_TYPE: Final[
|
||||
Mapping[ElementType, ValueFieldName]
|
||||
] = MappingProxyType(
|
||||
{
|
||||
"button": "trigger_value",
|
||||
"download_button": "trigger_value",
|
||||
"checkbox": "bool_value",
|
||||
"camera_input": "file_uploader_state_value",
|
||||
"color_picker": "string_value",
|
||||
"date_input": "string_array_value",
|
||||
"file_uploader": "file_uploader_state_value",
|
||||
"multiselect": "int_array_value",
|
||||
"number_input": "double_value",
|
||||
"radio": "int_value",
|
||||
"selectbox": "int_value",
|
||||
"slider": "double_array_value",
|
||||
"text_area": "string_value",
|
||||
"text_input": "string_value",
|
||||
"time_input": "string_value",
|
||||
"component_instance": "json_value",
|
||||
"data_editor": "string_value",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class NoValue:
|
||||
"""Return this from DeltaGenerator.foo_widget() when you want the st.foo_widget()
|
||||
call to return None. This is needed because `DeltaGenerator._enqueue`
|
||||
replaces `None` with a `DeltaGenerator` (for use in non-widget elements).
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def register_widget(
|
||||
element_type: ElementType,
|
||||
element_proto: WidgetProto,
|
||||
deserializer: WidgetDeserializer[T],
|
||||
serializer: WidgetSerializer[T],
|
||||
ctx: Optional["ScriptRunContext"],
|
||||
user_key: Optional[str] = None,
|
||||
widget_func_name: Optional[str] = None,
|
||||
on_change_handler: Optional[WidgetCallback] = None,
|
||||
args: Optional[WidgetArgs] = None,
|
||||
kwargs: Optional[WidgetKwargs] = None,
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget with Streamlit, and return its current value.
|
||||
NOTE: This function should be called after the proto has been filled.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
element_type : ElementType
|
||||
The type of the element as stored in proto.
|
||||
element_proto : WidgetProto
|
||||
The proto of the specified type (e.g. Button/Multiselect/Slider proto)
|
||||
deserializer : WidgetDeserializer[T]
|
||||
Called to convert a widget's protobuf value to the value returned by
|
||||
its st.<widget_name> function.
|
||||
serializer : WidgetSerializer[T]
|
||||
Called to convert a widget's value to its protobuf representation.
|
||||
ctx : Optional[ScriptRunContext]
|
||||
Used to ensure uniqueness of widget IDs, and to look up widget values.
|
||||
user_key : Optional[str]
|
||||
Optional user-specified string to use as the widget ID.
|
||||
If this is None, we'll generate an ID by hashing the element.
|
||||
widget_func_name : Optional[str]
|
||||
The widget's DeltaGenerator function name, if it's different from
|
||||
its element_type. Custom components are a special case: they all have
|
||||
the element_type "component_instance", but are instantiated with
|
||||
dynamically-named functions.
|
||||
on_change_handler : Optional[WidgetCallback]
|
||||
An optional callback invoked when the widget's value changes.
|
||||
args : Optional[WidgetArgs]
|
||||
args to pass to on_change_handler when invoked
|
||||
kwargs : Optional[WidgetKwargs]
|
||||
kwargs to pass to on_change_handler when invoked
|
||||
|
||||
Returns
|
||||
-------
|
||||
register_widget_result : RegisterWidgetResult[T]
|
||||
Provides information on which value to return to the widget caller,
|
||||
and whether the UI needs updating.
|
||||
|
||||
- Unhappy path:
|
||||
- Our ScriptRunContext doesn't exist (meaning that we're running
|
||||
as a "bare script" outside streamlit).
|
||||
- We are disconnected from the SessionState instance.
|
||||
In both cases we'll return a fallback RegisterWidgetResult[T].
|
||||
- Happy path:
|
||||
- The widget has already been registered on a previous run but the
|
||||
user hasn't interacted with it on the client. The widget will have
|
||||
the default value it was first created with. We then return a
|
||||
RegisterWidgetResult[T], containing this value.
|
||||
- The widget has already been registered and the user *has*
|
||||
interacted with it. The widget will have that most recent
|
||||
user-specified value. We then return a RegisterWidgetResult[T],
|
||||
containing this value.
|
||||
|
||||
For both paths a widget return value is provided, allowing the widgets
|
||||
to be used in a non-streamlit setting.
|
||||
"""
|
||||
widget_id = compute_widget_id(element_type, element_proto, user_key)
|
||||
element_proto.id = widget_id
|
||||
|
||||
# Create the widget's updated metadata, and register it with session_state.
|
||||
metadata = WidgetMetadata(
|
||||
widget_id,
|
||||
deserializer,
|
||||
serializer,
|
||||
value_type=ELEMENT_TYPE_TO_VALUE_TYPE[element_type],
|
||||
callback=on_change_handler,
|
||||
callback_args=args,
|
||||
callback_kwargs=kwargs,
|
||||
)
|
||||
return register_widget_from_metadata(metadata, ctx, widget_func_name, element_type)
|
||||
|
||||
|
||||
def register_widget_from_metadata(
|
||||
metadata: WidgetMetadata[T],
|
||||
ctx: Optional["ScriptRunContext"],
|
||||
widget_func_name: Optional[str],
|
||||
element_type: ElementType,
|
||||
) -> RegisterWidgetResult[T]:
|
||||
"""Register a widget and return its value, using an already constructed
|
||||
`WidgetMetadata`.
|
||||
|
||||
This is split out from `register_widget` to allow caching code to replay
|
||||
widgets by saving and reusing the completed metadata.
|
||||
|
||||
See `register_widget` for details on what this returns.
|
||||
"""
|
||||
# Local import to avoid import cycle
|
||||
import streamlit.runtime.caching as caching
|
||||
|
||||
if ctx is None:
|
||||
# Early-out if we don't have a script run context (which probably means
|
||||
# we're running as a "bare" Python script, and not via `streamlit run`).
|
||||
return RegisterWidgetResult.failure(deserializer=metadata.deserializer)
|
||||
|
||||
widget_id = metadata.id
|
||||
user_key = user_key_from_widget_id(widget_id)
|
||||
|
||||
# Ensure another widget with the same user key hasn't already been registered.
|
||||
if user_key is not None:
|
||||
if user_key not in ctx.widget_user_keys_this_run:
|
||||
ctx.widget_user_keys_this_run.add(user_key)
|
||||
else:
|
||||
raise DuplicateWidgetID(
|
||||
_build_duplicate_widget_message(
|
||||
widget_func_name if widget_func_name is not None else element_type,
|
||||
user_key,
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure another widget with the same id hasn't already been registered.
|
||||
new_widget = widget_id not in ctx.widget_ids_this_run
|
||||
if new_widget:
|
||||
ctx.widget_ids_this_run.add(widget_id)
|
||||
else:
|
||||
raise DuplicateWidgetID(
|
||||
_build_duplicate_widget_message(
|
||||
widget_func_name if widget_func_name is not None else element_type,
|
||||
user_key,
|
||||
)
|
||||
)
|
||||
# Save the widget metadata for cached result replay
|
||||
caching.save_widget_metadata(metadata)
|
||||
return ctx.session_state.register_widget(metadata, user_key)
|
||||
|
||||
|
||||
def coalesce_widget_states(
|
||||
old_states: WidgetStates, new_states: WidgetStates
|
||||
) -> WidgetStates:
|
||||
"""Coalesce an older WidgetStates into a newer one, and return a new
|
||||
WidgetStates containing the result.
|
||||
|
||||
For most widget values, we just take the latest version.
|
||||
|
||||
However, any trigger_values (which are set by buttons) that are True in
|
||||
`old_states` will be set to True in the coalesced result, so that button
|
||||
presses don't go missing.
|
||||
"""
|
||||
states_by_id: Dict[str, WidgetState] = {
|
||||
wstate.id: wstate for wstate in new_states.widgets
|
||||
}
|
||||
|
||||
for old_state in old_states.widgets:
|
||||
if old_state.WhichOneof("value") == "trigger_value" and old_state.trigger_value:
|
||||
|
||||
# Ensure the corresponding new_state is also a trigger;
|
||||
# otherwise, a widget that was previously a button but no longer is
|
||||
# could get a bad value.
|
||||
new_trigger_val = states_by_id.get(old_state.id)
|
||||
if (
|
||||
new_trigger_val
|
||||
and new_trigger_val.WhichOneof("value") == "trigger_value"
|
||||
):
|
||||
states_by_id[old_state.id] = old_state
|
||||
|
||||
coalesced = WidgetStates()
|
||||
coalesced.widgets.extend(states_by_id.values())
|
||||
|
||||
return coalesced
|
||||
|
||||
|
||||
def _build_duplicate_widget_message(
|
||||
widget_func_name: str, user_key: Optional[str] = None
|
||||
) -> str:
|
||||
if user_key is not None:
|
||||
message = textwrap.dedent(
|
||||
"""
|
||||
There are multiple widgets with the same `key='{user_key}'`.
|
||||
|
||||
To fix this, please make sure that the `key` argument is unique for each
|
||||
widget you create.
|
||||
"""
|
||||
)
|
||||
else:
|
||||
message = textwrap.dedent(
|
||||
"""
|
||||
There are multiple identical `st.{widget_type}` widgets with the
|
||||
same generated key.
|
||||
|
||||
When a widget is created, it's assigned an internal key based on
|
||||
its structure. Multiple widgets with an identical structure will
|
||||
result in the same internal key, which causes this error.
|
||||
|
||||
To fix this error, please pass a unique `key` argument to
|
||||
`st.{widget_type}`.
|
||||
"""
|
||||
)
|
||||
|
||||
return message.strip("\n").format(widget_type=widget_func_name, user_key=user_key)
|
||||
88
venv/lib/python3.9/site-packages/streamlit/runtime/stats.py
Normal file
88
venv/lib/python3.9/site-packages/streamlit/runtime/stats.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import List, NamedTuple
|
||||
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
from streamlit.proto.openmetrics_data_model_pb2 import Metric as MetricProto
|
||||
|
||||
|
||||
class CacheStat(NamedTuple):
|
||||
"""Describes a single cache entry.
|
||||
|
||||
Properties
|
||||
----------
|
||||
category_name : str
|
||||
A human-readable name for the cache "category" that the entry belongs
|
||||
to - e.g. "st.memo", "session_state", etc.
|
||||
cache_name : str
|
||||
A human-readable name for cache instance that the entry belongs to.
|
||||
For "st.memo" and other function decorator caches, this might be the
|
||||
name of the cached function. If the cache category doesn't have
|
||||
multiple separate cache instances, this can just be the empty string.
|
||||
byte_length : int
|
||||
The entry's memory footprint in bytes.
|
||||
"""
|
||||
|
||||
category_name: str
|
||||
cache_name: str
|
||||
byte_length: int
|
||||
|
||||
def to_metric_str(self) -> str:
|
||||
return 'cache_memory_bytes{cache_type="%s",cache="%s"} %s' % (
|
||||
self.category_name,
|
||||
self.cache_name,
|
||||
self.byte_length,
|
||||
)
|
||||
|
||||
def marshall_metric_proto(self, metric: MetricProto) -> None:
|
||||
"""Fill an OpenMetrics `Metric` protobuf object."""
|
||||
label = metric.labels.add()
|
||||
label.name = "cache_type"
|
||||
label.value = self.category_name
|
||||
|
||||
label = metric.labels.add()
|
||||
label.name = "cache"
|
||||
label.value = self.cache_name
|
||||
|
||||
metric_point = metric.metric_points.add()
|
||||
metric_point.gauge_value.int_value = self.byte_length
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CacheStatsProvider(Protocol):
|
||||
@abstractmethod
|
||||
def get_stats(self) -> List[CacheStat]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StatsManager:
|
||||
def __init__(self):
|
||||
self._cache_stats_providers: List[CacheStatsProvider] = []
|
||||
|
||||
def register_provider(self, provider: CacheStatsProvider) -> None:
|
||||
"""Register a CacheStatsProvider with the manager.
|
||||
This function is not thread-safe. Call it immediately after
|
||||
creation.
|
||||
"""
|
||||
self._cache_stats_providers.append(provider)
|
||||
|
||||
def get_stats(self) -> List[CacheStat]:
|
||||
"""Return a list containing all stats from each registered provider."""
|
||||
all_stats: List[CacheStat] = []
|
||||
for provider in self._cache_stats_providers:
|
||||
all_stats.extend(provider.get_stats())
|
||||
return all_stats
|
||||
@@ -0,0 +1,334 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import io
|
||||
import threading
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
|
||||
from blinker import Signal
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class UploadedFileRec(NamedTuple):
|
||||
"""Metadata and raw bytes for an uploaded file. Immutable."""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
type: str
|
||||
data: bytes
|
||||
|
||||
|
||||
class UploadedFile(io.BytesIO):
|
||||
"""A mutable uploaded file.
|
||||
|
||||
This class extends BytesIO, which has copy-on-write semantics when
|
||||
initialized with `bytes`.
|
||||
"""
|
||||
|
||||
def __init__(self, record: UploadedFileRec):
|
||||
# BytesIO's copy-on-write semantics doesn't seem to be mentioned in
|
||||
# the Python docs - possibly because it's a CPython-only optimization
|
||||
# and not guaranteed to be in other Python runtimes. But it's detailed
|
||||
# here: https://hg.python.org/cpython/rev/79a5fbe2c78f
|
||||
super(UploadedFile, self).__init__(record.data)
|
||||
self.id = record.id
|
||||
self.name = record.name
|
||||
self.type = record.type
|
||||
self.size = len(record.data)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, UploadedFile):
|
||||
return NotImplemented
|
||||
return self.id == other.id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
|
||||
class UploadedFileManager(CacheStatsProvider):
|
||||
"""Holds files uploaded by users of the running Streamlit app,
|
||||
and emits an event signal when a file is added.
|
||||
|
||||
This class can be used safely from multiple threads simultaneously.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# List of files for a given widget in a given session.
|
||||
self._files_by_id: Dict[Tuple[str, str], List[UploadedFileRec]] = {}
|
||||
|
||||
# A counter that generates unique file IDs. Each file ID is greater
|
||||
# than the previous ID, which means we can use IDs to compare files
|
||||
# by age.
|
||||
self._file_id_counter = 1
|
||||
self._file_id_lock = threading.Lock()
|
||||
|
||||
# Prevents concurrent access to the _files_by_id dict.
|
||||
# In remove_session_files(), we iterate over the dict's keys. It's
|
||||
# an error to mutate a dict while iterating; this lock prevents that.
|
||||
self._files_lock = threading.Lock()
|
||||
self.on_files_updated = Signal(
|
||||
doc="""Emitted when a file list is added to the manager or updated.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id : str
|
||||
The session_id for the session whose files were updated.
|
||||
"""
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def add_file(
|
||||
self,
|
||||
session_id: str,
|
||||
widget_id: str,
|
||||
file: UploadedFileRec,
|
||||
) -> UploadedFileRec:
|
||||
"""Add a file to the FileManager, and return a new UploadedFileRec
|
||||
with its ID assigned.
|
||||
|
||||
The "on_files_updated" Signal will be emitted.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that owns the file.
|
||||
widget_id
|
||||
The widget ID of the FileUploader that created the file.
|
||||
file
|
||||
The file to add.
|
||||
|
||||
Returns
|
||||
-------
|
||||
UploadedFileRec
|
||||
The added file, which has its unique ID assigned.
|
||||
"""
|
||||
files_by_widget = session_id, widget_id
|
||||
|
||||
# Assign the file a unique ID
|
||||
file_id = self._get_next_file_id()
|
||||
file = UploadedFileRec(
|
||||
id=file_id, name=file.name, type=file.type, data=file.data
|
||||
)
|
||||
|
||||
with self._files_lock:
|
||||
file_list = self._files_by_id.get(files_by_widget, None)
|
||||
if file_list is not None:
|
||||
file_list.append(file)
|
||||
else:
|
||||
self._files_by_id[files_by_widget] = [file]
|
||||
|
||||
self.on_files_updated.send(session_id)
|
||||
return file
|
||||
|
||||
def get_all_files(self, session_id: str, widget_id: str) -> List[UploadedFileRec]:
|
||||
"""Return all the files stored for the given widget.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that owns the files.
|
||||
widget_id
|
||||
The widget ID of the FileUploader that created the files.
|
||||
"""
|
||||
file_list_id = (session_id, widget_id)
|
||||
with self._files_lock:
|
||||
return self._files_by_id.get(file_list_id, []).copy()
|
||||
|
||||
def get_files(
|
||||
self, session_id: str, widget_id: str, file_ids: List[int]
|
||||
) -> List[UploadedFileRec]:
|
||||
"""Return the files with the given widget_id and file_ids.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id
|
||||
The ID of the session that owns the files.
|
||||
widget_id
|
||||
The widget ID of the FileUploader that created the files.
|
||||
file_ids
|
||||
List of file IDs. Only files whose IDs are in this list will be
|
||||
returned.
|
||||
"""
|
||||
return [
|
||||
f for f in self.get_all_files(session_id, widget_id) if f.id in file_ids
|
||||
]
|
||||
|
||||
def remove_orphaned_files(
|
||||
self,
|
||||
session_id: str,
|
||||
widget_id: str,
|
||||
newest_file_id: int,
|
||||
active_file_ids: List[int],
|
||||
) -> None:
|
||||
"""Remove 'orphaned' files: files that have been uploaded and
|
||||
subsequently deleted, but haven't yet been removed from memory.
|
||||
|
||||
Because FileUploader can live inside forms, file deletion is made a
|
||||
bit tricky: a file deletion should only happen after the form is
|
||||
submitted.
|
||||
|
||||
FileUploader's widget value is an array of numbers that has two parts:
|
||||
- The first number is always 'this.state.newestServerFileId'.
|
||||
- The remaining 0 or more numbers are the file IDs of all the
|
||||
uploader's uploaded files.
|
||||
|
||||
When the server receives the widget value, it deletes "orphaned"
|
||||
uploaded files. An orphaned file is any file associated with a given
|
||||
FileUploader whose file ID is not in the active_file_ids, and whose
|
||||
ID is <= `newestServerFileId`.
|
||||
|
||||
This logic ensures that a FileUploader within a form doesn't have any
|
||||
of its "unsubmitted" uploads prematurely deleted when the script is
|
||||
re-run.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
file_list_id = (session_id, widget_id)
|
||||
with self._files_lock:
|
||||
file_list = self._files_by_id.get(file_list_id)
|
||||
if file_list is None:
|
||||
return
|
||||
|
||||
# Remove orphaned files from the list:
|
||||
# - `f.id in active_file_ids`:
|
||||
# File is currently tracked by the widget. DON'T remove.
|
||||
# - `f.id > newest_file_id`:
|
||||
# file was uploaded *after* the widget was most recently
|
||||
# updated. (It's probably in a form.) DON'T remove.
|
||||
# - `f.id < newest_file_id and f.id not in active_file_ids`:
|
||||
# File is not currently tracked by the widget, and was uploaded
|
||||
# *before* this most recent update. This means it's been deleted
|
||||
# by the user on the frontend, and is now "orphaned". Remove!
|
||||
new_list = [
|
||||
f for f in file_list if f.id > newest_file_id or f.id in active_file_ids
|
||||
]
|
||||
self._files_by_id[file_list_id] = new_list
|
||||
num_removed = len(file_list) - len(new_list)
|
||||
|
||||
if num_removed > 0:
|
||||
LOGGER.debug("Removed %s orphaned files" % num_removed)
|
||||
|
||||
def remove_file(self, session_id: str, widget_id: str, file_id: int) -> bool:
|
||||
"""Remove the file list with the given ID, if it exists.
|
||||
|
||||
The "on_files_updated" Signal will be emitted.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the file was removed, or False if no such file exists.
|
||||
"""
|
||||
file_list_id = (session_id, widget_id)
|
||||
with self._files_lock:
|
||||
file_list = self._files_by_id.get(file_list_id, None)
|
||||
if file_list is None:
|
||||
return False
|
||||
|
||||
# Remove the file from its list.
|
||||
new_file_list = [file for file in file_list if file.id != file_id]
|
||||
self._files_by_id[file_list_id] = new_file_list
|
||||
|
||||
self.on_files_updated.send(session_id)
|
||||
return True
|
||||
|
||||
def _remove_files(self, session_id: str, widget_id: str) -> None:
|
||||
"""Remove the file list for the provided widget in the
|
||||
provided session, if it exists.
|
||||
|
||||
Does not emit any signals.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
files_by_widget = session_id, widget_id
|
||||
with self._files_lock:
|
||||
self._files_by_id.pop(files_by_widget, None)
|
||||
|
||||
def remove_files(self, session_id: str, widget_id: str) -> None:
|
||||
"""Remove the file list for the provided widget in the
|
||||
provided session, if it exists.
|
||||
|
||||
The "on_files_updated" Signal will be emitted.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id : str
|
||||
The ID of the session that owns the files.
|
||||
widget_id : str
|
||||
The widget ID of the FileUploader that created the files.
|
||||
"""
|
||||
self._remove_files(session_id, widget_id)
|
||||
self.on_files_updated.send(session_id)
|
||||
|
||||
def remove_session_files(self, session_id: str) -> None:
|
||||
"""Remove all files that belong to the given session.
|
||||
|
||||
Safe to call from any thread.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
session_id : str
|
||||
The ID of the session whose files we're removing.
|
||||
|
||||
"""
|
||||
# Copy the keys into a list, because we'll be mutating the dictionary.
|
||||
with self._files_lock:
|
||||
all_ids = list(self._files_by_id.keys())
|
||||
|
||||
for files_id in all_ids:
|
||||
if files_id[0] == session_id:
|
||||
self.remove_files(*files_id)
|
||||
|
||||
def _get_next_file_id(self) -> int:
|
||||
"""Return the next file ID and increment our ID counter."""
|
||||
with self._file_id_lock:
|
||||
file_id = self._file_id_counter
|
||||
self._file_id_counter += 1
|
||||
return file_id
|
||||
|
||||
def get_stats(self) -> List[CacheStat]:
|
||||
"""Return the manager's CacheStats.
|
||||
|
||||
Safe to call from any thread.
|
||||
"""
|
||||
with self._files_lock:
|
||||
# Flatten all files into a single list
|
||||
all_files: List[UploadedFileRec] = []
|
||||
for file_list in self._files_by_id.values():
|
||||
all_files.extend(file_list)
|
||||
|
||||
return [
|
||||
CacheStat(
|
||||
category_name="UploadedFileManager",
|
||||
cache_name="",
|
||||
byte_length=len(file.data),
|
||||
)
|
||||
for file in all_files
|
||||
]
|
||||
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, Dict, List, Optional, cast
|
||||
|
||||
from typing_extensions import Final
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.app_session import AppSession
|
||||
from streamlit.runtime.script_data import ScriptData
|
||||
from streamlit.runtime.session_manager import (
|
||||
ActiveSessionInfo,
|
||||
SessionClient,
|
||||
SessionInfo,
|
||||
SessionManager,
|
||||
SessionStorage,
|
||||
)
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
|
||||
from streamlit.watcher import LocalSourcesWatcher
|
||||
|
||||
LOGGER: Final = get_logger(__name__)
|
||||
|
||||
|
||||
class WebsocketSessionManager(SessionManager):
|
||||
"""A SessionManager used to manage sessions with lifecycles tied to those of a
|
||||
browser tab's websocket connection.
|
||||
|
||||
WebsocketSessionManagers differentiate between "active" and "inactive" sessions.
|
||||
Active sessions are those with a currently active websocket connection. Inactive
|
||||
sessions are sessions without. Eventual cleanup of inactive sessions is a detail left
|
||||
to the specific SessionStorage that a WebsocketSessionManager is instantiated with.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_storage: SessionStorage,
|
||||
uploaded_file_manager: UploadedFileManager,
|
||||
message_enqueued_callback: Optional[Callable[[], None]],
|
||||
) -> None:
|
||||
self._session_storage = session_storage
|
||||
self._uploaded_file_mgr = uploaded_file_manager
|
||||
self._message_enqueued_callback = message_enqueued_callback
|
||||
|
||||
# Mapping of AppSession.id -> ActiveSessionInfo.
|
||||
self._active_session_info_by_id: Dict[str, ActiveSessionInfo] = {}
|
||||
|
||||
def connect_session(
|
||||
self,
|
||||
client: SessionClient,
|
||||
script_data: ScriptData,
|
||||
user_info: Dict[str, Optional[str]],
|
||||
existing_session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
if existing_session_id in self._active_session_info_by_id:
|
||||
LOGGER.warning(
|
||||
"Session with id %s is already connected! Connecting to a new session.",
|
||||
existing_session_id,
|
||||
)
|
||||
|
||||
session_info = (
|
||||
existing_session_id
|
||||
and existing_session_id not in self._active_session_info_by_id
|
||||
and self._session_storage.get(existing_session_id)
|
||||
)
|
||||
|
||||
if session_info:
|
||||
existing_session = session_info.session
|
||||
existing_session.register_file_watchers()
|
||||
|
||||
self._active_session_info_by_id[existing_session.id] = ActiveSessionInfo(
|
||||
client,
|
||||
existing_session,
|
||||
session_info.script_run_count,
|
||||
)
|
||||
self._session_storage.delete(existing_session.id)
|
||||
|
||||
return existing_session.id
|
||||
|
||||
session = AppSession(
|
||||
script_data=script_data,
|
||||
uploaded_file_manager=self._uploaded_file_mgr,
|
||||
message_enqueued_callback=self._message_enqueued_callback,
|
||||
local_sources_watcher=LocalSourcesWatcher(script_data.main_script_path),
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
LOGGER.debug(
|
||||
"Created new session for client %s. Session ID: %s", id(client), session.id
|
||||
)
|
||||
|
||||
assert (
|
||||
session.id not in self._active_session_info_by_id
|
||||
), f"session.id '{session.id}' registered multiple times!"
|
||||
|
||||
self._active_session_info_by_id[session.id] = ActiveSessionInfo(client, session)
|
||||
return session.id
|
||||
|
||||
def disconnect_session(self, session_id: str) -> None:
|
||||
if session_id in self._active_session_info_by_id:
|
||||
active_session_info = self._active_session_info_by_id[session_id]
|
||||
session = active_session_info.session
|
||||
|
||||
session.request_script_stop()
|
||||
session.disconnect_file_watchers()
|
||||
|
||||
self._session_storage.save(
|
||||
SessionInfo(
|
||||
client=None,
|
||||
session=session,
|
||||
script_run_count=active_session_info.script_run_count,
|
||||
)
|
||||
)
|
||||
del self._active_session_info_by_id[session_id]
|
||||
|
||||
def get_active_session_info(self, session_id: str) -> Optional[ActiveSessionInfo]:
|
||||
return self._active_session_info_by_id.get(session_id)
|
||||
|
||||
def is_active_session(self, session_id: str) -> bool:
|
||||
return session_id in self._active_session_info_by_id
|
||||
|
||||
def list_active_sessions(self) -> List[ActiveSessionInfo]:
|
||||
return list(self._active_session_info_by_id.values())
|
||||
|
||||
def close_session(self, session_id: str) -> None:
|
||||
if session_id in self._active_session_info_by_id:
|
||||
active_session_info = self._active_session_info_by_id[session_id]
|
||||
del self._active_session_info_by_id[session_id]
|
||||
active_session_info.session.shutdown()
|
||||
return
|
||||
|
||||
session_info = self._session_storage.get(session_id)
|
||||
if session_info:
|
||||
self._session_storage.delete(session_id)
|
||||
session_info.session.shutdown()
|
||||
|
||||
def get_session_info(self, session_id: str) -> Optional[SessionInfo]:
|
||||
session_info = self.get_active_session_info(session_id)
|
||||
if session_info:
|
||||
return cast(SessionInfo, session_info)
|
||||
return self._session_storage.get(session_id)
|
||||
|
||||
def list_sessions(self) -> List[SessionInfo]:
|
||||
return (
|
||||
cast(List[SessionInfo], self.list_active_sessions())
|
||||
+ self._session_storage.list()
|
||||
)
|
||||
Reference in New Issue
Block a user