tor-browser

The Tor Browser
git clone https://git.dasho.dev/tor-browser.git
Log | Files | Refs | README | LICENSE

test_base_script.py (36594B)


      1 import gc
      2 import os
      3 import re
      4 import shutil
      5 import tempfile
      6 import types
      7 import unittest
      8 from unittest import mock
      9 
     10 import mozunit
     11 
     12 PYWIN32 = False
     13 if os.name == "nt":
     14    try:
     15        import win32file
     16 
     17        PYWIN32 = True
     18    except ImportError:
     19        pass
     20 
     21 
     22 from mozharness.base import errors, log, script
     23 from mozharness.base.config import parse_config_file
     24 from mozharness.base.log import CRITICAL, DEBUG, ERROR, FATAL, IGNORE, INFO, WARNING
     25 
     26 here = os.path.dirname(os.path.abspath(__file__))
     27 
     28 test_string = """foo
     29 bar
     30 baz"""
     31 
     32 
     33 class CleanupObj(script.ScriptMixin, log.LogMixin):
     34    def __init__(self):
     35        super().__init__()
     36        self.log_obj = None
     37        self.config = {"log_level": ERROR}
     38 
     39 
     40 def cleanup(files=None):
     41    files = files or []
     42    files.extend(("test_logs", "test_dir", "tmpfile_stdout", "tmpfile_stderr"))
     43    gc.collect()
     44    c = CleanupObj()
     45    for f in files:
     46        c.rmtree(f)
     47 
     48 
     49 def get_debug_script_obj():
     50    s = script.BaseScript(
     51        config={"log_type": "multi", "log_level": DEBUG},
     52        initial_config_file="test/test.json",
     53    )
     54    return s
     55 
     56 
     57 def _post_fatal(self, **kwargs):
     58    fh = open("tmpfile_stdout", "w")
     59    print(test_string, file=fh)
     60    fh.close()
     61 
     62 
     63 # TestScript {{{1
     64 class TestScript(unittest.TestCase):
     65    def setUp(self):
     66        cleanup()
     67        self.s = None
     68        self.tmpdir = tempfile.mkdtemp(suffix=".mozharness")
     69 
     70    def tearDown(self):
     71        # Close the logfile handles, or windows can't remove the logs
     72        if hasattr(self, "s") and isinstance(self.s, object):
     73            del self.s
     74        cleanup([self.tmpdir])
     75 
     76    # test _dump_config_hierarchy() when --dump-config-hierarchy is passed
     77    def test_dump_config_hierarchy_valid_files_len(self):
     78        try:
     79            self.s = script.BaseScript(
     80                initial_config_file="test/test.json",
     81                option_args=["--cfg", "test/test_override.py,test/test_override2.py"],
     82                config={"dump_config_hierarchy": True},
     83            )
     84        except SystemExit:
     85            local_cfg_files = parse_config_file("test_logs/localconfigfiles.json")
     86            # first let's see if the correct number of config files were
     87            # realized
     88            self.assertEqual(
     89                len(local_cfg_files),
     90                4,
     91                msg="--dump-config-hierarchy dumped wrong number of config files",
     92            )
     93 
     94    def test_dump_config_hierarchy_keys_unique_and_valid(self):
     95        try:
     96            self.s = script.BaseScript(
     97                initial_config_file="test/test.json",
     98                option_args=["--cfg", "test/test_override.py,test/test_override2.py"],
     99                config={"dump_config_hierarchy": True},
    100            )
    101        except SystemExit:
    102            local_cfg_files = parse_config_file("test_logs/localconfigfiles.json")
    103            # now let's see if only unique items were added from each config
    104            t_override = local_cfg_files.get("test/test_override.py", {})
    105            self.assertTrue(
    106                t_override.get("keep_string") == "don't change me"
    107                and len(t_override.keys()) == 1,
    108                msg="--dump-config-hierarchy dumped wrong keys/value for "
    109                "`test/test_override.py`. There should only be one "
    110                "item and it should be unique to all the other "
    111                "items in test_log/localconfigfiles.json.",
    112            )
    113 
    114    def test_dump_config_hierarchy_matches_self_config(self):
    115        try:
    116            ######
    117            # we need temp_cfg because self.s will be gcollected (NoneType) by
    118            # the time we get to SystemExit exception
    119            # temp_cfg will differ from self.s.config because of
    120            # 'dump_config_hierarchy'. we have to make a deepcopy because
    121            # config is a locked dict
    122            temp_s = script.BaseScript(
    123                initial_config_file="test/test.json",
    124                option_args=["--cfg", "test/test_override.py,test/test_override2.py"],
    125            )
    126            from copy import deepcopy
    127 
    128            temp_cfg = deepcopy(temp_s.config)
    129            temp_cfg.update({"dump_config_hierarchy": True})
    130            ######
    131            self.s = script.BaseScript(
    132                initial_config_file="test/test.json",
    133                option_args=["--cfg", "test/test_override.py,test/test_override2.py"],
    134                config={"dump_config_hierarchy": True},
    135            )
    136        except SystemExit:
    137            local_cfg_files = parse_config_file("test_logs/localconfigfiles.json")
    138            # finally let's just make sure that all the items added up, equals
    139            # what we started with: self.config
    140            target_cfg = {}
    141            for cfg_file in local_cfg_files:
    142                target_cfg.update(local_cfg_files[cfg_file])
    143            self.assertEqual(
    144                target_cfg,
    145                temp_cfg,
    146                msg="all of the items (combined) in each cfg file dumped via "
    147                "--dump-config-hierarchy does not equal self.config ",
    148            )
    149 
    150    # test _dump_config() when --dump-config is passed
    151    def test_dump_config_equals_self_config(self):
    152        try:
    153            ######
    154            # we need temp_cfg because self.s will be gcollected (NoneType) by
    155            # the time we get to SystemExit exception
    156            # temp_cfg will differ from self.s.config because of
    157            # 'dump_config_hierarchy'. we have to make a deepcopy because
    158            # config is a locked dict
    159            temp_s = script.BaseScript(
    160                initial_config_file="test/test.json",
    161                option_args=["--cfg", "test/test_override.py,test/test_override2.py"],
    162            )
    163            from copy import deepcopy
    164 
    165            temp_cfg = deepcopy(temp_s.config)
    166            temp_cfg.update({"dump_config": True})
    167            ######
    168            self.s = script.BaseScript(
    169                initial_config_file="test/test.json",
    170                option_args=["--cfg", "test/test_override.py,test/test_override2.py"],
    171                config={"dump_config": True},
    172            )
    173        except SystemExit:
    174            target_cfg = parse_config_file("test_logs/localconfig.json")
    175            self.assertEqual(
    176                target_cfg,
    177                temp_cfg,
    178                msg="all of the items (combined) in each cfg file dumped via "
    179                "--dump-config does not equal self.config ",
    180            )
    181 
    182    def test_nonexistent_mkdir_p(self):
    183        self.s = script.BaseScript(initial_config_file="test/test.json")
    184        self.s.mkdir_p("test_dir/foo/bar/baz")
    185        self.assertTrue(os.path.isdir("test_dir/foo/bar/baz"), msg="mkdir_p error")
    186 
    187    def test_existing_mkdir_p(self):
    188        self.s = script.BaseScript(initial_config_file="test/test.json")
    189        os.makedirs("test_dir/foo/bar/baz")
    190        self.s.mkdir_p("test_dir/foo/bar/baz")
    191        self.assertTrue(
    192            os.path.isdir("test_dir/foo/bar/baz"), msg="mkdir_p error when dir exists"
    193        )
    194 
    195    def test_chdir(self):
    196        self.s = script.BaseScript(initial_config_file="test/test.json")
    197        cwd = os.getcwd()
    198        self.s.chdir("test_logs")
    199        self.assertEqual(os.path.join(cwd, "test_logs"), os.getcwd(), msg="chdir error")
    200        self.s.chdir(cwd)
    201 
    202    def _test_log_helper(self, obj):
    203        obj.debug("Testing DEBUG")
    204        obj.warning("Testing WARNING")
    205        obj.error("Testing ERROR")
    206        obj.critical("Testing CRITICAL")
    207        try:
    208            obj.fatal("Testing FATAL")
    209        except SystemExit:
    210            pass
    211        else:
    212            self.assertTrue(False, msg="fatal() didn't SystemExit!")
    213 
    214    def test_log(self):
    215        self.s = get_debug_script_obj()
    216        self.s.log_obj = None
    217        self._test_log_helper(self.s)
    218        del self.s
    219        self.s = script.BaseScript(initial_config_file="test/test.json")
    220        self._test_log_helper(self.s)
    221 
    222    def test_run_nonexistent_command(self):
    223        self.s = get_debug_script_obj()
    224        self.s.run_command(
    225            command="this_cmd_should_not_exist --help",
    226            env={"GARBLE": "FARG"},
    227            error_list=errors.PythonErrorList,
    228        )
    229        error_logsize = os.path.getsize("test_logs/test_info.log")
    230        self.assertTrue(error_logsize > 0, msg="command not found error not hit")
    231 
    232    def test_run_command_in_bad_dir(self):
    233        self.s = get_debug_script_obj()
    234        self.s.run_command(
    235            command="ls",
    236            cwd="/this_dir_should_not_exist",
    237            error_list=errors.PythonErrorList,
    238        )
    239        error_logsize = os.path.getsize("test_logs/test_error.log")
    240        self.assertTrue(error_logsize > 0, msg="bad dir error not hit")
    241 
    242    def test_get_output_from_command_in_bad_dir(self):
    243        self.s = get_debug_script_obj()
    244        self.s.get_output_from_command(command="ls", cwd="/this_dir_should_not_exist")
    245        error_logsize = os.path.getsize("test_logs/test_error.log")
    246        self.assertTrue(error_logsize > 0, msg="bad dir error not hit")
    247 
    248    def test_get_output_from_command_with_missing_file(self):
    249        self.s = get_debug_script_obj()
    250        self.s.get_output_from_command(command="ls /this_file_should_not_exist")
    251        error_logsize = os.path.getsize("test_logs/test_error.log")
    252        self.assertTrue(error_logsize > 0, msg="bad file error not hit")
    253 
    254    def test_get_output_from_command_with_missing_file2(self):
    255        self.s = get_debug_script_obj()
    256        self.s.run_command(
    257            command="cat mozharness/base/errors.py",
    258            error_list=[
    259                {"substr": "error", "level": ERROR},
    260                {
    261                    "regex": re.compile(",$"),
    262                    "level": IGNORE,
    263                },
    264                {
    265                    "substr": "]$",
    266                    "level": WARNING,
    267                },
    268            ],
    269        )
    270        error_logsize = os.path.getsize("test_logs/test_error.log")
    271        self.assertTrue(error_logsize > 0, msg="error list not working properly")
    272 
    273    def test_download_unpack(self):
    274        # NOTE: The action is called *download*, however, it can work for files in disk
    275        self.s = get_debug_script_obj()
    276 
    277        archives_path = os.path.join(here, "helper_files", "archives")
    278 
    279        # Test basic decompression
    280        for archive in (
    281            "archive.tar",
    282            "archive.tar.bz2",
    283            "archive.tar.gz",
    284            "archive.tar.xz",
    285            "archive.zip",
    286        ):
    287            self.s.download_unpack(
    288                url=os.path.join(archives_path, archive), extract_to=self.tmpdir
    289            )
    290            self.assertIn("script.sh", os.listdir(os.path.join(self.tmpdir, "bin")))
    291            self.assertIn("lorem.txt", os.listdir(self.tmpdir))
    292            shutil.rmtree(self.tmpdir)
    293 
    294        # Test permissions for extracted entries from zip archive
    295        self.s.download_unpack(
    296            url=os.path.join(archives_path, "archive.zip"),
    297            extract_to=self.tmpdir,
    298        )
    299        file_stats = os.stat(os.path.join(self.tmpdir, "bin", "script.sh"))
    300        orig_fstats = os.stat(
    301            os.path.join(archives_path, "reference", "bin", "script.sh")
    302        )
    303        self.assertEqual(file_stats.st_mode, orig_fstats.st_mode)
    304        shutil.rmtree(self.tmpdir)
    305 
    306        # Test unzip specific dirs only
    307        self.s.download_unpack(
    308            url=os.path.join(archives_path, "archive.zip"),
    309            extract_to=self.tmpdir,
    310            extract_dirs=["bin/*"],
    311        )
    312        self.assertIn("bin", os.listdir(self.tmpdir))
    313        self.assertNotIn("lorem.txt", os.listdir(self.tmpdir))
    314        shutil.rmtree(self.tmpdir)
    315 
    316        # Test for invalid filenames (Windows only)
    317        if PYWIN32:
    318            with self.assertRaises(IOError):
    319                self.s.download_unpack(
    320                    url=os.path.join(archives_path, "archive_invalid_filename.zip"),
    321                    extract_to=self.tmpdir,
    322                )
    323 
    324        for archive in (
    325            "archive-setuid.tar",
    326            "archive-escape.tar",
    327            "archive-link.tar",
    328            "archive-link-abs.tar",
    329            "archive-double-link.tar",
    330        ):
    331            with self.assertRaises(Exception):
    332                self.s.download_unpack(
    333                    url=os.path.join(archives_path, archive),
    334                    extract_to=self.tmpdir,
    335                )
    336 
    337    def test_unpack(self):
    338        self.s = get_debug_script_obj()
    339 
    340        archives_path = os.path.join(here, "helper_files", "archives")
    341 
    342        # Test basic decompression
    343        for archive in (
    344            "archive.tar",
    345            "archive.tar.bz2",
    346            "archive.tar.gz",
    347            "archive.tar.xz",
    348            "archive.zip",
    349        ):
    350            self.s.unpack(os.path.join(archives_path, archive), self.tmpdir)
    351            self.assertIn("script.sh", os.listdir(os.path.join(self.tmpdir, "bin")))
    352            self.assertIn("lorem.txt", os.listdir(self.tmpdir))
    353            shutil.rmtree(self.tmpdir)
    354 
    355        # Test permissions for extracted entries from zip archive
    356        self.s.unpack(os.path.join(archives_path, "archive.zip"), self.tmpdir)
    357        file_stats = os.stat(os.path.join(self.tmpdir, "bin", "script.sh"))
    358        orig_fstats = os.stat(
    359            os.path.join(archives_path, "reference", "bin", "script.sh")
    360        )
    361        self.assertEqual(file_stats.st_mode, orig_fstats.st_mode)
    362        shutil.rmtree(self.tmpdir)
    363 
    364        # Test extract specific dirs only
    365        self.s.unpack(
    366            os.path.join(archives_path, "archive.zip"),
    367            self.tmpdir,
    368            extract_dirs=["bin/*"],
    369        )
    370        self.assertIn("bin", os.listdir(self.tmpdir))
    371        self.assertNotIn("lorem.txt", os.listdir(self.tmpdir))
    372        shutil.rmtree(self.tmpdir)
    373 
    374        # Test for invalid filenames (Windows only)
    375        if PYWIN32:
    376            with self.assertRaises(IOError):
    377                self.s.unpack(
    378                    os.path.join(archives_path, "archive_invalid_filename.zip"),
    379                    self.tmpdir,
    380                )
    381 
    382        for archive in (
    383            "archive-setuid.tar",
    384            "archive-escape.tar",
    385            "archive-link.tar",
    386            "archive-link-abs.tar",
    387            "archive-double-link.tar",
    388        ):
    389            with self.assertRaises(Exception):
    390                self.s.unpack(os.path.join(archives_path, archive), self.tmpdir)
    391 
    392 
    393 # TestHelperFunctions {{{1
    394 class TestHelperFunctions(unittest.TestCase):
    395    temp_file = "test_dir/mozilla"
    396 
    397    def setUp(self):
    398        cleanup()
    399        self.s = None
    400 
    401    def tearDown(self):
    402        # Close the logfile handles, or windows can't remove the logs
    403        if hasattr(self, "s") and isinstance(self.s, object):
    404            del self.s
    405        cleanup()
    406 
    407    def _create_temp_file(self, contents=test_string):
    408        os.mkdir("test_dir")
    409        fh = open(self.temp_file, "w+")
    410        fh.write(contents)
    411        fh.close
    412 
    413    def test_mkdir_p(self):
    414        self.s = script.BaseScript(initial_config_file="test/test.json")
    415        self.s.mkdir_p("test_dir")
    416        self.assertTrue(os.path.isdir("test_dir"), msg="mkdir_p error")
    417 
    418    def test_get_output_from_command(self):
    419        self._create_temp_file()
    420        self.s = script.BaseScript(initial_config_file="test/test.json")
    421        contents = self.s.get_output_from_command([
    422            "bash",
    423            "-c",
    424            "cat %s" % self.temp_file,
    425        ])
    426        self.assertEqual(
    427            test_string,
    428            contents,
    429            msg="get_output_from_command('cat file') differs from fh.write",
    430        )
    431 
    432    def test_run_command(self):
    433        self._create_temp_file()
    434        self.s = script.BaseScript(initial_config_file="test/test.json")
    435        temp_file_name = os.path.basename(self.temp_file)
    436        self.assertEqual(
    437            self.s.run_command("cat %s" % temp_file_name, cwd="test_dir"),
    438            0,
    439            msg="run_command('cat file') did not exit 0",
    440        )
    441 
    442    def test_move1(self):
    443        self._create_temp_file()
    444        self.s = script.BaseScript(initial_config_file="test/test.json")
    445        temp_file2 = "%s2" % self.temp_file
    446        self.s.move(self.temp_file, temp_file2)
    447        self.assertFalse(
    448            os.path.exists(self.temp_file),
    449            msg="%s still exists after move()" % self.temp_file,
    450        )
    451 
    452    def test_move2(self):
    453        self._create_temp_file()
    454        self.s = script.BaseScript(initial_config_file="test/test.json")
    455        temp_file2 = "%s2" % self.temp_file
    456        self.s.move(self.temp_file, temp_file2)
    457        self.assertTrue(
    458            os.path.exists(temp_file2), msg="%s doesn't exist after move()" % temp_file2
    459        )
    460 
    461    def test_copyfile(self):
    462        self._create_temp_file()
    463        self.s = script.BaseScript(initial_config_file="test/test.json")
    464        temp_file2 = "%s2" % self.temp_file
    465        self.s.copyfile(self.temp_file, temp_file2)
    466        self.assertEqual(
    467            os.path.getsize(self.temp_file),
    468            os.path.getsize(temp_file2),
    469            msg="%s and %s are different sizes after copyfile()"
    470            % (self.temp_file, temp_file2),
    471        )
    472 
    473    def test_existing_rmtree(self):
    474        self._create_temp_file()
    475        self.s = script.BaseScript(initial_config_file="test/test.json")
    476        self.s.mkdir_p("test_dir/foo/bar/baz")
    477        self.s.rmtree("test_dir")
    478        self.assertFalse(os.path.exists("test_dir"), msg="rmtree unsuccessful")
    479 
    480    def test_nonexistent_rmtree(self):
    481        self.s = script.BaseScript(initial_config_file="test/test.json")
    482        status = self.s.rmtree("test_dir")
    483        self.assertFalse(status, msg="nonexistent rmtree error")
    484 
    485    @unittest.skipUnless(PYWIN32, "PyWin32 specific")
    486    def test_long_dir_rmtree(self):
    487        self.s = script.BaseScript(initial_config_file="test/test.json")
    488        # create a very long path that the command-prompt cannot delete
    489        # by using unicode format (max path length 32000)
    490        path = "\\\\?\\%s\\test_dir" % os.getcwd()
    491        win32file.CreateDirectoryExW(".", path)
    492 
    493        for x in range(0, 20):
    494            print("path=%s" % path)
    495            path = path + "\\%sxxxxxxxxxxxxxxxxxxxx" % x
    496            win32file.CreateDirectoryExW(".", path)
    497        self.s.rmtree("test_dir")
    498        self.assertFalse(os.path.exists("test_dir"), msg="rmtree unsuccessful")
    499 
    500    @unittest.skipUnless(PYWIN32, "PyWin32 specific")
    501    def test_chmod_rmtree(self):
    502        self._create_temp_file()
    503        win32file.SetFileAttributesW(self.temp_file, win32file.FILE_ATTRIBUTE_READONLY)
    504        self.s = script.BaseScript(initial_config_file="test/test.json")
    505        self.s.rmtree("test_dir")
    506        self.assertFalse(os.path.exists("test_dir"), msg="rmtree unsuccessful")
    507 
    508    @unittest.skipIf(os.name == "nt", "Not for Windows")
    509    def test_chmod(self):
    510        self._create_temp_file()
    511        self.s = script.BaseScript(initial_config_file="test/test.json")
    512        self.s.chmod(self.temp_file, 0o100700)
    513        self.assertEqual(os.stat(self.temp_file)[0], 33216, msg="chmod unsuccessful")
    514 
    515    def test_env_normal(self):
    516        self.s = script.BaseScript(initial_config_file="test/test.json")
    517        script_env = self.s.query_env()
    518        self.assertEqual(
    519            script_env,
    520            os.environ,
    521            msg="query_env() != env\n%s\n%s" % (script_env, os.environ),
    522        )
    523 
    524    def test_env_normal2(self):
    525        self.s = script.BaseScript(initial_config_file="test/test.json")
    526        self.s.query_env()
    527        script_env = self.s.query_env()
    528        self.assertEqual(
    529            script_env,
    530            os.environ,
    531            msg="Second query_env() != env\n%s\n%s" % (script_env, os.environ),
    532        )
    533 
    534    def test_env_partial(self):
    535        self.s = script.BaseScript(initial_config_file="test/test.json")
    536        script_env = self.s.query_env(partial_env={"foo": "bar"})
    537        self.assertTrue("foo" in script_env and script_env["foo"] == "bar")
    538 
    539    def test_env_path(self):
    540        self.s = script.BaseScript(initial_config_file="test/test.json")
    541        partial_path = "yaddayadda:%(PATH)s"
    542        full_path = partial_path % {"PATH": os.environ["PATH"]}
    543        script_env = self.s.query_env(partial_env={"PATH": partial_path})
    544        self.assertEqual(script_env["PATH"], full_path)
    545 
    546    def test_query_exe(self):
    547        self.s = script.BaseScript(
    548            initial_config_file="test/test.json",
    549            config={"exes": {"foo": "bar"}},
    550        )
    551        path = self.s.query_exe("foo")
    552        self.assertEqual(path, "bar")
    553 
    554    def test_query_exe_string_replacement(self):
    555        self.s = script.BaseScript(
    556            initial_config_file="test/test.json",
    557            config={
    558                "base_work_dir": "foo",
    559                "work_dir": "bar",
    560                "exes": {"foo": os.path.join("%(abs_work_dir)s", "baz")},
    561            },
    562        )
    563        path = self.s.query_exe("foo")
    564        self.assertEqual(path, os.path.join("foo", "bar", "baz"))
    565 
    566    def test_read_from_file(self):
    567        self._create_temp_file()
    568        self.s = script.BaseScript(initial_config_file="test/test.json")
    569        contents = self.s.read_from_file(self.temp_file)
    570        self.assertEqual(contents, test_string)
    571 
    572    def test_read_from_nonexistent_file(self):
    573        self.s = script.BaseScript(initial_config_file="test/test.json")
    574        contents = self.s.read_from_file("nonexistent_file!!!")
    575        self.assertEqual(contents, None)
    576 
    577 
    578 # TestScriptLogging {{{1
    579 class TestScriptLogging(unittest.TestCase):
    580    # I need a log watcher helper function, here and in test_log.
    581    def setUp(self):
    582        cleanup()
    583        self.s = None
    584 
    585    def tearDown(self):
    586        # Close the logfile handles, or windows can't remove the logs
    587        if hasattr(self, "s") and isinstance(self.s, object):
    588            del self.s
    589        cleanup()
    590 
    591    def test_info_logsize(self):
    592        self.s = script.BaseScript(
    593            config={"log_type": "multi"}, initial_config_file="test/test.json"
    594        )
    595        info_logsize = os.path.getsize("test_logs/test_info.log")
    596        self.assertTrue(info_logsize > 0, msg="initial info logfile missing/size 0")
    597 
    598    def test_add_summary_info(self):
    599        self.s = script.BaseScript(
    600            config={"log_type": "multi"}, initial_config_file="test/test.json"
    601        )
    602        info_logsize = os.path.getsize("test_logs/test_info.log")
    603        self.s.add_summary("one")
    604        info_logsize2 = os.path.getsize("test_logs/test_info.log")
    605        self.assertTrue(
    606            info_logsize < info_logsize2, msg="add_summary() info not logged"
    607        )
    608 
    609    def test_add_summary_warning(self):
    610        self.s = script.BaseScript(
    611            config={"log_type": "multi"}, initial_config_file="test/test.json"
    612        )
    613        warning_logsize = os.path.getsize("test_logs/test_warning.log")
    614        self.s.add_summary("two", level=WARNING)
    615        warning_logsize2 = os.path.getsize("test_logs/test_warning.log")
    616        self.assertTrue(
    617            warning_logsize < warning_logsize2,
    618            msg="add_summary(level=%s) not logged in warning log" % WARNING,
    619        )
    620 
    621    def test_summary(self):
    622        self.s = script.BaseScript(
    623            config={"log_type": "multi"}, initial_config_file="test/test.json"
    624        )
    625        self.s.add_summary("one")
    626        self.s.add_summary("two", level=WARNING)
    627        info_logsize = os.path.getsize("test_logs/test_info.log")
    628        warning_logsize = os.path.getsize("test_logs/test_warning.log")
    629        self.s.summary()
    630        info_logsize2 = os.path.getsize("test_logs/test_info.log")
    631        warning_logsize2 = os.path.getsize("test_logs/test_warning.log")
    632        msg = ""
    633        if info_logsize >= info_logsize2:
    634            msg += "summary() didn't log to info!\n"
    635        if warning_logsize >= warning_logsize2:
    636            msg += "summary() didn't log to warning!\n"
    637        self.assertEqual(msg, "", msg=msg)
    638 
    639    def _test_log_level(self, log_level, log_level_file_list):
    640        self.s = script.BaseScript(
    641            config={"log_type": "multi"}, initial_config_file="test/test.json"
    642        )
    643        if log_level != FATAL:
    644            self.s.log("testing", level=log_level)
    645        else:
    646            self.s._post_fatal = types.MethodType(_post_fatal, self.s)
    647            try:
    648                self.s.fatal("testing")
    649            except SystemExit:
    650                contents = None
    651                if os.path.exists("tmpfile_stdout"):
    652                    fh = open("tmpfile_stdout")
    653                    contents = fh.read()
    654                    fh.close()
    655                self.assertEqual(contents.rstrip(), test_string, "_post_fatal failed!")
    656        del self.s
    657        msg = ""
    658        for level in log_level_file_list:
    659            log_path = "test_logs/test_%s.log" % level
    660            if not os.path.exists(log_path):
    661                msg += "%s doesn't exist!\n" % log_path
    662            else:
    663                filesize = os.path.getsize(log_path)
    664                if not filesize > 0:
    665                    msg += "%s is size 0!\n" % log_path
    666        self.assertEqual(msg, "", msg=msg)
    667 
    668    def test_debug(self):
    669        self._test_log_level(DEBUG, [])
    670 
    671    def test_ignore(self):
    672        self._test_log_level(IGNORE, [])
    673 
    674    def test_info(self):
    675        self._test_log_level(INFO, [INFO])
    676 
    677    def test_warning(self):
    678        self._test_log_level(WARNING, [INFO, WARNING])
    679 
    680    def test_error(self):
    681        self._test_log_level(ERROR, [INFO, WARNING, ERROR])
    682 
    683    def test_critical(self):
    684        self._test_log_level(CRITICAL, [INFO, WARNING, ERROR, CRITICAL])
    685 
    686    def test_fatal(self):
    687        self._test_log_level(FATAL, [INFO, WARNING, ERROR, CRITICAL, FATAL])
    688 
    689 
    690 # TestRetry {{{1
    691 class NewError(Exception):
    692    pass
    693 
    694 
    695 class OtherError(Exception):
    696    pass
    697 
    698 
    699 class TestRetry(unittest.TestCase):
    700    def setUp(self):
    701        self.ATTEMPT_N = 1
    702        self.s = script.BaseScript(initial_config_file="test/test.json")
    703 
    704    def tearDown(self):
    705        # Close the logfile handles, or windows can't remove the logs
    706        if hasattr(self, "s") and isinstance(self.s, object):
    707            del self.s
    708        cleanup()
    709 
    710    def _succeedOnSecondAttempt(self, foo=None, exception=Exception):
    711        if self.ATTEMPT_N == 2:
    712            self.ATTEMPT_N += 1
    713            return
    714        self.ATTEMPT_N += 1
    715        raise exception("Fail")
    716 
    717    def _raiseCustomException(self):
    718        return self._succeedOnSecondAttempt(exception=NewError)
    719 
    720    def _alwaysPass(self):
    721        self.ATTEMPT_N += 1
    722        return True
    723 
    724    def _mirrorArgs(self, *args, **kwargs):
    725        return args, kwargs
    726 
    727    def _alwaysFail(self):
    728        raise Exception("Fail")
    729 
    730    def testRetrySucceed(self):
    731        # Will raise if anything goes wrong
    732        self.s.retry(self._succeedOnSecondAttempt, attempts=2, sleeptime=0)
    733 
    734    def testRetryFailWithoutCatching(self):
    735        self.assertRaises(
    736            Exception, self.s.retry, self._alwaysFail, sleeptime=0, exceptions=()
    737        )
    738 
    739    def testRetryFailEnsureRaisesLastException(self):
    740        self.assertRaises(
    741            SystemExit, self.s.retry, self._alwaysFail, sleeptime=0, error_level=FATAL
    742        )
    743 
    744    def testRetrySelectiveExceptionSucceed(self):
    745        self.s.retry(
    746            self._raiseCustomException,
    747            attempts=2,
    748            sleeptime=0,
    749            retry_exceptions=(NewError,),
    750        )
    751 
    752    def testRetrySelectiveExceptionFail(self):
    753        self.assertRaises(
    754            NewError,
    755            self.s.retry,
    756            self._raiseCustomException,
    757            attempts=2,
    758            sleeptime=0,
    759            retry_exceptions=(OtherError,),
    760        )
    761 
    762    # TODO: figure out a way to test that the sleep actually happened
    763    def testRetryWithSleep(self):
    764        self.s.retry(self._succeedOnSecondAttempt, attempts=2, sleeptime=1)
    765 
    766    def testRetryOnlyRunOnce(self):
    767        """Tests that retry() doesn't call the action again after success"""
    768        self.s.retry(self._alwaysPass, attempts=3, sleeptime=0)
    769        # self.ATTEMPT_N gets increased regardless of pass/fail
    770        self.assertEqual(2, self.ATTEMPT_N)
    771 
    772    def testRetryReturns(self):
    773        ret = self.s.retry(self._alwaysPass, sleeptime=0)
    774        self.assertEqual(ret, True)
    775 
    776    def testRetryCleanupIsCalled(self):
    777        cleanup = mock.Mock()
    778        self.s.retry(self._succeedOnSecondAttempt, cleanup=cleanup, sleeptime=0)
    779        self.assertEqual(cleanup.call_count, 1)
    780 
    781    def testRetryArgsPassed(self):
    782        args = (1, "two", 3)
    783        kwargs = dict(foo="a", bar=7)
    784        ret = self.s.retry(
    785            self._mirrorArgs, args=args, kwargs=kwargs.copy(), sleeptime=0
    786        )
    787        print(ret)
    788        self.assertEqual(ret[0], args)
    789        self.assertEqual(ret[1], kwargs)
    790 
    791 
    792 class BaseScriptWithDecorators(script.BaseScript):
    793    def __init__(self, *args, **kwargs):
    794        self._tmpdir = tempfile.mkdtemp(suffix=".mozharness")
    795        option_args = kwargs.get("option_args", [])
    796        option_args.extend(["--base-work-dir", self._tmpdir])
    797        kwargs["option_args"] = option_args
    798 
    799        super().__init__(*args, **kwargs)
    800 
    801        self.pre_run_1_args = []
    802        self.raise_during_pre_run_1 = False
    803        self.pre_action_1_args = []
    804        self.raise_during_pre_action_1 = False
    805        self.pre_action_2_args = []
    806        self.pre_action_3_args = []
    807        self.post_action_1_args = []
    808        self.raise_during_post_action_1 = False
    809        self.post_action_2_args = []
    810        self.post_action_3_args = []
    811        self.post_run_1_args = []
    812        self.raise_during_post_run_1 = False
    813        self.post_run_2_args = []
    814        self.raise_during_build = False
    815 
    816    @script.PreScriptRun
    817    def pre_run_1(self, *args, **kwargs):
    818        self.pre_run_1_args.append((args, kwargs))
    819 
    820        if self.raise_during_pre_run_1:
    821            raise Exception(self.raise_during_pre_run_1)
    822 
    823    @script.PreScriptAction
    824    def pre_action_1(self, *args, **kwargs):
    825        self.pre_action_1_args.append((args, kwargs))
    826 
    827        if self.raise_during_pre_action_1:
    828            raise Exception(self.raise_during_pre_action_1)
    829 
    830    @script.PreScriptAction
    831    def pre_action_2(self, *args, **kwargs):
    832        self.pre_action_2_args.append((args, kwargs))
    833 
    834    @script.PreScriptAction("clobber")
    835    def pre_action_3(self, *args, **kwargs):
    836        self.pre_action_3_args.append((args, kwargs))
    837 
    838    @script.PostScriptAction
    839    def post_action_1(self, *args, **kwargs):
    840        self.post_action_1_args.append((args, kwargs))
    841 
    842        if self.raise_during_post_action_1:
    843            raise Exception(self.raise_during_post_action_1)
    844 
    845    @script.PostScriptAction
    846    def post_action_2(self, *args, **kwargs):
    847        self.post_action_2_args.append((args, kwargs))
    848 
    849    @script.PostScriptAction("build")
    850    def post_action_3(self, *args, **kwargs):
    851        self.post_action_3_args.append((args, kwargs))
    852 
    853    @script.PostScriptRun
    854    def post_run_1(self, *args, **kwargs):
    855        self.post_run_1_args.append((args, kwargs))
    856 
    857        if self.raise_during_post_run_1:
    858            raise Exception(self.raise_during_post_run_1)
    859 
    860    @script.PostScriptRun
    861    def post_run_2(self, *args, **kwargs):
    862        self.post_run_2_args.append((args, kwargs))
    863 
    864    def build(self):
    865        if self.raise_during_build:
    866            raise Exception(self.raise_during_build)
    867 
    868 
    869 class TestScriptDecorators(unittest.TestCase):
    870    def setUp(self):
    871        cleanup()
    872        self.s = None
    873 
    874    def tearDown(self):
    875        if isinstance(getattr(self, "s", None), BaseScriptWithDecorators):
    876            cleanup([self.s._tmpdir])
    877 
    878        if hasattr(self, "s") and isinstance(self.s, object):
    879            del self.s
    880 
    881        cleanup()
    882 
    883    def test_decorators_registered(self):
    884        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    885 
    886        self.assertEqual(len(self.s._listeners["pre_run"]), 1)
    887        self.assertEqual(len(self.s._listeners["pre_action"]), 3)
    888        self.assertEqual(len(self.s._listeners["post_action"]), 3)
    889        self.assertEqual(len(self.s._listeners["post_run"]), 2)
    890 
    891    def test_pre_post_fired(self):
    892        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    893        self.s.run()
    894 
    895        self.assertEqual(len(self.s.pre_run_1_args), 1)
    896        self.assertEqual(len(self.s.pre_action_1_args), 2)
    897        self.assertEqual(len(self.s.pre_action_2_args), 2)
    898        self.assertEqual(len(self.s.pre_action_3_args), 1)
    899        self.assertEqual(len(self.s.post_action_1_args), 2)
    900        self.assertEqual(len(self.s.post_action_2_args), 2)
    901        self.assertEqual(len(self.s.post_action_3_args), 1)
    902        self.assertEqual(len(self.s.post_run_1_args), 1)
    903 
    904        self.assertEqual(self.s.pre_run_1_args[0], ((), {}))
    905 
    906        self.assertEqual(self.s.pre_action_1_args[0], (("clobber",), {}))
    907        self.assertEqual(self.s.pre_action_1_args[1], (("build",), {}))
    908 
    909        # pre_action_3 should only get called for the action it is registered
    910        # with.
    911        self.assertEqual(self.s.pre_action_3_args[0], (("clobber",), {}))
    912 
    913        self.assertEqual(self.s.post_action_1_args[0][0], ("clobber",))
    914        self.assertEqual(self.s.post_action_1_args[0][1], dict(success=True))
    915        self.assertEqual(self.s.post_action_1_args[1][0], ("build",))
    916        self.assertEqual(self.s.post_action_1_args[1][1], dict(success=True))
    917 
    918        # post_action_3 should only get called for the action it is registered
    919        # with.
    920        self.assertEqual(self.s.post_action_3_args[0], (("build",), dict(success=True)))
    921 
    922        self.assertEqual(self.s.post_run_1_args[0], ((), {}))
    923 
    924    def test_post_always_fired(self):
    925        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    926        self.s.raise_during_build = "Testing post always fired."
    927 
    928        with self.assertRaises(SystemExit):
    929            self.s.run()
    930 
    931        self.assertEqual(len(self.s.pre_run_1_args), 1)
    932        self.assertEqual(len(self.s.pre_action_1_args), 2)
    933        self.assertEqual(len(self.s.post_action_1_args), 2)
    934        self.assertEqual(len(self.s.post_action_2_args), 2)
    935        self.assertEqual(len(self.s.post_run_1_args), 1)
    936        self.assertEqual(len(self.s.post_run_2_args), 1)
    937 
    938        self.assertEqual(self.s.post_action_1_args[0][1], dict(success=True))
    939        self.assertEqual(self.s.post_action_1_args[1][1], dict(success=False))
    940        self.assertEqual(self.s.post_action_2_args[1][1], dict(success=False))
    941 
    942    def test_pre_run_exception(self):
    943        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    944        self.s.raise_during_pre_run_1 = "Error during pre run 1"
    945 
    946        with self.assertRaises(SystemExit):
    947            self.s.run()
    948 
    949        self.assertEqual(len(self.s.pre_run_1_args), 1)
    950        self.assertEqual(len(self.s.pre_action_1_args), 0)
    951        self.assertEqual(len(self.s.post_run_1_args), 1)
    952        self.assertEqual(len(self.s.post_run_2_args), 1)
    953 
    954    def test_pre_action_exception(self):
    955        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    956        self.s.raise_during_pre_action_1 = "Error during pre 1"
    957 
    958        with self.assertRaises(SystemExit):
    959            self.s.run()
    960 
    961        self.assertEqual(len(self.s.pre_run_1_args), 1)
    962        self.assertEqual(len(self.s.pre_action_1_args), 1)
    963        self.assertEqual(len(self.s.pre_action_2_args), 0)
    964        self.assertEqual(len(self.s.post_action_1_args), 1)
    965        self.assertEqual(len(self.s.post_action_2_args), 1)
    966        self.assertEqual(len(self.s.post_run_1_args), 1)
    967        self.assertEqual(len(self.s.post_run_2_args), 1)
    968 
    969    def test_post_action_exception(self):
    970        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    971        self.s.raise_during_post_action_1 = "Error during post 1"
    972 
    973        with self.assertRaises(SystemExit):
    974            self.s.run()
    975 
    976        self.assertEqual(len(self.s.pre_run_1_args), 1)
    977        self.assertEqual(len(self.s.post_action_1_args), 1)
    978        self.assertEqual(len(self.s.post_action_2_args), 1)
    979        self.assertEqual(len(self.s.post_run_1_args), 1)
    980        self.assertEqual(len(self.s.post_run_2_args), 1)
    981 
    982    def test_post_run_exception(self):
    983        self.s = BaseScriptWithDecorators(initial_config_file="test/test.json")
    984        self.s.raise_during_post_run_1 = "Error during post run 1"
    985 
    986        with self.assertRaises(SystemExit):
    987            self.s.run()
    988 
    989        self.assertEqual(len(self.s.post_run_1_args), 1)
    990        self.assertEqual(len(self.s.post_run_2_args), 1)
    991 
    992 
    993 # main {{{1
    994 if __name__ == "__main__":
    995    mozunit.main()