diff --git a/http_server/simple_http_fuzz_examples_test.py b/http_server/simple_http_fuzz_examples_test.py new file mode 100644 index 0000000..7bd6aee --- /dev/null +++ b/http_server/simple_http_fuzz_examples_test.py @@ -0,0 +1,51 @@ +import sys +import string + +from amaranth.sim import Simulator + +from .simple_led_http import SimpleLedHttp +from stream_fixtures import StreamCollector +from hypothesis import given, strategies as st, settings, Phase, Verbosity +from hypothesis.errors import InvalidArgument + +def run_fuzz_http_request(method, path, headers, body): + """ + Same test as in simple_http_fuzz_test, but for manually re-running failures. + """ + dut = SimpleLedHttp() + sim = Simulator(dut) + sim.add_clock(1e-6) + + header = "".join(f"{k} : {v}\r\n" for k,v in headers.items()) + + input = f"{method} {path} HTTP/1.0\r\n{header}\r\n\r\n{body}" + sys.stderr.write(f"Testing with {input}") + + + async def driver(ctx): + ctx.set(dut.session.inbound.active, 1) + await ctx.tick().until(dut.session.outbound.active) + in_stream = dut.session.inbound.data + ctx.set(in_stream.valid, 1) + idx = 0 + while idx < len(input): + ctx.set(in_stream.payload, ord(input[idx])) + if ctx.get(in_stream.ready): + idx += 1 + await ctx.tick() + ctx.set(dut.session.inbound.active, 0) + await ctx.tick().until(~dut.session.outbound.active) + assert not ctx.get(dut.session.outbound.data.valid) + await ctx.tick() + + sim.add_testbench(driver) + collector = StreamCollector(stream=dut.session.outbound.data) + sim.add_process(collector.collect()) + with sim.write_vcd(sys.stdout): + sim.run_until(0.01) + + # All we're really checking is that every packet gets _some_ kind of response. + assert len(collector) != 0 + +def test_simple_get(): + run_fuzz_http_request("POST", "/led", {}, "") diff --git a/http_server/simple_http_fuzz_test.py b/http_server/simple_http_fuzz_test.py new file mode 100644 index 0000000..9354f52 --- /dev/null +++ b/http_server/simple_http_fuzz_test.py @@ -0,0 +1,73 @@ +import sys +import string + +from amaranth.sim import Simulator + +from .simple_led_http import SimpleLedHttp +from stream_fixtures import StreamCollector +from hypothesis import given, strategies as st, settings, Phase, Verbosity +from hypothesis.errors import InvalidArgument + +st_methods = st.sampled_from(["GET", "POST", "PUT", "DELETE", "BREW"]) +st_paths = st.sampled_from(["/", "/led", "/count", "/coffee", "/asdf"]) + +st_header_names = st.sampled_from(["Host", "User-Agent", "Content-Type", + "Content-Length", "Accept", "Accept-Additions" "Cookie"]) +st_header_values = st.text( + alphabet=st.characters(codec='utf-8', exclude_characters="\r\n"), + min_size=1, + max_size=32) +st_headers = st.dictionaries(st_header_names, st_header_values, min_size=0, max_size=10) + +st_bodies = st.text( + alphabet=st.characters(codec='utf-8'), + min_size=0, + max_size=256) + +@settings( + max_examples=2, # Increase for more testing. + verbosity=Verbosity.normal, + deadline=None, +) +@given( + method=st_methods, + path=st_paths, + headers=st_headers, + body=st_bodies +) +def test_fuzz_http_request(method, path, headers, body): + dut = SimpleLedHttp() + sim = Simulator(dut) + sim.add_clock(1e-6) + + header = "".join(f"{k} : {v}\r\n" for k,v in headers.items()) + + input = f"{method} {path} HTTP/1.0\r\n{header}\r\n\r\n{body}" + sys.stderr.write(f"Testing with {input}") + + + async def driver(ctx): + ctx.set(dut.session.inbound.active, 1) + await ctx.tick().until(dut.session.outbound.active) + in_stream = dut.session.inbound.data + ctx.set(in_stream.valid, 1) + idx = 0 + while idx < len(input): + ctx.set(in_stream.payload, ord(input[idx])) + if ctx.get(in_stream.ready): + idx += 1 + await ctx.tick() + ctx.set(dut.session.inbound.active, 0) + await ctx.tick().until(~dut.session.outbound.active) + assert not ctx.get(dut.session.outbound.data.valid) + await ctx.tick() + + sim.add_testbench(driver) + collector = StreamCollector(stream=dut.session.outbound.data) + sim.add_process(collector.collect()) + sim.run_until(0.01) + + # All we're really checking is that every packet gets _some_ kind of response. + sys.stderr.write(f"Got response {collector.body}") + + assert len(collector) != 0 diff --git a/http_server/simple_led_http.py b/http_server/simple_led_http.py index 220797f..0333b9a 100644 --- a/http_server/simple_led_http.py +++ b/http_server/simple_led_http.py @@ -1,4 +1,4 @@ -from amaranth import Module +from amaranth import Module, Signal from amaranth.lib.wiring import In, Out, Component, connect from .count_body import CountBody @@ -24,6 +24,9 @@ class SimpleLedHttp(Component): A GET from /count will return the number of requests and responses. + If there's no activity for 1024 cycles, will assume something + went wrong. + Attributes ---------- session: BidiSessionSignature @@ -72,7 +75,7 @@ def elaborate(self, _platform): m.d.comb += parser_demux.outs[HTTP_PARSER_SINK].ready.eq(1) ## Responders - response_mux = m.submodules.response_mux = StreamMux(mux_width=5, stream_width=8) + response_mux = m.submodules.response_mux = StreamMux(mux_width=6, stream_width=8) connect(m, response_mux.out, self.session.outbound.data) count_body = m.submodules.count_body = CountBody() @@ -143,7 +146,7 @@ def elaborate(self, _platform): teapot_printer.en.eq(1), count_body.inc_error.eq(1), ] - + RESPONSE_COUNT = 4 connect(m, count_body.output, response_mux.input[RESPONSE_COUNT]) send_count = [ @@ -152,39 +155,75 @@ def elaborate(self, _platform): count_body.en.eq(1), ] + # Response to send if something went wrong + bad_response = "\r\n".join( + ["HTTP/1.0 500 Internal Server Error", + "Host: Fomu", + "Content-Type: text/plain; charset=utf-8", + "", + "uh-oh"]) + "\r\n" + bad_response = bad_response.encode("utf-8") + bad_printer = m.submodules.bad_printer = Printer(bad_response) + RESPONSE_BAD = 5 + connect(m, bad_printer.output, response_mux.input[RESPONSE_BAD]) + send_bad = [ + response_mux.select.eq(RESPONSE_BAD), + parser_demux.select.eq(HTTP_PARSER_SINK), + bad_printer.en.eq(1), + count_body.inc_error.eq(1), + ] + + TIMEOUT_CYCLES=1023 + import math + timeout_count = Signal(math.ceil(math.log2(TIMEOUT_CYCLES))) + timeout = Signal(1) + active = Signal(1) + m.d.comb += active.eq( (self.session.inbound.data.valid & self.session.inbound.data.ready) + |(self.session.outbound.data.valid)) + with m.If(active | timeout): + m.d.sync += timeout_count.eq(0) + with m.Elif(timeout_count < TIMEOUT_CYCLES): + m.d.sync += timeout_count.eq(timeout_count + 1) + m.d.comb += timeout.eq(timeout_count >= TIMEOUT_CYCLES) + with m.FSM(): with m.State("reset"): m.d.comb += [ - start_matcher.reset.eq(1), - skip_headers.reset.eq(1), + start_matcher.reset.eq(1), + skip_headers.reset.eq(1), ] m.next = "idle" with m.State("idle"): m.d.comb += [ - start_matcher.reset.eq(0), - skip_headers.reset.eq(0), + start_matcher.reset.eq(0), + skip_headers.reset.eq(0), ] m.d.sync += [ - parser_demux.select.eq(HTTP_PARSER_START), - response_mux.select.eq(RESPONSE_OK), + parser_demux.select.eq(HTTP_PARSER_START), + response_mux.select.eq(RESPONSE_OK), ] m.next = "idle" with m.If(self.session.inbound.active): m.next = "parsing_start" m.d.sync += [ - self.session.outbound.active.eq(1), - count_body.inc_requests.eq(1), + self.session.outbound.active.eq(1), + count_body.inc_requests.eq(1), ] with m.State("parsing_start"): m.next = "parsing_start" m.d.sync += count_body.inc_requests.eq(0) - # start line matched successfully - with m.If(start_matcher.done): + with m.If(timeout): + m.next = "writing" + m.d.sync += send_bad + with m.Elif(start_matcher.done): m.next = "parsing_header" m.d.sync += parser_demux.select.eq(HTTP_PARSER_HEADERS) with m.State("parsing_header"): m.next = "parsing_header" - with m.If(skip_headers.accepted): + with m.If(timeout): + m.next = "writing" + m.d.sync += send_bad + with m.Elif(skip_headers.accepted): with m.If(start_matcher.path[MATCHED_LED_PATH]): with m.If(start_matcher.method[start_matcher.METHOD_POST]): m.next = "parsing_led_body" @@ -200,8 +239,8 @@ def elaborate(self, _platform): m.next = "writing" m.d.sync += send_405 with m.Elif(start_matcher.path[MATCHED_COFFEE_PATH]): - with m.If(start_matcher.method[start_matcher.METHOD_GET] - | start_matcher.method[start_matcher.METHOD_BREW]): + with m.If(start_matcher.method[start_matcher.METHOD_GET] + | start_matcher.method[start_matcher.METHOD_BREW]): m.next = "writing" m.d.sync += send_teapot with m.Else(): @@ -212,25 +251,24 @@ def elaborate(self, _platform): m.d.sync += send_404 with m.Elif(~self.session.inbound.active): m.next = "writing" - # TODO: #4 - Should send a different error code besides 404 if the - # headers fail to parse before end-of-session. m.d.sync += send_404 with m.State("parsing_led_body"): # TODO: #4 - Make body parsing state more generic. m.next = "parsing_led_body" - with m.If(led_body_handler.accepted): + with m.If(timeout): + m.next = "writing" + m.d.sync += send_bad + with m.Elif(led_body_handler.accepted): m.next = "writing" m.d.sync += send_ok with m.Elif(led_body_handler.rejected): m.next = "writing" - # TODO: #4 - Should send a different error code besides 404 if the - # body fails to parse before end-of-session. m.d.sync += send_404 with m.State("writing_count_ok"): m.next = "writing_count_ok" m.d.sync += [ - ok_printer.en.eq(0), - count_body.inc_ok.eq(0) - ] + ok_printer.en.eq(0), + count_body.inc_ok.eq(0) + ] with m.If(ok_printer.done): m.d.sync += send_count m.next = "writing" @@ -241,16 +279,18 @@ def elaborate(self, _platform): not_found_printer.en.eq(0), not_allowed_printer.en.eq(0), teapot_printer.en.eq(0), + bad_printer.en.eq(0), count_body.en.eq(0), self.session.outbound.active.eq(1), count_body.inc_ok.eq(0), count_body.inc_error.eq(0), - ] + ] with m.If( ((response_mux.select == RESPONSE_OK) & ok_printer.done) - | ((response_mux.select == RESPONSE_404) & not_found_printer.done) - | ((response_mux.select == RESPONSE_405) & not_allowed_printer.done) - | ((response_mux.select == RESPONSE_COUNT) & count_body.done) - | ((response_mux.select == RESPONSE_TEAPOT) & teapot_printer.done)): + | ((response_mux.select == RESPONSE_404) & not_found_printer.done) + | ((response_mux.select == RESPONSE_405) & not_allowed_printer.done) + | ((response_mux.select == RESPONSE_COUNT) & count_body.done) + | ((response_mux.select == RESPONSE_TEAPOT) & teapot_printer.done) + | ((response_mux.select == RESPONSE_BAD) & bad_printer.done)): m.d.sync += self.session.outbound.active.eq(0) # Can finish writing before all the input is collected, # since a bad request migh trigger an early 404. Wait