Merging PR_218 openai_rev package with new streamlit chat app

This commit is contained in:
noptuno
2023-04-27 20:29:30 -04:00
parent 479b8d6d10
commit 355dee533b
8378 changed files with 2931636 additions and 3 deletions

View File

@@ -0,0 +1,132 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from typing import Any, Iterator, Union
from google.protobuf.message import Message
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_data_api import (
CACHE_DATA_MESSAGE_REPLAY_CTX,
CacheDataAPI,
_data_caches,
)
from streamlit.runtime.caching.cache_errors import CACHE_DOCS_URL as CACHE_DOCS_URL
from streamlit.runtime.caching.cache_resource_api import (
CACHE_RESOURCE_MESSAGE_REPLAY_CTX,
CacheResourceAPI,
_resource_caches,
)
from streamlit.runtime.state.common import WidgetMetadata
def save_element_message(
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Save the message for an element to a thread-local callstack, so it can
be used later to replay the element when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_element_message(
delta_type, element_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
def save_block_message(
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Save the message for a block to a thread-local callstack, so it can
be used later to replay the block when a cache-decorated function's
execution is skipped.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_block_message(
block_proto, invoked_dg_id, used_dg_id, returned_dg_id
)
def save_widget_metadata(metadata: WidgetMetadata[Any]) -> None:
"""Save a widget's metadata to a thread-local callstack, so the widget
can be registered again when that widget is replayed.
"""
CACHE_DATA_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_widget_metadata(metadata)
def save_media_data(
image_data: Union[bytes, str], mimetype: str, image_id: str
) -> None:
CACHE_DATA_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.save_image_data(image_data, mimetype, image_id)
def maybe_show_cached_st_function_warning(dg, st_func_name: str) -> None:
CACHE_DATA_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
dg, st_func_name
)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX.maybe_show_cached_st_function_warning(
dg, st_func_name
)
@contextlib.contextmanager
def suppress_cached_st_function_warning() -> Iterator[None]:
with CACHE_DATA_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning(), CACHE_RESOURCE_MESSAGE_REPLAY_CTX.suppress_cached_st_function_warning():
yield
# Explicitly export public symbols
from streamlit.runtime.caching.cache_data_api import (
get_data_cache_stats_provider as get_data_cache_stats_provider,
)
from streamlit.runtime.caching.cache_resource_api import (
get_resource_cache_stats_provider as get_resource_cache_stats_provider,
)
# Create and export public API singletons.
cache_data = CacheDataAPI(decorator_metric_name="cache_data")
cache_resource = CacheResourceAPI(decorator_metric_name="cache_resource")
# Deprecated singletons
_MEMO_WARNING = (
f"`st.experimental_memo` is deprecated. Please use the new command `st.cache_data` instead, "
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
)
experimental_memo = CacheDataAPI(
decorator_metric_name="experimental_memo", deprecation_warning=_MEMO_WARNING
)
_SINGLETON_WARNING = (
f"`st.experimental_singleton` is deprecated. Please use the new command `st.cache_resource` instead, "
f"which has the same behavior. More information [in our docs]({CACHE_DOCS_URL})."
)
experimental_singleton = CacheResourceAPI(
decorator_metric_name="experimental_singleton",
deprecation_warning=_SINGLETON_WARNING,
)

View File

@@ -0,0 +1,678 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_data: pickle-based caching"""
from __future__ import annotations
import pickle
import threading
import types
from datetime import timedelta
from typing import Any, Callable, TypeVar, Union, cast, overload
from typing_extensions import Literal, TypeAlias
import streamlit as st
from streamlit import runtime
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
ElementMsgData,
MsgData,
MultiCacheResults,
)
from streamlit.runtime.caching.storage import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
InvalidCacheStorageContext,
)
from streamlit.runtime.caching.storage.dummy_cache_storage import (
MemoryCacheStorageManager,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
_LOGGER = get_logger(__name__)
CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)
# The cache persistence options we support: "disk" or None
CachePersistType: TypeAlias = Union[Literal["disk"], None]
class CachedDataFuncInfo(CachedFuncInfo):
"""Implements the CachedFuncInfo interface for @st.cache_data"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
persist: CachePersistType,
max_entries: int | None,
ttl: float | timedelta | None,
allow_widgets: bool,
):
super().__init__(
func,
show_spinner=show_spinner,
allow_widgets=allow_widgets,
)
self.persist = persist
self.max_entries = max_entries
self.ttl = ttl
self.validate_params()
@property
def cache_type(self) -> CacheType:
return CacheType.DATA
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_DATA_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _data_caches.get_cache(
key=function_key,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
display_name=self.display_name,
allow_widgets=self.allow_widgets,
)
def validate_params(self) -> None:
"""
Validate the params passed to @st.cache_data are compatible with cache storage
When called, this method could log warnings if cache params are invalid
for current storage.
"""
_data_caches.validate_cache_params(
function_name=self.func.__name__,
persist=self.persist,
max_entries=self.max_entries,
ttl=self.ttl,
)
class DataCaches(CacheStatsProvider):
"""Manages all DataCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: dict[str, DataCache] = {}
def get_cache(
self,
key: str,
persist: CachePersistType,
max_entries: int | None,
ttl: int | float | timedelta | None,
display_name: str,
allow_widgets: bool,
) -> DataCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and cache.persist == persist
):
return cache
# Close the existing cache's storage, if it exists.
if cache is not None:
_LOGGER.debug(
"Closing existing DataCache storage "
"(key=%s, persist=%s, max_entries=%s, ttl=%s) "
"before creating new one with different params",
key,
persist,
max_entries,
ttl,
)
cache.storage.close()
# Create a new cache object and put it in our dict
_LOGGER.debug(
"Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
key,
persist,
max_entries,
ttl,
)
cache_context = self.create_cache_storage_context(
function_key=key,
function_name=display_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
cache_storage_manager = self.get_storage_manager()
storage = cache_storage_manager.create(cache_context)
cache = DataCache(
key=key,
storage=storage,
persist=persist,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
display_name=display_name,
allow_widgets=allow_widgets,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all in-memory and on-disk caches."""
with self._caches_lock:
try:
# try to remove in optimal way if such ability provided by
# storage manager clear_all method;
# if not implemented, fallback to remove all
# available storages one by one
self.get_storage_manager().clear_all()
except NotImplementedError:
for data_cache in self._function_caches.values():
data_cache.clear()
data_cache.storage.close()
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
def validate_cache_params(
self,
function_name: str,
persist: CachePersistType,
max_entries: int | None,
ttl: int | float | timedelta | None,
) -> None:
"""Validate that the cache params are valid for given storage.
Raises
------
InvalidCacheStorageContext
Raised if the cache storage manager is not able to work with provided
CacheStorageContext.
"""
ttl_seconds = ttl_to_seconds(ttl, coerce_none_to_inf=False)
cache_context = self.create_cache_storage_context(
function_key="DUMMY_KEY",
function_name=function_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
try:
self.get_storage_manager().check_context(cache_context)
except InvalidCacheStorageContext as e:
_LOGGER.error(
"Cache params for function %s are incompatible with current "
"cache storage manager: %s",
function_name,
e,
)
raise
def create_cache_storage_context(
self,
function_key: str,
function_name: str,
persist: CachePersistType,
ttl_seconds: float | None,
max_entries: int | None,
) -> CacheStorageContext:
return CacheStorageContext(
function_key=function_key,
function_display_name=function_name,
ttl_seconds=ttl_seconds,
max_entries=max_entries,
persist=persist,
)
def get_storage_manager(self) -> CacheStorageManager:
if runtime.exists():
return runtime.get_instance().cache_storage_manager
else:
# When running in "raw mode", we can't access the CacheStorageManager,
# so we're falling back to InMemoryCache.
_LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
return MemoryCacheStorageManager()
# Singleton DataCaches instance
_data_caches = DataCaches()
def get_data_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_data functions."""
return _data_caches
class CacheDataAPI:
"""Implements the public st.cache_data API: the @st.cache_data decorator, and
st.cache_data.clear().
"""
def __init__(
self, decorator_metric_name: str, deprecation_warning: str | None = None
):
"""Create a CacheDataAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage. `@st.experimental_memo` is
deprecated, but we're still supporting it and tracking its usage separately
from `@st.cache_data`.
deprecation_warning
An optional deprecation warning to show when the API is accessed.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics( # type: ignore
decorator_metric_name, self._decorator
)
self._deprecation_warning = deprecation_warning
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
def __call__(self, func: F) -> F:
...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
) -> Callable[[F], F]:
...
def __call__(
self,
func: F | None = None,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
persist: CachePersistType | bool = None,
experimental_allow_widgets: bool = False,
):
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
persist=persist,
show_spinner=show_spinner,
experimental_allow_widgets=experimental_allow_widgets,
)
def _decorator(
self,
func: F | None = None,
*,
ttl: float | timedelta | None,
max_entries: int | None,
show_spinner: bool | str,
persist: CachePersistType | bool,
experimental_allow_widgets: bool,
):
"""Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).
Cached objects are stored in "pickled" form, which means that the return
value of a cached function must be pickleable. Each caller of the cached
function gets its own copy of the cached data.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_data.clear()``.
To cache global resources, use ``st.cache_resource`` instead. Learn more
about caching at https://docs.streamlit.io/library/advanced-features/caching.
Parameters
----------
func : callable
The function to cache. Streamlit hashes the function's source code.
ttl : float or timedelta or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
Note that ttl is incompatible with ``persist="disk"`` - ``ttl`` will be
ignored if ``persist`` is specified.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
show_spinner : boolean or string
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached data is being created. If string,
value of show_spinner param will be used for spinner text.
persist : str or boolean or None
Optional location to persist cached data to. Passing "disk" (or True)
will persist the cached data to the local disk. None (or False) will disable
persistence. The default is None.
experimental_allow_widgets : boolean
Allow widgets to be used in the cached function. Defaults to False.
Support for widgets in cached functions is currently experimental.
Setting this parameter to True may lead to excessive memory use since the
widget value is treated as an additional input parameter to the cache.
We may remove support for this option at any time without notice.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
...
>>> d1 = fetch_and_clean_data(DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> d2 = fetch_and_clean_data(DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the data in d1 is the same as in d2.
>>>
>>> d3 = fetch_and_clean_data(DATA_URL_2)
>>> # This is a different URL, so the function executes.
To set the ``persist`` parameter, use this command as follows:
>>> import streamlit as st
>>>
>>> @st.cache_data(persist="disk")
... def fetch_and_clean_data(url):
... # Fetch data from URL here, and then clean it up.
... return data
By default, all parameters to a cached function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> connection = make_database_connection()
>>> d1 = fetch_and_clean_data(connection, num_rows=10)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> another_connection = make_database_connection()
>>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _database_connection parameter was different
>>> # in both calls.
A cached function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_data
... def fetch_and_clean_data(_db_connection, num_rows):
... # Fetch data from _db_connection here, and then clean it up.
... return data
...
>>> fetch_and_clean_data.clear()
>>> # Clear all cached entries for this function.
"""
# Parse our persist value into a string
persist_string: CachePersistType
if persist is True:
persist_string = "disk"
elif persist is False:
persist_string = None
else:
persist_string = persist
if persist_string not in (None, "disk"):
# We'll eventually have more persist options.
raise StreamlitAPIException(
f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
)
self._maybe_show_deprecation_warning()
def wrapper(f):
return make_cached_func_wrapper(
CachedDataFuncInfo(
func=f,
persist=persist_string,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
allow_widgets=experimental_allow_widgets,
)
)
if func is None:
return wrapper
return make_cached_func_wrapper(
CachedDataFuncInfo(
func=cast(types.FunctionType, func),
persist=persist_string,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
allow_widgets=experimental_allow_widgets,
)
)
@gather_metrics("clear_data_caches")
def clear(self) -> None:
"""Clear all in-memory and on-disk data caches."""
self._maybe_show_deprecation_warning()
_data_caches.clear_all()
def _maybe_show_deprecation_warning(self):
"""If the API is being accessed with the deprecated `st.experimental_memo` name,
show a deprecation warning.
"""
if self._deprecation_warning is not None:
show_deprecation_warning(self._deprecation_warning)
class DataCache(Cache):
"""Manages cached values for a single st.cache_data function."""
def __init__(
self,
key: str,
storage: CacheStorage,
persist: CachePersistType,
max_entries: int | None,
ttl_seconds: float | None,
display_name: str,
allow_widgets: bool = False,
):
super().__init__()
self.key = key
self.display_name = display_name
self.storage = storage
self.ttl_seconds = ttl_seconds
self.max_entries = max_entries
self.persist = persist
self.allow_widgets = allow_widgets
def get_stats(self) -> list[CacheStat]:
if isinstance(self.storage, CacheStatsProvider):
return self.storage.get_stats()
return []
def read_result(self, key: str) -> CachedResult:
"""Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
if the value doesn't exist, and `CacheError` if the value exists but can't
be unpickled.
"""
try:
pickled_entry = self.storage.get(key)
except CacheStorageKeyNotFoundError as e:
raise CacheKeyNotFoundError(str(e)) from e
except CacheStorageError as e:
raise CacheError(str(e)) from e
try:
entry = pickle.loads(pickled_entry)
if not isinstance(entry, MultiCacheResults):
# Loaded an old cache file format, remove it and let the caller
# rerun the function.
self.storage.delete(key)
raise CacheKeyNotFoundError()
ctx = get_script_run_ctx()
if not ctx:
raise CacheKeyNotFoundError()
widget_key = entry.get_current_widget_key(ctx, CacheType.DATA)
if widget_key in entry.results:
return entry.results[widget_key]
else:
raise CacheKeyNotFoundError()
except pickle.UnpicklingError as exc:
raise CacheError(f"Failed to unpickle {key}") from exc
@gather_metrics("_cache_data_object")
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache.
The value must be pickleable.
"""
ctx = get_script_run_ctx()
if ctx is None:
return
main_id = st._main.id
sidebar_id = st.sidebar.id
if self.allow_widgets:
widgets = {
msg.widget_metadata.widget_id
for msg in messages
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
}
else:
widgets = set()
multi_cache_results: MultiCacheResults | None = None
# Try to find in cache storage, then falling back to a new result instance
try:
multi_cache_results = self._read_multi_results_from_storage(key)
except (CacheKeyNotFoundError, pickle.UnpicklingError):
pass
if multi_cache_results is None:
multi_cache_results = MultiCacheResults(widget_ids=widgets, results={})
multi_cache_results.widget_ids.update(widgets)
widget_key = multi_cache_results.get_current_widget_key(ctx, CacheType.DATA)
result = CachedResult(value, messages, main_id, sidebar_id)
multi_cache_results.results[widget_key] = result
try:
pickled_entry = pickle.dumps(multi_cache_results)
except (pickle.PicklingError, TypeError) as exc:
raise CacheError(f"Failed to pickle {key}") from exc
self.storage.set(key, pickled_entry)
def _clear(self) -> None:
self.storage.clear()
def _read_multi_results_from_storage(self, key: str) -> MultiCacheResults:
"""Look up the results from storage and ensure it has the right type.
Raises a `CacheKeyNotFoundError` if the key has no entry, or if the
entry is malformed.
"""
try:
pickled = self.storage.get(key)
except CacheStorageKeyNotFoundError as e:
raise CacheKeyNotFoundError(str(e)) from e
maybe_results = pickle.loads(pickled)
if isinstance(maybe_results, MultiCacheResults):
return maybe_results
else:
self.storage.delete(key)
raise CacheKeyNotFoundError()

View File

@@ -0,0 +1,176 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
from typing import Any, Optional
from streamlit import type_util
from streamlit.errors import (
MarkdownFormattedException,
StreamlitAPIException,
StreamlitAPIWarning,
)
from streamlit.runtime.caching.cache_type import CacheType, get_decorator_api_name
CACHE_DOCS_URL = "https://docs.streamlit.io/library/advanced-features/caching"
def get_cached_func_name_md(func: Any) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
elif hasattr(type(func), "__name__"):
return f"`{type(func).__name__}`"
return f"`{type(func)}`"
def get_return_value_type(return_value: Any) -> str:
if hasattr(return_value, "__module__") and hasattr(type(return_value), "__name__"):
return f"`{return_value.__module__}.{type(return_value).__name__}`"
return get_cached_func_name_md(return_value)
class UnhashableTypeError(Exception):
pass
class UnhashableParamError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
orig_exc: BaseException,
):
msg = self._create_message(cache_type, func, arg_name, arg_value)
super().__init__(msg)
self.with_traceback(orig_exc.__traceback__)
@staticmethod
def _create_message(
cache_type: CacheType,
func: types.FunctionType,
arg_name: Optional[str],
arg_value: Any,
) -> str:
arg_name_str = arg_name if arg_name is not None else "(unnamed)"
arg_type = type_util.get_fqn_type(arg_value)
func_name = func.__name__
arg_replacement_name = f"_{arg_name}" if arg_name is not None else "_arg"
return (
f"""
Cannot hash argument '{arg_name_str}' (of type `{arg_type}`) in '{func_name}'.
To address this, you can tell Streamlit not to hash this argument by adding a
leading underscore to the argument's name in the function signature:
```
@st.{get_decorator_api_name(cache_type)}
def {func_name}({arg_replacement_name}, ...):
...
```
"""
).strip("\n")
class CacheKeyNotFoundError(Exception):
pass
class CacheError(Exception):
pass
class CachedStFunctionWarning(StreamlitAPIWarning):
def __init__(
self,
cache_type: CacheType,
st_func_name: str,
cached_func: types.FunctionType,
):
args = {
"st_func_name": f"`st.{st_func_name}()`",
"func_name": self._get_cached_func_name_md(cached_func),
"decorator_name": get_decorator_api_name(cache_type),
}
msg = (
"""
Your script uses %(st_func_name)s to write to your Streamlit app from within
some cached code at %(func_name)s. This code will only be called when we detect
a cache "miss", which can lead to unexpected results.
How to fix this:
* Move the %(st_func_name)s call outside %(func_name)s.
* Or, if you know what you're doing, use `@st.%(decorator_name)s(experimental_allow_widgets=True)`
to enable widget replay and suppress this warning.
"""
% args
).strip("\n")
super().__init__(msg)
@staticmethod
def _get_cached_func_name_md(func: types.FunctionType) -> str:
"""Get markdown representation of the function name."""
if hasattr(func, "__name__"):
return "`%s()`" % func.__name__
else:
return "a cached function"
class CacheReplayClosureError(StreamlitAPIException):
def __init__(
self,
cache_type: CacheType,
cached_func: types.FunctionType,
):
func_name = get_cached_func_name_md(cached_func)
decorator_name = get_decorator_api_name(cache_type)
msg = (
f"""
While running {func_name}, a streamlit element is called on some layout block created outside the function.
This is incompatible with replaying the cached effect of that element, because the
the referenced block might not exist when the replay happens.
How to fix this:
* Move the creation of $THING inside {func_name}.
* Move the call to the streamlit element outside of {func_name}.
* Remove the `@st.{decorator_name}` decorator from {func_name}.
"""
).strip("\n")
super().__init__(msg)
class UnserializableReturnValueError(MarkdownFormattedException):
def __init__(self, func: types.FunctionType, return_value: types.FunctionType):
MarkdownFormattedException.__init__(
self,
f"""
Cannot serialize the return value (of type {get_return_value_type(return_value)}) in {get_cached_func_name_md(func)}.
`st.cache_data` uses [pickle](https://docs.python.org/3/library/pickle.html) to
serialize the functions return value and safely store it in the cache without mutating the original object. Please convert the return value to a pickle-serializable type.
If you want to cache unserializable objects such as database connections or Tensorflow
sessions, use `st.cache_resource` instead (see [our docs]({CACHE_DOCS_URL}) for differences).""",
)
class UnevaluatedDataFrameError(StreamlitAPIException):
"""Used to display a message about uncollected dataframe being used"""
pass

View File

@@ -0,0 +1,520 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""@st.cache_resource implementation"""
from __future__ import annotations
import math
import threading
import types
from datetime import timedelta
from typing import Any, Callable, TypeVar, cast, overload
from cachetools import TTLCache
from pympler import asizeof
from typing_extensions import TypeAlias
import streamlit as st
from streamlit.deprecation_util import show_deprecation_warning
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.cache_errors import CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
Cache,
CachedFuncInfo,
make_cached_func_wrapper,
ttl_to_seconds,
)
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
ElementMsgData,
MsgData,
MultiCacheResults,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.scriptrunner.script_run_context import get_script_run_ctx
from streamlit.runtime.stats import CacheStat, CacheStatsProvider
_LOGGER = get_logger(__name__)
CACHE_RESOURCE_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.RESOURCE)
ValidateFunc: TypeAlias = Callable[[Any], bool]
def _equal_validate_funcs(a: ValidateFunc | None, b: ValidateFunc | None) -> bool:
"""True if the two validate functions are equal for the purposes of
determining whether a given function cache needs to be recreated.
"""
# To "properly" test for function equality here, we'd need to compare function bytecode.
# For performance reasons, We've decided not to do that for now.
return (a is None and b is None) or (a is not None and b is not None)
class ResourceCaches(CacheStatsProvider):
"""Manages all ResourceCache instances"""
def __init__(self):
self._caches_lock = threading.Lock()
self._function_caches: dict[str, ResourceCache] = {}
def get_cache(
self,
key: str,
display_name: str,
max_entries: int | float | None,
ttl: float | timedelta | None,
validate: ValidateFunc | None,
allow_widgets: bool,
) -> ResourceCache:
"""Return the mem cache for the given key.
If it doesn't exist, create a new one with the given params.
"""
if max_entries is None:
max_entries = math.inf
ttl_seconds = ttl_to_seconds(ttl)
# Get the existing cache, if it exists, and validate that its params
# haven't changed.
with self._caches_lock:
cache = self._function_caches.get(key)
if (
cache is not None
and cache.ttl_seconds == ttl_seconds
and cache.max_entries == max_entries
and _equal_validate_funcs(cache.validate, validate)
):
return cache
# Create a new cache object and put it in our dict
_LOGGER.debug("Creating new ResourceCache (key=%s)", key)
cache = ResourceCache(
key=key,
display_name=display_name,
max_entries=max_entries,
ttl_seconds=ttl_seconds,
validate=validate,
allow_widgets=allow_widgets,
)
self._function_caches[key] = cache
return cache
def clear_all(self) -> None:
"""Clear all resource caches."""
with self._caches_lock:
self._function_caches = {}
def get_stats(self) -> list[CacheStat]:
with self._caches_lock:
# Shallow-clone our caches. We don't want to hold the global
# lock during stats-gathering.
function_caches = self._function_caches.copy()
stats: list[CacheStat] = []
for cache in function_caches.values():
stats.extend(cache.get_stats())
return stats
# Singleton ResourceCaches instance
_resource_caches = ResourceCaches()
def get_resource_cache_stats_provider() -> CacheStatsProvider:
"""Return the StatsProvider for all @st.cache_resource functions."""
return _resource_caches
class CachedResourceFuncInfo(CachedFuncInfo):
"""Implements the CachedFuncInfo interface for @st.cache_resource"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
max_entries: int | None,
ttl: float | timedelta | None,
validate: ValidateFunc | None,
allow_widgets: bool,
):
super().__init__(
func,
show_spinner=show_spinner,
allow_widgets=allow_widgets,
)
self.max_entries = max_entries
self.ttl = ttl
self.validate = validate
@property
def cache_type(self) -> CacheType:
return CacheType.RESOURCE
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
return CACHE_RESOURCE_MESSAGE_REPLAY_CTX
@property
def display_name(self) -> str:
"""A human-readable name for the cached function"""
return f"{self.func.__module__}.{self.func.__qualname__}"
def get_function_cache(self, function_key: str) -> Cache:
return _resource_caches.get_cache(
key=function_key,
display_name=self.display_name,
max_entries=self.max_entries,
ttl=self.ttl,
validate=self.validate,
allow_widgets=self.allow_widgets,
)
class CacheResourceAPI:
"""Implements the public st.cache_resource API: the @st.cache_resource decorator,
and st.cache_resource.clear().
"""
def __init__(
self, decorator_metric_name: str, deprecation_warning: str | None = None
):
"""Create a CacheResourceAPI instance.
Parameters
----------
decorator_metric_name
The metric name to record for decorator usage. `@st.experimental_singleton` is
deprecated, but we're still supporting it and tracking its usage separately
from `@st.cache_resource`.
deprecation_warning
An optional deprecation warning to show when the API is accessed.
"""
# Parameterize the decorator metric name.
# (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
self._decorator = gather_metrics(decorator_metric_name, self._decorator) # type: ignore
self._deprecation_warning = deprecation_warning
# Type-annotate the decorator function.
# (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)
F = TypeVar("F", bound=Callable[..., Any])
# Bare decorator usage
@overload
def __call__(self, func: F) -> F:
...
# Decorator with arguments
@overload
def __call__(
self,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
) -> Callable[[F], F]:
...
def __call__(
self,
func: F | None = None,
*,
ttl: float | timedelta | None = None,
max_entries: int | None = None,
show_spinner: bool | str = True,
validate: ValidateFunc | None = None,
experimental_allow_widgets: bool = False,
):
return self._decorator(
func,
ttl=ttl,
max_entries=max_entries,
show_spinner=show_spinner,
validate=validate,
experimental_allow_widgets=experimental_allow_widgets,
)
def _decorator(
self,
func: F | None,
*,
ttl: float | timedelta | None,
max_entries: int | None,
show_spinner: bool | str,
validate: ValidateFunc | None,
experimental_allow_widgets: bool,
):
"""Decorator to cache functions that return global resources (e.g. database connections, ML models).
Cached objects are shared across all users, sessions, and reruns. They
must be thread-safe because they can be accessed from multiple threads
concurrently. If thread safety is an issue, consider using ``st.session_state``
to store resources per session instead.
You can clear a function's cache with ``func.clear()`` or clear the entire
cache with ``st.cache_resource.clear()``.
To cache data, use ``st.cache_data`` instead. Learn more about caching at
https://docs.streamlit.io/library/advanced-features/caching.
Parameters
----------
func : callable
The function that creates the cached resource. Streamlit hashes the
function's source code.
ttl : float or timedelta or None
The maximum number of seconds to keep an entry in the cache, or
None if cache entries should not expire. The default is None.
max_entries : int or None
The maximum number of entries to keep in the cache, or None
for an unbounded cache. (When a new entry is added to a full cache,
the oldest cached entry will be removed.) The default is None.
show_spinner : boolean or string
Enable the spinner. Default is True to show a spinner when there is
a "cache miss" and the cached resource is being created. If string,
value of show_spinner param will be used for spinner text.
validate : callable or None
An optional validation function for cached data. ``validate`` is called
each time the cached value is accessed. It receives the cached value as
its only parameter and it must return a boolean. If ``validate`` returns
False, the current cached value is discarded, and the decorated function
is called to compute a new value. This is useful e.g. to check the
health of database connections.
experimental_allow_widgets : boolean
Allow widgets to be used in the cached function. Defaults to False.
Support for widgets in cached functions is currently experimental.
Setting this parameter to True may lead to excessive memory use since the
widget value is treated as an additional input parameter to the cache.
We may remove support for this option at any time without notice.
Example
-------
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(url):
... # Create a database session object that points to the URL.
... return session
...
>>> s1 = get_database_session(SESSION_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(SESSION_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value. This means that now the connection object in s1 is the same as in s2.
>>>
>>> s3 = get_database_session(SESSION_URL_2)
>>> # This is a different URL, so the function executes.
By default, all parameters to a cache_resource function must be hashable.
Any parameter whose name begins with ``_`` will not be hashed. You can use
this as an "escape hatch" for parameters that are not hashable:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> s1 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Actually executes the function, since this is the first time it was
>>> # encountered.
>>>
>>> s2 = get_database_session(create_sessionmaker(), DATA_URL_1)
>>> # Does not execute the function. Instead, returns its previously computed
>>> # value - even though the _sessionmaker parameter was different
>>> # in both calls.
A cache_resource function's cache can be procedurally cleared:
>>> import streamlit as st
>>>
>>> @st.cache_resource
... def get_database_session(_sessionmaker, url):
... # Create a database connection object that points to the URL.
... return connection
...
>>> get_database_session.clear()
>>> # Clear all cached entries for this function.
"""
self._maybe_show_deprecation_warning()
# Support passing the params via function decorator, e.g.
# @st.cache_resource(show_spinner=False)
if func is None:
return lambda f: make_cached_func_wrapper(
CachedResourceFuncInfo(
func=f,
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
validate=validate,
allow_widgets=experimental_allow_widgets,
)
)
return make_cached_func_wrapper(
CachedResourceFuncInfo(
func=cast(types.FunctionType, func),
show_spinner=show_spinner,
max_entries=max_entries,
ttl=ttl,
validate=validate,
allow_widgets=experimental_allow_widgets,
)
)
@gather_metrics("clear_resource_caches")
def clear(self) -> None:
"""Clear all cache_resource caches."""
self._maybe_show_deprecation_warning()
_resource_caches.clear_all()
def _maybe_show_deprecation_warning(self):
"""If the API is being accessed with the deprecated `st.experimental_singleton` name,
show a deprecation warning.
"""
if self._deprecation_warning is not None:
show_deprecation_warning(self._deprecation_warning)
class ResourceCache(Cache):
"""Manages cached values for a single st.cache_resource function."""
def __init__(
self,
key: str,
max_entries: float,
ttl_seconds: float,
validate: ValidateFunc | None,
display_name: str,
allow_widgets: bool,
):
super().__init__()
self.key = key
self.display_name = display_name
self._mem_cache: TTLCache[str, MultiCacheResults] = TTLCache(
maxsize=max_entries, ttl=ttl_seconds, timer=cache_utils.TTLCACHE_TIMER
)
self._mem_cache_lock = threading.Lock()
self.validate = validate
self.allow_widgets = allow_widgets
@property
def max_entries(self) -> float:
return cast(float, self._mem_cache.maxsize)
@property
def ttl_seconds(self) -> float:
return cast(float, self._mem_cache.ttl)
def read_result(self, key: str) -> CachedResult:
"""Read a value and associated messages from the cache.
Raise `CacheKeyNotFoundError` if the value doesn't exist.
"""
with self._mem_cache_lock:
if key not in self._mem_cache:
# key does not exist in cache.
raise CacheKeyNotFoundError()
multi_results: MultiCacheResults = self._mem_cache[key]
ctx = get_script_run_ctx()
if not ctx:
# ScriptRunCtx does not exist (we're probably running in "raw" mode).
raise CacheKeyNotFoundError()
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
if widget_key not in multi_results.results:
# widget_key does not exist in cache (this combination of widgets hasn't been
# seen for the value_key yet).
raise CacheKeyNotFoundError()
result = multi_results.results[widget_key]
if self.validate is not None and not self.validate(result.value):
# Validate failed: delete the entry and raise an error.
del multi_results.results[widget_key]
raise CacheKeyNotFoundError()
return result
@gather_metrics("_cache_resource_object")
def write_result(self, key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache."""
ctx = get_script_run_ctx()
if ctx is None:
return
main_id = st._main.id
sidebar_id = st.sidebar.id
if self.allow_widgets:
widgets = {
msg.widget_metadata.widget_id
for msg in messages
if isinstance(msg, ElementMsgData) and msg.widget_metadata is not None
}
else:
widgets = set()
with self._mem_cache_lock:
try:
multi_results = self._mem_cache[key]
except KeyError:
multi_results = MultiCacheResults(widget_ids=widgets, results={})
multi_results.widget_ids.update(widgets)
widget_key = multi_results.get_current_widget_key(ctx, CacheType.RESOURCE)
result = CachedResult(value, messages, main_id, sidebar_id)
multi_results.results[widget_key] = result
self._mem_cache[key] = multi_results
def _clear(self) -> None:
with self._mem_cache_lock:
self._mem_cache.clear()
def get_stats(self) -> list[CacheStat]:
# Shallow clone our cache. Computing item sizes is potentially
# expensive, and we want to minimize the time we spend holding
# the lock.
with self._mem_cache_lock:
cache_entries = list(self._mem_cache.values())
return [
CacheStat(
category_name="st_cache_resource",
cache_name=self.display_name,
byte_length=asizeof.asizeof(entry),
)
for entry in cache_entries
]

View File

@@ -0,0 +1,31 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
class CacheType(enum.Enum):
"""The function cache types we implement."""
DATA = "DATA"
RESOURCE = "RESOURCE"
def get_decorator_api_name(cache_type: CacheType) -> str:
"""Return the name of the public decorator API for the given CacheType."""
if cache_type is CacheType.DATA:
return "cache_data"
if cache_type is CacheType.RESOURCE:
return "cache_resource"
raise RuntimeError(f"Unrecognized CacheType '{cache_type}'")

View File

@@ -0,0 +1,449 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common cache logic shared by st.cache_data and st.cache_resource."""
from __future__ import annotations
import functools
import hashlib
import inspect
import math
import threading
import time
import types
from abc import abstractmethod
from collections import defaultdict
from datetime import timedelta
from typing import Any, Callable, overload
from typing_extensions import Literal
from streamlit import type_util
from streamlit.elements.spinner import spinner
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import (
CacheError,
CacheKeyNotFoundError,
UnevaluatedDataFrameError,
UnhashableParamError,
UnhashableTypeError,
UnserializableReturnValueError,
get_cached_func_name_md,
)
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cached_message_replay import (
CachedMessageReplayContext,
CachedResult,
MsgData,
replay_cached_messages,
)
from streamlit.runtime.caching.hashing import update_hash
_LOGGER = get_logger(__name__)
# The timer function we use with TTLCache. This is the default timer func, but
# is exposed here as a constant so that it can be patched in unit tests.
TTLCACHE_TIMER = time.monotonic
@overload
def ttl_to_seconds(
ttl: float | timedelta | None, *, coerce_none_to_inf: Literal[False]
) -> float | None:
...
@overload
def ttl_to_seconds(ttl: float | timedelta | None) -> float:
...
def ttl_to_seconds(
ttl: float | timedelta | None, *, coerce_none_to_inf: bool = True
) -> float | None:
"""
Convert a ttl value to a float representing "number of seconds".
"""
if coerce_none_to_inf and ttl is None:
return math.inf
if isinstance(ttl, timedelta):
return ttl.total_seconds()
return ttl
# We show a special "UnevaluatedDataFrame" warning for cached funcs
# that attempt to return one of these unserializable types:
UNEVALUATED_DATAFRAME_TYPES = (
"snowflake.snowpark.table.Table",
"snowflake.snowpark.dataframe.DataFrame",
"pyspark.sql.dataframe.DataFrame",
)
class Cache:
"""Function cache interface. Caches persist across script runs."""
def __init__(self):
self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock)
self._value_locks_lock = threading.Lock()
@abstractmethod
def read_result(self, value_key: str) -> CachedResult:
"""Read a value and associated messages from the cache.
Raises
------
CacheKeyNotFoundError
Raised if value_key is not in the cache.
"""
raise NotImplementedError
@abstractmethod
def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None:
"""Write a value and associated messages to the cache, overwriting any existing
result that uses the value_key.
"""
# We *could* `del self._value_locks[value_key]` here, since nobody will be taking
# a compute_value_lock for this value_key after the result is written.
raise NotImplementedError
def compute_value_lock(self, value_key: str) -> threading.Lock:
"""Return the lock that should be held while computing a new cached value.
In a popular app with a cache that hasn't been pre-warmed, many sessions may try
to access a not-yet-cached value simultaneously. We use a lock to ensure that
only one of those sessions computes the value, and the others block until
the value is computed.
"""
with self._value_locks_lock:
return self._value_locks[value_key]
def clear(self):
"""Clear all values from this cache."""
with self._value_locks_lock:
self._value_locks.clear()
self._clear()
@abstractmethod
def _clear(self) -> None:
"""Subclasses must implement this to perform cache-clearing logic."""
raise NotImplementedError
class CachedFuncInfo:
"""Encapsulates data for a cached function instance.
CachedFuncInfo instances are scoped to a single script run - they're not
persistent.
"""
def __init__(
self,
func: types.FunctionType,
show_spinner: bool | str,
allow_widgets: bool,
):
self.func = func
self.show_spinner = show_spinner
self.allow_widgets = allow_widgets
@property
def cache_type(self) -> CacheType:
raise NotImplementedError
@property
def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
raise NotImplementedError
def get_function_cache(self, function_key: str) -> Cache:
"""Get or create the function cache for the given key."""
raise NotImplementedError
def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]:
"""Create a callable wrapper around a CachedFunctionInfo.
Calling the wrapper will return the cached value if it's already been
computed, and will call the underlying function to compute and cache the
value otherwise.
The wrapper also has a `clear` function that can be called to clear
all of the wrapper's cached values.
"""
cached_func = CachedFunc(info)
# We'd like to simply return `cached_func`, which is already a Callable.
# But using `functools.update_wrapper` on the CachedFunc instance
# itself results in errors when our caching decorators are used to decorate
# member functions. (See https://github.com/streamlit/streamlit/issues/6109)
@functools.wraps(info.func)
def wrapper(*args, **kwargs):
return cached_func(*args, **kwargs)
# Give our wrapper its `clear` function.
# (This results in a spurious mypy error that we suppress.)
wrapper.clear = cached_func.clear # type: ignore
return wrapper
class CachedFunc:
def __init__(self, info: CachedFuncInfo):
self._info = info
self._function_key = _make_function_key(info.cache_type, info.func)
def __call__(self, *args, **kwargs) -> Any:
"""The wrapper. We'll only call our underlying function on a cache miss."""
name = self._info.func.__qualname__
if isinstance(self._info.show_spinner, bool):
if len(args) == 0 and len(kwargs) == 0:
message = f"Running `{name}()`."
else:
message = f"Running `{name}(...)`."
else:
message = self._info.show_spinner
if self._info.show_spinner or isinstance(self._info.show_spinner, str):
with spinner(message):
return self._get_or_create_cached_value(args, kwargs)
else:
return self._get_or_create_cached_value(args, kwargs)
def _get_or_create_cached_value(
self, func_args: tuple[Any, ...], func_kwargs: dict[str, Any]
) -> Any:
# Retrieve the function's cache object. We must do this "just-in-time"
# (as opposed to in the constructor), because caches can be invalidated
# at any time.
cache = self._info.get_function_cache(self._function_key)
# Generate the key for the cached value. This is based on the
# arguments passed to the function.
value_key = _make_value_key(
cache_type=self._info.cache_type,
func=self._info.func,
func_args=func_args,
func_kwargs=func_kwargs,
)
try:
cached_result = cache.read_result(value_key)
return self._handle_cache_hit(cached_result)
except CacheKeyNotFoundError:
return self._handle_cache_miss(cache, value_key, func_args, func_kwargs)
def _handle_cache_hit(self, result: CachedResult) -> Any:
"""Handle a cache hit: replay the result's cached messages, and return its value."""
replay_cached_messages(
result,
self._info.cache_type,
self._info.func,
)
return result.value
def _handle_cache_miss(
self,
cache: Cache,
value_key: str,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
) -> Any:
"""Handle a cache miss: compute a new cached value, write it back to the cache,
and return that newly-computed value.
"""
# Implementation notes:
# - We take a "compute_value_lock" before computing our value. This ensures that
# multiple sessions don't try to compute the same value simultaneously.
#
# - We use a different lock for each value_key, as opposed to a single lock for
# the entire cache, so that unrelated value computations don't block on each other.
#
# - When retrieving a cache entry that may not yet exist, we use a "double-checked locking"
# strategy: first we try to retrieve the cache entry without taking a value lock. (This
# happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't
# been computed yet, we take the value lock and then immediately try to retrieve cache entry
# *again*, while holding the lock. If the cache entry exists at this point, it means that
# another thread computed the value before us.
#
# This means that the happy path ("cache entry exists") is a wee bit faster because
# no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is
# a wee bit slower, because we do two lookups for the entry.
with cache.compute_value_lock(value_key):
# We've acquired the lock - but another thread may have acquired it first
# and already computed the value. So we need to test for a cache hit again,
# before computing.
try:
cached_result = cache.read_result(value_key)
# Another thread computed the value before us. Early exit!
return self._handle_cache_hit(cached_result)
except CacheKeyNotFoundError:
# We acquired the lock before any other thread. Compute the value!
with self._info.cached_message_replay_ctx.calling_cached_function(
self._info.func, self._info.allow_widgets
):
computed_value = self._info.func(*func_args, **func_kwargs)
# We've computed our value, and now we need to write it back to the cache
# along with any "replay messages" that were generated during value computation.
messages = self._info.cached_message_replay_ctx._most_recent_messages
try:
cache.write_result(value_key, computed_value, messages)
return computed_value
except (CacheError, RuntimeError):
# An exception was thrown while we tried to write to the cache. Report it to the user.
# (We catch `RuntimeError` here because it will be raised by Apache Spark if we do not
# collect dataframe before using `st.cache_data`.)
if True in [
type_util.is_type(computed_value, type_name)
for type_name in UNEVALUATED_DATAFRAME_TYPES
]:
raise UnevaluatedDataFrameError(
f"""
The function {get_cached_func_name_md(self._info.func)} is decorated with `st.cache_data` but it returns an unevaluated dataframe
of type `{type_util.get_fqn_type(computed_value)}`. Please call `collect()` or `to_pandas()` on the dataframe before returning it,
so `st.cache_data` can serialize and cache it."""
)
raise UnserializableReturnValueError(
return_value=computed_value, func=self._info.func
)
def clear(self):
"""Clear the wrapped function's associated cache."""
cache = self._info.get_function_cache(self._function_key)
cache.clear()
def _make_value_key(
cache_type: CacheType,
func: types.FunctionType,
func_args: tuple[Any, ...],
func_kwargs: dict[str, Any],
) -> str:
"""Create the key for a value within a cache.
This key is generated from the function's arguments. All arguments
will be hashed, except for those named with a leading "_".
Raises
------
StreamlitAPIException
Raised (with a nicely-formatted explanation message) if we encounter
an un-hashable arg.
"""
# Create a (name, value) list of all *args and **kwargs passed to the
# function.
arg_pairs: list[tuple[str | None, Any]] = []
for arg_idx in range(len(func_args)):
arg_name = _get_positional_arg_name(func, arg_idx)
arg_pairs.append((arg_name, func_args[arg_idx]))
for kw_name, kw_val in func_kwargs.items():
# **kwargs ordering is preserved, per PEP 468
# https://www.python.org/dev/peps/pep-0468/, so this iteration is
# deterministic.
arg_pairs.append((kw_name, kw_val))
# Create the hash from each arg value, except for those args whose name
# starts with "_". (Underscore-prefixed args are deliberately excluded from
# hashing.)
args_hasher = hashlib.new("md5")
for arg_name, arg_value in arg_pairs:
if arg_name is not None and arg_name.startswith("_"):
_LOGGER.debug("Not hashing %s because it starts with _", arg_name)
continue
try:
update_hash(
(arg_name, arg_value),
hasher=args_hasher,
cache_type=cache_type,
)
except UnhashableTypeError as exc:
raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc)
value_key = args_hasher.hexdigest()
_LOGGER.debug("Cache key: %s", value_key)
return value_key
def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str:
"""Create the unique key for a function's cache.
A function's key is stable across reruns of the app, and changes when
the function's source code changes.
"""
func_hasher = hashlib.new("md5")
# Include the function's __module__ and __qualname__ strings in the hash.
# This means that two identical functions in different modules
# will not share a hash; it also means that two identical *nested*
# functions in the same module will not share a hash.
update_hash(
(func.__module__, func.__qualname__),
hasher=func_hasher,
cache_type=cache_type,
)
# Include the function's source code in its hash. If the source code can't
# be retrieved, fall back to the function's bytecode instead.
source_code: str | bytes
try:
source_code = inspect.getsource(func)
except OSError as e:
_LOGGER.debug(
"Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}",
e,
)
source_code = func.__code__.co_code
update_hash(
source_code,
hasher=func_hasher,
cache_type=cache_type,
)
cache_key = func_hasher.hexdigest()
return cache_key
def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> str | None:
"""Return the name of a function's positional argument.
If arg_index is out of range, or refers to a parameter that is not a
named positional argument (e.g. an *args, **kwargs, or keyword-only param),
return None instead.
"""
if arg_index < 0:
return None
params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values())
if arg_index >= len(params):
return None
if params[arg_index].kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
):
return params[arg_index].name
return None

