""" Connect two Home Assistant instances via the Websocket API. For more details about this component, please refer to the documentation at https://home-assistant.io/components/remote_homeassistant/ """ from __future__ import annotations import asyncio from typing import Optional import copy import fnmatch import inspect import logging import re from contextlib import suppress import aiohttp from aiohttp import ClientWebSocketResponse import homeassistant.components.websocket_api.auth as api import homeassistant.helpers.config_validation as cv import voluptuous as vol try: from homeassistant.core_config import DATA_CUSTOMIZE except (ModuleNotFoundError, ImportError): # hass 2024.10 or older from homeassistant.config import DATA_CUSTOMIZE from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.const import (CONF_ABOVE, CONF_ACCESS_TOKEN, CONF_BELOW, CONF_DOMAINS, CONF_ENTITIES, CONF_ENTITY_ID, CONF_EXCLUDE, CONF_HOST, CONF_INCLUDE, CONF_PORT, CONF_UNIT_OF_MEASUREMENT, CONF_VERIFY_SSL, EVENT_CALL_SERVICE, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, SERVICE_RELOAD) from homeassistant.core import (Context, EventOrigin, HomeAssistant, callback, split_entity_id) from homeassistant.helpers import device_registry as dr from homeassistant.helpers import entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.reload import async_integration_yaml_config from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.typing import ConfigType from homeassistant.setup import async_setup_component from custom_components.remote_homeassistant.views import DiscoveryInfoView from .const import (CONF_EXCLUDE_DOMAINS, CONF_EXCLUDE_ENTITIES, CONF_INCLUDE_DOMAINS, CONF_INCLUDE_ENTITIES, CONF_LOAD_COMPONENTS, CONF_OPTIONS, CONF_REMOTE_CONNECTION, CONF_SERVICE_PREFIX, CONF_SERVICES, CONF_UNSUB_LISTENER, DOMAIN, REMOTE_ID, DEFAULT_MAX_MSG_SIZE) from .proxy_services import ProxyServices from .rest_api import UnsupportedVersion, async_get_discovery_info _LOGGER = logging.getLogger(__name__) PLATFORMS = ["sensor"] CONF_INSTANCES = "instances" CONF_SECURE = "secure" CONF_SUBSCRIBE_EVENTS = "subscribe_events" CONF_ENTITY_PREFIX = "entity_prefix" CONF_ENTITY_FRIENDLY_NAME_PREFIX = "entity_friendly_name_prefix" CONF_FILTER = "filter" CONF_MAX_MSG_SIZE = "max_message_size" STATE_INIT = "initializing" STATE_CONNECTING = "connecting" STATE_CONNECTED = "connected" STATE_AUTH_INVALID = "auth_invalid" STATE_AUTH_REQUIRED = "auth_required" STATE_RECONNECTING = "reconnecting" STATE_DISCONNECTED = "disconnected" DEFAULT_ENTITY_PREFIX = "" DEFAULT_ENTITY_FRIENDLY_NAME_PREFIX = "" INSTANCES_SCHEMA = vol.Schema( { vol.Required(CONF_HOST): cv.string, vol.Optional(CONF_PORT, default=8123): cv.port, vol.Optional(CONF_SECURE, default=False): cv.boolean, vol.Optional(CONF_VERIFY_SSL, default=True): cv.boolean, vol.Required(CONF_ACCESS_TOKEN): cv.string, vol.Optional(CONF_MAX_MSG_SIZE, default=DEFAULT_MAX_MSG_SIZE): vol.Coerce(int), vol.Optional(CONF_EXCLUDE, default={}): vol.Schema( { vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, vol.Optional(CONF_DOMAINS, default=[]): vol.All( cv.ensure_list, [cv.string] ), } ), vol.Optional(CONF_INCLUDE, default={}): vol.Schema( { vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, vol.Optional(CONF_DOMAINS, default=[]): vol.All( cv.ensure_list, [cv.string] ), } ), vol.Optional(CONF_FILTER, default=[]): vol.All( cv.ensure_list, [ vol.Schema( { vol.Optional(CONF_ENTITY_ID): cv.string, vol.Optional(CONF_UNIT_OF_MEASUREMENT): cv.string, vol.Optional(CONF_ABOVE): vol.Coerce(float), vol.Optional(CONF_BELOW): vol.Coerce(float), } ) ], ), vol.Optional(CONF_SUBSCRIBE_EVENTS): cv.ensure_list, vol.Optional(CONF_ENTITY_PREFIX, default=DEFAULT_ENTITY_PREFIX): cv.string, vol.Optional(CONF_ENTITY_FRIENDLY_NAME_PREFIX, default=DEFAULT_ENTITY_FRIENDLY_NAME_PREFIX): cv.string, vol.Optional(CONF_LOAD_COMPONENTS): cv.ensure_list, vol.Required(CONF_SERVICE_PREFIX, default="remote_"): cv.string, vol.Optional(CONF_SERVICES): cv.ensure_list, } ) CONFIG_SCHEMA = vol.Schema( { DOMAIN: vol.Schema( { vol.Required(CONF_INSTANCES): vol.All( cv.ensure_list, [INSTANCES_SCHEMA] ), } ), }, extra=vol.ALLOW_EXTRA, ) HEARTBEAT_INTERVAL = 20 HEARTBEAT_TIMEOUT = 5 INTERNALLY_USED_EVENTS = [EVENT_STATE_CHANGED] def async_yaml_to_config_entry(instance_conf): """Convert YAML config into data and options used by a config entry.""" conf = instance_conf.copy() options = {} if CONF_INCLUDE in conf: include = conf.pop(CONF_INCLUDE) if CONF_ENTITIES in include: options[CONF_INCLUDE_ENTITIES] = include[CONF_ENTITIES] if CONF_DOMAINS in include: options[CONF_INCLUDE_DOMAINS] = include[CONF_DOMAINS] if CONF_EXCLUDE in conf: exclude = conf.pop(CONF_EXCLUDE) if CONF_ENTITIES in exclude: options[CONF_EXCLUDE_ENTITIES] = exclude[CONF_ENTITIES] if CONF_DOMAINS in exclude: options[CONF_EXCLUDE_DOMAINS] = exclude[CONF_DOMAINS] for option in [ CONF_FILTER, CONF_SUBSCRIBE_EVENTS, CONF_ENTITY_PREFIX, CONF_ENTITY_FRIENDLY_NAME_PREFIX, CONF_LOAD_COMPONENTS, CONF_SERVICE_PREFIX, CONF_SERVICES, ]: if option in conf: options[option] = conf.pop(option) return conf, options async def _async_update_config_entry_if_from_yaml(hass, entries_by_id, conf): """Update a config entry with the latest yaml.""" try: info = await async_get_discovery_info( hass, conf[CONF_HOST], conf[CONF_PORT], conf[CONF_SECURE], conf[CONF_ACCESS_TOKEN], conf[CONF_VERIFY_SSL], ) except Exception: _LOGGER.exception(f"reload of {conf[CONF_HOST]} failed") else: entry = entries_by_id.get(info["uuid"]) if entry: data, options = async_yaml_to_config_entry(conf) hass.config_entries.async_update_entry(entry, data=data, options=options) async def setup_remote_instance(hass: HomeAssistant.core.HomeAssistant): hass.http.register_view(DiscoveryInfoView()) async def async_setup(hass: HomeAssistant.core.HomeAssistant, config: ConfigType): """Set up the remote_homeassistant component.""" hass.data.setdefault(DOMAIN, {}) async def _handle_reload(service): """Handle reload service call.""" config = await async_integration_yaml_config(hass, DOMAIN) if not config or DOMAIN not in config: return current_entries = hass.config_entries.async_entries(DOMAIN) entries_by_id = {entry.unique_id: entry for entry in current_entries} instances = config[DOMAIN][CONF_INSTANCES] update_tasks = [ _async_update_config_entry_if_from_yaml(hass, entries_by_id, instance) for instance in instances ] await asyncio.gather(*update_tasks) hass.async_create_task(setup_remote_instance(hass)) async_register_admin_service(hass, DOMAIN, SERVICE_RELOAD, _handle_reload, ) instances = config.get(DOMAIN, {}).get(CONF_INSTANCES, []) for instance in instances: hass.async_create_task( hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT}, data=instance ) ) return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): """Set up Remote Home-Assistant from a config entry.""" _async_import_options_from_yaml(hass, entry) if entry.unique_id == REMOTE_ID: hass.async_create_task(setup_remote_instance(hass)) return True else: remote = RemoteConnection(hass, entry) hass.data[DOMAIN][entry.entry_id] = { CONF_REMOTE_CONNECTION: remote, CONF_UNSUB_LISTENER: entry.add_update_listener(_update_listener), } async def setup_components_and_platforms(): """Set up platforms and initiate connection.""" for domain in entry.options.get(CONF_LOAD_COMPONENTS, []): hass.async_create_task(async_setup_component(hass, domain, {})) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await remote.async_connect() hass.async_create_task(setup_components_and_platforms()) return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): """Unload a config entry.""" unload_ok = all( await asyncio.gather( *[ hass.config_entries.async_forward_entry_unload(entry, platform) for platform in PLATFORMS ] ) ) if unload_ok: data = hass.data[DOMAIN].pop(entry.entry_id) await data[CONF_REMOTE_CONNECTION].async_stop() data[CONF_UNSUB_LISTENER]() return unload_ok @callback def _async_import_options_from_yaml(hass: HomeAssistant, entry: ConfigEntry): """Import options from YAML into options section of config entry.""" if CONF_OPTIONS in entry.data: data = entry.data.copy() options = data.pop(CONF_OPTIONS) hass.config_entries.async_update_entry(entry, data=data, options=options) async def _update_listener(hass, config_entry): """Update listener.""" await hass.config_entries.async_reload(config_entry.entry_id) class RemoteConnection: """A Websocket connection to a remote home-assistant instance.""" def __init__(self, hass, config_entry): """Initialize the connection.""" self._hass = hass self._entry = config_entry self._secure = config_entry.data.get(CONF_SECURE, False) self._verify_ssl = config_entry.data.get(CONF_VERIFY_SSL, False) self._access_token = config_entry.data.get(CONF_ACCESS_TOKEN) self._max_msg_size = config_entry.data.get(CONF_MAX_MSG_SIZE, DEFAULT_MAX_MSG_SIZE) # see homeassistant/components/influxdb/__init__.py # for include/exclude logic self._whitelist_e = set(config_entry.options.get(CONF_INCLUDE_ENTITIES, [])) self._whitelist_d = set(config_entry.options.get(CONF_INCLUDE_DOMAINS, [])) self._blacklist_e = set(config_entry.options.get(CONF_EXCLUDE_ENTITIES, [])) self._blacklist_d = set(config_entry.options.get(CONF_EXCLUDE_DOMAINS, [])) self._filter = [ { CONF_ENTITY_ID: re.compile(fnmatch.translate(f.get(CONF_ENTITY_ID))) if f.get(CONF_ENTITY_ID) else None, CONF_UNIT_OF_MEASUREMENT: f.get(CONF_UNIT_OF_MEASUREMENT), CONF_ABOVE: f.get(CONF_ABOVE), CONF_BELOW: f.get(CONF_BELOW), } for f in config_entry.options.get(CONF_FILTER, []) ] self._subscribe_events = set( config_entry.options.get(CONF_SUBSCRIBE_EVENTS, []) + INTERNALLY_USED_EVENTS ) self._entity_prefix = config_entry.options.get( CONF_ENTITY_PREFIX, "") self._entity_friendly_name_prefix = config_entry.options.get( CONF_ENTITY_FRIENDLY_NAME_PREFIX, "") self._connection : Optional[ClientWebSocketResponse] = None self._heartbeat_task = None self._is_stopping = False self._entities = set() self._all_entity_names = set() self._handlers = {} self._remove_listener = None self.proxy_services = ProxyServices(hass, config_entry, self) self.set_connection_state(STATE_CONNECTING) self.__id = 1 def _prefixed_entity_id(self, entity_id): if self._entity_prefix: domain, object_id = split_entity_id(entity_id) object_id = self._entity_prefix + object_id entity_id = domain + "." + object_id return entity_id return entity_id def _prefixed_entity_friendly_name(self, entity_friendly_name): if (self._entity_friendly_name_prefix and entity_friendly_name.startswith(self._entity_friendly_name_prefix) == False): entity_friendly_name = (self._entity_friendly_name_prefix + entity_friendly_name) return entity_friendly_name return entity_friendly_name def _full_picture_url(self, url): baseURL = "%s://%s:%s" % ( "https" if self._secure else "http", self._entry.data[CONF_HOST], self._entry.data[CONF_PORT], ) if url.startswith(baseURL) == False: url = baseURL + url return url return url def set_connection_state(self, state): """Change current connection state.""" signal = f"remote_homeassistant_{self._entry.unique_id}" async_dispatcher_send(self._hass, signal, state) @callback def _get_url(self): """Get url to connect to.""" return "%s://%s:%s/api/websocket" % ( "wss" if self._secure else "ws", self._entry.data[CONF_HOST], self._entry.data[CONF_PORT], ) async def async_connect(self): """Connect to remote home-assistant websocket...""" async def _async_stop_handler(event): """Stop when Home Assistant is shutting down.""" await self.async_stop() async def _async_instance_get_info(): """Fetch discovery info from remote instance.""" try: return await async_get_discovery_info( self._hass, self._entry.data[CONF_HOST], self._entry.data[CONF_PORT], self._secure, self._access_token, self._verify_ssl, ) except OSError: _LOGGER.exception("failed to connect") except UnsupportedVersion: _LOGGER.error("Unsupported version, at least 0.111 is required.") except Exception: _LOGGER.exception("failed to fetch instance info") return None @callback def _async_instance_id_match(info): """Verify if remote instance id matches the expected id.""" if not info: return False if info and info["uuid"] != self._entry.unique_id: _LOGGER.error( "instance id not matching: %s != %s", info["uuid"], self._entry.unique_id, ) return False return True url = self._get_url() session = async_get_clientsession(self._hass, self._verify_ssl) self.set_connection_state(STATE_CONNECTING) while True: info = await _async_instance_get_info() # Verify we are talking to correct instance if not _async_instance_id_match(info): self.set_connection_state(STATE_RECONNECTING) await asyncio.sleep(10) continue try: _LOGGER.info("Connecting to %s", url) self._connection = await session.ws_connect(url, max_msg_size = self._max_msg_size) except aiohttp.client_exceptions.ClientError: _LOGGER.error("Could not connect to %s, retry in 10 seconds...", url) self.set_connection_state(STATE_RECONNECTING) await asyncio.sleep(10) else: _LOGGER.info("Connected to home-assistant websocket at %s", url) break self._hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_stop_handler) device_registry = dr.async_get(self._hass) device_registry.async_get_or_create( config_entry_id=self._entry.entry_id, identifiers={(DOMAIN, f"remote_{self._entry.unique_id}")}, name=info.get("location_name"), manufacturer="Home Assistant", model=info.get("installation_type"), sw_version=info.get("ha_version"), ) asyncio.ensure_future(self._recv()) self._heartbeat_task = self._hass.loop.create_task(self._heartbeat_loop()) async def _heartbeat_loop(self): """Send periodic heartbeats to remote instance.""" while self._connection is not None and not self._connection.closed: await asyncio.sleep(HEARTBEAT_INTERVAL) _LOGGER.debug("Sending ping") event = asyncio.Event() def resp(message): _LOGGER.debug("Got pong: %s", message) event.set() await self.call(resp, "ping") try: await asyncio.wait_for(event.wait(), HEARTBEAT_TIMEOUT) except asyncio.TimeoutError: _LOGGER.warning("heartbeat failed") # Schedule closing on event loop to avoid deadlock asyncio.ensure_future(self._connection.close()) break async def async_stop(self): """Close connection.""" self._is_stopping = True if self._connection is not None: await self._connection.close() await self.proxy_services.unload() def _next_id(self): _id = self.__id self.__id += 1 return _id async def call(self, handler, message_type, **extra_args) -> None: if self._connection is None: _LOGGER.error("No remote websocket connection") return _id = self._next_id() self._handlers[_id] = handler try: await self._connection.send_json( {"id": _id, "type": message_type, **extra_args} ) except aiohttp.client_exceptions.ClientError as err: _LOGGER.error("remote websocket connection closed: %s", err) await self._disconnected() async def _disconnected(self): # Remove all published entries for entity in self._entities: self._hass.states.async_remove(entity) if self._heartbeat_task is not None: self._heartbeat_task.cancel() try: await self._heartbeat_task except asyncio.CancelledError: pass if self._remove_listener is not None: self._remove_listener() self.set_connection_state(STATE_DISCONNECTED) self._heartbeat_task = None self._remove_listener = None self._entities = set() self._all_entity_names = set() if not self._is_stopping: asyncio.ensure_future(self.async_connect()) async def _recv(self): while self._connection is not None and not self._connection.closed: try: data = await self._connection.receive() except aiohttp.client_exceptions.ClientError as err: _LOGGER.error("remote websocket connection closed: %s", err) break if not data: break if data.type in ( aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, ): _LOGGER.debug("websocket connection is closing") break if data.type == aiohttp.WSMsgType.ERROR: _LOGGER.error("websocket connection had an error") if data.data.code == aiohttp.WSCloseCode.MESSAGE_TOO_BIG: _LOGGER.error(f"please consider increasing message size with `{CONF_MAX_MSG_SIZE}`") break try: message = data.json() except TypeError as err: _LOGGER.error("could not decode data (%s) as json: %s", data, err) break if message is None: break _LOGGER.debug("received: %s", message) if message["type"] == api.TYPE_AUTH_OK: self.set_connection_state(STATE_CONNECTED) await self._init() elif message["type"] == api.TYPE_AUTH_REQUIRED: if self._access_token: json_data = {"type": api.TYPE_AUTH, "access_token": self._access_token} else: _LOGGER.error("Access token required, but not provided") self.set_connection_state(STATE_AUTH_REQUIRED) return try: await self._connection.send_json(json_data) except Exception as err: _LOGGER.error("could not send data to remote connection: %s", err) break elif message["type"] == api.TYPE_AUTH_INVALID: _LOGGER.error("Auth invalid, check your access token") self.set_connection_state(STATE_AUTH_INVALID) await self._connection.close() return else: handler = self._handlers.get(message["id"]) if handler is not None: if inspect.iscoroutinefunction(handler): await handler(message) else: handler(message) await self._disconnected() async def _init(self): async def forward_event(event): """Send local event to remote instance. The affected entity_id has to originate from that remote instance, otherwise the event is discarded. """ event_data = event.data service_data = event_data["service_data"] if not service_data: return entity_ids = service_data.get("entity_id", None) if not entity_ids: return if isinstance(entity_ids, str): entity_ids = (entity_ids.lower(),) entities = {entity_id.lower() for entity_id in self._entities} entity_ids = entities.intersection(entity_ids) if not entity_ids: return if self._entity_prefix: def _remove_prefix(entity_id): domain, object_id = split_entity_id(entity_id) object_id = object_id.replace(self._entity_prefix.lower(), "", 1) return domain + "." + object_id entity_ids = {_remove_prefix(entity_id) for entity_id in entity_ids} event_data = copy.deepcopy(event_data) event_data["service_data"]["entity_id"] = list(entity_ids) # Remove service_call_id parameter - websocket API # doesn't accept that one event_data.pop("service_call_id", None) _id = self._next_id() data = {"id": _id, "type": event.event_type, **event_data} _LOGGER.debug("forward event: %s", data) if self._connection is None: _LOGGER.error("There is no remote connecion to send send data to") return try: await self._connection.send_json(data) except Exception as err: _LOGGER.error("could not send data to remote connection: %s", err) await self._disconnected() def state_changed(entity_id, state, attr): """Publish remote state change on local instance.""" domain, _object_id = split_entity_id(entity_id) self._all_entity_names.add(entity_id) if entity_id in self._blacklist_e or domain in self._blacklist_d: return if ( (self._whitelist_e or self._whitelist_d) and entity_id not in self._whitelist_e and domain not in self._whitelist_d ): return for f in self._filter: if f[CONF_ENTITY_ID] and not f[CONF_ENTITY_ID].match(entity_id): continue if f[CONF_UNIT_OF_MEASUREMENT]: if CONF_UNIT_OF_MEASUREMENT not in attr: continue if f[CONF_UNIT_OF_MEASUREMENT] != attr[CONF_UNIT_OF_MEASUREMENT]: continue try: if f[CONF_BELOW] and float(state) < f[CONF_BELOW]: _LOGGER.info( "%s: ignoring state '%s', because below '%s'", entity_id, state, f[CONF_BELOW], ) return if f[CONF_ABOVE] and float(state) > f[CONF_ABOVE]: _LOGGER.info( "%s: ignoring state '%s', because above '%s'", entity_id, state, f[CONF_ABOVE], ) return except ValueError: pass entity_id = self._prefixed_entity_id(entity_id) # Add local unique id domain, object_id = split_entity_id(entity_id) attr['unique_id'] = f"{self._entry.unique_id[:16]}_{entity_id}" entity_registry = er.async_get(self._hass) entity_registry.async_get_or_create( domain=domain, platform='remote_homeassistant', unique_id=attr['unique_id'], suggested_object_id=object_id, ) # Add local customization data if DATA_CUSTOMIZE in self._hass.data: attr.update(self._hass.data[DATA_CUSTOMIZE].get(entity_id)) for attrId, value in attr.items(): if attrId == "friendly_name": attr[attrId] = self._prefixed_entity_friendly_name(value) if attrId == "entity_picture": attr[attrId] = self._full_picture_url(value) self._entities.add(entity_id) self._hass.states.async_set(entity_id, state, attr) def fire_event(message): """Publish remote event on local instance.""" if message["type"] == "result": return if message["type"] != "event": return if message["event"]["event_type"] == "state_changed": data = message["event"]["data"] entity_id = data["entity_id"] if not data["new_state"]: entity_id = self._prefixed_entity_id(entity_id) # entity was removed in the remote instance with suppress(ValueError, AttributeError, KeyError): self._entities.remove(entity_id) with suppress(ValueError, AttributeError, KeyError): self._all_entity_names.remove(entity_id) self._hass.states.async_remove(entity_id) return state = data["new_state"]["state"] attr = data["new_state"]["attributes"] state_changed(entity_id, state, attr) else: event = message["event"] self._hass.bus.async_fire( event_type=event["event_type"], event_data=event["data"], context=Context( id=event["context"].get("id"), user_id=event["context"].get("user_id"), parent_id=event["context"].get("parent_id"), ), origin=EventOrigin.remote, ) def got_states(message): """Called when list of remote states is available.""" for entity in message["result"]: entity_id = entity["entity_id"] state = entity["state"] attributes = entity["attributes"] for attr, value in attributes.items(): if attr == "friendly_name": attributes[attr] = self._prefixed_entity_friendly_name(value) if attr == "entity_picture": attributes[attr] = self._full_picture_url(value) state_changed(entity_id, state, attributes) self._remove_listener = self._hass.bus.async_listen( EVENT_CALL_SERVICE, forward_event ) for event in self._subscribe_events: await self.call(fire_event, "subscribe_events", event_type=event) await self.call(got_states, "get_states") await self.proxy_services.load()