8000 Add and restore context in recorder [MIGRATION] by balloob · Pull Request #15859 · home-assistant/core · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Add and restore context in recorder [MIGRATION] #15859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions homeassistant/components/recorder/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,27 @@ def _drop_index(engine, table_name, index_name):
"critical operation.", index_name, table_name)


def _add_columns(engine, table_name, columns_def):
"""Add columns to a table."""
from sqlalchemy import text
from sqlalchemy.exc import SQLAlchemyError

columns_def = ['ADD COLUMN {}'.format(col_def) for col_def in columns_def]

try:
engine.execute(text("ALTER TABLE {table} {columns_def}".format(
table=table_name,
columns_def=', '.join(columns_def))))
return
except SQLAlchemyError:
pass

for column_def in columns_def:
engine.execute(text("ALTER TABLE {table} {column_def}".format(
table=table_name,
column_def=column_def)))


def _apply_update(engine, new_version, old_version):
"""Perform operations to bring schema up to date."""
if new_version == 1:
Expand Down Expand Up @@ -146,6 +167,19 @@ def _apply_update(engine, new_version, old_version):
elif new_version == 5:
# Create supporting index for States.event_id foreign key
_create_index(engine, "states", "ix_states_event_id")
elif new_version == 6:
_add_columns(engine, "events", [
'context_id CHARACTER(36)',
'context_user_id CHARACTER(36)',
])
_create_index(engine, "events", "ix_events_context_id")
_create_index(engine, "events", "ix_events_context_user_id")
_add_columns(engine, "states", [
'context_id CHARACTER(36)',
'context_user_id CHARACTER(36)',
])
_create_index(engine, "states", "ix_states_context_id")
_create_index(engine, "states", "ix_states_context_user_id")
else:
raise ValueError("No schema migration defined for version {}"
.format(new_version))
Expand Down
33 changes: 27 additions & 6 deletions homeassistant/components/recorder/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@
from sqlalchemy.ext.declarative import declarative_base

import homeassistant.util.dt as dt_util
from homeassistant.core import Event, EventOrigin, State, split_entity_id
from homeassistant.core import (
Context, Event, EventOrigin, State, split_entity_id)
from homeassistant.remote import JSONEncoder

# SQLAlchemy Schema
# pylint: disable=invalid-name
Base = declarative_base()

SCHEMA_VERSION = 5
SCHEMA_VERSION = 6

_LOGGER = logging.getLogger(__name__)

Expand All @@ -31,23 +32,32 @@ class Events(Base): # type: ignore
origin = Column(String(32))
time_fired = Column(DateTime(timezone=True), index=True)
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)

@staticmethod
def from_event(event):
"""Create an event database object from a native event."""
return Events(event_type=event.event_type,
event_data=json.dumps(event.data, cls=JSONEncoder),
origin=str(event.origin),
time_fired=event.time_fired)
time_fired=event.time_fired,
context_id=event.context.id,
context_user_id=event.context.user_id)

def to_native(self):
"""Convert to a natve HA Event."""
context = Context(
id=self.context_id,
user_id=self.context_user_id
)
try:
return Event(
self.event_type,
json.loads(self.event_data),
EventOrigin(self.origin),
_process_timestamp(self.time_fired)
_process_timestamp(self.time_fired),
context=context,
)
except ValueError:
# When json.loads fails
Expand All @@ -69,6 +79,8 @@ class States(Base): # type: ignore
last_updated = Column(DateTime(timezone=True), default=datetime.utcnow,
index=True)
created = Column(DateTime(timezone=True), default=datetime.utcnow)
context_id = Column(String(36), index=True)
context_user_id = Column(String(36), index=True)

__table_args__ = (
# Used for fetching the state of entities at a specific time
Expand All @@ -82,7 +94,11 @@ def from_event(event):
entity_id = event.data['entity_id']
state = event.data.get('new_state')

dbstate = States(entity_id=entity_id)
dbstate = States(
entity_id=entity_id,
context_id=event.context.id,
context_user_id=event.context.user_id,
)

# State got deleted
if state is None:
Expand All @@ -103,12 +119,17 @@ def from_event(event):

def to_native(se 10000 lf):
"""Convert to an HA state object."""
context = Context(
id=self.context_id,
user_id=self.context_user_id
)
try:
return State(
self.entity_id, self.state,
json.loads(self.attributes),
_process_timestamp(self.last_changed),
_process_timestamp(self.last_updated)
_process_timestamp(self.last_updated),
context=context,
)
except ValueError:
# When json.loads fails
Expand Down
6 changes: 4 additions & 2 deletions homeassistant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,8 @@ def __eq__(self, other: Any) -> bool:
self.event_type == other.event_type and
self.data == other.data and
self.origin == other.origin and
self.time_fired == other.time_fired)
self.time_fired == other.time_fired and
self.context == other.context)


class EventBus:
Expand Down Expand Up @@ -695,7 +696,8 @@ def __eq__(self, other: Any) -> bool:
return (self.__class__ == other.__class__ and # type: ignore
self.entity_id == other.entity_id and
self.state == other.state and
self.attributes == other.attributes)
self.attributes == other.attributes and
self.context == other.context)

def __repr__(self) -> str:
"""Return the representation of the states."""
Expand Down
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def mock_state_change_event(hass, new_state, old_state=None):
if old_state:
event_data['old_state'] = old_state

hass.bus.fire(EVENT_STATE_CHANGED, event_data)
hass.bus.fire(EVENT_STATE_CHANGED, event_data, context=new_state.context)


@asyncio.coroutin 67F4 e
Expand Down
2 changes: 1 addition & 1 deletion tests/components/recorder/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_from_event(self):
'entity_id': 'sensor.temperature',
'old_state': None,
'new_state': state,
})
}, context=state.context)
assert state == States.from_event(event).to_native()

def test_from_event_to_delete_state(self):
Expand Down
7 changes: 4 additions & 3 deletions tests/components/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ def test_get_states(self):
self.wait_recording_done()

# Get states returns everything before POINT
self.assertEqual(states,
sorted(history.get_states(self.hass, future),
key=lambda state: state.entity_id))
for state1, state2 in zip(
states, sorted(history.get_states(self.hass, future),
key=lambda state: state.entity_id)):
assert state1 == state2

# Test get_state here because we have a DB setup
self.assertEqual(
Expand Down
3 changes: 2 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def test_eq(self):
"""Test events."""
now = dt_util.utcnow()
data = {'some': 'attr'}
context = ha.Context()
event1, event2 = [
ha.Event('some_type', data, time_fired=now)
ha.Event('some_type', data, time_fired=now, context=context)
for _ in range(2)
]

Expand Down
0