From 6daf88a437668394594b5ac2bf4cca68a4179483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Mart=C3=ADnez?= Date: Thu, 24 Dec 2020 10:36:50 +0100 Subject: [PATCH] [FIX+IMP] contract: Fix tests + Batch modifications --- contract/models/contract.py | 10 ++++++++++ contract/models/contract_modification.py | 10 +++++----- contract/tests/test_contract.py | 2 +- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/contract/models/contract.py b/contract/models/contract.py index d8f5895d2..2993634fc 100644 --- a/contract/models/contract.py +++ b/contract/models/contract.py @@ -135,6 +135,16 @@ class ContractContract(models.Model): records._set_start_contract_modification() return records + def write(self, vals): + if 'modification_ids' in vals: + res = super(ContractContract, self.with_context( + bypass_modification_send=True + )).write(vals) + self._modification_mail_send() + else: + res = super(ContractContract, self).write(vals) + return res + @api.model def _set_start_contract_modification(self): for record in self: diff --git a/contract/models/contract_modification.py b/contract/models/contract_modification.py index 277047069..00bfbe8f1 100644 --- a/contract/models/contract_modification.py +++ b/contract/models/contract_modification.py @@ -33,15 +33,15 @@ class ContractModification(models.Model): @api.model_create_multi def create(self, vals_list): records = super().create(vals_list) - records.check_modification_ids_need_sent() + if not self.env.context.get('bypass_modification_send'): + records.check_modification_ids_need_sent() return records def write(self, vals): res = super().write(vals) - self.check_modification_ids_need_sent() + if not self.env.context.get('bypass_modification_send'): + self.check_modification_ids_need_sent() return res def check_modification_ids_need_sent(self): - records_not_sent = self.filtered(lambda x: not x.sent) - if records_not_sent: - records_not_sent.mapped('contract_id')._modification_mail_send() + self.mapped('contract_id')._modification_mail_send() diff --git a/contract/tests/test_contract.py b/contract/tests/test_contract.py index 6d6d24622..c3648bbd1 100644 --- a/contract/tests/test_contract.py +++ b/contract/tests/test_contract.py @@ -168,7 +168,7 @@ class TestContract(TestContractBase): ("model", "=", "contract.contract"), ("res_id", "=", self.contract.id), ]) - self.assertGreaterEqual(len(mail_messages), 3) + self.assertGreaterEqual(len(mail_messages), 2) def test_check_discount(self): with self.assertRaises(ValidationError):