from datetime import datetime, timedelta
from typing import Optional
from bson import ObjectId
import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings("ignore")

"""Helper Functions for Market Analysis API"""
def get_duration_range_ma(duration, start_date=None, end_date=None):
    """
    Calculate start date based on given duration

    Args:
    duration (str): Duration in format '1W', '1M', '3M', '6M', '1Y', '3Y', '5Y', or custom date range
    end_date (datetime, optional): End date for calculation. Defaults to current date.

    Returns:
    tuple: Start date and end date
    """
    if end_date is None:
        end_date = datetime.now()

    # Ensure end_date is a datetime object
    if isinstance(end_date, str):
        end_date = pd.to_datetime(end_date)

    duration_map = {
        '1W': timedelta(weeks=1),
        '1M': timedelta(days=30),
        '3M': timedelta(days=90),
        '6M': timedelta(days=180),
        '1Y': timedelta(days=365),
        '3Y': timedelta(days=365 * 3),
        '5Y': timedelta(days=365 * 5)
    }

    # if start_date is None and end_date is None
    # Check if duration is in predefined map
    if duration in duration_map:
        start_date = end_date - duration_map[duration]
        return start_date, end_date

    # Check if it's a custom date range
    try:
        # Expect format like "2022-01-01:2023-01-01"
        start_str, end_str = duration.split(':')
        start_date = pd.to_datetime(start_str)
        end_date = pd.to_datetime(end_str)
        return start_date, end_date
    except:
        raise ValueError(f"Invalid duration format: {duration}")

def determine_current_trend_ma(df):
    """
    Determine the current trend of the price data

    Args:
    df (pandas.DataFrame): DataFrame with price data

    Returns:
    str: Current trend (Uptrend, Downtrend, or Neutral)
    """
    # Calculate percentage change between first and last close price
    first_close = df.iloc[0]['close']
    last_close = df.iloc[-1]['close']

    # Define trend threshold (e.g., 5% change)
    trend_threshold = 0.05  # 5%

    percentage_change = (last_close - first_close) / first_close * 100

    if percentage_change > trend_threshold:
        return "Uptrend"
    elif percentage_change < -trend_threshold:
        return "Downtrend"
    else:
        return "Neutral"

def find_price_occurrence_ma(df, price, is_top=True, tolerance_percent=0.5):
    """
    Find the dates when a specific price occurred or was very close

    Args:
    df (pandas.DataFrame): DataFrame with price data
    price (float): Price to search for
    is_top (bool): Whether searching for top (high) or bottom (low) price
    tolerance_percent (float): Percentage tolerance for price matching

    Returns:
    dict: Detailed information about price occurrences
    """
    # Calculate price range based on tolerance percentage
    lower_bound = price * (1 - tolerance_percent/100)
    upper_bound = price * (1 + tolerance_percent/100)

    if is_top:
        # Find dates where high price is within the tolerance range
        matching_dates = df[(df['high'] >= lower_bound) & (df['high'] <= upper_bound)]
    else:
        # Find dates where low price is within the tolerance range
        matching_dates = df[(df['low'] >= lower_bound) & (df['low'] <= upper_bound)]

    # Prepare detailed occurrence information
    occurrences = []
    for index, row in matching_dates.iterrows():
        occurrence = {
            "date": index.strftime("%Y-%m-%d"),
            "price": row['high'] if is_top else row['low'],
            "exact_match": False
        }

        # Check if the price is an exact match
        if is_top:
            if np.isclose(row['high'], price, rtol=0.001):
                occurrence["exact_match"] = True
        else:
            if np.isclose(row['low'], price, rtol=0.001):
                occurrence["exact_match"] = True

        occurrences.append(occurrence)

    return {
        "price_searched": price,
        "total_occurrences": len(occurrences),
        "exact_matches": [occ for occ in occurrences if occ["exact_match"]],
        "near_matches": [occ for occ in occurrences if not occ["exact_match"]],
        "occurrences": occurrences
    }

def process_price_data_ma(data):
    """
    Process price data and Return the Response.

    Args:
    data (dict): Dictionary containing price and instrument data
    """
    try:
        # Analyze the price data
        output = analyze_price_data_graph(data)

        # Convert any non-serializable fields to serializable formats
        serialized_output = convert_objectid_ma(output)

        return serialized_output

    except Exception as e:
        print(f"Error processing data for {data.get('symbol', 'Unknown')}: {str(e)}")
        return None

async def process_single_symbol_ma(data, duration):
    """
    Process price data for a single symbol asynchronously.

    Args:
    data (dict): Dictionary containing price and instrument data.
    duration (str): Duration to analyze.

    Returns:
    dict: Analyzed output.
    """
    try:
        # Analyze price data
        output = analyze_price_data_ma(data, duration)

        # Ensure JSON-safe data
        output = make_json_safe_ma(output)

        return output
    except Exception as e:
        return {"error": str(e)}

