Merging PR_218 openai_rev package with new streamlit chat app

This commit is contained in:
noptuno
2023-04-27 20:29:30 -04:00
parent 479b8d6d10
commit 355dee533b
8378 changed files with 2931636 additions and 3 deletions

View File

@@ -0,0 +1,33 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Explicitly export public symbols
from streamlit.runtime.scriptrunner.script_requests import RerunData as RerunData
from streamlit.runtime.scriptrunner.script_run_context import (
ScriptRunContext as ScriptRunContext,
)
from streamlit.runtime.scriptrunner.script_run_context import (
add_script_run_ctx as add_script_run_ctx,
)
from streamlit.runtime.scriptrunner.script_run_context import (
get_script_run_ctx as get_script_run_ctx,
)
from streamlit.runtime.scriptrunner.script_runner import (
RerunException as RerunException,
)
from streamlit.runtime.scriptrunner.script_runner import ScriptRunner as ScriptRunner
from streamlit.runtime.scriptrunner.script_runner import (
ScriptRunnerEvent as ScriptRunnerEvent,
)
from streamlit.runtime.scriptrunner.script_runner import StopException as StopException

View File

@@ -0,0 +1,197 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import sys
from typing_extensions import Final
# When a Streamlit app is magicified, we insert a `magic_funcs` import near the top of
# its module's AST:
# import streamlit.runtime.scriptrunner.magic_funcs as __streamlitmagic__
MAGIC_MODULE_NAME: Final = "__streamlitmagic__"
def add_magic(code, script_path):
"""Modifies the code to support magic Streamlit commands.
Parameters
----------
code : str
The Python code.
script_path : str
The path to the script file.
Returns
-------
ast.Module
The syntax tree for the code.
"""
# Pass script_path so we get pretty exceptions.
tree = ast.parse(code, script_path, "exec")
return _modify_ast_subtree(tree, is_root=True)
def _modify_ast_subtree(tree, body_attr="body", is_root=False):
"""Parses magic commands and modifies the given AST (sub)tree."""
body = getattr(tree, body_attr)
for i, node in enumerate(body):
node_type = type(node)
# Parse the contents of functions, With statements, and for statements
if (
node_type is ast.FunctionDef
or node_type is ast.With
or node_type is ast.For
or node_type is ast.While
or node_type is ast.AsyncFunctionDef
or node_type is ast.AsyncWith
or node_type is ast.AsyncFor
):
_modify_ast_subtree(node)
# Parse the contents of try statements
elif node_type is ast.Try:
for j, inner_node in enumerate(node.handlers):
node.handlers[j] = _modify_ast_subtree(inner_node)
finally_node = _modify_ast_subtree(node, body_attr="finalbody")
node.finalbody = finally_node.finalbody
_modify_ast_subtree(node)
# Convert if expressions to st.write
elif node_type is ast.If:
_modify_ast_subtree(node)
_modify_ast_subtree(node, "orelse")
# Convert standalone expression nodes to st.write
elif node_type is ast.Expr:
value = _get_st_write_from_expr(node, i, parent_type=type(tree))
if value is not None:
node.value = value
if is_root:
# Import Streamlit so we can use it in the new_value above.
_insert_import_statement(tree)
ast.fix_missing_locations(tree)
return tree
def _insert_import_statement(tree):
"""Insert Streamlit import statement at the top(ish) of the tree."""
st_import = _build_st_import_statement()
# If the 0th node is already an import statement, put the Streamlit
# import below that, so we don't break "from __future__ import".
if tree.body and type(tree.body[0]) in (ast.ImportFrom, ast.Import):
tree.body.insert(1, st_import)
# If the 0th node is a docstring and the 1st is an import statement,
# put the Streamlit import below those, so we don't break "from
# __future__ import".
elif (
len(tree.body) > 1
and (type(tree.body[0]) is ast.Expr and _is_docstring_node(tree.body[0].value))
and type(tree.body[1]) in (ast.ImportFrom, ast.Import)
):
tree.body.insert(2, st_import)
else:
tree.body.insert(0, st_import)
def _build_st_import_statement():
"""Build AST node for `import magic_funcs as __streamlitmagic__`."""
return ast.Import(
names=[
ast.alias(
name="streamlit.runtime.scriptrunner.magic_funcs",
asname=MAGIC_MODULE_NAME,
)
]
)
def _build_st_write_call(nodes):
"""Build AST node for `__streamlitmagic__.transparent_write(*nodes)`."""
return ast.Call(
func=ast.Attribute(
attr="transparent_write",
value=ast.Name(id=MAGIC_MODULE_NAME, ctx=ast.Load()),
ctx=ast.Load(),
),
args=nodes,
keywords=[],
kwargs=None,
starargs=None,
)
def _get_st_write_from_expr(node, i, parent_type):
# Don't change function calls
if type(node.value) is ast.Call:
return None
# Don't change Docstring nodes
if (
i == 0
and _is_docstring_node(node.value)
and parent_type in (ast.FunctionDef, ast.Module)
):
return None
# Don't change yield nodes
if type(node.value) is ast.Yield or type(node.value) is ast.YieldFrom:
return None
# Don't change await nodes
if type(node.value) is ast.Await:
return None
# If tuple, call st.write on the 0th element (rather than the
# whole tuple). This allows us to add a comma at the end of a statement
# to turn it into an expression that should be st-written. Ex:
# "np.random.randn(1000, 2),"
if type(node.value) is ast.Tuple:
args = node.value.elts
st_write = _build_st_write_call(args)
# st.write all strings.
elif type(node.value) is ast.Str:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write all variables.
elif type(node.value) is ast.Name:
args = [node.value]
st_write = _build_st_write_call(args)
# st.write everything else
else:
args = [node.value]
st_write = _build_st_write_call(args)
return st_write
def _is_docstring_node(node):
if sys.version_info >= (3, 8, 0):
return type(node) is ast.Constant and type(node.value) is str
else:
return type(node) is ast.Str

