Merging PR_218 openai_rev package with new streamlit chat app
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
from typing import Any, Iterator, Union
|
||||
|
||||
from google.protobuf.message import Message
|
||||
|
||||
from streamlit.proto.Block_pb2 import Block
|
||||
from streamlit.runtime.caching.cache_data_api import (
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX,
|
||||
CacheDataAPI,
|
||||
_data_caches,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_errors import CACHE_DOCS_URL as CACHE_DOCS_URL
|
||||
from streamlit.runtime.caching.cache_resource_api import (
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
|
||||
CacheResourceAPI,
|
||||
_resource_caches,
|
||||
)
|
||||
from streamlit.runtime.state.common import WidgetMetadata
|
||||
|
||||
|
||||
def save_element_message(
|
||||
delta_type: str,
|
||||
element_proto: Message,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Save the message for an element to a thread-local callstack, so it can
|
||||
be used later to replay the element when a cache-decorated function's
|
||||
execution is skipped.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
|
||||
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
|
||||
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
|
||||
|
||||
def save_block_message(
|
||||
block_proto: Block,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Save the message for a block to a thread-local callstack, so it can
|
||||
be used later to replay the block when a cache-decorated function's
|
||||
execution is skipped.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
|
||||
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
|
||||
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
|
||||
)
|
||||
|
||||
|
||||
def save_widget_metadata(metadata: WidgetMetadata[Any]) -> None:
|
||||
"""Save a widget's metadata to a thread-local callstack, so the widget
|
||||
can be registered again when that widget is replayed.
|
||||
"""
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
|
||||
|
||||
|
||||
def save_media_data(
|
||||
image_data: Union[bytes, str], mimetype: str, image_id: str
|
||||
) -> None:
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
|
||||
|
||||
|
||||
def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
|
||||
dg, st_func_name
|
||||
)
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
|
||||
dg, st_func_name
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cached_st_function_warning() -> Iterator[None]:
|
||||
with CACHE_DATA_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning(), CACHE_RESOURCE_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning():
|
||||
yield
|
||||
|
||||
|
||||
# Explicitly export public symbols
|
||||
from streamlit.runtime.caching.cache_data_api import (
|
||||
get_data_cache_stats_provider as get_data_cache_stats_provider,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_resource_api import (
|
||||
get_resource_cache_stats_provider as get_resource_cache_stats_provider,
|
||||
)
|
||||
|
||||
# Create and export public API singletons.
|
||||
cache_data = CacheDataAPI(decorator_metric_name="cache_data")
|
||||
cache_resource = CacheResourceAPI(decorator_metric_name="cache_resource")
|
||||
|
||||
# Deprecated singletons
|
||||
_MEMO_WARNING = (
|
||||
f"`st.experimental_memo` is deprecated. Please use the new command `st.cache_data` instead, "
|
||||
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
|
||||
)
|
||||
|
||||
experimental_memo = CacheDataAPI(
|
||||
decorator_metric_name="experimental_memo", deprecation_warning=_MEMO_WARNING
|
||||
)
|
||||
|
||||
_SINGLETON_WARNING = (
|
||||
f"`st.experimental_singleton` is deprecated. Please use the new command `st.cache_resource` instead, "
|
||||
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
|
||||
)
|
||||
|
||||
experimental_singleton = CacheResourceAPI(
|
||||
decorator_metric_name="experimental_singleton",
|
||||
deprecation_warning=_SINGLETON_WARNING,
|
||||
)
|
||||
@@ -0,0 +1,678 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""@st.cache_data: pickle-based caching"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
import threading
|
||||
import types
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, TypeVar, Union, cast, overload
|
||||
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import runtime
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.errors import StreamlitAPIException
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cache_utils import (
|
||||
Cache,
|
||||
CachedFuncInfo,
|
||||
make_cached_func_wrapper,
|
||||
ttl_to_seconds,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
ElementMsgData,
|
||||
MsgData,
|
||||
MultiCacheResults,
|
||||
)
|
||||
from streamlit.runtime.caching.storage import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
InvalidCacheStorageContext,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.dummy_cache_storage import (
|
||||
MemoryCacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)
|
||||
|
||||
# The cache persistence options we support: "disk" or None
|
||||
CachePersistType: TypeAlias = Union[Literal["disk"], None]
|
||||
|
||||
|
||||
class CachedDataFuncInfo(CachedFuncInfo):
|
||||
"""Implements the CachedFuncInfo interface for @st.cache_data"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: float | timedelta | None,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
super().__init__(
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self.persist = persist
|
||||
self.max_entries = max_entries
|
||||
self.ttl = ttl
|
||||
|
||||
self.validate_params()
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
return CacheType.DATA
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
return CACHE_DATA_MESSAGE_REPLAY_CTX
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""A human-readable name for the cached function"""
|
||||
return f"{self.func.__module__}.{self.func.__qualname__}"
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
return _data_caches.get_cache(
|
||||
key=function_key,
|
||||
persist=self.persist,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
display_name=self.display_name,
|
||||
allow_widgets=self.allow_widgets,
|
||||
)
|
||||
|
||||
def validate_params(self) -> None:
|
||||
"""
|
||||
Validate the params passed to @st.cache_data are compatible with cache storage
|
||||
|
||||
When called, this method could log warnings if cache params are invalid
|
||||
for current storage.
|
||||
"""
|
||||
_data_caches.validate_cache_params(
|
||||
function_name=self.func.__name__,
|
||||
persist=self.persist,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
)
|
||||
|
||||
|
||||
class DataCaches(CacheStatsProvider):
|
||||
"""Manages all DataCache instances"""
|
||||
|
||||
def __init__(self):
|
||||
self._caches_lock = threading.Lock()
|
||||
self._function_caches: dict[str, DataCache] = {}
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: int | float | timedelta | None,
|
||||
display_name: str,
|
||||
allow_widgets: bool,
|
||||
) -> DataCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
|
||||
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._caches_lock:
|
||||
cache = self._function_caches.get(key)
|
||||
if (
|
||||
cache is not None
|
||||
and cache.ttl_seconds == ttl_seconds
|
||||
and cache.max_entries == max_entries
|
||||
and cache.persist == persist
|
||||
):
|
||||
return cache
|
||||
|
||||
# Close the existing cache's storage, if it exists.
|
||||
if cache is not None:
|
||||
_LOGGER.debug(
|
||||
"Closing existing DataCache storage "
|
||||
"(key=%s, persist=%s, max_entries=%s, ttl=%s) "
|
||||
"before creating new one with different params",
|
||||
key,
|
||||
persist,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
cache.storage.close()
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug(
|
||||
"Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
|
||||
key,
|
||||
persist,
|
||||
max_entries,
|
||||
ttl,
|
||||
)
|
||||
|
||||
cache_context = self.create_cache_storage_context(
|
||||
function_key=key,
|
||||
function_name=display_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
cache_storage_manager = self.get_storage_manager()
|
||||
storage = cache_storage_manager.create(cache_context)
|
||||
|
||||
cache = DataCache(
|
||||
key=key,
|
||||
storage=storage,
|
||||
persist=persist,
|
||||
max_entries=max_entries,
|
||||
ttl_seconds=ttl_seconds,
|
||||
display_name=display_name,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self._function_caches[key] = cache
|
||||
return cache
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all in-memory and on-disk caches."""
|
||||
with self._caches_lock:
|
||||
try:
|
||||
# try to remove in optimal way if such ability provided by
|
||||
# storage manager clear_all method;
|
||||
# if not implemented, fallback to remove all
|
||||
# available storages one by one
|
||||
self.get_storage_manager().clear_all()
|
||||
except NotImplementedError:
|
||||
for data_cache in self._function_caches.values():
|
||||
data_cache.clear()
|
||||
data_cache.storage.close()
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
with self._caches_lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats: list[CacheStat] = []
|
||||
for cache in function_caches.values():
|
||||
stats.extend(cache.get_stats())
|
||||
return stats
|
||||
|
||||
def validate_cache_params(
|
||||
self,
|
||||
function_name: str,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl: int | float | timedelta | None,
|
||||
) -> None:
|
||||
"""Validate that the cache params are valid for given storage.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidCacheStorageContext
|
||||
Raised if the cache storage manager is not able to work with provided
|
||||
CacheStorageContext.
|
||||
"""
|
||||
|
||||
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
|
||||
|
||||
cache_context = self.create_cache_storage_context(
|
||||
function_key="DUMMY_KEY",
|
||||
function_name=function_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
try:
|
||||
self.get_storage_manager().check_context(cache_context)
|
||||
except InvalidCacheStorageContext as e:
|
||||
_LOGGER.error(
|
||||
"Cache params for function %s are incompatible with current "
|
||||
"cache storage manager: %s",
|
||||
function_name,
|
||||
e,
|
||||
)
|
||||
raise
|
||||
|
||||
def create_cache_storage_context(
|
||||
self,
|
||||
function_key: str,
|
||||
function_name: str,
|
||||
persist: CachePersistType,
|
||||
ttl_seconds: float | None,
|
||||
max_entries: int | None,
|
||||
) -> CacheStorageContext:
|
||||
return CacheStorageContext(
|
||||
function_key=function_key,
|
||||
function_display_name=function_name,
|
||||
ttl_seconds=ttl_seconds,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
)
|
||||
|
||||
def get_storage_manager(self) -> CacheStorageManager:
|
||||
if runtime.exists():
|
||||
return runtime.get_instance().cache_storage_manager
|
||||
else:
|
||||
# When running in "raw mode", we can't access the CacheStorageManager,
|
||||
# so we're falling back to InMemoryCache.
|
||||
_LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
|
||||
return MemoryCacheStorageManager()
|
||||
|
||||
|
||||
# Singleton DataCaches instance
|
||||
_data_caches = DataCaches()
|
||||
|
||||
|
||||
def get_data_cache_stats_provider() -> CacheStatsProvider:
|
||||
"""Return the StatsProvider for all @st.cache_data functions."""
|
||||
return _data_caches
|
||||
|
||||
|
||||
class CacheDataAPI:
|
||||
"""Implements the public st.cache_data API: the @st.cache_data decorator, and
|
||||
st.cache_data.clear().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, decorator_metric_name: str, deprecation_warning: str | None = None
|
||||
):
|
||||
"""Create a CacheDataAPI instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decorator_metric_name
|
||||
The metric name to record for decorator usage. `@st.experimental_memo` is
|
||||
deprecated, but we're still supporting it and tracking its usage separately
|
||||
from `@st.cache_data`.
|
||||
|
||||
deprecation_warning
|
||||
An optional deprecation warning to show when the API is accessed.
|
||||
"""
|
||||
|
||||
# Parameterize the decorator metric name.
|
||||
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
|
||||
self._decorator = gather_metrics( # type: ignore
|
||||
decorator_metric_name, self._decorator
|
||||
)
|
||||
self._deprecation_warning = deprecation_warning
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Bare decorator usage
|
||||
@overload
|
||||
def __call__(self, func: F) -> F:
|
||||
...
|
||||
|
||||
# Decorator with arguments
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
persist: CachePersistType | bool = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
persist: CachePersistType | bool = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
):
|
||||
return self._decorator(
|
||||
func,
|
||||
ttl=ttl,
|
||||
max_entries=max_entries,
|
||||
persist=persist,
|
||||
show_spinner=show_spinner,
|
||||
experimental_allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
|
||||
def _decorator(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | None,
|
||||
max_entries: int | None,
|
||||
show_spinner: bool | str,
|
||||
persist: CachePersistType | bool,
|
||||
experimental_allow_widgets: bool,
|
||||
):
|
||||
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).
|
||||
|
||||
Cached objects are stored in "pickled" form, which means that the return
|
||||
value of a cached function must be pickleable. Each caller of the cached
|
||||
function gets its own copy of the cached data.
|
||||
|
||||
You can clear a function's cache with ``func.clear()`` or clear the entire
|
||||
cache with ``st.cache_data.clear()``.
|
||||
|
||||
To cache global resources, use ``st.cache_resource`` instead. Learn more
|
||||
about caching at https://docs.streamlit.io/library/advanced-features/caching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function to cache. Streamlit hashes the function's source code.
|
||||
|
||||
ttl : float or timedelta or None
|
||||
The maximum number of seconds to keep an entry in the cache, or
|
||||
None if cache entries should not expire. The default is None.
|
||||
Note that ttl is incompatible with ``persist="disk"`` - ``ttl`` will be
|
||||
ignored if ``persist`` is specified.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. (When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed.) The default is None.
|
||||
|
||||
show_spinner : boolean or string
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a "cache miss" and the cached data is being created. If string,
|
||||
value of show_spinner param will be used for spinner text.
|
||||
|
||||
persist : str or boolean or None
|
||||
Optional location to persist cached data to. Passing "disk" (or True)
|
||||
will persist the cached data to the local disk. None (or False) will disable
|
||||
persistence. The default is None.
|
||||
|
||||
experimental_allow_widgets : boolean
|
||||
Allow widgets to be used in the cached function. Defaults to False.
|
||||
Support for widgets in cached functions is currently experimental.
|
||||
Setting this parameter to True may lead to excessive memory use since the
|
||||
widget value is treated as an additional input parameter to the cache.
|
||||
We may remove support for this option at any time without notice.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> d1 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> d2 = fetch_and_clean_data(DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the data in d1 is the same as in d2.
|
||||
>>>
|
||||
>>> d3 = fetch_and_clean_data(DATA_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
To set the ``persist`` parameter, use this command as follows:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data(persist="disk")
|
||||
... def fetch_and_clean_data(url):
|
||||
... # Fetch data from URL here, and then clean it up.
|
||||
... return data
|
||||
|
||||
By default, all parameters to a cached function must be hashable.
|
||||
Any parameter whose name begins with ``_`` will not be hashed. You can use
|
||||
this as an "escape hatch" for parameters that are not hashable:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(_db_connection, num_rows):
|
||||
... # Fetch data from _db_connection here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> connection = make_database_connection()
|
||||
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> another_connection = make_database_connection()
|
||||
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value - even though the _database_connection parameter was different
|
||||
>>> # in both calls.
|
||||
|
||||
A cached function's cache can be procedurally cleared:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_data
|
||||
... def fetch_and_clean_data(_db_connection, num_rows):
|
||||
... # Fetch data from _db_connection here, and then clean it up.
|
||||
... return data
|
||||
...
|
||||
>>> fetch_and_clean_data.clear()
|
||||
>>> # Clear all cached entries for this function.
|
||||
|
||||
"""
|
||||
|
||||
# Parse our persist value into a string
|
||||
persist_string: CachePersistType
|
||||
if persist is True:
|
||||
persist_string = "disk"
|
||||
elif persist is False:
|
||||
persist_string = None
|
||||
else:
|
||||
persist_string = persist
|
||||
|
||||
if persist_string not in (None, "disk"):
|
||||
# We'll eventually have more persist options.
|
||||
raise StreamlitAPIException(
|
||||
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
|
||||
)
|
||||
|
||||
self._maybe_show_deprecation_warning()
|
||||
|
||||
def wrapper(f):
|
||||
return make_cached_func_wrapper(
|
||||
CachedDataFuncInfo(
|
||||
func=f,
|
||||
persist=persist_string,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
if func is None:
|
||||
return wrapper
|
||||
|
||||
return make_cached_func_wrapper(
|
||||
CachedDataFuncInfo(
|
||||
func=cast(types.FunctionType, func),
|
||||
persist=persist_string,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
@gather_metrics("clear_data_caches")
|
||||
def clear(self) -> None:
|
||||
"""Clear all in-memory and on-disk data caches."""
|
||||
self._maybe_show_deprecation_warning()
|
||||
_data_caches.clear_all()
|
||||
|
||||
def _maybe_show_deprecation_warning(self):
|
||||
"""If the API is being accessed with the deprecated `st.experimental_memo` name,
|
||||
show a deprecation warning.
|
||||
"""
|
||||
if self._deprecation_warning is not None:
|
||||
show_deprecation_warning(self._deprecation_warning)
|
||||
|
||||
|
||||
class DataCache(Cache):
|
||||
"""Manages cached values for a single st.cache_data function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
storage: CacheStorage,
|
||||
persist: CachePersistType,
|
||||
max_entries: int | None,
|
||||
ttl_seconds: float | None,
|
||||
display_name: str,
|
||||
allow_widgets: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.display_name = display_name
|
||||
self.storage = storage
|
||||
self.ttl_seconds = ttl_seconds
|
||||
self.max_entries = max_entries
|
||||
self.persist = persist
|
||||
self.allow_widgets = allow_widgets
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
if isinstance(self.storage, CacheStatsProvider):
|
||||
return self.storage.get_stats()
|
||||
return []
|
||||
|
||||
def read_result(self, key: str) -> CachedResult:
|
||||
"""Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
|
||||
if the value doesn't exist, and `CacheError` if the value exists but can't
|
||||
be unpickled.
|
||||
"""
|
||||
try:
|
||||
pickled_entry = self.storage.get(key)
|
||||
except CacheStorageKeyNotFoundError as e:
|
||||
raise CacheKeyNotFoundError(str(e)) from e
|
||||
except CacheStorageError as e:
|
||||
raise CacheError(str(e)) from e
|
||||
|
||||
try:
|
||||
entry = pickle.loads(pickled_entry)
|
||||
if not isinstance(entry, MultiCacheResults):
|
||||
# Loaded an old cache file format, remove it and let the caller
|
||||
# rerun the function.
|
||||
self.storage.delete(key)
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if not ctx:
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
widget_key = entry.get_current_widget_key(ctx, CacheType.DATA)
|
||||
if widget_key in entry.results:
|
||||
return entry.results[widget_key]
|
||||
else:
|
||||
raise CacheKeyNotFoundError()
|
||||
except pickle.UnpicklingError as exc:
|
||||
raise CacheError(f"Failed to unpickle {key}") from exc
|
||||
|
||||
@gather_metrics("_cache_data_object")
|
||||
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache.
|
||||
The value must be pickleable.
|
||||
"""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
|
||||
main_id = st._main.id
|
||||
sidebar_id = st.sidebar.id
|
||||
|
||||
if self.allow_widgets:
|
||||
widgets = {
|
||||
msg.widget_metadata.widget_id
|
||||
for msg in messages
|
||||
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
|
||||
}
|
||||
else:
|
||||
widgets = set()
|
||||
|
||||
multi_cache_results: MultiCacheResults | None = None
|
||||
|
||||
# Try to find in cache storage, then falling back to a new result instance
|
||||
try:
|
||||
multi_cache_results = self._read_multi_results_from_storage(key)
|
||||
except (CacheKeyNotFoundError, pickle.UnpicklingError):
|
||||
pass
|
||||
|
||||
if multi_cache_results is None:
|
||||
multi_cache_results = MultiCacheResults(widget_ids=widgets, results={})
|
||||
multi_cache_results.widget_ids.update(widgets)
|
||||
widget_key = multi_cache_results.get_current_widget_key(ctx, CacheType.DATA)
|
||||
|
||||
result = CachedResult(value, messages, main_id, sidebar_id)
|
||||
multi_cache_results.results[widget_key] = result
|
||||
|
||||
try:
|
||||
pickled_entry = pickle.dumps(multi_cache_results)
|
||||
except (pickle.PicklingError, TypeError) as exc:
|
||||
raise CacheError(f"Failed to pickle {key}") from exc
|
||||
|
||||
self.storage.set(key, pickled_entry)
|
||||
|
||||
def _clear(self) -> None:
|
||||
self.storage.clear()
|
||||
|
||||
def _read_multi_results_from_storage(self, key: str) -> MultiCacheResults:
|
||||
"""Look up the results from storage and ensure it has the right type.
|
||||
|
||||
Raises a `CacheKeyNotFoundError` if the key has no entry, or if the
|
||||
entry is malformed.
|
||||
"""
|
||||
try:
|
||||
pickled = self.storage.get(key)
|
||||
except CacheStorageKeyNotFoundError as e:
|
||||
raise CacheKeyNotFoundError(str(e)) from e
|
||||
|
||||
maybe_results = pickle.loads(pickled)
|
||||
|
||||
if isinstance(maybe_results, MultiCacheResults):
|
||||
return maybe_results
|
||||
else:
|
||||
self.storage.delete(key)
|
||||
raise CacheKeyNotFoundError()
|
||||
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import types
|
||||
from typing import Any, Optional
|
||||
|
||||
from streamlit import type_util
|
||||
from streamlit.errors import (
|
||||
MarkdownFormattedException,
|
||||
StreamlitAPIException,
|
||||
StreamlitAPIWarning,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
|
||||
|
||||
CACHE_DOCS_URL = "https://docs.streamlit.io/library/advanced-features/caching"
|
||||
|
||||
|
||||
def get_cached_func_name_md(func: Any) -> str:
|
||||
"""Get markdown representation of the function name."""
|
||||
if hasattr(func, "__name__"):
|
||||
return "`%s()`" % func.__name__
|
||||
elif hasattr(type(func), "__name__"):
|
||||
return f"`{type(func).__name__}`"
|
||||
return f"`{type(func)}`"
|
||||
|
||||
|
||||
def get_return_value_type(return_value: Any) -> str:
|
||||
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
|
||||
return f"`{return_value.__module__}.{type(return_value).__name__}`"
|
||||
return get_cached_func_name_md(return_value)
|
||||
|
||||
|
||||
class UnhashableTypeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnhashableParamError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
func: types.FunctionType,
|
||||
arg_name: Optional[str],
|
||||
arg_value: Any,
|
||||
orig_exc: BaseException,
|
||||
):
|
||||
msg = self._create_message(cache_type, func, arg_name, arg_value)
|
||||
super().__init__(msg)
|
||||
self.with_traceback(orig_exc.__traceback__)
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
cache_type: CacheType,
|
||||
func: types.FunctionType,
|
||||
arg_name: Optional[str],
|
||||
arg_value: Any,
|
||||
) -> str:
|
||||
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
|
||||
arg_type = type_util.get_fqn_type(arg_value)
|
||||
func_name = func.__name__
|
||||
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
|
||||
|
||||
return (
|
||||
f"""
|
||||
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
|
||||
|
||||
To address this, you can tell Streamlit not to hash this argument by adding a
|
||||
leading underscore to the argument's name in the function signature:
|
||||
|
||||
```
|
||||
@st.{get_decorator_api_name(cache_type)}
|
||||
def {func_name}({arg_replacement_name}, ...):
|
||||
...
|
||||
```
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
|
||||
class CacheKeyNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CacheError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CachedStFunctionWarning(StreamlitAPIWarning):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
st_func_name: str,
|
||||
cached_func: types.FunctionType,
|
||||
):
|
||||
args = {
|
||||
"st_func_name": f"`st.{st_func_name}()`",
|
||||
"func_name": self._get_cached_func_name_md(cached_func),
|
||||
"decorator_name": get_decorator_api_name(cache_type),
|
||||
}
|
||||
|
||||
msg = (
|
||||
"""
|
||||
Your script uses %(st_func_name)s to write to your Streamlit app from within
|
||||
some cached code at %(func_name)s. This code will only be called when we detect
|
||||
a cache "miss", which can lead to unexpected results.
|
||||
|
||||
How to fix this:
|
||||
* Move the %(st_func_name)s call outside %(func_name)s.
|
||||
* Or, if you know what you're doing, use `@st.%(decorator_name)s(experimental_allow_widgets=True)`
|
||||
to enable widget replay and suppress this warning.
|
||||
"""
|
||||
% args
|
||||
).strip("\n")
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
@staticmethod
|
||||
def _get_cached_func_name_md(func: types.FunctionType) -> str:
|
||||
"""Get markdown representation of the function name."""
|
||||
if hasattr(func, "__name__"):
|
||||
return "`%s()`" % func.__name__
|
||||
else:
|
||||
return "a cached function"
|
||||
|
||||
|
||||
class CacheReplayClosureError(StreamlitAPIException):
|
||||
def __init__(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
cached_func: types.FunctionType,
|
||||
):
|
||||
func_name = get_cached_func_name_md(cached_func)
|
||||
decorator_name = get_decorator_api_name(cache_type)
|
||||
|
||||
msg = (
|
||||
f"""
|
||||
While running {func_name}, a streamlit element is called on some layout block created outside the function.
|
||||
This is incompatible with replaying the cached effect of that element, because the
|
||||
the referenced block might not exist when the replay happens.
|
||||
|
||||
How to fix this:
|
||||
* Move the creation of $THING inside {func_name}.
|
||||
* Move the call to the streamlit element outside of {func_name}.
|
||||
* Remove the `@st.{decorator_name}` decorator from {func_name}.
|
||||
"""
|
||||
).strip("\n")
|
||||
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class UnserializableReturnValueError(MarkdownFormattedException):
|
||||
def __init__(self, func: types.FunctionType, return_value: types.FunctionType):
|
||||
MarkdownFormattedException.__init__(
|
||||
self,
|
||||
f"""
|
||||
Cannot serialize the return value (of type {get_return_value_type(return_value)}) in {get_cached_func_name_md(func)}.
|
||||
`st.cache_data` uses [pickle](https://docs.python.org/3/library/pickle.html) to
|
||||
serialize the function’s return value and safely store it in the cache without mutating the original object. Please convert the return value to a pickle-serializable type.
|
||||
If you want to cache unserializable objects such as database connections or Tensorflow
|
||||
sessions, use `st.cache_resource` instead (see [our docs]({CACHE_DOCS_URL}) for differences).""",
|
||||
)
|
||||
|
||||
|
||||
class UnevaluatedDataFrameError(StreamlitAPIException):
|
||||
"""Used to display a message about uncollected dataframe being used"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,520 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""@st.cache_resource implementation"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
import types
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, TypeVar, cast, overload
|
||||
|
||||
from cachetools import TTLCache
|
||||
from pympler import asizeof
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import streamlit as st
|
||||
from streamlit.deprecation_util import show_deprecation_warning
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import cache_utils
|
||||
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cache_utils import (
|
||||
Cache,
|
||||
CachedFuncInfo,
|
||||
make_cached_func_wrapper,
|
||||
ttl_to_seconds,
|
||||
)
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
ElementMsgData,
|
||||
MsgData,
|
||||
MultiCacheResults,
|
||||
)
|
||||
from streamlit.runtime.metrics_util import gather_metrics
|
||||
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
|
||||
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
|
||||
|
||||
ValidateFunc: TypeAlias = Callable[[Any], bool]
|
||||
|
||||
|
||||
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
|
||||
"""True if the two validate functions are equal for the purposes of
|
||||
determining whether a given function cache needs to be recreated.
|
||||
"""
|
||||
# To "properly" test for function equality here, we'd need to compare function bytecode.
|
||||
# For performance reasons, We've decided not to do that for now.
|
||||
return (a is None and b is None) or (a is not None and b is not None)
|
||||
|
||||
|
||||
class ResourceCaches(CacheStatsProvider):
|
||||
"""Manages all ResourceCache instances"""
|
||||
|
||||
def __init__(self):
|
||||
self._caches_lock = threading.Lock()
|
||||
self._function_caches: dict[str, ResourceCache] = {}
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key: str,
|
||||
display_name: str,
|
||||
max_entries: int | float | None,
|
||||
ttl: float | timedelta | None,
|
||||
validate: ValidateFunc | None,
|
||||
allow_widgets: bool,
|
||||
) -> ResourceCache:
|
||||
"""Return the mem cache for the given key.
|
||||
|
||||
If it doesn't exist, create a new one with the given params.
|
||||
"""
|
||||
if max_entries is None:
|
||||
max_entries = math.inf
|
||||
|
||||
ttl_seconds = ttl_to_seconds(ttl)
|
||||
|
||||
# Get the existing cache, if it exists, and validate that its params
|
||||
# haven't changed.
|
||||
with self._caches_lock:
|
||||
cache = self._function_caches.get(key)
|
||||
if (
|
||||
cache is not None
|
||||
and cache.ttl_seconds == ttl_seconds
|
||||
and cache.max_entries == max_entries
|
||||
and _equal_validate_funcs(cache.validate, validate)
|
||||
):
|
||||
return cache
|
||||
|
||||
# Create a new cache object and put it in our dict
|
||||
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
|
||||
cache = ResourceCache(
|
||||
key=key,
|
||||
display_name=display_name,
|
||||
max_entries=max_entries,
|
||||
ttl_seconds=ttl_seconds,
|
||||
validate=validate,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self._function_caches[key] = cache
|
||||
return cache
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all resource caches."""
|
||||
with self._caches_lock:
|
||||
self._function_caches = {}
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
with self._caches_lock:
|
||||
# Shallow-clone our caches. We don't want to hold the global
|
||||
# lock during stats-gathering.
|
||||
function_caches = self._function_caches.copy()
|
||||
|
||||
stats: list[CacheStat] = []
|
||||
for cache in function_caches.values():
|
||||
stats.extend(cache.get_stats())
|
||||
return stats
|
||||
|
||||
|
||||
# Singleton ResourceCaches instance
|
||||
_resource_caches = ResourceCaches()
|
||||
|
||||
|
||||
def get_resource_cache_stats_provider() -> CacheStatsProvider:
|
||||
"""Return the StatsProvider for all @st.cache_resource functions."""
|
||||
return _resource_caches
|
||||
|
||||
|
||||
class CachedResourceFuncInfo(CachedFuncInfo):
|
||||
"""Implements the CachedFuncInfo interface for @st.cache_resource"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
max_entries: int | None,
|
||||
ttl: float | timedelta | None,
|
||||
validate: ValidateFunc | None,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
super().__init__(
|
||||
func,
|
||||
show_spinner=show_spinner,
|
||||
allow_widgets=allow_widgets,
|
||||
)
|
||||
self.max_entries = max_entries
|
||||
self.ttl = ttl
|
||||
self.validate = validate
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
return CacheType.RESOURCE
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""A human-readable name for the cached function"""
|
||||
return f"{self.func.__module__}.{self.func.__qualname__}"
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
return _resource_caches.get_cache(
|
||||
key=function_key,
|
||||
display_name=self.display_name,
|
||||
max_entries=self.max_entries,
|
||||
ttl=self.ttl,
|
||||
validate=self.validate,
|
||||
allow_widgets=self.allow_widgets,
|
||||
)
|
||||
|
||||
|
||||
class CacheResourceAPI:
|
||||
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
|
||||
and st.cache_resource.clear().
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, decorator_metric_name: str, deprecation_warning: str | None = None
|
||||
):
|
||||
"""Create a CacheResourceAPI instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
decorator_metric_name
|
||||
The metric name to record for decorator usage. `@st.experimental_singleton` is
|
||||
deprecated, but we're still supporting it and tracking its usage separately
|
||||
from `@st.cache_resource`.
|
||||
|
||||
deprecation_warning
|
||||
An optional deprecation warning to show when the API is accessed.
|
||||
"""
|
||||
|
||||
# Parameterize the decorator metric name.
|
||||
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
|
||||
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
|
||||
self._deprecation_warning = deprecation_warning
|
||||
|
||||
# Type-annotate the decorator function.
|
||||
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
# Bare decorator usage
|
||||
@overload
|
||||
def __call__(self, func: F) -> F:
|
||||
...
|
||||
|
||||
# Decorator with arguments
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
validate: ValidateFunc | None = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
func: F | None = None,
|
||||
*,
|
||||
ttl: float | timedelta | None = None,
|
||||
max_entries: int | None = None,
|
||||
show_spinner: bool | str = True,
|
||||
validate: ValidateFunc | None = None,
|
||||
experimental_allow_widgets: bool = False,
|
||||
):
|
||||
return self._decorator(
|
||||
func,
|
||||
ttl=ttl,
|
||||
max_entries=max_entries,
|
||||
show_spinner=show_spinner,
|
||||
validate=validate,
|
||||
experimental_allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
|
||||
def _decorator(
|
||||
self,
|
||||
func: F | None,
|
||||
*,
|
||||
ttl: float | timedelta | None,
|
||||
max_entries: int | None,
|
||||
show_spinner: bool | str,
|
||||
validate: ValidateFunc | None,
|
||||
experimental_allow_widgets: bool,
|
||||
):
|
||||
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
|
||||
|
||||
Cached objects are shared across all users, sessions, and reruns. They
|
||||
must be thread-safe because they can be accessed from multiple threads
|
||||
concurrently. If thread safety is an issue, consider using ``st.session_state``
|
||||
to store resources per session instead.
|
||||
|
||||
You can clear a function's cache with ``func.clear()`` or clear the entire
|
||||
cache with ``st.cache_resource.clear()``.
|
||||
|
||||
To cache data, use ``st.cache_data`` instead. Learn more about caching at
|
||||
https://docs.streamlit.io/library/advanced-features/caching.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : callable
|
||||
The function that creates the cached resource. Streamlit hashes the
|
||||
function's source code.
|
||||
|
||||
ttl : float or timedelta or None
|
||||
The maximum number of seconds to keep an entry in the cache, or
|
||||
None if cache entries should not expire. The default is None.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to keep in the cache, or None
|
||||
for an unbounded cache. (When a new entry is added to a full cache,
|
||||
the oldest cached entry will be removed.) The default is None.
|
||||
|
||||
show_spinner : boolean or string
|
||||
Enable the spinner. Default is True to show a spinner when there is
|
||||
a "cache miss" and the cached resource is being created. If string,
|
||||
value of show_spinner param will be used for spinner text.
|
||||
|
||||
validate : callable or None
|
||||
An optional validation function for cached data. ``validate`` is called
|
||||
each time the cached value is accessed. It receives the cached value as
|
||||
its only parameter and it must return a boolean. If ``validate`` returns
|
||||
False, the current cached value is discarded, and the decorated function
|
||||
is called to compute a new value. This is useful e.g. to check the
|
||||
health of database connections.
|
||||
|
||||
experimental_allow_widgets : boolean
|
||||
Allow widgets to be used in the cached function. Defaults to False.
|
||||
Support for widgets in cached functions is currently experimental.
|
||||
Setting this parameter to True may lead to excessive memory use since the
|
||||
widget value is treated as an additional input parameter to the cache.
|
||||
We may remove support for this option at any time without notice.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(url):
|
||||
... # Create a database session object that points to the URL.
|
||||
... return session
|
||||
...
|
||||
>>> s1 = get_database_session(SESSION_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> s2 = get_database_session(SESSION_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value. This means that now the connection object in s1 is the same as in s2.
|
||||
>>>
|
||||
>>> s3 = get_database_session(SESSION_URL_2)
|
||||
>>> # This is a different URL, so the function executes.
|
||||
|
||||
By default, all parameters to a cache_resource function must be hashable.
|
||||
Any parameter whose name begins with ``_`` will not be hashed. You can use
|
||||
this as an "escape hatch" for parameters that are not hashable:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(_sessionmaker, url):
|
||||
... # Create a database connection object that points to the URL.
|
||||
... return connection
|
||||
...
|
||||
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
|
||||
>>> # Actually executes the function, since this is the first time it was
|
||||
>>> # encountered.
|
||||
>>>
|
||||
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
|
||||
>>> # Does not execute the function. Instead, returns its previously computed
|
||||
>>> # value - even though the _sessionmaker parameter was different
|
||||
>>> # in both calls.
|
||||
|
||||
A cache_resource function's cache can be procedurally cleared:
|
||||
|
||||
>>> import streamlit as st
|
||||
>>>
|
||||
>>> @st.cache_resource
|
||||
... def get_database_session(_sessionmaker, url):
|
||||
... # Create a database connection object that points to the URL.
|
||||
... return connection
|
||||
...
|
||||
>>> get_database_session.clear()
|
||||
>>> # Clear all cached entries for this function.
|
||||
|
||||
"""
|
||||
self._maybe_show_deprecation_warning()
|
||||
|
||||
# Support passing the params via function decorator, e.g.
|
||||
# @st.cache_resource(show_spinner=False)
|
||||
if func is None:
|
||||
return lambda f: make_cached_func_wrapper(
|
||||
CachedResourceFuncInfo(
|
||||
func=f,
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
validate=validate,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
return make_cached_func_wrapper(
|
||||
CachedResourceFuncInfo(
|
||||
func=cast(types.FunctionType, func),
|
||||
show_spinner=show_spinner,
|
||||
max_entries=max_entries,
|
||||
ttl=ttl,
|
||||
validate=validate,
|
||||
allow_widgets=experimental_allow_widgets,
|
||||
)
|
||||
)
|
||||
|
||||
@gather_metrics("clear_resource_caches")
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache_resource caches."""
|
||||
self._maybe_show_deprecation_warning()
|
||||
_resource_caches.clear_all()
|
||||
|
||||
def _maybe_show_deprecation_warning(self):
|
||||
"""If the API is being accessed with the deprecated `st.experimental_singleton` name,
|
||||
show a deprecation warning.
|
||||
"""
|
||||
if self._deprecation_warning is not None:
|
||||
show_deprecation_warning(self._deprecation_warning)
|
||||
|
||||
|
||||
class ResourceCache(Cache):
|
||||
"""Manages cached values for a single st.cache_resource function."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
max_entries: float,
|
||||
ttl_seconds: float,
|
||||
validate: ValidateFunc | None,
|
||||
display_name: str,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.display_name = display_name
|
||||
self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache(
|
||||
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
|
||||
)
|
||||
self._mem_cache_lock = threading.Lock()
|
||||
self.validate = validate
|
||||
self.allow_widgets = allow_widgets
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return cast(float, self._mem_cache.maxsize)
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return cast(float, self._mem_cache.ttl)
|
||||
|
||||
def read_result(self, key: str) -> CachedResult:
|
||||
"""Read a value and associated messages from the cache.
|
||||
Raise `CacheKeyNotFoundError` if the value doesn't exist.
|
||||
"""
|
||||
with self._mem_cache_lock:
|
||||
if key not in self._mem_cache:
|
||||
# key does not exist in cache.
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
multi_results: MultiCacheResults = self._mem_cache[key]
|
||||
|
||||
ctx = get_script_run_ctx()
|
||||
if not ctx:
|
||||
# ScriptRunCtx does not exist (we're probably running in "raw" mode).
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
|
||||
if widget_key not in multi_results.results:
|
||||
# widget_key does not exist in cache (this combination of widgets hasn't been
|
||||
# seen for the value_key yet).
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
result = multi_results.results[widget_key]
|
||||
|
||||
if self.validate is not None and not self.validate(result.value):
|
||||
# Validate failed: delete the entry and raise an error.
|
||||
del multi_results.results[widget_key]
|
||||
raise CacheKeyNotFoundError()
|
||||
|
||||
return result
|
||||
|
||||
@gather_metrics("_cache_resource_object")
|
||||
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache."""
|
||||
ctx = get_script_run_ctx()
|
||||
if ctx is None:
|
||||
return
|
||||
|
||||
main_id = st._main.id
|
||||
sidebar_id = st.sidebar.id
|
||||
if self.allow_widgets:
|
||||
widgets = {
|
||||
msg.widget_metadata.widget_id
|
||||
for msg in messages
|
||||
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
|
||||
}
|
||||
else:
|
||||
widgets = set()
|
||||
|
||||
with self._mem_cache_lock:
|
||||
try:
|
||||
multi_results = self._mem_cache[key]
|
||||
except KeyError:
|
||||
multi_results = MultiCacheResults(widget_ids=widgets, results={})
|
||||
|
||||
multi_results.widget_ids.update(widgets)
|
||||
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
|
||||
|
||||
result = CachedResult(value, messages, main_id, sidebar_id)
|
||||
multi_results.results[widget_key] = result
|
||||
self._mem_cache[key] = multi_results
|
||||
|
||||
def _clear(self) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.clear()
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
# Shallow clone our cache. Computing item sizes is potentially
|
||||
# expensive, and we want to minimize the time we spend holding
|
||||
# the lock.
|
||||
with self._mem_cache_lock:
|
||||
cache_entries = list(self._mem_cache.values())
|
||||
|
||||
return [
|
||||
CacheStat(
|
||||
category_name="st_cache_resource",
|
||||
cache_name=self.display_name,
|
||||
byte_length=asizeof.asizeof(entry),
|
||||
)
|
||||
for entry in cache_entries
|
||||
]
|
||||
@@ -0,0 +1,31 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class CacheType(enum.Enum):
|
||||
"""The function cache types we implement."""
|
||||
|
||||
DATA = "DATA"
|
||||
RESOURCE = "RESOURCE"
|
||||
|
||||
|
||||
def get_decorator_api_name(cache_type: CacheType) -> str:
|
||||
"""Return the name of the public decorator API for the given CacheType."""
|
||||
if cache_type is CacheType.DATA:
|
||||
return "cache_data"
|
||||
if cache_type is CacheType.RESOURCE:
|
||||
return "cache_resource"
|
||||
raise RuntimeError(f"Unrecognized CacheType '{cache_type}'")
|
||||
@@ -0,0 +1,449 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Common cache logic shared by st.cache_data and st.cache_resource."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from datetime import timedelta
|
||||
from typing import Any, Callable, overload
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from streamlit import type_util
|
||||
from streamlit.elements.spinner import spinner
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.cache_errors import (
|
||||
CacheError,
|
||||
CacheKeyNotFoundError,
|
||||
UnevaluatedDataFrameError,
|
||||
UnhashableParamError,
|
||||
UnhashableTypeError,
|
||||
UnserializableReturnValueError,
|
||||
get_cached_func_name_md,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.cached_message_replay import (
|
||||
CachedMessageReplayContext,
|
||||
CachedResult,
|
||||
MsgData,
|
||||
replay_cached_messages,
|
||||
)
|
||||
from streamlit.runtime.caching.hashing import update_hash
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
# The timer function we use with TTLCache. This is the default timer func, but
|
||||
# is exposed here as a constant so that it can be patched in unit tests.
|
||||
TTLCACHE_TIMER = time.monotonic
|
||||
|
||||
|
||||
@overload
|
||||
def ttl_to_seconds(
|
||||
ttl: float | timedelta | None, *, coerce_none_to_inf: Literal[False]
|
||||
) -> float | None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def ttl_to_seconds(ttl: float | timedelta | None) -> float:
|
||||
...
|
||||
|
||||
|
||||
def ttl_to_seconds(
|
||||
ttl: float | timedelta | None, *, coerce_none_to_inf: bool = True
|
||||
) -> float | None:
|
||||
"""
|
||||
Convert a ttl value to a float representing "number of seconds".
|
||||
"""
|
||||
if coerce_none_to_inf and ttl is None:
|
||||
return math.inf
|
||||
if isinstance(ttl, timedelta):
|
||||
return ttl.total_seconds()
|
||||
return ttl
|
||||
|
||||
|
||||
# We show a special "UnevaluatedDataFrame" warning for cached funcs
|
||||
# that attempt to return one of these unserializable types:
|
||||
UNEVALUATED_DATAFRAME_TYPES = (
|
||||
"snowflake.snowpark.table.Table",
|
||||
"snowflake.snowpark.dataframe.DataFrame",
|
||||
"pyspark.sql.dataframe.DataFrame",
|
||||
)
|
||||
|
||||
|
||||
class Cache:
|
||||
"""Function cache interface. Caches persist across script runs."""
|
||||
|
||||
def __init__(self):
|
||||
self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
|
||||
self._value_locks_lock = threading.Lock()
|
||||
|
||||
@abstractmethod
|
||||
def read_result(self, value_key: str) -> CachedResult:
|
||||
"""Read a value and associated messages from the cache.
|
||||
|
||||
Raises
|
||||
------
|
||||
CacheKeyNotFoundError
|
||||
Raised if value_key is not in the cache.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None:
|
||||
"""Write a value and associated messages to the cache, overwriting any existing
|
||||
result that uses the value_key.
|
||||
"""
|
||||
# We *could* `del self._value_locks[value_key]` here, since nobody will be taking
|
||||
# a compute_value_lock for this value_key after the result is written.
|
||||
raise NotImplementedError
|
||||
|
||||
def compute_value_lock(self, value_key: str) -> threading.Lock:
|
||||
"""Return the lock that should be held while computing a new cached value.
|
||||
In a popular app with a cache that hasn't been pre-warmed, many sessions may try
|
||||
to access a not-yet-cached value simultaneously. We use a lock to ensure that
|
||||
only one of those sessions computes the value, and the others block until
|
||||
the value is computed.
|
||||
"""
|
||||
with self._value_locks_lock:
|
||||
return self._value_locks[value_key]
|
||||
|
||||
def clear(self):
|
||||
"""Clear all values from this cache."""
|
||||
with self._value_locks_lock:
|
||||
self._value_locks.clear()
|
||||
self._clear()
|
||||
|
||||
@abstractmethod
|
||||
def _clear(self) -> None:
|
||||
"""Subclasses must implement this to perform cache-clearing logic."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CachedFuncInfo:
|
||||
"""Encapsulates data for a cached function instance.
|
||||
|
||||
CachedFuncInfo instances are scoped to a single script run - they're not
|
||||
persistent.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: types.FunctionType,
|
||||
show_spinner: bool | str,
|
||||
allow_widgets: bool,
|
||||
):
|
||||
self.func = func
|
||||
self.show_spinner = show_spinner
|
||||
self.allow_widgets = allow_widgets
|
||||
|
||||
@property
|
||||
def cache_type(self) -> CacheType:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_function_cache(self, function_key: str) -> Cache:
|
||||
"""Get or create the function cache for the given key."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]:
|
||||
"""Create a callable wrapper around a CachedFunctionInfo.
|
||||
|
||||
Calling the wrapper will return the cached value if it's already been
|
||||
computed, and will call the underlying function to compute and cache the
|
||||
value otherwise.
|
||||
|
||||
The wrapper also has a `clear` function that can be called to clear
|
||||
all of the wrapper's cached values.
|
||||
"""
|
||||
cached_func = CachedFunc(info)
|
||||
|
||||
# We'd like to simply return `cached_func`, which is already a Callable.
|
||||
# But using `functools.update_wrapper` on the CachedFunc instance
|
||||
# itself results in errors when our caching decorators are used to decorate
|
||||
# member functions. (See https://github.com/streamlit/streamlit/issues/6109)
|
||||
|
||||
@functools.wraps(info.func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return cached_func(*args, **kwargs)
|
||||
|
||||
# Give our wrapper its `clear` function.
|
||||
# (This results in a spurious mypy error that we suppress.)
|
||||
wrapper.clear = cached_func.clear # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class CachedFunc:
|
||||
def __init__(self, info: CachedFuncInfo):
|
||||
self._info = info
|
||||
self._function_key = _make_function_key(info.cache_type, info.func)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
"""The wrapper. We'll only call our underlying function on a cache miss."""
|
||||
|
||||
name = self._info.func.__qualname__
|
||||
|
||||
if isinstance(self._info.show_spinner, bool):
|
||||
if len(args) == 0 and len(kwargs) == 0:
|
||||
message = f"Running `{name}()`."
|
||||
else:
|
||||
message = f"Running `{name}(...)`."
|
||||
else:
|
||||
message = self._info.show_spinner
|
||||
|
||||
if self._info.show_spinner or isinstance(self._info.show_spinner, str):
|
||||
with spinner(message):
|
||||
return self._get_or_create_cached_value(args, kwargs)
|
||||
else:
|
||||
return self._get_or_create_cached_value(args, kwargs)
|
||||
|
||||
def _get_or_create_cached_value(
|
||||
self, func_args: tuple[Any, ...], func_kwargs: dict[str, Any]
|
||||
) -> Any:
|
||||
# Retrieve the function's cache object. We must do this "just-in-time"
|
||||
# (as opposed to in the constructor), because caches can be invalidated
|
||||
# at any time.
|
||||
cache = self._info.get_function_cache(self._function_key)
|
||||
|
||||
# Generate the key for the cached value. This is based on the
|
||||
# arguments passed to the function.
|
||||
value_key = _make_value_key(
|
||||
cache_type=self._info.cache_type,
|
||||
func=self._info.func,
|
||||
func_args=func_args,
|
||||
func_kwargs=func_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
cached_result = cache.read_result(value_key)
|
||||
return self._handle_cache_hit(cached_result)
|
||||
except CacheKeyNotFoundError:
|
||||
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
|
||||
|
||||
def _handle_cache_hit(self, result: CachedResult) -> Any:
|
||||
"""Handle a cache hit: replay the result's cached messages, and return its value."""
|
||||
replay_cached_messages(
|
||||
result,
|
||||
self._info.cache_type,
|
||||
self._info.func,
|
||||
)
|
||||
return result.value
|
||||
|
||||
def _handle_cache_miss(
|
||||
self,
|
||||
cache: Cache,
|
||||
value_key: str,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""Handle a cache miss: compute a new cached value, write it back to the cache,
|
||||
and return that newly-computed value.
|
||||
"""
|
||||
|
||||
# Implementation notes:
|
||||
# - We take a "compute_value_lock" before computing our value. This ensures that
|
||||
# multiple sessions don't try to compute the same value simultaneously.
|
||||
#
|
||||
# - We use a different lock for each value_key, as opposed to a single lock for
|
||||
# the entire cache, so that unrelated value computations don't block on each other.
|
||||
#
|
||||
# - When retrieving a cache entry that may not yet exist, we use a "double-checked locking"
|
||||
# strategy: first we try to retrieve the cache entry without taking a value lock. (This
|
||||
# happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't
|
||||
# been computed yet, we take the value lock and then immediately try to retrieve cache entry
|
||||
# *again*, while holding the lock. If the cache entry exists at this point, it means that
|
||||
# another thread computed the value before us.
|
||||
#
|
||||
# This means that the happy path ("cache entry exists") is a wee bit faster because
|
||||
# no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is
|
||||
# a wee bit slower, because we do two lookups for the entry.
|
||||
|
||||
with cache.compute_value_lock(value_key):
|
||||
# We've acquired the lock - but another thread may have acquired it first
|
||||
# and already computed the value. So we need to test for a cache hit again,
|
||||
# before computing.
|
||||
try:
|
||||
cached_result = cache.read_result(value_key)
|
||||
# Another thread computed the value before us. Early exit!
|
||||
return self._handle_cache_hit(cached_result)
|
||||
|
||||
except CacheKeyNotFoundError:
|
||||
# We acquired the lock before any other thread. Compute the value!
|
||||
with self._info.cached_message_replay_ctx.calling_cached_function(
|
||||
self._info.func, self._info.allow_widgets
|
||||
):
|
||||
computed_value = self._info.func(*func_args, **func_kwargs)
|
||||
|
||||
# We've computed our value, and now we need to write it back to the cache
|
||||
# along with any "replay messages" that were generated during value computation.
|
||||
messages = self._info.cached_message_replay_ctx._most_recent_messages
|
||||
try:
|
||||
cache.write_result(value_key, computed_value, messages)
|
||||
return computed_value
|
||||
except (CacheError, RuntimeError):
|
||||
# An exception was thrown while we tried to write to the cache. Report it to the user.
|
||||
# (We catch `RuntimeError` here because it will be raised by Apache Spark if we do not
|
||||
# collect dataframe before using `st.cache_data`.)
|
||||
if True in [
|
||||
type_util.is_type(computed_value, type_name)
|
||||
for type_name in UNEVALUATED_DATAFRAME_TYPES
|
||||
]:
|
||||
raise UnevaluatedDataFrameError(
|
||||
f"""
|
||||
The function {get_cached_func_name_md(self._info.func)} is decorated with `st.cache_data` but it returns an unevaluated dataframe
|
||||
of type `{type_util.get_fqn_type(computed_value)}`. Please call `collect()` or `to_pandas()` on the dataframe before returning it,
|
||||
so `st.cache_data` can serialize and cache it."""
|
||||
)
|
||||
raise UnserializableReturnValueError(
|
||||
return_value=computed_value, func=self._info.func
|
||||
)
|
||||
|
||||
def clear(self):
|
||||
"""Clear the wrapped function's associated cache."""
|
||||
cache = self._info.get_function_cache(self._function_key)
|
||||
cache.clear()
|
||||
|
||||
|
||||
def _make_value_key(
|
||||
cache_type: CacheType,
|
||||
func: types.FunctionType,
|
||||
func_args: tuple[Any, ...],
|
||||
func_kwargs: dict[str, Any],
|
||||
) -> str:
|
||||
"""Create the key for a value within a cache.
|
||||
|
||||
This key is generated from the function's arguments. All arguments
|
||||
will be hashed, except for those named with a leading "_".
|
||||
|
||||
Raises
|
||||
------
|
||||
StreamlitAPIException
|
||||
Raised (with a nicely-formatted explanation message) if we encounter
|
||||
an un-hashable arg.
|
||||
"""
|
||||
|
||||
# Create a (name, value) list of all *args and **kwargs passed to the
|
||||
# function.
|
||||
arg_pairs: list[tuple[str | None, Any]] = []
|
||||
for arg_idx in range(len(func_args)):
|
||||
arg_name = _get_positional_arg_name(func, arg_idx)
|
||||
arg_pairs.append((arg_name, func_args[arg_idx]))
|
||||
|
||||
for kw_name, kw_val in func_kwargs.items():
|
||||
# **kwargs ordering is preserved, per PEP 468
|
||||
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
|
||||
# deterministic.
|
||||
arg_pairs.append((kw_name, kw_val))
|
||||
|
||||
# Create the hash from each arg value, except for those args whose name
|
||||
# starts with "_". (Underscore-prefixed args are deliberately excluded from
|
||||
# hashing.)
|
||||
args_hasher = hashlib.new("md5")
|
||||
for arg_name, arg_value in arg_pairs:
|
||||
if arg_name is not None and arg_name.startswith("_"):
|
||||
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
|
||||
continue
|
||||
|
||||
try:
|
||||
update_hash(
|
||||
(arg_name, arg_value),
|
||||
hasher=args_hasher,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
except UnhashableTypeError as exc:
|
||||
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
|
||||
|
||||
value_key = args_hasher.hexdigest()
|
||||
_LOGGER.debug("Cache key: %s", value_key)
|
||||
|
||||
return value_key
|
||||
|
||||
|
||||
def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
|
||||
"""Create the unique key for a function's cache.
|
||||
|
||||
A function's key is stable across reruns of the app, and changes when
|
||||
the function's source code changes.
|
||||
"""
|
||||
func_hasher = hashlib.new("md5")
|
||||
|
||||
# Include the function's __module__ and __qualname__ strings in the hash.
|
||||
# This means that two identical functions in different modules
|
||||
# will not share a hash; it also means that two identical *nested*
|
||||
# functions in the same module will not share a hash.
|
||||
update_hash(
|
||||
(func.__module__, func.__qualname__),
|
||||
hasher=func_hasher,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
# Include the function's source code in its hash. If the source code can't
|
||||
# be retrieved, fall back to the function's bytecode instead.
|
||||
source_code: str | bytes
|
||||
try:
|
||||
source_code = inspect.getsource(func)
|
||||
except OSError as e:
|
||||
_LOGGER.debug(
|
||||
"Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}",
|
||||
e,
|
||||
)
|
||||
source_code = func.__code__.co_code
|
||||
|
||||
update_hash(
|
||||
source_code,
|
||||
hasher=func_hasher,
|
||||
cache_type=cache_type,
|
||||
)
|
||||
|
||||
cache_key = func_hasher.hexdigest()
|
||||
return cache_key
|
||||
|
||||
|
||||
def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> str | None:
|
||||
"""Return the name of a function's positional argument.
|
||||
|
||||
If arg_index is out of range, or refers to a parameter that is not a
|
||||
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
|
||||
return None instead.
|
||||
"""
|
||||
if arg_index < 0:
|
||||
return None
|
||||
|
||||
params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values())
|
||||
if arg_index >= len(params):
|
||||
return None
|
||||
|
||||
if params[arg_index].kind in (
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
):
|
||||
return params[arg_index].name
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,478 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import hashlib
|
||||
import threading
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Iterator, Union
|
||||
|
||||
from google.protobuf.message import Message
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
|
||||
import streamlit as st
|
||||
from streamlit import runtime, util
|
||||
from streamlit.elements import NONWIDGET_ELEMENTS, WIDGETS
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.proto.Block_pb2 import Block
|
||||
from streamlit.runtime.caching.cache_errors import (
|
||||
CachedStFunctionWarning,
|
||||
CacheReplayClosureError,
|
||||
)
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.caching.hashing import update_hash
|
||||
from streamlit.runtime.scriptrunner.script_run_context import (
|
||||
ScriptRunContext,
|
||||
get_script_run_ctx,
|
||||
)
|
||||
from streamlit.runtime.state.common import WidgetMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Widget(Protocol):
|
||||
id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WidgetMsgMetadata:
|
||||
"""Everything needed for replaying a widget and treating it as an implicit
|
||||
argument to a cached function, beyond what is stored for all elements.
|
||||
"""
|
||||
|
||||
widget_id: str
|
||||
widget_value: Any
|
||||
metadata: WidgetMetadata[Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MediaMsgData:
|
||||
media: bytes | str
|
||||
mimetype: str
|
||||
media_id: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ElementMsgData:
|
||||
"""An element's message and related metadata for
|
||||
replaying that element's function call.
|
||||
|
||||
widget_metadata is filled in if and only if this element is a widget.
|
||||
media_data is filled in iff this is a media element (image, audio, video).
|
||||
"""
|
||||
|
||||
delta_type: str
|
||||
message: Message
|
||||
id_of_dg_called_on: str
|
||||
returned_dgs_id: str
|
||||
widget_metadata: WidgetMsgMetadata | None = None
|
||||
media_data: list[MediaMsgData] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BlockMsgData:
|
||||
message: Block
|
||||
id_of_dg_called_on: str
|
||||
returned_dgs_id: str
|
||||
|
||||
|
||||
MsgData = Union[ElementMsgData, BlockMsgData]
|
||||
|
||||
"""
|
||||
Note [Cache result structure]
|
||||
|
||||
The cache for a decorated function's results is split into two parts to enable
|
||||
handling widgets invoked by the function.
|
||||
|
||||
Widgets act as implicit additional inputs to the cached function, so they should
|
||||
be used when deriving the cache key. However, we don't know what widgets are
|
||||
involved without first calling the function! So, we use the first execution
|
||||
of the function with a particular cache key to record what widgets are used,
|
||||
and use the current values of those widgets to derive a second cache key to
|
||||
look up the function execution's results. The combination of first and second
|
||||
cache keys act as one true cache key, just split up because the second part depends
|
||||
on the first.
|
||||
|
||||
We need to treat widgets as implicit arguments of the cached function, because
|
||||
the behavior of the function, inluding what elements are created and what it
|
||||
returns, can be and usually will be influenced by the values of those widgets.
|
||||
For example:
|
||||
> @st.memo
|
||||
> def example_fn(x):
|
||||
> y = x + 1
|
||||
> if st.checkbox("hi"):
|
||||
> st.write("you checked the thing")
|
||||
> y = 0
|
||||
> return y
|
||||
>
|
||||
> example_fn(2)
|
||||
|
||||
If the checkbox is checked, the function call should return 0 and the checkbox and
|
||||
message should be rendered. If the checkbox isn't checked, only the checkbox should
|
||||
render, and the function will return 3.
|
||||
|
||||
|
||||
There is a small catch in this. Since what widgets execute could depend on the values of
|
||||
any prior widgets, if we replace the `st.write` call in the example with a slider,
|
||||
the first time it runs, we would miss the slider because it wasn't called,
|
||||
so when we later execute the function with the checkbox checked, the widget cache key
|
||||
would not include the state of the slider, and would incorrectly get a cache hit
|
||||
for a different slider value.
|
||||
|
||||
In principle the cache could be function+args key -> 1st widget key -> 2nd widget key
|
||||
... -> final widget key, with each widget dependent on the exact values of the widgets
|
||||
seen prior. This would prevent unnecessary cache misses due to differing values of widgets
|
||||
that wouldn't affect the function's execution because they aren't even created.
|
||||
But this would add even more complexity and both conceptual and runtime overhead, so it is
|
||||
unclear if it would be worth doing.
|
||||
|
||||
Instead, we can keep the widgets as one cache key, and if we encounter a new widget
|
||||
while executing the function, we just update the list of widgets to include it.
|
||||
This will cause existing cached results to be invalidated, which is bad, but to
|
||||
avoid it we would need to keep around the full list of widgets and values for each
|
||||
widget cache key so we could compute the updated key, which is probably too expensive
|
||||
to be worth it.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedResult:
|
||||
"""The full results of calling a cache-decorated function, enough to
|
||||
replay the st functions called while executing it.
|
||||
"""
|
||||
|
||||
value: Any
|
||||
messages: list[MsgData]
|
||||
main_id: str
|
||||
sidebar_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiCacheResults:
|
||||
"""Widgets called by a cache-decorated function, and a mapping of the
|
||||
widget-derived cache key to the final results of executing the function.
|
||||
"""
|
||||
|
||||
widget_ids: set[str]
|
||||
results: dict[str, CachedResult]
|
||||
|
||||
def get_current_widget_key(
|
||||
self, ctx: ScriptRunContext, cache_type: CacheType
|
||||
) -> str:
|
||||
state = ctx.session_state
|
||||
# Compute the key using only widgets that have values. A missing widget
|
||||
# can be ignored because we only care about getting different keys
|
||||
# for different widget values, and for that purpose doing nothing
|
||||
# to the running hash is just as good as including the widget with a
|
||||
# sentinel value. But by excluding it, we might get to reuse a result
|
||||
# saved before we knew about that widget.
|
||||
widget_values = [
|
||||
(wid, state[wid]) for wid in sorted(self.widget_ids) if wid in state
|
||||
]
|
||||
widget_key = _make_widget_key(widget_values, cache_type)
|
||||
return widget_key
|
||||
|
||||
|
||||
"""
|
||||
Note [DeltaGenerator method invocation]
|
||||
There are two top level DG instances defined for all apps:
|
||||
`main`, which is for putting elements in the main part of the app
|
||||
`sidebar`, for the sidebar
|
||||
|
||||
There are 3 different ways an st function can be invoked:
|
||||
1. Implicitly on the main DG instance (plain `st.foo` calls)
|
||||
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
|
||||
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
|
||||
|
||||
To simplify replaying messages from a cached function result, we convert all of these
|
||||
to explicit invocations. How they get rewritten depends on if the invocation was
|
||||
implicit vs explicit, and if the target DG has been seen/produced during replay.
|
||||
|
||||
Implicit invocation on a known DG -> Explicit invocation on that DG
|
||||
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
|
||||
with st.container():
|
||||
my_cache_decorated_function()
|
||||
|
||||
This is situation 2 above, and the DG is a block entirely outside our function call,
|
||||
so we interpret it as "put this element in the enclosing contextmanager block"
|
||||
(or main if there isn't one), which is achieved by invoking on main.
|
||||
Explicit invocation on a known DG -> No change needed
|
||||
Explicit invocation on an unknown DG -> Raise an error
|
||||
We have no way to identify the target DG, and it may not even be present in the
|
||||
current script run, so the least surprising thing to do is raise an error.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CachedMessageReplayContext(threading.local):
|
||||
"""A utility for storing messages generated by `st` commands called inside
|
||||
a cached function.
|
||||
|
||||
Data is stored in a thread-local object, so it's safe to use an instance
|
||||
of this class across multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_type: CacheType):
|
||||
self._cached_func_stack: list[types.FunctionType] = []
|
||||
self._suppress_st_function_warning = 0
|
||||
self._cached_message_stack: list[list[MsgData]] = []
|
||||
self._seen_dg_stack: list[set[str]] = []
|
||||
self._most_recent_messages: list[MsgData] = []
|
||||
self._registered_metadata: WidgetMetadata[Any] | None = None
|
||||
self._media_data: list[MediaMsgData] = []
|
||||
self._cache_type = cache_type
|
||||
self._allow_widgets: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def calling_cached_function(
|
||||
self, func: types.FunctionType, allow_widgets: bool
|
||||
) -> Iterator[None]:
|
||||
"""Context manager that should wrap the invocation of a cached function.
|
||||
It allows us to track any `st.foo` messages that are generated from inside the function
|
||||
for playback during cache retrieval.
|
||||
"""
|
||||
self._cached_func_stack.append(func)
|
||||
self._cached_message_stack.append([])
|
||||
self._seen_dg_stack.append(set())
|
||||
if allow_widgets:
|
||||
self._allow_widgets += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._cached_func_stack.pop()
|
||||
self._most_recent_messages = self._cached_message_stack.pop()
|
||||
self._seen_dg_stack.pop()
|
||||
if allow_widgets:
|
||||
self._allow_widgets -= 1
|
||||
|
||||
def save_element_message(
|
||||
self,
|
||||
delta_type: str,
|
||||
element_proto: Message,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
"""Record the element protobuf as having been produced during any currently
|
||||
executing cached functions, so they can be replayed any time the function's
|
||||
execution is skipped because they're in the cache.
|
||||
"""
|
||||
if not runtime.exists():
|
||||
return
|
||||
if len(self._cached_message_stack) >= 1:
|
||||
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
||||
# Arrow dataframes have an ID but only set it when used as data editor
|
||||
# widgets, so we have to check that the ID has been actually set to
|
||||
# know if an element is a widget.
|
||||
if isinstance(element_proto, Widget) and element_proto.id:
|
||||
wid = element_proto.id
|
||||
# TODO replace `Message` with a more precise type
|
||||
if not self._registered_metadata:
|
||||
_LOGGER.error(
|
||||
"Trying to save widget message that wasn't registered. This should not be possible."
|
||||
)
|
||||
raise AttributeError
|
||||
widget_meta = WidgetMsgMetadata(
|
||||
wid, None, metadata=self._registered_metadata
|
||||
)
|
||||
else:
|
||||
widget_meta = None
|
||||
|
||||
media_data = self._media_data
|
||||
|
||||
element_msg_data = ElementMsgData(
|
||||
delta_type,
|
||||
element_proto,
|
||||
id_to_save,
|
||||
returned_dg_id,
|
||||
widget_meta,
|
||||
media_data,
|
||||
)
|
||||
for msgs in self._cached_message_stack:
|
||||
if self._allow_widgets or widget_meta is None:
|
||||
msgs.append(element_msg_data)
|
||||
|
||||
# Reset instance state, now that it has been used for the
|
||||
# associated element.
|
||||
self._media_data = []
|
||||
self._registered_metadata = None
|
||||
|
||||
for s in self._seen_dg_stack:
|
||||
s.add(returned_dg_id)
|
||||
|
||||
def save_block_message(
|
||||
self,
|
||||
block_proto: Block,
|
||||
invoked_dg_id: str,
|
||||
used_dg_id: str,
|
||||
returned_dg_id: str,
|
||||
) -> None:
|
||||
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
|
||||
for msgs in self._cached_message_stack:
|
||||
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
|
||||
for s in self._seen_dg_stack:
|
||||
s.add(returned_dg_id)
|
||||
|
||||
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
|
||||
"""Select the id of the DG that this message should be invoked on
|
||||
during message replay.
|
||||
|
||||
See Note [DeltaGenerator method invocation]
|
||||
|
||||
invoked_id is the DG the st function was called on, usually `st._main`.
|
||||
acting_on_id is the DG the st function ultimately runs on, which may be different
|
||||
if the invoked DG delegated to another one because it was in a `with` block.
|
||||
"""
|
||||
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
|
||||
return acting_on_id
|
||||
else:
|
||||
return invoked_id
|
||||
|
||||
def save_widget_metadata(self, metadata: WidgetMetadata[Any]) -> None:
|
||||
self._registered_metadata = metadata
|
||||
|
||||
def save_image_data(
|
||||
self, image_data: bytes | str, mimetype: str, image_id: str
|
||||
) -> None:
|
||||
self._media_data.append(MediaMsgData(image_data, mimetype, image_id))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_cached_st_function_warning(self) -> Iterator[None]:
|
||||
self._suppress_st_function_warning += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._suppress_st_function_warning -= 1
|
||||
assert self._suppress_st_function_warning >= 0
|
||||
|
||||
def maybe_show_cached_st_function_warning(
|
||||
self,
|
||||
dg: "DeltaGenerator",
|
||||
st_func_name: str,
|
||||
) -> None:
|
||||
"""If appropriate, warn about calling st.foo inside @memo.
|
||||
|
||||
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
|
||||
the user when they're calling st.foo() from within a function that is
|
||||
wrapped in @st.cache.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dg : DeltaGenerator
|
||||
The DeltaGenerator to publish the warning to.
|
||||
|
||||
st_func_name : str
|
||||
The name of the Streamlit function that was called.
|
||||
|
||||
"""
|
||||
# There are some elements not in either list, which we still want to warn about.
|
||||
# Ideally we will fix this by either updating the lists or creating a better
|
||||
# way of categorizing elements.
|
||||
if st_func_name in NONWIDGET_ELEMENTS:
|
||||
return
|
||||
if st_func_name in WIDGETS and self._allow_widgets > 0:
|
||||
return
|
||||
|
||||
if len(self._cached_func_stack) > 0 and self._suppress_st_function_warning <= 0:
|
||||
cached_func = self._cached_func_stack[-1]
|
||||
self._show_cached_st_function_warning(dg, st_func_name, cached_func)
|
||||
|
||||
def _show_cached_st_function_warning(
|
||||
self,
|
||||
dg: "DeltaGenerator",
|
||||
st_func_name: str,
|
||||
cached_func: types.FunctionType,
|
||||
) -> None:
|
||||
# Avoid infinite recursion by suppressing additional cached
|
||||
# function warnings from within the cached function warning.
|
||||
with self.suppress_cached_st_function_warning():
|
||||
e = CachedStFunctionWarning(self._cache_type, st_func_name, cached_func)
|
||||
dg.exception(e)
|
||||
|
||||
|
||||
def replay_cached_messages(
|
||||
result: CachedResult, cache_type: CacheType, cached_func: types.FunctionType
|
||||
) -> None:
|
||||
"""Replay the st element function calls that happened when executing a
|
||||
cache-decorated function.
|
||||
|
||||
When a cache function is executed, we record the element and block messages
|
||||
produced, and use those to reproduce the DeltaGenerator calls, so the elements
|
||||
will appear in the web app even when execution of the function is skipped
|
||||
because the result was cached.
|
||||
|
||||
To make this work, for each st function call we record an identifier for the
|
||||
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
|
||||
We also record the identifier for each DG returned by an st function call, if
|
||||
it returns one. Then, for each recorded message, we get the current DG instance
|
||||
corresponding to the DG the message was originally called on, and enqueue the
|
||||
message using that, recording any new DGs produced in case a later st function
|
||||
call is on one of them.
|
||||
"""
|
||||
from streamlit.delta_generator import DeltaGenerator
|
||||
from streamlit.runtime.state.widgets import register_widget_from_metadata
|
||||
|
||||
# Maps originally recorded dg ids to this script run's version of that dg
|
||||
returned_dgs: dict[str, DeltaGenerator] = {}
|
||||
returned_dgs[result.main_id] = st._main
|
||||
returned_dgs[result.sidebar_id] = st.sidebar
|
||||
ctx = get_script_run_ctx()
|
||||
|
||||
try:
|
||||
for msg in result.messages:
|
||||
if isinstance(msg, ElementMsgData):
|
||||
if msg.widget_metadata is not None:
|
||||
register_widget_from_metadata(
|
||||
msg.widget_metadata.metadata,
|
||||
ctx,
|
||||
None,
|
||||
msg.delta_type,
|
||||
)
|
||||
if msg.media_data is not None:
|
||||
for data in msg.media_data:
|
||||
runtime.get_instance().media_file_mgr.add(
|
||||
data.media, data.mimetype, data.media_id
|
||||
)
|
||||
dg = returned_dgs[msg.id_of_dg_called_on]
|
||||
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
|
||||
if isinstance(maybe_dg, DeltaGenerator):
|
||||
returned_dgs[msg.returned_dgs_id] = maybe_dg
|
||||
elif isinstance(msg, BlockMsgData):
|
||||
dg = returned_dgs[msg.id_of_dg_called_on]
|
||||
new_dg = dg._block(msg.message)
|
||||
returned_dgs[msg.returned_dgs_id] = new_dg
|
||||
except KeyError:
|
||||
raise CacheReplayClosureError(cache_type, cached_func)
|
||||
|
||||
|
||||
def _make_widget_key(widgets: list[tuple[str, Any]], cache_type: CacheType) -> str:
|
||||
"""Generate a key for the given list of widgets used in a cache-decorated function.
|
||||
|
||||
Keys are generated by hashing the IDs and values of the widgets in the given list.
|
||||
"""
|
||||
func_hasher = hashlib.new("md5")
|
||||
for widget_id_val in widgets:
|
||||
update_hash(widget_id_val, func_hasher, cache_type)
|
||||
|
||||
return func_hasher.hexdigest()
|
||||
@@ -0,0 +1,395 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Hashing for st.memo and st.singleton."""
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import unittest.mock
|
||||
import weakref
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Pattern
|
||||
|
||||
from streamlit import type_util, util
|
||||
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
|
||||
from streamlit.runtime.caching.cache_type import CacheType
|
||||
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
||||
|
||||
# If a dataframe has more than this many rows, we consider it large and hash a sample.
|
||||
_PANDAS_ROWS_LARGE = 100000
|
||||
_PANDAS_SAMPLE_SIZE = 10000
|
||||
|
||||
# Similar to dataframes, we also sample large numpy arrays.
|
||||
_NP_SIZE_LARGE = 1000000
|
||||
_NP_SAMPLE_SIZE = 100000
|
||||
|
||||
# Arbitrary item to denote where we found a cycle in a hashed object.
|
||||
# This allows us to hash self-referencing lists, dictionaries, etc.
|
||||
_CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
|
||||
|
||||
|
||||
def update_hash(val: Any, hasher, cache_type: CacheType) -> None:
|
||||
"""Updates a hashlib hasher with the hash of val.
|
||||
|
||||
This is the main entrypoint to hashing.py.
|
||||
"""
|
||||
ch = _CacheFuncHasher(cache_type)
|
||||
ch.update(hasher, val)
|
||||
|
||||
|
||||
class _HashStack:
|
||||
"""Stack of what has been hashed, for debug and circular reference detection.
|
||||
|
||||
This internally keeps 1 stack per thread.
|
||||
|
||||
Internally, this stores the ID of pushed objects rather than the objects
|
||||
themselves because otherwise the "in" operator inside __contains__ would
|
||||
fail for objects that don't return a boolean for "==" operator. For
|
||||
example, arr == 10 where arr is a NumPy array returns another NumPy array.
|
||||
This causes the "in" to crash since it expects a boolean.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def push(self, val: Any):
|
||||
self._stack[id(val)] = val
|
||||
|
||||
def pop(self):
|
||||
self._stack.popitem()
|
||||
|
||||
def __contains__(self, val: Any):
|
||||
return id(val) in self._stack
|
||||
|
||||
|
||||
class _HashStacks:
|
||||
"""Stacks of what has been hashed, with at most 1 stack per thread."""
|
||||
|
||||
def __init__(self):
|
||||
self._stacks: weakref.WeakKeyDictionary[
|
||||
threading.Thread, _HashStack
|
||||
] = weakref.WeakKeyDictionary()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
@property
|
||||
def current(self) -> _HashStack:
|
||||
current_thread = threading.current_thread()
|
||||
|
||||
stack = self._stacks.get(current_thread, None)
|
||||
|
||||
if stack is None:
|
||||
stack = _HashStack()
|
||||
self._stacks[current_thread] = stack
|
||||
|
||||
return stack
|
||||
|
||||
|
||||
hash_stacks = _HashStacks()
|
||||
|
||||
|
||||
def _int_to_bytes(i: int) -> bytes:
|
||||
num_bytes = (i.bit_length() + 8) // 8
|
||||
return i.to_bytes(num_bytes, "little", signed=True)
|
||||
|
||||
|
||||
def _key(obj: Optional[Any]) -> Any:
|
||||
"""Return key for memoization."""
|
||||
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
def is_simple(obj):
|
||||
return (
|
||||
isinstance(obj, bytes)
|
||||
or isinstance(obj, bytearray)
|
||||
or isinstance(obj, str)
|
||||
or isinstance(obj, float)
|
||||
or isinstance(obj, int)
|
||||
or isinstance(obj, bool)
|
||||
or obj is None
|
||||
)
|
||||
|
||||
if is_simple(obj):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, tuple):
|
||||
if all(map(is_simple, obj)):
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list):
|
||||
if all(map(is_simple, obj)):
|
||||
return ("__l", tuple(obj))
|
||||
|
||||
if (
|
||||
type_util.is_type(obj, "pandas.core.frame.DataFrame")
|
||||
or type_util.is_type(obj, "numpy.ndarray")
|
||||
or inspect.isbuiltin(obj)
|
||||
or inspect.isroutine(obj)
|
||||
or inspect.iscode(obj)
|
||||
):
|
||||
return id(obj)
|
||||
|
||||
return NoResult
|
||||
|
||||
|
||||
class _CacheFuncHasher:
|
||||
"""A hasher that can hash objects with cycles."""
|
||||
|
||||
def __init__(self, cache_type: CacheType):
|
||||
self._hashes: Dict[Any, bytes] = {}
|
||||
|
||||
# The number of the bytes in the hash.
|
||||
self.size = 0
|
||||
|
||||
self.cache_type = cache_type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return util.repr_(self)
|
||||
|
||||
def to_bytes(self, obj: Any) -> bytes:
|
||||
"""Add memoization to _to_bytes and protect against cycles in data structures."""
|
||||
tname = type(obj).__qualname__.encode()
|
||||
key = (tname, _key(obj))
|
||||
|
||||
# Memoize if possible.
|
||||
if key[1] is not NoResult:
|
||||
if key in self._hashes:
|
||||
return self._hashes[key]
|
||||
|
||||
# Break recursive cycles.
|
||||
if obj in hash_stacks.current:
|
||||
return _CYCLE_PLACEHOLDER
|
||||
|
||||
hash_stacks.current.push(obj)
|
||||
|
||||
try:
|
||||
# Hash the input
|
||||
b = b"%s:%s" % (tname, self._to_bytes(obj))
|
||||
|
||||
# Hmmm... It's possible that the size calculation is wrong. When we
|
||||
# call to_bytes inside _to_bytes things get double-counted.
|
||||
self.size += sys.getsizeof(b)
|
||||
|
||||
if key[1] is not NoResult:
|
||||
self._hashes[key] = b
|
||||
|
||||
finally:
|
||||
# In case an UnhashableTypeError (or other) error is thrown, clean up the
|
||||
# stack so we don't get false positives in future hashing calls
|
||||
hash_stacks.current.pop()
|
||||
|
||||
return b
|
||||
|
||||
def update(self, hasher, obj: Any) -> None:
|
||||
"""Update the provided hasher with the hash of an object."""
|
||||
b = self.to_bytes(obj)
|
||||
hasher.update(b)
|
||||
|
||||
def _to_bytes(self, obj: Any) -> bytes:
|
||||
"""Hash objects to bytes, including code with dependencies.
|
||||
|
||||
Python's built in `hash` does not produce consistent results across
|
||||
runs.
|
||||
"""
|
||||
|
||||
if isinstance(obj, unittest.mock.Mock):
|
||||
# Mock objects can appear to be infinitely
|
||||
# deep, so we don't try to hash them at all.
|
||||
return self.to_bytes(id(obj))
|
||||
|
||||
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
|
||||
return obj
|
||||
|
||||
elif isinstance(obj, str):
|
||||
return obj.encode()
|
||||
|
||||
elif isinstance(obj, float):
|
||||
return self.to_bytes(hash(obj))
|
||||
|
||||
elif isinstance(obj, int):
|
||||
return _int_to_bytes(obj)
|
||||
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
h = hashlib.new("md5")
|
||||
for item in obj:
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
elif isinstance(obj, dict):
|
||||
h = hashlib.new("md5")
|
||||
for item in obj.items():
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
elif obj is None:
|
||||
return b"0"
|
||||
|
||||
elif obj is True:
|
||||
return b"1"
|
||||
|
||||
elif obj is False:
|
||||
return b"0"
|
||||
|
||||
elif dataclasses.is_dataclass(obj):
|
||||
return self.to_bytes(dataclasses.asdict(obj))
|
||||
|
||||
elif isinstance(obj, Enum):
|
||||
return str(obj).encode()
|
||||
|
||||
elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
|
||||
obj, "pandas.core.series.Series"
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
if len(obj) >= _PANDAS_ROWS_LARGE:
|
||||
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
|
||||
try:
|
||||
return b"%s" % pd.util.hash_pandas_object(obj).sum()
|
||||
except TypeError:
|
||||
# Use pickle if pandas cannot hash the object for example if
|
||||
# it contains unhashable objects.
|
||||
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
elif type_util.is_type(obj, "numpy.ndarray"):
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.shape)
|
||||
|
||||
if obj.size >= _NP_SIZE_LARGE:
|
||||
import numpy as np
|
||||
|
||||
state = np.random.RandomState(0)
|
||||
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
|
||||
|
||||
self.update(h, obj.tobytes())
|
||||
return h.digest()
|
||||
elif type_util.is_type(obj, "PIL.Image.Image"):
|
||||
import numpy as np
|
||||
|
||||
# we don't just hash the results of obj.tobytes() because we want to use
|
||||
# the sampling logic for numpy data
|
||||
np_array = np.frombuffer(obj.tobytes(), dtype="uint8")
|
||||
return self.to_bytes(np_array)
|
||||
|
||||
elif inspect.isbuiltin(obj):
|
||||
return bytes(obj.__name__.encode())
|
||||
|
||||
elif type_util.is_type(obj, "builtins.mappingproxy") or type_util.is_type(
|
||||
obj, "builtins.dict_items"
|
||||
):
|
||||
return self.to_bytes(dict(obj))
|
||||
|
||||
elif type_util.is_type(obj, "builtins.getset_descriptor"):
|
||||
return bytes(obj.__qualname__.encode())
|
||||
|
||||
elif isinstance(obj, UploadedFile):
|
||||
# UploadedFile is a BytesIO (thus IOBase) but has a name.
|
||||
# It does not have a timestamp so this must come before
|
||||
# temporary files
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.name)
|
||||
self.update(h, obj.tell())
|
||||
self.update(h, obj.getvalue())
|
||||
return h.digest()
|
||||
|
||||
elif hasattr(obj, "name") and (
|
||||
isinstance(obj, io.IOBase)
|
||||
# Handle temporary files used during testing
|
||||
or isinstance(obj, tempfile._TemporaryFileWrapper)
|
||||
):
|
||||
# Hash files as name + last modification date + offset.
|
||||
# NB: we're using hasattr("name") to differentiate between
|
||||
# on-disk and in-memory StringIO/BytesIO file representations.
|
||||
# That means that this condition must come *before* the next
|
||||
# condition, which just checks for StringIO/BytesIO.
|
||||
h = hashlib.new("md5")
|
||||
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
|
||||
self.update(h, obj_name)
|
||||
self.update(h, os.path.getmtime(obj_name))
|
||||
self.update(h, obj.tell())
|
||||
return h.digest()
|
||||
|
||||
elif isinstance(obj, Pattern):
|
||||
return self.to_bytes([obj.pattern, obj.flags])
|
||||
|
||||
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
|
||||
# Hash in-memory StringIO/BytesIO by their full contents
|
||||
# and seek position.
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.tell())
|
||||
self.update(h, obj.getvalue())
|
||||
return h.digest()
|
||||
|
||||
elif type_util.is_type(obj, "numpy.ufunc"):
|
||||
# For numpy.remainder, this returns remainder.
|
||||
return bytes(obj.__name__.encode())
|
||||
|
||||
elif inspect.ismodule(obj):
|
||||
# TODO: Figure out how to best show this kind of warning to the
|
||||
# user. In the meantime, show nothing. This scenario is too common,
|
||||
# so the current warning is quite annoying...
|
||||
# st.warning(('Streamlit does not support hashing modules. '
|
||||
# 'We did not hash `%s`.') % obj.__name__)
|
||||
# TODO: Hash more than just the name for internal modules.
|
||||
return self.to_bytes(obj.__name__)
|
||||
|
||||
elif inspect.isclass(obj):
|
||||
# TODO: Figure out how to best show this kind of warning to the
|
||||
# user. In the meantime, show nothing. This scenario is too common,
|
||||
# (e.g. in every "except" statement) so the current warning is
|
||||
# quite annoying...
|
||||
# st.warning(('Streamlit does not support hashing classes. '
|
||||
# 'We did not hash `%s`.') % obj.__name__)
|
||||
# TODO: Hash more than just the name of classes.
|
||||
return self.to_bytes(obj.__name__)
|
||||
|
||||
elif isinstance(obj, functools.partial):
|
||||
# The return value of functools.partial is not a plain function:
|
||||
# it's a callable object that remembers the original function plus
|
||||
# the values you pickled into it. So here we need to special-case it.
|
||||
h = hashlib.new("md5")
|
||||
self.update(h, obj.args)
|
||||
self.update(h, obj.func)
|
||||
self.update(h, obj.keywords)
|
||||
return h.digest()
|
||||
|
||||
else:
|
||||
# As a last resort, hash the output of the object's __reduce__ method
|
||||
h = hashlib.new("md5")
|
||||
try:
|
||||
reduce_data = obj.__reduce__()
|
||||
except Exception as ex:
|
||||
raise UnhashableTypeError() from ex
|
||||
|
||||
for item in reduce_data:
|
||||
self.update(h, item)
|
||||
return h.digest()
|
||||
|
||||
|
||||
class NoResult:
|
||||
"""Placeholder class for return values when None is meaningful."""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage as CacheStorage,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageContext as CacheStorageContext,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageError as CacheStorageError,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageKeyNotFoundError as CacheStorageKeyNotFoundError,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorageManager as CacheStorageManager,
|
||||
)
|
||||
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Declares the CacheStorageContext dataclass, which contains parameter information for
|
||||
each function decorated by `@st.cache_data` (for example: ttl, max_entries etc.)
|
||||
|
||||
Declares the CacheStorageManager protocol, which implementations are used
|
||||
to create CacheStorage instances and to optionally clear all cache storages,
|
||||
that were created by this manager, and to check if the context is valid for the storage.
|
||||
|
||||
Declares the CacheStorage protocol, which implementations are used to store cached
|
||||
values for a single `@st.cache_data` decorated function serialized as bytes.
|
||||
|
||||
How these classes work together
|
||||
-------------------------------
|
||||
- CacheStorageContext : this is a dataclass that contains the parameters from
|
||||
`@st.cache_data` that are passed to the CacheStorageManager.create() method.
|
||||
|
||||
- CacheStorageManager : each instance of this is able to create CacheStorage
|
||||
instances, and optionally to clear data of all cache storages.
|
||||
|
||||
- CacheStorage : each instance of this is able to get, set, delete, and clear
|
||||
entries for a single `@st.cache_data` decorated function.
|
||||
|
||||
┌───────────────────────────────┐
|
||||
│ │
|
||||
│ CacheStorageManager │
|
||||
│ │
|
||||
│ - clear_all(optional) │
|
||||
│ - check_context │
|
||||
│ │
|
||||
└──┬────────────────────────────┘
|
||||
│
|
||||
│ ┌──────────────────────┐
|
||||
│ │ CacheStorage │
|
||||
│ create(context)│ │
|
||||
└────────────────► - get │
|
||||
│ - set │
|
||||
│ - delete │
|
||||
│ - close (optional)│
|
||||
│ - clear │
|
||||
└──────────────────────┘
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing_extensions import Literal, Protocol
|
||||
|
||||
|
||||
class CacheStorageError(Exception):
|
||||
"""Base exception raised by the cache storage"""
|
||||
|
||||
|
||||
class CacheStorageKeyNotFoundError(CacheStorageError):
|
||||
"""Raised when the key is not found in the cache storage"""
|
||||
|
||||
|
||||
class InvalidCacheStorageContext(CacheStorageError):
|
||||
"""Raised if the cache storage manager is not able to work with
|
||||
provided CacheStorageContext.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CacheStorageContext:
|
||||
"""Context passed to the cache storage during initialization
|
||||
This is the normalized parameters that are passed to CacheStorageManager.create()
|
||||
method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
function_key: str
|
||||
A hash computed based on function name and source code decorated
|
||||
by `@st.cache_data`
|
||||
|
||||
function_display_name: str
|
||||
The display name of the function that is decorated by `@st.cache_data`
|
||||
|
||||
ttl_seconds : float or None
|
||||
The time-to-live for the keys in storage, in seconds. If None, the entry
|
||||
will never expire.
|
||||
|
||||
max_entries : int or None
|
||||
The maximum number of entries to store in the cache storage.
|
||||
If None, the cache storage will not limit the number of entries.
|
||||
|
||||
persist : Literal["disk"] or None
|
||||
The persistence mode for the cache storage.
|
||||
Legacy parameter, that used in Streamlit current cache storage implementation.
|
||||
Could be ignored by cache storage implementation, if storage does not support
|
||||
persistence or it persistent by default.
|
||||
"""
|
||||
|
||||
function_key: str
|
||||
function_display_name: str
|
||||
ttl_seconds: float | None = None
|
||||
max_entries: int | None = None
|
||||
persist: Literal["disk"] | None = None
|
||||
|
||||
|
||||
class CacheStorage(Protocol):
|
||||
"""Cache storage protocol, that should be implemented by the concrete cache storages.
|
||||
Used to store cached values for a single `@st.cache_data` decorated function
|
||||
serialized as bytes.
|
||||
|
||||
CacheStorage instances should be created by `CacheStorageManager.create()` method.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: The methods of this protocol could be called from multiple threads.
|
||||
This is a responsibility of the concrete implementation to ensure thread safety
|
||||
guarantees.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, key: str) -> bytes:
|
||||
"""Returns the stored value for the key.
|
||||
|
||||
Raises
|
||||
------
|
||||
CacheStorageKeyNotFoundError
|
||||
Raised if the key is not in the storage.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a given key"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
"""Remove all keys for the storage"""
|
||||
raise NotImplementedError
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the cache storage, it is optional to implement, and should be used
|
||||
to close open resources, before we delete the storage instance.
|
||||
e.g. close the database connection etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class CacheStorageManager(Protocol):
|
||||
"""Cache storage manager protocol, that should be implemented by the concrete
|
||||
cache storage managers.
|
||||
|
||||
It is responsible for:
|
||||
- Creating cache storage instances for the specific
|
||||
decorated functions,
|
||||
- Validating the context for the cache storages.
|
||||
- Optionally clearing all cache storages in optimal way.
|
||||
|
||||
It should be created during Runtime initialization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance
|
||||
Please note that the ttl, max_entries and other context fields are specific
|
||||
for whole storage, not for individual key.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: Should be safe to call from any thread.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Remove everything what possible from the cache storages in optimal way.
|
||||
meaningful default behaviour is to raise NotImplementedError, so this is not
|
||||
abstractmethod.
|
||||
|
||||
The method is optional to implement: cache data API will fall back to remove
|
||||
all available storages one by one via storage.clear() method
|
||||
if clear_all raises NotImplementedError.
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError
|
||||
Raised if the storage manager does not provide an ability to clear
|
||||
all storages at once in optimal way.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: This method could be called from multiple threads.
|
||||
This is a responsibility of the concrete implementation to ensure
|
||||
thread safety guarantees.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
"""Checks if the context is valid for the storage manager.
|
||||
This method should not return anything, but log message or raise an exception
|
||||
if the context is invalid.
|
||||
|
||||
In case of raising an exception, we not handle it and let the exception to be
|
||||
propagated.
|
||||
|
||||
check_context is called only once at the moment of creating `@st.cache_data`
|
||||
decorator for specific function, so it is not called for every cache hit.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
context: CacheStorageContext
|
||||
The context to check for the storage manager, dummy function_key in context
|
||||
will be used, since it is not computed at the point of calling this method.
|
||||
|
||||
Raises
|
||||
------
|
||||
InvalidCacheStorageContext
|
||||
Raised if the cache storage manager is not able to work with provided
|
||||
CacheStorageContext. When possible we should log message instead, since
|
||||
this exception will be propagated to the user.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: Should be safe to call from any thread.
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
|
||||
InMemoryCacheStorageWrapper,
|
||||
)
|
||||
|
||||
|
||||
class MemoryCacheStorageManager(CacheStorageManager):
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
|
||||
persist_storage = DummyCacheStorage()
|
||||
return InMemoryCacheStorageWrapper(
|
||||
persist_storage=persist_storage, context=context
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DummyCacheStorage(CacheStorage):
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Dummy gets the value for a given key,
|
||||
always raises an CacheStorageKeyNotFoundError
|
||||
"""
|
||||
raise CacheStorageKeyNotFoundError("Key not found in dummy cache")
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
pass
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
pass
|
||||
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
@@ -0,0 +1,145 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import threading
|
||||
|
||||
from cachetools import TTLCache
|
||||
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching import cache_utils
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageKeyNotFoundError,
|
||||
)
|
||||
from streamlit.runtime.stats import CacheStat
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class InMemoryCacheStorageWrapper(CacheStorage):
|
||||
"""
|
||||
In-memory cache storage wrapper.
|
||||
|
||||
This class wraps a cache storage and adds an in-memory cache front layer,
|
||||
which is used to reduce the number of calls to the storage.
|
||||
|
||||
The in-memory cache is a TTL cache, which means that the entries are
|
||||
automatically removed if a given time to live (TTL) has passed.
|
||||
|
||||
The in-memory cache is also an LRU cache, which means that the entries
|
||||
are automatically removed if the cache size exceeds a given maxsize.
|
||||
|
||||
If the storage implements its strategy for maxsize, it is recommended
|
||||
(but not necessary) that the storage implement the same LRU strategy,
|
||||
otherwise a situation may arise when different items are deleted from
|
||||
the memory cache and from the storage.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Threading: in-memory caching layer is thread safe: we hold self._mem_cache_lock for
|
||||
working with this self._mem_cache object.
|
||||
However, we do not hold this lock when calling into the underlying storage,
|
||||
so it is the responsibility of the that storage to ensure that it is safe to use
|
||||
it from multiple threads.
|
||||
"""
|
||||
|
||||
def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
|
||||
self.function_key = context.function_key
|
||||
self.function_display_name = context.function_display_name
|
||||
self._ttl_seconds = context.ttl_seconds
|
||||
self._max_entries = context.max_entries
|
||||
self._mem_cache: TTLCache[str, bytes] = TTLCache(
|
||||
maxsize=self.max_entries,
|
||||
ttl=self.ttl_seconds,
|
||||
timer=cache_utils.TTLCACHE_TIMER,
|
||||
)
|
||||
self._mem_cache_lock = threading.Lock()
|
||||
self._persist_storage = persist_storage
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return float(self._max_entries) if self._max_entries is not None else math.inf
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Returns the stored value for the key or raise CacheStorageKeyNotFoundError if
|
||||
the key is not found
|
||||
"""
|
||||
try:
|
||||
entry_bytes = self._read_from_mem_cache(key)
|
||||
except CacheStorageKeyNotFoundError:
|
||||
entry_bytes = self._persist_storage.get(key)
|
||||
self._write_to_mem_cache(key, entry_bytes)
|
||||
return entry_bytes
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key"""
|
||||
self._write_to_mem_cache(key, value)
|
||||
self._persist_storage.set(key, value)
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a given key"""
|
||||
self._remove_from_mem_cache(key)
|
||||
self._persist_storage.delete(key)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all keys for the in memory cache, and also the persistent storage"""
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.clear()
|
||||
self._persist_storage.clear()
|
||||
|
||||
def get_stats(self) -> list[CacheStat]:
|
||||
"""Returns a list of stats in bytes for the cache memory storage per item"""
|
||||
stats = []
|
||||
|
||||
with self._mem_cache_lock:
|
||||
for item in self._mem_cache.values():
|
||||
stats.append(
|
||||
CacheStat(
|
||||
category_name="st_cache_data",
|
||||
cache_name=self.function_display_name,
|
||||
byte_length=len(item),
|
||||
)
|
||||
)
|
||||
return stats
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the cache storage"""
|
||||
self._persist_storage.close()
|
||||
|
||||
def _read_from_mem_cache(self, key: str) -> bytes:
|
||||
with self._mem_cache_lock:
|
||||
if key in self._mem_cache:
|
||||
entry = bytes(self._mem_cache[key])
|
||||
_LOGGER.debug("Memory cache HIT: %s", key)
|
||||
return entry
|
||||
|
||||
else:
|
||||
_LOGGER.debug("Memory cache MISS: %s", key)
|
||||
raise CacheStorageKeyNotFoundError("Key not found in mem cache")
|
||||
|
||||
def _write_to_mem_cache(self, key: str, entry_bytes: bytes) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache[key] = entry_bytes
|
||||
|
||||
def _remove_from_mem_cache(self, key: str) -> None:
|
||||
with self._mem_cache_lock:
|
||||
self._mem_cache.pop(key, None)
|
||||
@@ -0,0 +1,222 @@
|
||||
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Declares the LocalDiskCacheStorageManager class, which is used
|
||||
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
|
||||
InMemoryCacheStorageWrapper wrapper allows to have first layer of in-memory cache,
|
||||
before accessing to LocalDiskCacheStorage itself.
|
||||
|
||||
Declares the LocalDiskCacheStorage class, which is used to store cached
|
||||
values on disk.
|
||||
|
||||
How these classes work together
|
||||
-------------------------------
|
||||
|
||||
- LocalDiskCacheStorageManager : each instance of this is able
|
||||
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
|
||||
and to clear data from cache storage folder. It is also LocalDiskCacheStorageManager
|
||||
responsibility to check if the context is valid for the storage, and to log warning
|
||||
if the context is not valid.
|
||||
|
||||
- LocalDiskCacheStorage : each instance of this is able to get, set, delete, and clear
|
||||
entries from disk for a single `@st.cache_data` decorated function if `persist="disk"`
|
||||
is used in CacheStorageContext.
|
||||
|
||||
|
||||
┌───────────────────────────────┐
|
||||
│ LocalDiskCacheStorageManager │
|
||||
│ │
|
||||
│ - clear_all │
|
||||
│ - check_context │
|
||||
│ │
|
||||
└──┬────────────────────────────┘
|
||||
│
|
||||
│ ┌──────────────────────────────┐
|
||||
│ │ │
|
||||
│ create(context)│ InMemoryCacheStorageWrapper │
|
||||
└────────────────► │
|
||||
│ ┌─────────────────────┐ │
|
||||
│ │ │ │
|
||||
│ │ LocalDiskStorage │ │
|
||||
│ │ │ │
|
||||
│ └─────────────────────┘ │
|
||||
│ │
|
||||
└──────────────────────────────┘
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from streamlit import util
|
||||
from streamlit.file_util import get_streamlit_file_path, streamlit_read, streamlit_write
|
||||
from streamlit.logger import get_logger
|
||||
from streamlit.runtime.caching.storage.cache_storage_protocol import (
|
||||
CacheStorage,
|
||||
CacheStorageContext,
|
||||
CacheStorageError,
|
||||
CacheStorageKeyNotFoundError,
|
||||
CacheStorageManager,
|
||||
)
|
||||
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
|
||||
InMemoryCacheStorageWrapper,
|
||||
)
|
||||
|
||||
# Streamlit directory where persisted @st.cache_data objects live.
|
||||
# (This is the same directory that @st.cache persisted objects live.
|
||||
# But @st.cache_data uses a different extension, so they don't overlap.)
|
||||
_CACHE_DIR_NAME = "cache"
|
||||
|
||||
# The extension for our persisted @st.cache_data objects.
|
||||
# (`@st.cache_data` was originally called `@st.memo`)
|
||||
_CACHED_FILE_EXTENSION = "memo"
|
||||
|
||||
_LOGGER = get_logger(__name__)
|
||||
|
||||
|
||||
class LocalDiskCacheStorageManager(CacheStorageManager):
|
||||
def create(self, context: CacheStorageContext) -> CacheStorage:
|
||||
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
|
||||
persist_storage = LocalDiskCacheStorage(context)
|
||||
return InMemoryCacheStorageWrapper(
|
||||
persist_storage=persist_storage, context=context
|
||||
)
|
||||
|
||||
def clear_all(self) -> None:
|
||||
cache_path = get_cache_folder_path()
|
||||
if os.path.isdir(cache_path):
|
||||
shutil.rmtree(cache_path)
|
||||
|
||||
def check_context(self, context: CacheStorageContext) -> None:
|
||||
if (
|
||||
context.persist == "disk"
|
||||
and context.ttl_seconds is not None
|
||||
and not math.isinf(context.ttl_seconds)
|
||||
):
|
||||
_LOGGER.warning(
|
||||
f"The cached function '{context.function_display_name}' has a TTL "
|
||||
"that will be ignored. Persistent cached functions currently don't "
|
||||
"support TTL."
|
||||
)
|
||||
|
||||
|
||||
class LocalDiskCacheStorage(CacheStorage):
|
||||
"""Cache storage that persists data to disk
|
||||
This is the default cache persistence layer for `@st.cache_data`
|
||||
"""
|
||||
|
||||
def __init__(self, context: CacheStorageContext):
|
||||
self.function_key = context.function_key
|
||||
self.persist = context.persist
|
||||
self._ttl_seconds = context.ttl_seconds
|
||||
self._max_entries = context.max_entries
|
||||
|
||||
@property
|
||||
def ttl_seconds(self) -> float:
|
||||
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
|
||||
|
||||
@property
|
||||
def max_entries(self) -> float:
|
||||
return float(self._max_entries) if self._max_entries is not None else math.inf
|
||||
|
||||
def get(self, key: str) -> bytes:
|
||||
"""
|
||||
Returns the stored value for the key if persisted,
|
||||
raise CacheStorageKeyNotFoundError if not found, or not configured
|
||||
with persist="disk"
|
||||
"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
with streamlit_read(path, binary=True) as input:
|
||||
value = input.read()
|
||||
_LOGGER.debug("Disk cache HIT: %s", key)
|
||||
return bytes(value)
|
||||
except FileNotFoundError:
|
||||
raise CacheStorageKeyNotFoundError("Key not found in disk cache")
|
||||
except Exception as ex:
|
||||
_LOGGER.error(ex)
|
||||
raise CacheStorageError("Unable to read from cache") from ex
|
||||
else:
|
||||
raise CacheStorageKeyNotFoundError(
|
||||
f"Local disk cache storage is disabled (persist={self.persist})"
|
||||
)
|
||||
|
||||
def set(self, key: str, value: bytes) -> None:
|
||||
"""Sets the value for a given key"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
with streamlit_write(path, binary=True) as output:
|
||||
output.write(value)
|
||||
except util.Error as e:
|
||||
_LOGGER.debug(e)
|
||||
# Clean up file so we don't leave zero byte files.
|
||||
try:
|
||||
os.remove(path)
|
||||
except (FileNotFoundError, IOError, OSError):
|
||||
# If we can't remove the file, it's not a big deal.
|
||||
pass
|
||||
raise CacheStorageError("Unable to write to cache") from e
|
||||
|
||||
def delete(self, key: str) -> None:
|
||||
"""Delete a cache file from disk. If the file does not exist on disk,
|
||||
return silently. If another exception occurs, log it. Does not throw.
|
||||
"""
|
||||
if self.persist == "disk":
|
||||
path = self._get_cache_file_path(key)
|
||||
try:
|
||||
os.remove(path)
|
||||
except FileNotFoundError:
|
||||
# The file is already removed.
|
||||
pass
|
||||
except Exception as ex:
|
||||
_LOGGER.exception(
|
||||
"Unable to remove a file from the disk cache", exc_info=ex
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Delete all keys for the current storage"""
|
||||
cache_dir = get_cache_folder_path()
|
||||
|
||||
if os.path.isdir(cache_dir):
|
||||
# We try to remove all files in the cache directory that start with
|
||||
# the function key, whether `clear` called for `self.persist`
|
||||
# storage or not, to avoid leaving orphaned files in the cache directory.
|
||||
for file_name in os.listdir(cache_dir):
|
||||
if self._is_cache_file(file_name):
|
||||
os.remove(os.path.join(cache_dir, file_name))
|
||||
|
||||
def close(self) -> None:
|
||||
"""Dummy implementation of close, we don't need to actually "close" anything"""
|
||||
|
||||
def _get_cache_file_path(self, value_key: str) -> str:
|
||||
"""Return the path of the disk cache file for the given value."""
|
||||
cache_dir = get_cache_folder_path()
|
||||
return os.path.join(
|
||||
cache_dir, f"{self.function_key}-{value_key}.{_CACHED_FILE_EXTENSION}"
|
||||
)
|
||||
|
||||
def _is_cache_file(self, fname: str) -> bool:
|
||||
"""Return true if the given file name is a cache file for this storage."""
|
||||
return fname.startswith(f"{self.function_key}-") and fname.endswith(
|
||||
f".{_CACHED_FILE_EXTENSION}"
|
||||
)
|
||||
|
||||
|
||||
def get_cache_folder_path() -> str:
|
||||
return get_streamlit_file_path(_CACHE_DIR_NAME)
|
||||
Reference in New Issue
Block a user