from collections import defaultdict
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from collections import Counter
from typing import DefaultDict
from datetime import datetime, timedelta
from functools import lru_cache
import re

@lru_cache(maxsize=128)
def calculate_date_range_rd(duration: str) -> tuple:
    """
    Calculates the start and end date based on the duration provided.
    Duration format: <integer><unit>, where unit can be:
    D - days, W - weeks, M - months (approx. 30 days), Y - years (365 days).

    :param duration: A string representing the time duration (e.g., '5D', '2W', '3M', '1Y').
    :return: A tuple of strings representing the start and end dates in '%d-%m-%Y' format.
    :raises ValueError: If the duration format is invalid.
    """
    today = datetime.today()

    # Regular expression to parse duration string
    match = re.match(r'^(\d+)([DWMY])$', duration.upper())
    if not match:
        raise ValueError("Invalid duration format. Use format like '5D', '2W', '3M', or '1Y'.")

    # Extract the number and unit from the duration
    number, unit = int(match.group(1)), match.group(2)

    # Map units to timedelta
    unit_mapping = {
        'D': timedelta(days=1),
        'W': timedelta(weeks=1),
        'M': timedelta(days=30),  # Approximate for months
        'Y': timedelta(days=365)  # Approximate for years
    }

    # Calculate the start date
    start_dt = today - (number * unit_mapping[unit])
    return start_dt.strftime('%d-%m-%Y'), today.strftime('%d-%m-%Y')

def calculate_report_values_rd(doc):
    high = doc['high']
    low = doc['low']
    close = doc['close']
    prev_close = doc['prev_close']

    processed_doc = {
        'symbol': doc['symbol'],
        'date': doc['date'],
        'prev_close': doc['prev_close'],
        'open': doc['open'],
        'high': doc['high'],
        'low': doc['low'],
        'close': doc['close'],
        'close_minus_prev_close': close - prev_close,
        'high_minus_low': high - low,
        'high_minus_prev_close': high - prev_close,
        'high_minus_prev_close_perc': (high - prev_close) / prev_close * 100 if prev_close != 0 else 0,
        'high_minus_low_perc': (high - low) / high * 100 if high != 0 else 0,
        'close_minus_prev_close_perc': (close - prev_close) / prev_close * 100 if prev_close != 0 else 0
    }

    return processed_doc

def filter_response_rd(data, condition):
    if len(condition):
        # Map filters to their corresponding keys
        filter_map = {
            "0": "high_minus_prev_close",
            "1": "high_minus_low",
            "2": "close_minus_prev_close",
            "4": "high_minus_prev_close_perc",
            "3": "high_minus_low_perc",
            "5": "close_minus_prev_close_perc",
        }

        # Apply filtering logic
        def matches(item, conditions):
            results = []
            for filt in conditions:
                key = filter_map.get(filt["filter"])
                value = float(filt["value"])
                operator = filt.get("operator", 1)  # Default to AND
                if not key:
                    continue

                condition_met = False
                if filt["condition"] == "greater":
                    condition_met = item[key] > value
                elif filt["condition"] == "less":
                    condition_met = item[key] < value
                elif filt["condition"] == "equal":
                    condition_met = item[key] == value

                results.append(condition_met)

                # Evaluate conditions based on logical operator
                if operator == 1 and not all(results):
                    return False
                elif operator == 2 and any(results):
                    return True
            return all(results)

        # Filter the data across durations
        filtered_data = {}
        for duration_idx, (duration, content) in enumerate(data.items()):
            date_range = content["date_range"]
            grouped_data = content["data"]
            filtered_data[duration] = {
                "date_range": date_range,
                "data": []
            }

            # Get the condition set for the current duration index
            condition_set = dict(condition[duration_idx])
            conditions_data = []
            for cond in condition_set["conditions"]:
                conditions_data.append(dict(cond))

            # Apply filters for this time range
            for symbol_data in grouped_data:
                if matches(symbol_data, conditions_data):
                    filtered_data[duration]["data"].append(symbol_data)

    return filtered_data

