import React, { useMemo, useEffect, useCallback } from 'react'
import { LinePath } from '@visx/shape'
import { curveLinear } from '@visx/curve'
import { Group } from '@visx/group'
import { GridRows } from '@visx/grid'
import { scaleLinear, scaleOrdinal, scalePoint } from '@visx/scale'
import { schemeTableau10 as schemeSet } from 'd3-scale-chromatic'
import { extent } from 'd3-array'
import { theme } from 'twin.macro'

import { getPrecision, valueFormatter } from '../util/charts'
import {
  MONTHS,
  QUARTERS,
  YEARS,
} from '../connectors/dataSelection/filters/useTime'
import useSvgTooltip from '../connectors/useSvgTooltip'
import Axes from './Axes'
import ChartTitle from './ChartTitle'
import Legend from './Legend'
import TooltipContent from './TooltipContent'
import { getRoundedString, getSubtitle } from '../util'
import useData from '../data/useData'

const colorScale = scaleOrdinal({ range: schemeSet })

const defaultMargin = { top: 50, left: 90, right: 40, bottom: 90 }

// Accessors
const getY = (d) => d.y
const getX = (d) => d.x

const getInterval = (dateText) => {
  const patterns = {
    quarter: /(Q\d)\/(\d{4})/,
    month: /(\d{4})-(\d\d)/,
  }

  if (patterns.quarter.test(dateText)) {
    return QUARTERS
  }

  if (patterns.month.test(dateText)) {
    return MONTHS
  }

  return YEARS
}

const getLevel1Keys = (level1Keys, interval) => {
  if (interval === QUARTERS) {
    const indexOfFirstQ1 = level1Keys.findIndex((e) => e.startsWith('Q1'))
    return level1Keys.slice(indexOfFirstQ1)
  }

  if (interval === MONTHS) {
    const indexOfFirstMonth = level1Keys.findIndex((e) => e.endsWith('01'))
    return level1Keys.slice(indexOfFirstMonth)
  }

  return level1Keys
}

const getLabelValues = (level1Keys, interval) => {
  if (interval === QUARTERS) {
    return level1Keys.filter(
      (key) => key.startsWith('Q1') || key.startsWith('Q3')
    )
  }

  if (interval === MONTHS) {
    return level1Keys.filter((key) => /-0[159]$/.test(key))
  }

  return
}

const createSeries = (level2, level1Keys) => {
  // Create list of all unique level2 data names
  const level2Keys = [...new Set(level2.map((e) => e.data.name))]

  // Create emptySeries from level2Keys with all points in xScale to get continuous lines
  const emptySeries = level2Keys.reduce(
    (acc, curr) => ({
      ...acc,
      [curr]: level1Keys.map((e) => ({
        x: e,
        y: 0,
      })),
    }),
    {}
  )

  // Create objects with array of x and y points for each level2Key
  const series = level2.reduce((acc, curr) => {
    let point = {
      x: curr.parent.data.name,
      y: curr.value,
    }

    const points = acc[curr.data.name].map((e) => {
      if (e.x === point.x) {
        e.y = point.y
      }
      return e
    })

    return {
      ...acc,
      [curr.data.name]: points,
    }
  }, emptySeries)

  return series
}

/*
 * the data comes in getRoot, not root, because storybook doesn't tolerate
 * props that contain cycles.
 */
