# rs485stream.py
import uasyncio as asyncio
from machine import UART, Pin
import struct, time
from protocol import PREFIX, CMD_TILE, c2s

BITS_PER_BYTE = 10  # 8N1

def crc16_ccitt(data, crc=0xFFFF):
    for b in data:
        crc ^= (b << 8)
        for _ in range(8):
            if crc & 0x8000:
                crc = ((crc << 1) ^ 0x1021) & 0xFFFF
            else:
                crc = (crc << 1) & 0xFFFF
    return crc

class RS485Stream:
    def __init__(self, uart_id, tx, rx, dir_pin, baudrate):
        self.uart = UART(uart_id, baudrate=baudrate, tx=Pin(tx), rx=Pin(rx), bits=8, parity=None, stop=1)
        self.dir  = Pin(dir_pin, Pin.OUT, value=0)  # 0=RX, 1=TX
        self._baud = baudrate
        self.swriter = asyncio.StreamWriter(self.uart, {})
        self.sreader = asyncio.StreamReader(self.uart)
        self._tx_lock = asyncio.Lock()

    # ========== 非同期：小フレーム ==========
    async def _recv_prefix(self):
        while True:
            b0 = await self.sreader.readexactly(1)
            if b0 != PREFIX[:1]:
                continue
            b1 = await self.sreader.readexactly(1)
            if b1 == PREFIX[1:2]:
                return

    async def recv_frame(self, max_len=64):
        while True:
            await self._recv_prefix()
            hdr = await self.sreader.readexactly(6)  # ADDR(1), CMD(1), LEN32(4)
            addr, cmd, ln = hdr[0], hdr[1], struct.unpack("<I", hdr[2:6])[0]
            # 大きいフレーム（CMD_TILEなど）はここでは扱わない
            if cmd == CMD_TILE or ln > max_len:
                # 軽く捨てる（大容量は同期側で扱う設計）
                t_end = time.ticks_add(time.ticks_ms(), 5)
                while time.ticks_diff(t_end, time.ticks_ms()) > 0:
                    n = self.uart.any()
                    if n:
                        self.uart.read(n)
                    else:
                        await asyncio.sleep_ms(0)
                return None
            payload = await self.sreader.readexactly(ln) if ln else b""
            rx_crc  = struct.unpack("<H", await self.sreader.readexactly(2))[0]
            if crc16_ccitt(hdr + payload) != rx_crc:
                # 再同期
                continue
            return (addr, cmd, payload)

    async def send_frame(self, addr, cmd, payload=b""):
        head = PREFIX + bytes([addr & 0xFF, cmd & 0xFF]) + struct.pack("<I", len(payload))
        crc  = crc16_ccitt(head[2:] + payload)
        pkt  = head + payload + struct.pack("<H", crc)
        tx_us = (len(pkt) * BITS_PER_BYTE * 1_000_000) // self._baud
        tx_ms = max(1, (tx_us + 999)//1000)
        async with self._tx_lock:
            self.dir.value(1)
            self.swriter.write(pkt)
            await asyncio.sleep_ms(tx_ms + 3)
            self.dir.value(0)

    async def send_ctrl_with_retry(self, addr, cmd, seq, payload=b"", timeout_ms=300, retries=4):
        body = bytes([seq & 0xFF]) + payload
        for _ in range(retries + 1):
            await self.send_frame(addr, cmd, body)
            deadline = time.ticks_add(time.ticks_ms(), timeout_ms)
            # ACKだけを締切まで探す
            while time.ticks_diff(deadline, time.ticks_ms()) > 0:
                fr = await self.recv_frame(max_len=16)
                if not fr:
                    await asyncio.sleep_ms(0)
                    continue
                a, c, pl = fr
                if a == addr and c == 0x11 and len(pl) >= 1 and pl[0] == (seq & 0xFF):
                    return True
            # timeout→再送
        return False

    # ========== 同期：大フレーム（ヘッダ/ライン/CRC） ==========
    def recv_header_sync(self, timeout_ms=5000):
        deadline = time.ticks_add(time.ticks_ms(), timeout_ms)
        state = 0
        while time.ticks_diff(deadline, time.ticks_ms()) > 0:
            if self.uart.any():
                b = self.uart.read(1)
                if not b:
                    continue
                if state == 0:
                    state = 1 if b == PREFIX[:1] else 0
                elif state == 1:
                    if b == PREFIX[1:2]:
                        break
                    state = 1 if b == PREFIX[:1] else 0
            else:
                time.sleep_ms(1)
        else:
            return None
        hdr = self._read_exact_sync(6, deadline)
        if hdr is None:
            return None
        addr, cmd, ln = hdr[0], hdr[1], struct.unpack("<I", hdr[2:6])[0]
        return (addr, cmd, ln, hdr)

    def readinto_exact_sync(self, mv, n, timeout_ms=5000):
        deadline = time.ticks_add(time.ticks_ms(), timeout_ms)
        got = 0
        view = memoryview(mv)
        while got < n and time.ticks_diff(deadline, time.ticks_ms()) > 0:
            if self.uart.any():
                w = self.uart.readinto(view[got:n])
                if w:
                    got += w
            else:
                time.sleep_ms(1)
        return got == n

    def _read_exact_sync(self, n, deadline_ms):
        buf = bytearray(n)
        i = 0
        while i < n and time.ticks_diff(deadline_ms, time.ticks_ms()) > 0:
            if self.uart.any():
                r = self.uart.readinto(memoryview(buf)[i:])
                if r:
                    i += r
            else:
                time.sleep_ms(1)
        return buf if i == n else None

    def send_frame_streaming_sync(self, addr, cmd, total_len, header=b"", chunk_iter=None,
                                  interline_gap_ms=2):
        if chunk_iter is None:
            chunk_iter = ()
        head = PREFIX + bytes([addr & 0xFF, cmd & 0xFF]) + struct.pack("<I", total_len)
        crc = crc16_ccitt(head[2:])
        self.dir.value(1)
        self.uart.write(head)
        if header:
            self.uart.write(header)
            crc = crc16_ccitt(header, crc)
        sent = len(head) + len(header)
        for chunk in chunk_iter:
            self.uart.write(chunk)
            crc = crc16_ccitt(chunk, crc)
            sent += len(chunk)
            tx_us = (len(chunk) * BITS_PER_BYTE * 1_000_000) // self._baud
            time.sleep_ms(max(1, (tx_us + 999)//1000) + interline_gap_ms)
        self.uart.write(struct.pack("<H", crc))
        sent += 2
        tx_us_all = (sent * BITS_PER_BYTE * 1_000_000) // self._baud
        time.sleep_ms(max(1, (tx_us_all + 999)//1000) + 2)
        self.dir.value(0)

