#!/usr/bin/env python3
# (c) 2022 @pixie - Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License

# Import standard library modules.
import argparse
import pathlib
import io
import struct 
from collections import deque

# Import third-party packages.
from pixelblaze import *
from PIL import Image, ImageDraw

# ------------------------------------------------

def animatePreview(imgPreview, mapData, outputScale, borderColor, fps):

    # Make some starting assumptions that might be modified by the mapData.
    patternHeight = 1
    patternWidth = imgPreview.width
    patternFrames = imgPreview.height
    worldMap = [[], [], []]

    # Parse the mapData to build the worldMap.
    headerSize = 3 * 4 # first 3 longwords are the header.
    offsets = struct.unpack('<3I', mapData[:headerSize])
    fileVersion = offsets[0]
    numDimensions = offsets[1]
    dataSize = offsets[2]
    wordSize = fileVersion * 1  # v1 uses uint8, v2 uses uint16
    numElements = dataSize // wordSize // numDimensions
    # If the number of elements in the worldMap doesn't match the pixelCount, it's stale and needs to be refreshed.
    if (numElements != imgPreview.width):
        raise ValueError("Map does not match pixelCount; re-save map and try again.")

    # Read in the list of 8- or 16-bit coordinates
    exponent = pow(2, 8 * wordSize)
    format = f"<{numDimensions}{'BH'[wordSize - 1]}"
    for tuple in struct.iter_unpack(format, mapData[headerSize:]):
        for dimension in range(numDimensions):
            value = tuple[dimension] / (exponent - 1)
            worldMap[dimension].append(value)

    # -----------------------------
    
    # Analyze the sparseness of the points in the worldMap.
    minValue = [ 1, 1, 1 ]
    maxValue = [ 0, 0, 0 ]
    minDelta = [ 1, 1, 1 ]
    numPositions = [ 0, 0, 0 ]
    for dimension in range(len(worldMap)):
        sortedMap = sorted(worldMap[dimension])
        for element in range(len(worldMap[dimension])):
            minValue[dimension] = min(minValue[dimension], sortedMap[element])
            maxValue[dimension] = max(minValue[dimension], sortedMap[element])
            delta = abs(sortedMap[element] - sortedMap[(element + 1) % numElements])
            if delta > 0: minDelta[dimension] = min(minDelta[dimension], delta)
        numPositions[dimension] = 1 + round((maxValue[dimension] - minValue[dimension]) / minDelta[dimension])

    # -----------------------------

    # Rescale the elements in this dimension appropriately.
    offsetMap = [[], []]
    for dimension in range(len(worldMap)):
        for element in range(len(worldMap[dimension])): 
            offsetMap[dimension].append(round((worldMap[dimension][element] - minValue[dimension]) / minDelta[dimension]))

    # Calculate how big things need to be.
    patternWidth = 1 + round((maxValue[0] - minValue[0]) / minDelta[0])
    patternHeight = 1 + round((maxValue[1] - minValue[1]) / minDelta[1])
    outputWidth = patternWidth * (1 + outputScale) + 1
    outputHeight = patternHeight * (1 + outputScale) + 1

    # -----------------------------

    # Combine all the captured preview frames into a single animated image.
    animationFrames = []
    # Pantograph pixels from the captured previews into the animated PNG.
    for iterRow in range(patternFrames):
        print(f"  rendering frame {iterRow + 1} of {max(range(patternFrames)) + 1}\r", end='')
        # Do a simplified 2D plot.
        with Image.new('RGB', (outputWidth, outputHeight), (borderColor, borderColor, borderColor)) as animationFrame: 
            animationCanvas = ImageDraw.Draw(animationFrame)
            for iterCol in range(imgPreview.width):
                r, g, b = imgPreview.getpixel((iterCol, iterRow))
                pixelX = 1 + offsetMap[0][iterCol] * (1 + outputScale)
                pixelY = 1 + offsetMap[1][iterCol] * (1 + outputScale)
                animationCanvas.rectangle(xy=[(pixelX, pixelY), (pixelX + outputScale - 1, pixelY + outputScale - 1)], \
                    outline=(borderColor, borderColor, borderColor), fill=(r, g, b), width=0)
        # save the frame.
        animationFrames.append(animationFrame)

    # Save the animation to an Image in memory.
    with io.BytesIO() as imageBuffer:
        animationFrames[0].save(imageBuffer, format="PNG", save_all=True, append_images=animationFrames[1:], duration=1000/fps, loop=0)
        animationFrames.clear() 
        animationFrames = None
        imageBuffer.seek(0)
        return imageBuffer.getvalue()