export default function TimeSeries({
  getRoot,
  parentWidth: width,
  parentHeight: height,
  margin = defaultMargin,
  setFilterDefinitions,
  sortByLevel,
  getKeyFromDepth,
}) {
  if (typeof getRoot !== 'function') {
    throw new Error('Need to supply a getRoot function')
  }
  const { orgCodes, orgColors } = useData()
  const root = useMemo(getRoot, [getRoot])

  const xMax = width - margin.left - margin.right
  const yMax = height - margin.top - margin.bottom

  const xScale = scalePoint({
    range: [0, xMax],
    padding: 0.2,
  })
  const yScale = scaleLinear({
    domain: [0, root.value],
    range: [yMax, 0],
    nice: true,
  })

  // Update scale based on children
  const level1 = root.descendants().filter((node) => node.depth === 1)
  const level2 = root.descendants().filter((node) => node.depth === 2)

  let data
  let keys
  let labelValues
  const unit = root.data.unit

  if (level1.length) {
    let level1Keys = level1.map((d) => d.data.name)

    const level1FilterKey = getKeyFromDepth(level1[0].depth - 1)
    if (level1FilterKey === 'time') {
      const interval = getInterval(level1[0].data.name)
      level1Keys = getLevel1Keys(level1Keys, interval)
      labelValues = getLabelValues(level1Keys, interval)
    }

    xScale.domain(level1Keys)

    if (!level2.length) {
      // Create data array with points where x = child.data.name and y = value
      data = root.children.reduce(
        (acc, curr) => [
          ...acc,
          {
            x: curr.data.name,
            y: curr.value,
          },
        ],
        []
      )

      // Set yScale
      yScale.domain(extent(level1, (d) => d.value)).nice()
    }

    if (level2.length) {
      const series = createSeries(level2, level1Keys)

      data = Object.values(series)

      // Set yScale
      const allValues = data.reduce(
        (all, curr) => [...all, ...curr.map((d) => d.y)],
        []
      )
      yScale.domain(extent(allValues)).nice()

      // Set colorScale according to the number of unique level2 options and fix sorting order
      keys = Object.keys(series).sort((a, b) => sortByLevel(a, b, 1))

      // Check if level2 key === org_code
      const isOrgCode = level2.find((d) => d.data.color)
      colorScale.domain(keys)
      if (isOrgCode) {
        const domain = [...new Set(level2.map((d) => d.data.name))].sort()
        const range = domain.map((d) => {
          const key = Object.keys(orgCodes).find((k) => orgCodes[k] === d)
          return orgColors[key] ?? theme`colors.blue.DEFAULT`
        })
        colorScale.domain(domain).range(range)
      }
    }
  }

  useEffect(() => {
    if (!level1.length) {
      setFilterDefinitions((prev) =>
        prev.map((filter) => ({
          ...filter,
          active: filter.level === 1,
        }))
      )
    } else if (!level2.length) {
      setFilterDefinitions((prev) =>
        prev.map((filter) => ({
          ...filter,
          active: filter.level === 2,
        }))
      )
    } else {
      setFilterDefinitions((prev) =>
        prev.map((filter) => ({
          ...filter,
          active: false,
        }))
      )
    }
  }, [level1.length, level2.length, setFilterDefinitions])

  const {
    containerRef,
    handlePointer,
    TooltipInPortal,
    hideTooltip,
    tooltipLeft,
    tooltipTop,
    tooltipOpen,
    tooltipData,
    tooltipStyles,
  } = useSvgTooltip()

  const handleTooltip = useCallback(
    (e, { label, value }) => {
      handlePointer(e, {
        label,
        value: getRoundedString(value, getPrecision(value)),
        unit,
      })
    },
    [handlePointer, unit]
  )

  return width < 10 || root == null ? null : (
    <div tw="relative overflow-hidden">
      <svg width={width} height={height} ref={containerRef}>
        <rect width={width} height={height} fill="none" />
        <Group top={margin.top} left={margin.left}>
          <GridRows
            scale={yScale}
            width={xMax}
            stroke={theme`colors.grey.4`}
            pointerEvents="none"
          />
          <ChartTitle name={root.data.name} unit={root.data.unit} />

          {level1.length && !level2.length && (
            <Group>
              <CustomLinePath
                data={data}
                x={(d) => xScale(getX(d)) ?? 0}
                y={(d) => yScale(getY(d)) ?? 0}
                stroke={colorScale('root')}
              />
              {data.map((d, j) => (
                <LineCircle
                  key={`circle-${j}`}
                  cx={xScale(getX(d)) ?? 0}
                  cy={yScale(getY(d)) ?? 0}
                  fill={colorScale('root')}
                  onPointerEnter={(e) =>
                    handleTooltip(e, { label: getX(d), value: getY(d) })
                  }
                  onPointerLeave={hideTooltip}
                />
              ))}
            </Group>
          )}

          {level2.length &&
            data.map((item, i) => {
              return (
                <Group key={`line-item-${i}`}>
                  <CustomLinePath
                    data={item}
                    x={(d) => xScale(getX(d)) ?? 0}
                    y={(d) => yScale(getY(d)) ?? 0}
                    stroke={colorScale(colorScale.domain()[i])}
                  />
                  {item.map((d, j) => (
                    <LineCircle
                      key={`circle-${i}-${j}`}
                      cx={xScale(getX(d)) ?? 0}
                      cy={yScale(getY(d)) ?? 0}
                      fill={colorScale(colorScale.domain()[i])}
                      onPointerEnter={(e) =>
                        handleTooltip(e, { label: getX(d), value: getY(d) })
                      }
                      onPointerLeave={hideTooltip}
                    />
                  ))}
                </Group>
              )
            })}

          <Axes
            xScale={xScale}
            yScale={yScale}
            top={yMax}
            hasLabels={level1.length}
            tickValues={labelValues}
            leftTickFormat={valueFormatter}
          />
        </Group>

        {tooltipOpen && (
          <>
            <TooltipInPortal
              left={tooltipLeft}
              top={tooltipTop}
              style={tooltipStyles}
            >
              <TooltipContent
                indicator={root.data.name}
                label={tooltipData.label}
                subLabel={getSubtitle(tooltipData.value, tooltipData.unit)}
              />
            </TooltipInPortal>
          </>
        )}
      </svg>

      {/* Only show if level2 selected */}
      {!!level2.length && (
        <div tw="absolute top-10 left-24">
          <Legend scale={colorScale} isHorizontal={false} />
        </div>
      )}
    </div>
  )
}

function CustomLinePath({ data, x, y, stroke }) {
  return (
    <LinePath
      data={data}
      x={x}
      y={y}
      stroke={stroke}
      strokeWidth={3}
      curve={curveLinear}
    />
  )
}

function LineCircle({ cx, cy, fill, ...rest }) {
  return (
    <circle
      r={4}
      cx={cx}
      cy={cy}
      fill={fill}
      stroke="white"
      strokeWidth={1}
      {...rest}
    />
  )
}