View File

@@ -0,0 +1,30 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
from streamlit.runtime.metrics_util import gather_metrics
@gather_metrics("magic")
def transparent_write(*args: Any) -> Any:
"""The function that gets magic-ified into Streamlit apps.
This is just st.write, but returns the arguments you passed to it.
"""
import streamlit as st
st.write(*args)
if len(args) == 1:
return args[0]
return args

View File

@@ -0,0 +1,180 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from dataclasses import dataclass
from enum import Enum
from typing import Optional, cast
from streamlit.proto.WidgetStates_pb2 import WidgetStates
from streamlit.runtime.state import coalesce_widget_states
class ScriptRequestType(Enum):
# The ScriptRunner should continue running its script.
CONTINUE = "CONTINUE"
# If the script is running, it should be stopped as soon
# as the ScriptRunner reaches an interrupt point.
# This is a terminal state.
STOP = "STOP"
# A script rerun has been requested. The ScriptRunner should
# handle this request as soon as it reaches an interrupt point.
RERUN = "RERUN"
@dataclass(frozen=True)
class RerunData:
"""Data attached to RERUN requests. Immutable."""
query_string: str = ""
widget_states: Optional[WidgetStates] = None
page_script_hash: str = ""
page_name: str = ""
@dataclass(frozen=True)
class ScriptRequest:
"""A STOP or RERUN request and associated data."""
type: ScriptRequestType
_rerun_data: Optional[RerunData] = None
@property
def rerun_data(self) -> RerunData:
if self.type is not ScriptRequestType.RERUN:
raise RuntimeError("RerunData is only set for RERUN requests.")
return cast(RerunData, self._rerun_data)
class ScriptRequests:
"""An interface for communicating with a ScriptRunner. Thread-safe.
AppSession makes requests of a ScriptRunner through this class, and
ScriptRunner handles those requests.
"""
def __init__(self):
self._lock = threading.Lock()
self._state = ScriptRequestType.CONTINUE
self._rerun_data = RerunData()
def request_stop(self) -> None:
"""Request that the ScriptRunner stop running. A stopped ScriptRunner
can't be used anymore. STOP requests succeed unconditionally.
"""
with self._lock:
self._state = ScriptRequestType.STOP
def request_rerun(self, new_data: RerunData) -> bool:
"""Request that the ScriptRunner rerun its script.
If the ScriptRunner has been stopped, this request can't be honored:
return False.
Otherwise, record the request and return True. The ScriptRunner will
handle the rerun request as soon as it reaches an interrupt point.
"""
with self._lock:
if self._state == ScriptRequestType.STOP:
# We can't rerun after being stopped.
return False
if self._state == ScriptRequestType.CONTINUE:
# If we're running, we can handle a rerun request
# unconditionally.
self._state = ScriptRequestType.RERUN
self._rerun_data = new_data
return True
if self._state == ScriptRequestType.RERUN:
# If we have an existing Rerun request, we coalesce this
# new request into it.
if self._rerun_data.widget_states is None:
# The existing request's widget_states is None, which
# means it wants to rerun with whatever the most
# recent script execution's widget state was.
# We have no meaningful state to merge with, and
# so we simply overwrite the existing request.
self._rerun_data = new_data
return True
if new_data.widget_states is not None:
# Both the existing and the new request have
# non-null widget_states. Merge them together.
coalesced_states = coalesce_widget_states(
self._rerun_data.widget_states, new_data.widget_states
)
self._rerun_data = RerunData(
query_string=new_data.query_string,
widget_states=coalesced_states,
page_script_hash=new_data.page_script_hash,
page_name=new_data.page_name,
)
return True
# If old widget_states is NOT None, and new widget_states IS
# None, then this new request is entirely redundant. Leave
# our existing rerun_data as is.
return True
# We'll never get here
raise RuntimeError(f"Unrecognized ScriptRunnerState: {self._state}")
def on_scriptrunner_yield(self) -> Optional[ScriptRequest]:
"""Called by the ScriptRunner when it's at a yield point.
If we have no request, return None.
If we have a RERUN request, return the request and set our internal
state to CONTINUE.
If we have a STOP request, return the request and remain stopped.
"""
if self._state == ScriptRequestType.CONTINUE:
# We avoid taking a lock in the common case. If a STOP or RERUN
# request is received between the `if` and `return`, it will be
# handled at the next `on_scriptrunner_yield`, or when
# `on_scriptrunner_ready` is called.
return None
with self._lock:
if self._state == ScriptRequestType.RERUN:
self._state = ScriptRequestType.CONTINUE
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
assert self._state == ScriptRequestType.STOP
return ScriptRequest(ScriptRequestType.STOP)
def on_scriptrunner_ready(self) -> ScriptRequest:
"""Called by the ScriptRunner when it's about to run its script for
the first time, and also after its script has successfully completed.
If we have a RERUN request, return the request and set
our internal state to CONTINUE.
If we have a STOP request or no request, set our internal state
to STOP.
"""
with self._lock:
if self._state == ScriptRequestType.RERUN:
self._state = ScriptRequestType.CONTINUE
return ScriptRequest(ScriptRequestType.RERUN, self._rerun_data)
# If we don't have a rerun request, unconditionally change our
# state to STOP.
self._state = ScriptRequestType.STOP
return ScriptRequest(ScriptRequestType.STOP)

