8000 __call__ function within models/_components/graphy.py passing parent function kwargs unintentionally · Issue #445 · mckinsey/vizro · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

__call__ function within models/_components/graphy.py passing parent function kwargs unintentionally #445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
1 task done
pablo-fence opened this issue Apr 30, 2024 · 11 comments
Labels
General Question ❓ Issue contains a general question

Comments

@pablo-fence
Copy link
pablo-fence commented Apr 30, 2024

Description

Function call in

fig = self.figure(**kwargs)
passing down kwargs from parent function origination_total.

@capture('graph')
def origination_total(
    data_frame: pd.DataFrame,
    id_column: Optional[str] = "client_id",
    date_column: Optional[str] = "date",
    amount_column: Optional[Union[float, int]] = "outstanding_balance",
    sum_count: Optional[str] = "Count",
    cumulative: Optional[str] = "True",
    currency: Optional[str] = "EUR"
):

    if cumulative == "False":
        cumulative = False
    else:
        cumulative = True
    
    if sum_count == "Sum":
        var_to_plot = amount_column
        grouped = data_frame.groupby([date_column])[var_to_plot].sum().reset_index()
    elif sum_count == "Count":
        var_to_plot = id_column
        grouped = data_frame.groupby([date_column])[var_to_plot].nunique().reset_index()
    
    if cumulative:
        grouped = grouped.sort_values([date_column])
        grouped["cumulative"] = grouped[var_to_plot].cumsum()
        var_to_plot = 'cumulative'

    grouped["percentage"] = grouped.groupby(date_column)[var_to_plot].transform(lambda x:x.sum())
    grouped["percentage"] = 10000 * grouped[var_to_plot] / grouped["percentage"] // 1 / 100
    
    # Initialize figure
    fig = px.bar(
        grouped, 
        x=date_column,
        y=var_to_plot,
        hover_data=[var_to_plot, 'percentage']
        )

    annotations = []
    for date, d_group in grouped.groupby(date_column):
        total_volume = d_group[var_to_plot].sum()
        formatted_volume = format_volume(total_volume)
        annotations.append(
            go.layout.Annotation(
                x=date,
                y=total_volume,
                text=formatted_volume,
                showarrow=False,
                yshift=20
            )
        )

    y1_max = grouped[var_to_plot].max()*1.1

    # Update layout with dynamic max range for y-axis based on the max cumulative sum
    fig.update_layout(
        annotations=annotations,
        xaxis=dict(title=f'Time'),
        yaxis=dict(title=f'Total Origination ({currency})', side='left', nticks=5, range=[0, y1_max])
    )

    return fig

Debugged including a print of kwargs in line 54 and got:

{'data_frame':    year  client_id  percentage
0  2023        801       100.0
1  2024         98       100.0, 'x': 'year', 'y': 'client_id', 'hover_data': ['client_id', 'percentage'], 'sum_count': 'Count', 'id_column': 'client_id', 'amount_column': 'outstanding_balance'}

where "'sum_count': 'Count', 'id_column': 'client_id', 'amount_column': 'outstanding_balance'}" are not expected in the function call.

Expected behavior

No response

Which package?

vizro

Package version

0.1.15

Python version

3.12.1

OS

MacOS Sonoma 14.4

How to Reproduce

I can't share the full repository due to privacy issues, but here is sample data that can be used to reproduce the issue:

    data = {
        'client_id': [1, 2, 1, 2, 3, 3],
        'date': ['2021-01', '2021-01', '2021-02', '2021-02', '2021-03', '2021-03'],
        'outstanding_balance': [100, 150, 200, 250, 300, 350]
    }

