372 lines
15 KiB
Python
372 lines
15 KiB
Python
|
"""Config flow for Remote Home-Assistant integration."""
|
||
|
import logging
|
||
|
import enum
|
||
|
|
||
|
from urllib.parse import urlparse
|
||
|
|
||
|
import homeassistant.helpers.config_validation as cv
|
||
|
import voluptuous as vol
|
||
|
from homeassistant import config_entries, core
|
||
|
from homeassistant.const import (CONF_ABOVE, CONF_ACCESS_TOKEN, CONF_BELOW,
|
||
|
CONF_ENTITY_ID, CONF_HOST, CONF_PORT,
|
||
|
CONF_UNIT_OF_MEASUREMENT, CONF_VERIFY_SSL, CONF_TYPE)
|
||
|
from homeassistant.core import callback
|
||
|
from homeassistant.helpers.instance_id import async_get
|
||
|
from homeassistant.util import slugify
|
||
|
|
||
|
from . import async_yaml_to_config_entry
|
||
|
from .const import (CONF_ENTITY_PREFIX, # pylint:disable=unused-import
|
||
|
CONF_EXCLUDE_DOMAINS, CONF_EXCLUDE_ENTITIES, CONF_FILTER,
|
||
|
CONF_INCLUDE_DOMAINS, CONF_INCLUDE_ENTITIES,
|
||
|
CONF_LOAD_COMPONENTS, CONF_MAIN, CONF_OPTIONS, CONF_REMOTE, CONF_REMOTE_CONNECTION,
|
||
|
CONF_SECURE, CONF_SERVICE_PREFIX, CONF_SERVICES, CONF_MAX_MSG_SIZE,
|
||
|
CONF_SUBSCRIBE_EVENTS, DOMAIN, REMOTE_ID, DEFAULT_MAX_MSG_SIZE)
|
||
|
from .rest_api import (ApiProblem, CannotConnect, EndpointMissing, InvalidAuth,
|
||
|
UnsupportedVersion, async_get_discovery_info)
|
||
|
|
||
|
_LOGGER = logging.getLogger(__name__)
|
||
|
|
||
|
ADD_NEW_EVENT = "add_new_event"
|
||
|
|
||
|
FILTER_OPTIONS = [CONF_ENTITY_ID, CONF_UNIT_OF_MEASUREMENT, CONF_ABOVE, CONF_BELOW]
|
||
|
|
||
|
|
||
|
def _filter_str(index, filter):
|
||
|
entity_id = filter[CONF_ENTITY_ID]
|
||
|
unit = filter[CONF_UNIT_OF_MEASUREMENT]
|
||
|
above = filter[CONF_ABOVE]
|
||
|
below = filter[CONF_BELOW]
|
||
|
return f"{index+1}. {entity_id}, unit: {unit}, above: {above}, below: {below}"
|
||
|
|
||
|
|
||
|
async def validate_input(hass: core.HomeAssistant, conf):
|
||
|
"""Validate the user input allows us to connect."""
|
||
|
try:
|
||
|
info = await async_get_discovery_info(
|
||
|
hass,
|
||
|
conf[CONF_HOST],
|
||
|
conf[CONF_PORT],
|
||
|
conf.get(CONF_SECURE, False),
|
||
|
conf[CONF_ACCESS_TOKEN],
|
||
|
conf.get(CONF_VERIFY_SSL, False),
|
||
|
)
|
||
|
except OSError:
|
||
|
raise CannotConnect()
|
||
|
|
||
|
return {"title": info["location_name"], "uuid": info["uuid"]}
|
||
|
|
||
|
|
||
|
class InstanceType(enum.Enum):
|
||
|
"""Possible options for instance type."""
|
||
|
|
||
|
remote = "Setup as remote node"
|
||
|
main = "Add a remote"
|
||
|
|
||
|
|
||
|
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||
|
"""Handle a config flow for Remote Home-Assistant."""
|
||
|
|
||
|
VERSION = 1
|
||
|
CONNECTION_CLASS = config_entries.CONN_CLASS_LOCAL_PUSH
|
||
|
|
||
|
def __init__(self):
|
||
|
"""Initialize a new ConfigFlow."""
|
||
|
self.prefill = {CONF_PORT: 8123, CONF_SECURE: True, CONF_MAX_MSG_SIZE: DEFAULT_MAX_MSG_SIZE}
|
||
|
|
||
|
@staticmethod
|
||
|
@callback
|
||
|
def async_get_options_flow(config_entry):
|
||
|
"""Get options flow for this handler."""
|
||
|
return OptionsFlowHandler(config_entry)
|
||
|
|
||
|
async def async_step_user(self, user_input=None):
|
||
|
"""Handle the initial step."""
|
||
|
errors = {}
|
||
|
|
||
|
if user_input is not None:
|
||
|
if user_input[CONF_TYPE] == CONF_REMOTE:
|
||
|
await self.async_set_unique_id(REMOTE_ID)
|
||
|
self._abort_if_unique_id_configured()
|
||
|
return self.async_create_entry(title="Remote instance", data=user_input)
|
||
|
|
||
|
elif user_input[CONF_TYPE] == CONF_MAIN:
|
||
|
return await self.async_step_connection_details()
|
||
|
|
||
|
errors["base"] = "unknown"
|
||
|
|
||
|
return self.async_show_form(
|
||
|
step_id="user",
|
||
|
data_schema=vol.Schema(
|
||
|
{
|
||
|
vol.Required(CONF_TYPE): vol.In([CONF_REMOTE, CONF_MAIN])
|
||
|
}
|
||
|
),
|
||
|
errors=errors,
|
||
|
)
|
||
|
|
||
|
|
||
|
async def async_step_connection_details(self, user_input=None):
|
||
|
"""Handle the connection details step."""
|
||
|
errors = {}
|
||
|
if user_input is not None:
|
||
|
try:
|
||
|
info = await validate_input(self.hass, user_input)
|
||
|
except ApiProblem:
|
||
|
errors["base"] = "api_problem"
|
||
|
except CannotConnect:
|
||
|
errors["base"] = "cannot_connect"
|
||
|
except InvalidAuth:
|
||
|
errors["base"] = "invalid_auth"
|
||
|
except UnsupportedVersion:
|
||
|
errors["base"] = "unsupported_version"
|
||
|
except EndpointMissing:
|
||
|
errors["base"] = "missing_endpoint"
|
||
|
except Exception: # pylint: disable=broad-except
|
||
|
_LOGGER.exception("Unexpected exception")
|
||
|
errors["base"] = "unknown"
|
||
|
else:
|
||
|
await self.async_set_unique_id(info["uuid"])
|
||
|
self._abort_if_unique_id_configured()
|
||
|
return self.async_create_entry(title=info["title"], data=user_input)
|
||
|
|
||
|
user_input = user_input or dict()
|
||
|
host = user_input.get(CONF_HOST, self.prefill.get(CONF_HOST) or vol.UNDEFINED)
|
||
|
port = user_input.get(CONF_PORT, self.prefill.get(CONF_PORT) or vol.UNDEFINED)
|
||
|
secure = user_input.get(CONF_SECURE, self.prefill.get(CONF_SECURE) or vol.UNDEFINED)
|
||
|
max_msg_size = user_input.get(CONF_MAX_MSG_SIZE, self.prefill.get(CONF_MAX_MSG_SIZE) or vol.UNDEFINED)
|
||
|
return self.async_show_form(
|
||
|
step_id="connection_details",
|
||
|
data_schema=vol.Schema(
|
||
|
{
|
||
|
vol.Required(CONF_HOST, default=host): str,
|
||
|
vol.Required(CONF_PORT, default=port): int,
|
||
|
vol.Required(CONF_ACCESS_TOKEN, default=user_input.get(CONF_ACCESS_TOKEN, vol.UNDEFINED)): str,
|
||
|
vol.Required(CONF_MAX_MSG_SIZE, default=max_msg_size): int,
|
||
|
vol.Optional(CONF_SECURE, default=secure): bool,
|
||
|
vol.Optional(CONF_VERIFY_SSL, default=user_input.get(CONF_VERIFY_SSL, True)): bool,
|
||
|
}
|
||
|
),
|
||
|
errors=errors,
|
||
|
)
|
||
|
|
||
|
async def async_step_zeroconf(self, info):
|
||
|
"""Handle instance discovered via zeroconf."""
|
||
|
properties = info.properties
|
||
|
port = info.port
|
||
|
uuid = properties["uuid"]
|
||
|
|
||
|
await self.async_set_unique_id(uuid)
|
||
|
self._abort_if_unique_id_configured()
|
||
|
|
||
|
if await async_get(self.hass) == uuid:
|
||
|
return self.async_abort(reason="already_configured")
|
||
|
|
||
|
url = properties.get("internal_url")
|
||
|
if not url:
|
||
|
url = properties.get("base_url")
|
||
|
url = urlparse(url)
|
||
|
|
||
|
self.prefill = {
|
||
|
CONF_HOST: url.hostname,
|
||
|
CONF_PORT: port,
|
||
|
CONF_SECURE: url.scheme == "https",
|
||
|
}
|
||
|
|
||
|
# pylint: disable=no-member # https://github.com/PyCQA/pylint/issues/3167
|
||
|
self.context["identifier"] = self.unique_id
|
||
|
self.context["title_placeholders"] = {"name": properties["location_name"]}
|
||
|
return await self.async_step_connection_details()
|
||
|
|
||
|
async def async_step_import(self, user_input):
|
||
|
"""Handle import from YAML."""
|
||
|
try:
|
||
|
info = await validate_input(self.hass, user_input)
|
||
|
except Exception:
|
||
|
_LOGGER.exception(f"import of {user_input[CONF_HOST]} failed")
|
||
|
return self.async_abort(reason="import_failed")
|
||
|
|
||
|
conf, options = async_yaml_to_config_entry(user_input)
|
||
|
|
||
|
# Options cannot be set here, so store them in a special key and import them
|
||
|
# before setting up an entry
|
||
|
conf[CONF_OPTIONS] = options
|
||
|
|
||
|
await self.async_set_unique_id(info["uuid"])
|
||
|
self._abort_if_unique_id_configured(updates=conf)
|
||
|
|
||
|
return self.async_create_entry(title=f"{info['title']} (YAML)", data=conf)
|
||
|
|
||
|
|
||
|
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||
|
"""Handle options flow for the Home Assistant remote integration."""
|
||
|
|
||
|
def __init__(self, config_entry):
|
||
|
"""Initialize remote_homeassistant options flow."""
|
||
|
self.config_entry = config_entry
|
||
|
self.filters = None
|
||
|
self.events = None
|
||
|
self.options = None
|
||
|
|
||
|
async def async_step_init(self, user_input=None):
|
||
|
"""Manage basic options."""
|
||
|
if self.config_entry.unique_id == REMOTE_ID:
|
||
|
return self.async_abort(reason="not_supported")
|
||
|
|
||
|
if user_input is not None:
|
||
|
self.options = user_input.copy()
|
||
|
return await self.async_step_domain_entity_filters()
|
||
|
|
||
|
domains, _ = self._domains_and_entities()
|
||
|
domains = set(domains + self.config_entry.options.get(CONF_LOAD_COMPONENTS, []))
|
||
|
|
||
|
remote = self.hass.data[DOMAIN][self.config_entry.entry_id][
|
||
|
CONF_REMOTE_CONNECTION
|
||
|
]
|
||
|
|
||
|
return self.async_show_form(
|
||
|
step_id="init",
|
||
|
data_schema=vol.Schema(
|
||
|
{
|
||
|
vol.Optional(
|
||
|
CONF_ENTITY_PREFIX,
|
||
|
description={
|
||
|
"suggested_value": self.config_entry.options.get(
|
||
|
CONF_ENTITY_PREFIX
|
||
|
)
|
||
|
},
|
||
|
): str,
|
||
|
vol.Optional(
|
||
|
CONF_LOAD_COMPONENTS,
|
||
|
default=self._default(CONF_LOAD_COMPONENTS),
|
||
|
): cv.multi_select(sorted(domains)),
|
||
|
vol.Required(
|
||
|
CONF_SERVICE_PREFIX, default=self.config_entry.options.get(CONF_SERVICE_PREFIX) or slugify(self.config_entry.title)
|
||
|
): str,
|
||
|
vol.Optional(
|
||
|
CONF_SERVICES,
|
||
|
default=self._default(CONF_SERVICES),
|
||
|
): cv.multi_select(remote.proxy_services.services),
|
||
|
}
|
||
|
),
|
||
|
)
|
||
|
|
||
|
async def async_step_domain_entity_filters(self, user_input=None):
|
||
|
"""Manage domain and entity filters."""
|
||
|
if user_input is not None:
|
||
|
self.options.update(user_input)
|
||
|
return await self.async_step_general_filters()
|
||
|
|
||
|
domains, entities = self._domains_and_entities()
|
||
|
return self.async_show_form(
|
||
|
step_id="domain_entity_filters",
|
||
|
data_schema=vol.Schema(
|
||
|
{
|
||
|
vol.Optional(
|
||
|
CONF_INCLUDE_DOMAINS,
|
||
|
default=self._default(CONF_INCLUDE_DOMAINS),
|
||
|
): cv.multi_select(domains),
|
||
|
vol.Optional(
|
||
|
CONF_INCLUDE_ENTITIES,
|
||
|
default=self._default(CONF_INCLUDE_ENTITIES),
|
||
|
): cv.multi_select(entities),
|
||
|
vol.Optional(
|
||
|
CONF_EXCLUDE_DOMAINS,
|
||
|
default=self._default(CONF_EXCLUDE_DOMAINS),
|
||
|
): cv.multi_select(domains),
|
||
|
vol.Optional(
|
||
|
CONF_EXCLUDE_ENTITIES,
|
||
|
default=self._default(CONF_EXCLUDE_ENTITIES),
|
||
|
): cv.multi_select(entities),
|
||
|
}
|
||
|
),
|
||
|
)
|
||
|
|
||
|
async def async_step_general_filters(self, user_input=None):
|
||
|
"""Manage domain and entity filters."""
|
||
|
if user_input is not None:
|
||
|
# Continue to next step if entity id is not specified
|
||
|
if CONF_ENTITY_ID not in user_input:
|
||
|
# Each filter string is prefixed with a number (index in self.filter+1).
|
||
|
# Extract all of them and build the final filter list.
|
||
|
selected_indices = [
|
||
|
int(filter.split(".")[0]) - 1
|
||
|
for filter in user_input.get(CONF_FILTER, [])
|
||
|
]
|
||
|
self.options[CONF_FILTER] = [self.filters[i] for i in selected_indices]
|
||
|
return await self.async_step_events()
|
||
|
|
||
|
selected = user_input.get(CONF_FILTER, [])
|
||
|
new_filter = {conf: user_input.get(conf) for conf in FILTER_OPTIONS}
|
||
|
selected.append(_filter_str(len(self.filters), new_filter))
|
||
|
self.filters.append(new_filter)
|
||
|
else:
|
||
|
self.filters = self.config_entry.options.get(CONF_FILTER, [])
|
||
|
selected = [_filter_str(i, filter) for i, filter in enumerate(self.filters)]
|
||
|
|
||
|
strings = [_filter_str(i, filter) for i, filter in enumerate(self.filters)]
|
||
|
return self.async_show_form(
|
||
|
step_id="general_filters",
|
||
|
data_schema=vol.Schema(
|
||
|
{
|
||
|
vol.Optional(CONF_FILTER, default=selected): cv.multi_select(
|
||
|
strings
|
||
|
),
|
||
|
vol.Optional(CONF_ENTITY_ID): str,
|
||
|
vol.Optional(CONF_UNIT_OF_MEASUREMENT): str,
|
||
|
vol.Optional(CONF_ABOVE): vol.Coerce(float),
|
||
|
vol.Optional(CONF_BELOW): vol.Coerce(float),
|
||
|
}
|
||
|
),
|
||
|
)
|
||
|
|
||
|
async def async_step_events(self, user_input=None):
|
||
|
"""Manage event options."""
|
||
|
if user_input is not None:
|
||
|
if ADD_NEW_EVENT not in user_input:
|
||
|
self.options[CONF_SUBSCRIBE_EVENTS] = user_input.get(
|
||
|
CONF_SUBSCRIBE_EVENTS, []
|
||
|
)
|
||
|
return self.async_create_entry(title="", data=self.options)
|
||
|
|
||
|
selected = user_input.get(CONF_SUBSCRIBE_EVENTS, [])
|
||
|
self.events.add(user_input[ADD_NEW_EVENT])
|
||
|
selected.append(user_input[ADD_NEW_EVENT])
|
||
|
else:
|
||
|
self.events = set(
|
||
|
self.config_entry.options.get(CONF_SUBSCRIBE_EVENTS) or []
|
||
|
)
|
||
|
selected = self._default(CONF_SUBSCRIBE_EVENTS)
|
||
|
|
||
|
return self.async_show_form(
|
||
|
step_id="events",
|
||
|
data_schema=vol.Schema(
|
||
|
{
|
||
|
vol.Optional(
|
||
|
CONF_SUBSCRIBE_EVENTS, default=selected
|
||
|
): cv.multi_select(self.events),
|
||
|
vol.Optional(ADD_NEW_EVENT): str,
|
||
|
}
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def _default(self, conf):
|
||
|
"""Return default value for an option."""
|
||
|
return self.config_entry.options.get(conf) or vol.UNDEFINED
|
||
|
|
||
|
def _domains_and_entities(self):
|
||
|
"""Return all entities and domains exposed by remote instance."""
|
||
|
remote = self.hass.data[DOMAIN][self.config_entry.entry_id][
|
||
|
CONF_REMOTE_CONNECTION
|
||
|
]
|
||
|
|
||
|
# Include entities we have in the config explicitly, otherwise they will be
|
||
|
# pre-selected and not possible to remove if they are no lobger present on
|
||
|
# the remote host.
|
||
|
include_entities = set(self.config_entry.options.get(CONF_INCLUDE_ENTITIES, []))
|
||
|
exclude_entities = set(self.config_entry.options.get(CONF_EXCLUDE_ENTITIES, []))
|
||
|
entities = sorted(
|
||
|
remote._all_entity_names | include_entities | exclude_entities
|
||
|
)
|
||
|
domains = sorted(set([entity_id.split(".")[0] for entity_id in entities]))
|
||
|
return domains, entities
|