From 24a348987aa5112cad623b5e54700ac0df8214cc Mon Sep 17 00:00:00 2001
From: shadchin <shadchin@yandex-team.com>
Date: Fri, 9 Feb 2024 11:05:14 +0300
Subject: Rework joinpath for importlib.resources

---
 library/python/runtime_py3/importer.pxi           | 13 ++--
 library/python/runtime_py3/sitecustomize.pyx      | 75 ++++++++++++++++-------
 library/python/runtime_py3/test/test_resources.py | 40 ++++++++++++
 3 files changed, 98 insertions(+), 30 deletions(-)

(limited to 'library/python')

diff --git a/library/python/runtime_py3/importer.pxi b/library/python/runtime_py3/importer.pxi
index 66ecd35167..0a194308c3 100644
--- a/library/python/runtime_py3/importer.pxi
+++ b/library/python/runtime_py3/importer.pxi
@@ -27,7 +27,7 @@ Y_PYTHON_EXTENDED_SOURCE_SEARCH = _os.environ.get(env_extended_source_search) or
 
 def _init_venv():
     if not _path_isabs(executable):
-        raise RuntimeError('path in sys.executable is not absolute: {}'.format(executable))
+        raise RuntimeError(f'path in sys.executable is not absolute: {executable}')
 
     # Creative copy-paste from site.py
     exe_dir, _ = _path_split(executable)
@@ -50,7 +50,7 @@ def _init_venv():
         if _path_isfile(conffile)
         ]
     if not candidate_confs:
-        raise RuntimeError('{} not found'.format(conf_basename))
+        raise RuntimeError(f'{conf_basename} not found')
     virtual_conf = candidate_confs[0]
     with FileIO(virtual_conf, 'r') as f:
         for line in f:
@@ -60,7 +60,7 @@ def _init_venv():
                 value = value.strip()
                 if key == cfg_source_root:
                     return value
-    raise RuntimeError('{} key not found in {}'.format(cfg_source_root, virtual_conf))
+    raise RuntimeError(f'{cfg_source_root} key not found in {virtual_conf}')
 
 
 def _get_source_root():
@@ -175,7 +175,7 @@ def mod_path(mod):
     return py_prefix + _b(mod).replace(b'.', b'/') + b'.py'
 
 
-class ResourceImporter(object):
+class ResourceImporter:
 
     """ A meta_path importer that loads code from built-in resources.
     """
@@ -292,7 +292,7 @@ class ResourceImporter(object):
         path = path.replace(_b('\\'), _b('/'))
         data = resfs_read(path, builtin=True)
         if data is None:
-            raise IOError(path)  # Y_PYTHON_ENTRY_POINT=:resource_files
+            raise OSError(path)  # Y_PYTHON_ENTRY_POINT=:resource_files
         return data
 
     # PEP-302 extension 2 of 3: get __file__ without importing.
@@ -513,8 +513,7 @@ class ArcadiaSourceFinder:
                     m = rx.match(mod)
                     if m:
                         found.append((prefix + m.group(1), self.is_package(mod)))
-            for cm in found:
-                yield cm
+            yield from found
 
             # Yield from file system
             for path in paths:
diff --git a/library/python/runtime_py3/sitecustomize.pyx b/library/python/runtime_py3/sitecustomize.pyx
index 64b8b909b2..25b4ccb55a 100644
--- a/library/python/runtime_py3/sitecustomize.pyx
+++ b/library/python/runtime_py3/sitecustomize.pyx
@@ -5,7 +5,12 @@ import re
 import sys
 import warnings
 
-from importlib.metadata import Distribution, DistributionFinder, PackageNotFoundError, Prepared
+from importlib.metadata import (
+    Distribution,
+    DistributionFinder,
+    PackageNotFoundError,
+    Prepared,
+)
 from importlib.resources.abc import Traversable
 
 import __res
@@ -15,12 +20,12 @@ with warnings.catch_warnings(action="ignore", category=DeprecationWarning):
 
 ResourceReader.register(__res._ResfsResourceReader)
 
-METADATA_NAME = re.compile('^Name: (.*)$', re.MULTILINE)
+METADATA_NAME = re.compile("^Name: (.*)$", re.MULTILINE)
 
 
-class ArcadiaResourceHandle(Traversable):
-    def __init__(self, key):
-        self.resfs_key = key
+class ArcadiaResource(Traversable):
+    def __init__(self, resfs_key):
+        self.resfs_key = resfs_key
 
     def is_file(self):
         return True
@@ -28,14 +33,14 @@ class ArcadiaResourceHandle(Traversable):
     def is_dir(self):
         return False
 
-    def open(self, mode='r', *args, **kwargs):
+    def open(self, mode="r", *args, **kwargs):
         data = __res.find(self.resfs_key.encode("utf-8"))
         if data is None:
             raise FileNotFoundError(self.resfs_key)
 
         stream = io.BytesIO(data)
 
-        if 'b' not in mode:
+        if "b" not in mode:
             stream = io.TextIOWrapper(stream, *args, **kwargs)
 
         return stream
@@ -50,6 +55,9 @@ class ArcadiaResourceHandle(Traversable):
     def name(self):
         return os.path.basename(self.resfs_key)
 
+    def __repr__(self):
+        return f"ArcadiaResource({self.resfs_key!r})"
+
 
 class ArcadiaResourceContainer(Traversable):
     def __init__(self, prefix):
@@ -62,29 +70,50 @@ class ArcadiaResourceContainer(Traversable):
         return False
 
     def iterdir(self):