and the function that generates the dashboard page:

    def summary_origination_page(
    data_frame: pd.DataFrame,
    id_columns: List[str],
    date_columns: List[str],
    amount_columns: List[str]
    ) -> vm.Page:
    
    layout = vm.Layout(
        grid = [
        [0] * 5 + [2] * 1,
        [1] * 5 + [3] * 1,
        [1] * 5 + [3] * 1,
        [1] * 5 + [3] * 1
        ],
    row_min_height="100px",
    row_gap="24px"
    )

    components = [
            vm.Card(
                text="""
                #### __How to Read this Chart:__ (in `pink`, `adjustable parameters`)
                This bar chart shows the evolution of the origination volume over the selected `time axis`. If `type of aggregation` is set to `Sum`, the chart will show each period's 
                sum of `column to aggregate`, else if `type of aggregation`is set to `Count`, the unique count of `id column`will be displayed. Additionally, volumes can be `accumulated` or not.
                """
            ),
            vm.Graph(id='origination_total', figure=origination_total(data_frame=data_frame)
            ),
            vm.Card(
                text="""
                #### __How to Read this Chart:__
                This bar chart shows the comparison of the current year vs. the previous year. This chart is affected by the same parameters as the chart on the left.
                """
            ),
            vm.Graph(id='origination_total_yoy_comparison', figure=origination_total_yoy_comparison(data_frame=data_frame)
            )

        ]
    
    controls = [
        vm.Parameter(id='Summary Origination: date axis', 
                    targets=['origination_total.date_column'],
                    selector=vm.Dropdown(
                        options=date_columns,
                        multi=False,
                        value=date_columns[0],
                        title="Choose column to use as time axis:"
                    )
        ),
        vm.Parameter(id='Summary Origination: sum_count', 
                    targets=['origination_total.sum_count', 'origination_total_yoy_comparison.sum_count'],
                    selector=vm.Dropdown(
                        options=["Sum", "Count"],
                        multi=False,
                        value="Sum",
                        title="Choose type of aggregation:"
                    ),
        ),
        vm.Parameter(id='Summary Origination: id', 
                    targets=['origination_total.id_column', 'origination_total_yoy_comparison.id_column'],
                    selector=vm.Dropdown(
                        options=id_columns,
                        multi=False,
                        value=id_columns[0],
                        title="Choose id column for counts:"
                    )
        ),
        vm.Parameter(id='Summary Origination: column to aggregate', 
                    targets=['origination_total.amount_column', 'origination_total_yoy_comparison.amount_column'],
                    selector=vm.Dropdown(
                        options=amount_columns,
                        multi=False,
                        value=amount_columns[0],
                        title="Choose column to aggregate for sums:"
                    )
        ),vm.Parameter(id='Summary Origination: cumulative', 
                    targets=['origination_total.cumulative'],
                    selector=vm.Dropdown(
                        options=["True", "False"],
                        multi=False,
                        value="False",
                        title="Choose wether to accumulate or not:"
                    )
        )
            ]
    
    # Build page
    page = vm.Page(
        title='Summary: Origination',
        layout=layout,
        components=components,
        controls=controls
    )
    return page

Output

No response

Code of Conduct

@pablo-fence pablo-fence added Bug Report 🐛 Issue contains a bug report Needs triage 🔍 Issue needs triaging labels Apr 30, 2024
@maxschulz-COL
Copy link
Contributor

Hi @pablo-fence ,

thanks for reaching out! I hope I understand the issue correctly, but you are saying that if you do not provide keyword arguments, e.g. sum_count because it's optional, then it should never appear under the debugging statement you mentioned.

If that is the case I think I understand where the confusion comes from: Using a vm.Parameter with a vm.Graph as target will modify (and in this case insert) the relevant keyword arguments to the function call. In that case it would also appear in a print of those kwargs.

Did I understand you correctly?

@maxschulz-COL maxschulz-COL added General Question ❓ Issue contains a general question Community and removed Bug Report 🐛 Issue contains a bug report Needs triage 🔍 Issue needs triaging labels Apr 30, 2024
@pablo-fence
Copy link
Author
pablo-fence commented May 1, 2024

Hi @maxschulz-COL,

Sorry I wasn't clear enough on my explanation; I understand the behavior of using a vm.Parameter with a vm.Graph you mention above, however, the error I am encountering is during one of the app callbacks to the px.bar() once the app is active (it is not raised during loading but it does raise when I move to the page that's built with the code above). The TypeError is the following TypeError: bar() got an unexpected keyword argument 'sum_coun' so I am assuming (and actually checked with the debug print) those kwargs that belong in the parent origination_total() function parent are being passed on to the bar(), raising the error since it is not expecting sum_count as a kwarg.

@maxschulz-COL
Copy link
Contributor

Hi @maxschulz-COL,

