# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2026)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""@st.cache_data: pickle-based caching."""

from __future__ import annotations

import pickle
import threading
from typing import (
    TYPE_CHECKING,
    Any,
    Final,
    Literal,
    TypeAlias,
    TypeVar,
    cast,
    overload,
)

from typing_extensions import ParamSpec

import streamlit as st
from streamlit import runtime
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.runtime.caching.cache_errors import CacheError, CacheKeyNotFoundError
from streamlit.runtime.caching.cache_type import CacheType
from streamlit.runtime.caching.cache_utils import (
    Cache,
    CachedFunc,
    CachedFuncInfo,
    CacheScope,
    get_session_id_or_throw,
    make_cached_func_wrapper,
)
from streamlit.runtime.caching.cached_message_replay import (
    CachedMessageReplayContext,
    CachedResult,
    MsgData,
)
from streamlit.runtime.caching.storage import (
    CacheStorage,
    CacheStorageContext,
    CacheStorageError,
    CacheStorageKeyNotFoundError,
    CacheStorageManager,
)
from streamlit.runtime.caching.storage.cache_storage_protocol import (
    InvalidCacheStorageContextError,
)
from streamlit.runtime.caching.storage.dummy_cache_storage import (
    MemoryCacheStorageManager,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.runtime.stats import (
    CACHE_MEMORY_FAMILY,
    CacheStat,
    StatsProvider,
    group_cache_stats,
)
from streamlit.time_util import time_to_seconds

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from datetime import timedelta

    from streamlit.runtime.caching.hashing import HashFuncsDict

_LOGGER: Final = get_logger(__name__)

CACHE_DATA_MESSAGE_REPLAY_CTX = CachedMessageReplayContext(CacheType.DATA)

# The cache persistence options we support: "disk" or None
CachePersistType: TypeAlias = Literal["disk"] | None


P = ParamSpec("P")
R = TypeVar("R")


class CachedDataFuncInfo(CachedFuncInfo[P, R]):
    """Implements the CachedFuncInfo interface for @st.cache_data."""

    persist: CachePersistType
    max_entries: int | None
    ttl: float | timedelta | str | None

    def __init__(
        self,
        func: Callable[P, R],
        persist: CachePersistType,
        max_entries: int | None,
        ttl: float | timedelta | str | None,
        show_spinner: bool | str,
        show_time: bool = False,
        hash_funcs: HashFuncsDict | None = None,
        scope: CacheScope = "global",
    ) -> None:
        super().__init__(
            func,
            hash_funcs=hash_funcs,
            show_spinner=show_spinner,
            show_time=show_time,
            scope=scope,
        )
        self.persist = persist
        self.max_entries = max_entries
        self.ttl = ttl

        self.validate_params()

    @property
    def cache_type(self) -> CacheType:
        return CacheType.DATA

    @property
    def cached_message_replay_ctx(self) -> CachedMessageReplayContext:
        return CACHE_DATA_MESSAGE_REPLAY_CTX

    @property
    def display_name(self) -> str:
        """A human-readable name for the cached function."""
        return f"{self.func.__module__}.{self.func.__qualname__}"

    def get_function_cache(self, function_key: str) -> Cache[R]:
        return _data_caches.get_cache(
            key=function_key,
            persist=self.persist,
            max_entries=self.max_entries,
            ttl=self.ttl,
            display_name=self.display_name,
            scope=self.scope,
        )

    def validate_params(self) -> None:
        """
        Validate the params passed to @st.cache_data are compatible with cache storage.

        When called, this method could log warnings if cache params are invalid
        for current storage.
        """
        _data_caches.validate_cache_params(
            function_name=self.func.__name__,
            persist=self.persist,
            max_entries=self.max_entries,
            ttl=self.ttl,
        )


class DataCaches(StatsProvider):
    """Manages all DataCache instances."""

    def __init__(self) -> None:
        self._caches_lock = threading.Lock()
        # Map of session IDs to map of function keys to caches.
        self._function_caches: dict[str | None, dict[str, DataCache[Any]]] = {}

    @property
    def stats_families(self) -> Sequence[str]:
        return (CACHE_MEMORY_FAMILY,)

    def get_cache(
        self,
        key: str,
        persist: CachePersistType,
        max_entries: int | None,
        ttl: int | float | timedelta | str | None,
        display_name: str,
        scope: CacheScope = "global",
    ) -> DataCache[Any]:
        """Return the mem cache for the given key.

        If it doesn't exist, create a new one with the given params.

        Raises
        ------
        StreamlitAPIException
            Raised when ``scope`` is ``"session"`` and there is no thread-local run
            context.
        """

        ttl_seconds = time_to_seconds(ttl, coerce_none_to_inf=False)

        # Fetch the session ID. Note that this will throw an exception if there is no
        # session associated with the current thread.
        session_id: str | None
        if scope == "global":
            session_id = None
        else:
            session_id = get_session_id_or_throw()

        # Get the existing cache, if it exists, and validate that its params
        # haven't changed.
        with self._caches_lock:
            session_caches = self._function_caches.get(session_id)
            if session_caches is None:
                session_caches = self._function_caches[session_id] = {}

            cache = session_caches.get(key)
            if (
                cache is not None
                and cache.ttl_seconds == ttl_seconds
                and cache.max_entries == max_entries
                and cache.persist == persist
            ):
                return cache

            # Close the existing cache's storage, if it exists.
            if cache is not None:
                _LOGGER.debug(
                    "Closing existing DataCache storage "
                    "(key=%s, persist=%s, max_entries=%s, ttl=%s) "
                    "before creating new one with different params",
                    key,
                    persist,
                    max_entries,
                    ttl,
                )
                cache.storage.close()

            # Create a new cache object and put it in our dict
            _LOGGER.debug(
                "Creating new DataCache (key=%s, persist=%s, max_entries=%s, ttl=%s)",
                key,
                persist,
                max_entries,
                ttl,
            )

            cache_context = self.create_cache_storage_context(
                function_key=key,
                function_name=display_name,
                ttl_seconds=ttl_seconds,
                max_entries=max_entries,
                persist=persist,
            )
            cache_storage_manager = self.get_storage_manager()
            storage = cache_storage_manager.create(cache_context)

            cache = DataCache(
                key=key,
                storage=storage,
                persist=persist,
                max_entries=max_entries,
                ttl_seconds=ttl_seconds,
                display_name=display_name,
            )
            self._function_caches[session_id][key] = cache
            return cache

    def clear_session(self, session_id: str) -> None:
        """Clears all caches for the given session ID."""
        # Hold the lock while removing the cache, but release it while clearing.
        with self._caches_lock:
            session_caches = self._function_caches.get(session_id)
            if session_caches is not None:
                del self._function_caches[session_id]

        if session_caches is not None:
            for cache in session_caches.values():
                cache.clear()
                cache.storage.close()

    def clear_all(self) -> None:
        """Clear all in-memory and on-disk caches."""
        with self._caches_lock:
            try:
                # try to remove in optimal way if such ability provided by
                # storage manager clear_all method;
                # if not implemented, fallback to remove all
                # available storages one by one
                self.get_storage_manager().clear_all()
            except NotImplementedError:
                for data_caches in self._function_caches.values():
                    for data_cache in data_caches.values():
                        data_cache.clear()
                        data_cache.storage.close()
            self._function_caches = {}

    def get_stats(
        self, _family_names: Sequence[str] | None = None
    ) -> dict[str, list[CacheStat]]:
        with self._caches_lock:
            # Shallow-clone our caches. We don't want to hold the global
            # lock during stats-gathering.
            function_caches = [
                cache
                for caches in self._function_caches.values()
                for cache in caches.values()
            ]

        stats: list[CacheStat] = []
        for cache in function_caches:
            cache_stats = cache.get_stats()
            for family_stats in cache_stats.values():
                stats.extend(family_stats)
        if not stats:
            return {}
        # In general, get_stats methods need to be able to return only requested stat
        # families, but this method only returns a single family, and we're guaranteed
        # that it was one of those requested if we make it here.
        return {CACHE_MEMORY_FAMILY: group_cache_stats(stats)}

    def validate_cache_params(
        self,
        function_name: str,
        persist: CachePersistType,
        max_entries: int | None,
        ttl: int | float | timedelta | str | None,
    ) -> None:
        """Validate that the cache params are valid for given storage.

        Raises
        ------
        InvalidCacheStorageContext
            Raised if the cache storage manager is not able to work with provided
            CacheStorageContext.
        """

        ttl_seconds = time_to_seconds(ttl, coerce_none_to_inf=False)

        cache_context = self.create_cache_storage_context(
            function_key="DUMMY_KEY",
            function_name=function_name,
            ttl_seconds=ttl_seconds,
            max_entries=max_entries,
            persist=persist,
        )
        try:
            self.get_storage_manager().check_context(cache_context)
        except InvalidCacheStorageContextError:
            _LOGGER.exception(
                "Cache params for function %s are incompatible with current "
                "cache storage manager.",
                function_name,
            )
            raise

    def create_cache_storage_context(
        self,
        function_key: str,
        function_name: str,
        persist: CachePersistType,
        ttl_seconds: float | None,
        max_entries: int | None,
    ) -> CacheStorageContext:
        return CacheStorageContext(
            function_key=function_key,
            function_display_name=function_name,
            ttl_seconds=ttl_seconds,
            max_entries=max_entries,
            persist=persist,
        )

    def get_storage_manager(self) -> CacheStorageManager:
        if runtime.exists():
            return runtime.get_instance().cache_storage_manager
        # When running in "raw mode", we can't access the CacheStorageManager,
        # so we're falling back to InMemoryCache.
        _LOGGER.warning("No runtime found, using MemoryCacheStorageManager")
        return MemoryCacheStorageManager()


# Singleton DataCaches instance
_data_caches = DataCaches()


def clear_session_cache(session_id: str) -> None:
    """Clears all caches for the given session ID."""
    _data_caches.clear_session(session_id)


def get_data_cache_stats_provider() -> StatsProvider:
    """Return the StatsProvider for all @st.cache_data functions."""
    return _data_caches


class CacheDataAPI:
    """Implements the public st.cache_data API: the @st.cache_data decorator, and
    st.cache_data.clear().
    """

    def __init__(self, decorator_metric_name: str) -> None:
        """Create a CacheDataAPI instance.

        Parameters
        ----------
        decorator_metric_name
            The metric name to record for decorator usage.
        """

        # Parameterize the decorator metric name.
        # (Ignore spurious mypy complaints - https://github.com/python/mypy/issues/2427)
        self._decorator = gather_metrics(  # type: ignore
            decorator_metric_name, self._decorator
        )

    # Type-annotate the decorator function.
    # (See https://mypy.readthedocs.io/en/stable/generics.html#decorator-factories)

    # Bare decorator usage
    @overload
    def __call__(self, func: Callable[P, R]) -> CachedFunc[P, R]: ...

    # Decorator with arguments
    @overload
    def __call__(
        self,
        *,
        ttl: float | timedelta | str | None = None,
        max_entries: int | None = None,
        show_spinner: bool | str = True,
        show_time: bool = False,
        persist: CachePersistType | bool = None,
        hash_funcs: HashFuncsDict | None = None,
        scope: CacheScope = "global",
    ) -> Callable[[Callable[P, R]], CachedFunc[P, R]]: ...

    def __call__(
        self,
        func: Callable[P, R] | None = None,
        *,
        ttl: float | timedelta | str | None = None,
        max_entries: int | None = None,
        show_spinner: bool | str = True,
        show_time: bool = False,
        persist: CachePersistType | bool = None,
        hash_funcs: HashFuncsDict | None = None,
        scope: CacheScope = "global",
    ) -> CachedFunc[P, R] | Callable[[Callable[P, R]], CachedFunc[P, R]]:
        return self._decorator(
            func,  # ty: ignore[invalid-argument-type]
            ttl=ttl,
            max_entries=max_entries,
            persist=persist,
            show_spinner=show_spinner,
            show_time=show_time,
            hash_funcs=hash_funcs,
            scope=scope,
        )

    def _decorator(
        self,
        func: Callable[P, R] | None = None,
        *,
        ttl: float | timedelta | str | None,
        max_entries: int | None,
        show_spinner: bool | str,
        show_time: bool = False,
        persist: CachePersistType | bool,
        hash_funcs: HashFuncsDict | None = None,
        scope: CacheScope = "global",
    ) -> CachedFunc[P, R] | Callable[[Callable[P, R]], CachedFunc[P, R]]:
        """Decorator to cache functions that return data (e.g. dataframe transforms, database queries, ML inference).

        Cached objects can be global or session-scoped. Global cached data is
        available across all users, sessions, and reruns. Session-scoped cached
        data is only available in the current session and removed when the
        session disconnects.

        Cached objects are stored in "pickled" form, which means that the return
        value of a cached function must be pickleable. Each caller of the cached
        function gets its own copy of the cached data.

        You can clear a function's cache with ``func.clear()`` or clear the entire
        cache with ``st.cache_data.clear()``.

        A function's arguments must be hashable to cache it. Streamlit makes a
        best effort to hash a variety of objects, but the fallback hashing method
        requires that the argument be pickleable, also. If you have an unhashable
        argument (like a database connection) or an argument you want to exclude
        from caching, use an underscore prefix in the argument name. In this case,
        Streamlit will return a cached value when all other arguments match a
        previous function call. Alternatively, you can declare custom hashing
        functions with ``hash_funcs``.

        Objects cached by ``st.cache_data`` are returned as copies. To cache a
        shared resource or something you want to mutate in place, use
        ``st.cache_resource`` instead. To learn more about caching, see
        `Caching overview
        <https://docs.streamlit.io/develop/concepts/architecture/caching>`_.

        .. note::
            Caching async functions is not supported. To upvote this feature,
            see GitHub issue `#8308
            <https://github.com/streamlit/streamlit/issues/8308>`_.

        Parameters
        ----------
        func : callable
            The function to cache. Streamlit hashes the function's source code.

        ttl : float, timedelta, str, or None
            The maximum time to keep an entry in the cache. Can be one of:

            - ``None`` if cache entries should never expire (default).
            - A number specifying the time in seconds.
            - A string specifying the time in a format supported by `Pandas's
              Timedelta constructor <https://pandas.pydata.org/docs/reference/api/pandas.Timedelta.html>`_,
              e.g. ``"1d"``, ``"1.5 days"``, or ``"1h23s"``.
            - A ``timedelta`` object from `Python's built-in datetime library
              <https://docs.python.org/3/library/datetime.html#timedelta-objects>`_,
              e.g. ``timedelta(days=1)``.

            Note that ``ttl`` will be ignored if ``persist="disk"`` or ``persist=True``.

        max_entries : int or None
            The maximum number of entries to keep in the cache, or None
            for an unbounded cache. When a new entry is added to a full cache,
            the oldest cached entry will be removed. Defaults to None.

        show_spinner : bool or str
            Enable the spinner. Default is True to show a spinner when there is
            a "cache miss" and the cached data is being created. If string,
            value of show_spinner param will be used for spinner text.

        show_time : bool
            Whether to show the elapsed time next to the spinner text. If this is
            ``False`` (default), no time is displayed. If this is ``True``,
            elapsed time is displayed with a precision of 0.1 seconds. The time
            format is not configurable.

        persist : "disk", bool, or None
            Optional location to persist cached data to. Passing "disk" (or True)
            will persist the cached data to the local disk. None (or False) will disable
            persistence. The default is None.

        hash_funcs : dict or None
            Mapping of types or fully qualified names to hash functions.
            This is used to override the behavior of the hasher inside Streamlit's
            caching mechanism: when the hasher encounters an object, it will first
            check to see if its type matches a key in this dict and, if so, will use
            the provided function to generate a hash for it. See below for an example
            of how this can be used.

        scope : "global" or "session"
            The scope for the data cache. If this is ``"global"`` (default),
            the data is cached globally. If this is ``"session"``, the data is
            removed from the cache when the session disconnects.

            Because a session-scoped cache is cleared when a session disconnects,
            an unstable network connection can cause the cache to populate
            multiple times in a single session. If this is a problem, you might
            consider adjusting the ``server.websocketPingInterval``
            configuration option.

        Example
        -------
        >>> import streamlit as st
        >>>
        >>> @st.cache_data
        ... def fetch_and_clean_data(url):
        ...     # Fetch data from URL here, and then clean it up.
        ...     return data
        >>>
        >>> d1 = fetch_and_clean_data(DATA_URL_1)
        >>> # Actually executes the function, since this is the first time it was
        >>> # encountered.
        >>>
        >>> d2 = fetch_and_clean_data(DATA_URL_1)
        >>> # Does not execute the function. Instead, returns its previously computed
        >>> # value. This means that now the data in d1 is the same as in d2.
        >>>
        >>> d3 = fetch_and_clean_data(DATA_URL_2)
        >>> # This is a different URL, so the function executes.

        To set the ``persist`` parameter, use this command as follows:

        >>> import streamlit as st
        >>>
        >>> @st.cache_data(persist="disk")
        ... def fetch_and_clean_data(url):
        ...     # Fetch data from URL here, and then clean it up.
        ...     return data

        By default, all parameters to a cached function must be hashable.
        Any parameter whose name begins with ``_`` will not be hashed. You can use
        this as an "escape hatch" for parameters that are not hashable:

        >>> import streamlit as st
        >>>
        >>> @st.cache_data
        ... def fetch_and_clean_data(_db_connection, num_rows):
        ...     # Fetch data from _db_connection here, and then clean it up.
        ...     return data
        >>>
        >>> connection = make_database_connection()
        >>> d1 = fetch_and_clean_data(connection, num_rows=10)
        >>> # Actually executes the function, since this is the first time it was
        >>> # encountered.
        >>>
        >>> another_connection = make_database_connection()
        >>> d2 = fetch_and_clean_data(another_connection, num_rows=10)
        >>> # Does not execute the function. Instead, returns its previously computed
        >>> # value - even though the _database_connection parameter was different
        >>> # in both calls.

        A cached function's cache can be procedurally cleared:

        >>> import streamlit as st
        >>>
        >>> @st.cache_data
        ... def fetch_and_clean_data(_db_connection, num_rows):
        ...     # Fetch data from _db_connection here, and then clean it up.
        ...     return data
        >>>
        >>> fetch_and_clean_data.clear(_db_connection, 50)
        >>> # Clear the cached entry for the arguments provided.
        >>>
        >>> fetch_and_clean_data.clear()
        >>> # Clear all cached entries for this function.

        To override the default hashing behavior, pass a custom hash function.
        You can do that by mapping a type (e.g. ``datetime.datetime``) to a hash
        function (``lambda dt: dt.isoformat()``) like this:

        >>> import streamlit as st
        >>> import datetime
        >>>
        >>> @st.cache_data(hash_funcs={datetime.datetime: lambda dt: dt.isoformat()})
        ... def convert_to_utc(dt: datetime.datetime):
        ...     return dt.astimezone(datetime.timezone.utc)

        Alternatively, you can map the type's fully-qualified name
        (e.g. ``"datetime.datetime"``) to the hash function instead:

        >>> import streamlit as st
        >>> import datetime
        >>>
        >>> @st.cache_data(hash_funcs={"datetime.datetime": lambda dt: dt.isoformat()})
        ... def convert_to_utc(dt: datetime.datetime):
        ...     return dt.astimezone(datetime.timezone.utc)

        """

        # Parse our persist value into a string
        persist_string: CachePersistType
        if persist is True:
            persist_string = "disk"
        elif persist is False:
            persist_string = None
        else:
            persist_string = persist

        if persist_string not in (None, "disk"):
            # We'll eventually have more persist options.
            raise StreamlitAPIException(
                f"Unsupported persist option '{persist}'. Valid values are 'disk' or None."
            )

        if scope not in ("global", "session"):
            raise StreamlitAPIException(
                f"Unsupported scope option '{scope}'. Valid values are 'global' or 'session'."
            )

        def wrapper(f: Callable[P, R]) -> CachedFunc[P, R]:
            return make_cached_func_wrapper(
                CachedDataFuncInfo(
                    func=f,
                    persist=persist_string,
                    show_spinner=show_spinner,
                    show_time=show_time,
                    max_entries=max_entries,
                    ttl=ttl,
                    hash_funcs=hash_funcs,
                    scope=scope,
                )
            )

        if func is None:
            return wrapper

        return make_cached_func_wrapper(
            CachedDataFuncInfo(
                func=func,
                persist=persist_string,
                show_spinner=show_spinner,
                show_time=show_time,
                max_entries=max_entries,
                ttl=ttl,
                hash_funcs=hash_funcs,
                scope=scope,
            )
        )

    @gather_metrics("clear_data_caches")
    def clear(self) -> None:
        """Clear all in-memory and on-disk data caches."""
        _data_caches.clear_all()


class DataCache(Cache[R]):
    """Manages cached values for a single st.cache_data function."""

    def __init__(
        self,
        key: str,
        storage: CacheStorage,
        persist: CachePersistType,
        max_entries: int | None,
        ttl_seconds: float | None,
        display_name: str,
    ) -> None:
        super().__init__()
        self.key = key
        self.display_name = display_name
        self.storage = storage
        self.ttl_seconds = ttl_seconds
        self.max_entries = max_entries
        self.persist = persist

    def get_stats(
        self, _family_names: Sequence[str] | None = None
    ) -> dict[str, list[CacheStat]]:
        # In general, get_stats methods need to be able to return only requested stat
        # families, but this method only returns a single family, and we're guaranteed
        # that it was one of those requested if we make it here.
        if isinstance(self.storage, StatsProvider):
            return cast("dict[str, list[CacheStat]]", self.storage.get_stats())
        return {}

    def read_result(self, key: str) -> CachedResult[R]:
        """Read a value and messages from the cache. Raise `CacheKeyNotFoundError`
        if the value doesn't exist, and `CacheError` if the value exists but can't
        be unpickled.
        """
        try:
            pickled_entry = self.storage.get(key)
        except CacheStorageKeyNotFoundError as e:
            raise CacheKeyNotFoundError(str(e)) from e
        except CacheStorageError as e:
            raise CacheError(str(e)) from e

        try:
            entry = pickle.loads(pickled_entry)  # noqa: S301
            if not isinstance(entry, CachedResult):
                # Loaded an old cache file format, remove it and let the caller
                # rerun the function.
                self.storage.delete(key)
                raise CacheKeyNotFoundError()
            return entry
        except pickle.UnpicklingError as exc:
            raise CacheError(f"Failed to unpickle {key}") from exc

    @gather_metrics("_cache_data_object")
    def write_result(self, key: str, value: R, messages: list[MsgData]) -> None:
        """Write a value and associated messages to the cache.
        The value must be pickleable.
        """
        try:
            main_id = st._main.id
            sidebar_id = st.sidebar.id
            entry = CachedResult(value, messages, main_id, sidebar_id)
            pickled_entry = pickle.dumps(entry)
        except (pickle.PicklingError, TypeError) as exc:
            raise CacheError(f"Failed to pickle {key}") from exc
        self.storage.set(key, pickled_entry)

    def _clear(self, key: str | None = None) -> None:
        if not key:
            self.storage.clear()
        else:
            self.storage.delete(key)