-        for key, path_without_prefix in __res.iter_keys(self.resfs_prefix.encode("utf-8")):
+        seen = set()
+        for key, path_without_prefix in __res.iter_keys(
+            self.resfs_prefix.encode("utf-8")
+        ):
             if b"/" in path_without_prefix:
-                name = path_without_prefix.decode("utf-8").split("/", maxsplit=1)[0]
-                yield ArcadiaResourceContainer(f"{self.resfs_prefix}{name}/")
+                subdir = path_without_prefix.split(b"/", maxsplit=1)[0].decode("utf-8")
+                if subdir not in seen:
+                    seen.add(subdir)
+                    yield ArcadiaResourceContainer(f"{self.resfs_prefix}{subdir}/")
             else:
-                yield ArcadiaResourceHandle(key.decode("utf-8"))
+                yield ArcadiaResource(key.decode("utf-8"))
 
     def open(self, *args, **kwargs):
         raise IsADirectoryError(self.resfs_prefix)
 
+    @staticmethod
+    def _flatten(compound_names):
+        for name in compound_names:
+            yield from name.split("/")
+
     def joinpath(self, *descendants):
         if not descendants:
             return self
 
-        return ArcadiaResourceHandle(os.path.join(self.resfs_prefix, *descendants))
+        names = self._flatten(descendants)
+        target = next(names)
+        for traversable in self.iterdir():
+            if traversable.name == target:
+                if isinstance(traversable, ArcadiaResource):
+                    return traversable
+                else:
+                    return traversable.joinpath(*names)
+
+        raise FileNotFoundError("/".join(self._flatten(descendants)))
 
     @property
     def name(self):
         return os.path.basename(self.resfs_prefix[:-1])
 
+    def __repr__(self):
+        return f"ArcadiaResourceContainer({self.resfs_prefix!r})"
 
-class ArcadiaDistribution(Distribution):
 
+class ArcadiaDistribution(Distribution):
     def __init__(self, prefix):
         self.prefix = prefix
 
@@ -93,17 +122,17 @@ class ArcadiaDistribution(Distribution):
         return pathlib.Path(self.prefix)
 
     def read_text(self, filename):
-        data = __res.resfs_read(f'{self.prefix}{filename}')
+        data = __res.resfs_read(f"{self.prefix}{filename}")
         if data:
-            return data.decode('utf-8')
+            return data.decode("utf-8")
+
     read_text.__doc__ = Distribution.read_text.__doc__
 
     def locate_file(self, path):
-        return f'{self.prefix}{path}'
+        return f"{self.prefix}{path}"
 
 
 class ArcadiaMetadataFinder(DistributionFinder):
-
     prefixes = {}
 
     @classmethod
@@ -116,14 +145,14 @@ class ArcadiaMetadataFinder(DistributionFinder):
         cls.prefixes.clear()
 
         for resource in __res.resfs_files():
-            resource = resource.decode('utf-8')
-            if not resource.endswith('METADATA'):
+            resource = resource.decode("utf-8")
+            if not resource.endswith("METADATA"):
                 continue
-            data = __res.resfs_read(resource).decode('utf-8')
+            data = __res.resfs_read(resource).decode("utf-8")
             metadata_name = METADATA_NAME.search(data)
             if metadata_name:
                 metadata_name = Prepared(metadata_name.group(1))
-                cls.prefixes[metadata_name.normalized] = resource[:-len('METADATA')]
+                cls.prefixes[metadata_name.normalized] = resource[: -len("METADATA")]
 
     @classmethod
     def _search_prefixes(cls, name):
@@ -136,10 +165,10 @@ class ArcadiaMetadataFinder(DistributionFinder):
             except KeyError:
                 raise PackageNotFoundError(name)
         else:
-            for prefix in sorted(cls.prefixes.values()):
-                yield prefix
+            yield from sorted(cls.prefixes.values())
 
 
 # monkeypatch standart library
 import importlib.metadata
+
 importlib.metadata.MetadataPathFinder = ArcadiaMetadataFinder
diff --git a/library/python/runtime_py3/test/test_resources.py b/library/python/runtime_py3/test/test_resources.py
index 059cc039e6..75c1eb3549 100644
--- a/library/python/runtime_py3/test/test_resources.py
+++ b/library/python/runtime_py3/test/test_resources.py
@@ -71,3 +71,43 @@ def test_read_text_missing():
 )
 def test_contents_good_path(package, expected):
     assert sorted(ir.contents(package)) == sorted(expected)
+
+
+def test_files_joinpath():
+    assert ir.files("resources") / "submodule"
+    assert ir.files("resources") / "foo.txt"
+    assert ir.files("resources") / "submodule" / "bar.txt"
+    assert ir.files("resources.submodule") / "bar.txt"
+
+
+@pytest.mark.parametrize(
+    "package, resource, expected",
+    (
+        ("resources", "foo.txt", b"bar"),
+        ("resources.submodule", "bar.txt", b"foo"),
+    ),
+)
+def test_files_read_bytes(package, resource, expected):
+    assert (ir.files(package) / resource).read_bytes() == expected
+
+
+@pytest.mark.parametrize(
+    "package, resource, expected",
+    (
+        ("resources", "foo.txt", "bar"),
+        ("resources.submodule", "bar.txt", "foo"),
+    ),
+)
+def test_files_read_text(package, resource, expected):
+    assert (ir.files(package) / resource).read_text() == expected
+
+
+@pytest.mark.parametrize(
+    "package, expected",
+    (
+        ("resources", ("foo.txt", "submodule")),
+        ("resources.submodule", ("bar.txt",)),
+    ),
+)
+def test_files_iterdir(package, expected):
+    assert tuple(resource.name for resource in ir.files(package).iterdir()) == expected
-- 
cgit v1.2.3