# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2026)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import errno
import os
import unittest
from unittest.mock import MagicMock, mock_open, patch

import pytest

from streamlit import errors, file_util

FILENAME = "/some/cache/file"
mock_get_path = MagicMock(return_value=FILENAME)


class FileUtilTest(unittest.TestCase):
    def setUp(self):
        self.patch1 = patch("streamlit.file_util.os.stat")
        self.os_stat = self.patch1.start()

    def tearDown(self):
        self.patch1.stop()

    @patch("streamlit.file_util.get_streamlit_file_path", mock_get_path)
    @patch("streamlit.file_util.open", mock_open(read_data="data"))
    def test_streamlit_read(self):
        """Test streamlitfile_util.streamlit_read."""
        with file_util.streamlit_read(FILENAME) as file_input:
            data = file_input.read()
        assert data == "data"

    @patch("streamlit.file_util.get_streamlit_file_path", mock_get_path)
    @patch("streamlit.file_util.open", mock_open(read_data=b"\xaa\xbb"))
    def test_streamlit_read_binary(self):
        """Test streamlitfile_util.streamlit_read."""
        with file_util.streamlit_read(FILENAME, binary=True) as file_input:
            data = file_input.read()
        assert data == b"\xaa\xbb"

    @patch("streamlit.file_util.get_streamlit_file_path", mock_get_path)
    @patch("streamlit.file_util.open", mock_open(read_data="data"))
    def test_streamlit_read_zero_bytes(self):
        """Test streamlitfile_util.streamlit_read."""
        self.os_stat.return_value.st_size = 0
        with (
            pytest.raises(errors.Error) as e,
            file_util.streamlit_read(FILENAME) as file_input,
        ):
            file_input.read()
        assert str(e.value) == 'Read zero byte file: "/some/cache/file"'

    @patch("streamlit.file_util.get_streamlit_file_path", mock_get_path)
    def test_streamlit_write(self):
        """Test streamlitfile_util.streamlit_write."""

        dirname = os.path.dirname(file_util.get_streamlit_file_path(FILENAME))
        # patch streamlit.*.os.makedirs instead of os.makedirs for py35 compat
        with (
            patch("streamlit.file_util.open", mock_open()) as file_open,
            patch("streamlit.file_util.os.makedirs") as makedirs,
            file_util.streamlit_write(FILENAME) as output,
        ):
            output.write("some data")
            file_open.return_value.write.assert_called_once_with("some data")
            makedirs.assert_called_once_with(dirname, exist_ok=True)

    @patch("streamlit.file_util.get_streamlit_file_path", mock_get_path)
    @patch("streamlit.env_util.IS_DARWIN", True)
    def test_streamlit_write_exception(self):
        """Test streamlitfile_util.streamlit_write."""
        with (
            patch("streamlit.file_util.open", mock_open()) as file_open,
            patch("streamlit.file_util.os.makedirs"),
        ):
            file_open.side_effect = OSError(errno.EINVAL, "[Errno 22] Invalid argument")
            with (
                pytest.raises(errors.Error) as e,
                file_util.streamlit_write(FILENAME) as output,
            ):
                output.write("some data")
            error_msg = (
                "Unable to write file: /some/cache/file\n"
                "Python is limited to files below 2GB on OSX. "
                "See https://bugs.python.org/issue24658"
            )
            assert str(e.value) == error_msg

    def test_get_project_streamlit_file_path(self):
        expected = os.path.join(
            os.getcwd(), file_util.CONFIG_FOLDER_NAME, "some/random/file"
        )

        assert expected == file_util.get_project_streamlit_file_path("some/random/file")

        assert expected == file_util.get_project_streamlit_file_path(
            "some", "random", "file"
        )

    def test_get_app_static_dir(self):
        assert (
            file_util.get_app_static_dir("/some_path/to/app/myapp.py")
            == "/some_path/to/app/static"
        )

    @patch("os.path.getsize", MagicMock(return_value=42))
    @patch(
        "os.walk",
        MagicMock(
            return_value=[
                ("dir1", [], ["file1", "file2", "file3"]),
                ("dir2", [], ["file4", "file5"]),
            ]
        ),
    )
    def test_get_directory_size(self):
        assert file_util.get_directory_size("the_dir") == 42 * 5


