move some functions around in preparation for backend module

This commit is contained in:
Pelle Koster 2020-10-10 16:06:16 +02:00
parent 0594c33e53
commit 0dac621a15
10 changed files with 345 additions and 325 deletions

@ -131,9 +131,8 @@ def app(**kwds):
sys.modules.pop("pypiserver._app", None)
kwds = default_config(**kwds)
config, packages = core.configure(**kwds)
config = core.configure(**kwds)
_app.config = config
_app.packages = packages
_app.app.module = _app # HACK for testing.
return _app.app

@ -1,10 +1,13 @@
from collections import namedtuple
import logging
import mimetypes
import os
import re
import zipfile
import xml.dom.minidom
import xmlrpc.client as xmlrpclib
import zipfile
from collections import namedtuple
from io import BytesIO
from urllib.parse import urljoin, urlparse
from . import __version__
from . import core
@ -17,25 +20,9 @@ from .bottle import (
Bottle,
template,
)
try:
import xmlrpc.client as xmlrpclib # py3
except ImportError:
import xmlrpclib # py2
try:
from io import BytesIO
except ImportError:
from StringIO import StringIO as BytesIO
try: # PY3
from urllib.parse import urljoin, urlparse
except ImportError: # PY2
from urlparse import urljoin, urlparse
from .pkg_utils import guess_pkgname_and_version, normalize_pkgname_for_url
log = logging.getLogger(__name__)
packages = None
config = None
app = Bottle()
@ -104,7 +91,7 @@ def root():
fp = request.custom_fullpath
try:
numpkgs = len(list(packages()))
numpkgs = len(list(core.packages()))
except:
numpkgs = 0
@ -150,7 +137,7 @@ def remove_pkg():
pkgs = list(
filter(
lambda pkg: pkg.pkgname == name and pkg.version == version,
core.find_packages(packages()),
core.find_packages(),
)
)
if len(pkgs) == 0:
@ -182,11 +169,11 @@ def file_upload():
continue
if (
not is_valid_pkg_filename(uf.raw_filename)
or core.guess_pkgname_and_version(uf.raw_filename) is None
or guess_pkgname_and_version(uf.raw_filename) is None
):
raise HTTPError(400, f"Bad filename: {uf.raw_filename}")
if not config.overwrite and core.exists(packages.root, uf.raw_filename):
if not config.overwrite and core.exists(uf.raw_filename):
log.warning(
f"Cannot upload {uf.raw_filename!r} since it already exists! \n"
" You may start server with `--overwrite` option. "
@ -197,7 +184,7 @@ def file_upload():
" You may start server with `--overwrite` option.",
)
core.store(packages.root, uf.raw_filename, uf.save)
core.store(uf.raw_filename, uf.save)
if request.auth:
user = request.auth[0]
else:
@ -254,7 +241,7 @@ def handle_rpc():
)
response = []
ordering = 0
for p in packages():
for p in core.packages():
if p.pkgname.count(value) > 0:
# We do not presently have any description/summary, returning
# version instead
@ -275,7 +262,7 @@ def handle_rpc():
@app.route("/simple/")
@auth("list")
def simpleindex():
links = sorted(core.get_prefixes(packages()))
links = sorted(core.get_prefixes())
tmpl = """\
<html>
<head>
@ -296,12 +283,12 @@ def simpleindex():
@auth("list")
def simple(prefix=""):
# PEP 503: require normalized prefix
normalized = core.normalize_pkgname_for_url(prefix)
normalized = normalize_pkgname_for_url(prefix)
if prefix != normalized:
return redirect("/simple/{0}/".format(normalized), 301)
files = sorted(
core.find_packages(packages(), prefix=prefix),
core.find_packages(prefix=prefix),
key=lambda x: (x.parsed_version, x.relfn),
)
if not files:
@ -338,7 +325,7 @@ def simple(prefix=""):
def list_packages():
fp = request.custom_fullpath
files = sorted(
core.find_packages(packages()),
core.find_packages(),
key=lambda x: (os.path.dirname(x.relfn), x.pkgname, x.parsed_version),
)
links = [
@ -364,7 +351,7 @@ def list_packages():
@app.route("/packages/:filename#.*#")
@auth("download")
def server_static(filename):
entries = core.find_packages(packages())
entries = core.find_packages()
for x in entries:
f = x.relfn_unix
if f == filename:

144
pypiserver/backend.py Normal file

@ -0,0 +1,144 @@
import hashlib
import os
from pathlib import Path
from typing import List, Union
from .pkg_utils import (
normalize_pkgname,
parse_version,
is_allowed_path,
guess_pkgname_and_version,
)
class Backend:
def find_packages(self):
raise NotImplementedError
def find_package(self, name, version):
raise NotImplementedError
def add_package(self, pkg):
raise NotImplementedError
def remove_package(self):
raise NotImplementedError
def digest_file(self):
raise NotImplementedError
class SimpleFileBackend(Backend):
def __init__(self, roots: List[Union[str, Path]] = None):
self.roots = roots
class PkgFile:
__slots__ = [
"fn",
"root",
"_fname_and_hash",
"relfn",
"relfn_unix",
"pkgname_norm",
"pkgname",
"version",
"parsed_version",
"replaces",
]
def __init__(
self, pkgname, version, fn=None, root=None, relfn=None, replaces=None
):
self.pkgname = pkgname
self.pkgname_norm = normalize_pkgname(pkgname)
self.version = version
self.parsed_version = parse_version(version)
self.fn = fn
self.root = root
self.relfn = relfn
self.relfn_unix = None if relfn is None else relfn.replace("\\", "/")
self.replaces = replaces
def __repr__(self):
return "{}({})".format(
self.__class__.__name__,
", ".join(
[
f"{k}={getattr(self, k, 'AttributeError')!r}"
for k in sorted(self.__slots__)
]
),
)
def fname_and_hash(self, hash_algo):
if not hasattr(self, "_fname_and_hash"):
if hash_algo:
self._fname_and_hash = (
f"{self.relfn_unix}#{hash_algo}="
f"{digest_file(self.fn, hash_algo)}"
)
else:
self._fname_and_hash = self.relfn_unix
return self._fname_and_hash
def _listdir(root):
root = os.path.abspath(root)
for dirpath, dirnames, filenames in os.walk(root):
dirnames[:] = [x for x in dirnames if is_allowed_path(x)]
for x in filenames:
fn = os.path.join(root, dirpath, x)
if not is_allowed_path(x) or not os.path.isfile(fn):
continue
res = guess_pkgname_and_version(x)
if not res:
# #Seems the current file isn't a proper package
continue
pkgname, version = res
if pkgname:
yield PkgFile(
pkgname=pkgname,
version=version,
fn=fn,
root=root,
relfn=fn[len(root) + 1 :],
)
def _digest_file(fpath, hash_algo):
"""
Reads and digests a file according to specified hashing-algorith.
:param str sha256: any algo contained in :mod:`hashlib`
:return: <hash_algo>=<hex_digest>
From http://stackoverflow.com/a/21565932/548792
"""
blocksize = 2 ** 16
digester = getattr(hashlib, hash_algo)()
with open(fpath, "rb") as f:
for block in iter(lambda: f.read(blocksize), b""):
digester.update(block)
return digester.hexdigest()
try:
from .cache import cache_manager
def listdir(root):
# root must be absolute path
return cache_manager.listdir(root, _listdir)
def digest_file(fpath, hash_algo):
# fpath must be absolute path
return cache_manager.digest_file(fpath, hash_algo, _digest_file)
except ImportError:
pass
listdir = _listdir
digest_file = _digest_file

@ -1,4 +1,4 @@
#! /usr/bin/env python
#! /usr/bin/env python3
"""minimal PyPI like server for use with pip/easy_install"""
import functools
@ -8,26 +8,26 @@ import itertools
import logging
import mimetypes
import os
import re
import sys
try: # PY3
from urllib.parse import quote
except ImportError: # PY2
from urllib import quote
from typing import Optional
from urllib.parse import quote
import pkg_resources
from pypiserver import Configuration
from .backend import listdir
from .pkg_utils import normalize_pkgname
log = logging.getLogger(__name__)
packages: Optional[callable] = None
def configure(**kwds):
"""
:return: a 2-tuple (Configure, package-list)
"""
global packages
c = Configuration(**kwds)
log.info(f"+++Pypiserver invoked with: {c}")
@ -87,7 +87,7 @@ def configure(**kwds):
log.info(f"+++Pypiserver started with: {c}")
return c, packages
return c
def auth_by_htpasswd_file(htPsswdFile, username, password):
@ -102,207 +102,9 @@ mimetypes.add_type("application/octet-stream", ".whl")
mimetypes.add_type("text/plain", ".asc")
# ### Next 2 functions adapted from :mod:`distribute.pkg_resources`.
#
component_re = re.compile(r"(\d+ | [a-z]+ | \.| -)", re.I | re.VERBOSE)
replace = {"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@"}.get
def _parse_version_parts(s):
for part in component_re.split(s):
part = replace(part, part)
if part in ["", "."]:
continue
if part[:1] in "0123456789":
yield part.zfill(8) # pad for numeric comparison
else:
yield "*" + part
yield "*final" # ensure that alpha/beta/candidate are before final
def parse_version(s):
parts = []
for part in _parse_version_parts(s.lower()):
if part.startswith("*"):
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
return tuple(parts)
#
#### -- End of distribute's code.
_archive_suffix_rx = re.compile(
r"(\.zip|\.tar\.gz|\.tgz|\.tar\.bz2|-py[23]\.\d-.*|"
r"\.win-amd64-py[23]\.\d\..*|\.win32-py[23]\.\d\..*|\.egg)$",
re.I,
)
wheel_file_re = re.compile(
r"""^(?P<namever>(?P<name>.+?)-(?P<ver>\d.*?))
((-(?P<build>\d.*?))?-(?P<pyver>.+?)-(?P<abi>.+?)-(?P<plat>.+?)
\.whl|\.dist-info)$""",
re.VERBOSE,
)
_pkgname_re = re.compile(r"-\d+[a-z_.!+]", re.I)
_pkgname_parts_re = re.compile(
r"[\.\-](?=cp\d|py\d|macosx|linux|sunos|solaris|irix|aix|cygwin|win)", re.I
)
def _guess_pkgname_and_version_wheel(basename):
m = wheel_file_re.match(basename)
if not m:
return None, None
name = m.group("name")
ver = m.group("ver")
build = m.group("build")
if build:
return name, ver + "-" + build
else:
return name, ver
def guess_pkgname_and_version(path):
path = os.path.basename(path)
if path.endswith(".asc"):
path = path.rstrip(".asc")
if path.endswith(".whl"):
return _guess_pkgname_and_version_wheel(path)
if not _archive_suffix_rx.search(path):
return
path = _archive_suffix_rx.sub("", path)
if "-" not in path:
pkgname, version = path, ""
elif path.count("-") == 1:
pkgname, version = path.split("-", 1)
elif "." not in path:
pkgname, version = path.rsplit("-", 1)
else:
pkgname = _pkgname_re.split(path)[0]
ver_spec = path[len(pkgname) + 1 :]
parts = _pkgname_parts_re.split(ver_spec)
version = parts[0]
return pkgname, version
def normalize_pkgname(name):
"""Perform PEP 503 normalization"""
return re.sub(r"[-_.]+", "-", name).lower()
def normalize_pkgname_for_url(name):
"""Perform PEP 503 normalization and ensure the value is safe for URLs."""
return quote(re.sub(r"[-_.]+", "-", name).lower())
def is_allowed_path(path_part):
p = path_part.replace("\\", "/")
return not (p.startswith(".") or "/." in p)
class PkgFile:
__slots__ = [
"fn",
"root",
"_fname_and_hash",
"relfn",
"relfn_unix",
"pkgname_norm",
"pkgname",
"version",
"parsed_version",
"replaces",
]
def __init__(
self, pkgname, version, fn=None, root=None, relfn=None, replaces=None
):
self.pkgname = pkgname
self.pkgname_norm = normalize_pkgname(pkgname)
self.version = version
self.parsed_version = parse_version(version)
self.fn = fn
self.root = root
self.relfn = relfn
self.relfn_unix = None if relfn is None else relfn.replace("\\", "/")
self.replaces = replaces
def __repr__(self):
return "{}({})".format(
self.__class__.__name__,
", ".join(
[
f"{k}={getattr(self, k, 'AttributeError')!r}"
for k in sorted(self.__slots__)
]
),
)
def fname_and_hash(self, hash_algo):
if not hasattr(self, "_fname_and_hash"):
if hash_algo:
self._fname_and_hash = (
f"{self.relfn_unix}#{hash_algo}="
f"{digest_file(self.fn, hash_algo)}"
)
else:
self._fname_and_hash = self.relfn_unix
return self._fname_and_hash
def _listdir(root):
root = os.path.abspath(root)
for dirpath, dirnames, filenames in os.walk(root):
dirnames[:] = [x for x in dirnames if is_allowed_path(x)]
for x in filenames:
fn = os.path.join(root, dirpath, x)
if not is_allowed_path(x) or not os.path.isfile(fn):
continue
res = guess_pkgname_and_version(x)
if not res:
# #Seems the current file isn't a proper package
continue
pkgname, version = res
if pkgname:
yield PkgFile(
pkgname=pkgname,
version=version,
fn=fn,
root=root,
relfn=fn[len(root) + 1 :],
)
def read_lines(filename):
"""
Read the contents of `filename`, stripping empty lines and '#'-comments.
Return a list of strings, containing the lines of the file.
"""
lines = []
try:
with open(filename) as f:
lines = [
line
for line in (ln.strip() for ln in f.readlines())
if line and not line.startswith("#")
]
except Exception:
log.error(
f'Failed to read package blacklist file "{filename}". '
"Aborting server startup, please fix this."
)
raise
return lines
def find_packages(pkgs, prefix=""):
def find_packages(pkgs=None, prefix=""):
if pkgs is None:
pkgs = packages()
prefix = normalize_pkgname(prefix)
for x in pkgs:
if prefix and x.pkgname_norm != prefix:
@ -310,7 +112,8 @@ def find_packages(pkgs, prefix=""):
yield x
def get_prefixes(pkgs):
def get_prefixes():
pkgs = packages()
normalized_pkgnames = set()
for x in pkgs:
if x.pkgname:
@ -318,13 +121,15 @@ def get_prefixes(pkgs):
return normalized_pkgnames
def exists(root, filename):
def exists(filename):
root = packages.root
assert "/" not in filename
dest_fn = os.path.join(root, filename)
return os.path.exists(dest_fn)
def store(root, filename, save_method):
def store(filename, save_method):
root = packages.root
assert "/" not in filename
dest_fn = os.path.join(root, filename)
save_method(dest_fn, overwrite=True) # Overwite check earlier.
@ -341,35 +146,4 @@ def get_bad_url_redirect_path(request, prefix):
return p
def _digest_file(fpath, hash_algo):
"""
Reads and digests a file according to specified hashing-algorith.
:param str sha256: any algo contained in :mod:`hashlib`
:return: <hash_algo>=<hex_digest>
From http://stackoverflow.com/a/21565932/548792
"""
blocksize = 2 ** 16
digester = getattr(hashlib, hash_algo)()
with open(fpath, "rb") as f:
for block in iter(lambda: f.read(blocksize), b""):
digester.update(block)
return digester.hexdigest()
try:
from .cache import cache_manager
def listdir(root):
# root must be absolute path
return cache_manager.listdir(root, _listdir)
def digest_file(fpath, hash_algo):
# fpath must be absolute path
return cache_manager.digest_file(fpath, hash_algo, _digest_file)
except ImportError:
listdir = _listdir
digest_file = _digest_file

@ -7,12 +7,13 @@ import os
import sys
from distutils.version import LooseVersion
from subprocess import call
from xmlrpc.client import Server
import pip
from . import core
from xmlrpc.client import Server
from .backend import PkgFile, listdir
from .core import log
from .pkg_utils import normalize_pkgname, parse_version
def make_pypi_client(url):
@ -41,7 +42,7 @@ def filter_latest_pkgs(pkgs):
pkgname2latest = {}
for x in pkgs:
pkgname = core.normalize_pkgname(x.pkgname)
pkgname = normalize_pkgname(x.pkgname)
if pkgname not in pkgname2latest:
pkgname2latest[pkgname] = x
@ -53,9 +54,9 @@ def filter_latest_pkgs(pkgs):
def build_releases(pkg, versions):
for x in versions:
parsed_version = core.parse_version(x)
parsed_version = parse_version(x)
if parsed_version > pkg.parsed_version:
yield core.PkgFile(pkgname=pkg.pkgname, version=x, replaces=pkg)
yield PkgFile(pkgname=pkg.pkgname, version=x, replaces=pkg)
def find_updates(pkgset, stable_only=True):
@ -171,11 +172,11 @@ def update(pkgset, destdir=None, dry_run=False, stable_only=True):
def update_all_packages(
roots, destdir=None, dry_run=False, stable_only=True, blacklist_file=None
):
all_packages = itertools.chain(*[core.listdir(r) for r in roots])
all_packages = itertools.chain(*[listdir(r) for r in roots])
skip_packages = set()
if blacklist_file:
skip_packages = set(core.read_lines(blacklist_file))
skip_packages = set(read_lines(blacklist_file))
print(
'Skipping update of blacklisted packages (listed in "{}"): {}'.format(
blacklist_file, ", ".join(sorted(skip_packages))
@ -187,3 +188,26 @@ def update_all_packages(
)
update(packages, destdir, dry_run, stable_only)
def read_lines(filename):
"""
Read the contents of `filename`, stripping empty lines and '#'-comments.
Return a list of strings, containing the lines of the file.
"""
try:
with open(filename) as f:
lines = [
line
for line in (ln.strip() for ln in f.readlines())
if line and not line.startswith("#")
]
except Exception:
log.error(
f'Failed to read package blacklist file "{filename}". '
"Aborting server startup, please fix this."
)
raise
return lines

107
pypiserver/pkg_utils.py Normal file

@ -0,0 +1,107 @@
import os
import re
from urllib.parse import quote
def normalize_pkgname(name):
"""Perform PEP 503 normalization"""
return re.sub(r"[-_.]+", "-", name).lower()
def normalize_pkgname_for_url(name):
"""Perform PEP 503 normalization and ensure the value is safe for URLs."""
return quote(normalize_pkgname(name))
# ### Next 2 functions adapted from :mod:`distribute.pkg_resources`.
#
component_re = re.compile(r"(\d+ | [a-z]+ | \.| -)", re.I | re.VERBOSE)
replace = {"pre": "c", "preview": "c", "-": "final-", "rc": "c", "dev": "@"}.get
def _parse_version_parts(s):
for part in component_re.split(s):
part = replace(part, part)
if part in ["", "."]:
continue
if part[:1] in "0123456789":
yield part.zfill(8) # pad for numeric comparison
else:
yield "*" + part
yield "*final" # ensure that alpha/beta/candidate are before final
def parse_version(s):
parts = []
for part in _parse_version_parts(s.lower()):
if part.startswith("*"):
# remove trailing zeros from each series of numeric parts
while parts and parts[-1] == "00000000":
parts.pop()
parts.append(part)
return tuple(parts)
#
# ### -- End of distribute's code.
def is_allowed_path(path_part):
p = path_part.replace("\\", "/")
return not (p.startswith(".") or "/." in p)
_archive_suffix_rx = re.compile(
r"(\.zip|\.tar\.gz|\.tgz|\.tar\.bz2|-py[23]\.\d-.*|"
r"\.win-amd64-py[23]\.\d\..*|\.win32-py[23]\.\d\..*|\.egg)$",
re.I,
)
wheel_file_re = re.compile(
r"""^(?P<namever>(?P<name>.+?)-(?P<ver>\d.*?))
((-(?P<build>\d.*?))?-(?P<pyver>.+?)-(?P<abi>.+?)-(?P<plat>.+?)
\.whl|\.dist-info)$""",
re.VERBOSE,
)
_pkgname_re = re.compile(r"-\d+[a-z_.!+]", re.I)
_pkgname_parts_re = re.compile(
r"[\.\-](?=cp\d|py\d|macosx|linux|sunos|solaris|irix|aix|cygwin|win)", re.I
)
def _guess_pkgname_and_version_wheel(basename):
m = wheel_file_re.match(basename)
if not m:
return None, None
name = m.group("name")
ver = m.group("ver")
build = m.group("build")
if build:
return name, ver + "-" + build
else:
return name, ver
def guess_pkgname_and_version(path: str):
path = os.path.basename(path)
if path.endswith(".asc"):
path = path.rstrip(".asc")
if path.endswith(".whl"):
return _guess_pkgname_and_version_wheel(path)
if not _archive_suffix_rx.search(path):
return
path = _archive_suffix_rx.sub("", path)
if "-" not in path:
pkgname, version = path, ""
elif path.count("-") == 1:
pkgname, version = path.split("-", 1)
elif "." not in path:
pkgname, version = path.rsplit("-", 1)
else:
pkgname = _pkgname_re.split(path)[0]
ver_spec = path[len(pkgname) + 1 :]
parts = _pkgname_parts_re.split(ver_spec)
version = parts[0]
return pkgname, version

@ -4,19 +4,9 @@
import logging
import os
from html import unescape
import xmlrpc.client as xmlrpclib
try: # python 3
from html.parser import HTMLParser
from html import unescape
except ImportError:
from HTMLParser import HTMLParser
unescape = HTMLParser().unescape
try:
import xmlrpc.client as xmlrpclib
except ImportError:
import xmlrpclib # legacy Python
# Third party imports
import pytest
@ -24,7 +14,7 @@ import webtest
# Local Imports
from pypiserver import __main__, bottle
from pypiserver import __main__, bottle, backend
import tests.test_core as test_core

@ -6,7 +6,7 @@ import os
import pytest
from pypiserver import __main__, core
from pypiserver import __main__, core, backend, pkg_utils, manage
from tests.doubles import Namespace
@ -93,20 +93,20 @@ def _capitalize_ext(fpath):
@pytest.mark.parametrize(("filename", "pkgname", "version"), files)
def test_guess_pkgname_and_version(filename, pkgname, version):
exp = (pkgname, version)
assert core.guess_pkgname_and_version(filename) == exp
assert core.guess_pkgname_and_version(_capitalize_ext(filename)) == exp
assert pkg_utils.guess_pkgname_and_version(filename) == exp
assert pkg_utils.guess_pkgname_and_version(_capitalize_ext(filename)) == exp
@pytest.mark.parametrize(("filename", "pkgname", "version"), files)
def test_guess_pkgname_and_version_asc(filename, pkgname, version):
exp = (pkgname, version)
filename = f"{filename}.asc"
assert core.guess_pkgname_and_version(filename) == exp
assert pkg_utils.guess_pkgname_and_version(filename) == exp
def test_listdir_bad_name(tmpdir):
tmpdir.join("foo.whl").ensure()
res = list(core.listdir(tmpdir.strpath))
res = list(backend.listdir(tmpdir.strpath))
assert res == []
@ -123,7 +123,7 @@ def test_read_lines(tmpdir):
f = tmpdir.join(filename).ensure()
f.write(file_contents)
assert core.read_lines(f.strpath) == [
assert manage.read_lines(f.strpath) == [
"my_private_pkg",
"my_other_private_pkg",
]
@ -144,7 +144,7 @@ hashes = (
def test_hashfile(tmpdir, algo, digest):
f = tmpdir.join("empty")
f.ensure()
assert core.digest_file(f.strpath, algo) == digest
assert backend.digest_file(f.strpath, algo) == digest
@pytest.mark.parametrize("hash_algo", ("md5", "sha256", "sha512"))
@ -152,7 +152,7 @@ def test_fname_and_hash(tmpdir, hash_algo):
"""Ensure we are returning the expected hashes for files."""
f = tmpdir.join("tmpfile")
f.ensure()
pkgfile = core.PkgFile("tmp", "1.0.0", f.strpath, f.dirname, f.basename)
pkgfile = backend.PkgFile("tmp", "1.0.0", f.strpath, f.dirname, f.basename)
assert pkgfile.fname_and_hash(hash_algo) == "{}#{}={}".format(
f.basename, hash_algo, str(f.computehash(hashtype=hash_algo))
)
@ -168,6 +168,6 @@ def test_redirect_prefix_encodes_newlines():
def test_normalize_pkgname_for_url_encodes_newlines():
"""Ensure newlines are url encoded in package names for urls."""
assert "\n" not in core.normalize_pkgname_for_url(
assert "\n" not in pkg_utils.normalize_pkgname_for_url(
"/\nSet-Cookie:malicious=1;"
)

@ -1,7 +1,7 @@
#! /usr/bin/env py.test
import sys, os, pytest, logging
from pypiserver import __main__
from pypiserver import __main__, core
try:
from unittest import mock
@ -66,13 +66,13 @@ def test_server(main):
def test_root(main):
main(["--root", "."])
assert main.app.module.packages.root == os.path.abspath(".")
assert core.packages.root == os.path.abspath(".")
assert main.pkgdir == os.path.abspath(".")
def test_root_r(main):
main(["-r", "."])
assert main.app.module.packages.root == os.path.abspath(".")
assert core.packages.root == os.path.abspath(".")
assert main.pkgdir == os.path.abspath(".")

@ -1,22 +1,17 @@
#!/usr/bin/env py.test
"""Tests for manage.py."""
from __future__ import absolute_import, print_function, unicode_literals
try:
from unittest.mock import Mock
except ImportError:
from mock import Mock
import pypiserver.manage
from unittest.mock import Mock
import py
import pytest
from pypiserver import manage
from pypiserver.core import (
PkgFile,
guess_pkgname_and_version,
parse_version,
)
from pypiserver.pkg_utils import parse_version, guess_pkgname_and_version
from pypiserver.backend import PkgFile
from pypiserver.manage import (
PipCmd,
build_releases,
@ -220,8 +215,8 @@ def test_update_all_packages(monkeypatch):
def core_listdir_mock(directory):
return roots_mock.get(directory, [])
monkeypatch.setattr(manage.core, "listdir", core_listdir_mock)
monkeypatch.setattr(manage.core, "read_lines", Mock(return_value=[]))
monkeypatch.setattr(manage, "listdir", core_listdir_mock)
monkeypatch.setattr(manage, "read_lines", Mock(return_value=[]))
monkeypatch.setattr(manage, "update", Mock(return_value=None))
destdir = None
@ -237,7 +232,7 @@ def test_update_all_packages(monkeypatch):
blacklist_file=blacklist_file,
)
manage.core.read_lines.assert_not_called() # pylint: disable=no-member
pypiserver.manage.read_lines.assert_not_called() # pylint: disable=no-member
manage.update.assert_called_once_with( # pylint: disable=no-member
frozenset([public_pkg_1, public_pkg_2, private_pkg_1, private_pkg_2]),
destdir,
@ -264,9 +259,9 @@ def test_update_all_packages_with_blacklist(monkeypatch):
def core_listdir_mock(directory):
return roots_mock.get(directory, [])
monkeypatch.setattr(manage.core, "listdir", core_listdir_mock)
monkeypatch.setattr(manage, "listdir", core_listdir_mock)
monkeypatch.setattr(
manage.core,
manage,
"read_lines",
Mock(return_value=["my_private_pkg", "my_other_private_pkg"]),
)
@ -288,6 +283,6 @@ def test_update_all_packages_with_blacklist(monkeypatch):
manage.update.assert_called_once_with( # pylint: disable=no-member
frozenset([public_pkg_1, public_pkg_2]), destdir, dry_run, stable_only
)
manage.core.read_lines.assert_called_once_with(
pypiserver.manage.read_lines.assert_called_once_with(
blacklist_file
) # pylint: disable=no-member