# ------------------------------------------------

# Record preview frames from a Pixelblaze identified by its IP address.
def recordVideo(ipAddress, fileName):

    pb = Pixelblaze(ipAddress)
    majorVersion = pb.getVersionMajor()
    minorVersion = pb.getVersionMinor()
    if ((majorVersion == 2 and minorVersion < 29) or (majorVersion == 3 and minorVersion < 24)):
        print("Sorry; recordings can only be generated with firmware versions v2.29/v3.24 and higher.")
        return

    # get the Pixelblaze device name so we can keep unique patterns across different devices
    savedConfigSettings = pb.getConfigSettings()
    mapData = pb.getMapData()

    # Capture preview frames.
    imgPreview = None
    allRows = deque()
    previewWidth = 0
    frameCount = 0
    print(f"Capturing frames from {pb.getDeviceName()}...press [Ctrl]-[C] to stop.")

    #   Disable the LEDs so we can maximize the brightness without affecting the current draw.
    pb.setLedType(pb.ledTypes.noLeds)
    pb.setBrightnessLimit(100)
    pb.setBrightnessSlider(1)
    pb.setSendPreviewFrames(True)

    while True:
        try:
            print(f"  capturing frame {frameCount + 1}\r", end='')
            scanLine = pb.getPreviewFrame()
            if scanLine is None:
                # If we miss a frame, try starting over so it doesn't spoil the pattern.
                print(f"missed a preview frame; retrying.")
                pb.setSendPreviewFrames(True)

            # Store the contents of each frame into the frameBuffer.
            else: 
                allRows.append(scanLine)
                frameCount += 1
                # Infer the preview width from the first frame received.
                if previewWidth == 0: previewWidth = int((len(scanLine)) / 3)

        except KeyboardInterrupt:
            #print('')
            break

    # Restore the settings.
    pb.setSendPreviewFrames(False)
    pb.setBrightnessLimit(pb.getBrightnessLimit(savedConfigSettings))
    pb.setBrightnessSlider(pb.getBrightnessSlider(savedConfigSettings))
    pb.setLedType(pb.getLedType(savedConfigSettings), dataSpeed=pb.getDataSpeed(savedConfigSettings))

    # If we got something, render it.
    print(f"\rCaptured {frameCount} frames.".ljust(40, ' '))
    if frameCount > 0:
        # Set up a drawing context and copy the frameBuffer into the Image
        with Image.new("RGB", (previewWidth, frameCount), None) as imgPreview:
            draw = ImageDraw.Draw(imgPreview)
            for frame in range(len(allRows)):
                row = allRows.popleft()
                for column in range(previewWidth):
                    # the scanline contains three bytes (R, G, B) for each pixel, in index order.
                    R = row[(column * 3) + 0]
                    G = row[(column * 3) + 1]
                    B = row[(column * 3) + 2]
                    draw.point((column, frame), (R, G, B))

            # Delete the stored rows.
            allRows.clear()

            # ...generate an animated version of the pattern
            animatedImage = animatePreview(imgPreview, mapData, 5, 63, 30)
            with open(pathlib.Path(fileName).with_suffix('.png'), "wb") as file:
                file.write(animatedImage)

    print("\rComplete!".ljust(40, ' '))

# ------------------------------------------------

# Here's where the magic happens.
if __name__ == "__main__":

    # Parse command line.
    parser = argparse.ArgumentParser()
    parser.add_argument("--ipAddress", required=True, help="The IP address of the Pixelblaze to record")
    parser.add_argument("--fileName", required=True, help='The filename for the recording to be created')
    args = parser.parse_args()

    recordVideo(args.ipAddress, args.fileName)

    exit()