def convert_objectid_ma(data):
    if isinstance(data, list):
        return [convert_objectid_ma(item) for item in data]
    elif isinstance(data, dict):
        return {key: convert_objectid_ma(value) for key, value in data.items()}
    elif isinstance(data, ObjectId):
        return str(data)
    else:
        return data

def make_json_safe_ma(data):
    """
    Recursively replace non-JSON-compliant values in the data structure.
    """
    if isinstance(data, list):
        return [make_json_safe_ma(item) for item in data]
    elif isinstance(data, dict):
        return {key: make_json_safe_ma(value) for key, value in data.items()}
    elif isinstance(data, (float, np.float64, np.float32)):
        if np.isnan(data) or np.isinf(data):
            return None  # Replace NaN or inf with null
        return float(data)
    elif isinstance(data, (int, str)):
        return data
    else:
        return str(data)  # Convert unexpected types to strings


def parse_time_range_ma(time_range: Optional[str] = None,
                     start_date: Optional[str] = None,
                     end_date: Optional[str] = None) -> tuple:
    """
    Parse time range or custom date range.

    Supported time ranges:
    - 1W: 1 Week
    - 1M: 1 Month
    - 3M: 3 Months
    - 6M: 6 Months
    - 1Y: 1 Year
    - 2Y: 2 Years
    - 3Y: 3 Years
    - 5Y: 5 Years

    If custom dates are provided, they take precedence.
    """
    now = datetime.now()

    # If custom dates are provided
    if start_date and end_date:
        return datetime.strptime(start_date, "%Y-%m-%d"), datetime.strptime(end_date, "%Y-%m-%d")

    # Parse predefined time ranges
    if time_range:
        time_range = time_range.upper()
        multipliers = {
            'W': 7,
            'M': 30,
            'Y': 365
        }

        if len(time_range) < 2:
            raise ValueError("Invalid time range format")

        unit = time_range[-1]
        try:
            value = int(time_range[:-1])
        except ValueError:
            raise ValueError("Invalid time range value")

        if unit not in multipliers:
            raise ValueError("Invalid time range unit")

        days = value * multipliers[unit]
        return now - timedelta(days=days), now

    # Default to 1 year if no range specified
    return now - timedelta(days=365), now

def filter_prices_by_date_range_ma(prices, start_date, end_date):
    """
    Filter prices dictionary based on date range.
    """
    filtered_prices = {}
    for date_str, price_data in prices.items():
        date = datetime.strptime(date_str, "%Y-%m-%d")
        if start_date <= date <= end_date:
            filtered_prices[date_str] = price_data
    return filtered_prices

def mark_special_levels_ma(df):
    """Mark levels where the bottom candles before and after two candles are not touching."""
    special_levels = {"bottom_levels": [], "top_levels": []}
    for i in range(2, len(df) - 2):
        prev_two_bottom = df.iloc[i - 2 : i]["low"].min()
        next_two_bottom = df.iloc[i + 1 : i + 3]["low"].min()
        current_bottom = df.iloc[i]["low"]
        if current_bottom < prev_two_bottom and current_bottom < next_two_bottom:
            special_levels["bottom_levels"].append(
                {
                    "date": df.index[i].strftime("%Y-%m-%d"),
                    "price": current_bottom,
                    "candle_data": {
                        "open": df.iloc[i]["open"],
                        "high": df.iloc[i]["high"],
                        "low": df.iloc[i]["low"],
                        "close": df.iloc[i]["close"],
                    },
                }
            )
        prev_two_top = df.iloc[i - 2 : i]["high"].max()
        next_two_top = df.iloc[i + 1 : i + 3]["high"].max()
        current_top = df.iloc[i]["high"]
        if current_top > prev_two_top and current_top > next_two_top:
            special_levels["top_levels"].append(
                {
                    "date": df.index[i].strftime("%Y-%m-%d"),
                    "price": current_top,
                    "candle_data": {
                        "open": df.iloc[i]["open"],
                        "high": df.iloc[i]["high"],
                        "low": df.iloc[i]["low"],
                        "close": df.iloc[i]["close"],
                    },
                }
            )
    return special_levels

def find_support_resistance_ma(prices, window=3):
    """Find support and resistance levels in the price data."""
    supports = []
    resistances = []
    for i in range(window, len(prices) - window):
        if prices["low"][i] == min(prices["low"][i - window : i + window + 1]):
            supports.append(i)
        if prices["high"][i] == max(prices["high"][i - window : i + window + 1]):
            resistances.append(i)
    return supports, resistances