class FileIsInFolderTest(unittest.TestCase):
    def test_file_in_folder(self):
        # Test with and without trailing slash
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/a/b/c/")
        assert ret
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/a/b/c")
        assert ret

    def test_file_in_subfolder(self):
        # Test with and without trailing slash
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/a")
        assert ret
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/a/")
        assert ret
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/a/b")
        assert ret
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/a/b/")
        assert ret

    def test_file_not_in_folder(self):
        # Test with and without trailing slash
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/d/e/f/")
        assert not ret
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "/d/e/f")
        assert not ret

    def test_rel_file_not_in_folder(self):
        # Test with and without trailing slash
        ret = file_util.file_is_in_folder_glob("foo.py", "/d/e/f/")
        assert not ret
        ret = file_util.file_is_in_folder_glob("foo.py", "/d/e/f")
        assert not ret

    def test_file_in_folder_glob(self):
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "**/c")
        assert ret

    def test_file_not_in_folder_glob(self):
        ret = file_util.file_is_in_folder_glob("/a/b/c/foo.py", "**/f")
        assert not ret

    def test_rel_file_not_in_folder_glob(self):
        ret = file_util.file_is_in_folder_glob("foo.py", "**/f")
        assert not ret

    def test_rel_file_in_folder_glob(self):
        ret = file_util.file_is_in_folder_glob("foo.py", "")
        assert ret


class FileInPythonPathTest(unittest.TestCase):
    @staticmethod
    def _make_it_absolute(path):
        # Use manual join instead of os.abspath to test against non normalized paths
        return os.path.join(os.getcwd(), path)

    def test_no_pythonpath(self):
        with patch("os.environ", {}):
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("../something/dir1/dir2/module")
            )

    def test_empty_pythonpath(self):
        with patch("os.environ", {"PYTHONPATH": ""}):
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("something/dir1/dir2/module")
            )

    def test_python_path_relative(self):
        with patch("os.environ", {"PYTHONPATH": "something"}):
            assert file_util.file_in_pythonpath(
                self._make_it_absolute("something/dir1/dir2/module")
            )
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("something_else/module")
            )
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("../something/dir1/dir2/module")
            )

    def test_python_path_absolute(self):
        with patch("os.environ", {"PYTHONPATH": self._make_it_absolute("something")}):
            assert file_util.file_in_pythonpath(
                self._make_it_absolute("something/dir1/dir2/module")
            )
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("something_else/module")
            )
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("../something/dir1/dir2/module")
            )

    def test_python_path_mixed(self):
        with patch(
            "os.environ",
            {
                "PYTHONPATH": os.pathsep.join(
                    [self._make_it_absolute("something"), "something"]
                )
            },
        ):
            assert file_util.file_in_pythonpath(
                self._make_it_absolute("something/dir1/dir2/module")
            )
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("something_else/module")
            )

    def test_current_directory(self):
        with patch("os.environ", {"PYTHONPATH": "."}):
            assert file_util.file_in_pythonpath(
                self._make_it_absolute("something/dir1/dir2/module")
            )
            assert file_util.file_in_pythonpath(
                self._make_it_absolute("something_else/module")
            )
            assert not file_util.file_in_pythonpath(
                self._make_it_absolute("../something_else/module")
            )

    def test_get_main_script_directory(self):
        """Test file_util.get_main_script_directory."""
        with patch("os.getcwd", return_value="/some/random"):
            assert file_util.get_main_script_directory("app.py") == "/some/random"
            assert file_util.get_main_script_directory("./app.py") == "/some/random"
            assert file_util.get_main_script_directory("../app.py") == "/some"
            assert (
                file_util.get_main_script_directory("/path/to/my/app.py")
                == "/path/to/my"
            )

    def test_normalize_path_join(self):
        """Test file_util.normalize_path_join."""
        assert (
            file_util.normalize_path_join("/some", "random", "path")
            == "/some/random/path"
        )
        assert (
            file_util.normalize_path_join("some", "random", "path")
            == "some/random/path"
        )
        assert (
            file_util.normalize_path_join("/some", "random", "./path")
            == "/some/random/path"
        )
        assert (
            file_util.normalize_path_join("some", "random", "./path")
            == "some/random/path"
        )
        assert (
            file_util.normalize_path_join("/some", "random", "../path") == "/some/path"
        )
        assert file_util.normalize_path_join("some", "random", "../path") == "some/path"
        assert file_util.normalize_path_join("/some", "random", "/path") == "/path"
        assert file_util.normalize_path_join("some", "random", "/path") == "/path"
        assert (
            file_util.normalize_path_join("some", "random", "path", "..")
            == "some/random"
        )
        assert (
            file_util.normalize_path_join("some", "random", "path", "../..") == "some"
        )
