import React, { useCallback, useEffect, useRef, useState } from "react";
import {
  select,
  axisBottom,
  scaleLinear,
  axisLeft,
  extent,
  ScaleOrdinal,
  scaleBand,
  scaleOrdinal,
  timeFormat,
} from "d3";
import { GraphMargins, formatter } from "./common";
import { Swatches } from "./Swatches";
// https://codesandbox.io/p/sandbox/d3-react-stacked-barchart-hnyre

interface ArrayWithKey<T> extends Array<T> {
  key: string;
}

interface Data {
  date: Date;
  entries: [string,number,number,number][];
}

interface ArrayWithData extends Array<number> {
  data: Data;
}

interface StackedBarChartProps {
  data: Map<Date, Map<string,number>>;
  stackClicked: (date: Date, category: string, total: number) => void;
}

const margins: GraphMargins = {
  top: 10,
  right: 50,
  bottom: 40,
  left: 100,
};

const COLOR_MAP: Map<string,string> = new Map([
  ["Auto & Transport", "#222"],
  ["Clothing", "#80b1d3"],
  ["Dog", "#e5c494"],
  ["Entertainment", "#bebada"],
  ["Food & Drink", "#fb8072"],
  ["Gift", "#fccde5"],
  ["Health & Wellness", "#8dd3c7"],
  ["Home & Utilities", "#d9d9d9"],
  ["Income", "#b3de69"],
  ["Investments", "#ccebc5"],
  ["Taxes", "#bc80bd"],
  ["Vacation", "#ffed6f"]
])

