# controller.py
import uasyncio as asyncio
import struct, time, math
from machine import SPI, Pin
from protocol import *
from rs485stream import RS485Stream, crc16_ccitt
from st7789sync import ST7789, COLOR_BLACK
from gpio_chain import ChainIO

LCD_W, LCD_H = 240, 240

def compute_grid(n):
    if n <= 0:
        raise ValueError("n must be positive")
    best_rows, best_cols = 1, n
    best_ratio = n
    for rows in range(1, n + 1):
        if n % rows == 0:
            cols = n // rows
            ratio = abs(rows - cols)
            if ratio < best_ratio:
                best_ratio = ratio
                best_rows, best_cols = rows, cols
    return best_rows, best_cols

def even_align(x):
    return x & ~1

def tiles_for_nodes(n):
    cols, rows = compute_grid(n)
    xs = [even_align((LCD_W * (i + 1)) // cols) for i in range(cols - 1)] + [LCD_W]
    ys = [even_align((LCD_H * (i + 1)) // rows) for i in range(rows - 1)] + [LCD_H]
    tiles = []
    x0 = 0
    for cx in range(cols):
        y0 = 0
        x1 = xs[cx]
        for cy in range(rows):
            y1 = ys[cy]
            tiles.append((x0, y0, x1 - x0, y1 - y0))
            y0 = y1
        x0 = x1
    return tiles[:n]

async def assign_jobs(bus: RS485Stream, assigned_ids, tiles):
    seq = 0
    for addr, (x0, y0, w, h) in zip(assigned_ids, tiles):
        payload = struct.pack("<HHHHHH", LCD_W, LCD_H, x0, y0, w, h)
        print(f"job assigned addr = {addr}, payload = {payload}")
        ok = await bus.send_ctrl_with_retry(addr, CMD_ASSIGN, seq, payload, timeout_ms=300, retries=4)
        print(f"-> assign ok = {ok}")
        if not ok:
            print("ASSIGN failed:", addr)
        seq ^= 1

async def gpio_measure_span(chain: ChainIO, start_timeout_ms=300000, done_timeout_ms=300000):
    print(f"gpio_measure_span START")
    # Start波を出す
    print(f"send START wave (start rendering)")
    chain.drive_high()
    t0 = time.ticks_ms()
    print(f"... waiting for START send back (GPIO.pin_in = {chain.pin_in.value()})")
    if not await chain.wait_in_level(1, start_timeout_ms):
        raise RuntimeError("Start wave did not return High")
    print(f"receive START wave back (GPIO.pin_in = {chain.pin_in.value()})")
    t_all_start = time.ticks_ms()
    # Done波
    print(f"send DONE wave (")
    chain.drive_low()
    print(f"... waiting for END send back")
    if not await chain.wait_in_level(0, done_timeout_ms):
        raise RuntimeError("Done wave did not return Low")
    print(f"receive END wave back")
    t_all_done = time.ticks_ms()
    return time.ticks_diff(t_all_done, t_all_start)

def recv_and_draw_one(bus: RS485Stream, lcd: ST7789, expected_tile, timeout_ms=5000):
    print(f"recv_and_draw_one")
    x0, y0, w, h = expected_tile
    # ヘッダ待ち
    deadline = time.ticks_add(time.ticks_ms(), timeout_ms)
    hdr = None
    while time.ticks_diff(deadline, time.ticks_ms()) > 0:
        r = bus.recv_header_sync(timeout_ms=500)
        if r is None:
            continue
        a, cmd, ln, hdr_bytes = r
        if a == MASTER_ADDR and cmd == CMD_TILE and ln >= 8 + 2*w*h:
            hdr = (a, cmd, ln, hdr_bytes)
            break
    if hdr is None:
        print("no header")
        return False

    # メタ
    meta = bus._read_exact_sync(8, time.ticks_add(time.ticks_ms(), 2000))
    if meta is None:
        return False
    rx_x0, rx_y0, rx_w, rx_h = struct.unpack("<HHHH", meta)
    if (rx_x0, rx_y0, rx_w, rx_h) != (x0, y0, w, h):
        print("tile meta mismatch")
        return False

    _, _, ln, hdr_bytes = hdr
    if ln != 8 + 2*w*h:
        print("payload len mismatch")
        return False

    # 行ごと受信→描画
    crc = crc16_ccitt(hdr_bytes + meta)
    line_bytes = 2 * w
    buf = bytearray(line_bytes)
    mv = memoryview(buf)

    for line in range(h):
        ok = bus.readinto_exact_sync(mv, line_bytes, timeout_ms=2000)
        if not ok:
            print("line read timeout")
            return False
        crc = crc16_ccitt(mv, crc)
        lcd.push_line_direct_sync(x0, y0 + line, w, mv)

    # 末尾CRC
    rx_crc_bytes = bus._read_exact_sync(2, time.ticks_add(time.ticks_ms(), 1000))
    if rx_crc_bytes is None:
        print("crc tail missing")
        return False
    rx_crc = struct.unpack("<H", rx_crc_bytes)[0]
    if rx_crc != crc:
        print("crc mismatch")
        return False

    return True

async def start(assigned_ids):
    # LCD
    spi = SPI(0, baudrate=40_000_000, polarity=1, phase=1,
              sck=Pin(6), mosi=Pin(7), miso=Pin(4))
    lcd = ST7789(spi, LCD_W, LCD_H,
                 reset=Pin(14, Pin.OUT), dc=Pin(15, Pin.OUT),
                 cs=Pin(5, Pin.OUT),  blk=Pin(13, Pin.OUT))
    lcd.backlight(True)
    lcd.clear(COLOR_BLACK); lcd.text("INIT", 10, 10); lcd.update()

    # RS485 / GPIO
    bus = RS485Stream(uart_id=0, tx=16, rx=17, dir_pin=18, baudrate=100_000)
    chain = ChainIO(in_pin=20, out_pin=21)

    # タイル割当は assigned の数で作る
    tiles = tiles_for_nodes(len(assigned_ids))

    # 1) ASSIGN
    print(f"1) ASSIGN")
    await assign_jobs(bus, assigned_ids, tiles)

    # 2) GPIOでスパン計測
    print(f"2) SPAN")
    span_ms = await gpio_measure_span(chain)

    # 3) トークン起点（assignedのみを並べる）
    print(f"3) TOKEN")
    token = bytes(assigned_ids)
    await bus.send_frame(assigned_ids[0], CMD_TOKEN, token)

    # 4) 順次受信・描画（トークンはworker間で勝手に回る）
    print(f"4) DRAW")
    for tile in tiles:
        ok = recv_and_draw_one(bus, lcd, tile, timeout_ms=15000)
        if not ok:
            print("tile recv failed")
    # lcd.text("SPAN: {} ms".format(span_ms), 10, 10); lcd.update()
    print(f"SPAN: {span_ms} ms".format(span_ms))
    print("FRAME DONE")

def main(assigned_ids=None):
    if assigned_ids is None:
        # 例：8台中4台使用
        assigned_ids = [1, 3, 5, 7]
    try:
        asyncio.run(start(assigned_ids))
    finally:
        asyncio.new_event_loop()

if __name__ == "__main__":
    main([1,3,5,7])

