8000 Allow changing entity ID by balloob · Pull Request #15637 · home-assistant/core · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Allow changing entity ID #15637

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions homeassistant/components/config/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
vol.Required('entity_id'): cv.entity_id,
# If passed in, we update value. Passing None will remove old value.
vol.Optional('name'): vol.Any(str, None),
vol.Optional('new_entity_id'): str,
})


Expand Down Expand Up @@ -74,13 +75,28 @@ async def update_entity():
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
return

entry = registry.async_update_entity(
msg['entity_id'], name=msg['name'])
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
changes = {}

hass.async_add_job(update_entity())
if 'name' in msg:
changes['name'] = msg['name']

if 'new_entity_id' in msg:
changes['new_entity_id'] = msg['new_entity_id']

try:
if changes:
entry = registry.async_update_entity(
msg['entity_id'], **changes)
except ValueError as err:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_info', str(err)
))
else:
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))

hass.async_create_task(update_entity())


@callback
Expand Down
26 changes: 25 additions & 1 deletion homeassistant/helpers/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class Entity:
# Name in the entity registry
registry_name = None

# Hold list for functions to call on remove.
_on_remove = None

@property
def should_poll(self) -> bool:
"""Return True if entity has to be polled for state.
Expand Down Expand Up @@ -324,8 +327,19 @@ def async_device_update(self, warning=True):
if self.parallel_updates:
self.parallel_updates.release()

@callback
def async_on_remove(self, func):
"""Add a function to call when entity removed."""
if self._on_remove is None:
self._on_remove = []
self._on_remove.append(func)

async def async_remove(self):
"""Remove entity from Home Assistant."""
if self._on_remove is not None:
while self._on_remove:
self._on_remove.pop()()

if self.platform is not None:
await self.platform.async_remove_entity(self.entity_id)
else:
Expand All @@ -335,7 +349,17 @@ async def async_remove(self):
def async_registry_updated(self, old, new):
"""Called when the entity registry has been updated."""
self.registry_name = new.name
self.async_schedule_update_ha_state()

if new.entity_id == self.entity_id:
self.async_schedule_update_ha_state()
return

async def readd():
"""Remove and add entity again."""
await self.async_remove()
await self.platform.async_add_entities([self])

self.hass.async_create_task(readd())

def __eq__(self, other):
"""Return the comparison."""
Expand Down
2 changes: 1 addition & 1 deletion homeassistant/helpers/entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ async def _async_add_entity(self, entity, update_before_add,

entity.entity_id = entry.entity_id
entity.registry_name = entry.name
entry.add_update_listener(entity)
entity.async_on_remove(entry.add_update_listener(entity))

# We won't generate an entity ID if the platform has already set one
# We will however make sure that platform cannot pick a registered ID
Expand Down
40 changes: 32 additions & 8 deletions homeassistant/helpers/entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@

import attr

from ..core import callback, split_entity_id
from ..loader import bind_hass
from ..util import ensure_unique_string, slugify
from ..util.yaml import load_yaml, save_yaml
from homeassistant.core import callback, split_entity_id, valid_entity_id
from homeassistant.loader import bind_hass
from homeassistant.util import ensure_unique_string, slugify
from homeassistant.util.yaml import load_yaml, save_yaml

PATH_REGISTRY = 'entity_registry.yaml'
DATA_REGISTRY = 'entity_registry'
Expand Down Expand Up @@ -63,8 +63,13 @@ def add_update_listener(self, listener):
"""Listen for when entry is updated.

Listener: Callback function(old_entry, new_entry)

Returns function to unlisten.
"""
self.update_listeners.append(weakref.ref(listener))
weak_listener = weakref.ref(listener)
self.update_listeners.append(weak_listener)

return lambda: self.update_listeners.remove(weak_listener)


class EntityRegistry:
Expand Down Expand Up @@ -133,13 +138,18 @@ def async_get_or_create(self, domain, platform, unique_id, *,
return entity

@callback
def async_update_entity(self, entity_id, *, name=_UNDEF):
def async_update_entity(self, entity_id, *, name=_UNDEF,
new_entity_id=_UNDEF):
"""Update properties of an entity."""
return self._async_update_entity(entity_id, name=name)
return self._async_update_entity(
entity_id,
name=name,
new_entity_id=new_entity_id
)

@callback
def _async_update_entity(self, entity_id, *, name=_UNDEF,
config_entry_id=_UNDEF):
config_entry_id=_UNDEF, new_entity_id=_UNDEF):
"""Private facing update properties method."""
old = self.entities[entity_id]

Expand All @@ -152,6 +162,20 @@ def _async_update_entity(self, entity_id, *, name=_UNDEF,
config_entry_id != old.config_entry_id):
changes['config_entry_id'] = config_entry_id

if new_entity_id is not _UNDEF and new_entity_id != old.entity_id:
if self.async_is_registered(new_entity_id):
raise ValueError('Entity is already registered')

if not valid_entity_id(new_entity_id):
raise ValueError('Invalid entity ID')

if (split_entity_id(new_entity_id)[0] !=
split_entity_id(entity_id)[0]):
raise ValueError('New entity ID should be same domain')

self.entities.pop(entity_id)
entity_id = changes['entity_id'] = new_entity_id

if not changes:
return old

Expand Down
44 changes: 39 additions & 5 deletions tests/components/config/test_entity_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ async def test_get_entity(hass, client):
}


async def test_update_entity(hass, client):
"""Test get entry."""
async def test_update_entity_name(hass, client):
"""Test updating entity name."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
Expand Down Expand Up @@ -92,7 +92,7 @@ async def test_update_entity(hass, client):