const StackedBarChart: React.FC<StackedBarChartProps> = ({ data, stackClicked }) => {
  const [layers, setLayers] = useState<ArrayWithKey<ArrayWithData>[]>(([] as unknown) as ArrayWithKey<ArrayWithData>[]);
  const [filteredCategories, setFilteredCategories] = useState<string[]>(["Investments", "Taxes"]);
  const [selectedCategory, setSelectedCategory] = useState<string>();

  const colorsRef = useRef<ScaleOrdinal<string, string, string>>(scaleOrdinal(COLOR_MAP.keys(), COLOR_MAP.values()).unknown("#fdb462"));
  const svgRef = useRef<SVGSVGElement>(null);

  const xAxisRef = useRef<SVGGElement>(null);
  const yAxisRef = useRef<SVGGElement>(null);
  
  const xScale = useRef<d3.ScaleBand<d3.NumberValue | Date>>(scaleBand());
  const yScale = useRef<d3.ScaleLinear<number,number,any>>(scaleLinear());

  const xAxis = useRef<d3.Axis<d3.NumberValue | Date>>(axisBottom(xScale.current).tickPadding(6).tickFormat(timeFormat("%Y-%m") as any));
  const yAxis = useRef<d3.Axis<d3.NumberValue | Date>>(axisLeft(yScale.current).tickPadding(6).tickFormat((a: any) => formatter.format(a)));

  const resizeRef = useRef<ResizeObserver>();

  const [graphWidth, setGraphWidth] = useState(0);
  const [graphHeight, setGraphHeight] = useState(0);
  const [xDomain, setXDomain] = useState<any[]>([new Date()]);
  const [yDomain, setYDomain] = useState<[any,any]>([0,0]);
  const [categoriesDomain, setCategoriesDomain] = useState<string[]>([]);

  const ticks = xScale.current.domain().filter((d: any) => { return d.getMonth() % 2 === 0} );
  xScale.current.range([0, graphWidth]).domain(xDomain);
  yScale.current.range([graphHeight, 0]).domain(yDomain);
  xAxis.current.scale(xScale.current).tickSize(-graphHeight).tickValues(ticks);
  yAxis.current.scale(yScale.current).tickSize(-graphWidth);
  select(xAxisRef.current).call(xAxis.current as any).transition().style('color', '#CCC');
  select(yAxisRef.current).call(yAxis.current as any).transition().style('color', '#CCC');

  const resize = useCallback(() => {
      const { current: svg } = svgRef;
      if (!svg) { return; }

      const updatedWidth = svg.clientWidth;
      const updatedHeight = svg.clientHeight;
      const updatedGraphWidth = updatedWidth - margins.left - margins.right;
      const updatedGraphHeight = updatedHeight - margins.bottom;

      setGraphWidth(updatedGraphWidth);
      setGraphHeight(updatedGraphHeight);
  }, []);

  useEffect(() => {
    const updatedFilteredData = new Map<Date,Map<string,number>>();
    for (const [date, entry] of data.entries()) {
      const filteredEntries = new Map<string, number>();
      for (const[type, value] of entry) {
        if (filteredCategories.indexOf(type) === -1) {
          filteredEntries.set(type, value);
        }
      }
      updatedFilteredData.set(date, filteredEntries);
    }
    for (const [, entry] of updatedFilteredData) {
      for (const [type, value] of entry) {
        if (value === 0) {
          entry.delete(type);
        }
      }
    }

    const updatedDays = [];
    for (const [date, values] of updatedFilteredData.entries()) {
      const entries: [string,number,number,number][] = [];
      const valueEntries = Array.from(values.entries()).sort((a,b) => {
        const absA = Math.abs(a[1]);
        const absB = Math.abs(b[1]);
        return absA < absB ? 1 : absB < absA ? -1 : 0;
      });
      var negativeOffset = 0;
      var positiveOffset = 0;
      for (var i = 0; i < valueEntries.length; i++) {
        const [category, value] = valueEntries[i];
        
        const positive = value >= 0;
        const start = positive ? positiveOffset : negativeOffset;
        const end = start + value;
        if (positive) {
          positiveOffset = end;
        } else {
          negativeOffset = end;
        }
        entries.push([category, value, start, end])
      }

      updatedDays.push({
        date,
        entries,
      });
    }

    const keys = Array.from(new Set([...updatedFilteredData.values()].map(a => [...a.keys()]).flat(1))).sort();
    const updatedLayersv2: ArrayWithKey<ArrayWithData>[] = ([] as unknown) as ArrayWithKey<ArrayWithData>[];
    for (var k = 0; k < keys.length; k++) {
      const key = keys[k];
      const layer: ArrayWithKey<ArrayWithData> = ([] as unknown) as ArrayWithKey<ArrayWithData>;
      for (var d = 0; d < updatedDays.length; d++) {
        const day = updatedDays[d];
        const [,,start,end] = day.entries.find(e => e[0] === key) || [undefined,undefined,undefined,undefined];
        if (start === undefined || end === undefined) {
          // console.warn(`Start and end were undefined for ${key} - ${day.date}`);
          continue;
        }
        const arrayWithData = ((end < start) ? [end,start] : [start,end] as unknown) as ArrayWithData;
        arrayWithData.data = day;
        layer.push(arrayWithData);
      }
      (layer as any).key = key;
      updatedLayersv2.push(layer);
    }
    setLayers(updatedLayersv2);
    setXDomain(Array.from(updatedFilteredData.keys()));
    setYDomain(extent(updatedLayersv2.flat(2)));
    setCategoriesDomain(keys.concat(filteredCategories).sort());
  }, [data, filteredCategories]);

  useEffect(() => {
    if (!resizeRef.current) {
        resizeRef.current = new ResizeObserver(resize);
    }
    const { current: currentSvgRef } = svgRef;
    if (currentSvgRef) {
        resizeRef.current.observe(currentSvgRef);
    }
    
    resize();
    return () => {
        if (resizeRef.current && currentSvgRef) {
            resizeRef.current?.unobserve(currentSvgRef);
        }
    }
  }, [resize]);

  return (
    <div
      style={{
        display: 'flex',
        flexDirection: 'column',
        width: '100%',
        height: '100%'
    }}>
      <Swatches
        columns="180px"
        colors={colorsRef.current}
        domain={categoriesDomain}
        filteredCategories={filteredCategories}
        swatchClicked={(value: string) => {
          const updatedFilteredCategories = [...filteredCategories];
          const valueIndex = filteredCategories.indexOf(value);
          valueIndex !== -1
            ? updatedFilteredCategories.splice(valueIndex, 1)
            : updatedFilteredCategories.push(value);
          setFilteredCategories(updatedFilteredCategories);
        }}
        selectedSwatch={selectedCategory}
      />
      <div style={{ flexGrow: 1, width: '100%', height: '100%' }}>
        <svg style={{ width: '100%', height: '100%' }} ref={svgRef}>
          <g
            transform={`translate(${margins.left},${margins.top / 2})`}
          >
            <g
              ref={xAxisRef}
              transform={`translate(0, ${graphHeight})`}
            />
            <g
              ref={yAxisRef}
            />
            <g>
              {layers.map(layer => {
                const layerInDomain = colorsRef.current.domain().indexOf(layer.key) !== -1;
                const stroke = layerInDomain ? '': 'gray';
                const strokeWidth = layerInDomain ? '': '0.5px';
                return (
                  <g
                    key={layer.key}
                    fill={colorsRef.current(layer.key)}
                    opacity={(selectedCategory === undefined || selectedCategory === layer.key) ? 1 : 0.25}
                    stroke={stroke}
                    strokeWidth={strokeWidth}
                  >
                    {layer.map(data => {
                      // console.log(layer.key, data, !isNaN(data[0]) && !isNaN(data[1]), yScale.current(data[0]), yScale.current(data[1]));
                      return (
                        <rect
                          key={data.data.date.toString()}
                          x={xScale.current(data.data.date) || 0}
                          width={xScale.current.bandwidth()}
                          y={yScale.current(data[1])}
                          height={(!isNaN(data[0]) && !isNaN(data[1])) ?
                            yScale.current(data[0]) - yScale.current(data[1]) : 0}
                          onClick={() => {
                            stackClicked(
                              data.data.date as Date,
                              layer.key,
                              (data.data.entries.find(e => e[0] === layer.key) || ['',0])[1]
                            );
                          }}
                          onMouseOver={(event) => {
                            setSelectedCategory(layer.key);
                            select(event.currentTarget).attr('stroke', 'gray').attr('stroke-width', 1.0);
                          }}
                          onMouseOut={(event) => {
                            setSelectedCategory(undefined);
                            select(event.currentTarget).attr('stroke', stroke).attr('stroke-width', strokeWidth);
                          }}
                        >
                          <title>
                            {layer.key} | {data.data.date.toISOString().substring(0,7)} | ${(data.data.entries.find(e => e[0] === layer.key) || ['', 0])[1].toFixed(2)}
                          </title>
                        </rect>
                      );
                    })}

                  </g>
                );
              })}
            </g>
          </g>
        </svg>
      </div>
    </div>
  );
}

export default StackedBarChart;