feat: support gpu
This commit is contained in:
@@ -8,11 +8,7 @@ from research.statistics.utils import LETTERS, build_letter_frequencies
|
||||
|
||||
def plot_transition_matrix(ax, df_probs, title=""):
|
||||
hm = sns.heatmap(
|
||||
df_probs.loc[list(LETTERS), list(LETTERS)],
|
||||
cmap="Reds",
|
||||
annot=False,
|
||||
cbar=False,
|
||||
ax=ax
|
||||
df_probs.loc[list(LETTERS), list(LETTERS)], cmap="Reds", annot=False, cbar=False, ax=ax
|
||||
)
|
||||
ax.set_title(title, fontsize=12)
|
||||
return hm
|
||||
@@ -20,8 +16,8 @@ def plot_transition_matrix(ax, df_probs, title=""):
|
||||
|
||||
def plot_letter_frequencies(males, females, sort_values=False, title=None):
|
||||
# Compute frequencies
|
||||
L_m = build_letter_frequencies(males['name']).set_index("letter")["freq"]
|
||||
L_f = build_letter_frequencies(females['name']).set_index("letter")["freq"]
|
||||
L_m = build_letter_frequencies(males["name"]).set_index("letter")["freq"]
|
||||
L_f = build_letter_frequencies(females["name"]).set_index("letter")["freq"]
|
||||
|
||||
# Combine into one DataFrame
|
||||
df_plot = pd.DataFrame({"Male": L_m, "Female": L_f}).fillna(0).reset_index()
|
||||
@@ -35,8 +31,8 @@ def plot_letter_frequencies(males, females, sort_values=False, title=None):
|
||||
x = np.arange(len(df_plot))
|
||||
w = 0.4
|
||||
fig, ax = plt.subplots(figsize=(16, 6))
|
||||
ax.bar(x - w/2, df_plot["Male"], width=w, label="Male", color="steelblue", alpha=0.8)
|
||||
ax.bar(x + w/2, df_plot["Female"], width=w, label="Female", color="salmon", alpha=0.8)
|
||||
ax.bar(x - w / 2, df_plot["Male"], width=w, label="Male", color="steelblue", alpha=0.8)
|
||||
ax.bar(x + w / 2, df_plot["Female"], width=w, label="Female", color="salmon", alpha=0.8)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(df_plot["letter"])
|
||||
|
||||
Reference in New Issue
Block a user