async def overall_data_rd(request, collection):
    try:
        # Get date range either from duration or custom dates
        if request.duration and not (request.start_date and request.end_date):
            request.start_date, request.end_date = calculate_date_range_rd(request.duration)
        elif request.start_date and request.end_date:
            try:
                # Validate date format
                datetime.strptime(request.start_date, '%d-%m-%Y')
                datetime.strptime(request.end_date, '%d-%m-%Y')
            except ValueError:
                return JSONResponse({"detail": "Invalid date format. Use 'dd-mm-yyyy'."}, status_code=400)
        else:
            return JSONResponse({"detail": "Either duration or custom start/end dates must be provided."}, status_code=400)

        if request.input_symbol:
            pipeline = [
                {
                    '$match': {
                        # 'symbol': {"$regex": f"^{request.input_symbol.upper()}"}
                        'symbol': f"{request.input_symbol.upper()}"
                        }
                },
                {
                    '$project': {
                        '_id': 0,
                        'symbol': 1,
                        'prices': {
                            '$filter': {
                                'input': {'$objectToArray': "$prices"},
                                'as': 'date_item',
                                'cond': {
                                    '$and': [
                                        {
                                            '$gte': [
                                                {'$dateFromString': {'dateString': "$$date_item.k"}},
                                                {'$dateFromString': {'dateString': request.start_date}}
                                            ]
                                        },
                                        {
                                            '$lte': [
                                                {'$dateFromString': {'dateString': "$$date_item.k"}},
                                                {'$dateFromString': {'dateString': request.end_date}}
                                            ]
                                        }
                                    ]
                                }
                            }
                        }
                    }
                },
                {
                    '$unwind': '$prices'
                },
                {
                    '$project': {
                        'symbol': 1,
                        'date': '$prices.k',
                        'prev_close': {'$toDouble': {'$ifNull': ['$prices.v.prev_close', 0]}},
                        'open': {'$toDouble': '$prices.v.open'},
                        'high': {'$toDouble': '$prices.v.high'},
                        'low': {'$toDouble': '$prices.v.low'},
                        'close': {'$toDouble': '$prices.v.close'}
                    }
                }
            ]

            cursor = collection.aggregate(pipeline, allowDiskUse=True)
            documents = []
            async for doc in cursor:
                processed_doc = calculate_report_values_rd(doc=doc)
                documents.append(processed_doc)

            return JSONResponse(content=documents, status_code=200)
        else:
            pipeline = [
                {
                    '$match': {
                        # **({'Instrument': request.Instrument} if request.Instrument != "ALL" else {})
                        **({'Instrument': request.Instrument} if request.Instrument and request.Instrument != "ALL" else {})
                    }
                },
                {
                    '$project': {
                        '_id': 0,
                        'symbol': 1,
                        'prices': {
                            '$filter': {
                                'input': {'$objectToArray': "$prices"},
                                'as': 'date_item',
                                'cond': {
                                    '$and': [
                                        {
                                            '$gte': [
                                                {'$dateFromString': {'dateString': "$$date_item.k"}},
                                                {'$dateFromString': {'dateString': request.start_date}}
                                            ]
                                        },
                                        {
                                            '$lte': [
                                                {'$dateFromString': {'dateString': "$$date_item.k"}},
                                                {'$dateFromString': {'dateString': request.end_date}}
                                            ]
                                        }
                                    ]
                                }
                            }
                        }
                    }
                },
                {
                    '$unwind': '$prices'
                },
                {
                    '$project': {
                        'symbol': 1,
                        'date': '$prices.k',
                        'prev_close': {'$toDouble': {'$ifNull': ['$prices.v.prev_close', 0]}},
                        'open': {'$toDouble': '$prices.v.open'},
                        'high': {'$toDouble': '$prices.v.high'},
                        'low': {'$toDouble': '$prices.v.low'},
                        'close': {'$toDouble': '$prices.v.close'}
                    }
                }
            ]

            cursor = collection.aggregate(pipeline, allowDiskUse=True)
            grouped_data = DefaultDict(lambda: DefaultDict(list))

            async for doc in cursor:
                symbol = doc['symbol']
                high = doc['high']
                low = doc['low']
                close = doc['close']
                prev_close = doc['prev_close']

                # Calculate metrics
                close_minus_prev_close = close - prev_close
                high_minus_low = high - low
                high_minus_prev_close = high - prev_close
                high_minus_prev_close_perc = (high - prev_close) / prev_close * 100 if prev_close != 0 else 0
                high_minus_low_perc = (high - low) / high * 100 if high != 0 else 0
                close_minus_prev_close_perc = (close - prev_close) / prev_close * 100 if prev_close != 0 else 0

                # Group data by symbol
                grouped_data[symbol]['high_minus_prev_close'].append(high_minus_prev_close)
                grouped_data[symbol]['high_minus_low'].append(high_minus_low)
                grouped_data[symbol]['close_minus_prev_close'].append(close_minus_prev_close)
                grouped_data[symbol]['high_minus_prev_close_perc'].append(high_minus_prev_close_perc)
                grouped_data[symbol]['high_minus_low_perc'].append(high_minus_low_perc)
                grouped_data[symbol]['close_minus_prev_close_perc'].append(close_minus_prev_close_perc)

            # Calculate averages
            result = []
            for symbol, fields in grouped_data.items():
                avg_data = {
                    'symbol': symbol,
                    'high_minus_prev_close': sum(fields['high_minus_prev_close']) / len(fields['high_minus_prev_close']) if fields['high_minus_prev_close'] else 0,
                    'high_minus_low': sum(fields['high_minus_low']) / len(fields['high_minus_low']) if fields['high_minus_low'] else 0,
                    'close_minus_prev_close': sum(fields['close_minus_prev_close']) / len(fields['close_minus_prev_close']) if fields['close_minus_prev_close'] else 0,
                    'high_minus_prev_close_perc': sum(fields['high_minus_prev_close_perc']) / len(fields['high_minus_prev_close_perc']) if fields['high_minus_prev_close_perc'] else 0,
                    'high_minus_low_perc': sum(fields['high_minus_low_perc']) / len(fields['high_minus_low_perc']) if fields['high_minus_low_perc'] else 0,
                    'close_minus_prev_close_perc': sum(fields['close_minus_prev_close_perc']) / len(fields['close_minus_prev_close_perc']) if fields['close_minus_prev_close_perc'] else 0
                }
                result.append(avg_data)

            # if request.type == 1:
            #     result = filter_response_rd(result, request.filters)

            return JSONResponse(content=result, status_code=200)

    except Exception as e:
        return JSONResponse({"detail": f"An error occurred: {str(e)}"}, status_code=500)

