Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fomu_http_accel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,5 +57,5 @@ def elaborate(self, platform):

if __name__ == "__main__":
FomuPVTPlatform().build(FomuHttpAccelerator(),
# do_program=True,
do_program=True,
verbose=True)
28 changes: 28 additions & 0 deletions host_sim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@

import asyncio

from sim_server import SimServer
from not_tcp.host import StreamProxy


class HostSimulator(SimServer, StreamProxy):
# Multiple inheritance is not a *crime*, it's just an abuse of the rules.
# Tax avoidance is not tax evasion!
pass


async def run_server(port):
import sys
import ntcp_http
dut = ntcp_http.NtcpHttpServer()

with HostSimulator(dut, dut.tx, dut.rx) as srv:
server = await asyncio.start_server(
client_connected_cb=srv.client_connected, host="localhost",
port=port)
sys.stderr.write(f"listening on port {port}\n")
await server.serve_forever()


if __name__ == "__main__":
asyncio.run(run_server(3278))
73 changes: 73 additions & 0 deletions not_tcp/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import struct
from enum import IntFlag
from typing import Optional
from asyncio import StreamReader, StreamWriter
import asyncio
import sys


class Flag(IntFlag):
Expand Down Expand Up @@ -96,3 +99,73 @@ def from_bytes(cls, buf: bytes) -> (Optional["Packet"], bytes):
body = body_and_remainder[:header.body_length]
remainder = body_and_remainder[header.body_length:]
return (Packet.from_header(header, body), remainder)


# Use as superclass; subclass to simulator or real
class StreamProxy:
lock = asyncio.Lock()

def send(self, b: bytes()):
# Must be implemented by subclass
pass

def recv(self) -> bytes:
# Must be implemented by subclass
pass

def client_connected(
self, reader: StreamReader, writer: StreamWriter):
asyncio.create_task(self.client_loop(reader, writer))

async def client_loop(self, reader: StreamReader, writer: StreamWriter):
async with self.lock, asyncio.TaskGroup() as tg:
tg.create_task(self.run_inbound(reader))
tg.create_task(self.run_outbound(writer))

async def run_inbound(self, reader: StreamReader):
p1 = Packet(flags=Flag.START, stream_id=1, body=bytes())
self.send(p1.to_bytes())
want_bytes = 256
while True:
try:
async with asyncio.timeout(1):
buffer = await reader.read(want_bytes)
if len(buffer) == 0:
# Zero bytes returned at EOF; but not a timeout.
# That's end-of-stream.
break
# On a successful read, keep that many bytes
want_bytes = 256
p2 = Packet(flags=0, stream_id=1, body=buffer)
self.send(p2.to_bytes())
except asyncio.TimeoutError:
want_bytes = want_bytes // 2
if want_bytes == 0:
want_bytes = 1
# Input is done, in theory
p3 = Packet(flags=Flag.END, stream_id=1, body=bytes())
self.send(p3.to_bytes())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be the cause of the issue where we can't send more than one request. If I get rid of this line, I can get as many /coffee (418) and /counts responses I want from curl

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rewrite/split of "separate state machines" also fixed the "only one request", without getting rid of this -- and I think the "stop" is important to capture, so I'd like to leave this.

Merging this with the bug intact, #39 has the rewrite.


async def run_outbound(self, writer: StreamWriter):
buffer = bytes()
packet_count = 0
while True:
rcvd = self.recv() # Has its own timeout, but isn't async. So:
await asyncio.sleep(0)
buffer += rcvd
(p, rem) = Packet.from_bytes(buffer)
if p is None:
continue
buffer = rem
if packet_count == 0:
assert p.start
packet_count += 1
if not p.to_host:
# Ignore the packet
continue
writer.write(p.body)
await writer.drain()
if p.end:
break
writer.close()
await writer.wait_closed()
133 changes: 133 additions & 0 deletions not_tcp/host_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import sys
import pytest
import asyncio
from ntcp_http import NtcpHttpServer

from amaranth import Module
from amaranth.lib.wiring import Component, In, Out
from amaranth.lib import stream

from host_sim import HostSimulator
from not_tcp.host import Packet, Flag
from sim_server import SimServer
from not_tcp.not_tcp import StreamStop
from http_server import capitalizer

pytest_plugins = ('pytest_asyncio',)


class Capitalize(Component):
"""
A Not TCP server that capitalizes its input.
"""

tx: Out(stream.Signature(8))
rx: In(stream.Signature(8))

def elaborate(self, platform):
m = Module()

stop = m.submodules.stop = StreamStop(1)
# Serial side:
m.d.comb += [
self.tx.valid.eq(stop.bus.downstream.valid),
self.tx.payload.eq(stop.bus.downstream.payload),
stop.bus.downstream.ready.eq(self.tx.ready),

stop.bus.upstream.valid.eq(self.rx.valid),
stop.bus.upstream.payload.eq(self.rx.payload),
self.rx.ready.eq(stop.bus.upstream.ready),
]

