8000 feat(anta)!: Use HTTP HEAD request instead of a TCP connection to check if device is online by dlobato · Pull Request #851 · aristanetworks/anta · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

feat(anta)!: Use HTTP HEAD request instead of a TCP connection to check if device is online #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions anta/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,29 +539,36 @@ async def refresh(self) -> None:
"""Update attributes of an AsyncEOSDevice instance.

This coroutine must update the following attributes of AsyncEOSDevice:
- is_online: When a device IP is reachable and a port can be open
- is_online: When a device eAPI HTTP endpoint is accessible
- established: When a command execution succeeds
- hw_model: The hardware model of the device
"""
logger.debug("Refreshing device %s", self.name)
self.is_online = await self._session.check_connection()
if self.is_online:
show_version = AntaCommand(command="show version")
await self._collect(show_version)
if not show_version.collected:
logger.warning("Cannot get hardware information from device %s", self.name)
else:
self.hw_model = show_version.json_output.get("modelName", None)
if self.hw_model is None:
logger.critical("Cannot parse 'show version' returned by device %s", self.name)
# in some cases it is possible that 'modelName' comes back empty
# and it is nice to get a meaninfule error message
elif self.hw_model == "":
logger.critical("Got an empty 'modelName' in the 'show version' returned by device %s", self.name)
try:
self.is_online = await self._session.check_api_endpoint()
except HTTPError as e:
self.is_online = False
self.established = False
logger.warning("Could not connect to device %s: %s", self.name, e)
return

show_version = AntaCommand(command="show version")
await self._collect(show_version)
if not show_version.collected:
self.established = False
logger.warning("Cannot get hardware information from device %s", self.name)
return

self.hw_model = show_version.json_output.get("modelName", None)
if self.hw_model is None:
self.established = False
logger.critical("Cannot parse 'show version' returned by device %s", self.name)
# in some cases it is possible that 'modelName' comes back empty
elif self.hw_model == "":
self.established = False
logger.critical("Got an empty 'modelName' in the 'show version' returned by device %s", self.name)
else:
logger.warning("Could not connect to device %s: cannot open eAPI port", self.name)

self.established = bool(self.is_online and self.hw_model)
self.established = True

async def copy(self, sources: list[Path], destination: Path, direction: Literal["to", "from"] = "from") -> None:
"""Copy files to and from the device using asyncssh.scp().
Expand Down
21 changes: 20 additions & 1 deletion asynceapi/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# Public Imports
# -----------------------------------------------------------------------------
import httpx
from typing_extensions import deprecated

# -----------------------------------------------------------------------------
# Private Imports
Expand Down Expand Up @@ -51,6 +52,7 @@ class Device(httpx.AsyncClient):
"""

auth = None
EAPI_COMMAND_API_URL = "/command-api"
EAPI_OFMT_OPTIONS = ("json", "text")
EAPI_DEFAULT_OFMT = "json"

