diff --git a/helper_functions.py b/helper_functions.py index 88f65521..e1657ebc 100644 --- a/helper_functions.py +++ b/helper_functions.py @@ -66,7 +66,7 @@ def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Ten # Reshape preds and plot y_pred = y_pred.reshape(xx.shape).detach().numpy() plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7) - plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu) + plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.gist_heat) plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max())