async def filter_date_range(request, collection):
    rec_filters = request.filters

    # Ensure filters are provided
    if not rec_filters:
        raise HTTPException(status_code=400, detail="No filters provided.")

    # Extract and parse all dates from filters
    try:
        all_dates = [
            datetime.strptime(date, "%Y-%m-%d")
            for item in rec_filters
            for date in item.timeRange  # Accessing timeRange as an attribute
        ]
    except ValueError:
        raise HTTPException(status_code=400, detail="Invalid date format. Use 'yyyy-mm-dd'.")

    # Get the range of dates
    oldest_date = min(all_dates)
    newest_date = max(all_dates)

    # Convert dates back to string
    start_date = oldest_date.strftime("%Y-%m-%d")
    end_date = newest_date.strftime("%Y-%m-%d")

    # Build the aggregation pipeline
    pipeline = [
        {
            '$match': {
                **({'Instrument': request.Instrument} if request.Instrument != "ALL" else {})
            }
        },
        {
            '$project': {
                '_id': 0,
                'symbol': 1,
                'prices': {
                    '$filter': {
                        'input': {'$objectToArray': "$prices"},
                        'as': 'date_item',
                        'cond': {
                            '$and': [
                                {'$gte': [{'$dateFromString': {'dateString': "$$date_item.k"}}, {'$dateFromString': {'dateString': start_date}}]},
                                {'$lte': [{'$dateFromString': {'dateString': "$$date_item.k"}}, {'$dateFromString': {'dateString': end_date}}]}
                            ]
                        }
                    }
                }
            }
        },
        {'$unwind': '$prices'},
        {
            '$project': {
                'symbol': 1,
                'date': '$prices.k',
                'prev_close': {'$toDouble': {'$ifNull': ['$prices.v.prev_close', 0]}},
                'open': {'$toDouble': {'$ifNull': ['$prices.v.open', 0]}},
                'high': {'$toDouble': {'$ifNull': ['$prices.v.high', 0]}},
                'low': {'$toDouble': {'$ifNull': ['$prices.v.low', 0]}},
                'close': {'$toDouble': {'$ifNull': ['$prices.v.close', 0]}}
            }
        }
    ]

    # Execute the pipeline
    try:
        cursor = await collection.aggregate(pipeline, allowDiskUse=True).to_list(length=None)
        return cursor
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Database query failed: {str(e)}")