Sorry I wasn't clear enough on my explanation; I understand the behavior of using a vm.Parameter with a vm.Graph you mention above, however, the error I am encountering is during one of the app callbacks to the px.bar() once the app is active (it is not raised during loading but it does raise when I move to the page that's built with the code above). The TypeError is the following TypeError: bar() got an unexpected keyword argument 'sum_coun' so I am assuming (and actually checked with the debug print) those kwargs that belong in the parent origination_total() function parent are being passed on to the bar(), raising the error since it is not expecting sum_count as a kwarg.

Understood - would it be possible to post a minimal working example that I can copy paste to debug. Currently it is a little difficult to patch things together and guess which are the important bits to keep. Then I will try to debug this ASAP. 💪

@pablo-fence
Copy link
Author
pablo-fence commented May 1, 2024

Hi @maxschulz-COL,
Sorry I wasn't clear enough on my explanation; I understand the behavior of using a vm.Parameter with a vm.Graph you mention above, however, the error I am encountering is during one of the app callbacks to the px.bar() once the app is active (it is not raised during loading but it does raise when I move to the page that's built with the code above). The TypeError is the following TypeError: bar() got an unexpected keyword argument 'sum_coun' so I am assuming (and actually checked with the debug print) those kwargs that belong in the parent origination_total() function parent are being passed on to the bar(), raising the error since it is not expecting sum_count as a kwarg.

Understood - would it be possible to post a minimal working example that I can copy paste to debug. Currently it is a little difficult to patch things together and guess which are the important bits to keep. Then I will try to debug this ASAP. 💪

Of course, please find below:

import vizro.plotly.express as px
import vizro.models as vm
import plotly.graph_objects as go
import pandas as pd
import os
from vizro import Vizro
from vizro.models.types import capture
from typing import List, Union, Optional

import pandas as pd

def format_volume(volume):
    if volume >= 1_000_000_000:
        return f'{volume/1_000_000_000:.1f}B'
    elif volume >= 1_000_000:
        return f'{volume/1_000_000:.1f}M'
    elif volume >= 1_000:
        return f'{volume/1_000:.0f}K'
    else:
        return f'{volume}'

@capture('graph')
def origination_total(
    data_frame: pd.DataFrame,
    id_column: Optional[str] = "client_id",
    date_column: Optional[str] = "date",
    amount_column: Optional[Union[float, int]] = "outstanding_balance",
    sum_count: Optional[str] = "Count",
    cumulative: Optional[str] = "True",
    currency: Optional[str] = "EUR"
):

    if cumulative == "False":
        cumulative = False
    else:
        cumulative = True
    
    if sum_count == "Sum":
        var_to_plot = amount_column
        grouped = data_frame.groupby([date_column])[var_to_plot].sum().reset_index()
    elif sum_count == "Count":
        var_to_plot = id_column
        grouped = data_frame.groupby([date_column])[var_to_plot].nunique().reset_index()
    
    if cumulative:
        grouped = grouped.sort_values([date_column])
        grouped["cumulative"] = grouped[var_to_plot].cumsum()
        var_to_plot = 'cumulative'

    grouped["percentage"] = grouped.groupby(date_column)[var_to_plot].transform(lambda x:x.sum())
    grouped["percentage"] = 10000 * grouped[var_to_plot] / grouped["percentage"] // 1 / 100
    
    # Initialize figure
    fig = px.bar(
        grouped, 
        x=date_column,
        y=var_to_plot,
        hover_data=[var_to_plot, 'percentage']
        )

    annotations = []
    for date, d_group in grouped.groupby(date_column):
        total_volume = d_group[var_to_plot].sum()
        formatted_volume = format_volume(total_volume)
        annotations.append(
            go.layout.Annotation(
                x=date,
                y=total_volume,
                text=formatted_volume,
                showarrow=False,
                yshift=20
            )
        )

    y1_max = grouped[var_to_plot].max()*1.1

    # Update layout with dynamic max range for y-axis based on the max cumulative sum
    fig.update_layout(
        annotations=annotations,
        xaxis=dict(title=f'Time'),
        yaxis=dict(title=f'Total Origination ({currency})', side='left', nticks=5, range=[0, y1_max])
    )

    return fig

