8000 Fuzz the HTTP endpoint, make sure it never hangs. by slongfield · Pull Request #44 · cceckman/http-accel · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

Fuzz the HTTP endpoint, make sure it never hangs. #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions http_server/simple_http_fuzz_examples_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import string

from amaranth.sim import Simulator

from .simple_led_http import SimpleLedHttp
from stream_fixtures import StreamCollector
from hypothesis import given, strategies as st, settings, Phase, Verbosity
from hypothesis.errors import InvalidArgument

def run_fuzz_http_request(method, path, headers, body):
"""
Same test as in simple_http_fuzz_test, but for manually re-running failures.
"""
dut = SimpleLedHttp()
sim = Simulator(dut)
sim.add_clock(1e-6)

header = "".join(f"{k} : {v}\r\n" for k,v in headers.items())

input = f"{method} {path} HTTP/1.0\r\n{header}\r\n\r\n{body}"
sys.stderr.write(f"Testing with {input}")


async def driver(ctx):
ctx.set(dut.session.inbound.active, 1)
await ctx.tick().until(dut.session.outbound.active)
in_stream = dut.session.inbound.data
ctx.set(in_stream.valid, 1)
idx = 0
while idx < len(input):
ctx.set(in_stream.payload, ord(input[idx]))
if ctx.get(in_stream.ready):
idx += 1
await ctx.tick()
ctx.set(dut.session.inbound.active, 0)
await ctx.tick().until(~dut.session.outbound.active)
assert not ctx.get(dut.session.outbound.data.valid)
await ctx.tick()

sim.add_testbench(driver)
collector = StreamCollector(stream=dut.session.outbound.data)
sim.add_process(collector.collect())
with sim.write_vcd(sys.stdout):
sim.run_until(0.01)

# All we're really checking is that every packet gets _some_ kind of response.
assert len(collector) != 0

def test_simple_get():
run_fuzz_http_request("POST", "/led", {}, "")
73 changes: 73 additions & 0 deletions http_server/simple_http_fuzz_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import sys
import string

from amaranth.sim import Simulator

from .simple_led_http import SimpleLedHttp
from stream_fixtures import StreamCollector
from hypothesis import given, strategies as st, settings, Phase, Verbosity
from hypothesis.errors import InvalidArgument

st_methods = st.sampled_from(["GET", "POST", "PUT", "DELETE", "BREW"])
st_paths = st.sampled_from(["/", "/led", "/count", "/coffee", "/asdf"])

st_header_names = st.sampled_from(["Host", "User-Agent", "Content-Type",
"Content-Length", "Accept", "Accept-Additions" "Cookie"])
st_header_values = st.text(
alphabet=st.characters(codec='utf-8', exclude_characters="\r\n"),
min_size=1,
max_size=32)
st_headers = st.dictionaries(st_header_names, st_header_values, min_size=0, max_size=10)

st_bodies = st.text(
alphabet=st.characters(codec='utf-8'),
min_size=0,
max_size=256)

@settings(
max_examples=2, # Increase for more testing.
verbosity=Verbosity.normal,
deadline=None,
)
@given(
method=st_methods,
path=st_paths,
headers=st_headers,
body=st_bodies
)
def test_fuzz_http_request(method, path, headers, body):
dut = SimpleLedHttp()
sim = Simulator(dut)
sim.add_clock(1e-6)

header = "".join(f"{k} : {v}\r\n" for k,v in headers.items())

input = f"{method} {path} HTTP/1.0\r\n{header}\r\n\r\n{body}"
sys.stderr.write(f"Testing with {input}")


async def driver(ctx):
ctx.set(dut.session.inbound.active, 1)
await ctx.tick().until(dut.session.outbound.active)
in_stream = dut.session.inbound.data
ctx.set(in_stream.valid, 1)
idx = 0
while idx < len(input):
ctx.set(in_stream.payload, ord(input[idx]))
if ctx.get(in_stream.ready):
idx += 1
await ctx.tick()
ctx.set(dut.session.inbound.active, 0)
await ctx.tick().until(~dut.session.outbound.active)
assert not ctx.get(dut.session.outbound.data.valid)
await ctx.tick()