async def get_duration_data(request, collection):
    # Get data filtered by date range
    duration_data = await filter_date_range(request, collection)

    seperated_data = {}  # Initialize as a dictionary to store results by key
    rec_filters = request.filters  # Assuming `request.filters` is a list of filter objects

    counter = 1
    for rfilter in rec_filters:
        dict_key = f"duration_{counter}"
        sdt, edt = rfilter.timeRange  # Accessing timeRange as an attribute

        # Parse start and end dates
        try:
            start_date = datetime.strptime(sdt, "%Y-%m-%d")
            end_date = datetime.strptime(edt, "%Y-%m-%d")
        except ValueError:
            raise HTTPException(status_code=400, detail="Invalid date format. Use 'yyyy-mm-dd'.")

        # Filter data for the specific range
        filtered_data = [
            entry for entry in duration_data
            if start_date <= datetime.strptime(entry["date"], "%Y-%m-%d") <= end_date
        ]
        seperated_data[dict_key] = filtered_data  # Assign to dictionary with the current key

        counter += 1  # Increment counter for the next duration key

    return seperated_data  # Return filtered data for all ranges

def calculate_metrics(data):
    results = {}
    common_symbols = None

    for duration, records in data.items():
        grouped_data = defaultdict(lambda: defaultdict(list))

        for record in records:
            symbol = record['symbol']
            high = record['high']
            low = record['low']
            close = record['close']
            prev_close = record['prev_close']

            # Calculate metrics
            close_minus_prev_close = close - prev_close
            high_minus_low = high - low
            high_minus_prev_close = high - prev_close
            high_minus_prev_close_perc = (high - prev_close) / prev_close * 100 if prev_close != 0 else 0
            high_minus_low_perc = (high - low) / high * 100 if high != 0 else 0
            close_minus_prev_close_perc = (close - prev_close) / prev_close * 100 if prev_close != 0 else 0

            # Group data by symbol
            grouped_data[symbol]['high_minus_prev_close'].append(high_minus_prev_close)
            grouped_data[symbol]['high_minus_low'].append(high_minus_low)
            grouped_data[symbol]['close_minus_prev_close'].append(close_minus_prev_close)
            grouped_data[symbol]['high_minus_prev_close_perc'].append(high_minus_prev_close_perc)
            grouped_data[symbol]['high_minus_low_perc'].append(high_minus_low_perc)
            grouped_data[symbol]['close_minus_prev_close_perc'].append(close_minus_prev_close_perc)

        # Update common symbols
        if common_symbols is None:
            common_symbols = set(grouped_data.keys())
        else:
            common_symbols.intersection_update(grouped_data.keys())

        # Store grouped data for the current duration
        results[duration] = {
            "date_range": {
                "start_date": records[0]['date'],
                "end_date": records[-1]['date']
            },
            "data": grouped_data
        }

    # Filter results to only include common symbols
    filtered_results = {}
    for duration, content in results.items():
        grouped_data = content["data"]
        filtered_results[duration] = {
            "date_range": content["date_range"],
            "data": []
        }
        for symbol in common_symbols:
            fields = grouped_data[symbol]
            avg_data = {
                'symbol': symbol,
                'high_minus_prev_close': sum(fields['high_minus_prev_close']) / len(fields['high_minus_prev_close']) if fields['high_minus_prev_close'] else 0,
                'high_minus_low': sum(fields['high_minus_low']) / len(fields['high_minus_low']) if fields['high_minus_low'] else 0,
                'close_minus_prev_close': sum(fields['close_minus_prev_close']) / len(fields['close_minus_prev_close']) if fields['close_minus_prev_close'] else 0,
                'high_minus_prev_close_perc': sum(fields['high_minus_prev_close_perc']) / len(fields['high_minus_prev_close_perc']) if fields['high_minus_prev_close_perc'] else 0,
                'high_minus_low_perc': sum(fields['high_minus_low_perc']) / len(fields['high_minus_low_perc']) if fields['high_minus_low_perc'] else 0,
                'close_minus_prev_close_perc': sum(fields['close_minus_prev_close_perc']) / len(fields['close_minus_prev_close_perc']) if fields['close_minus_prev_close_perc'] else 0
            }
            filtered_results[duration]["data"].append(avg_data)

    return filtered_results