def origination_total_yoy_comparison(
    data_frame: pd.DataFrame,
    id_column: Optional[str] = "client_id",
    date_column: Optional[str] = "date",
    amount_column: Optional[Union[float, int]] = "outstanding_balance",
    sum_count = "Count",
    currency = "EUR"
):
    
    # Generate year column
    data_frame['year'] = data_frame[date_column].str[:4].astype(int)
    
    # Keep only relevant columns and filter for the current and previous years
    current_year = data_frame['year'].max()
    previous_year = current_year - 1
    data_frame = data_frame[
        data_frame['year'].isin([current_year, previous_year])
    ][[id_column, 'year', date_column, amount_column]].drop_duplicates()

    if sum_count == "Sum":
        var_to_plot = amount_column
        grouped = data_frame.groupby(['year'])[var_to_plot].sum().reset_index()
    elif sum_count == "Count":
        var_to_plot = id_column
        grouped = data_frame.groupby(['year'])[var_to_plot].nunique().reset_index()

    grouped["percentage"] = grouped.groupby('year')[var_to_plot].transform(lambda x:x.sum())
    grouped["percentage"] = 10000 * grouped[var_to_plot] / grouped["percentage"] // 1 / 100
    
    # Initialize figure
    fig = px.bar(
        grouped, 
        x='year',
        y=var_to_plot,
        hover_data=[var_to_plot, 'percentage']
        )

    total_volumes = grouped.groupby('year')[var_to_plot].sum().reset_index()
    annotations = [
        go.layout.Annotation(
            x=row['year'],
            y=row[var_to_plot],
            text=format_volume(row[var_to_plot]),
            showarrow=False,
            yshift=10,
            font=dict(color="black")
        ) for _ , row in total_volumes.iterrows()
    ]
    
    if sum_count == "Count":
        currency = "N"

    y1_max = grouped[var_to_plot].max()*1.1

    # Update layout with dynamic max range for y-axis based on the max cumulative sum
    fig.update_layout(
        annotations=annotations,
        xaxis=dict(
            title=f'Time',
            tickmode='array',  # Explicitly set the tick mode
            tickvals=grouped['year'].unique()  # Set the tick values to unique years from data
            ),
        yaxis=dict(title=f'Total Origination ({currency})', side='left', nticks=5, range=[0, y1_max])
    )

    return fig

def summary_origination_page(
    data_frame: pd.DataFrame,
    id_columns: List[str],
    date_columns: List[str],
    amount_columns: List[str]
    ) -> vm.Page:
    
    layout = vm.Layout(
        grid = [
        [0] * 6 + [2] * 2,
        [1] * 6 + [3] * 2,
        [1] * 6 + [3] * 2,
        [1] * 6 + [3] * 2
        ],
    row_min_height="100px",
    row_gap="24px"
    )

    components = [
            vm.Card(
                text="""
                #### __How to Read this Chart:__ (in `pink`, `adjustable parameters`)
                This bar chart shows the evolution of the origination volume over the selected `time axis`. If `type of aggregation` is set to `Sum`, the chart will show each period's 
                sum of `column to aggregate`, else if `type of aggregation`is set to `Count`, the unique count of `id column`will be displayed. Additionally, volumes can be `accumulated` or not.
                """
            ),
            vm.Graph(id='origination_total', figure=origination_total(data_frame=data_frame)
            ),
            vm.Card(
                text="""
                #### __How to Read this Chart:__
                This bar chart shows the comparison of the current year vs. the previous year. This chart is affected by the same parameters as the chart on the left.
                """
            ),
            vm.Graph(id='origination_total_yoy_comparison', figure=origination_total_yoy_comparison(data_frame=data_frame)
            )

        ]
    
    controls = [
        vm.Parameter(id='Summary Origination: date axis', 
                    targets=['origination_total.date_column'],
                    selector=vm.Dropdown(
                        options=date_columns,
                        multi=False,
                        value=date_columns[0],
                        title="Choose column to use as time axis:"
                    )
        ),
        vm.Parameter(id='Summary Origination: sum_count', 
                    targets=['origination_total.sum_count', 'origination_total_yoy_comparison.sum_count'],
                    selector=vm.Dropdown(
                        options=["Sum", "Count"],
                        multi=False,
                        value="Sum",
                        title="Choose type of aggregation:"
                    ),
        ),
        vm.Parameter(id='Summary Origination: id', 
                    targets=['origination_total.id_column', 'origination_total_yoy_comparison.id_column'],
                    selector=vm.Dropdown(
                        options=id_columns,
                        multi=False,
                        value=id_columns[0],
                        title="Choose id column for counts:"
                    )
        ),
        vm.Parameter(id='Summary Origination: column to aggregate', 
                    targets=['origination_total.amount_column', 'origination_total_yoy_comparison.amount_column'],
                    selector=vm.Dropdown(
                        options=amount_columns,
                        multi=False,
                        value=amount_columns[0],
                        title="Choose column to aggregate for sums:"
                    )
        ),vm.Parameter(id='Summary Origination: cumulative', 
                    targets=['origination_total.cumulative'],
                    selector=vm.Dropdown(
                        options=["True", "False"],
                        multi=False,
                        value="False",
                        title="Choose wether to accumulate or not:"
                    )
        )
            ]
    
    # Build page
    page = vm.Page(
        title='Summary: Origination',
        layout=layout,
        components=components,
        controls=controls
    )
    return page


