746 lines
26 KiB
Python
746 lines
26 KiB
Python
|
"""
|
||
|
Connect two Home Assistant instances via the Websocket API.
|
||
|
|
||
|
For more details about this component, please refer to the documentation at
|
||
|
https://home-assistant.io/components/remote_homeassistant/
|
||
|
"""
|
||
|
import asyncio
|
||
|
import copy
|
||
|
import fnmatch
|
||
|
import inspect
|
||
|
import logging
|
||
|
import re
|
||
|
from contextlib import suppress
|
||
|
|
||
|
import aiohttp
|
||
|
import homeassistant.components.websocket_api.auth as api
|
||
|
import homeassistant.helpers.config_validation as cv
|
||
|
import voluptuous as vol
|
||
|
from homeassistant.config import DATA_CUSTOMIZE
|
||
|
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
|
||
|
from homeassistant.const import (CONF_ABOVE, CONF_ACCESS_TOKEN, CONF_BELOW,
|
||
|
CONF_DOMAINS, CONF_ENTITIES, CONF_ENTITY_ID,
|
||
|
CONF_EXCLUDE, CONF_HOST, CONF_INCLUDE,
|
||
|
CONF_PORT, CONF_UNIT_OF_MEASUREMENT,
|
||
|
CONF_VERIFY_SSL, EVENT_CALL_SERVICE,
|
||
|
EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED,
|
||
|
SERVICE_RELOAD)
|
||
|
from homeassistant.core import (Context, EventOrigin, HomeAssistant, callback,
|
||
|
split_entity_id)
|
||
|
from homeassistant.helpers import device_registry as dr
|
||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||
|
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||
|
from homeassistant.helpers.reload import async_integration_yaml_config
|
||
|
from homeassistant.helpers.typing import ConfigType, HomeAssistantType
|
||
|
from homeassistant.setup import async_setup_component
|
||
|
|
||
|
from custom_components.remote_homeassistant.views import DiscoveryInfoView
|
||
|
|
||
|
from .const import (CONF_EXCLUDE_DOMAINS, CONF_EXCLUDE_ENTITIES,
|
||
|
CONF_INCLUDE_DOMAINS, CONF_INCLUDE_ENTITIES,
|
||
|
CONF_LOAD_COMPONENTS, CONF_OPTIONS, CONF_REMOTE_CONNECTION,
|
||
|
CONF_SERVICE_PREFIX, CONF_SERVICES, CONF_UNSUB_LISTENER,
|
||
|
DOMAIN, REMOTE_ID, DEFAULT_MAX_MSG_SIZE)
|
||
|
from .proxy_services import ProxyServices
|
||
|
from .rest_api import UnsupportedVersion, async_get_discovery_info
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
PLATFORMS = ["sensor"]
|
||
|
|
||
|
CONF_INSTANCES = "instances"
|
||
|
CONF_SECURE = "secure"
|
||
|
CONF_SUBSCRIBE_EVENTS = "subscribe_events"
|
||
|
CONF_ENTITY_PREFIX = "entity_prefix"
|
||
|
CONF_FILTER = "filter"
|
||
|
CONF_MAX_MSG_SIZE = "max_message_size"
|
||
|
|
||
|
STATE_INIT = "initializing"
|
||
|
STATE_CONNECTING = "connecting"
|
||
|
STATE_CONNECTED = "connected"
|
||
|
STATE_AUTH_INVALID = "auth_invalid"
|
||
|
STATE_AUTH_REQUIRED = "auth_required"
|
||
|
STATE_RECONNECTING = "reconnecting"
|
||
|
STATE_DISCONNECTED = "disconnected"
|
||
|
|
||
|
DEFAULT_ENTITY_PREFIX = ""
|
||
|
|
||
|
INSTANCES_SCHEMA = vol.Schema(
|
||
|
{
|
||
|
vol.Required(CONF_HOST): cv.string,
|
||
|
vol.Optional(CONF_PORT, default=8123): cv.port,
|
||
|
vol.Optional(CONF_SECURE, default=False): cv.boolean,
|
||
|
vol.Optional(CONF_VERIFY_SSL, default=True): cv.boolean,
|
||
|
vol.Required(CONF_ACCESS_TOKEN): cv.string,
|
||
|
vol.Optional(CONF_MAX_MSG_SIZE, default=DEFAULT_MAX_MSG_SIZE): vol.Coerce(int),
|
||
|
vol.Optional(CONF_EXCLUDE, default={}): vol.Schema(
|
||
|
{
|
||
|
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||
|
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
|
||
|
cv.ensure_list, [cv.string]
|
||
|
),
|
||
|
}
|
||
|
),
|
||
|
vol.Optional(CONF_INCLUDE, default={}): vol.Schema(
|
||
|
{
|
||
|
vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids,
|
||
|
vol.Optional(CONF_DOMAINS, default=[]): vol.All(
|
||
|
cv.ensure_list, [cv.string]
|
||
|
),
|
||
|
}
|
||
|
),
|
||
|
vol.Optional(CONF_FILTER, default=[]): vol.All(
|
||
|
cv.ensure_list,
|
||
|
[
|
||
|
vol.Schema(
|
||
|
{
|
||
|
vol.Optional(CONF_ENTITY_ID): cv.string,
|
||
|
vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string,
|
||
|
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||
|
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||
|
}
|
||
|
)
|
||
|
],
|
||
|
),
|
||
|
vol.Optional(CONF_SUBSCRIBE_EVENTS): cv.ensure_list,
|
||
|
vol.Optional(CONF_ENTITY_PREFIX, default=DEFAULT_ENTITY_PREFIX): cv.string,
|
||
|
vol.Optional(CONF_LOAD_COMPONENTS): cv.ensure_list,
|
||
|
vol.Required(CONF_SERVICE_PREFIX, default="remote_"): cv.string,
|
||
|
vol.Optional(CONF_SERVICES): cv.ensure_list,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
CONFIG_SCHEMA = vol.Schema(
|
||
|
{
|
||
|
DOMAIN: vol.Schema(
|
||
|
{
|
||
|
vol.Required(CONF_INSTANCES): vol.All(
|
||
|
cv.ensure_list, [INSTANCES_SCHEMA]
|
||
|
),
|
||
|
}
|
||
|
),
|
||
|
},
|
||
|
extra=vol.ALLOW_EXTRA,
|
||
|
)
|
||
|
|
||
|
HEARTBEAT_INTERVAL = 20
|
||
|
HEARTBEAT_TIMEOUT = 5
|
||
|
|
||
|
INTERNALLY_USED_EVENTS = [EVENT_STATE_CHANGED]
|
||
|
|
||
|
|
||
|
def async_yaml_to_config_entry(instance_conf):
|
||
|
"""Convert YAML config into data and options used by a config entry."""
|
||
|
conf = instance_conf.copy()
|
||
|
options = {}
|
||
|
|
||
|
if CONF_INCLUDE in conf:
|
||
|
include = conf.pop(CONF_INCLUDE)
|
||
|
if CONF_ENTITIES in include:
|
||
|
options[CONF_INCLUDE_ENTITIES] = include[CONF_ENTITIES]
|
||
|
if CONF_DOMAINS in include:
|
||
|
options[CONF_INCLUDE_DOMAINS] = include[CONF_DOMAINS]
|
||
|
|
||
|
if CONF_EXCLUDE in conf:
|
||
|
exclude = conf.pop(CONF_EXCLUDE)
|
||
|
if CONF_ENTITIES in exclude:
|
||
|
options[CONF_EXCLUDE_ENTITIES] = exclude[CONF_ENTITIES]
|
||
|
if CONF_DOMAINS in exclude:
|
||
|
options[CONF_EXCLUDE_DOMAINS] = exclude[CONF_DOMAINS]
|
||
|
|
||
|
for option in [
|
||
|
CONF_FILTER,
|
||
|
CONF_SUBSCRIBE_EVENTS,
|
||
|
CONF_ENTITY_PREFIX,
|
||
|
CONF_LOAD_COMPONENTS,
|
||
|
CONF_SERVICE_PREFIX,
|
||
|
CONF_SERVICES,
|
||
|
]:
|
||
|
if option in conf:
|
||
|
options[option] = conf.pop(option)
|
||
|
|
||
|
return conf, options
|
||
|
|
||
|
|
||
|
async def _async_update_config_entry_if_from_yaml(hass, entries_by_id, conf):
|
||
|
"""Update a config entry with the latest yaml."""
|
||
|
try:
|
||
|
info = await async_get_discovery_info(
|
||
|
hass,
|
||
|
conf[CONF_HOST],
|
||
|
conf[CONF_PORT],
|
||
|
conf[CONF_SECURE],
|
||
|
conf[CONF_ACCESS_TOKEN],
|
||
|
conf[CONF_VERIFY_SSL],
|
||
|
)
|
||
|
except Exception:
|
||
|
_LOGGER.exception(f"reload of {conf[CONF_HOST]} failed")
|
||
|
else:
|
||
|
entry = entries_by_id.get(info["uuid"])
|
||
|
if entry:
|
||
|
data, options = async_yaml_to_config_entry(conf)
|
||
|
hass.config_entries.async_update_entry(entry, data=data, options=options)
|
||
|
|
||
|
|
||
|
async def setup_remote_instance(hass: HomeAssistantType):
|
||
|
hass.http.register_view(DiscoveryInfoView())
|
||
|
|
||
|
|
||
|
async def async_setup(hass: HomeAssistantType, config: ConfigType):
|
||
|
"""Set up the remote_homeassistant component."""
|
||
|
hass.data.setdefault(DOMAIN, {})
|
||
|
|
||
|
async def _handle_reload(service):
|
||
|
"""Handle reload service call."""
|
||
|
config = await async_integration_yaml_config(hass, DOMAIN)
|
||
|
|
||
|
if not config or DOMAIN not in config:
|
||
|
return
|
||
|
|
||
|
current_entries = hass.config_entries.async_entries(DOMAIN)
|
||
|
entries_by_id = {entry.unique_id: entry for entry in current_entries}
|
||
|
|
||
|
instances = config[DOMAIN][CONF_INSTANCES]
|
||
|
update_tasks = [
|
||
|
_async_update_config_entry_if_from_yaml(hass, entries_by_id, instance)
|
||
|
for instance in instances
|
||
|
]
|
||
|
|
||
|
await asyncio.gather(*update_tasks)
|
||
|
|
||
|
hass.async_create_task(setup_remote_instance(hass))
|
||
|
|
||
|
hass.helpers.service.async_register_admin_service(
|
||
|
DOMAIN,
|
||
|
SERVICE_RELOAD,
|
||
|
_handle_reload,
|
||
|
)
|
||
|
|
||
|
instances = config.get(DOMAIN, {}).get(CONF_INSTANCES, [])
|
||
|
for instance in instances:
|
||
|
hass.async_create_task(
|
||
|
hass.config_entries.flow.async_init(
|
||
|
DOMAIN, context={"source": SOURCE_IMPORT}, data=instance
|
||
|
)
|
||
|
)
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||
|
"""Set up Remote Home-Assistant from a config entry."""
|
||
|
_async_import_options_from_yaml(hass, entry)
|
||
|
if entry.unique_id == REMOTE_ID:
|
||
|
hass.async_create_task(setup_remote_instance(hass))
|
||
|
return True
|
||
|
else:
|
||
|
remote = RemoteConnection(hass, entry)
|
||
|
|
||
|
hass.data[DOMAIN][entry.entry_id] = {
|
||
|
CONF_REMOTE_CONNECTION: remote,
|
||
|
CONF_UNSUB_LISTENER: entry.add_update_listener(_update_listener),
|
||
|
}
|
||
|
|
||
|
async def setup_components_and_platforms():
|
||
|
"""Set up platforms and initiate connection."""
|
||
|
for domain in entry.options.get(CONF_LOAD_COMPONENTS, []):
|
||
|
hass.async_create_task(async_setup_component(hass, domain, {}))
|
||
|
|
||
|
await asyncio.gather(
|
||
|
*[
|
||
|
hass.config_entries.async_forward_entry_setup(entry, platform)
|
||
|
for platform in PLATFORMS
|
||
|
]
|
||
|
)
|
||
|
await remote.async_connect()
|
||
|
|
||
|
hass.async_create_task(setup_components_and_platforms())
|
||
|
|
||
|
return True
|
||
|
|
||
|
|
||
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||
|
"""Unload a config entry."""
|
||
|
unload_ok = all(
|
||
|
await asyncio.gather(
|
||
|
*[
|
||
|
hass.config_entries.async_forward_entry_unload(entry, platform)
|
||
|
for platform in PLATFORMS
|
||
|
]
|
||
|
)
|
||
|
)
|
||
|
|
||
|
if unload_ok:
|
||
|
data = hass.data[DOMAIN].pop(entry.entry_id)
|
||
|
await data[CONF_REMOTE_CONNECTION].async_stop()
|
||
|
data[CONF_UNSUB_LISTENER]()
|
||
|
|
||
|
return unload_ok
|
||
|
|
||
|
|
||
|
@callback
|
||
|
def _async_import_options_from_yaml(hass: HomeAssistant, entry: ConfigEntry):
|
||
|
"""Import options from YAML into options section of config entry."""
|
||
|
if CONF_OPTIONS in entry.data:
|
||
|
data = entry.data.copy()
|
||
|
options = data.pop(CONF_OPTIONS)
|
||
|
hass.config_entries.async_update_entry(entry, data=data, options=options)
|
||
|
|
||
|
|
||
|
async def _update_listener(hass, config_entry):
|
||
|
"""Update listener."""
|
||
|
await hass.config_entries.async_reload(config_entry.entry_id)
|
||
|
|
||
|
|
||
|
class RemoteConnection(object):
|
||
|
"""A Websocket connection to a remote home-assistant instance."""
|
||
|
|
||
|
def __init__(self, hass, config_entry):
|
||
|
"""Initialize the connection."""
|
||
|
self._hass = hass
|
||
|
self._entry = config_entry
|
||
|
self._secure = config_entry.data.get(CONF_SECURE, False)
|
||
|
self._verify_ssl = config_entry.data.get(CONF_VERIFY_SSL, False)
|
||
|
self._access_token = config_entry.data.get(CONF_ACCESS_TOKEN)
|
||
|
self._max_msg_size = config_entry.data.get(CONF_MAX_MSG_SIZE)
|
||
|
|
||
|
# see homeassistant/components/influxdb/__init__.py
|
||
|
# for include/exclude logic
|
||
|
self._whitelist_e = set(config_entry.options.get(CONF_INCLUDE_ENTITIES, []))
|
||
|
self._whitelist_d = set(config_entry.options.get(CONF_INCLUDE_DOMAINS, []))
|
||
|
self._blacklist_e = set(config_entry.options.get(CONF_EXCLUDE_ENTITIES, []))
|
||
|
self._blacklist_d = set(config_entry.options.get(CONF_EXCLUDE_DOMAINS, []))
|
||
|
|
||
|
self._filter = [
|
||
|
{
|
||
|
CONF_ENTITY_ID: re.compile(fnmatch.translate(f.get(CONF_ENTITY_ID)))
|
||
|
if f.get(CONF_ENTITY_ID)
|
||
|
else None,
|
||
|
CONF_UNIT_OF_MEASUREMENT: f.get(CONF_UNIT_OF_MEASUREMENT),
|
||
|
CONF_ABOVE: f.get(CONF_ABOVE),
|
||
|
CONF_BELOW: f.get(CONF_BELOW),
|
||
|
}
|
||
|
for f in config_entry.options.get(CONF_FILTER, [])
|
||
|
]
|
||
|
|
||
|
self._subscribe_events = set(
|
||
|
config_entry.options.get(CONF_SUBSCRIBE_EVENTS, []) + INTERNALLY_USED_EVENTS
|
||
|
)
|
||
|
self._entity_prefix = config_entry.options.get(CONF_ENTITY_PREFIX, "")
|
||
|
|
||
|
self._connection = None
|
||
|
self._heartbeat_task = None
|
||
|
self._is_stopping = False
|
||
|
self._entities = set()
|
||
|
self._all_entity_names = set()
|
||
|
self._handlers = {}
|
||
|
self._remove_listener = None
|
||
|
self.proxy_services = ProxyServices(hass, config_entry, self)
|
||
|
|
||
|
self.set_connection_state(STATE_CONNECTING)
|
||
|
|
||
|
self.__id = 1
|
||
|
|
||
|
def _prefixed_entity_id(self, entity_id):
|
||
|
if self._entity_prefix:
|
||
|
domain, object_id = split_entity_id(entity_id)
|
||
|
object_id = self._entity_prefix + object_id
|
||
|
entity_id = domain + "." + object_id
|
||
|
return entity_id
|
||
|
return entity_id
|
||
|
|
||
|
def set_connection_state(self, state):
|
||
|
"""Change current connection state."""
|
||
|
signal = f"remote_homeassistant_{self._entry.unique_id}"
|
||
|
async_dispatcher_send(self._hass, signal, state)
|
||
|
|
||
|
@callback
|
||
|
def _get_url(self):
|
||
|
"""Get url to connect to."""
|
||
|
return "%s://%s:%s/api/websocket" % (
|
||
|
"wss" if self._secure else "ws",
|
||
|
self._entry.data[CONF_HOST],
|
||
|
self._entry.data[CONF_PORT],
|
||
|
)
|
||
|
|
||
|
async def async_connect(self):
|
||
|
"""Connect to remote home-assistant websocket..."""
|
||
|
|
||
|
async def _async_stop_handler(event):
|
||
|
"""Stop when Home Assistant is shutting down."""
|
||
|
await self.async_stop()
|
||
|
|
||
|
async def _async_instance_get_info():
|
||
|
"""Fetch discovery info from remote instance."""
|
||
|
try:
|
||
|
return await async_get_discovery_info(
|
||
|
self._hass,
|
||
|
self._entry.data[CONF_HOST],
|
||
|
self._entry.data[CONF_PORT],
|
||
|
self._secure,
|
||
|
self._access_token,
|
||
|
self._verify_ssl,
|
||
|
)
|
||
|
except OSError:
|
||
|
_LOGGER.exception("failed to connect")
|
||
|
except UnsupportedVersion:
|
||
|
_LOGGER.error("Unsupported version, at least 0.111 is required.")
|
||
|
except Exception:
|
||
|
_LOGGER.exception("failed to fetch instance info")
|
||
|
return None
|
||
|
|
||
|
@callback
|
||
|
def _async_instance_id_match(info):
|
||
|
"""Verify if remote instance id matches the expected id."""
|
||
|
if not info:
|
||
|
return False
|
||
|
if info and info["uuid"] != self._entry.unique_id:
|
||
|
_LOGGER.error(
|
||
|
"instance id not matching: %s != %s",
|
||
|
info["uuid"],
|
||
|
self._entry.unique_id,
|
||
|
)
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
url = self._get_url()
|
||
|
|
||
|
session = async_get_clientsession(self._hass, self._verify_ssl)
|
||
|
self.set_connection_state(STATE_CONNECTING)
|
||
|
|
||
|
while True:
|
||
|
info = await _async_instance_get_info()
|
||
|
|
||
|
# Verify we are talking to correct instance
|
||
|
if not _async_instance_id_match(info):
|
||
|
self.set_connection_state(STATE_RECONNECTING)
|
||
|
await asyncio.sleep(10)
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
_LOGGER.info("Connecting to %s", url)
|
||
|
self._connection = await session.ws_connect(url, max_msg_size = self._max_msg_size)
|
||
|
except aiohttp.client_exceptions.ClientError:
|
||
|
_LOGGER.error("Could not connect to %s, retry in 10 seconds...", url)
|
||
|
self.set_connection_state(STATE_RECONNECTING)
|
||
|
await asyncio.sleep(10)
|
||
|
else:
|
||
|
_LOGGER.info("Connected to home-assistant websocket at %s", url)
|
||
|
break
|
||
|
|
||
|
self._hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop_handler)
|
||
|
|
||
|
device_registry = dr.async_get(self._hass)
|
||
|
device_registry.async_get_or_create(
|
||
|
config_entry_id=self._entry.entry_id,
|
||
|
identifiers={(DOMAIN, f"remote_{self._entry.unique_id}")},
|
||
|
name=info.get("location_name"),
|
||
|
manufacturer="Home Assistant",
|
||
|
model=info.get("installation_type"),
|
||
|
sw_version=info.get("ha_version"),
|
||
|
)
|
||
|
|
||
|
asyncio.ensure_future(self._recv())
|
||
|
self._heartbeat_task = self._hass.loop.create_task(self._heartbeat_loop())
|
||
|
|
||
|
async def _heartbeat_loop(self):
|
||
|
"""Send periodic heartbeats to remote instance."""
|
||
|
while not self._connection.closed:
|
||
|
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||
|
|
||
|
_LOGGER.debug("Sending ping")
|
||
|
event = asyncio.Event()
|
||
|
|
||
|
def resp(message):
|
||
|
_LOGGER.debug("Got pong: %s", message)
|
||
|
event.set()
|
||
|
|
||
|
await self.call(resp, "ping")
|
||
|
|
||
|
try:
|
||
|
await asyncio.wait_for(event.wait(), HEARTBEAT_TIMEOUT)
|
||
|
except asyncio.TimeoutError:
|
||
|
_LOGGER.error("heartbeat failed")
|
||
|
|
||
|
# Schedule closing on event loop to avoid deadlock
|
||
|
asyncio.ensure_future(self._connection.close())
|
||
|
break
|
||
|
|
||
|
async def async_stop(self):
|
||
|
"""Close connection."""
|
||
|
self._is_stopping = True
|
||
|
if self._connection is not None:
|
||
|
await self._connection.close()
|
||
|
await self.proxy_services.unload()
|
||
|
|
||
|
def _next_id(self):
|
||
|
_id = self.__id
|
||
|
self.__id += 1
|
||
|
return _id
|
||
|
|
||
|
async def call(self, callback, message_type, **extra_args):
|
||
|
_id = self._next_id()
|
||
|
self._handlers[_id] = callback
|
||
|
try:
|
||
|
await self._connection.send_json(
|
||
|
{"id": _id, "type": message_type, **extra_args}
|
||
|
)
|
||
|
except aiohttp.client_exceptions.ClientError as err:
|
||
|
_LOGGER.error("remote websocket connection closed: %s", err)
|
||
|
await self._disconnected()
|
||
|
|
||
|
async def _disconnected(self):
|
||
|
# Remove all published entries
|
||
|
for entity in self._entities:
|
||
|
self._hass.states.async_remove(entity)
|
||
|
if self._heartbeat_task is not None:
|
||
|
self._heartbeat_task.cancel()
|
||
|
try:
|
||
|
await self._heartbeat_task
|
||
|
except asyncio.CancelledError:
|
||
|
pass
|
||
|
if self._remove_listener is not None:
|
||
|
self._remove_listener()
|
||
|
|
||
|
self.set_connection_state(STATE_DISCONNECTED)
|
||
|
self._heartbeat_task = None
|
||
|
self._remove_listener = None
|
||
|
self._entities = set()
|
||
|
self._all_entity_names = set()
|
||
|
if not self._is_stopping:
|
||
|
asyncio.ensure_future(self.async_connect())
|
||
|
|
||
|
async def _recv(self):
|
||
|
while not self._connection.closed:
|
||
|
try:
|
||
|
data = await self._connection.receive()
|
||
|
except aiohttp.client_exceptions.ClientError as err:
|
||
|
_LOGGER.error("remote websocket connection closed: %s", err)
|
||
|
break
|
||
|
|
||
|
if not data:
|
||
|
break
|
||
|
|
||
|
if data.type in (
|
||
|
aiohttp.WSMsgType.CLOSE,
|
||
|
aiohttp.WSMsgType.CLOSED,
|
||
|
aiohttp.WSMsgType.CLOSING,
|
||
|
):
|
||
|
_LOGGER.debug("websocket connection is closing")
|
||
|
break
|
||
|
|
||
|
if data.type == aiohttp.WSMsgType.ERROR:
|
||
|
_LOGGER.error("websocket connection had an error")
|
||
|
if data.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG:
|
||
|
_LOGGER.error(f"please consider increasing message size with `{CONF_MAX_MSG_SIZE}`")
|
||
|
break
|
||
|
|
||
|
try:
|
||
|
message = data.json()
|
||
|
except TypeError as err:
|
||
|
_LOGGER.error("could not decode data (%s) as json: %s", data, err)
|
||
|
break
|
||
|
|
||
|
if message is None:
|
||
|
break
|
||
|
|
||
|
_LOGGER.debug("received: %s", message)
|
||
|
|
||
|
if message["type"] == api.TYPE_AUTH_OK:
|
||
|
self.set_connection_state(STATE_CONNECTED)
|
||
|
await self._init()
|
||
|
|
||
|
elif message["type"] == api.TYPE_AUTH_REQUIRED:
|
||
|
if self._access_token:
|
||
|
data = {"type": api.TYPE_AUTH, "access_token": self._access_token}
|
||
|
else:
|
||
|
_LOGGER.error("Access token required, but not provided")
|
||
|
self.set_connection_state(STATE_AUTH_REQUIRED)
|
||
|
return
|
||
|
try:
|
||
|
await self._connection.send_json(data)
|
||
|
except Exception as err:
|
||
|
_LOGGER.error("could not send data to remote connection: %s", err)
|
||
|
break
|
||
|
|
||
|
elif message["type"] == api.TYPE_AUTH_INVALID:
|
||
|
_LOGGER.error("Auth invalid, check your access token")
|
||
|
self.set_connection_state(STATE_AUTH_INVALID)
|
||
|
await self._connection.close()
|
||
|
return
|
||
|
|
||
|
else:
|
||
|
callback = self._handlers.get(message["id"])
|
||
|
if callback is not None:
|
||
|
if inspect.iscoroutinefunction(callback):
|
||
|
await callback(message)
|
||
|
else:
|
||
|
callback(message)
|
||
|
|
||
|
await self._disconnected()
|
||
|
|
||
|
async def _init(self):
|
||
|
async def forward_event(event):
|
||
|
"""Send local event to remote instance.
|
||
|
|
||
|
The affected entity_id has to origin from that remote instance,
|
||
|
otherwise the event is dicarded.
|
||
|
"""
|
||
|
event_data = event.data
|
||
|
service_data = event_data["service_data"]
|
||
|
|
||
|
if not service_data:
|
||
|
return
|
||
|
|
||
|
entity_ids = service_data.get("entity_id", None)
|
||
|
|
||
|
if not entity_ids:
|
||
|
return
|
||
|
|
||
|
if isinstance(entity_ids, str):
|
||
|
entity_ids = (entity_ids.lower(),)
|
||
|
|
||
|
entities = {entity_id.lower() for entity_id in self._entities}
|
||
|
|
||
|
entity_ids = entities.intersection(entity_ids)
|
||
|
|
||
|
if not entity_ids:
|
||
|
return
|
||
|
|
||
|
if self._entity_prefix:
|
||
|
|
||
|
def _remove_prefix(entity_id):
|
||
|
domain, object_id = split_entity_id(entity_id)
|
||
|
object_id = object_id.replace(self._entity_prefix.lower(), "", 1)
|
||
|
return domain + "." + object_id
|
||
|
|
||
|
entity_ids = {_remove_prefix(entity_id) for entity_id in entity_ids}
|
||
|
|
||
|
event_data = copy.deepcopy(event_data)
|
||
|
event_data["service_data"]["entity_id"] = list(entity_ids)
|
||
|
|
||
|
# Remove service_call_id parameter - websocket API
|
||
|
# doesn't accept that one
|
||
|
event_data.pop("service_call_id", None)
|
||
|
|
||
|
_id = self._next_id()
|
||
|
data = {"id": _id, "type": event.event_type, **event_data}
|
||
|
|
||
|
_LOGGER.debug("forward event: %s", data)
|
||
|
|
||
|
try:
|
||
|
await self._connection.send_json(data)
|
||
|
except Exception as err:
|
||
|
_LOGGER.error("could not send data to remote connection: %s", err)
|
||
|
await self._disconnected()
|
||
|
|
||
|
def state_changed(entity_id, state, attr):
|
||
|
"""Publish remote state change on local instance."""
|
||
|
domain, object_id = split_entity_id(entity_id)
|
||
|
|
||
|
self._all_entity_names.add(entity_id)
|
||
|
|
||
|
if entity_id in self._blacklist_e or domain in self._blacklist_d:
|
||
|
return
|
||
|
|
||
|
if (
|
||
|
(self._whitelist_e or self._whitelist_d)
|
||
|
and entity_id not in self._whitelist_e
|
||
|
and domain not in self._whitelist_d
|
||
|
):
|
||
|
return
|
||
|
|
||
|
for f in self._filter:
|
||
|
if f[CONF_ENTITY_ID] and not f[CONF_ENTITY_ID].match(entity_id):
|
||
|
continue
|
||
|
if f[CONF_UNIT_OF_MEASUREMENT]:
|
||
|
if CONF_UNIT_OF_MEASUREMENT not in attr:
|
||
|
continue
|
||
|
if f[CONF_UNIT_OF_MEASUREMENT] != attr[CONF_UNIT_OF_MEASUREMENT]:
|
||
|
continue
|
||
|
try:
|
||
|
if f[CONF_BELOW] and float(state) < f[CONF_BELOW]:
|
||
|
_LOGGER.info(
|
||
|
"%s: ignoring state '%s', because " "below '%s'",
|
||
|
entity_id,
|
||
|
state,
|
||
|
f[CONF_BELOW],
|
||
|
)
|
||
|
return
|
||
|
if f[CONF_ABOVE] and float(state) > f[CONF_ABOVE]:
|
||
|
_LOGGER.info(
|
||
|
"%s: ignoring state '%s', because " "above '%s'",
|
||
|
entity_id,
|
||
|
state,
|
||
|
f[CONF_ABOVE],
|
||
|
)
|
||
|
return
|
||
|
except ValueError:
|
||
|
pass
|
||
|
|
||
|
entity_id = self._prefixed_entity_id(entity_id)
|
||
|
|
||
|
# Add local customization data
|
||
|
if DATA_CUSTOMIZE in self._hass.data:
|
||
|
attr.update(self._hass.data[DATA_CUSTOMIZE].get(entity_id))
|
||
|
|
||
|
self._entities.add(entity_id)
|
||
|
self._hass.states.async_set(entity_id, state, attr)
|
||
|
|
||
|
def fire_event(message):
|
||
|
"""Publish remove event on local instance."""
|
||
|
if message["type"] == "result":
|
||
|
return
|
||
|
|
||
|
if message["type"] != "event":
|
||
|
return
|
||
|
|
||
|
if message["event"]["event_type"] == "state_changed":
|
||
|
data = message["event"]["data"]
|
||
|
entity_id = data["entity_id"]
|
||
|
if not data["new_state"]:
|
||
|
entity_id = self._prefixed_entity_id(entity_id)
|
||
|
# entity was removed in the remote instance
|
||
|
with suppress(ValueError, AttributeError, KeyError):
|
||
|
self._entities.remove(entity_id)
|
||
|
with suppress(ValueError, AttributeError, KeyError):
|
||
|
self._all_entity_names.remove(entity_id)
|
||
|
self._hass.states.async_remove(entity_id)
|
||
|
return
|
||
|
|
||
|
state = data["new_state"]["state"]
|
||
|
attr = data["new_state"]["attributes"]
|
||
|
state_changed(entity_id, state, attr)
|
||
|
else:
|
||
|
event = message["event"]
|
||
|
self._hass.bus.async_fire(
|
||
|
event_type=event["event_type"],
|
||
|
event_data=event["data"],
|
||
|
context=Context(
|
||
|
id=event["context"].get("id"),
|
||
|
user_id=event["context"].get("user_id"),
|
||
|
parent_id=event["context"].get("parent_id"),
|
||
|
),
|
||
|
origin=EventOrigin.remote,
|
||
|
)
|
||
|
|
||
|
def got_states(message):
|
||
|
"""Called when list of remote states is available."""
|
||
|
for entity in message["result"]:
|
||
|
entity_id = entity["entity_id"]
|
||
|
state = entity["state"]
|
||
|
attributes = entity["attributes"]
|
||
|
|
||
|
state_changed(entity_id, state, attributes)
|
||
|
|
||
|
self._remove_listener = self._hass.bus.async_listen(
|
||
|
EVENT_CALL_SERVICE, forward_event
|
||
|
)
|
||
|
|
||
|
for event in self._subscribe_events:
|
||
|
await self.call(fire_event, "subscribe_events", event_type=event)
|
||
|
|
||
|
await self.call(got_states, "get_states")
|
||
|
|
||
|
await self.proxy_services.load()
|