sim.add_testbench(driver)
collector = StreamCollector(stream=dut.session.outbound.data)
sim.add_process(collector.collect())
sim.run_until(0.01)

# All we're really checking is that every packet gets _some_ kind of response.
sys.stderr.write(f"Got response {collector.body}")

assert len(collector) != 0
98 changes: 69 additions & 29 deletions http_server/simple_led_http.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from amaranth import Module
from amaranth import Module, Signal
from amaranth.lib.wiring import In, Out, Component, connect

from .count_body import CountBody
Expand All @@ -24,6 +24,9 @@ class SimpleLedHttp(Component):
A GET from /count will return the number of requests and
responses.

If there's no activity for 1024 cycles, will assume something
went wrong.

Attributes
----------
session: BidiSessionSignature
Expand Down Expand Up @@ -72,7 +75,7 @@ def elaborate(self, _platform):
m.d.comb += parser_demux.outs[HTTP_PARSER_SINK].ready.eq(1)

## Responders
response_mux = m.submodules.response_mux = StreamMux(mux_width=5, stream_width=8)
response_mux = m.submodules.response_mux = StreamMux(mux_width=6, stream_width=8)
connect(m, response_mux.out, self.session.outbound.data)
count_body = m.submodules.count_body = CountBody()

Expand Down Expand Up @@ -143,7 +146,7 @@ def elaborate(self, _platform):
teapot_printer.en.eq(1),
count_body.inc_error.eq(1),
]

RESPONSE_COUNT = 4
connect(m, count_body.output, response_mux.input[RESPONSE_COUNT])
send_count = [
Expand All @@ -152,39 +155,75 @@ def elaborate(self, _platform):
count_body.en.eq(1),
]

# Response to send if something went wrong
bad_response = "\r\n".join(
["HTTP/1.0 500 Internal Server Error",
"Host: Fomu",
"Content-Type: text/plain; charset=utf-8",
"",
"uh-oh"]) + "\r\n"
bad_response = bad_response.encode("utf-8")
bad_printer = m.submodules.bad_printer = Printer(bad_response)
RESPONSE_BAD = 5
connect(m, bad_printer.output, response_mux.input[RESPONSE_BAD])
send_bad = [
response_mux.select.eq(RESPONSE_BAD),
parser_demux.select.eq(HTTP_PARSER_SINK),
bad_printer.en.eq(1),
count_body.inc_error.eq(1),
]

TIMEOUT_CYCLES=1023
import math
timeout_count = Signal(math.ceil(math.log2(TIMEOUT_CYCLES)))
timeout = Signal(1)
active = Signal(1)
m.d.comb += active.eq( (self.session.inbound.data.valid & self.session.inbound.data.ready)
|(self.session.outbound.data.valid))
with m.If(active | timeout):
m.d.sync += timeout_count.eq(0)
with m.Elif(timeout_count < TIMEOUT_CYCLES):
m.d.sync += timeout_count.eq(timeout_count + 1)
m.d.comb += timeout.eq(timeout_count >= TIMEOUT_CYCLES)