if __name__ == "__main__":
    
    # Sample data to use with the function
    data = {
        'client_id': [1, 2, 1, 2, 3, 3],
        'date': ['2021-01', '2021-01', '2021-02', '2021-02', '2021-03', '2021-03'],
        'outstanding_balance': [100, 150, 200, 250, 300, 350]
    }

    df = pd.DataFrame(data)
    
    page1 = summary_origination_page(
        data_frame=df,
        id_columns=["client_id"],
        date_columns=["date"],
        amount_columns=["outstanding_balance"]
    )
    
    # Building and running the dashboard
    dashboard = vm.Dashboard(
    title="Debug Example",
    pages=[page1],
    navigation=vm.Navigation(
        nav_selector=vm.NavBar(
            items=[
                vm.NavLink(label="Summary", pages=["Summary: Origination"], icon="Tenancy"),
            ]
        )
        ),
    theme='vizro_dark'
    )

    current_dir = os.path.dirname(os.path.realpath(__file__))
    project_root_path = os.path.join(os.path.dirname(current_dir))

    vizro_app = Vizro(assets_folder=os.path.join(current_dir, "assets/"))

    vizro_app.build(dashboard).run(debug=False)

Also, here is the requirements file I am using requirements.txt

@maxschulz-COL maxschulz-COL added Bug Report 🐛 Issue contains a bug report and removed Bug Report 🐛 Issue contains a bug report labels May 2, 2024
@maxschulz-COL
Copy link
Contributor

I figured out what is going on, and it took me a while 😄

