# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common cache logic shared by st.cache_data and st.cache_resource.""" from __future__ import annotations import functools import hashlib import inspect import math import threading import time import types from abc import abstractmethod from collections import defaultdict from datetime import timedelta from typing import Any, Callable, overload from typing_extensions import Literal from streamlit import type_util from streamlit.elements.spinner import spinner from streamlit.logger import get_logger from streamlit.runtime.caching.cache_errors import ( CacheError, CacheKeyNotFoundError, UnevaluatedDataFrameError, UnhashableParamError, UnhashableTypeError, UnserializableReturnValueError, get_cached_func_name_md, ) from streamlit.runtime.caching.cache_type import CacheType from streamlit.runtime.caching.cached_message_replay import ( CachedMessageReplayContext, CachedResult, MsgData, replay_cached_messages, ) from streamlit.runtime.caching.hashing import update_hash _LOGGER = get_logger(__name__) # The timer function we use with TTLCache. This is the default timer func, but # is exposed here as a constant so that it can be patched in unit tests. TTLCACHE_TIMER = time.monotonic @overload def ttl_to_seconds( ttl: float | timedelta | None, *, coerce_none_to_inf: Literal[False] ) -> float | None: ... @overload def ttl_to_seconds(ttl: float | timedelta | None) -> float: ... def ttl_to_seconds( ttl: float | timedelta | None, *, coerce_none_to_inf: bool = True ) -> float | None: """ Convert a ttl value to a float representing "number of seconds". """ if coerce_none_to_inf and ttl is None: return math.inf if isinstance(ttl, timedelta): return ttl.total_seconds() return ttl # We show a special "UnevaluatedDataFrame" warning for cached funcs # that attempt to return one of these unserializable types: UNEVALUATED_DATAFRAME_TYPES = ( "snowflake.snowpark.table.Table", "snowflake.snowpark.dataframe.DataFrame", "pyspark.sql.dataframe.DataFrame", ) class Cache: """Function cache interface. Caches persist across script runs.""" def __init__(self): self._value_locks: dict[str, threading.Lock] = defaultdict(threading.Lock) self._value_locks_lock = threading.Lock() @abstractmethod def read_result(self, value_key: str) -> CachedResult: """Read a value and associated messages from the cache. Raises ------ CacheKeyNotFoundError Raised if value_key is not in the cache. """ raise NotImplementedError @abstractmethod def write_result(self, value_key: str, value: Any, messages: list[MsgData]) -> None: """Write a value and associated messages to the cache, overwriting any existing result that uses the value_key. """ # We *could* `del self._value_locks[value_key]` here, since nobody will be taking # a compute_value_lock for this value_key after the result is written. raise NotImplementedError def compute_value_lock(self, value_key: str) -> threading.Lock: """Return the lock that should be held while computing a new cached value. In a popular app with a cache that hasn't been pre-warmed, many sessions may try to access a not-yet-cached value simultaneously. We use a lock to ensure that only one of those sessions computes the value, and the others block until the value is computed. """ with self._value_locks_lock: return self._value_locks[value_key] def clear(self): """Clear all values from this cache.""" with self._value_locks_lock: self._value_locks.clear() self._clear() @abstractmethod def _clear(self) -> None: """Subclasses must implement this to perform cache-clearing logic.""" raise NotImplementedError class CachedFuncInfo: """Encapsulates data for a cached function instance. CachedFuncInfo instances are scoped to a single script run - they're not persistent. """ def __init__( self, func: types.FunctionType, show_spinner: bool | str, allow_widgets: bool, ): self.func = func self.show_spinner = show_spinner self.allow_widgets = allow_widgets @property def cache_type(self) -> CacheType: raise NotImplementedError @property def cached_message_replay_ctx(self) -> CachedMessageReplayContext: raise NotImplementedError def get_function_cache(self, function_key: str) -> Cache: """Get or create the function cache for the given key.""" raise NotImplementedError def make_cached_func_wrapper(info: CachedFuncInfo) -> Callable[..., Any]: """Create a callable wrapper around a CachedFunctionInfo. Calling the wrapper will return the cached value if it's already been computed, and will call the underlying function to compute and cache the value otherwise. The wrapper also has a `clear` function that can be called to clear all of the wrapper's cached values. """ cached_func = CachedFunc(info) # We'd like to simply return `cached_func`, which is already a Callable. # But using `functools.update_wrapper` on the CachedFunc instance # itself results in errors when our caching decorators are used to decorate # member functions. (See https://github.com/streamlit/streamlit/issues/6109) @functools.wraps(info.func) def wrapper(*args, **kwargs): return cached_func(*args, **kwargs) # Give our wrapper its `clear` function. # (This results in a spurious mypy error that we suppress.) wrapper.clear = cached_func.clear # type: ignore return wrapper class CachedFunc: def __init__(self, info: CachedFuncInfo): self._info = info self._function_key = _make_function_key(info.cache_type, info.func) def __call__(self, *args, **kwargs) -> Any: """The wrapper. We'll only call our underlying function on a cache miss.""" name = self._info.func.__qualname__ if isinstance(self._info.show_spinner, bool): if len(args) == 0 and len(kwargs) == 0: message = f"Running `{name}()`." else: message = f"Running `{name}(...)`." else: message = self._info.show_spinner if self._info.show_spinner or isinstance(self._info.show_spinner, str): with spinner(message): return self._get_or_create_cached_value(args, kwargs) else: return self._get_or_create_cached_value(args, kwargs) def _get_or_create_cached_value( self, func_args: tuple[Any, ...], func_kwargs: dict[str, Any] ) -> Any: # Retrieve the function's cache object. We must do this "just-in-time" # (as opposed to in the constructor), because caches can be invalidated # at any time. cache = self._info.get_function_cache(self._function_key) # Generate the key for the cached value. This is based on the # arguments passed to the function. value_key = _make_value_key( cache_type=self._info.cache_type, func=self._info.func, func_args=func_args, func_kwargs=func_kwargs, ) try: cached_result = cache.read_result(value_key) return self._handle_cache_hit(cached_result) except CacheKeyNotFoundError: return self._handle_cache_miss(cache, value_key, func_args, func_kwargs) def _handle_cache_hit(self, result: CachedResult) -> Any: """Handle a cache hit: replay the result's cached messages, and return its value.""" replay_cached_messages( result, self._info.cache_type, self._info.func, ) return result.value def _handle_cache_miss( self, cache: Cache, value_key: str, func_args: tuple[Any, ...], func_kwargs: dict[str, Any], ) -> Any: """Handle a cache miss: compute a new cached value, write it back to the cache, and return that newly-computed value. """ # Implementation notes: # - We take a "compute_value_lock" before computing our value. This ensures that # multiple sessions don't try to compute the same value simultaneously. # # - We use a different lock for each value_key, as opposed to a single lock for # the entire cache, so that unrelated value computations don't block on each other. # # - When retrieving a cache entry that may not yet exist, we use a "double-checked locking" # strategy: first we try to retrieve the cache entry without taking a value lock. (This # happens in `_get_or_create_cached_value()`.) If that fails because the value hasn't # been computed yet, we take the value lock and then immediately try to retrieve cache entry # *again*, while holding the lock. If the cache entry exists at this point, it means that # another thread computed the value before us. # # This means that the happy path ("cache entry exists") is a wee bit faster because # no lock is acquired. But the unhappy path ("cache entry needs to be recomputed") is # a wee bit slower, because we do two lookups for the entry. with cache.compute_value_lock(value_key): # We've acquired the lock - but another thread may have acquired it first # and already computed the value. So we need to test for a cache hit again, # before computing. try: cached_result = cache.read_result(value_key) # Another thread computed the value before us. Early exit! return self._handle_cache_hit(cached_result) except CacheKeyNotFoundError: # We acquired the lock before any other thread. Compute the value! with self._info.cached_message_replay_ctx.calling_cached_function( self._info.func, self._info.allow_widgets ): computed_value = self._info.func(*func_args, **func_kwargs) # We've computed our value, and now we need to write it back to the cache # along with any "replay messages" that were generated during value computation. messages = self._info.cached_message_replay_ctx._most_recent_messages try: cache.write_result(value_key, computed_value, messages) return computed_value except (CacheError, RuntimeError): # An exception was thrown while we tried to write to the cache. Report it to the user. # (We catch `RuntimeError` here because it will be raised by Apache Spark if we do not # collect dataframe before using `st.cache_data`.) if True in [ type_util.is_type(computed_value, type_name) for type_name in UNEVALUATED_DATAFRAME_TYPES ]: raise UnevaluatedDataFrameError( f""" The function {get_cached_func_name_md(self._info.func)} is decorated with `st.cache_data` but it returns an unevaluated dataframe of type `{type_util.get_fqn_type(computed_value)}`. Please call `collect()` or `to_pandas()` on the dataframe before returning it, so `st.cache_data` can serialize and cache it.""" ) raise UnserializableReturnValueError( return_value=computed_value, func=self._info.func ) def clear(self): """Clear the wrapped function's associated cache.""" cache = self._info.get_function_cache(self._function_key) cache.clear() def _make_value_key( cache_type: CacheType, func: types.FunctionType, func_args: tuple[Any, ...], func_kwargs: dict[str, Any], ) -> str: """Create the key for a value within a cache. This key is generated from the function's arguments. All arguments will be hashed, except for those named with a leading "_". Raises ------ StreamlitAPIException Raised (with a nicely-formatted explanation message) if we encounter an un-hashable arg. """ # Create a (name, value) list of all *args and **kwargs passed to the # function. arg_pairs: list[tuple[str | None, Any]] = [] for arg_idx in range(len(func_args)): arg_name = _get_positional_arg_name(func, arg_idx) arg_pairs.append((arg_name, func_args[arg_idx])) for kw_name, kw_val in func_kwargs.items(): # **kwargs ordering is preserved, per PEP 468 # https://www.python.org/dev/peps/pep-0468/, so this iteration is # deterministic. arg_pairs.append((kw_name, kw_val)) # Create the hash from each arg value, except for those args whose name # starts with "_". (Underscore-prefixed args are deliberately excluded from # hashing.) args_hasher = hashlib.new("md5") for arg_name, arg_value in arg_pairs: if arg_name is not None and arg_name.startswith("_"): _LOGGER.debug("Not hashing %s because it starts with _", arg_name) continue try: update_hash( (arg_name, arg_value), hasher=args_hasher, cache_type=cache_type, ) except UnhashableTypeError as exc: raise UnhashableParamError(cache_type, func, arg_name, arg_value, exc) value_key = args_hasher.hexdigest() _LOGGER.debug("Cache key: %s", value_key) return value_key def _make_function_key(cache_type: CacheType, func: types.FunctionType) -> str: """Create the unique key for a function's cache. A function's key is stable across reruns of the app, and changes when the function's source code changes. """ func_hasher = hashlib.new("md5") # Include the function's __module__ and __qualname__ strings in the hash. # This means that two identical functions in different modules # will not share a hash; it also means that two identical *nested* # functions in the same module will not share a hash. update_hash( (func.__module__, func.__qualname__), hasher=func_hasher, cache_type=cache_type, ) # Include the function's source code in its hash. If the source code can't # be retrieved, fall back to the function's bytecode instead. source_code: str | bytes try: source_code = inspect.getsource(func) except OSError as e: _LOGGER.debug( "Failed to retrieve function's source code when building its key; falling back to bytecode. err={0}", e, ) source_code = func.__code__.co_code update_hash( source_code, hasher=func_hasher, cache_type=cache_type, ) cache_key = func_hasher.hexdigest() return cache_key def _get_positional_arg_name(func: types.FunctionType, arg_index: int) -> str | None: """Return the name of a function's positional argument. If arg_index is out of range, or refers to a parameter that is not a named positional argument (e.g. an *args, **kwargs, or keyword-only param), return None instead. """ if arg_index < 0: return None params: list[inspect.Parameter] = list(inspect.signature(func).parameters.values()) if arg_index >= len(params): return None if params[arg_index].kind in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY, ): return params[arg_index].name return None