import marimo

app = marimo.App()


@app.cell
def _(dist_sq_1, np):
    _K = 2
    nearest_partition = np.argpartition(dist_sq_1, _K + 1, axis=1)
    return (nearest_partition,)


@app.cell
def _(X_1, nearest_partition, plt):
    plt.scatter(X_1[:, 0], X_1[:, 1], s=100)
    _K = 2
    for i_1 in range(X_1.shape[0]):
        for j in nearest_partition[i_1, :_K + 1]:
            plt.plot(*zip(X_1[j], X_1[i_1]), color='black')
    return


if __name__ == "__main__":
    app.run()