View File

@@ -0,0 +1,170 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import threading
from dataclasses import dataclass, field
from typing import Callable, Counter, Dict, List, Optional, Set
from typing_extensions import Final, TypeAlias
from streamlit import runtime
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.PageProfile_pb2 import Command
from streamlit.runtime.state import SafeSessionState
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
LOGGER: Final = get_logger(__name__)
UserInfo: TypeAlias = Dict[str, Optional[str]]
@dataclass
class ScriptRunContext:
"""A context object that contains data for a "script run" - that is,
data that's scoped to a single ScriptRunner execution (and therefore also
scoped to a single connected "session").
ScriptRunContext is used internally by virtually every `st.foo()` function.
It is accessed only from the script thread that's created by ScriptRunner.
Streamlit code typically retrieves the active ScriptRunContext via the
`get_script_run_ctx` function.
"""
session_id: str
_enqueue: Callable[[ForwardMsg], None]
query_string: str
session_state: SafeSessionState
uploaded_file_mgr: UploadedFileManager
page_script_hash: str
user_info: UserInfo
gather_usage_stats: bool = False
command_tracking_deactivated: bool = False
tracked_commands: List[Command] = field(default_factory=list)
tracked_commands_counter: Counter[str] = field(default_factory=collections.Counter)
_set_page_config_allowed: bool = True
_has_script_started: bool = False
widget_ids_this_run: Set[str] = field(default_factory=set)
widget_user_keys_this_run: Set[str] = field(default_factory=set)
form_ids_this_run: Set[str] = field(default_factory=set)
cursors: Dict[int, "streamlit.cursor.RunningCursor"] = field(default_factory=dict)
dg_stack: List["streamlit.delta_generator.DeltaGenerator"] = field(
default_factory=list
)
def reset(self, query_string: str = "", page_script_hash: str = "") -> None:
self.cursors = {}
self.widget_ids_this_run = set()
self.widget_user_keys_this_run = set()
self.form_ids_this_run = set()
self.query_string = query_string
self.page_script_hash = page_script_hash
# Permit set_page_config when the ScriptRunContext is reused on a rerun
self._set_page_config_allowed = True
self._has_script_started = False
self.command_tracking_deactivated: bool = False
self.tracked_commands = []
self.tracked_commands_counter = collections.Counter()
def on_script_start(self) -> None:
self._has_script_started = True
def enqueue(self, msg: ForwardMsg) -> None:
"""Enqueue a ForwardMsg for this context's session."""
if msg.HasField("page_config_changed") and not self._set_page_config_allowed:
raise StreamlitAPIException(
"`set_page_config()` can only be called once per app, "
+ "and must be called as the first Streamlit command in your script.\n\n"
+ "For more information refer to the [docs]"
+ "(https://docs.streamlit.io/library/api-reference/utilities/st.set_page_config)."
)
# We want to disallow set_page config if one of the following occurs:
# - set_page_config was called on this message
# - The script has already started and a different st call occurs (a delta)
if msg.HasField("page_config_changed") or (
msg.HasField("delta") and self._has_script_started
):
self._set_page_config_allowed = False
# Pass the message up to our associated ScriptRunner.
self._enqueue(msg)
SCRIPT_RUN_CONTEXT_ATTR_NAME: Final = "streamlit_script_run_ctx"
def add_script_run_ctx(
thread: Optional[threading.Thread] = None, ctx: Optional[ScriptRunContext] = None
):
"""Adds the current ScriptRunContext to a newly-created thread.
This should be called from this thread's parent thread,
before the new thread starts.
Parameters
----------
thread : threading.Thread
The thread to attach the current ScriptRunContext to.
ctx : ScriptRunContext or None
The ScriptRunContext to add, or None to use the current thread's
ScriptRunContext.
Returns
-------
threading.Thread
The same thread that was passed in, for chaining.
"""
if thread is None:
thread = threading.current_thread()
if ctx is None:
ctx = get_script_run_ctx()
if ctx is not None:
setattr(thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, ctx)
return thread
def get_script_run_ctx(suppress_warning: bool = False) -> Optional[ScriptRunContext]:
"""
Parameters
----------
suppress_warning : bool
If True, don't log a warning if there's no ScriptRunContext.
Returns
-------
ScriptRunContext | None
The current thread's ScriptRunContext, or None if it doesn't have one.
"""
thread = threading.current_thread()
ctx: Optional[ScriptRunContext] = getattr(
thread, SCRIPT_RUN_CONTEXT_ATTR_NAME, None
)
if ctx is None and runtime.exists() and not suppress_warning:
# Only warn about a missing ScriptRunContext if suppress_warning is False, and
# we were started via `streamlit run`. Otherwise, the user is likely running a
# script "bare", and doesn't need to be warned about streamlit
# bits that are irrelevant when not connected to a session.
LOGGER.warning("Thread '%s': missing ScriptRunContext", thread.name)
return ctx
# Needed to avoid circular dependencies while running tests.
import streamlit