There is one line missing: @capture('graph') needs to be added before def origination_total_yoy_comparison(, because vm.Graph only accepts CapturedCallables, ie something decorated with the capture decorator.

The problem is that the error message is super obscure, because there are a few things starting to mix. If you'd like to bear with me, here is the explanation:

  • by importing vizro.plotly.express as px (as one rightly should! so definitely not wrong to do so) we decorate all px charts behind the scenes with @capture
  • now if you use one such decorated px chart in another chart (inception style 😉 ), the returned custom object is of type CapturedCallable, although the main function (the custom chart you are building) has not directly been decorated.
  • this has two consequences: it stops the vm.Graph from failing, as it receives a CapturedCallable, but it is the wrong CapturedCallable (the px chart you used to build the other chart) and thus it messes up the arguments
  • what lead me to this: if one does not directly insert px charts into the vm.Graph, there is no need to import vizro.plotly.express, one can just use plotly.express to build the chart (as you do) and then decorate the final function with @capture --> not doing so then raises the correct error

TLDR: This is not a bug, just a typo in the script, but THANKS for raising that as it exposes a very confusing and obscure error message 🙏

Hope that makes sense 😄

P.S. We are soon merging the bug fix for your other question
P.P.S. @antonymilne we should probably do something about fixing this behaviour as it was really really confusing

@antonymilne
Copy link
Contributor

@maxschulz-COL great job debugging this. Agree this can be super confusing (as it is here) but I'm not immediately sure whether we'll be able to do anything that can fix it unfortunately 🤔 Other than changing our recommendation to be that you use import plotly.express as px when you're creating a custom chart, which has its own problems.

@pablo-fence thanks for raising this indeed, and also for posting your dashboard here. It's always really interesting to see what people are doing with vizro, and this is a really cool example. Just a small question: for the parameter targets origination_total.cumulative you want boolean True/False but supply strings for options and then converting back to boolean inside the graph function. Did you try to set options = [True, False] as boolean and find it didn't work?

@pablo-fence
Copy link
Author
pablo-fence commented May 2, 2024

I figured out what is going on, and it took me a while 😄

There is one line missing: @capture('graph') needs to be added before def origination_total_yoy_comparison(, because vm.Graph only accepts CapturedCallables, ie something decorated with the capture decorator.

The problem is that the error message is super obscure, because there are a few things starting to mix. If you'd like to bear with me, here is the explanation:

  • by importing vizro.plotly.express as px (as one rightly should! so definitely not wrong to do so) we decorate all px charts behind the scenes with @capture
  • now if you use one such decorated px chart in another chart (inception style 😉 ), the returned custom object is of type CapturedCallable, although the main function (the custom chart you are building) has not directly been decorated.
  • this has two consequences: it stops the vm.Graph from failing, as it receives a CapturedCallable, but it is the wrong CapturedCallable (the px chart you used to build the other chart) and thus it messes up the arguments
  • what lead me to this: if one does not directly insert px charts into the vm.Graph, there is no need to import vizro.plotly.express, one can just use plotly.express to build the chart (as you do) and then decorate the final function with @capture --> not doing so then raises the correct error

TLDR: This is not a bug, just a typo in the script, but THANKS for raising that as it exposes a very confusing and obscure error message 🙏

Hope that makes sense 😄

P.S. We are soon merging the bug fix for your other question P.P.S. @antonymilne we should probably do something about fixing this behaviour as it was really really confusing

Oh, I see! That makes a lot of sense, many thanks @maxschulz-COL for looking into it so promptly - glad my (silly) mistake can be useful.

P.S. Great! When it is, I should remove the line I added to return an empty dataframe up top my custom ag_grid function right?

@antonymilne yes, I had to use strings since if I changed the options to boolean type:

                    targets=['origination_total.cumulative'],
                    selector=vm.Dropdown(
                        options=[True, False],
                        multi=False,
                        value=False,
                        title="Choose whether to accumulate or not:"
                    )

the dashboard displays the dropdown like this (and clicking either cell has no effect)
image

@antonymilne
Copy link
Contributor

@pablo-fence thanks for checking the boolean parameter thing. This sounded familiar so I searched back through our git history and found that actually the right way to do this is:

options=[{'label': 'True', 'value': True}, {'label': 'False', 'value': False}]

This is explained in the note on the page https://vizro.readthedocs.io/en/stable/pages/user-guides/selectors/#categorical-selectors 🙂

@huong-li-nguyen since you did the original PR in https://github.com/McK-Internal/vizx-hyphen/pull/246 and I can't remember any more... Do you know why we still allow StrictBool as a type for these fields? Feels like it should fail pydantic validation (maybe with a useful error message explaining the right way to do it) to avoid people falling into this trap.

@pablo-fence
Copy link
Author
pablo-fence commented May 2, 2024

Thanks @antonymilne ! I see the default value with this implementation is the first dict (parameter value in vm.Dropdown has no effect) - in the above case default will be True. Just FYI.

@pablo-fence
Copy link
Author

Closed succesfully.

Thank you @maxschulz-COL and @antonymilne !! 👍

@antonymilne
Copy link
Contributor

Thanks @antonymilne ! I see the default value with this implementation is the first dict (parameter value in vm.Dropdown has no effect) - in the above case default will be True. Just FYI.

Hmm, this seems to work ok for me like this:

vm.Dropdown(
    options=[{"label": "True", "value": True}, {"label": "False", "value": False}],
    value=False,
)

Both with multi=True and multi=False, I get the default option selected being "False" here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
General Question ❓ Issue contains a general question
Projects
None yet
Development

No branches or pull requests

3 participants
0