View File

@@ -0,0 +1,478 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import hashlib
import threading
import types
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, Union
from google.protobuf.message import Message
from typing_extensions import Protocol, runtime_checkable
import streamlit as st
from streamlit import runtime, util
from streamlit.elements import NONWIDGET_ELEMENTS, WIDGETS
from streamlit.logger import get_logger
from streamlit.proto.Block_pb2 import Block
from streamlit.runtime.caching.cache_errors import (
CachedStFunctionWarning,
CacheReplayClosureError,
)
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.hashing import update_hash
from streamlit.runtime.scriptrunner.script_run_context import (
ScriptRunContext,
get_script_run_ctx,
)
from streamlit.runtime.state.common import WidgetMetadata
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
_LOGGER = get_logger(__name__)
@runtime_checkable
class Widget(Protocol):
id: str
@dataclass(frozen=True)
class WidgetMsgMetadata:
"""Everything needed for replaying a widget and treating it as an implicit
argument to a cached function, beyond what is stored for all elements.
"""
widget_id: str
widget_value: Any
metadata: WidgetMetadata[Any]
@dataclass(frozen=True)
class MediaMsgData:
media: bytes | str
mimetype: str
media_id: str
@dataclass(frozen=True)
class ElementMsgData:
"""An element's message and related metadata for
replaying that element's function call.
widget_metadata is filled in if and only if this element is a widget.
media_data is filled in iff this is a media element (image, audio, video).
"""
delta_type: str
message: Message
id_of_dg_called_on: str
returned_dgs_id: str
widget_metadata: WidgetMsgMetadata | None = None
media_data: list[MediaMsgData] | None = None
@dataclass(frozen=True)
class BlockMsgData:
message: Block
id_of_dg_called_on: str
returned_dgs_id: str
MsgData = Union[ElementMsgData, BlockMsgData]
"""
Note [Cache result structure]
The cache for a decorated function's results is split into two parts to enable
handling widgets invoked by the function.
Widgets act as implicit additional inputs to the cached function, so they should
be used when deriving the cache key. However, we don't know what widgets are
involved without first calling the function! So, we use the first execution
of the function with a particular cache key to record what widgets are used,
and use the current values of those widgets to derive a second cache key to
look up the function execution's results. The combination of first and second
cache keys act as one true cache key, just split up because the second part depends
on the first.
We need to treat widgets as implicit arguments of the cached function, because
the behavior of the function, inluding what elements are created and what it
returns, can be and usually will be influenced by the values of those widgets.
For example:
> @st.memo
> def example_fn(x):
> y = x + 1
> if st.checkbox("hi"):
> st.write("you checked the thing")
> y = 0
> return y
>
> example_fn(2)
If the checkbox is checked, the function call should return 0 and the checkbox and
message should be rendered. If the checkbox isn't checked, only the checkbox should
render, and the function will return 3.
There is a small catch in this. Since what widgets execute could depend on the values of
any prior widgets, if we replace the `st.write` call in the example with a slider,
the first time it runs, we would miss the slider because it wasn't called,
so when we later execute the function with the checkbox checked, the widget cache key
would not include the state of the slider, and would incorrectly get a cache hit
for a different slider value.
In principle the cache could be function+args key -> 1st widget key -> 2nd widget key
... -> final widget key, with each widget dependent on the exact values of the widgets
seen prior. This would prevent unnecessary cache misses due to differing values of widgets
that wouldn't affect the function's execution because they aren't even created.
But this would add even more complexity and both conceptual and runtime overhead, so it is
unclear if it would be worth doing.
Instead, we can keep the widgets as one cache key, and if we encounter a new widget
while executing the function, we just update the list of widgets to include it.
This will cause existing cached results to be invalidated, which is bad, but to
avoid it we would need to keep around the full list of widgets and values for each
widget cache key so we could compute the updated key, which is probably too expensive
to be worth it.
"""
@dataclass
class CachedResult:
"""The full results of calling a cache-decorated function, enough to
replay the st functions called while executing it.
"""
value: Any
messages: list[MsgData]
main_id: str
sidebar_id: str
@dataclass
class MultiCacheResults:
"""Widgets called by a cache-decorated function, and a mapping of the
widget-derived cache key to the final results of executing the function.
"""
widget_ids: set[str]
results: dict[str, CachedResult]
def get_current_widget_key(
self, ctx: ScriptRunContext, cache_type: CacheType
) -> str:
state = ctx.session_state
# Compute the key using only widgets that have values. A missing widget
# can be ignored because we only care about getting different keys
# for different widget values, and for that purpose doing nothing
# to the running hash is just as good as including the widget with a
# sentinel value. But by excluding it, we might get to reuse a result
# saved before we knew about that widget.
widget_values = [
(wid, state[wid]) for wid in sorted(self.widget_ids) if wid in state
]
widget_key = _make_widget_key(widget_values, cache_type)
return widget_key
"""
Note [DeltaGenerator method invocation]
There are two top level DG instances defined for all apps:
`main`, which is for putting elements in the main part of the app
`sidebar`, for the sidebar
There are 3 different ways an st function can be invoked:
1. Implicitly on the main DG instance (plain `st.foo` calls)
2. Implicitly in an active contextmanager block (`st.foo` within a `with st.container` context)
3. Explicitly on a DG instance (`st.sidebar.foo`, `my_column_1.foo`)
To simplify replaying messages from a cached function result, we convert all of these
to explicit invocations. How they get rewritten depends on if the invocation was
implicit vs explicit, and if the target DG has been seen/produced during replay.
Implicit invocation on a known DG -> Explicit invocation on that DG
Implicit invocation on an unknown DG -> Rewrite as explicit invocation on main
with st.container():
my_cache_decorated_function()
This is situation 2 above, and the DG is a block entirely outside our function call,
so we interpret it as "put this element in the enclosing contextmanager block"
(or main if there isn't one), which is achieved by invoking on main.
Explicit invocation on a known DG -> No change needed
Explicit invocation on an unknown DG -> Raise an error
We have no way to identify the target DG, and it may not even be present in the
current script run, so the least surprising thing to do is raise an error.
"""
class CachedMessageReplayContext(threading.local):
"""A utility for storing messages generated by `st` commands called inside
a cached function.
Data is stored in a thread-local object, so it's safe to use an instance
of this class across multiple threads.
"""
def __init__(self, cache_type: CacheType):
self._cached_func_stack: list[types.FunctionType] = []
self._suppress_st_function_warning = 0
self._cached_message_stack: list[list[MsgData]] = []
self._seen_dg_stack: list[set[str]] = []
self._most_recent_messages: list[MsgData] = []
self._registered_metadata: WidgetMetadata[Any] | None = None
self._media_data: list[MediaMsgData] = []
self._cache_type = cache_type
self._allow_widgets: int = 0
def __repr__(self) -> str:
return util.repr_(self)
@contextlib.contextmanager
def calling_cached_function(
self, func: types.FunctionType, allow_widgets: bool
) -> Iterator[None]:
"""Context manager that should wrap the invocation of a cached function.
It allows us to track any `st.foo` messages that are generated from inside the function
for playback during cache retrieval.
"""
self._cached_func_stack.append(func)
self._cached_message_stack.append([])
self._seen_dg_stack.append(set())
if allow_widgets:
self._allow_widgets += 1
try:
yield
finally:
self._cached_func_stack.pop()
self._most_recent_messages = self._cached_message_stack.pop()
self._seen_dg_stack.pop()
if allow_widgets:
self._allow_widgets -= 1
def save_element_message(
self,
delta_type: str,
element_proto: Message,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
"""Record the element protobuf as having been produced during any currently
executing cached functions, so they can be replayed any time the function's
execution is skipped because they're in the cache.
"""
if not runtime.exists():
return
if len(self._cached_message_stack) >= 1:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
# Arrow dataframes have an ID but only set it when used as data editor
# widgets, so we have to check that the ID has been actually set to
# know if an element is a widget.
if isinstance(element_proto, Widget) and element_proto.id:
wid = element_proto.id
# TODO replace `Message` with a more precise type
if not self._registered_metadata:
_LOGGER.error(
"Trying to save widget message that wasn't registered. This should not be possible."
)
raise AttributeError
widget_meta = WidgetMsgMetadata(
wid, None, metadata=self._registered_metadata
)
else:
widget_meta = None
media_data = self._media_data
element_msg_data = ElementMsgData(
delta_type,
element_proto,
id_to_save,
returned_dg_id,
widget_meta,
media_data,
)
for msgs in self._cached_message_stack:
if self._allow_widgets or widget_meta is None:
msgs.append(element_msg_data)
# Reset instance state, now that it has been used for the
# associated element.
self._media_data = []
self._registered_metadata = None
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def save_block_message(
self,
block_proto: Block,
invoked_dg_id: str,
used_dg_id: str,
returned_dg_id: str,
) -> None:
id_to_save = self.select_dg_to_save(invoked_dg_id, used_dg_id)
for msgs in self._cached_message_stack:
msgs.append(BlockMsgData(block_proto, id_to_save, returned_dg_id))
for s in self._seen_dg_stack:
s.add(returned_dg_id)
def select_dg_to_save(self, invoked_id: str, acting_on_id: str) -> str:
"""Select the id of the DG that this message should be invoked on
during message replay.
See Note [DeltaGenerator method invocation]
invoked_id is the DG the st function was called on, usually `st._main`.
acting_on_id is the DG the st function ultimately runs on, which may be different
if the invoked DG delegated to another one because it was in a `with` block.
"""
if len(self._seen_dg_stack) > 0 and acting_on_id in self._seen_dg_stack[-1]:
return acting_on_id
else:
return invoked_id
def save_widget_metadata(self, metadata: WidgetMetadata[Any]) -> None:
self._registered_metadata = metadata
def save_image_data(
self, image_data: bytes | str, mimetype: str, image_id: str
) -> None:
self._media_data.append(MediaMsgData(image_data, mimetype, image_id))
@contextlib.contextmanager
def suppress_cached_st_function_warning(self) -> Iterator[None]:
self._suppress_st_function_warning += 1
try:
yield
finally:
self._suppress_st_function_warning -= 1
assert self._suppress_st_function_warning >= 0
def maybe_show_cached_st_function_warning(
self,
dg: "DeltaGenerator",
st_func_name: str,
) -> None:
"""If appropriate, warn about calling st.foo inside @memo.
DeltaGenerator's @_with_element and @_widget wrappers use this to warn
the user when they're calling st.foo() from within a function that is
wrapped in @st.cache.
Parameters
----------
dg : DeltaGenerator
The DeltaGenerator to publish the warning to.
st_func_name : str
The name of the Streamlit function that was called.
"""
# There are some elements not in either list, which we still want to warn about.
# Ideally we will fix this by either updating the lists or creating a better
# way of categorizing elements.
if st_func_name in NONWIDGET_ELEMENTS:
return
if st_func_name in WIDGETS and self._allow_widgets > 0:
return
if len(self._cached_func_stack) > 0 and self._suppress_st_function_warning <= 0:
cached_func = self._cached_func_stack[-1]
self._show_cached_st_function_warning(dg, st_func_name, cached_func)
def _show_cached_st_function_warning(
self,
dg: "DeltaGenerator",
st_func_name: str,
cached_func: types.FunctionType,
) -> None:
# Avoid infinite recursion by suppressing additional cached
# function warnings from within the cached function warning.
with self.suppress_cached_st_function_warning():
e = CachedStFunctionWarning(self._cache_type, st_func_name, cached_func)
dg.exception(e)
def replay_cached_messages(
result: CachedResult, cache_type: CacheType, cached_func: types.FunctionType
) -> None:
"""Replay the st element function calls that happened when executing a
cache-decorated function.
When a cache function is executed, we record the element and block messages
produced, and use those to reproduce the DeltaGenerator calls, so the elements
will appear in the web app even when execution of the function is skipped
because the result was cached.
To make this work, for each st function call we record an identifier for the
DG it was effectively called on (see Note [DeltaGenerator method invocation]).
We also record the identifier for each DG returned by an st function call, if
it returns one. Then, for each recorded message, we get the current DG instance
corresponding to the DG the message was originally called on, and enqueue the
message using that, recording any new DGs produced in case a later st function
call is on one of them.
"""
from streamlit.delta_generator import DeltaGenerator
from streamlit.runtime.state.widgets import register_widget_from_metadata
# Maps originally recorded dg ids to this script run's version of that dg
returned_dgs: dict[str, DeltaGenerator] = {}
returned_dgs[result.main_id] = st._main
returned_dgs[result.sidebar_id] = st.sidebar
ctx = get_script_run_ctx()
try:
for msg in result.messages:
if isinstance(msg, ElementMsgData):
if msg.widget_metadata is not None:
register_widget_from_metadata(
msg.widget_metadata.metadata,
ctx,
None,
msg.delta_type,
)
if msg.media_data is not None:
for data in msg.media_data:
runtime.get_instance().media_file_mgr.add(
data.media, data.mimetype, data.media_id
)
dg = returned_dgs[msg.id_of_dg_called_on]
maybe_dg = dg._enqueue(msg.delta_type, msg.message)
if isinstance(maybe_dg, DeltaGenerator):
returned_dgs[msg.returned_dgs_id] = maybe_dg
elif isinstance(msg, BlockMsgData):
dg = returned_dgs[msg.id_of_dg_called_on]
new_dg = dg._block(msg.message)
returned_dgs[msg.returned_dgs_id] = new_dg
except KeyError:
raise CacheReplayClosureError(cache_type, cached_func)
def _make_widget_key(widgets: list[tuple[str, Any]], cache_type: CacheType) -> str:
"""Generate a key for the given list of widgets used in a cache-decorated function.
Keys are generated by hashing the IDs and values of the widgets in the given list.
"""
func_hasher = hashlib.new("md5")
for widget_id_val in widgets:
update_hash(widget_id_val, func_hasher, cache_type)
return func_hasher.hexdigest()

View File

@@ -0,0 +1,395 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hashing for st.memo and st.singleton."""
import collections
import dataclasses
import functools
import hashlib
import inspect
import io
import os
import pickle
import sys
import tempfile
import threading
import unittest.mock
import weakref
from enum import Enum
from typing import Any, Dict, List, Optional, Pattern
from streamlit import type_util, util
from streamlit.runtime.caching.cache_errors import UnhashableTypeError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.uploaded_file_manager import UploadedFile
# If a dataframe has more than this many rows, we consider it large and hash a sample.
_PANDAS_ROWS_LARGE = 100000
_PANDAS_SAMPLE_SIZE = 10000
# Similar to dataframes, we also sample large numpy arrays.
_NP_SIZE_LARGE = 1000000
_NP_SAMPLE_SIZE = 100000
# Arbitrary item to denote where we found a cycle in a hashed object.
# This allows us to hash self-referencing lists, dictionaries, etc.
_CYCLE_PLACEHOLDER = b"streamlit-57R34ML17-hesamagicalponyflyingthroughthesky-CYCLE"
def update_hash(val: Any, hasher, cache_type: CacheType) -> None:
"""Updates a hashlib hasher with the hash of val.
This is the main entrypoint to hashing.py.
"""
ch = _CacheFuncHasher(cache_type)
ch.update(hasher, val)
class _HashStack:
"""Stack of what has been hashed, for debug and circular reference detection.
This internally keeps 1 stack per thread.
Internally, this stores the ID of pushed objects rather than the objects
themselves because otherwise the "in" operator inside __contains__ would
fail for objects that don't return a boolean for "==" operator. For
example, arr == 10 where arr is a NumPy array returns another NumPy array.
This causes the "in" to crash since it expects a boolean.
"""
def __init__(self):
self._stack: collections.OrderedDict[int, List[Any]] = collections.OrderedDict()
def __repr__(self) -> str:
return util.repr_(self)
def push(self, val: Any):
self._stack[id(val)] = val
def pop(self):
self._stack.popitem()
def __contains__(self, val: Any):
return id(val) in self._stack
class _HashStacks:
"""Stacks of what has been hashed, with at most 1 stack per thread."""
def __init__(self):
self._stacks: weakref.WeakKeyDictionary[
threading.Thread, _HashStack
] = weakref.WeakKeyDictionary()
def __repr__(self) -> str:
return util.repr_(self)
@property
def current(self) -> _HashStack:
current_thread = threading.current_thread()
stack = self._stacks.get(current_thread, None)
if stack is None:
stack = _HashStack()
self._stacks[current_thread] = stack
return stack
hash_stacks = _HashStacks()
def _int_to_bytes(i: int) -> bytes:
num_bytes = (i.bit_length() + 8) // 8
return i.to_bytes(num_bytes, "little", signed=True)
def _key(obj: Optional[Any]) -> Any:
"""Return key for memoization."""
if obj is None:
return None
def is_simple(obj):
return (
isinstance(obj, bytes)
or isinstance(obj, bytearray)
or isinstance(obj, str)
or isinstance(obj, float)
or isinstance(obj, int)
or isinstance(obj, bool)
or obj is None
)
if is_simple(obj):
return obj
if isinstance(obj, tuple):
if all(map(is_simple, obj)):
return obj
if isinstance(obj, list):
if all(map(is_simple, obj)):
return ("__l", tuple(obj))
if (
type_util.is_type(obj, "pandas.core.frame.DataFrame")
or type_util.is_type(obj, "numpy.ndarray")
or inspect.isbuiltin(obj)
or inspect.isroutine(obj)
or inspect.iscode(obj)
):
return id(obj)
return NoResult
class _CacheFuncHasher:
"""A hasher that can hash objects with cycles."""
def __init__(self, cache_type: CacheType):
self._hashes: Dict[Any, bytes] = {}
# The number of the bytes in the hash.
self.size = 0
self.cache_type = cache_type
def __repr__(self) -> str:
return util.repr_(self)
def to_bytes(self, obj: Any) -> bytes:
"""Add memoization to _to_bytes and protect against cycles in data structures."""
tname = type(obj).__qualname__.encode()
key = (tname, _key(obj))
# Memoize if possible.
if key[1] is not NoResult:
if key in self._hashes:
return self._hashes[key]
# Break recursive cycles.
if obj in hash_stacks.current:
return _CYCLE_PLACEHOLDER
hash_stacks.current.push(obj)
try:
# Hash the input
b = b"%s:%s" % (tname, self._to_bytes(obj))
# Hmmm... It's possible that the size calculation is wrong. When we
# call to_bytes inside _to_bytes things get double-counted.
self.size += sys.getsizeof(b)
if key[1] is not NoResult:
self._hashes[key] = b
finally:
# In case an UnhashableTypeError (or other) error is thrown, clean up the
# stack so we don't get false positives in future hashing calls
hash_stacks.current.pop()
return b
def update(self, hasher, obj: Any) -> None:
"""Update the provided hasher with the hash of an object."""
b = self.to_bytes(obj)
hasher.update(b)
def _to_bytes(self, obj: Any) -> bytes:
"""Hash objects to bytes, including code with dependencies.
Python's built in `hash` does not produce consistent results across
runs.
"""
if isinstance(obj, unittest.mock.Mock):
# Mock objects can appear to be infinitely
# deep, so we don't try to hash them at all.
return self.to_bytes(id(obj))
elif isinstance(obj, bytes) or isinstance(obj, bytearray):
return obj
elif isinstance(obj, str):
return obj.encode()
elif isinstance(obj, float):
return self.to_bytes(hash(obj))
elif isinstance(obj, int):
return _int_to_bytes(obj)
elif isinstance(obj, (list, tuple)):
h = hashlib.new("md5")
for item in obj:
self.update(h, item)
return h.digest()
elif isinstance(obj, dict):
h = hashlib.new("md5")
for item in obj.items():
self.update(h, item)
return h.digest()
elif obj is None:
return b"0"
elif obj is True:
return b"1"
elif obj is False:
return b"0"
elif dataclasses.is_dataclass(obj):
return self.to_bytes(dataclasses.asdict(obj))
elif isinstance(obj, Enum):
return str(obj).encode()
elif type_util.is_type(obj, "pandas.core.frame.DataFrame") or type_util.is_type(
obj, "pandas.core.series.Series"
):
import pandas as pd
if len(obj) >= _PANDAS_ROWS_LARGE:
obj = obj.sample(n=_PANDAS_SAMPLE_SIZE, random_state=0)
try:
return b"%s" % pd.util.hash_pandas_object(obj).sum()
except TypeError:
# Use pickle if pandas cannot hash the object for example if
# it contains unhashable objects.
return b"%s" % pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
elif type_util.is_type(obj, "numpy.ndarray"):
h = hashlib.new("md5")
self.update(h, obj.shape)
if obj.size >= _NP_SIZE_LARGE:
import numpy as np
state = np.random.RandomState(0)
obj = state.choice(obj.flat, size=_NP_SAMPLE_SIZE)
self.update(h, obj.tobytes())
return h.digest()
elif type_util.is_type(obj, "PIL.Image.Image"):
import numpy as np
# we don't just hash the results of obj.tobytes() because we want to use
# the sampling logic for numpy data
np_array = np.frombuffer(obj.tobytes(), dtype="uint8")
return self.to_bytes(np_array)
elif inspect.isbuiltin(obj):
return bytes(obj.__name__.encode())
elif type_util.is_type(obj, "builtins.mappingproxy") or type_util.is_type(
obj, "builtins.dict_items"
):
return self.to_bytes(dict(obj))
elif type_util.is_type(obj, "builtins.getset_descriptor"):
return bytes(obj.__qualname__.encode())
elif isinstance(obj, UploadedFile):
# UploadedFile is a BytesIO (thus IOBase) but has a name.
# It does not have a timestamp so this must come before
# temporary files
h = hashlib.new("md5")
self.update(h, obj.name)
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif hasattr(obj, "name") and (
isinstance(obj, io.IOBase)
# Handle temporary files used during testing
or isinstance(obj, tempfile._TemporaryFileWrapper)
):
# Hash files as name + last modification date + offset.
# NB: we're using hasattr("name") to differentiate between
# on-disk and in-memory StringIO/BytesIO file representations.
# That means that this condition must come *before* the next
# condition, which just checks for StringIO/BytesIO.
h = hashlib.new("md5")
obj_name = getattr(obj, "name", "wonthappen") # Just to appease MyPy.
self.update(h, obj_name)
self.update(h, os.path.getmtime(obj_name))
self.update(h, obj.tell())
return h.digest()
elif isinstance(obj, Pattern):
return self.to_bytes([obj.pattern, obj.flags])
elif isinstance(obj, io.StringIO) or isinstance(obj, io.BytesIO):
# Hash in-memory StringIO/BytesIO by their full contents
# and seek position.
h = hashlib.new("md5")
self.update(h, obj.tell())
self.update(h, obj.getvalue())
return h.digest()
elif type_util.is_type(obj, "numpy.ufunc"):
# For numpy.remainder, this returns remainder.
return bytes(obj.__name__.encode())
elif inspect.ismodule(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# so the current warning is quite annoying...
# st.warning(('Streamlit does not support hashing modules. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name for internal modules.
return self.to_bytes(obj.__name__)
elif inspect.isclass(obj):
# TODO: Figure out how to best show this kind of warning to the
# user. In the meantime, show nothing. This scenario is too common,
# (e.g. in every "except" statement) so the current warning is
# quite annoying...
# st.warning(('Streamlit does not support hashing classes. '
# 'We did not hash `%s`.') % obj.__name__)
# TODO: Hash more than just the name of classes.
return self.to_bytes(obj.__name__)
elif isinstance(obj, functools.partial):
# The return value of functools.partial is not a plain function:
# it's a callable object that remembers the original function plus
# the values you pickled into it. So here we need to special-case it.
h = hashlib.new("md5")
self.update(h, obj.args)
self.update(h, obj.func)
self.update(h, obj.keywords)
return h.digest()
else:
# As a last resort, hash the output of the object's __reduce__ method
h = hashlib.new("md5")
try:
reduce_data = obj.__reduce__()
except Exception as ex:
raise UnhashableTypeError() from ex
for item in reduce_data:
self.update(h, item)
return h.digest()
class NoResult:
"""Placeholder class for return values when None is meaningful."""
pass

View File

@@ -0,0 +1,29 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage as CacheStorage,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageContext as CacheStorageContext,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageError as CacheStorageError,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageKeyNotFoundError as CacheStorageKeyNotFoundError,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorageManager as CacheStorageManager,
)

View File

@@ -0,0 +1,240 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Declares the CacheStorageContext dataclass, which contains parameter information for
each function decorated by `@st.cache_data` (for example: ttl, max_entries etc.)
Declares the CacheStorageManager protocol, which implementations are used
to create CacheStorage instances and to optionally clear all cache storages,
that were created by this manager, and to check if the context is valid for the storage.
Declares the CacheStorage protocol, which implementations are used to store cached
values for a single `@st.cache_data` decorated function serialized as bytes.
How these classes work together
-------------------------------
- CacheStorageContext : this is a dataclass that contains the parameters from
`@st.cache_data` that are passed to the CacheStorageManager.create() method.
- CacheStorageManager : each instance of this is able to create CacheStorage
instances, and optionally to clear data of all cache storages.
- CacheStorage : each instance of this is able to get, set, delete, and clear
entries for a single `@st.cache_data` decorated function.
┌───────────────────────────────┐
│ │
│ CacheStorageManager │
│ │
│ - clear_all(optional) │
│ - check_context │
│ │
└──┬────────────────────────────┘
│ ┌──────────────────────┐
│ │ CacheStorage │
│ create(context)│ │
└────────────────► - get │
│ - set │
│ - delete │
│ - close (optional)│
│ - clear │
└──────────────────────┘
"""
from __future__ import annotations
from abc import abstractmethod
from dataclasses import dataclass
from typing_extensions import Literal, Protocol
class CacheStorageError(Exception):
"""Base exception raised by the cache storage"""
class CacheStorageKeyNotFoundError(CacheStorageError):
"""Raised when the key is not found in the cache storage"""
class InvalidCacheStorageContext(CacheStorageError):
"""Raised if the cache storage manager is not able to work with
provided CacheStorageContext.
"""
@dataclass(frozen=True)
class CacheStorageContext:
"""Context passed to the cache storage during initialization
This is the normalized parameters that are passed to CacheStorageManager.create()
method.
Parameters
----------
function_key: str
A hash computed based on function name and source code decorated
by `@st.cache_data`
function_display_name: str
The display name of the function that is decorated by `@st.cache_data`
ttl_seconds : float or None
The time-to-live for the keys in storage, in seconds. If None, the entry
will never expire.
max_entries : int or None
The maximum number of entries to store in the cache storage.
If None, the cache storage will not limit the number of entries.
persist : Literal["disk"] or None
The persistence mode for the cache storage.
Legacy parameter, that used in Streamlit current cache storage implementation.
Could be ignored by cache storage implementation, if storage does not support
persistence or it persistent by default.
"""
function_key: str
function_display_name: str
ttl_seconds: float | None = None
max_entries: int | None = None
persist: Literal["disk"] | None = None
class CacheStorage(Protocol):
"""Cache storage protocol, that should be implemented by the concrete cache storages.
Used to store cached values for a single `@st.cache_data` decorated function
serialized as bytes.
CacheStorage instances should be created by `CacheStorageManager.create()` method.
Notes
-----
Threading: The methods of this protocol could be called from multiple threads.
This is a responsibility of the concrete implementation to ensure thread safety
guarantees.
"""
@abstractmethod
def get(self, key: str) -> bytes:
"""Returns the stored value for the key.
Raises
------
CacheStorageKeyNotFoundError
Raised if the key is not in the storage.
"""
raise NotImplementedError
@abstractmethod
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key"""
raise NotImplementedError
@abstractmethod
def delete(self, key: str) -> None:
"""Delete a given key"""
raise NotImplementedError
@abstractmethod
def clear(self) -> None:
"""Remove all keys for the storage"""
raise NotImplementedError
def close(self) -> None:
"""Closes the cache storage, it is optional to implement, and should be used
to close open resources, before we delete the storage instance.
e.g. close the database connection etc.
"""
pass
class CacheStorageManager(Protocol):
"""Cache storage manager protocol, that should be implemented by the concrete
cache storage managers.
It is responsible for:
- Creating cache storage instances for the specific
decorated functions,
- Validating the context for the cache storages.
- Optionally clearing all cache storages in optimal way.
It should be created during Runtime initialization.
"""
@abstractmethod
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance
Please note that the ttl, max_entries and other context fields are specific
for whole storage, not for individual key.
Notes
-----
Threading: Should be safe to call from any thread.
"""
raise NotImplementedError
def clear_all(self) -> None:
"""Remove everything what possible from the cache storages in optimal way.
meaningful default behaviour is to raise NotImplementedError, so this is not
abstractmethod.
The method is optional to implement: cache data API will fall back to remove
all available storages one by one via storage.clear() method
if clear_all raises NotImplementedError.
Raises
------
NotImplementedError
Raised if the storage manager does not provide an ability to clear
all storages at once in optimal way.
Notes
-----
Threading: This method could be called from multiple threads.
This is a responsibility of the concrete implementation to ensure
thread safety guarantees.
"""
raise NotImplementedError
def check_context(self, context: CacheStorageContext) -> None:
"""Checks if the context is valid for the storage manager.
This method should not return anything, but log message or raise an exception
if the context is invalid.
In case of raising an exception, we not handle it and let the exception to be
propagated.
check_context is called only once at the moment of creating `@st.cache_data`
decorator for specific function, so it is not called for every cache hit.
Parameters
----------
context: CacheStorageContext
The context to check for the storage manager, dummy function_key in context
will be used, since it is not computed at the point of calling this method.
Raises
------
InvalidCacheStorageContext
Raised if the cache storage manager is not able to work with provided
CacheStorageContext. When possible we should log message instead, since
this exception will be propagated to the user.
Notes
-----
Threading: Should be safe to call from any thread.
"""
pass

View File

@@ -0,0 +1,60 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
InMemoryCacheStorageWrapper,
)
class MemoryCacheStorageManager(CacheStorageManager):
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
persist_storage = DummyCacheStorage()
return InMemoryCacheStorageWrapper(
persist_storage=persist_storage, context=context
)
def clear_all(self) -> None:
raise NotImplementedError
def check_context(self, context: CacheStorageContext) -> None:
pass
class DummyCacheStorage(CacheStorage):
def get(self, key: str) -> bytes:
"""
Dummy gets the value for a given key,
always raises an CacheStorageKeyNotFoundError
"""
raise CacheStorageKeyNotFoundError("Key not found in dummy cache")
def set(self, key: str, value: bytes) -> None:
pass
def delete(self, key: str) -> None:
pass
def clear(self) -> None:
pass
def close(self) -> None:
pass

