import React, {useState, useRef} from "react";
import Tooltip from "../Tooltip";
import CrossHair from "../CrossHair";
import useD3 from "../../hooks/useD3";
import useImmediate from "../../hooks/useImmediate";
import * as d3 from "d3";
import styles from "./StackedBarChart.module.css";

const BAR_WIDTH = 12;
const MIN_LABEL_WIDTH = 35;
const PALETTE = ["#DBFF00","#E07E46", "#FF2727", "#39FF4D", "#55C4FF", "#3B63F0", "#FB9FEC"];
const MARGIN = {top: 20, right: 72, bottom: 59, left: 77};

export default function StackedBarChart({data, width, height, detail, yAxisFormat, yDomainMin, normalize}) {
    const [tooltip, setTooltip] = useState(null);
    const tooltipRef = useRef();

    function handleMouseMove(e) {
        const t = tooltipRef.current;
        t.style.left = `${e.clientX}px`;
        t.style.bottom = `${window.innerHeight - e.clientY + 30}px`;
    }
    function handleMouseLeave(e) {
        setTooltip(null);
    }
    
    const animate = useImmediate(() => true, [data]);

    const ref = useD3((g) => {
        const {categories, series} = data;
        if (!series)
            return;
        
        function handleMouseEnter(e) {
            if (detail) {
                const d = d3.select(this).datum();
                const d1 = d3.select(this.parentNode).datum();
                setTooltip(detail(d, d1, categories));
                handleMouseMove(e);
            }
        }
        
        const xScale = d3.scaleBand()
            .domain(series.map(i => i.x))
            .range([MARGIN.left, width - MARGIN.right]);
        xScale.paddingInner(Math.max(0.1, (xScale.bandwidth() - BAR_WIDTH) / xScale.bandwidth()));
        const yScale = d3.scaleLinear()
            .domain(normalize ? [0, 1] : [0, Math.max(yDomainMin ?? 10, d3.max(series, d => d3.sum(d.y)))])
            .range([height - MARGIN.bottom, MARGIN.top]).nice();
        const cScale = d3.scaleOrdinal()
            .domain(categories)
            .range(PALETTE);
        
        const stacks = d3.stack()
            .keys(Array(categories.length).fill(1).map((e, i) => i))
            .value((d, key) => d.y[key])
            .order(d3.stackOrderDescending)
            .offset(normalize ? d3.stackOffsetExpand : d3.stackOffsetNone)
            (series);
        
        let sel = g.append("g").attr("class", "plot")
            .selectAll("g")
            .data(stacks)
            .enter()
            .append("g")
            .attr("fill", d => cScale(d.key))
            .selectAll("rect")
            .data(d => d)
            .enter()
            .append("rect")
            .attr("x", d => xScale(d.data.x))
            .attr("y", d => yScale(d[0]))
            .attr("width", xScale.bandwidth())
            .attr("height", 0)
            .on("mouseenter", handleMouseEnter)
            .on("mousemove", handleMouseMove)
            .on("mouseleave", handleMouseLeave);
        if (animate) sel = sel.transition().delay((d, i) => i + 10).duration(500);
        sel.attr("y", d => yScale(d[1]))
            .attr("height", d => yScale(d[0]) - yScale(d[1]));

        const text = g.append("g").attr("class", "xAxis")
            .attr("transform", `translate(0, ${height - MARGIN.bottom + 16})`)
            .style("font-size", xScale.step() < MIN_LABEL_WIDTH ? "10px" : "12px")
            .call(d3.axisBottom(xScale)
                .tickSize(0)
                .tickPadding(0)
            )
            .selectAll("text");
        if (xScale.step() < MIN_LABEL_WIDTH)
            text.style("text-anchor", "end").attr("transform", "rotate(-65)");
        g.select(".xAxis").select(".domain").remove();

        g.append("g").attr("class", "yAxis")
            .attr("transform", `translate(${MARGIN.left - 30}, 0)`)
            .call(d3.axisLeft(yScale)
                .ticks(Math.floor(height / 45), yAxisFormat)
                .tickSize(0)
                .tickPadding(0)
            );
        g.select(".yAxis").select(".domain").remove();
    }, [data, width, height]);

    return (
        <>
            <svg width={width} height={height} viewBox={`0 0 ${width} ${height}`}>
                <g ref={ref} className={styles.chart}>
                </g>
                <CrossHair width={width} height={height} margin={MARGIN}/>
            </svg>
            <Tooltip content={tooltip} ref={tooltipRef}/>
        </>
    );
}