async def test_update_entity_no_changes(hass, client):
"""Test get entry."""
"""Test update entity with no changes."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
Expand Down Expand Up @@ -129,7 +129,7 @@ async def test_update_entity_no_changes(hass, client):


async def test_get_nonexisting_entity(client):
"""Test get entry."""
"""Test get entry with nonexisting entity."""
await client.send_json({
'id': 6,
'type': 'config/entity_registry/get',
Expand All @@ -141,7 +141,7 @@ async def test_get_nonexisting_entity(client):


async def test_update_nonexisting_entity(client):
"""Test get entry."""
"""Test update a nonexisting entity."""
await client.send_json({
'id': 6,
'type': 'config/entity_registry/update',
Expand All @@ -151,3 +151,37 @@ async def test_update_nonexisting_entity(client):
msg = await client.receive_json()

assert not msg['success']


async def test_update_entity_id(hass, client):
"""Test update entity id."""
mock_registry(hass, {
'test_domain.world': RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])

assert hass.states.get('test_domain.world') is not None

await client.send_json({
'id': 6,
'type': 'config/entity_registry/update',
'entity_id': 'test_domain.world',
'new_entity_id': 'test_domain.planet',
})

msg = await client.receive_json()

assert msg['result'] == {
'entity_id': 'test_domain.planet',
'name': None
}

assert hass.states.get('test_domain.world') is None
assert hass.states.get('test_domain.planet') is not None
12 changes: 12 additions & 0 deletions tests/helpers/test_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,3 +400,15 @@ def test_async_remove_no_platform(hass):
assert len(hass.states.async_entity_ids()) == 1
yield from ent.async_remove()
assert len(hass.states.async_entity_ids()) == 0


async def test_async_remove_runs_callbacks(hass):
"""Test async_remove method when no platform set."""
result = []

ent = entity.Entity()
ent.hass = hass
ent.entity_id = 'test.test'
ent.async_on_remove(lambda: result.append(1))
await ent.async_remove()
assert len(result) == 1
76 changes: 75 additions & 1 deletion tests/helpers/test_entity_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from unittest.mock import patch, Mock, MagicMock
from datetime import timedelta

import pytest

from homeassistant.exceptions import PlatformNotReady
import homeassistant.loader as loader
from homeassistant.helpers.entity import generate_entity_id
Expand Down Expand Up @@ -487,7 +489,7 @@ def test_registry_respect_entity_disabled(hass):
assert hass.states.async_entity_ids() == []


async def test_entity_registry_updates(hass):
async def test_entity_registry_updates_name(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
Expand Down Expand Up @@ -602,3 +604,75 @@ def test_not_fails_with_adding_empty_entities_(hass):
yield from component.async_add_entities([])

assert len(hass.states.async_entity_ids()) == 0


async def test_entity_registry_updates_entity_id(hass):
"""Test that updates on the entity registry update platform entities."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='Some name'
)
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])

state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'Some name'

registry.async_update_entity('test_domain.world',
new_entity_id='test_domain.planet')
await hass.async_block_till_done()
await hass.async_block_till_done()

assert hass.states.get('test_domain.world') is None
assert hass.states.get('test_domain.planet') is not None


async def test_entity_registry_updates_invalid_entity_id(hass):
"""Test that we can't update to an invalid entity id."""
registry = mock_registry(hass, {
'test_domain.world': entity_registry.RegistryEntry(
entity_id='test_domain.world',
unique_id='1234',
# Using component.async_add_entities is equal to platform "domain"
platform='test_platform',
name='Some name'
),
'test_domain.existing': entity_registry.RegistryEntry(
entity_id='test_domain.existing',
unique_id='5678',
platform='test_platform',
),
})
platform = MockEntityPlatform(hass)
entity = MockEntity(unique_id='1234')
await platform.async_add_entities([entity])

state = hass.states.get('test_domain.world')
assert state is not None
assert state.name == 'Some name'

with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='test_domain.existing')

with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='invalid_entity_id')

with pytest.raises(ValueError):
registry.async_update_entity('test_domain.world',
new_entity_id='diff_domain.world')

await hass.async_block_till_done()
await hass.async_block_till_done()

assert hass.states.get('test_domain.world') is not None
assert hass.states.get('invalid_entity_id') is None
assert hass.states.get('diff_domain.world') is None
0