import React, { FC, useMemo } from 'react'
import {
  ComposedChart,
  Line,
  Scatter,
  XAxis,
  YAxis,
  ResponsiveContainer,
  YAxisProps,
  Tooltip as RechartsTooltip,
  Area,
  Cell,
  Dot,
  DotProps
} from 'recharts'
import { CustomTooltipProps, LineFitData, ReturnCurveChartProps, ReturnCurveData } from './types'
import { useTheme } from '@mui/material/styles'
import { CurveType } from 'recharts/types/shape/Curve'

const LARGE_DOT_RADIUS = 10
const SMALL_DOT_RADIUS = 5

const getLines = ({
  lineFitData,
  currentPoint,
  optimalPoint
}: {
  lineFitData: LineFitData[]
  currentPoint: ReturnCurveData
  optimalPoint: ReturnCurveData
}): Array<{ x: number; yFitA?: number; yFitB?: number; yFitC?: number }> => {
  const { x: currentX } = currentPoint
  const { x: optimalX } = optimalPoint

  const minimumX = Math.min(currentX, optimalX)
  const maximumX = Math.max(currentX, optimalX)

  return lineFitData.map(({ x, y, ...rest }) => ({
    x,
    ...(x <= minimumX && { yFitA: y }),
    ...(x > minimumX && x <= maximumX && { yFitB: y }),
    ...(x > maximumX && { yFitC: y }),
    ...rest
  }))
}

const ReturnCurveChart = ({
  containerWidth = '100%',
  containerHeight = '100%',
  showScatterPlot = false,
  showLine = true,
  lineProps,
  confidenceIntervalProps,
  scatterPlotData,
  currentPoint,
  optimalPoint,
  xAxis,
  yAxis,
  lineFitData,
  getDotColor,
  renderTooltip,
  ...rest
}: ReturnCurveChartProps): React.ReactElement => {
  const theme = useTheme()

  const RenderDot: FC<DotProps> = ({ cx, cy, r, fill }) => {
    return <Dot cx={cx} cy={cy} fill={fill} r={r} />
  }

  const processedLineFitData = useMemo(() => {
    return getLines({
      lineFitData,
      currentPoint,
      optimalPoint
    })
  }, [lineFitData, currentPoint, optimalPoint])

  const defaultXAxisProps = {
    dataKey: 'x',
    tickLine: false,
    axisLine: { stroke: theme.palette.neutrals.stone100 },
    tick: { fontWeight: theme.typography.fontWeightBold },
    stroke: theme.palette.secondary.main,
    type: 'number' as 'number'
  }

  const defaultYAxisProps: YAxisProps = {
    axisLine: false,
    tickLine: false,
    tick: { textAnchor: 'end' },
    stroke: theme.palette.secondary.main,
    type: 'number'
  }

  const defaultYLabelProps = {
    position: 'top',
    offset: 15,
    stroke: theme.palette.secondary.main,
    style: { textAnchor: 'middle' }
  }

  const defaultXLabelProps = {
    stroke: theme.palette.secondary.main,
    position: 'bottom'
  }

  const defaultLineProps = {
    strokeWidth: 4,
    stroke: theme.palette.secondary.main,
    dot: false,
    isAnimationActive: false,
    type: 'natural' as CurveType
  }

  const defaultConfidenceIntervalProps = {
    isAnimationActive: false,
    dataKey: 'confidenceInterval',
    fill: theme.palette.custom.grey,
    stroke: 'none',
    type: 'monotone' as CurveType
  }

  const linePropsWithDefaults = { ...defaultLineProps, ...lineProps }

  const yScatterDataKeys = Object.keys(scatterPlotData[0]).filter(
    (key) => key !== 'x' && key !== 'label' && key !== 'date'
  )

  const optimalCurrentPoints = [optimalPoint, currentPoint]

  const CustomTooltip = ({ active = false, payload }: CustomTooltipProps): React.ReactElement | null => {
    if (active && payload != null && payload?.length > 0 && renderTooltip != null) {
      return renderTooltip({ dataKey: payload[0].dataKey, payload: payload[0].payload })
    }

    return null
  }

  return (
    <ResponsiveContainer width={containerWidth} height={containerHeight}>
      <ComposedChart data={processedLineFitData} {...rest}>
        <XAxis
          {...{
            ...defaultXAxisProps,
            ...xAxis,
            label: {
              ...defaultXLabelProps,
              ...xAxis?.label
            }
          }}
        />
        <RechartsTooltip content={<CustomTooltip />} />
        {showLine && <Area {...{ ...defaultConfidenceIntervalProps, ...confidenceIntervalProps }} />}
        {showScatterPlot &&
          yScatterDataKeys.map((dataKey, index) => {
            return (
              <Scatter
                key={index}
                data={scatterPlotData}
                dataKey={dataKey}
                isAnimationActive={false}
                shape={<RenderDot />}
              >
                {scatterPlotData.map((entry, i) => (
                  <Cell key={`cell-${i}`} fill={getDotColor(entry)} r={SMALL_DOT_RADIUS} />
                ))}
              </Scatter>
            )
          })}
        {showLine && (
          <>
            <Line dataKey="yFitA" {...linePropsWithDefaults} />
            <Line dataKey="yFitB" strokeDasharray="3 3" {...linePropsWithDefaults} />
            <Line dataKey="yFitC" {...linePropsWithDefaults} />

            <Scatter data={optimalCurrentPoints} dataKey="y" isAnimationActive={false} shape={<RenderDot />}>
              {optimalCurrentPoints.map((entry, i) => (
                <Cell key={`cell-${i}`} fill={getDotColor(entry)} r={LARGE_DOT_RADIUS} />
              ))}
            </Scatter>
          </>
        )}

        <YAxis
          {...{
            ...defaultYAxisProps,
            ...yAxis,
            label: {
              ...defaultYLabelProps,
              ...yAxis?.label
            }
          }}
        />
      </ComposedChart>
    </ResponsiveContainer>
  )
}
export default ReturnCurveChart