View File

@@ -0,0 +1,704 @@
# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import sys
import threading
import types
from contextlib import contextmanager
from enum import Enum
from timeit import default_timer as timer
from typing import Callable, Dict, Optional
from blinker import Signal
from streamlit import config, runtime, source_util, util
from streamlit.error_util import handle_uncaught_app_exception
from streamlit.logger import get_logger
from streamlit.proto.ClientState_pb2 import ClientState
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.runtime.scriptrunner import magic
from streamlit.runtime.scriptrunner.script_requests import (
RerunData,
ScriptRequests,
ScriptRequestType,
)
from streamlit.runtime.scriptrunner.script_run_context import (
ScriptRunContext,
add_script_run_ctx,
get_script_run_ctx,
)
from streamlit.runtime.state import (
SCRIPT_RUN_WITHOUT_ERRORS_KEY,
SafeSessionState,
SessionState,
)
from streamlit.runtime.uploaded_file_manager import UploadedFileManager
from streamlit.vendor.ipython.modified_sys_path import modified_sys_path
_LOGGER = get_logger(__name__)
class ScriptRunnerEvent(Enum):
## "Control" events. These are emitted when the ScriptRunner's state changes.
# The script started running.
SCRIPT_STARTED = "SCRIPT_STARTED"
# The script run stopped because of a compile error.
SCRIPT_STOPPED_WITH_COMPILE_ERROR = "SCRIPT_STOPPED_WITH_COMPILE_ERROR"
# The script run stopped because it ran to completion, or was
# interrupted by the user.
SCRIPT_STOPPED_WITH_SUCCESS = "SCRIPT_STOPPED_WITH_SUCCESS"
# The script run stopped in order to start a script run with newer widget state.
SCRIPT_STOPPED_FOR_RERUN = "SCRIPT_STOPPED_FOR_RERUN"
# The ScriptRunner is done processing the ScriptEventQueue and
# is shut down.
SHUTDOWN = "SHUTDOWN"
## "Data" events. These are emitted when the ScriptRunner's script has
## data to send to the frontend.
# The script has a ForwardMsg to send to the frontend.
ENQUEUE_FORWARD_MSG = "ENQUEUE_FORWARD_MSG"
"""
Note [Threading]
There are two kinds of threads in Streamlit, the main thread and script threads.
The main thread is started by invoking the Streamlit CLI, and bootstraps the
framework and runs the Tornado webserver.
A script thread is created by a ScriptRunner when it starts. The script thread
is where the ScriptRunner executes, including running the user script itself,
processing messages to/from the frontend, and all the Streamlit library function
calls in the user script.
It is possible for the user script to spawn its own threads, which could call
Streamlit functions. We restrict the ScriptRunner's execution control to the
script thread. Calling Streamlit functions from other threads is unlikely to
work correctly due to lack of ScriptRunContext, so we may add a guard against
it in the future.
"""
class ScriptRunner:
def __init__(
self,
session_id: str,
main_script_path: str,
client_state: ClientState,
session_state: SessionState,
uploaded_file_mgr: UploadedFileManager,
initial_rerun_data: RerunData,
user_info: Dict[str, Optional[str]],
):
"""Initialize the ScriptRunner.
(The ScriptRunner won't start executing until start() is called.)
Parameters
----------
session_id : str
The AppSession's id.
main_script_path : str
Path to our main app script.
client_state : ClientState
The current state from the client (widgets and query params).
uploaded_file_mgr : UploadedFileManager
The File manager to store the data uploaded by the file_uploader widget.
user_info: Dict
A dict that contains information about the current user. For now,
it only contains the user's email address.
{
"email": "example@example.com"
}
Information about the current user is optionally provided when a
websocket connection is initialized via the "X-Streamlit-User" header.
"""
self._session_id = session_id
self._main_script_path = main_script_path
self._uploaded_file_mgr = uploaded_file_mgr
self._user_info = user_info
# Initialize SessionState with the latest widget states
session_state.set_widgets_from_proto(client_state.widget_states)
self._client_state = client_state
self._session_state = SafeSessionState(session_state)
self._requests = ScriptRequests()
self._requests.request_rerun(initial_rerun_data)
self.on_event = Signal(
doc="""Emitted when a ScriptRunnerEvent occurs.
This signal is generally emitted on the ScriptRunner's script
thread (which is *not* the same thread that the ScriptRunner was
created on).
Parameters
----------
sender: ScriptRunner
The sender of the event (this ScriptRunner).
event : ScriptRunnerEvent
forward_msg : ForwardMsg | None
The ForwardMsg to send to the frontend. Set only for the
ENQUEUE_FORWARD_MSG event.
exception : BaseException | None
Our compile error. Set only for the
SCRIPT_STOPPED_WITH_COMPILE_ERROR event.
widget_states : streamlit.proto.WidgetStates_pb2.WidgetStates | None
The ScriptRunner's final WidgetStates. Set only for the
SHUTDOWN event.
"""
)
# Set to true while we're executing. Used by
# _maybe_handle_execution_control_request.
self._execing = False
# This is initialized in start()
self._script_thread: Optional[threading.Thread] = None
def __repr__(self) -> str:
return util.repr_(self)
def request_stop(self) -> None:
"""Request that the ScriptRunner stop running its script and
shut down. The ScriptRunner will handle this request when it reaches
an interrupt point.
Safe to call from any thread.
"""
self._requests.request_stop()
# "Disconnect" our SafeSessionState wrapper from its underlying
# SessionState instance. This will cause all further session_state
# operations in this ScriptRunner to no-op.
#
# After `request_stop` is called, our script will continue executing
# until it reaches a yield point. AppSession may also *immediately*
# spin up a new ScriptRunner after this call, which means we'll
# potentially have two active ScriptRunners for a brief period while
# this one is shutting down. Disconnecting our SessionState ensures
# that this ScriptRunner's thread won't introduce SessionState-
# related race conditions during this script overlap.
self._session_state.disconnect()
def request_rerun(self, rerun_data: RerunData) -> bool:
"""Request that the ScriptRunner interrupt its currently-running
script and restart it.
If the ScriptRunner has been stopped, this request can't be honored:
return False.
Otherwise, record the request and return True. The ScriptRunner will
handle the rerun request as soon as it reaches an interrupt point.
Safe to call from any thread.
"""
return self._requests.request_rerun(rerun_data)
def start(self) -> None:
"""Start a new thread to process the ScriptEventQueue.
This must be called only once.
"""
if self._script_thread is not None:
raise Exception("ScriptRunner was already started")
self._script_thread = threading.Thread(
target=self._run_script_thread,
name="ScriptRunner.scriptThread",
)
self._script_thread.start()
def _get_script_run_ctx(self) -> ScriptRunContext:
"""Get the ScriptRunContext for the current thread.
Returns
-------
ScriptRunContext
The ScriptRunContext for the current thread.
Raises
------
AssertionError
If called outside of a ScriptRunner thread.
RuntimeError
If there is no ScriptRunContext for the current thread.
"""
assert self._is_in_script_thread()
ctx = get_script_run_ctx()
if ctx is None:
# This should never be possible on the script_runner thread.
raise RuntimeError(
"ScriptRunner thread has a null ScriptRunContext. Something has gone very wrong!"
)
return ctx
def _run_script_thread(self) -> None:
"""The entry point for the script thread.
Processes the ScriptRequestQueue, which will at least contain the RERUN
request that will trigger the first script-run.
When the ScriptRequestQueue is empty, or when a SHUTDOWN request is
dequeued, this function will exit and its thread will terminate.
"""
assert self._is_in_script_thread()
_LOGGER.debug("Beginning script thread")
# Create and attach the thread's ScriptRunContext
ctx = ScriptRunContext(
session_id=self._session_id,
_enqueue=self._enqueue_forward_msg,
query_string=self._client_state.query_string,
session_state=self._session_state,
uploaded_file_mgr=self._uploaded_file_mgr,
page_script_hash=self._client_state.page_script_hash,
user_info=self._user_info,
gather_usage_stats=bool(config.get_option("browser.gatherUsageStats")),
)
add_script_run_ctx(threading.current_thread(), ctx)
request = self._requests.on_scriptrunner_ready()
while request.type == ScriptRequestType.RERUN:
# When the script thread starts, we'll have a pending rerun
# request that we'll handle immediately. When the script finishes,
# it's possible that another request has come in that we need to
# handle, which is why we call _run_script in a loop.
self._run_script(request.rerun_data)
request = self._requests.on_scriptrunner_ready()
assert request.type == ScriptRequestType.STOP
# Send a SHUTDOWN event before exiting. This includes the widget values
# as they existed after our last successful script run, which the
# AppSession will pass on to the next ScriptRunner that gets
# created.
client_state = ClientState()
client_state.query_string = ctx.query_string
client_state.page_script_hash = ctx.page_script_hash
widget_states = self._session_state.get_widget_states()
client_state.widget_states.widgets.extend(widget_states)
self.on_event.send(
self, event=ScriptRunnerEvent.SHUTDOWN, client_state=client_state
)
def _is_in_script_thread(self) -> bool:
"""True if the calling function is running in the script thread"""
return self._script_thread == threading.current_thread()
def _enqueue_forward_msg(self, msg: ForwardMsg) -> None:
"""Enqueue a ForwardMsg to our browser queue.
This private function is called by ScriptRunContext only.
It may be called from the script thread OR the main thread.
"""
# Whenever we enqueue a ForwardMsg, we also handle any pending
# execution control request. This means that a script can be
# cleanly interrupted and stopped inside most `st.foo` calls.
#
# (If "runner.installTracer" is true, then we'll actually be
# handling these requests in a callback called after every Python
# instruction instead.)
if not config.get_option("runner.installTracer"):
self._maybe_handle_execution_control_request()
# Pass the message to our associated AppSession.
self.on_event.send(
self, event=ScriptRunnerEvent.ENQUEUE_FORWARD_MSG, forward_msg=msg
)
def _maybe_handle_execution_control_request(self) -> None:
"""Check our current ScriptRequestState to see if we have a
pending STOP or RERUN request.
This function is called every time the app script enqueues a
ForwardMsg, which means that most `st.foo` commands - which generally
involve sending a ForwardMsg to the frontend - act as implicit
yield points in the script's execution.
"""
if not self._is_in_script_thread():
# We can only handle execution_control_request if we're on the
# script execution thread. However, it's possible for deltas to
# be enqueued (and, therefore, for this function to be called)
# in separate threads, so we check for that here.
return
if not self._execing:
# If the _execing flag is not set, we're not actually inside
# an exec() call. This happens when our script exec() completes,
# we change our state to STOPPED, and a statechange-listener
# enqueues a new ForwardEvent
return
request = self._requests.on_scriptrunner_yield()
if request is None:
# No RERUN or STOP request.
return
if request.type == ScriptRequestType.RERUN:
raise RerunException(request.rerun_data)
assert request.type == ScriptRequestType.STOP
raise StopException()
def _install_tracer(self) -> None:
"""Install function that runs before each line of the script."""
def trace_calls(frame, event, arg):
self._maybe_handle_execution_control_request()
return trace_calls
# Python interpreters are not required to implement sys.settrace.
if hasattr(sys, "settrace"):
sys.settrace(trace_calls)
@contextmanager
def _set_execing_flag(self):
"""A context for setting the ScriptRunner._execing flag.
Used by _maybe_handle_execution_control_request to ensure that
we only handle requests while we're inside an exec() call
"""
if self._execing:
raise RuntimeError("Nested set_execing_flag call")
self._execing = True
try:
yield
finally:
self._execing = False
def _run_script(self, rerun_data: RerunData) -> None:
"""Run our script.
Parameters
----------
rerun_data: RerunData
The RerunData to use.
"""
assert self._is_in_script_thread()
_LOGGER.debug("Running script %s", rerun_data)
start_time: float = timer()
prep_time: float = 0 # This will be overwritten once preparations are done.
# Reset DeltaGenerators, widgets, media files.
runtime.get_instance().media_file_mgr.clear_session_refs()
main_script_path = self._main_script_path
pages = source_util.get_pages(main_script_path)
# Safe because pages will at least contain the app's main page.
main_page_info = list(pages.values())[0]
current_page_info = None
uncaught_exception = None
if rerun_data.page_script_hash:
current_page_info = pages.get(rerun_data.page_script_hash, None)
elif not rerun_data.page_script_hash and rerun_data.page_name:
# If a user navigates directly to a non-main page of an app, we get
# the first script run request before the list of pages has been
# sent to the frontend. In this case, we choose the first script
# with a name matching the requested page name.
current_page_info = next(
filter(
# There seems to be this weird bug with mypy where it
# thinks that p can be None (which is impossible given the
# types of pages), so we add `p and` at the beginning of
# the predicate to circumvent this.
lambda p: p and (p["page_name"] == rerun_data.page_name),
pages.values(),
),
None,
)
else:
# If no information about what page to run is given, default to
# running the main page.
current_page_info = main_page_info
page_script_hash = (
current_page_info["page_script_hash"]
if current_page_info is not None
else main_page_info["page_script_hash"]
)
ctx = self._get_script_run_ctx()
ctx.reset(
query_string=rerun_data.query_string,
page_script_hash=page_script_hash,
)
self.on_event.send(
self,
event=ScriptRunnerEvent.SCRIPT_STARTED,
page_script_hash=page_script_hash,
)
# Compile the script. Any errors thrown here will be surfaced
# to the user via a modal dialog in the frontend, and won't result
# in their previous script elements disappearing.
try:
if current_page_info:
script_path = current_page_info["script_path"]
else:
script_path = main_script_path
# At this point, we know that either
# * the script corresponding to the hash requested no longer
# exists, or
# * we were not able to find a script with the requested page
# name.
# In both of these cases, we want to send a page_not_found
# message to the frontend.
msg = ForwardMsg()
msg.page_not_found.page_name = rerun_data.page_name
ctx.enqueue(msg)
with source_util.open_python_file(script_path) as f:
filebody = f.read()
if config.get_option("runner.magicEnabled"):
filebody = magic.add_magic(filebody, script_path)
code = compile( # type: ignore
filebody,
# Pass in the file path so it can show up in exceptions.
script_path,
# We're compiling entire blocks of Python, so we need "exec"
# mode (as opposed to "eval" or "single").
mode="exec",
# Don't inherit any flags or "future" statements.
flags=0,
dont_inherit=1,
# Use the default optimization options.
optimize=-1,
)
except Exception as ex:
# We got a compile error. Send an error event and bail immediately.
_LOGGER.debug("Fatal script error: %s", ex)
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
self.on_event.send(
self,
event=ScriptRunnerEvent.SCRIPT_STOPPED_WITH_COMPILE_ERROR,
exception=ex,
)
return
# If we get here, we've successfully compiled our script. The next step
# is to run it. Errors thrown during execution will be shown to the
# user as ExceptionElements.
if config.get_option("runner.installTracer"):
self._install_tracer()
# This will be set to a RerunData instance if our execution
# is interrupted by a RerunException.
rerun_exception_data: Optional[RerunData] = None
try:
# Create fake module. This gives us a name global namespace to
# execute the code in.
# TODO(vdonato): Double-check that we're okay with naming the
# module for every page `__main__`. I'm pretty sure this is
# necessary given that people will likely often write
# ```
# if __name__ == "__main__":
# ...
# ```
# in their scripts.
module = _new_module("__main__")
# Install the fake module as the __main__ module. This allows
# the pickle module to work inside the user's code, since it now
# can know the module where the pickled objects stem from.
# IMPORTANT: This means we can't use "if __name__ == '__main__'" in
# our code, as it will point to the wrong module!!!
sys.modules["__main__"] = module
# Add special variables to the module's globals dict.
# Note: The following is a requirement for the CodeHasher to
# work correctly. The CodeHasher is scoped to
# files contained in the directory of __main__.__file__, which we
# assume is the main script directory.
module.__dict__["__file__"] = script_path
with modified_sys_path(self._main_script_path), self._set_execing_flag():
# Run callbacks for widgets whose values have changed.
if rerun_data.widget_states is not None:
self._session_state.on_script_will_rerun(rerun_data.widget_states)
ctx.on_script_start()
prep_time = timer() - start_time
exec(code, module.__dict__)
self._session_state.maybe_check_serializable()
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = True
except RerunException as e:
rerun_exception_data = e.rerun_data
except StopException:
# This is thrown when the script executes `st.stop()`.
# We don't have to do anything here.
pass
except Exception as ex:
self._session_state[SCRIPT_RUN_WITHOUT_ERRORS_KEY] = False
uncaught_exception = ex
handle_uncaught_app_exception(uncaught_exception)
finally:
if rerun_exception_data:
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_FOR_RERUN
else:
finished_event = ScriptRunnerEvent.SCRIPT_STOPPED_WITH_SUCCESS
if ctx.gather_usage_stats:
try:
# Prevent issues with circular import
from streamlit.runtime.metrics_util import (
create_page_profile_message,
to_microseconds,
)
# Create and send page profile information
ctx.enqueue(
create_page_profile_message(
ctx.tracked_commands,
exec_time=to_microseconds(timer() - start_time),
prep_time=to_microseconds(prep_time),
uncaught_exception=type(uncaught_exception).__name__
if uncaught_exception
else None,
)
)
except Exception as ex:
# Always capture all exceptions since we want to make sure that
# the telemetry never causes any issues.
_LOGGER.debug("Failed to create page profile", exc_info=ex)
self._on_script_finished(ctx, finished_event)
# Use _log_if_error() to make sure we never ever ever stop running the
# script without meaning to.
_log_if_error(_clean_problem_modules)
if rerun_exception_data is not None:
self._run_script(rerun_exception_data)
def _on_script_finished(
self, ctx: ScriptRunContext, event: ScriptRunnerEvent
) -> None:
"""Called when our script finishes executing, even if it finished
early with an exception. We perform post-run cleanup here.
"""
# Tell session_state to update itself in response
self._session_state.on_script_finished(ctx.widget_ids_this_run)
# Signal that the script has finished. (We use SCRIPT_STOPPED_WITH_SUCCESS
# even if we were stopped with an exception.)
self.on_event.send(self, event=event)
# Remove orphaned files now that the script has run and files in use
# are marked as active.
runtime.get_instance().media_file_mgr.remove_orphaned_files()
# Force garbage collection to run, to help avoid memory use building up
# This is usually not an issue, but sometimes GC takes time to kick in and
# causes apps to go over resource limits, and forcing it to run between
# script runs is low cost, since we aren't doing much work anyway.
if config.get_option("runner.postScriptGC"):
gc.collect(2)
class ScriptControlException(BaseException):
"""Base exception for ScriptRunner."""
pass
class StopException(ScriptControlException):
"""Silently stop the execution of the user's script."""
pass
class RerunException(ScriptControlException):
"""Silently stop and rerun the user's script."""
def __init__(self, rerun_data: RerunData):
"""Construct a RerunException
Parameters
----------
rerun_data : RerunData
The RerunData that should be used to rerun the script
"""
self.rerun_data = rerun_data
def __repr__(self) -> str:
return util.repr_(self)
def _clean_problem_modules() -> None:
"""Some modules are stateful, so we have to clear their state."""
if "keras" in sys.modules:
try:
keras = sys.modules["keras"]
keras.backend.clear_session()
except Exception:
# We don't want to crash the app if we can't clear the Keras session.
pass
if "matplotlib.pyplot" in sys.modules:
try:
plt = sys.modules["matplotlib.pyplot"]
plt.close("all")
except Exception:
# We don't want to crash the app if we can't close matplotlib
pass
def _new_module(name: str) -> types.ModuleType:
"""Create a new module with the given name."""
return types.ModuleType(name)
# The reason this is not a decorator is because we want to make it clear at the
# calling location that this function is being used.
def _log_if_error(fn: Callable[[], None]) -> None:
try:
fn()
except Exception as e:
_LOGGER.warning(e)