View File

@@ -0,0 +1,145 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import math
import threading
from cachetools import TTLCache
from streamlit.logger import get_logger
from streamlit.runtime.caching import cache_utils
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageKeyNotFoundError,
)
from streamlit.runtime.stats import CacheStat
_LOGGER = get_logger(__name__)
class InMemoryCacheStorageWrapper(CacheStorage):
"""
In-memory cache storage wrapper.
This class wraps a cache storage and adds an in-memory cache front layer,
which is used to reduce the number of calls to the storage.
The in-memory cache is a TTL cache, which means that the entries are
automatically removed if a given time to live (TTL) has passed.
The in-memory cache is also an LRU cache, which means that the entries
are automatically removed if the cache size exceeds a given maxsize.
If the storage implements its strategy for maxsize, it is recommended
(but not necessary) that the storage implement the same LRU strategy,
otherwise a situation may arise when different items are deleted from
the memory cache and from the storage.
Notes
-----
Threading: in-memory caching layer is thread safe: we hold self._mem_cache_lock for
working with this self._mem_cache object.
However, we do not hold this lock when calling into the underlying storage,
so it is the responsibility of the that storage to ensure that it is safe to use
it from multiple threads.
"""
def __init__(self, persist_storage: CacheStorage, context: CacheStorageContext):
self.function_key = context.function_key
self.function_display_name = context.function_display_name
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
self._mem_cache: TTLCache[str, bytes] = TTLCache(
maxsize=self.max_entries,
ttl=self.ttl_seconds,
timer=cache_utils.TTLCACHE_TIMER,
)
self._mem_cache_lock = threading.Lock()
self._persist_storage = persist_storage
@property
def ttl_seconds(self) -> float:
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
@property
def max_entries(self) -> float:
return float(self._max_entries) if self._max_entries is not None else math.inf
def get(self, key: str) -> bytes:
"""
Returns the stored value for the key or raise CacheStorageKeyNotFoundError if
the key is not found
"""
try:
entry_bytes = self._read_from_mem_cache(key)
except CacheStorageKeyNotFoundError:
entry_bytes = self._persist_storage.get(key)
self._write_to_mem_cache(key, entry_bytes)
return entry_bytes
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key"""
self._write_to_mem_cache(key, value)
self._persist_storage.set(key, value)
def delete(self, key: str) -> None:
"""Delete a given key"""
self._remove_from_mem_cache(key)
self._persist_storage.delete(key)
def clear(self) -> None:
"""Delete all keys for the in memory cache, and also the persistent storage"""
with self._mem_cache_lock:
self._mem_cache.clear()
self._persist_storage.clear()
def get_stats(self) -> list[CacheStat]:
"""Returns a list of stats in bytes for the cache memory storage per item"""
stats = []
with self._mem_cache_lock:
for item in self._mem_cache.values():
stats.append(
CacheStat(
category_name="st_cache_data",
cache_name=self.function_display_name,
byte_length=len(item),
)
)
return stats
def close(self) -> None:
"""Closes the cache storage"""
self._persist_storage.close()
def _read_from_mem_cache(self, key: str) -> bytes:
with self._mem_cache_lock:
if key in self._mem_cache:
entry = bytes(self._mem_cache[key])
_LOGGER.debug("Memory cache HIT: %s", key)
return entry
else:
_LOGGER.debug("Memory cache MISS: %s", key)
raise CacheStorageKeyNotFoundError("Key not found in mem cache")
def _write_to_mem_cache(self, key: str, entry_bytes: bytes) -> None:
with self._mem_cache_lock:
self._mem_cache[key] = entry_bytes
def _remove_from_mem_cache(self, key: str) -> None:
with self._mem_cache_lock:
self._mem_cache.pop(key, None)

View File

@@ -0,0 +1,222 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Declares the LocalDiskCacheStorageManager class, which is used
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
InMemoryCacheStorageWrapper wrapper allows to have first layer of in-memory cache,
before accessing to LocalDiskCacheStorage itself.
Declares the LocalDiskCacheStorage class, which is used to store cached
values on disk.
How these classes work together
-------------------------------
- LocalDiskCacheStorageManager : each instance of this is able
to create LocalDiskCacheStorage instances wrapped by InMemoryCacheStorageWrapper,
and to clear data from cache storage folder. It is also LocalDiskCacheStorageManager
responsibility to check if the context is valid for the storage, and to log warning
if the context is not valid.
- LocalDiskCacheStorage : each instance of this is able to get, set, delete, and clear
entries from disk for a single `@st.cache_data` decorated function if `persist="disk"`
is used in CacheStorageContext.
┌───────────────────────────────┐
│ LocalDiskCacheStorageManager │
│ │
│ - clear_all │
│ - check_context │
│ │
└──┬────────────────────────────┘
│ ┌──────────────────────────────┐
│ │ │
│ create(context)│ InMemoryCacheStorageWrapper │
└────────────────► │
│ ┌─────────────────────┐ │
│ │ │ │
│ │ LocalDiskStorage │ │
│ │ │ │
│ └─────────────────────┘ │
│ │
└──────────────────────────────┘
"""
from __future__ import annotations
import math
import os
import shutil
from streamlit import util
from streamlit.file_util import get_streamlit_file_path, streamlit_read, streamlit_write
from streamlit.logger import get_logger
from streamlit.runtime.caching.storage.cache_storage_protocol import (
CacheStorage,
CacheStorageContext,
CacheStorageError,
CacheStorageKeyNotFoundError,
CacheStorageManager,
)
from streamlit.runtime.caching.storage.in_memory_cache_storage_wrapper import (
InMemoryCacheStorageWrapper,
)
# Streamlit directory where persisted @st.cache_data objects live.
# (This is the same directory that @st.cache persisted objects live.
# But @st.cache_data uses a different extension, so they don't overlap.)
_CACHE_DIR_NAME = "cache"
# The extension for our persisted @st.cache_data objects.
# (`@st.cache_data` was originally called `@st.memo`)
_CACHED_FILE_EXTENSION = "memo"
_LOGGER = get_logger(__name__)
class LocalDiskCacheStorageManager(CacheStorageManager):
def create(self, context: CacheStorageContext) -> CacheStorage:
"""Creates a new cache storage instance wrapped with in-memory cache layer"""
persist_storage = LocalDiskCacheStorage(context)
return InMemoryCacheStorageWrapper(
persist_storage=persist_storage, context=context
)
def clear_all(self) -> None:
cache_path = get_cache_folder_path()
if os.path.isdir(cache_path):
shutil.rmtree(cache_path)
def check_context(self, context: CacheStorageContext) -> None:
if (
context.persist == "disk"
and context.ttl_seconds is not None
and not math.isinf(context.ttl_seconds)
):
_LOGGER.warning(
f"The cached function '{context.function_display_name}' has a TTL "
"that will be ignored. Persistent cached functions currently don't "
"support TTL."
)
class LocalDiskCacheStorage(CacheStorage):
"""Cache storage that persists data to disk
This is the default cache persistence layer for `@st.cache_data`
"""
def __init__(self, context: CacheStorageContext):
self.function_key = context.function_key
self.persist = context.persist
self._ttl_seconds = context.ttl_seconds
self._max_entries = context.max_entries
@property
def ttl_seconds(self) -> float:
return self._ttl_seconds if self._ttl_seconds is not None else math.inf
@property
def max_entries(self) -> float:
return float(self._max_entries) if self._max_entries is not None else math.inf
def get(self, key: str) -> bytes:
"""
Returns the stored value for the key if persisted,
raise CacheStorageKeyNotFoundError if not found, or not configured
with persist="disk"
"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
with streamlit_read(path, binary=True) as input:
value = input.read()
_LOGGER.debug("Disk cache HIT: %s", key)
return bytes(value)
except FileNotFoundError:
raise CacheStorageKeyNotFoundError("Key not found in disk cache")
except Exception as ex:
_LOGGER.error(ex)
raise CacheStorageError("Unable to read from cache") from ex
else:
raise CacheStorageKeyNotFoundError(
f"Local disk cache storage is disabled (persist={self.persist})"
)
def set(self, key: str, value: bytes) -> None:
"""Sets the value for a given key"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
with streamlit_write(path, binary=True) as output:
output.write(value)
except util.Error as e:
_LOGGER.debug(e)
# Clean up file so we don't leave zero byte files.
try:
os.remove(path)
except (FileNotFoundError, IOError, OSError):
# If we can't remove the file, it's not a big deal.
pass
raise CacheStorageError("Unable to write to cache") from e
def delete(self, key: str) -> None:
"""Delete a cache file from disk. If the file does not exist on disk,
return silently. If another exception occurs, log it. Does not throw.
"""
if self.persist == "disk":
path = self._get_cache_file_path(key)
try:
os.remove(path)
except FileNotFoundError:
# The file is already removed.
pass
except Exception as ex:
_LOGGER.exception(
"Unable to remove a file from the disk cache", exc_info=ex
)
def clear(self) -> None:
"""Delete all keys for the current storage"""
cache_dir = get_cache_folder_path()
if os.path.isdir(cache_dir):
# We try to remove all files in the cache directory that start with
# the function key, whether `clear` called for `self.persist`
# storage or not, to avoid leaving orphaned files in the cache directory.
for file_name in os.listdir(cache_dir):
if self._is_cache_file(file_name):
os.remove(os.path.join(cache_dir, file_name))
def close(self) -> None:
"""Dummy implementation of close, we don't need to actually "close" anything"""
def _get_cache_file_path(self, value_key: str) -> str:
"""Return the path of the disk cache file for the given value."""
cache_dir = get_cache_folder_path()
return os.path.join(
cache_dir, f"{self.function_key}-{value_key}.{_CACHED_FILE_EXTENSION}"
)
def _is_cache_file(self, fname: str) -> bool:
"""Return true if the given file name is a cache file for this storage."""
return fname.startswith(f"{self.function_key}-") and fname.endswith(
f".{_CACHED_FILE_EXTENSION}"
)
def get_cache_folder_path() -> str:
return get_streamlit_file_path(_CACHE_DIR_NAME)