def analyze_price_data_ma(data, duration='1Y'):
    """
    Analyze price data for specified duration

    Args:
    data (dict): Dictionary containing price and instrument data
    duration (str): Duration to analyze

    Returns:
    dict: Comprehensive price analysis
    """
    # Extracting price data
    prices = data["prices"]
    df = pd.DataFrame.from_dict(prices, orient="index")
    df.index = pd.to_datetime(df.index)
    df.sort_index(inplace=True)

    # Calculate duration range
    try:
        start_date, end_date = get_duration_range_ma(duration, df.index[-1])
    except ValueError as e:
        return {"error": str(e)}

    # Filter DataFrame for the specified duration
    df_duration = df.loc[start_date:end_date]

    # Prepare output dictionary
    output = {
        "symbol": data["symbol"],
        "Instrument": data["Instrument"],
        "instrument_token": data["instrument_token"],
        "exchange_token": data["exchange_token"],
        "duration": duration,
        "analysis_period": {
            "start_date": start_date.strftime("%Y-%m-%d"),
            "end_date": end_date.strftime("%Y-%m-%d")
        }
    }

    # Determine current trend
    output["current_trend"] = determine_current_trend_ma(df_duration)

    # Find top and bottom prices with their occurrences
    top_price = df_duration['high'].max()
    bottom_price = df_duration['low'].min()

    # Find price occurrences with different tolerance levels
    output["price_details"] = {
        "top_price": {
            "price": top_price,
            "exact_matches": find_price_occurrence_ma(df_duration, top_price, is_top=True, tolerance_percent=0.1),
            "near_matches_0.5%": find_price_occurrence_ma(df_duration, top_price, is_top=True, tolerance_percent=0.5),
            "near_matches_1%": find_price_occurrence_ma(df_duration, top_price, is_top=True, tolerance_percent=1),
        },
        "bottom_price": {
            "price": bottom_price,
            "exact_matches": find_price_occurrence_ma(df_duration, bottom_price, is_top=False, tolerance_percent=0.1),
            "near_matches_0.5%": find_price_occurrence_ma(df_duration, bottom_price, is_top=False, tolerance_percent=0.5),
            "near_matches_1%": find_price_occurrence_ma(df_duration, bottom_price, is_top=False, tolerance_percent=1),
        }
    }

    # Calculate price change and percentage change
    first_close = df_duration.iloc[0]['prev_close']
    if first_close is None:
        first_close = df_duration.iloc[0]['close']
    last_close = df_duration.iloc[-1]['close']
    price_change = last_close - first_close
    percentage_change = (price_change / first_close) * 100

    output["price_performance"] = {
        "first_close": first_close,
        "last_close": last_close,
        "absolute_change": price_change,
        "percentage_change": round(percentage_change, 2)
    }

    return output


def analyze_price_data_graph(data):
    """Analyze price data and extract key levels and trends."""
    prices = data["prices"]
    df = pd.DataFrame.from_dict(prices, orient="index")
    df.index = pd.to_datetime(df.index)
    df.sort_index(inplace=True)
    special_marked_levels = mark_special_levels_ma(df)
    supports, resistances = find_support_resistance_ma(df)
    output = {
        "bottom": [],
        "top": [],
        "trends": [],
        "special_bottom_levels": special_marked_levels["bottom_levels"],
        "special_top_levels": special_marked_levels["top_levels"],
    }
    for idx in supports:
        date = df.index[idx].strftime("%Y-%m-%d")
        output["bottom"].append(
            {
                date: {
                    "open": df.iloc[idx]["open"],
                    "high": df.iloc[idx]["high"],
                    "low": df.iloc[idx]["low"],
                    "close": df.iloc[idx]["close"],
                }
            }
        )
    for idx in resistances:
        date = df.index[idx].strftime("%Y-%m-%d")
        output["top"].append(
            {
                date: {
                    "open": df.iloc[idx]["open"],
                    "high": df.iloc[idx]["high"],
                    "low": df.iloc[idx]["low"],
                    "close": df.iloc[idx]["close"],
                }
            }
        )
    levels = sorted(supports + resistances)
    trend_points = []
    for idx in levels:
        point = {
            "type": "support" if idx in supports else "resistance",
            "date": df.index[idx],
            "price": df.iloc[idx]["low"] if idx in supports else df.iloc[idx]["high"],
        }
        trend_points.append(point)
    for i in range(1, len(trend_points)):
        start_point = trend_points[i - 1]
        end_point = trend_points[i]
        trend_type = (
            "uptrend"
            if start_point["type"] == "support" and end_point["type"] == "resistance"
            else "downtrend"
            if start_point["type"] == "resistance" and end_point["type"] == "support"
            else None
        )
        if trend_type:
            start_price = start_point["price"]
            end_price = end_point["price"]
            price_change = end_price - start_price
            percentage_change = (price_change / start_price) * 100
            trend_duration = (end_point["date"] - start_point["date"]).days
            trend_details = {
                "start_date": start_point["date"].strftime("%Y-%m-%d"),
                "end_date": end_point["date"].strftime("%Y-%m-%d"),
                "trend": trend_type,
                "days": trend_duration,
                "start_price": start_price,
                "end_price": end_price,
                "price_change": price_change,
                "percentage_change": round(percentage_change, 2),
            }
            output["trends"].append(trend_details)
    output["symbol"] = data["symbol"]
    output["Instrument"] = data["Instrument"]
    return output







