8000 Improve type hints in stream by uvjustin · Pull Request #51837 · home-assistant/core · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Improve type hints in stream #51837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .strict-typing
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ homeassistant.components.sensor.*
homeassistant.components.slack.*
homeassistant.components.sonos.media_player
homeassistant.components.ssdp.*
homeassistant.components.stream.*
homeassistant.components.sun.*
homeassistant.components.switch.*
homeassistant.components.synology_dsm.*
Expand Down
68 changes: 39 additions & 29 deletions homeassistant/components/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
"""
from __future__ import annotations

from collections.abc import Mapping
import logging
import re
import secrets
import threading
import time
from types import MappingProxyType
from typing import cast

from homeassistant.const import EVENT_HOMEASSISTANT_STOP
from homeassistant.core import callback
from homeassistant.core import Event, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.typing import ConfigType

from .const import (
ATTR_ENDPOINTS,
Expand All @@ -40,18 +43,21 @@
)
from .core import PROVIDERS, IdleTimer, StreamOutput
from .hls import async_setup_hls
from .recorder import RecorderOutput

_LOGGER = logging.getLogger(__name__)

STREAM_SOURCE_RE = re.compile("//.*:.*@")


def redact_credentials(data):
def redact_credentials(data: str) -> str:
"""Redact credentials from string data."""
return STREAM_SOURCE_RE.sub("//****:****@", data)


def create_stream(hass, stream_source, options=None):
def create_stream(
hass: HomeAssistant, stream_source: str, options: dict[str, str]
) -> Stream:
"""Create a stream with the specified identfier based on the source url.

The stream_source is typically an rtsp url and options are passed into
Expand All @@ -60,9 +66,6 @@ def create_stream(hass, stream_source, options=None):
if DOMAIN not in hass.config.components:
raise HomeAssistantError("Stream integration is not set up.")

if options is None:
options = {}

