usda-hass-config/custom_components/remote_homeassistant/__init__.py

746 lines
26 KiB
Python

"""
Connect two Home Assistant instances via the Websocket API.
For more details about this component, please refer to the documentation at
https://home-assistant.io/components/remote_homeassistant/
"""
import asyncio
import copy
import fnmatch
import inspect
import logging
import re
from contextlib import suppress
import aiohttp
import homeassistant.components.websocket_api.auth as api
import homeassistant.helpers.config_validation as cv
import voluptuous as vol
from homeassistant.config import DATA_CUSTOMIZE
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import (CONF_ABOVE, CONF_ACCESS_TOKEN, CONF_BELOW,
CONF_DOMAINS, CONF_ENTITIES, CONF_ENTITY_ID,
CONF_EXCLUDE, CONF_HOST, CONF_INCLUDE,
CONF_PORT, CONF_UNIT_OF_MEASUREMENT,
CONF_VERIFY_SSL, EVENT_CALL_SERVICE,
EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED,
SERVICE_RELOAD)
from homeassistant.core import (Context, EventOrigin, HomeAssistant, callback,
split_entity_id)
from homeassistant.helpers import device_registry as dr
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.reload import async_integration_yaml_config
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
from homeassistant.setup import async_setup_component
from custom_components.remote_homeassistant.views import DiscoveryInfoView
from .const import (CONF_EXCLUDE_DOMAINS, CONF_EXCLUDE_ENTITIES,
CONF_INCLUDE_DOMAINS, CONF_INCLUDE_ENTITIES,
CONF_LOAD_COMPONENTS, CONF_OPTIONS, CONF_REMOTE_CONNECTION,
CONF_SERVICE_PREFIX, CONF_SERVICES, CONF_UNSUB_LISTENER,
DOMAIN, REMOTE_ID, DEFAULT_MAX_MSG_SIZE)
from .proxy_services import ProxyServices
from .rest_api import UnsupportedVersion, async_get_discovery_info
_LOGGER = logging.getLogger(__name__)
PLATFORMS = ["sensor"]
CONF_INSTANCES = "instances"
CONF_SECURE = "secure"
CONF_SUBSCRIBE_EVENTS = "subscribe_events"
CONF_ENTITY_PREFIX = "entity_prefix"
CONF_FILTER = "filter"
CONF_MAX_MSG_SIZE = "max_message_size"
STATE_INIT = "initializing"
STATE_CONNECTING = "connecting"
STATE_CONNECTED = "connected"
STATE_AUTH_INVALID = "auth_invalid"
STATE_AUTH_REQUIRED = "auth_required"
STATE_RECONNECTING = "reconnecting"
STATE_DISCONNECTED = "disconnected"
DEFAULT_ENTITY_PREFIX = ""
INSTANCES_SCHEMA = vol.Schema(
{
vol.Required(CONF_HOST): cv.string,
vol.Optional(CONF_PORT, default=8123): cv.port,
vol.Optional(CONF_SECURE, default=False): cv.boolean,
vol.Optional(CONF_VERIFY_SSL, default=True): cv.boolean,
vol.Required(CONF_ACCESS_TOKEN): cv.string,
vol.Optional(CONF_MAX_MSG_SIZE, default=DEFAULT_MAX_MSG_SIZE): vol.Coerce(int),
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema(
{
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
vol.Optional(CONF_INCLUDE, default={}): vol.Schema(
{
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
cv.ensure_list, [cv.string]
),
}
),
vol.Optional(CONF_FILTER, default=[]): vol.All(
cv.ensure_list,
[
vol.Schema(
{
vol.Optional(CONF_ENTITY_ID): cv.string,
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
vol.Optional(CONF_ABOVE): vol.Coerce(float),
vol.Optional(CONF_BELOW): vol.Coerce(float),
}
)
],
),
vol.Optional(CONF_SUBSCRIBE_EVENTS): cv.ensure_list,
vol.Optional(CONF_ENTITY_PREFIX, default=DEFAULT_ENTITY_PREFIX): cv.string,
vol.Optional(CONF_LOAD_COMPONENTS): cv.ensure_list,
vol.Required(CONF_SERVICE_PREFIX, default="remote_"): cv.string,
vol.Optional(CONF_SERVICES): cv.ensure_list,
}
)
CONFIG_SCHEMA = vol.Schema(
{
DOMAIN: vol.Schema(
{
vol.Required(CONF_INSTANCES): vol.All(
cv.ensure_list, [INSTANCES_SCHEMA]
),
}
),
},
extra=vol.ALLOW_EXTRA,
)
HEARTBEAT_INTERVAL = 20
HEARTBEAT_TIMEOUT = 5
INTERNALLY_USED_EVENTS = [EVENT_STATE_CHANGED]
def async_yaml_to_config_entry(instance_conf):
"""Convert YAML config into data and options used by a config entry."""
conf = instance_conf.copy()
options = {}
if CONF_INCLUDE in conf:
include = conf.pop(CONF_INCLUDE)
if CONF_ENTITIES in include:
options[CONF_INCLUDE_ENTITIES] = include[CONF_ENTITIES]
if CONF_DOMAINS in include:
options[CONF_INCLUDE_DOMAINS] = include[CONF_DOMAINS]
if CONF_EXCLUDE in conf:
exclude = conf.pop(CONF_EXCLUDE)
if CONF_ENTITIES in exclude:
options[CONF_EXCLUDE_ENTITIES] = exclude[CONF_ENTITIES]
if CONF_DOMAINS in exclude:
options[CONF_EXCLUDE_DOMAINS] = exclude[CONF_DOMAINS]
for option in [
CONF_FILTER,
CONF_SUBSCRIBE_EVENTS,
CONF_ENTITY_PREFIX,
CONF_LOAD_COMPONENTS,
CONF_SERVICE_PREFIX,
CONF_SERVICES,
]:
if option in conf:
options[option] = conf.pop(option)
return conf, options
async def _async_update_config_entry_if_from_yaml(hass, entries_by_id, conf):
"""Update a config entry with the latest yaml."""
try:
info = await async_get_discovery_info(
hass,
conf[CONF_HOST],
conf[CONF_PORT],
conf[CONF_SECURE],
conf[CONF_ACCESS_TOKEN],
conf[CONF_VERIFY_SSL],
)
except Exception:
_LOGGER.exception(f"reload of {conf[CONF_HOST]} failed")
else:
entry = entries_by_id.get(info["uuid"])
if entry:
data, options = async_yaml_to_config_entry(conf)
hass.config_entries.async_update_entry(entry, data=data, options=options)
async def setup_remote_instance(hass: HomeAssistantType):
hass.http.register_view(DiscoveryInfoView())
async def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Set up the remote_homeassistant component."""
hass.data.setdefault(DOMAIN, {})
async def _handle_reload(service):
"""Handle reload service call."""
config = await async_integration_yaml_config(hass, DOMAIN)
if not config or DOMAIN not in config:
return
current_entries = hass.config_entries.async_entries(DOMAIN)
entries_by_id = {entry.unique_id: entry for entry in current_entries}
instances = config[DOMAIN][CONF_INSTANCES]
update_tasks = [
_async_update_config_entry_if_from_yaml(hass, entries_by_id, instance)
for instance in instances
]
await asyncio.gather(*update_tasks)
hass.async_create_task(setup_remote_instance(hass))
hass.helpers.service.async_register_admin_service(
DOMAIN,
SERVICE_RELOAD,
_handle_reload,
)
instances = config.get(DOMAIN, {}).get(CONF_INSTANCES, [])
for instance in instances:
hass.async_create_task(
hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=instance
)
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Set up Remote Home-Assistant from a config entry."""
_async_import_options_from_yaml(hass, entry)
if entry.unique_id == REMOTE_ID:
hass.async_create_task(setup_remote_instance(hass))
return True
else:
remote = RemoteConnection(hass, entry)
hass.data[DOMAIN][entry.entry_id] = {
CONF_REMOTE_CONNECTION: remote,
CONF_UNSUB_LISTENER: entry.add_update_listener(_update_listener),
}
async def setup_components_and_platforms():
"""Set up platforms and initiate connection."""
for domain in entry.options.get(CONF_LOAD_COMPONENTS, []):
hass.async_create_task(async_setup_component(hass, domain, {}))
await asyncio.gather(
*[
hass.config_entries.async_forward_entry_setup(entry, platform)
for platform in PLATFORMS
]
)
await remote.async_connect()
hass.async_create_task(setup_components_and_platforms())
return True
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Unload a config entry."""
unload_ok = all(
await asyncio.gather(
*[
hass.config_entries.async_forward_entry_unload(entry, platform)
for platform in PLATFORMS
]
)
)
if unload_ok:
data = hass.data[DOMAIN].pop(entry.entry_id)
await data[CONF_REMOTE_CONNECTION].async_stop()
data[CONF_UNSUB_LISTENER]()
return unload_ok
@callback
def _async_import_options_from_yaml(hass: HomeAssistant, entry: ConfigEntry):
"""Import options from YAML into options section of config entry."""
if CONF_OPTIONS in entry.data:
data = entry.data.copy()
options = data.pop(CONF_OPTIONS)
hass.config_entries.async_update_entry(entry, data=data, options=options)
async def _update_listener(hass, config_entry):
"""Update listener."""
await hass.config_entries.async_reload(config_entry.entry_id)
class RemoteConnection(object):
"""A Websocket connection to a remote home-assistant instance."""
def __init__(self, hass, config_entry):
"""Initialize the connection."""
self._hass = hass
self._entry = config_entry
self._secure = config_entry.data.get(CONF_SECURE, False)
self._verify_ssl = config_entry.data.get(CONF_VERIFY_SSL, False)
self._access_token = config_entry.data.get(CONF_ACCESS_TOKEN)
self._max_msg_size = config_entry.data.get(CONF_MAX_MSG_SIZE)
# see homeassistant/components/influxdb/__init__.py
# for include/exclude logic
self._whitelist_e = set(config_entry.options.get(CONF_INCLUDE_ENTITIES, []))
self._whitelist_d = set(config_entry.options.get(CONF_INCLUDE_DOMAINS, []))
self._blacklist_e = set(config_entry.options.get(CONF_EXCLUDE_ENTITIES, []))
self._blacklist_d = set(config_entry.options.get(CONF_EXCLUDE_DOMAINS, []))
self._filter = [
{
CONF_ENTITY_ID: re.compile(fnmatch.translate(f.get(CONF_ENTITY_ID)))
if f.get(CONF_ENTITY_ID)
else None,
CONF_UNIT_OF_MEASUREMENT: f.get(CONF_UNIT_OF_MEASUREMENT),
CONF_ABOVE: f.get(CONF_ABOVE),
CONF_BELOW: f.get(CONF_BELOW),
}
for f in config_entry.options.get(CONF_FILTER, [])
]
self._subscribe_events = set(
config_entry.options.get(CONF_SUBSCRIBE_EVENTS, []) + INTERNALLY_USED_EVENTS
)
self._entity_prefix = config_entry.options.get(CONF_ENTITY_PREFIX, "")
self._connection = None
self._heartbeat_task = None
self._is_stopping = False
self._entities = set()
self._all_entity_names = set()
self._handlers = {}
self._remove_listener = None
self.proxy_services = ProxyServices(hass, config_entry, self)
self.set_connection_state(STATE_CONNECTING)
self.__id = 1
def _prefixed_entity_id(self, entity_id):
if self._entity_prefix:
domain, object_id = split_entity_id(entity_id)
object_id = self._entity_prefix + object_id
entity_id = domain + "." + object_id
return entity_id
return entity_id
def set_connection_state(self, state):
"""Change current connection state."""
signal = f"remote_homeassistant_{self._entry.unique_id}"
async_dispatcher_send(self._hass, signal, state)
@callback
def _get_url(self):
"""Get url to connect to."""
return "%s://%s:%s/api/websocket" % (
"wss" if self._secure else "ws",
self._entry.data[CONF_HOST],
self._entry.data[CONF_PORT],
)
async def async_connect(self):
"""Connect to remote home-assistant websocket..."""
async def _async_stop_handler(event):
"""Stop when Home Assistant is shutting down."""
await self.async_stop()
async def _async_instance_get_info():
"""Fetch discovery info from remote instance."""
try:
return await async_get_discovery_info(
self._hass,
self._entry.data[CONF_HOST],
self._entry.data[CONF_PORT],
self._secure,
self._access_token,
self._verify_ssl,
)
except OSError:
_LOGGER.exception("failed to connect")
except UnsupportedVersion:
_LOGGER.error("Unsupported version, at least 0.111 is required.")
except Exception:
_LOGGER.exception("failed to fetch instance info")
return None
@callback
def _async_instance_id_match(info):
"""Verify if remote instance id matches the expected id."""
if not info:
return False
if info and info["uuid"] != self._entry.unique_id:
_LOGGER.error(
"instance id not matching: %s != %s",
info["uuid"],
self._entry.unique_id,
)
return False
return True
url = self._get_url()
session = async_get_clientsession(self._hass, self._verify_ssl)
self.set_connection_state(STATE_CONNECTING)
while True:
info = await _async_instance_get_info()
# Verify we are talking to correct instance
if not _async_instance_id_match(info):
self.set_connection_state(STATE_RECONNECTING)
await asyncio.sleep(10)
continue
try:
_LOGGER.info("Connecting to %s", url)
self._connection = await session.ws_connect(url, max_msg_size = self._max_msg_size)
except aiohttp.client_exceptions.ClientError:
_LOGGER.error("Could not connect to %s, retry in 10 seconds...", url)
self.set_connection_state(STATE_RECONNECTING)
await asyncio.sleep(10)
else:
_LOGGER.info("Connected to home-assistant websocket at %s", url)
break
self._hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop_handler)
device_registry = dr.async_get(self._hass)
device_registry.async_get_or_create(
config_entry_id=self._entry.entry_id,
identifiers={(DOMAIN, f"remote_{self._entry.unique_id}")},
name=info.get("location_name"),
manufacturer="Home Assistant",
model=info.get("installation_type"),
sw_version=info.get("ha_version"),
)
asyncio.ensure_future(self._recv())
self._heartbeat_task = self._hass.loop.create_task(self._heartbeat_loop())
async def _heartbeat_loop(self):
"""Send periodic heartbeats to remote instance."""
while not self._connection.closed:
await asyncio.sleep(HEARTBEAT_INTERVAL)
_LOGGER.debug("Sending ping")
event = asyncio.Event()
def resp(message):
_LOGGER.debug("Got pong: %s", message)
event.set()
await self.call(resp, "ping")
try:
await asyncio.wait_for(event.wait(), HEARTBEAT_TIMEOUT)
except asyncio.TimeoutError:
_LOGGER.error("heartbeat failed")
# Schedule closing on event loop to avoid deadlock
asyncio.ensure_future(self._connection.close())
break
async def async_stop(self):
"""Close connection."""
self._is_stopping = True
if self._connection is not None:
await self._connection.close()
await self.proxy_services.unload()
def _next_id(self):
_id = self.__id
self.__id += 1
return _id
async def call(self, callback, message_type, **extra_args):
_id = self._next_id()
self._handlers[_id] = callback
try:
await self._connection.send_json(
{"id": _id, "type": message_type, **extra_args}
)
except aiohttp.client_exceptions.ClientError as err:
_LOGGER.error("remote websocket connection closed: %s", err)
await self._disconnected()
async def _disconnected(self):
# Remove all published entries
for entity in self._entities:
self._hass.states.async_remove(entity)
if self._heartbeat_task is not None:
self._heartbeat_task.cancel()
try:
await self._heartbeat_task
except asyncio.CancelledError:
pass
if self._remove_listener is not None:
self._remove_listener()
self.set_connection_state(STATE_DISCONNECTED)
self._heartbeat_task = None
self._remove_listener = None
self._entities = set()
self._all_entity_names = set()
if not self._is_stopping:
asyncio.ensure_future(self.async_connect())
async def _recv(self):
while not self._connection.closed:
try:
data = await self._connection.receive()
except aiohttp.client_exceptions.ClientError as err:
_LOGGER.error("remote websocket connection closed: %s", err)
break
if not data:
break
if data.type in (
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSING,
):
_LOGGER.debug("websocket connection is closing")
break
if data.type == aiohttp.WSMsgType.ERROR:
_LOGGER.error("websocket connection had an error")
if data.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG:
_LOGGER.error(f"please consider increasing message size with `{CONF_MAX_MSG_SIZE}`")
break
try:
message = data.json()
except TypeError as err:
_LOGGER.error("could not decode data (%s) as json: %s", data, err)
break
if message is None:
break
_LOGGER.debug("received: %s", message)
if message["type"] == api.TYPE_AUTH_OK:
self.set_connection_state(STATE_CONNECTED)
await self._init()
elif message["type"] == api.TYPE_AUTH_REQUIRED:
if self._access_token:
data = {"type": api.TYPE_AUTH, "access_token": self._access_token}
else:
_LOGGER.error("Access token required, but not provided")
self.set_connection_state(STATE_AUTH_REQUIRED)
return
try:
await self._connection.send_json(data)
except Exception as err:
_LOGGER.error("could not send data to remote connection: %s", err)
break
elif message["type"] == api.TYPE_AUTH_INVALID:
_LOGGER.error("Auth invalid, check your access token")
self.set_connection_state(STATE_AUTH_INVALID)
await self._connection.close()
return
else:
callback = self._handlers.get(message["id"])
if callback is not None:
if inspect.iscoroutinefunction(callback):
await callback(message)
else:
callback(message)
await self._disconnected()
async def _init(self):
async def forward_event(event):
"""Send local event to remote instance.
The affected entity_id has to origin from that remote instance,
otherwise the event is dicarded.
"""
event_data = event.data
service_data = event_data["service_data"]
if not service_data:
return
entity_ids = service_data.get("entity_id", None)
if not entity_ids:
return
if isinstance(entity_ids, str):
entity_ids = (entity_ids.lower(),)
entities = {entity_id.lower() for entity_id in self._entities}
entity_ids = entities.intersection(entity_ids)
if not entity_ids:
return
if self._entity_prefix:
def _remove_prefix(entity_id):
domain, object_id = split_entity_id(entity_id)
object_id = object_id.replace(self._entity_prefix.lower(), "", 1)
return domain + "." + object_id
entity_ids = {_remove_prefix(entity_id) for entity_id in entity_ids}
event_data = copy.deepcopy(event_data)
event_data["service_data"]["entity_id"] = list(entity_ids)
# Remove service_call_id parameter - websocket API
# doesn't accept that one
event_data.pop("service_call_id", None)
_id = self._next_id()
data = {"id": _id, "type": event.event_type, **event_data}
_LOGGER.debug("forward event: %s", data)
try:
await self._connection.send_json(data)
except Exception as err:
_LOGGER.error("could not send data to remote connection: %s", err)
await self._disconnected()
def state_changed(entity_id, state, attr):
"""Publish remote state change on local instance."""
domain, object_id = split_entity_id(entity_id)
self._all_entity_names.add(entity_id)
if entity_id in self._blacklist_e or domain in self._blacklist_d:
return
if (
(self._whitelist_e or self._whitelist_d)
and entity_id not in self._whitelist_e
and domain not in self._whitelist_d
):
return
for f in self._filter:
if f[CONF_ENTITY_ID] and not f[CONF_ENTITY_ID].match(entity_id):
continue
if f[CONF_UNIT_OF_MEASUREMENT]:
if CONF_UNIT_OF_MEASUREMENT not in attr:
continue
if f[CONF_UNIT_OF_MEASUREMENT] != attr[CONF_UNIT_OF_MEASUREMENT]:
continue
try:
if f[CONF_BELOW] and float(state) < f[CONF_BELOW]:
_LOGGER.info(
"%s: ignoring state '%s', because " "below '%s'",
entity_id,
state,
f[CONF_BELOW],
)
return
if f[CONF_ABOVE] and float(state) > f[CONF_ABOVE]:
_LOGGER.info(
"%s: ignoring state '%s', because " "above '%s'",
entity_id,
state,
f[CONF_ABOVE],
)
return
except ValueError:
pass
entity_id = self._prefixed_entity_id(entity_id)
# Add local customization data
if DATA_CUSTOMIZE in self._hass.data:
attr.update(self._hass.data[DATA_CUSTOMIZE].get(entity_id))
self._entities.add(entity_id)
self._hass.states.async_set(entity_id, state, attr)
def fire_event(message):
"""Publish remove event on local instance."""
if message["type"] == "result":
return
if message["type"] != "event":
return
if message["event"]["event_type"] == "state_changed":
data = message["event"]["data"]
entity_id = data["entity_id"]
if not data["new_state"]:
entity_id = self._prefixed_entity_id(entity_id)
# entity was removed in the remote instance
with suppress(ValueError, AttributeError, KeyError):
self._entities.remove(entity_id)
with suppress(ValueError, AttributeError, KeyError):
self._all_entity_names.remove(entity_id)
self._hass.states.async_remove(entity_id)
return
state = data["new_state"]["state"]
attr = data["new_state"]["attributes"]
state_changed(entity_id, state, attr)
else:
event = message["event"]
self._hass.bus.async_fire(
event_type=event["event_type"],
event_data=event["data"],
context=Context(
id=event["context"].get("id"),
user_id=event["context"].get("user_id"),
parent_id=event["context"].get("parent_id"),
),
origin=EventOrigin.remote,
)
def got_states(message):
"""Called when list of remote states is available."""
for entity in message["result"]:
entity_id = entity["entity_id"]
state = entity["state"]
attributes = entity["attributes"]
state_changed(entity_id, state, attributes)
self._remove_listener = self._hass.bus.async_listen(
EVENT_CALL_SERVICE, forward_event
)
for event in self._subscribe_events:
await self.call(fire_event, "subscribe_events", event_type=event)
await self.call(got_states, "get_states")
await self.proxy_services.load()