Expand Down Expand Up @@ -109,6 +111,7 @@ def __init__(
super().__init__(**kwargs)
self.headers["Content-Type"] = "application/json-rpc"

@deprecated("This method is deprecated, use `Device.check_api_endpoint` method instead. This will be removed in ANTA v2.0.0.", category=DeprecationWarning)
async def check_connection(self) -> bool:
"""Check the target device to ensure that the eAPI port is open and accepting connections.

Expand All @@ -122,6 +125,22 @@ async def check_connection(self) -> bool:
"""
return await port_check_url(self.base_url)

async def check_api_endpoint(self) -> bool:
"""Check the target device eAPI HTTP endpoint with a HEAD request.

It is recommended that a Caller checks the connection before involving cli commands,
but this step is not required.

Returns
-------
bool
True when the device eAPI HTTP endpoint is accessible (2xx status code),
otherwise an HTTPStatusError exception is raised.
"""
response = await self.head(self.EAPI_COMMAND_API_URL, timeout=5)
response.raise_for_status()
return True

# Single command, JSON output, no suppression
@overload
async def cli(
Expand Down Expand Up @@ -416,7 +435,7 @@ async def jsonrpc_exec(self, jsonrpc: JsonRpc) -> list[EapiJsonOutput] | list[Ea
The list of command results; either dict or text depending on the
JSON-RPC format parameter.
"""
res = await self.post("/command-api", json=jsonrpc)
res = await self.post(self.EAPI_COMMAND_API_URL, json=jsonrpc)
res.raise_for_status()
body = res.json()

Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
# that can be found in the LICENSE file.
"""See https://docs.pytest.org/en/stable/reference/fixtures.html#conftest-py-sharing-fixtures-across-multiple-files."""

import asyncio
from collections.abc import Iterator
from pathlib import Path
from unittest.mock import AsyncMock, Mock, patch

import pytest
import respx
Expand Down Expand Up @@ -42,7 +40,8 @@ def inventory(request: pytest.FixtureRequest) -> Iterator[AntaInventory]:
)
if reachable:
# This context manager makes all devices reachable
with patch("asyncio.open_connection", AsyncMock(spec=asyncio.open_connection, return_value=(Mock(), Mock()))), respx.mock:
with respx.mock:
respx.head(path="/command-api")
respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}, json__params__cmds__0__cmd="show version").respond(
json={
"result": [
Expand All @@ -54,5 +53,6 @@ def inventory(request: pytest.FixtureRequest) -> Iterator[AntaInventory]:
)
yield inv
else:
with patch("asyncio.open_connection", AsyncMock(spec=asyncio.open_connection, side_effect=TimeoutError)):
with respx.mock:
respx.head(path="/command-api").respond(status_code=401)
yield inv
2 changes: 1 addition & 1 deletion tests/units/cli/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def get_output(command: str | dict[str, Any]) -> dict[str, Any]:

# Patch asynceapi methods used by AsyncEOSDevice. See tests/units/test_device.py
with (
patch("asynceapi.device.Device.check_connection", return_value=True),
patch("asynceapi.device.Device.check_api_endpoint", return_value=True),
patch("asynceapi.device.Device.cli", side_effect=cli),
patch("asyncssh.connect"),
patch(
Expand Down
29 changes: 4 additions & 25 deletions tests/units/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,29 +429,8 @@
pytest.param(
{},
(
{"return_value": False},
{
"return_value": {
"mfgName": "Arista",
"modelName": "DCS-7280CR3-32P4-F",
"hardwareRevision": "11.00",
"serialNumber": "JPE19500066",
"systemMacAddress": "fc:bd:67:3d:13:c5",
"hwMacAddress": "fc:bd:67:3d:13:c5",
"configMacAddress": "00:00:00:00:00:00",
"version": "4.31.1F-34361447.fraserrel (engineering build)",
"architecture": "x86_64",
"internalVersion": "4.31.1F-34361447.fraserrel",
"internalBuildId": "4940d112-a2fc-4970-8b5a-a16cd03fd08c",
"imageFormatVersion": "3.0",
"imageOptimization": "Default",
"bootupTimestamp": 1700729434.5892005,
"uptime": 20666.78,
"memTotal": 8099732,
"memFree": 4989568,
"isIntlVersion": False,
}
},
{"side_effect": HTTPError(message="Unauthorized")},
{},
),
{"is_online": False, "established": False, "hw_model": None},
id="is not online",
Expand Down Expand Up @@ -653,9 +632,9 @@ def test__eq(self, device1: dict[str, Any], device2: dict[str, Any], expected: b
)
async def test_refresh(self, async_device: AsyncEOSDevice, patch_kwargs: list[dict[str, Any]], expected: dict[str, Any]) -> None:
"""Test AsyncEOSDevice.refresh()."""
with patch.object(async_device._session, "check_connection", **patch_kwargs[0]), patch.object(async_device._session, "cli", **patch_kwargs[1]):
with patch.object(async_device._session, "check_api_endpoint", **patch_kwargs[0]), patch.object(async_device._session, "cli", **patch_kwargs[1]):
await async_device.refresh()
async_device._session.check_connection.assert_called_once() # type: ignore[attr-defined] # asynceapi.Device.check_connection is patched
async_device._session.check_api_endpoint.assert_called_once() # type: ignore[attr-defined] # asynceapi.Device.check_api_endpoint is patched
if expected["is_online"]:
async_device._session.cli.assert_called_once() # type: ignore[attr-defined] # asynceapi.Device.cli is patched
assert async_device.is_online == expected["is_online"]
Expand Down
0