# For RTSP streams, prefer TCP
if isinstance(stream_source, str) and stream_source[:7] == "rtsp://":
options = {
Expand All @@ -76,7 +79,7 @@ def create_stream(hass, stream_source, options=None):
return stream


async def async_setup(hass, config):
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up stream."""
# Set log level to error for libav
logging.getLogger("libav").setLevel(logging.ERROR)
Expand All @@ -98,7 +101,7 @@ async def async_setup(hass, config):
async_setup_recorder(hass)

@callback
def shutdown(event):
def shutdown(event: Event) -> None:
"""Stop all stream workers."""
for stream in hass.data[DOMAIN][ATTR_STREAMS]:
stream.keepalive = False
Expand All @@ -113,41 +116,43 @@ def shutdown(event):
class Stream:
"""Represents a single stream."""

def __init__(self, hass, source, options=None):
def __init__(
self, hass: HomeAssistant, source: str, options: dict[str, str]
) -> None:
"""Initialize a stream."""
self.hass = hass
self.source = source
self.options = options
self.keepalive = False
self.access_token = None
self._thread = None
self.access_token: str | None = None
self._thread: threading.Thread | None = None
self._thread_quit = threading.Event()
self._outputs: dict[str, StreamOutput] = {}
self._fast_restart_once = False

if self.options is None:
self.options = {}

def endpoint_url(self, fmt: str) -> str:
"""Start the stream and returns a url for the output format."""
if fmt not in self._outputs:
raise ValueError(f"Stream is not configured for format '{fmt}'")
if not self.access_token:
self.access_token = secrets.token_hex()
return self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt].format(self.access_token)
endpoint_fmt: str = self.hass.data[DOMAIN][ATTR_ENDPOINTS][fmt]
return endpoint_fmt.format(self.access_token)

def outputs(self):
def outputs(self) -> Mapping[str, StreamOutput]:
"""Return a copy of the stream outputs."""
# A copy is returned so the caller can iterate through the outputs
# without concern about self._outputs being modified from another thread.
return MappingProxyType(self._outputs.copy())

def add_provider(self, fmt, timeout=OUTPUT_IDLE_TIMEOUT):
def add_provider(
self, fmt: str, timeout: int = OUTPUT_IDLE_TIMEOUT
) -> StreamOutput:
"""Add provider output stream."""
if not self._outputs.get(fmt):

@callback
def idle_callback():
def idle_callback() -> None:
if (
not self.keepalive or fmt == RECORDER_PROVIDER
) and fmt in self._outputs:
Expand All @@ -160,7 +165,7 @@ def idle_callback():
self._outputs[fmt] = provider
return self._outputs[fmt]

def remove_provider(self, provider):
def remove_provider(self, provider: StreamOutput) -> None:
"""Remove provider output stream."""
if provider.name in self._outputs:
self._outputs[provider.name].cleanup()
Expand All @@ -169,12 +174,12 @@ def remove_provider(self, provider):
if not self._outputs:
self.stop()

def check_idle(self):
def check_idle(self) -> None:
"""Reset access token if all providers are idle."""
if all(p.idle for p in self._outputs.values()):
self.access_token = None

def start(self):
def start(self) -> None:
"""Start a stream."""
if self._thread is None or not self._thread.is_alive():
if self._thread is not None:
Expand All @@ -189,14 +194,14 @@ def start(self):
self._thread.start()
_LOGGER.info("Started stream: %s", redact_credentials(str(self.source)))

def update_source(self, new_source):
def update_source(self, new_source: str) -> None:
"""Restart the stream with a new stream source."""
_LOGGER.debug("Updating stream source %s", new_source)
self.source = new_source
self._fast_restart_once = True
self._thread_quit.set()

def _run_worker(self):
def _run_worker(self) -> None:
"""Handle consuming streams and restart keepalive streams."""
# Keep import here so that we can import stream integration without installing reqs
# pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -229,33 +234,35 @@ def _run_worker(self):
)
self._worker_finished()

def _worker_finished(self):
def _worker_finished(self) -> None:
"""Schedule cleanup of all outputs."""

@callback
def remove_outputs():
def remove_outputs() -> None:
for provider in self.outputs().values():
self.remove_provider(provider)

self.hass.loop.call_soon_threadsafe(remove_outputs)

def stop(self):
def stop(self) -> None:
"""Remove outputs and access token."""
self._outputs = {}
self.access_token = None

if not self.keepalive:
self._stop()

def _stop(self):
def _stop(self) -> None:
"""Stop worker thread."""
if self._thread is not None:
self._thread_quit.set()
self._thread.join()
self._thread = None
_LOGGER.info("Stopped stream: %s", redact_credentials(str(self.source)))

async def async_record(self, video_path, duration=30, lookback=5):
async def async_record(
self, video_path: str, duration: int = 30, lookback: int = 5
) -> None:
"""Make a .mp4 recording from a provided stream."""

# Check for file access
Expand All @@ -265,10 +272,13 @@ async def async_record(self, video_path, duration=30, lookback=5):
# Add recorder
recorder = self.outputs().get(RECORDER_PROVIDER)
if recorder:
assert isinstance(recorder, RecorderOutput)
raise HomeAssistantError(
f"Stream already recording to {recorder.video_path}!"
)
recorder = self.add_provider(RECORDER_PROVIDER, timeout=duration)
recorder = cast(
RecorderOutput, self.add_provider(RECORDER_PROVIDER, timeout=duration)
)
recorder.video_path = video_path

self.start()
Expand Down
36 changes: 23 additions & 13 deletions homeassistant/components/stream/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
import asyncio
from collections import deque
import datetime
from typing import Callable
from typing import TYPE_CHECKING

from aiohttp import web
import attr

from homeassistant.components.http import HomeAssistantView
from homeassistant.core import HomeAssistant, callback
from homeassistant.components.http.view import HomeAssistantView
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
from homeassistant.helpers.event import async_call_later
from homeassistant.util.decorator import Registry

from .const import ATTR_STREAMS, DOMAIN

if TYPE_CHECKING:
from . import Stream

PROVIDERS = Registry()


Expand Down Expand Up @@ -59,34 +62,34 @@ class IdleTimer:
"""

def __init__(
self, hass: HomeAssistant, timeout: int, idle_callback: Callable[[], None]
self, hass: HomeAssistant, timeout: int, idle_callback: CALLBACK_TYPE
) -> None:
"""Initialize IdleTimer."""
self._hass = hass
self._timeout = timeout
self._callback = idle_callback
self._unsub = None
self._unsub: CALLBACK_TYPE | None = None
self.idle = False

def start(self):
def start(self) -> None:
"""Start the idle timer if not already started."""
self.idle = False
if self._unsub is None:
self._unsub = async_call_later(self._hass, self._timeout, self.fire)

def awake(self):
def awake(self) -> None:
"""Keep the idle time alive by resetting the timeout."""
self.idle = False
# Reset idle timeout
self.clear()
self._unsub = async_call_later(self._hass, self._timeout, self.fire)

def clear(self):
def clear(self) -> None:
"""Clear and disable the timer if it has not already fired."""
if self._unsub is not None:
self._unsub()

def fire(self, _now=None):
def fire(self, _now: datetime.datetime) -> None:
"""Invoke the idle timeout callback, called when the alarm fires."""
self.idle = True
self._unsub = None
Expand All @@ -97,7 +100,10 @@ class StreamOutput:
"""Represents a stream output."""

def __init__(
self, hass: HomeAssistant, idle_timer: IdleTimer, deque_maxlen: int = None
self,
hass: HomeAssistant,
idle_timer: IdleTimer,
deque_maxlen: int | None = None,
) -> None:
"""Initialize a stream output."""
self._hass = hass
Expand Down Expand Up @@ -172,7 +178,7 @@ def _async_put(self, segment: Segment) -> None:
self._event.set()
self._event.clear()

def cleanup(self):
def cleanup(self) -> None:
"""Handle cleanup."""
self._event.set()
self.idle_timer.clear()
Expand All @@ -190,7 +196,9 @@ class StreamView(HomeAssistantView):
requires_auth = False
platform = None

async def get(self, request, token, sequence=None):
async def get(
self, request: web.Request, token: str, sequence: str = ""
) -> web.StreamResponse:
"""Start a GET request."""
hass = request.app["hass"]

Expand All @@ -207,6 +215,8 @@ async def get(self, request, token, sequence=None):

return await self.handle(request, stream, sequence)

async def handle(self, request, stream, sequence):
async def handle(
self, request: web.Request, stream: Stream, sequence: str
) -> web.StreamResponse:
"""Handle the stream request."""
raise NotImplementedError()
Loading
0