# Function to remove duplicates from a list of dictionaries
def remove_duplicates(data_list):
    seen = set()
    unique_data = []
    for obj in data_list:
        # Convert dictionary to a tuple of sorted items for hashing
        obj_tuple = tuple(sorted(obj.items()))
        if obj_tuple not in seen:
            seen.add(obj_tuple)
            unique_data.append(obj)
    return unique_data

# New Function
async def calculate_multidura(request, collection):
    rec_filters = request.filters

    filters = []
    # Assuming FilterItem_test objects have attributes 'conditions' and 'filter'
    for filter_item in rec_filters:
        if hasattr(filter_item, 'conditions'):
            for condition in filter_item.conditions:
                if hasattr(condition, 'filter'):
                    filters.append(condition.filter)

    seperated_data = await get_duration_data(request, collection)
    calculated_data = calculate_metrics(seperated_data)  # Data for different durations is calculated here

    filtered_data = filter_response_rd(data=calculated_data, condition=rec_filters)
    for key, value in filtered_data.items():
        value["data"] = remove_duplicates(value["data"])

    # Collect all data grouped by symbol with duration-specific keys
    symbol_data = defaultdict(lambda: defaultdict(list))
    filter_map = {
        "0": "high_minus_prev_close",
        "1": "high_minus_low",
        "2": "close_minus_prev_close",
        "4": "high_minus_prev_close_perc",
        "3": "high_minus_low_perc",
        "5": "close_minus_prev_close_perc",
    }

    # Store symbols for each duration
    duration_symbols = []

    # Then map the filters to their names
    filter_names = []
    for filter_value in filters:
        try:
            target = filter_map.get(filter_value)  # Using .get() instead of [] for safer dictionary access
            if target is not None:
                filter_names.append(target)
        except (KeyError, TypeError) as e:
            print(f"Error processing filter {filter_value}: {e}")
            continue

    for duration_index, duration in enumerate(filtered_data.values()):
        current_symbols = set()
        for entry in duration["data"]:
            symbol = entry["symbol"]
            current_symbols.add(symbol)
            for key, value in entry.items():
               # if key != "symbol" and key in filter_map.values():
                if key != "symbol" and key in filter_names:
                # Append duration-specific data with new key
                    new_key = f"{key}_duration{duration_index}"
                    symbol_data[symbol][new_key].append(value)
        duration_symbols.append(current_symbols)

    # Get symbols common across all durations
    common_symbols = set.intersection(*duration_symbols)

    # Format the result
    result = {
        "common_data": [
            {
                "symbol": symbol,
                **{key: values[0] for key, values in attributes.items() if len(values) == 1},
            }
            for symbol, attributes in symbol_data.items() if symbol in common_symbols
        ]
    }

    return result