cap = m.submodules.capitalizer = capitalizer.Capitalizer()

m.d.comb += [
# Data:
cap.input.eq(stop.stop.inbound.data.payload),
stop.stop.outbound.data.payload.eq(cap.output),
# Stream control:
stop.stop.outbound.data.valid.eq(stop.stop.inbound.data.valid),
stop.stop.inbound.data.ready.eq(stop.stop.outbound.data.ready),
# Session control:
stop.stop.outbound.active.eq(stop.stop.inbound.active),
]

return m


def DISABLED_test_capitalize_server():
dut = Capitalize()

with SimServer(dut, dut.tx, dut.rx) as srv:
p1 = Packet(flags=Flag.START, stream_id=1, body=b"hello world")
srv.send(p1.to_bytes())

received_bytes = bytes()
received_body = bytes()
packets = []
import sys
for i in range(100):
received_bytes += srv.recv()
(packet, remainder) = Packet.from_bytes(received_bytes)
if packet is not None:
sys.stderr.write(f"{packet}\n")
received_bytes = remainder
packets += [packet]
received_body += packet.body
if packet.end or len(received_body) == len("hello world"):
break
assert received_body == b"HELLO WORLD"

received_body = bytes()
# TODO: For now, we have to send an explicit "end" packet
p2 = Packet(stream_id=1, body=b"Goodbye for now")
p3 = Packet(flags=Flag.END, stream_id=1)
srv.send(p2.to_bytes())
srv.send(p3.to_bytes())
for i in range(100):
received_bytes += srv.recv()
(packet, remainder) = Packet.from_bytes(received_bytes)
if packet is not None:
sys.stderr.write(f"{packet}\n")
received_bytes = remainder
packets += [packet]
received_body += packet.body
if packet.end:
break
assert received_body == b"GOODBYE FOR NOW"

for i in range(len(packets)):
packet = packets[i]
assert packet.to_host
assert packet.start == (
i == 0), f"start {packet.start} for packet {i}"
assert packet.end == (
i == len(packets)-1), f"end {packet.end} for packet {i}"
assert packet.to_host


@pytest.mark.asyncio
async def test_tcp_proxy():
dut = NtcpHttpServer()

with HostSimulator(dut, dut.tx, dut.rx) as srv:
server = await asyncio.start_server(
client_connected_cb=srv.client_connected, host="localhost",
port=3278)
async with server:
reader, writer = await asyncio.open_connection("127.0.0.1", 3278)
writer.write(
"\r\n".join([
"POST /nothing-here HTTP/1.0",
"Cache-Control: private",
"",
"",
"lovely day today"
]).encode("utf-8")
)
await writer.drain()

read = await reader.read(-1)
response = read.decode("utf-8")
lines = response.split("\r\n")
assert lines[0] == "HTTP/1.0 404 Not Found"
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ regex
# Dev dependencies:
flake8
pytest
pytest-asyncio
hypothesis
12 changes: 7 additions & 5 deletions sim_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import queue
import sys
import traceback
from threading import Thread

from amaranth.sim import Simulator
Expand Down Expand Up @@ -98,25 +99,26 @@ def __enter__(self):
def _run_sim(self, sim):
def runnable():
try:
sys.stderr.write("running simulator\n")
# Uncomment this line, and indent the next, to get debug info.
# with sim.write_vcd("testout.vcd"):
sim.run()
sys.stderr.write("simulation complete\n")
except Exception as e:
sys.stderr.write(f"error in Amaranth simulation: {e}\n")
# Try to force shutdown:
self._sender.done = True
self._sender.die = True
raise e

return runnable

def __exit__(self, *args, **kwargs):
def __exit__(self, exe_type, exe_val, exe_traceback, **kwargs):
assert self._sim_thread is not None
# Shutting down the data input should shut down the simulator;
# the data input is driving the tick.
# self._data_in.shutdown()
# .shutdown() is not available on python3.11,
# so we have to use a flag.
self._sender.done = True
if exe_traceback is not None:
traceback.print_tb(exe_traceback)

self._sender.die = True
self._sim_thread.join()
10 changes: 9 additions & 1 deletion stream_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class StreamSender:
# Flag bit, signaled when all bytes from all packets have been delivered.
done: bool = False

# Flag bit, to kill the send_queue_active thread
die: bool = False


def __init__(self,
stream,
random_delay=False,
Expand Down Expand Up @@ -189,13 +193,17 @@ def send_queue_active(self, q: queue.Queue[bytes], idle_ticks=100):
stream = self._stream

async def sender(ctx):
while not self.done:
while not self.die:
try:
data = q.get_nowait()
except queue.Empty:
data = bytes()
except queue.ShutDown:
sys.stderr.write("queue is shut down\n")
return
except Exception as e:
sys.stderr.write(f"unexpected exception: {e}\n")
raise e

if isinstance(data, str):
data = str.encode(data, "utf-8")
Expand Down