with m.FSM():
with m.State("reset"):
m.d.comb += [
start_matcher.reset.eq(1),
skip_headers.reset.eq(1),
start_matcher.reset.eq(1),
skip_headers.reset.eq(1),
]
m.next = "idle"
with m.State("idle"):
m.d.comb += [
start_matcher.reset.eq(0),
skip_headers.reset.eq(0),
start_matcher.reset.eq(0),
skip_headers.reset.eq(0),
]
m.d.sync += [
parser_demux.select.eq(HTTP_PARSER_START),
response_mux.select.eq(RESPONSE_OK),
parser_demux.select.eq(HTTP_PARSER_START),
response_mux.select.eq(RESPONSE_OK),
]
m.next = "idle"
with m.If(self.session.inbound.active):
m.next = "parsing_start"
m.d.sync += [
self.session.outbound.active.eq(1),
count_body.inc_requests.eq(1),
self.session.outbound.active.eq(1),
count_body.inc_requests.eq(1),
]
with m.State("parsing_start"):
m.next = "parsing_start"
m.d.sync += count_body.inc_requests.eq(0)
# start line matched successfully
with m.If(start_matcher.done):
with m.If(timeout):
m.next = "writing"
m.d.sync += send_bad
with m.Elif(start_matcher.done):
m.next = "parsing_header"
m.d.sync += parser_demux.select.eq(HTTP_PARSER_HEADERS)
with m.State("parsing_header"):
m.next = "parsing_header"
with m.If(skip_headers.accepted):
with m.If(timeout):
m.next = "writing"
m.d.sync += send_bad
with m.Elif(skip_headers.accepted):
with m.If(start_matcher.path[MATCHED_LED_PATH]):
with m.If(start_matcher.method[start_matcher.METHOD_POST]):
m.next = "parsing_led_body"
Expand All @@ -200,8 +239,8 @@ def elaborate(self, _platform):
m.next = "writing"
m.d.sync += send_405
with m.Elif(start_matcher.path[MATCHED_COFFEE_PATH]):
with m.If(start_matcher.method[start_matcher.METHOD_GET]
| start_matcher.method[start_matcher.METHOD_BREW]):
with m.If(start_matcher.method[start_matcher.METHOD_GET]
| start_matcher.method[start_matcher.METHOD_BREW]):
m.next = "writing"
m.d.sync += send_teapot
with m.Else():
Expand All @@ -212,25 +251,24 @@ def elaborate(self, _platform):
m.d.sync += send_404
with m.Elif(~self.session.inbound.active):
m.next = "writing"
# TODO: #4 - Should send a different error code besides 404 if the
# headers fail to parse before end-of-session.
m.d.sync += send_404
with m.State("parsing_led_body"): # TODO: #4 - Make body parsing state more generic.
m.next = "parsing_led_body"
with m.If(led_body_handler.accepted):
with m.If(timeout):
m.next = "writing"
m.d.sync += send_bad
with m.Elif(led_body_handler.accepted):
m.next = "writing"
m.d.sync += send_ok
with m.Elif(led_body_handler.rejected):
m.next = "writing"
# TODO: #4 - Should send a different error code besides 404 if the
# body fails to parse before end-of-session.
m.d.sync += send_404
with m.State("writing_count_ok"):
m.next = "writing_count_ok"
m.d.sync += [
ok_printer.en.eq(0),
count_body.inc_ok.eq(0)
]
ok_printer.en.eq(0),
count_body.inc_ok.eq(0)
]
with m.If(ok_printer.done):
m.d.sync += send_count
m.next = "writing"
Expand All @@ -241,16 +279,18 @@ def elaborate(self, _platform):
not_found_printer.en.eq(0),
not_allowed_printer.en.eq(0),
teapot_printer.en.eq(0),
bad_printer.en.eq(0),
count_body.en.eq(0),
self.session.outbound.active.eq(1),
count_body.inc_ok.eq(0),
count_body.inc_error.eq(0),
]
]
with m.If( ((response_mux.select == RESPONSE_OK) & ok_printer.done)
| ((response_mux.select == RESPONSE_404) & not_found_printer.done)
| ((response_mux.select == RESPONSE_405) & not_allowed_printer.done)
| ((response_mux.select == RESPONSE_COUNT) & count_body.done)
| ((response_mux.select == RESPONSE_TEAPOT) & teapot_printer.done)):
| ((response_mux.select == RESPONSE_404) & not_found_printer.done)
| ((response_mux.select == RESPONSE_405) & not_allowed_printer.done)
| ((response_mux.select == RESPONSE_COUNT) & count_body.done)
| ((response_mux.select == RESPONSE_TEAPOT) & teapot_printer.done)
| ((response_mux.select == RESPONSE_BAD) & bad_printer.done)):
m.d.sync += self.session.outbound.active.eq(0)
# Can finish writing before all the input is collected,
# since a bad request migh trigger an early 404. Wait
Expand Down
0