package com.tekdiving.plotter

import io.data2viz.axis.Orient
import io.data2viz.axis.axis
import io.data2viz.color.Color
import io.data2viz.geom.Size
import io.data2viz.math.Percent
import io.data2viz.math.pct
import io.data2viz.scale.Scales
import io.data2viz.viz.*
import org.w3c.dom.HTMLCanvasElement
import kotlin.math.roundToInt
import kotlin.properties.Delegates.observable

data class Range(val min: Double, val max: Double)

fun range(min: Double, max: Double) = Range(min, max)

sealed interface LineType
object DirectLine: LineType
data class QuadraticLine(
    val cpxPercent: Percent = 30.pct,
    val cpyPercent: Percent = 40.pct
): LineType

data class Line(
    val label: String,
    val strokeColor: Color,
    val fillColor: Color? = null,
    val strokeWidth: Double = 1.0,
    val lineType: LineType = DirectLine,
    val pointRadius: Double = 2.0,
    val points: List<LinePoint> = emptyList()
)

data class LinePoint(
    val x: Int, val y: Double
)

class LineChart(
    val canvas: HTMLCanvasElement,
    val lines: List<Line>, size: Size,
    val onStepMove: (Int) -> Unit = { }
) {

    private var redraw = false

    var size: Size by observable(size) { _, old, new ->
        if (old != new) {
            visual.size = new
            xScale = newXScale()
            yScale = newYScale()
            rebuild()
        }
    }

    // TODO make margins and size mutable
    // TODO checks that margins aren't bigger than size
    val margins = Margins(40.5, 30.5, 50.5, 50.5)

    val chartWidth get() = size.width - margins.hMargins
    val chartHeight get() = size.height - margins.vMargins

    var xRange: Range by observable(computeXRange()) { _, old, new ->
        if (old != new) {
            xScale = newXScale()
            rebuild()
        }
    }

    var yRange: Range by observable(computeYRange()) { _, old, new ->
        if (old != new) {
            yScale = newYScale()
            rebuild()
        }
    }

    private fun computeXRange() = range(
        (lines.map { line -> line.points.map { it.x }.minOrNull() ?: 0 }.minOrNull() ?: 0).toDouble(),
        (lines.map { line -> line.points.map { it.x }.maxOrNull() ?: 0 }.maxOrNull() ?: 0).toDouble(),
    )


    private fun computeYRange() = range(
        lines.map { line -> line.points.map { it.y }.minOrNull() ?: 0.0 }.minOrNull() ?: 0.0,
        lines.map { line -> line.points.map { it.y }.maxOrNull() ?: 0.0 }.maxOrNull() ?: 0.0,
    )


    // linear scale for x
    private var xScale = newXScale()

    private fun newXScale() = Scales.Continuous.linearRound {
        domain = listOf(xRange.min, xRange.max)
        range = listOf(0.0, chartWidth)
    }

    // linear scale for y
    private var yScale = newYScale()

    private fun newYScale() = Scales.Continuous.linear {
        domain = listOf(yRange.min, yRange.max)
        range = listOf(chartHeight, 0.0) // <- y is mapped in the reverse order (in SVG, javafx (0,0) is top left.
    }

    private fun GroupNode.xAxis() = group {
        transform { translate(y = chartHeight + 10) }
        axis(Orient.BOTTOM, xScale)
    }

    private fun GroupNode.yAxis() = group {
        transform { translate(x = -10.0) }
        axis(Orient.LEFT, yScale)
    }

    private fun Viz.build() {
        clear()
        group {
            transform { translate(x = margins.left, y = margins.top) }
            group { yAxis() }
            group { xAxis() }

            group {
                for (line in lines) {
                    val points = line.points
                    if (points.isNotEmpty()) {
                        group {
                            // adds line path
                            path {
                                fill = line.fillColor
                                strokeColor = line.strokeColor
                                strokeWidth = line.strokeWidth

                                var previousX = xScale(points[0].x)
                                var previousY = yScale(points[0].y)
                                moveTo(previousX, previousY)
                                for (i in 1 until points.size) {
                                    val x = xScale(points[i].x)
                                    val y = yScale(points[i].y)

                                    when (line.lineType) {
                                        is DirectLine -> lineTo(x, y)
                                        is QuadraticLine -> {
                                            val cpx = previousX + (x - previousX) * line.lineType.cpxPercent.value
                                            val cpy = previousY + (y - previousY) * line.lineType.cpyPercent.value
                                            quadraticCurveTo(cpx, cpy, x, y)
                                        }
                                    }

                                    previousX = x
                                    previousY = y
                                }
                            }
                        }

                        // adds point bullet
                        if (line.pointRadius * 2.0 > line.strokeWidth) {
                            for (point in points) {
                                circle {
                                    x = xScale(point.x)
                                    y = yScale(point.y)
                                    fill = line.strokeColor
                                    radius = line.pointRadius
                                }
                            }
                        }

                        // TODO adds fill path
                    }
                }
            }
        }
    }


    private var currentStep = -1

    val visual: Viz = viz {
        this.size = this@LineChart.size
        build()

        on(KMouseMove) {
            // TODO there is a DPI problem with the position, it needs to be solved
            val step = xScale
                .invert(it.pos.x-margins.left)
                .coerceIn(xRange.min, xRange.max)
                .roundToInt()
            if (step != currentStep) onStepMove(step)
            currentStep = step
        }

        bindRendererOn(canvas)
    }

    fun invalidate() {
        redraw = true
    }

    fun renderRequest() {
        if (redraw) {
            visual.render()
            redraw = false
        }
    }

    fun rebuild() {
        visual.build()
        invalidate()
    }

    fun pointsForLabel(x: Int): List<Double> = lines.map {
        it.points.find { it.x >= x }?.y ?: 0.0
    }

}

